You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import json
  2. from pyspark.sql import SparkSession, DataFrame, Row
  3. from pyspark.sql import functions as F
  4. import time
  5. start = time.time()
  6. config = json.load(open("./settings.json"))
  7. debug = config['debug']
  8. class Master:
  9. spark: SparkSession
  10. CLUSTERS_TABLE: str
  11. TX_TABLE: str
  12. def __init__(self, config):
  13. self.spark = self.makeSparkContext(config)
  14. self.config = config
  15. self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
  16. self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
  17. def makeSparkContext(self, config) -> SparkSession:
  18. return SparkSession.builder \
  19. .appName('SparkCassandraApp') \
  20. .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
  21. .getOrCreate()
  22. def get_tx_dataframe(self) -> DataFrame:
  23. return self.spark.table(self.TX_TABLE)
  24. def union_single_col(self, df1: DataFrame, df2: DataFrame, column: str) -> DataFrame:
  25. return df1 \
  26. .select(column) \
  27. .union(df2.select(column))
  28. def reduce_concat_array_column(self, df: DataFrame, column: str, distinct:bool = False) -> DataFrame:
  29. df = self.explode_array_col(df.select(column), column)
  30. return self.collect_col_to_array(df, column, distinct)
  31. def collect_col_to_array(self, df: DataFrame, column: str, distinct: bool = False) -> DataFrame:
  32. if(distinct):
  33. return df.select(F.collect_set(column).alias(column))
  34. else:
  35. return df.select(F.collect_list(column).alias(column))
  36. def explode_array_col(self, df: DataFrame, column: str) -> DataFrame:
  37. return df \
  38. .rdd \
  39. .flatMap(lambda row: list(map(lambda elem: (elem,), row[column]))) \
  40. .toDF([column])
  41. # end class Master
  42. master = Master(config)
  43. master.spark.catalog.clearCache()
  44. master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir'])
  45. tx_df = master.get_tx_dataframe()
  46. #Turn transactions into a list of ('id', [addr, addr, ...])
  47. tx_grouped = tx_df \
  48. .groupBy('tx_id') \
  49. .agg(F.collect_set('address').alias('addresses'))
  50. # TODO: Load clusters from DB, check if any exist, if no make initial cluster, else proceed with loaded data
  51. # find initial cluster
  52. # take the first tx
  53. tx_zero = tx_grouped \
  54. .select('*') \
  55. .limit(1)
  56. # find txs with overlapping addresses
  57. overlapping_txs = tx_grouped \
  58. .join(
  59. tx_zero \
  60. .withColumnRenamed('addresses', 'tx_addresses') \
  61. .withColumnRenamed('tx_id', 'overlap_id')
  62. ) \
  63. .select(
  64. tx_grouped.tx_id,
  65. tx_grouped.addresses,
  66. F.arrays_overlap(tx_grouped.addresses, 'tx_addresses').alias('overlap')
  67. ) \
  68. .where(F.col('overlap') == True) \
  69. .drop('overlap')
  70. # overlapped txs must not be considered anymore, so remove them candidate dataframe
  71. tx_grouped = tx_grouped \
  72. .join(
  73. overlapping_txs.drop('addresses'),
  74. 'tx_id',
  75. 'leftanti'
  76. )
  77. # get the distinct addresses of all overlaps in a single array
  78. distinct_addresses = master.reduce_concat_array_column(
  79. master.union_single_col(
  80. overlapping_txs,
  81. tx_zero,
  82. column='addresses'
  83. ),
  84. column='addresses',
  85. distinct=True,
  86. )
  87. #pick out a random representative for this cluster and add it to every address
  88. cluster = distinct_addresses \
  89. .rdd \
  90. .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
  91. .toDF(['address', 'id'])
  92. # done finding initial cluster
  93. #group cluster by representative and transform the result into a list of shape ('id', ['addr', 'addr', ...])
  94. clusters_grouped = cluster \
  95. .groupBy('id') \
  96. .agg(F.collect_list('address').alias('addresses'))
  97. def take_tx_and_cluster(txs: DataFrame, clusters: DataFrame, n=0):
  98. if (txs.count() == 0): # done!
  99. return clusters
  100. # take a random tx
  101. tx = txs \
  102. .select('*').limit(1)
  103. # find clusters with overlapping addresses from tx
  104. overlapping_clusters = clusters \
  105. .join(tx.withColumnRenamed('addresses', 'tx_addresses')) \
  106. .select(
  107. clusters.id,
  108. clusters.addresses,
  109. 'tx_addresses',
  110. F.arrays_overlap(clusters.addresses,'tx_addresses').alias('overlap')
  111. ) \
  112. .where(F.col('overlap') == True)
  113. clusters_union_tx = master.union_single_col(tx, overlapping_clusters, 'addresses')
  114. #collect all addresses into single array field
  115. new_cluster_arrays = master.reduce_concat_array_column(
  116. clusters_union_tx,
  117. column='addresses',
  118. distinct=True
  119. )
  120. #declare cluster representative
  121. new_cluster = new_cluster_arrays \
  122. .rdd \
  123. .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
  124. .toDF(['address', 'id']) \
  125. .groupBy('id') \
  126. .agg(F.collect_list('address').alias('addresses'))
  127. txs = txs.join(tx, 'tx_id', 'leftanti')
  128. clusters = clusters.join(overlapping_clusters, 'id', 'leftanti').union(new_cluster)
  129. #the RDD legacy (internal history tracker) gets too big as iterations continue, use checkpoint to prune it regularly
  130. if(n % 3 == 0):
  131. txs = txs.checkpoint()
  132. clusters = clusters.checkpoint()
  133. #start new round with txs minus the one just used, and updated clusters
  134. return take_tx_and_cluster(txs,clusters,n+1)
  135. result = take_tx_and_cluster(tx_grouped, clusters_grouped).collect()
  136. for row in result:
  137. print(sorted(row['addresses']))
  138. end = time.time()
  139. print("ELAPSED TIME:", end-start)