from gc import collect import json from sqlite3 import Row from typing import Iterable, List from pyspark.sql import SparkSession, DataFrame, Row from pyspark.sql import functions as F import time start = time.time() config = json.load(open("./settings.json")) debug = config['debug'] class Master: spark: SparkSession CLUSTERS_TABLE: str TX_TABLE: str def __init__(self, config): self.spark = self.makeSparkContext(config) self.config = config self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}" self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}" def makeSparkContext(self,config) -> SparkSession: return SparkSession.builder \ .appName('SparkCassandraApp') \ .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \ .getOrCreate() def group_tx_addrs(self) -> DataFrame: return self.spark \ .read \ .table(self.TX_TABLE) \ .groupBy("tx_id") \ .agg(F.collect_set('address').alias('addresses')) def group_cluster_addrs(self) -> DataFrame: return self.spark \ .read \ .table(self.CLUSTERS_TABLE) \ .groupBy("id") \ .agg(F.collect_set('address').alias('addresses')) def insertNewCluster (self, addrs: Iterable[str], root: str | None = None) -> str: if(root == None): root = addrs[0] df = self.spark.createDataFrame(map(lambda addr: (addr, root), addrs), schema=['address', 'id']) df.writeTo(self.CLUSTERS_TABLE).append() return root def enumerate(self, data: DataFrame) -> DataFrame: return data \ .rdd \ .zipWithIndex() \ .toDF(["tx_group", "index"]) def rewrite_cluster_id(self, cluster_roots: Iterable[str], new_cluster_root: str) -> None: cluster_rewrite = self.spark \ .table(self.CLUSTERS_TABLE) \ .where(F.col('id').isin(cluster_roots)) \ .select('address') \ .rdd \ .map(lambda addr: (addr['address'], new_cluster_root)) \ .toDF(['address', 'id']) \ if(debug): print("REWRITE JOB") cluster_rewrite.show(truncate=False, vertical=True) print() cluster_rewrite.writeTo(self.CLUSTERS_TABLE).append() # end class Master """ tuple structure: Row => Row(id=addr, addresses=list[addr] | the cluster Iterable[str] => list[addr] | the transaction addresses """ def find(data: tuple[Row, Iterable[str]]) -> str | None: cluster = data[0] tx = data[1] clusteraddresses = cluster['addresses'] + [cluster['id']] if any(x in tx for x in clusteraddresses): return cluster['id'] else: return None master = Master(config) tx_addr_groups = master.group_tx_addrs() tx_groups_indexed = master.enumerate(tx_addr_groups).cache() for i in range(0, tx_addr_groups.count()): cluster_addr_groups = master.group_cluster_addrs() if(debug): print("KNOWN CLUSTERS") cluster_addr_groups.show(truncate=True) print() tx_addrs: Iterable[str] = tx_groups_indexed \ .where(tx_groups_indexed.index == i) \ .select('tx_group') \ .collect()[0]['tx_group']['addresses'] if(debug): print("CURRENT TX") print(tx_addrs) print() if (cluster_addr_groups.count() == 0): master.insertNewCluster(tx_addrs) continue cluster_tx_mapping = cluster_addr_groups \ .rdd \ .map(lambda cluster: (cluster, tx_addrs)) if(debug): print("cluster_tx_mapping") cluster_tx_mapping \ .toDF(['cluster', 'tx']) \ .show(truncate=True) print() matched_roots: "List[str]" = cluster_tx_mapping \ .map(find) \ .filter(lambda root: root != None) \ .collect() if(debug): print("FOUND ROOTS") print(matched_roots) print() if(len(matched_roots) == 0): master.insertNewCluster(tx_addrs) elif(len(matched_roots) == 1): master.insertNewCluster(tx_addrs, matched_roots[0]) else: master.rewrite_cluster_id(matched_roots[1:], matched_roots[0]) master.insertNewCluster(tx_addrs, matched_roots[0]) if(debug): print("======================================================================") end = time.time() print("ELAPSED TIME:", end-start)