You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_partition.py 3.0KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import json
  2. from typing import Iterable, List, Set
  3. from pyspark.sql import SparkSession, DataFrame, Row
  4. from pyspark.sql import functions as F
  5. import time
  6. start = time.time()
  7. config = json.load(open("./settings.json"))
  8. debug = config['debug']
  9. class Master:
  10. spark: SparkSession
  11. CLUSTERS_TABLE: str
  12. TX_TABLE: str
  13. def __init__(self, config):
  14. self.spark = self.makeSparkContext(config)
  15. self.config = config
  16. self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
  17. self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
  18. def makeSparkContext(self, config) -> SparkSession:
  19. return SparkSession.builder \
  20. .appName('SparkCassandraApp') \
  21. .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
  22. .getOrCreate()
  23. def get_tx_dataframe(self) -> DataFrame:
  24. return self.spark.table(self.TX_TABLE)
  25. # end class Master
  26. def merge_lists_distinct(*lists: "Iterable[List[str]]") -> List[str]:
  27. accum = set()
  28. for lst in lists:
  29. accum = accum.union(set(lst))
  30. return list(accum)
  31. def check_lists_overlap(list1, list2):
  32. return any(x in list1 for x in list2)
  33. def cluster_step(clusters: "List[List[str]]", addresses: "List[List[str]]"):
  34. #if there are no more sets of addresses to consider, we are done
  35. if(len(addresses) == 0):
  36. return clusters
  37. #take a set of addresses
  38. tx = addresses[0]
  39. #remove it from list candidates
  40. addresses = addresses[1:]
  41. #find clusters that match these addresses
  42. matching_clusters = filter(lambda cluster: check_lists_overlap(tx, cluster), clusters)
  43. #remove all clusters that match these addresses
  44. clusters = list(filter(lambda cluster: not check_lists_overlap(tx, cluster), clusters))
  45. #add a new cluster that is the union of found clusters and the inspected list of addresses
  46. clusters.append(merge_lists_distinct(tx, *matching_clusters))
  47. return cluster_step(clusters,addresses)
  48. def cluster_id_addresses_rows(iter: "Iterable[Row]") -> Iterable:
  49. address_lists = list(map(lambda row: row['addresses'], iter))
  50. yield cluster_step([], address_lists)
  51. def dud(iter):
  52. address_lists = list(map(lambda row: row['addresses'], iter))
  53. yield address_lists
  54. master = Master(config)
  55. master.spark.catalog.clearCache()
  56. master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir'])
  57. tx_df = master.get_tx_dataframe()
  58. #Turn transactions into a list of ('id', [addr, addr, ...])
  59. tx_grouped = tx_df \
  60. .groupBy('tx_id') \
  61. .agg(F.collect_set('address').alias('addresses')) \
  62. .orderBy('tx_id') \
  63. print()
  64. res = tx_grouped \
  65. .repartition(5) \
  66. .rdd \
  67. .mapPartitions(cluster_id_addresses_rows) \
  68. .fold([], cluster_step)
  69. for cluster in res:
  70. print(sorted(cluster))
  71. end = time.time()
  72. print("ELAPSED TIME:", end-start)