1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- import json
- from typing import Iterable, List, Set
-
- from pyspark.sql import SparkSession, DataFrame, Row
- from pyspark.sql import functions as F
-
- import time
- start = time.time()
-
-
- config = json.load(open("./settings.json"))
- debug = config['debug']
-
-
- class Master:
- spark: SparkSession
- CLUSTERS_TABLE: str
- TX_TABLE: str
-
- def __init__(self, config):
- self.spark = self.makeSparkContext(config)
- self.config = config
- self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
- self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
-
- def makeSparkContext(self, config) -> SparkSession:
- return SparkSession.builder \
- .appName('SparkCassandraApp') \
- .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
- .getOrCreate()
-
- def get_tx_dataframe(self) -> DataFrame:
- return self.spark.table(self.TX_TABLE)
-
- # end class Master
-
- def merge_lists_distinct(*lists: "Iterable[List[str]]") -> List[str]:
- accum = set()
- for lst in lists:
- accum = accum.union(set(lst))
- return list(accum)
-
- def check_lists_overlap(list1, list2):
- return any(x in list1 for x in list2)
-
- def cluster_step(clusters: "List[List[str]]", addresses: "List[List[str]]"):
- #if there are no more sets of addresses to consider, we are done
- if(len(addresses) == 0):
- return clusters
-
- #take a set of addresses
- tx = addresses[0]
- #remove it from list candidates
- addresses = addresses[1:]
-
- #find clusters that match these addresses
- matching_clusters = filter(lambda cluster: check_lists_overlap(tx, cluster), clusters)
-
- #remove all clusters that match these addresses
- clusters = list(filter(lambda cluster: not check_lists_overlap(tx, cluster), clusters))
-
- #add a new cluster that is the union of found clusters and the inspected list of addresses
- clusters.append(merge_lists_distinct(tx, *matching_clusters))
-
- return cluster_step(clusters,addresses)
-
-
- def cluster_id_addresses_rows(iter: "Iterable[Row]") -> Iterable:
- address_lists = list(map(lambda row: row['addresses'], iter))
- yield cluster_step([], address_lists)
-
- def dud(iter):
- address_lists = list(map(lambda row: row['addresses'], iter))
- yield address_lists
-
- master = Master(config)
- master.spark.catalog.clearCache()
- master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir'])
- tx_df = master.get_tx_dataframe()
-
- #Turn transactions into a list of ('id', [addr, addr, ...])
- tx_grouped = tx_df \
- .groupBy('tx_id') \
- .agg(F.collect_set('address').alias('addresses')) \
- .orderBy('tx_id') \
-
- print()
- res = tx_grouped \
- .repartition(5) \
- .rdd \
- .mapPartitions(cluster_id_addresses_rows) \
- .fold([], cluster_step)
-
- for cluster in res:
- print(sorted(cluster))
-
- end = time.time()
- print("ELAPSED TIME:", end-start)
|