You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long. 2.6KB

  1. import json
  2. from typing import Iterable, List, Set
  3. from pyspark.sql import SparkSession, DataFrame, Row
  4. from pyspark.sql import functions as F
  5. import time
  6. start = time.time()
  7. config = json.load(open("./settings.json"))
  8. debug = config['debug']
  9. class Master:
  10. spark: SparkSession
  12. TX_TABLE: str
  13. def __init__(self, config):
  14. self.spark = self.makeSparkContext(config)
  15. self.config = config
  16. self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
  17. self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
  18. def makeSparkContext(self, config) -> SparkSession:
  19. return SparkSession.builder \
  20. .appName('SparkCassandraApp') \
  21. .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
  22. .getOrCreate()
  23. def get_tx_dataframe(self) -> DataFrame:
  24. return self.spark.table(self.TX_TABLE)
  25. # end class Master
  26. def merge_lists_distinct(*lists: "Iterable[List[str]]") -> List[str]:
  27. accum = set()
  28. for lst in lists:
  29. accum = accum.union(set(lst))
  30. return list(accum)
  31. def check_lists_overlap(list1, list2):
  32. return any(x in list1 for x in list2)
  33. def cluster_step(clusters: "List[List[str]]", addresses: "List[List[str]]"):
  34. #if there are no more sets of addresses to consider, we are done
  35. if(len(addresses) == 0):
  36. return clusters
  37. tx = addresses[0]
  38. matching_clusters = []
  39. new_clusters = []
  40. for cluster in clusters:
  41. if(check_lists_overlap(tx, cluster)):
  42. matching_clusters.append(cluster)
  43. else:
  44. new_clusters.append(cluster)
  45. new_clusters.append(merge_lists_distinct(tx, *matching_clusters))
  46. return cluster_step(new_clusters,addresses[1:])
  47. def cluster_partition(iter: "Iterable[Row]") -> Iterable:
  48. yield cluster_step([], list(map(lambda row: row['addresses'], iter)))
  49. master = Master(config)
  50. master.spark.catalog.clearCache()
  51. master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir'])
  52. tx_df = master.get_tx_dataframe()
  53. #Turn transactions into a list of ('id', [addr, addr, ...])
  54. tx_grouped = tx_df \
  55. .groupBy('tx_id') \
  56. .agg(F.collect_set('address').alias('addresses')) \
  57. .orderBy('tx_id') \
  58. res = tx_grouped \
  59. .repartition(5) \
  60. .rdd \
  61. .mapPartitions(cluster_partition) \
  62. .fold([], cluster_step)
  63. for cluster in res:
  64. print()
  65. print(sorted(cluster))
  66. end = time.time()
  67. print("ELAPSED TIME:", end-start)