You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 5.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import json
  2. from typing import Iterable
  3. from pyspark.sql import SparkSession, DataFrame, Row
  4. from pyspark.sql import functions as F
  5. import time
  6. start = time.time()
  7. config = json.load(open("./settings.json"))
  8. debug = config['debug']
  9. class Master:
  10. spark: SparkSession
  11. CLUSTERS_TABLE: str
  12. TX_TABLE: str
  13. def __init__(self, config):
  14. self.spark = self.makeSparkContext(config)
  15. self.config = config
  16. self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
  17. self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
  18. def makeSparkContext(self, config) -> SparkSession:
  19. return SparkSession.builder \
  20. .appName('SparkCassandraApp') \
  21. .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
  22. .getOrCreate()
  23. def get_tx_dataframe(self) -> DataFrame:
  24. return self.spark.table(self.TX_TABLE)
  25. def union_single_col(self, df1: DataFrame, df2: DataFrame, column: str) -> DataFrame:
  26. return df1 \
  27. .select(column) \
  28. .union(df2.select(column))
  29. def reduce_concat_array_column(self, df: DataFrame, column: str, distinct:bool = False) -> DataFrame:
  30. df = self.explode_array_col(df.select(column), column)
  31. return self.collect_col_to_array(df, column, distinct)
  32. def collect_col_to_array(self, df: DataFrame, column: str, distinct: bool = False) -> DataFrame:
  33. if(distinct):
  34. return df.select(F.collect_set(column).alias(column))
  35. else:
  36. return df.select(F.collect_list(column).alias(column))
  37. def explode_array_col(self, df: DataFrame, column: str) -> DataFrame:
  38. return df \
  39. .rdd \
  40. .flatMap(lambda row: list(map(lambda elem: (elem,), row[column]))) \
  41. .toDF([column])
  42. # end class Master
  43. def cluster_id_addresses_rows(iter: "Iterable[Row]") -> Iterable:
  44. return iter
  45. master = Master(config)
  46. master.spark.catalog.clearCache()
  47. master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir'])
  48. tx_df = master.get_tx_dataframe()
  49. #Turn transactions into a list of ('id', [addr, addr, ...])
  50. tx_grouped = tx_df \
  51. .groupBy('tx_id') \
  52. .agg(F.collect_set('address').alias('addresses'))
  53. # TODO: Load clusters from DB, check if any exist, if no make initial cluster, else proceed with loaded data
  54. # find initial cluster
  55. # take the first tx
  56. tx_zero = tx_grouped \
  57. .select('*') \
  58. .limit(1)
  59. # find txs with overlapping addresses
  60. overlapping_txs = tx_grouped \
  61. .join(
  62. tx_zero \
  63. .withColumnRenamed('addresses', 'tx_addresses') \
  64. .withColumnRenamed('tx_id', 'overlap_id')
  65. ) \
  66. .select(
  67. tx_grouped.tx_id,
  68. tx_grouped.addresses,
  69. F.arrays_overlap(tx_grouped.addresses, 'tx_addresses').alias('overlap')
  70. ) \
  71. .where(F.col('overlap') == True) \
  72. .drop('overlap')
  73. # overlapped txs must not be considered anymore, so remove them candidate dataframe
  74. tx_grouped = tx_grouped \
  75. .join(
  76. overlapping_txs.drop('addresses'),
  77. 'tx_id',
  78. 'leftanti'
  79. )
  80. # get the distinct addresses of all overlaps in a single array
  81. distinct_addresses = master.reduce_concat_array_column(
  82. master.union_single_col(
  83. overlapping_txs,
  84. tx_zero,
  85. column='addresses'
  86. ),
  87. column='addresses',
  88. distinct=True,
  89. )
  90. #pick out a random representative for this cluster and add it to every address
  91. cluster = distinct_addresses \
  92. .rdd \
  93. .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
  94. .toDF(['address', 'id'])
  95. # done finding initial cluster
  96. #group cluster by representative and transform the result into a list of shape ('id', ['addr', 'addr', ...])
  97. clusters_grouped = cluster \
  98. .groupBy('id') \
  99. .agg(F.collect_list('address').alias('addresses'))
  100. def take_tx_and_cluster(txs: DataFrame, clusters: DataFrame, n=0):
  101. if (txs.count() == 0): # done!
  102. return clusters
  103. # take a random tx
  104. tx = txs \
  105. .select('*').limit(1)
  106. # find clusters with overlapping addresses from tx
  107. overlapping_clusters = clusters \
  108. .join(tx.withColumnRenamed('addresses', 'tx_addresses')) \
  109. .select(
  110. clusters.id,
  111. clusters.addresses,
  112. 'tx_addresses',
  113. F.arrays_overlap(clusters.addresses,'tx_addresses').alias('overlap')
  114. ) \
  115. .where(F.col('overlap') == True)
  116. clusters_union_tx = master.union_single_col(tx, overlapping_clusters, 'addresses')
  117. #collect all addresses into single array field
  118. new_cluster_arrays = master.reduce_concat_array_column(
  119. clusters_union_tx,
  120. column='addresses',
  121. distinct=True
  122. )
  123. #declare cluster representative
  124. new_cluster = new_cluster_arrays \
  125. .rdd \
  126. .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
  127. .toDF(['address', 'id']) \
  128. .groupBy('id') \
  129. .agg(F.collect_list('address').alias('addresses'))
  130. txs = txs.join(tx, 'tx_id', 'leftanti')
  131. clusters = clusters.join(overlapping_clusters, 'id', 'leftanti').union(new_cluster)
  132. #the RDD legacy (internal history tracker) gets too big as iterations continue, use checkpoint to prune it regularly
  133. if(n % 3 == 0):
  134. txs = txs.checkpoint()
  135. clusters = clusters.checkpoint()
  136. #start new round with txs minus the one just used, and updated clusters
  137. return take_tx_and_cluster(txs,clusters,n+1)
  138. result = take_tx_and_cluster(tx_grouped, clusters_grouped).collect()
  139. for row in result:
  140. print(sorted(row['addresses']))
  141. end = time.time()
  142. print("ELAPSED TIME:", end-start)