import json from typing import Iterable, List, Set from pyspark.sql import SparkSession, DataFrame, Row from pyspark.sql import functions as F import time start = time.time() config = json.load(open("./settings.json")) debug = config['debug'] class Master: spark: SparkSession CLUSTERS_TABLE: str TX_TABLE: str def __init__(self, config): self.spark = self.makeSparkContext(config) self.config = config self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}" self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}" def makeSparkContext(self, config) -> SparkSession: return SparkSession.builder \ .appName('SparkCassandraApp') \ .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \ .getOrCreate() def get_tx_dataframe(self) -> DataFrame: return self.spark.table(self.TX_TABLE) # end class Master def merge_lists_distinct(*lists: "Iterable[List[str]]") -> List[str]: accum = set() for lst in lists: accum = accum.union(set(lst)) return list(accum) def check_lists_overlap(list1, list2): return any(x in list1 for x in list2) def cluster_step(clusters: "List[List[str]]", addresses: "List[List[str]]"): #if there are no more sets of addresses to consider, we are done if(len(addresses) == 0): return clusters #take a set of addresses tx = addresses[0] #remove it from list candidates addresses = addresses[1:] #find clusters that match these addresses matching_clusters = filter(lambda cluster: check_lists_overlap(tx, cluster), clusters) #remove all clusters that match these addresses clusters = list(filter(lambda cluster: not check_lists_overlap(tx, cluster), clusters)) #add a new cluster that is the union of found clusters and the inspected list of addresses clusters.append(merge_lists_distinct(tx, *matching_clusters)) return cluster_step(clusters,addresses) def cluster_id_addresses_rows(iter: "Iterable[Row]") -> Iterable: address_lists = list(map(lambda row: row['addresses'], iter)) yield cluster_step([], address_lists) def dud(iter): address_lists = list(map(lambda row: row['addresses'], iter)) yield address_lists master = Master(config) master.spark.catalog.clearCache() master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir']) tx_df = master.get_tx_dataframe() #Turn transactions into a list of ('id', [addr, addr, ...]) tx_grouped = tx_df \ .groupBy('tx_id') \ .agg(F.collect_set('address').alias('addresses')) \ .orderBy('tx_id') \ print() res = tx_grouped \ .repartition(5) \ .rdd \ .mapPartitions(cluster_id_addresses_rows) \ .fold([], cluster_step) for cluster in res: print(sorted(cluster)) end = time.time() print("ELAPSED TIME:", end-start)