You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_graphs.py 2.4KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from typing import List
  2. from graphframes import GraphFrame
  3. import json
  4. from pyspark.sql import SparkSession, DataFrame, Row
  5. from pyspark.sql import functions as F
  6. import time
  7. start = time.time()
  8. config = json.load(open("./settings.json"))
  9. debug = config['debug']
  10. class Master:
  11. spark: SparkSession
  12. CLUSTERS_TABLE: str
  13. TX_TABLE: str
  14. def __init__(self, config):
  15. self.spark = self.makeSparkContext(config)
  16. self.config = config
  17. self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
  18. self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
  19. def makeSparkContext(self, config) -> SparkSession:
  20. return SparkSession.builder \
  21. .appName('DistributedUnionFindWithGraphs') \
  22. .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
  23. .getOrCreate()
  24. def get_tx_dataframe(self) -> DataFrame:
  25. return self.spark.table(self.TX_TABLE)
  26. def get_cluster_dataframe(self) -> DataFrame:
  27. return self.spark.table(self.CLUSTERS_TABLE)
  28. def write_connected_components_as_clusters(self, conn_comp: DataFrame) -> None:
  29. conn_comp \
  30. .withColumnRenamed('id', 'address') \
  31. .withColumnRenamed('component', 'id') \
  32. .writeTo(self.CLUSTERS_TABLE) \
  33. .append()
  34. # end class Master
  35. master = Master(config)
  36. master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir'])
  37. tx_df = master.get_tx_dataframe()
  38. addresses_as_vertices = tx_df \
  39. .select('address') \
  40. .withColumnRenamed('address', 'id') \
  41. .distinct()
  42. def explode_row(row: Row) -> List[Row]:
  43. addresses = row['addresses']
  44. return list(map(lambda addr: (addr, addresses[0]), addresses[1:]))
  45. transactions_as_edges = tx_df \
  46. .groupBy("tx_id") \
  47. .agg(F.collect_set('address').alias('addresses')) \
  48. .rdd \
  49. .flatMap(explode_row) \
  50. .toDF(['src', 'dst'])
  51. g = GraphFrame(addresses_as_vertices, transactions_as_edges)
  52. components = g.connectedComponents(algorithm='graphframes')
  53. #master.write_connected_components_as_clusters(components)
  54. clusters = components \
  55. .groupBy('component') \
  56. .agg(F.collect_list('id')) \
  57. .collect()
  58. #print(len(clusters))
  59. #for cluster in clusters:
  60. # print(sorted(cluster['collect_list(id)']))
  61. end = time.time()
  62. print(end-start, end='')