You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_partition.py 2.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import json
  2. from typing import Iterable, List, Set
  3. from pyspark.sql import SparkSession, DataFrame, Row
  4. from pyspark.sql import functions as F
  5. import time
  6. start = time.time()
  7. config = json.load(open("./settings.json"))
  8. debug = config['debug']
  9. class Master:
  10. spark: SparkSession
  11. CLUSTERS_TABLE: str
  12. TX_TABLE: str
  13. def __init__(self, config):
  14. self.spark = self.makeSparkContext(config)
  15. self.config = config
  16. self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
  17. self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
  18. def makeSparkContext(self, config) -> SparkSession:
  19. return SparkSession.builder \
  20. .appName('SparkCassandraApp') \
  21. .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
  22. .getOrCreate()
  23. def get_tx_dataframe(self) -> DataFrame:
  24. return self.spark.table(self.TX_TABLE)
  25. # end class Master
  26. def merge_lists_distinct(*lists: "Iterable[List[str]]") -> List[str]:
  27. accum = set()
  28. for lst in lists:
  29. accum = accum.union(set(lst))
  30. return list(accum)
  31. def check_lists_overlap(list1, list2):
  32. return any(x in list1 for x in list2)
  33. def cluster_step(clusters: "List[List[str]]", addresses: "List[List[str]]"):
  34. #if there are no more sets of addresses to consider, we are done
  35. if(len(addresses) == 0):
  36. return clusters
  37. tx = addresses[0]
  38. matching_clusters = []
  39. new_clusters = []
  40. for cluster in clusters:
  41. if(check_lists_overlap(tx, cluster)):
  42. matching_clusters.append(cluster)
  43. else:
  44. new_clusters.append(cluster)
  45. new_clusters.append(merge_lists_distinct(tx, *matching_clusters))
  46. return cluster_step(new_clusters,addresses[1:])
  47. def cluster_partition(iter: "Iterable[Row]") -> Iterable:
  48. yield cluster_step([], list(map(lambda row: row['addresses'], iter)))
  49. master = Master(config)
  50. master.spark.catalog.clearCache()
  51. master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir'])
  52. tx_df = master.get_tx_dataframe()
  53. #Turn transactions into a list of ('id', [addr, addr, ...])
  54. tx_grouped = tx_df \
  55. .groupBy('tx_id') \
  56. .agg(F.collect_set('address').alias('addresses')) \
  57. .orderBy('tx_id') \
  58. res = tx_grouped \
  59. .repartition(5) \
  60. .rdd \
  61. .mapPartitions(cluster_partition) \
  62. .fold([], cluster_step)
  63. for cluster in res:
  64. print()
  65. print(sorted(cluster))
  66. end = time.time()
  67. print("ELAPSED TIME:", end-start)