import json from typing import Iterable from pyspark.sql import SparkSession, DataFrame, Row from pyspark.sql import functions as F import time start = time.time() config = json.load(open("./settings.json")) debug = config['debug'] class Master: spark: SparkSession CLUSTERS_TABLE: str TX_TABLE: str def __init__(self, config): self.spark = self.makeSparkContext(config) self.config = config self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}" self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}" def makeSparkContext(self, config) -> SparkSession: return SparkSession.builder \ .appName('SparkCassandraApp') \ .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \ .getOrCreate() def get_tx_dataframe(self) -> DataFrame: return self.spark.table(self.TX_TABLE) def union_single_col(self, df1: DataFrame, df2: DataFrame, column: str) -> DataFrame: return df1 \ .select(column) \ .union(df2.select(column)) def reduce_concat_array_column(self, df: DataFrame, column: str, distinct:bool = False) -> DataFrame: df = self.explode_array_col(df.select(column), column) return self.collect_col_to_array(df, column, distinct) def collect_col_to_array(self, df: DataFrame, column: str, distinct: bool = False) -> DataFrame: if(distinct): return df.select(F.collect_set(column).alias(column)) else: return df.select(F.collect_list(column).alias(column)) def explode_array_col(self, df: DataFrame, column: str) -> DataFrame: return df \ .rdd \ .flatMap(lambda row: list(map(lambda elem: (elem,), row[column]))) \ .toDF([column]) # end class Master def cluster_id_addresses_rows(iter: "Iterable[Row]") -> Iterable: return iter master = Master(config) master.spark.catalog.clearCache() master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir']) tx_df = master.get_tx_dataframe() #Turn transactions into a list of ('id', [addr, addr, ...]) tx_grouped = tx_df \ .groupBy('tx_id') \ .agg(F.collect_set('address').alias('addresses')) tx_grouped.rdd.mapPartitions(cluster_id_addresses_rows) # TODO: Load clusters from DB, check if any exist, if no make initial cluster, else proceed with loaded data # find initial cluster # take the first tx tx_zero = tx_grouped \ .select('*') \ .limit(1) # find txs with overlapping addresses overlapping_txs = tx_grouped \ .join( tx_zero \ .withColumnRenamed('addresses', 'tx_addresses') \ .withColumnRenamed('tx_id', 'overlap_id') ) \ .select( tx_grouped.tx_id, tx_grouped.addresses, F.arrays_overlap(tx_grouped.addresses, 'tx_addresses').alias('overlap') ) \ .where(F.col('overlap') == True) \ .drop('overlap') # overlapped txs must not be considered anymore, so remove them candidate dataframe tx_grouped = tx_grouped \ .join( overlapping_txs.drop('addresses'), 'tx_id', 'leftanti' ) # get the distinct addresses of all overlaps in a single array distinct_addresses = master.reduce_concat_array_column( master.union_single_col( overlapping_txs, tx_zero, column='addresses' ), column='addresses', distinct=True, ) #pick out a random representative for this cluster and add it to every address cluster = distinct_addresses \ .rdd \ .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \ .toDF(['address', 'id']) # done finding initial cluster #group cluster by representative and transform the result into a list of shape ('id', ['addr', 'addr', ...]) clusters_grouped = cluster \ .groupBy('id') \ .agg(F.collect_list('address').alias('addresses')) def take_tx_and_cluster(txs: DataFrame, clusters: DataFrame, n=0): if (txs.count() == 0): # done! return clusters # take a random tx tx = txs \ .select('*').limit(1) # find clusters with overlapping addresses from tx overlapping_clusters = clusters \ .join(tx.withColumnRenamed('addresses', 'tx_addresses')) \ .select( clusters.id, clusters.addresses, 'tx_addresses', F.arrays_overlap(clusters.addresses,'tx_addresses').alias('overlap') ) \ .where(F.col('overlap') == True) clusters_union_tx = master.union_single_col(tx, overlapping_clusters, 'addresses') #collect all addresses into single array field new_cluster_arrays = master.reduce_concat_array_column( clusters_union_tx, column='addresses', distinct=True ) #declare cluster representative new_cluster = new_cluster_arrays \ .rdd \ .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \ .toDF(['address', 'id']) \ .groupBy('id') \ .agg(F.collect_list('address').alias('addresses')) txs = txs.join(tx, 'tx_id', 'leftanti') clusters = clusters.join(overlapping_clusters, 'id', 'leftanti').union(new_cluster) #the RDD legacy (internal history tracker) gets too big as iterations continue, use checkpoint to prune it regularly if(n % 3 == 0): txs = txs.checkpoint() clusters = clusters.checkpoint() #start new round with txs minus the one just used, and updated clusters return take_tx_and_cluster(txs,clusters,n+1) result = take_tx_and_cluster(tx_grouped, clusters_grouped).collect() for row in result: print(sorted(row['addresses'])) end = time.time() print("ELAPSED TIME:", end-start)