You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_graphs.py 2.3KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from typing import List
  2. from graphframes import GraphFrame
  3. import json
  4. from pyspark.sql import SparkSession, DataFrame, Row
  5. from pyspark.sql import functions as F
  6. import time
  7. start = time.time()
  8. config = json.load(open("./settings.json"))
  9. debug = config['debug']
  10. class Master:
  11. spark: SparkSession
  12. CLUSTERS_TABLE: str
  13. TX_TABLE: str
  14. def __init__(self, config):
  15. self.spark = self.makeSparkContext(config)
  16. self.config = config
  17. self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
  18. self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
  19. def makeSparkContext(self, config) -> SparkSession:
  20. return SparkSession.builder \
  21. .appName('DistributedUnionFindWithGraphs') \
  22. .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
  23. .getOrCreate()
  24. def empty_dataframe(self, schema) -> DataFrame:
  25. return self.spark.createDataFrame(self.spark.sparkContext.emptyRDD(), schema)
  26. def get_tx_dataframe(self) -> DataFrame:
  27. return self.spark.table(self.TX_TABLE)
  28. def get_cluster_dataframe(self) -> DataFrame:
  29. return self.spark.table(self.CLUSTERS_TABLE)
  30. # end class Master
  31. master = Master(config)
  32. master.spark.sparkContext.setCheckpointDir(
  33. './checkpoints') # spark is really adamant it needs this
  34. # Vertex DataFrame
  35. transaction_as_vertices = master.get_tx_dataframe() \
  36. .select('address') \
  37. .withColumnRenamed('address', 'id') \
  38. .distinct()
  39. def explode_row(row: Row) -> List[Row]:
  40. addresses = row['addresses']
  41. if(len(addresses) == 1):
  42. return []
  43. return list(map(lambda addr: (addr, addresses[0]), addresses[1:]))
  44. tx_groups = master.get_tx_dataframe() \
  45. .groupBy("tx_id") \
  46. .agg(F.collect_set('address').alias('addresses'))
  47. transactions_as_edges = tx_groups \
  48. .rdd \
  49. .flatMap(explode_row) \
  50. .toDF(['src', 'dst'])
  51. # Create a GraphFrame
  52. g = GraphFrame(transaction_as_vertices, transactions_as_edges)
  53. res = g.connectedComponents().groupBy('component').agg(F.collect_list('id')).collect()
  54. for row in res:
  55. print(sorted(row['collect_list(id)']))
  56. end = time.time()
  57. print("ELAPSED TIME:", end-start)