1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- from typing import List
- from graphframes import GraphFrame
- import json
- from pyspark.sql import SparkSession, DataFrame, Row
- from pyspark.sql import functions as F
-
- import time
- start = time.time()
-
-
- config = json.load(open("./settings.json"))
- debug = config['debug']
-
-
- class Master:
- spark: SparkSession
- CLUSTERS_TABLE: str
- TX_TABLE: str
-
- def __init__(self, config):
- self.spark = self.makeSparkContext(config)
- self.config = config
- self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
- self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
-
- def makeSparkContext(self, config) -> SparkSession:
- return SparkSession.builder \
- .appName('DistributedUnionFindWithGraphs') \
- .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
- .getOrCreate()
-
- def empty_dataframe(self, schema) -> DataFrame:
- return self.spark.createDataFrame(self.spark.sparkContext.emptyRDD(), schema)
-
- def get_tx_dataframe(self) -> DataFrame:
- return self.spark.table(self.TX_TABLE)
-
- def get_cluster_dataframe(self) -> DataFrame:
- return self.spark.table(self.CLUSTERS_TABLE)
-
- # end class Master
-
-
- master = Master(config)
- master.spark.sparkContext.setCheckpointDir(
- './checkpoints') # spark is really adamant it needs this
-
- # Vertex DataFrame
- transaction_as_vertices = master.get_tx_dataframe() \
- .select('address') \
- .withColumnRenamed('address', 'id') \
- .distinct()
-
- def explode_row(row: Row) -> List[Row]:
- addresses = row['addresses']
- if(len(addresses) == 1):
- return []
-
- return list(map(lambda addr: (addr, addresses[0]), addresses[1:]))
-
-
- tx_groups = master.get_tx_dataframe() \
- .groupBy("tx_id") \
- .agg(F.collect_set('address').alias('addresses'))
-
- transactions_as_edges = tx_groups \
- .rdd \
- .flatMap(explode_row) \
- .toDF(['src', 'dst'])
-
-
- # Create a GraphFrame
- g = GraphFrame(transaction_as_vertices, transactions_as_edges)
- res = g.connectedComponents().groupBy('component').agg(F.collect_list('id')).collect()
-
- for row in res:
- print(sorted(row['collect_list(id)']))
-
- end = time.time()
- print("ELAPSED TIME:", end-start)
|