123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- import json
- from typing import Iterable, List, Set
-
- from pyspark.sql import SparkSession, DataFrame, Row
- from pyspark.sql import functions as F
-
- import time
- start = time.time()
-
-
- config = json.load(open("./settings.json"))
- debug = config['debug']
-
-
- class Master:
- spark: SparkSession
- CLUSTERS_TABLE: str
- TX_TABLE: str
-
- def __init__(self, config):
- self.spark = self.makeSparkContext(config)
- self.config = config
- self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
- self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
-
- def makeSparkContext(self, config) -> SparkSession:
- return SparkSession.builder \
- .appName('SparkCassandraApp') \
- .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
- .getOrCreate()
-
- def get_tx_dataframe(self) -> DataFrame:
- return self.spark.table(self.TX_TABLE)
-
- # end class Master
-
- def merge_lists_distinct(*lists: "Iterable[List[str]]") -> List[str]:
- accum = set()
- for lst in lists:
- accum = accum.union(set(lst))
- return list(accum)
-
- def check_lists_overlap(list1, list2):
- return any(x in list1 for x in list2)
-
- def cluster_step(clusters: "List[List[str]]", addresses: "List[List[str]]"):
- #if there are no more sets of addresses to consider, we are done
- if(len(addresses) == 0):
- return clusters
-
- tx = addresses[0]
- matching_clusters = []
- new_clusters = []
-
- for cluster in clusters:
- if(check_lists_overlap(tx, cluster)):
- matching_clusters.append(cluster)
- else:
- new_clusters.append(cluster)
-
- new_clusters.append(merge_lists_distinct(tx, *matching_clusters))
-
- return cluster_step(new_clusters,addresses[1:])
-
- def cluster_partition(iter: "Iterable[Row]") -> Iterable:
- yield cluster_step([], list(map(lambda row: row['addresses'], iter)))
-
- master = Master(config)
- master.spark.catalog.clearCache()
- master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir'])
- tx_df = master.get_tx_dataframe()
-
- #Turn transactions into a list of ('id', [addr, addr, ...])
- tx_grouped = tx_df \
- .groupBy('tx_id') \
- .agg(F.collect_set('address').alias('addresses')) \
- .orderBy('tx_id') \
-
- res = tx_grouped \
- .repartition(5) \
- .rdd \
- .mapPartitions(cluster_partition) \
- .fold([], cluster_step)
-
- for cluster in res:
- print()
- print(sorted(cluster))
-
- end = time.time()
- print("ELAPSED TIME:", end-start)
|