1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- from typing import List
- from graphframes import GraphFrame
- import json
- from pyspark.sql import SparkSession, DataFrame, Row
- from pyspark.sql import functions as F
-
- import time
- start = time.time()
-
-
- config = json.load(open("./settings.json"))
- debug = config['debug']
-
-
- class Master:
- spark: SparkSession
- CLUSTERS_TABLE: str
- TX_TABLE: str
-
- def __init__(self, config):
- self.spark = self.makeSparkContext(config)
- self.config = config
- self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
- self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
-
- def makeSparkContext(self, config) -> SparkSession:
- return SparkSession.builder \
- .appName('DistributedUnionFindWithGraphs') \
- .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
- .getOrCreate()
-
- def get_tx_dataframe(self) -> DataFrame:
- return self.spark.table(self.TX_TABLE)
-
- def get_cluster_dataframe(self) -> DataFrame:
- return self.spark.table(self.CLUSTERS_TABLE)
-
- def write_connected_components_as_clusters(self, conn_comp: DataFrame) -> None:
- conn_comp \
- .withColumnRenamed('id', 'address') \
- .withColumnRenamed('component', 'id') \
- .writeTo(self.CLUSTERS_TABLE) \
- .append()
-
- # end class Master
-
- master = Master(config)
- master.spark.sparkContext.setCheckpointDir('./checkpoints') # spark is really adamant it needs this even if the algorithm is set to the non-checkpointed version
-
- tx_df = master.get_tx_dataframe()
-
- transaction_as_vertices = tx_df \
- .select('address') \
- .withColumnRenamed('address', 'id') \
- .distinct()
-
- def explode_row(row: Row) -> List[Row]:
- addresses = row['addresses']
- return list(map(lambda addr: (addr, addresses[0]), addresses[1:]))
-
- transactions_as_edges = tx_df \
- .groupBy("tx_id") \
- .agg(F.collect_set('address').alias('addresses')) \
- .rdd \
- .flatMap(explode_row) \
- .toDF(['src', 'dst'])
-
- g = GraphFrame(transaction_as_vertices, transactions_as_edges)
- components = g.connectedComponents(algorithm='graphframes')
-
- master.write_connected_components_as_clusters(components)
-
- if(debug):
- clusters = components \
- .groupBy('component') \
- .agg(F.collect_list('id')) \
- .collect()
-
- for cluster in clusters:
- print(sorted(cluster['collect_list(id)']))
-
- end = time.time()
- print("ELAPSED TIME:", end-start)
|