import json from typing import Iterable, List, Set from pyspark.sql import SparkSession, DataFrame, Row from pyspark.sql import functions as F import time start = time.time() config = json.load(open("./settings.json")) debug = config['debug'] class Master: spark: SparkSession CLUSTERS_TABLE: str TX_TABLE: str def __init__(self, config): self.spark = self.makeSparkContext(config) self.config = config self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}" self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}" def makeSparkContext(self, config) -> SparkSession: return SparkSession.builder \ .appName('SparkCassandraApp') \ .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \ .getOrCreate() def get_tx_dataframe(self) -> DataFrame: return self.spark.table(self.TX_TABLE) # end class Master def merge_lists_distinct(*lists: "Iterable[List[str]]") -> List[str]: accum = set() for lst in lists: accum = accum.union(set(lst)) return list(accum) def check_lists_overlap(list1, list2): return any(x in list1 for x in list2) def cluster_step(clusters: "List[List[str]]", addresses: "List[List[str]]"): #if there are no more sets of addresses to consider, we are done if(len(addresses) == 0): return clusters tx = addresses[0] matching_clusters = [] new_clusters = [] for cluster in clusters: if(check_lists_overlap(tx, cluster)): matching_clusters.append(cluster) else: new_clusters.append(cluster) new_clusters.append(merge_lists_distinct(tx, *matching_clusters)) return cluster_step(new_clusters,addresses[1:]) def cluster_partition(iter: "Iterable[Row]") -> Iterable: yield cluster_step([], list(map(lambda row: row['addresses'], iter))) master = Master(config) master.spark.catalog.clearCache() master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir']) tx_df = master.get_tx_dataframe() #Turn transactions into a list of ('id', [addr, addr, ...]) tx_grouped = tx_df \ .groupBy('tx_id') \ .agg(F.collect_set('address').alias('addresses')) \ .orderBy('tx_id') \ res = tx_grouped \ .repartition(5) \ .rdd \ .mapPartitions(cluster_partition) \ .fold([], cluster_step) for cluster in res: print() print(sorted(cluster)) end = time.time() print("ELAPSED TIME:", end-start)