123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- from gc import collect
- import json
-
- from sqlite3 import Row
- from typing import Iterable, List
-
- from pyspark.sql import SparkSession, DataFrame, Row
- from pyspark.sql import functions as F
-
- import time
- start = time.time()
-
-
- config = json.load(open("./settings.json"))
- debug = config['debug']
-
- class Master:
- spark: SparkSession
- CLUSTERS_TABLE: str
- TX_TABLE: str
-
- def __init__(self, config):
- self.spark = self.makeSparkContext(config)
- self.config = config
- self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
- self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
-
- def makeSparkContext(self,config) -> SparkSession:
- return SparkSession.builder \
- .appName('SparkCassandraApp') \
- .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
- .getOrCreate()
-
- def group_tx_addrs(self) -> DataFrame:
- return self.spark \
- .read \
- .table(self.TX_TABLE) \
- .groupBy("tx_id") \
- .agg(F.collect_set('address').alias('addresses'))
-
- def group_cluster_addrs(self) -> DataFrame:
- return self.spark \
- .read \
- .table(self.CLUSTERS_TABLE) \
- .groupBy("id") \
- .agg(F.collect_set('address').alias('addresses'))
-
- def insertNewCluster (self, addrs: Iterable[str], root: str | None = None) -> str:
- if(root == None):
- root = addrs[0]
- df = self.spark.createDataFrame(map(lambda addr: (addr, root), addrs), schema=['address', 'id'])
- df.writeTo(self.CLUSTERS_TABLE).append()
- return root
-
- def enumerate(self, data: DataFrame) -> DataFrame:
- return data \
- .rdd \
- .zipWithIndex() \
- .toDF(["tx_group", "index"])
-
- def rewrite_cluster_id(self, cluster_roots: Iterable[str], new_cluster_root: str) -> None:
- cluster_rewrite = self.spark \
- .table(self.CLUSTERS_TABLE) \
- .where(F.col('id').isin(cluster_roots)) \
- .select('address') \
- .rdd \
- .map(lambda addr: (addr['address'], new_cluster_root)) \
- .toDF(['address', 'id']) \
-
- if(debug):
- print("REWRITE JOB")
- cluster_rewrite.show(truncate=False, vertical=True)
- print()
-
- cluster_rewrite.writeTo(self.CLUSTERS_TABLE).append()
- # end class Master
-
-
- """
- tuple structure:
- Row => Row(id=addr, addresses=list[addr] | the cluster
- Iterable[str] => list[addr] | the transaction addresses
- """
- def find(data: tuple[Row, Iterable[str]]) -> str | None:
- cluster = data[0]
- tx = data[1]
-
- clusteraddresses = cluster['addresses'] + [cluster['id']]
-
- if any(x in tx for x in clusteraddresses):
- return cluster['id']
- else:
- return None
-
- master = Master(config)
-
- tx_addr_groups = master.group_tx_addrs()
- tx_groups_indexed = master.enumerate(tx_addr_groups).cache()
-
- for i in range(0, tx_addr_groups.count()):
- cluster_addr_groups = master.group_cluster_addrs()
-
- if(debug):
- print("KNOWN CLUSTERS")
- cluster_addr_groups.show(truncate=True)
- print()
-
- tx_addrs: Iterable[str] = tx_groups_indexed \
- .where(tx_groups_indexed.index == i) \
- .select('tx_group') \
- .collect()[0]['tx_group']['addresses']
-
- if(debug):
- print("CURRENT TX")
- print(tx_addrs)
- print()
-
- if (cluster_addr_groups.count() == 0):
- master.insertNewCluster(tx_addrs)
- continue
-
- cluster_tx_mapping = cluster_addr_groups \
- .rdd \
- .map(lambda cluster: (cluster, tx_addrs))
-
- if(debug):
- print("cluster_tx_mapping")
- cluster_tx_mapping \
- .toDF(['cluster', 'tx']) \
- .show(truncate=True)
- print()
-
-
- matched_roots: "List[str]" = cluster_tx_mapping \
- .map(find) \
- .filter(lambda root: root != None) \
- .collect()
-
- if(debug):
- print("FOUND ROOTS")
- print(matched_roots)
- print()
-
-
- if(len(matched_roots) == 0):
- master.insertNewCluster(tx_addrs)
- elif(len(matched_roots) == 1):
- master.insertNewCluster(tx_addrs, matched_roots[0])
- else:
- master.rewrite_cluster_id(matched_roots[1:], matched_roots[0])
- master.insertNewCluster(tx_addrs, matched_roots[0])
-
- if(debug):
- print("======================================================================")
-
- end = time.time()
- print("ELAPSED TIME:", end-start)
|