from typing import List from graphframes import GraphFrame import json from pyspark.sql import SparkSession, DataFrame, Row from pyspark.sql import functions as F import time start = time.time() config = json.load(open("./settings.json")) debug = config['debug'] class Master: spark: SparkSession CLUSTERS_TABLE: str TX_TABLE: str def __init__(self, config): self.spark = self.makeSparkContext(config) self.config = config self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}" self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}" def makeSparkContext(self, config) -> SparkSession: return SparkSession.builder \ .appName('DistributedUnionFindWithGraphs') \ .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \ .getOrCreate() def get_tx_dataframe(self) -> DataFrame: return self.spark.table(self.TX_TABLE) def get_cluster_dataframe(self) -> DataFrame: return self.spark.table(self.CLUSTERS_TABLE) def write_connected_components_as_clusters(self, conn_comp: DataFrame) -> None: conn_comp \ .withColumnRenamed('id', 'address') \ .withColumnRenamed('component', 'id') \ .writeTo(self.CLUSTERS_TABLE) \ .append() # end class Master master = Master(config) master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir']) tx_df = master.get_tx_dataframe() addresses_as_vertices = tx_df \ .select('address') \ .withColumnRenamed('address', 'id') \ .distinct() def explode_row(row: Row) -> List[Row]: addresses = row['addresses'] return list(map(lambda addr: (addr, addresses[0]), addresses[1:])) transactions_as_edges = tx_df \ .groupBy("tx_id") \ .agg(F.collect_set('address').alias('addresses')) \ .rdd \ .flatMap(explode_row) \ .toDF(['src', 'dst']) g = GraphFrame(addresses_as_vertices, transactions_as_edges) components = g.connectedComponents(algorithm='graphframes') #master.write_connected_components_as_clusters(components) clusters = components \ .groupBy('component') \ .agg(F.collect_list('id')) \ .collect() #print(len(clusters)) #for cluster in clusters: # print(sorted(cluster['collect_list(id)'])) end = time.time() print(end-start, end='')