You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_with_collect.py 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from gc import collect
  2. import json
  3. from sqlite3 import Row
  4. from typing import Iterable, List
  5. from pyspark.sql import SparkSession, DataFrame, Row
  6. from pyspark.sql import functions as F
  7. import time
  8. start = time.time()
  9. config = json.load(open("./settings.json"))
  10. debug = config['debug']
  11. class Master:
  12. spark: SparkSession
  13. CLUSTERS_TABLE: str
  14. TX_TABLE: str
  15. def __init__(self, config):
  16. self.spark = self.makeSparkContext(config)
  17. self.config = config
  18. self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
  19. self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
  20. def makeSparkContext(self,config) -> SparkSession:
  21. return SparkSession.builder \
  22. .appName('SparkCassandraApp') \
  23. .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
  24. .getOrCreate()
  25. def group_tx_addrs(self) -> DataFrame:
  26. return self.spark \
  27. .read \
  28. .table(self.TX_TABLE) \
  29. .groupBy("tx_id") \
  30. .agg(F.collect_set('address').alias('addresses'))
  31. def group_cluster_addrs(self) -> DataFrame:
  32. return self.spark \
  33. .read \
  34. .table(self.CLUSTERS_TABLE) \
  35. .groupBy("id") \
  36. .agg(F.collect_set('address').alias('addresses'))
  37. def insertNewCluster (self, addrs: Iterable[str], root: str | None = None) -> str:
  38. if(root == None):
  39. root = addrs[0]
  40. df = self.spark.createDataFrame(map(lambda addr: (addr, root), addrs), schema=['address', 'id'])
  41. df.writeTo(self.CLUSTERS_TABLE).append()
  42. return root
  43. def enumerate(self, data: DataFrame) -> DataFrame:
  44. return data \
  45. .rdd \
  46. .zipWithIndex() \
  47. .toDF(["tx_group", "index"])
  48. def rewrite_cluster_id(self, cluster_roots: Iterable[str], new_cluster_root: str) -> None:
  49. cluster_rewrite = self.spark \
  50. .table(self.CLUSTERS_TABLE) \
  51. .where(F.col('id').isin(cluster_roots)) \
  52. .select('address') \
  53. .rdd \
  54. .map(lambda addr: (addr['address'], new_cluster_root)) \
  55. .toDF(['address', 'id']) \
  56. if(debug):
  57. print("REWRITE JOB")
  58. cluster_rewrite.show(truncate=False, vertical=True)
  59. print()
  60. cluster_rewrite.writeTo(self.CLUSTERS_TABLE).append()
  61. # end class Master
  62. """
  63. tuple structure:
  64. Row => Row(id=addr, addresses=list[addr] | the cluster
  65. Iterable[str] => list[addr] | the transaction addresses
  66. """
  67. def find(data: tuple[Row, Iterable[str]]) -> str | None:
  68. cluster = data[0]
  69. tx = data[1]
  70. clusteraddresses = cluster['addresses'] + [cluster['id']]
  71. if any(x in tx for x in clusteraddresses):
  72. return cluster['id']
  73. else:
  74. return None
  75. master = Master(config)
  76. tx_addr_groups = master.group_tx_addrs()
  77. tx_groups_indexed = master.enumerate(tx_addr_groups).cache()
  78. for i in range(0, tx_addr_groups.count()):
  79. cluster_addr_groups = master.group_cluster_addrs()
  80. if(debug):
  81. print("KNOWN CLUSTERS")
  82. cluster_addr_groups.show(truncate=True)
  83. print()
  84. tx_addrs: Iterable[str] = tx_groups_indexed \
  85. .where(tx_groups_indexed.index == i) \
  86. .select('tx_group') \
  87. .collect()[0]['tx_group']['addresses']
  88. if(debug):
  89. print("CURRENT TX")
  90. print(tx_addrs)
  91. print()
  92. if (cluster_addr_groups.count() == 0):
  93. master.insertNewCluster(tx_addrs)
  94. continue
  95. cluster_tx_mapping = cluster_addr_groups \
  96. .rdd \
  97. .map(lambda cluster: (cluster, tx_addrs))
  98. if(debug):
  99. print("cluster_tx_mapping")
  100. cluster_tx_mapping \
  101. .toDF(['cluster', 'tx']) \
  102. .show(truncate=True)
  103. print()
  104. matched_roots: "List[str]" = cluster_tx_mapping \
  105. .map(find) \
  106. .filter(lambda root: root != None) \
  107. .collect()
  108. if(debug):
  109. print("FOUND ROOTS")
  110. print(matched_roots)
  111. print()
  112. if(len(matched_roots) == 0):
  113. master.insertNewCluster(tx_addrs)
  114. elif(len(matched_roots) == 1):
  115. master.insertNewCluster(tx_addrs, matched_roots[0])
  116. else:
  117. master.rewrite_cluster_id(matched_roots[1:], matched_roots[0])
  118. master.insertNewCluster(tx_addrs, matched_roots[0])
  119. if(debug):
  120. print("======================================================================")
  121. end = time.time()
  122. print("ELAPSED TIME:", end-start)