123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- import json
- from typing import Iterable
-
- from pyspark.sql import SparkSession, DataFrame, Row
- from pyspark.sql import functions as F
-
- import time
- start = time.time()
-
-
- config = json.load(open("./settings.json"))
- debug = config['debug']
-
-
- class Master:
- spark: SparkSession
- CLUSTERS_TABLE: str
- TX_TABLE: str
-
- def __init__(self, config):
- self.spark = self.makeSparkContext(config)
- self.config = config
- self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
- self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
-
- def makeSparkContext(self, config) -> SparkSession:
- return SparkSession.builder \
- .appName('SparkCassandraApp') \
- .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
- .getOrCreate()
-
- def get_tx_dataframe(self) -> DataFrame:
- return self.spark.table(self.TX_TABLE)
-
- def union_single_col(self, df1: DataFrame, df2: DataFrame, column: str) -> DataFrame:
- return df1 \
- .select(column) \
- .union(df2.select(column))
-
- def reduce_concat_array_column(self, df: DataFrame, column: str, distinct:bool = False) -> DataFrame:
- df = self.explode_array_col(df.select(column), column)
- return self.collect_col_to_array(df, column, distinct)
-
- def collect_col_to_array(self, df: DataFrame, column: str, distinct: bool = False) -> DataFrame:
- if(distinct):
- return df.select(F.collect_set(column).alias(column))
- else:
- return df.select(F.collect_list(column).alias(column))
-
- def explode_array_col(self, df: DataFrame, column: str) -> DataFrame:
- return df \
- .rdd \
- .flatMap(lambda row: list(map(lambda elem: (elem,), row[column]))) \
- .toDF([column])
- # end class Master
-
-
- def cluster_id_addresses_rows(iter: "Iterable[Row]") -> Iterable:
- return iter
-
-
- master = Master(config)
- master.spark.catalog.clearCache()
- master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir'])
- tx_df = master.get_tx_dataframe()
-
- #Turn transactions into a list of ('id', [addr, addr, ...])
- tx_grouped = tx_df \
- .groupBy('tx_id') \
- .agg(F.collect_set('address').alias('addresses'))
-
- tx_grouped.rdd.mapPartitions(cluster_id_addresses_rows)
-
- # TODO: Load clusters from DB, check if any exist, if no make initial cluster, else proceed with loaded data
-
- # find initial cluster
-
- # take the first tx
- tx_zero = tx_grouped \
- .select('*') \
- .limit(1)
-
- # find txs with overlapping addresses
- overlapping_txs = tx_grouped \
- .join(
- tx_zero \
- .withColumnRenamed('addresses', 'tx_addresses') \
- .withColumnRenamed('tx_id', 'overlap_id')
- ) \
- .select(
- tx_grouped.tx_id,
- tx_grouped.addresses,
- F.arrays_overlap(tx_grouped.addresses, 'tx_addresses').alias('overlap')
- ) \
- .where(F.col('overlap') == True) \
- .drop('overlap')
-
- # overlapped txs must not be considered anymore, so remove them candidate dataframe
- tx_grouped = tx_grouped \
- .join(
- overlapping_txs.drop('addresses'),
- 'tx_id',
- 'leftanti'
- )
-
- # get the distinct addresses of all overlaps in a single array
- distinct_addresses = master.reduce_concat_array_column(
- master.union_single_col(
- overlapping_txs,
- tx_zero,
- column='addresses'
- ),
- column='addresses',
- distinct=True,
- )
-
- #pick out a random representative for this cluster and add it to every address
- cluster = distinct_addresses \
- .rdd \
- .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
- .toDF(['address', 'id'])
-
- # done finding initial cluster
-
- #group cluster by representative and transform the result into a list of shape ('id', ['addr', 'addr', ...])
- clusters_grouped = cluster \
- .groupBy('id') \
- .agg(F.collect_list('address').alias('addresses'))
-
- def take_tx_and_cluster(txs: DataFrame, clusters: DataFrame, n=0):
- if (txs.count() == 0): # done!
- return clusters
-
- # take a random tx
- tx = txs \
- .select('*').limit(1)
-
- # find clusters with overlapping addresses from tx
- overlapping_clusters = clusters \
- .join(tx.withColumnRenamed('addresses', 'tx_addresses')) \
- .select(
- clusters.id,
- clusters.addresses,
- 'tx_addresses',
- F.arrays_overlap(clusters.addresses,'tx_addresses').alias('overlap')
- ) \
- .where(F.col('overlap') == True)
-
- clusters_union_tx = master.union_single_col(tx, overlapping_clusters, 'addresses')
-
- #collect all addresses into single array field
- new_cluster_arrays = master.reduce_concat_array_column(
- clusters_union_tx,
- column='addresses',
- distinct=True
- )
-
- #declare cluster representative
- new_cluster = new_cluster_arrays \
- .rdd \
- .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
- .toDF(['address', 'id']) \
- .groupBy('id') \
- .agg(F.collect_list('address').alias('addresses'))
-
- txs = txs.join(tx, 'tx_id', 'leftanti')
- clusters = clusters.join(overlapping_clusters, 'id', 'leftanti').union(new_cluster)
-
- #the RDD legacy (internal history tracker) gets too big as iterations continue, use checkpoint to prune it regularly
- if(n % 3 == 0):
- txs = txs.checkpoint()
- clusters = clusters.checkpoint()
-
- #start new round with txs minus the one just used, and updated clusters
- return take_tx_and_cluster(txs,clusters,n+1)
-
-
- result = take_tx_and_cluster(tx_grouped, clusters_grouped).collect()
- for row in result:
- print(sorted(row['addresses']))
-
- end = time.time()
- print("ELAPSED TIME:", end-start)
|