You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 5.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import json
  2. from typing import Iterable
  3. from pyspark.sql import SparkSession, DataFrame, Row
  4. from pyspark.sql import functions as F
  5. import time
  6. start = time.time()
  7. config = json.load(open("./settings.json"))
  8. debug = config['debug']
  9. class Master:
  10. spark: SparkSession
  11. CLUSTERS_TABLE: str
  12. TX_TABLE: str
  13. def __init__(self, config):
  14. self.spark = self.makeSparkContext(config)
  15. self.config = config
  16. self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
  17. self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
  18. def makeSparkContext(self, config) -> SparkSession:
  19. return SparkSession.builder \
  20. .appName('SparkCassandraApp') \
  21. .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
  22. .getOrCreate()
  23. def get_tx_dataframe(self) -> DataFrame:
  24. return self.spark.table(self.TX_TABLE)
  25. def union_single_col(self, df1: DataFrame, df2: DataFrame, column: str) -> DataFrame:
  26. return df1 \
  27. .select(column) \
  28. .union(df2.select(column))
  29. def reduce_concat_array_column(self, df: DataFrame, column: str, distinct:bool = False) -> DataFrame:
  30. df = self.explode_array_col(df.select(column), column)
  31. return self.collect_col_to_array(df, column, distinct)
  32. def collect_col_to_array(self, df: DataFrame, column: str, distinct: bool = False) -> DataFrame:
  33. if(distinct):
  34. return df.select(F.collect_set(column).alias(column))
  35. else:
  36. return df.select(F.collect_list(column).alias(column))
  37. def explode_array_col(self, df: DataFrame, column: str) -> DataFrame:
  38. return df \
  39. .rdd \
  40. .flatMap(lambda row: list(map(lambda elem: (elem,), row[column]))) \
  41. .toDF([column])
  42. # end class Master
  43. def cluster_id_addresses_rows(iter: "Iterable[Row]") -> Iterable:
  44. return iter
  45. master = Master(config)
  46. master.spark.catalog.clearCache()
  47. master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir'])
  48. tx_df = master.get_tx_dataframe()
  49. #Turn transactions into a list of ('id', [addr, addr, ...])
  50. tx_grouped = tx_df \
  51. .groupBy('tx_id') \
  52. .agg(F.collect_set('address').alias('addresses'))
  53. tx_grouped.rdd.mapPartitions(cluster_id_addresses_rows)
  54. # TODO: Load clusters from DB, check if any exist, if no make initial cluster, else proceed with loaded data
  55. # find initial cluster
  56. # take the first tx
  57. tx_zero = tx_grouped \
  58. .select('*') \
  59. .limit(1)
  60. # find txs with overlapping addresses
  61. overlapping_txs = tx_grouped \
  62. .join(
  63. tx_zero \
  64. .withColumnRenamed('addresses', 'tx_addresses') \
  65. .withColumnRenamed('tx_id', 'overlap_id')
  66. ) \
  67. .select(
  68. tx_grouped.tx_id,
  69. tx_grouped.addresses,
  70. F.arrays_overlap(tx_grouped.addresses, 'tx_addresses').alias('overlap')
  71. ) \
  72. .where(F.col('overlap') == True) \
  73. .drop('overlap')
  74. # overlapped txs must not be considered anymore, so remove them candidate dataframe
  75. tx_grouped = tx_grouped \
  76. .join(
  77. overlapping_txs.drop('addresses'),
  78. 'tx_id',
  79. 'leftanti'
  80. )
  81. # get the distinct addresses of all overlaps in a single array
  82. distinct_addresses = master.reduce_concat_array_column(
  83. master.union_single_col(
  84. overlapping_txs,
  85. tx_zero,
  86. column='addresses'
  87. ),
  88. column='addresses',
  89. distinct=True,
  90. )
  91. #pick out a random representative for this cluster and add it to every address
  92. cluster = distinct_addresses \
  93. .rdd \
  94. .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
  95. .toDF(['address', 'id'])
  96. # done finding initial cluster
  97. #group cluster by representative and transform the result into a list of shape ('id', ['addr', 'addr', ...])
  98. clusters_grouped = cluster \
  99. .groupBy('id') \
  100. .agg(F.collect_list('address').alias('addresses'))
  101. def take_tx_and_cluster(txs: DataFrame, clusters: DataFrame, n=0):
  102. if (txs.count() == 0): # done!
  103. return clusters
  104. # take a random tx
  105. tx = txs \
  106. .select('*').limit(1)
  107. # find clusters with overlapping addresses from tx
  108. overlapping_clusters = clusters \
  109. .join(tx.withColumnRenamed('addresses', 'tx_addresses')) \
  110. .select(
  111. clusters.id,
  112. clusters.addresses,
  113. 'tx_addresses',
  114. F.arrays_overlap(clusters.addresses,'tx_addresses').alias('overlap')
  115. ) \
  116. .where(F.col('overlap') == True)
  117. clusters_union_tx = master.union_single_col(tx, overlapping_clusters, 'addresses')
  118. #collect all addresses into single array field
  119. new_cluster_arrays = master.reduce_concat_array_column(
  120. clusters_union_tx,
  121. column='addresses',
  122. distinct=True
  123. )
  124. #declare cluster representative
  125. new_cluster = new_cluster_arrays \
  126. .rdd \
  127. .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
  128. .toDF(['address', 'id']) \
  129. .groupBy('id') \
  130. .agg(F.collect_list('address').alias('addresses'))
  131. txs = txs.join(tx, 'tx_id', 'leftanti')
  132. clusters = clusters.join(overlapping_clusters, 'id', 'leftanti').union(new_cluster)
  133. #the RDD legacy (internal history tracker) gets too big as iterations continue, use checkpoint to prune it regularly
  134. if(n % 3 == 0):
  135. txs = txs.checkpoint()
  136. clusters = clusters.checkpoint()
  137. #start new round with txs minus the one just used, and updated clusters
  138. return take_tx_and_cluster(txs,clusters,n+1)
  139. result = take_tx_and_cluster(tx_grouped, clusters_grouped).collect()
  140. for row in result:
  141. print(sorted(row['addresses']))
  142. end = time.time()
  143. print("ELAPSED TIME:", end-start)