Browse Source

add db write to graph impl

master
nitowa 1 year ago
parent
commit
d42f70d33c
5 changed files with 45 additions and 39 deletions
  1. 2
    1
      .gitignore
  2. 9
    5
      README.md
  3. 1
    1
      config/db/tables/clusters/CREATE.sql
  4. 9
    11
      src/spark/main.py
  5. 24
    21
      src/spark/main_graphs.py

+ 2
- 1
.gitignore View File

@@ -1,4 +1,5 @@
1 1
 __pycache__
2 2
 .vscode
3 3
 checkpoints
4
-spark-warehouse
4
+spark-warehouse
5
+scratchpad.py

+ 9
- 5
README.md View File

@@ -8,13 +8,17 @@ TODO
8 8
 
9 9
 - Python3
10 10
 - Apache spark 3.2 (https://spark.apache.org/downloads.html)
11
-- Cassandra DB (https://cassandra.apache.org/_/index.html, locally the docker build is recommended: https://hub.docker.com/_/cassandra)
11
+- Cassandra DB (https://cassandra.apache.org/\_/index.html, locally the docker build is recommended: https://hub.docker.com/\_/cassandra)
12 12
 
13
-For the graph implementation specifically you need to install `graphframes` manually since the official release is incompatible with `spark 3.x` (pull request pending). A prebuilt copy is supplied in the `spark-packages` directory. 
13
+For the graph implementation specifically you need to install `graphframes` manually from a third party since the official release is incompatible with `spark 3.x` ([pull request pending](https://github.com/graphframes/graphframes/pull/415)). A prebuilt copy is supplied in the `spark-packages` directory. 
14 14
 - graphframes (https://github.com/eejbyfeldt/graphframes/tree/spark-3.3)
15 15
 
16 16
 ## Setting up
17 17
 
18
-- Modify `settings.json` to reflect your setup. If you are running everything locally you can use `start_services.sh` to turn everything on in one swoop.
19
-- Load the development database by running `python3 setup.py` from the project root.
20
-- Start the spark workload by either running `submit.sh` (slow) or `submit_graph.sh` (faster)
18
+- Modify `settings.json` to reflect your setup. If you are running everything locally you can use `start_services.sh` to turn everything on in one swoop. It might take a few minutes for Cassandra to become available.
19
+- Load the development database by running `python3 setup.py` from the project root. Per default this will move `small_test_data.csv` into the transactions table.
20
+
21
+# Deploying:
22
+
23
+- Start the spark workload by either running `submit.sh` (slow) or `submit_graph.sh` (faster)
24
+- If you need to clean out the Database you can run `python3 clean.py`. Be wary that this wipes all data.

+ 1
- 1
config/db/tables/clusters/CREATE.sql View File

@@ -1,5 +1,5 @@
1 1
 CREATE TABLE clusters(
2 2
     address TEXT,
3
-    parent TEXT,
3
+    id TEXT,
4 4
     PRIMARY KEY (address)
5 5
 );

+ 9
- 11
src/spark/main.py View File

@@ -42,13 +42,13 @@ class Master:
42 42
         return self.spark \
43 43
             .read \
44 44
             .table(self.CLUSTERS_TABLE) \
45
-            .groupBy("parent") \
45
+            .groupBy("id") \
46 46
             .agg(F.collect_set('address').alias('addresses'))
47 47
 
48 48
     def insertNewCluster (self, addrs: Iterable[str], root: str | None = None) -> str:
49 49
         if(root == None):
50 50
             root = addrs[0]
51
-        df = self.spark.createDataFrame(map(lambda addr: (addr, root), addrs), schema=['address', 'parent'])
51
+        df = self.spark.createDataFrame(map(lambda addr: (addr, root), addrs), schema=['address', 'id'])
52 52
         df.writeTo(self.CLUSTERS_TABLE).append()
53 53
         return root
54 54
 
@@ -58,14 +58,14 @@ class Master:
58 58
             .zipWithIndex() \
59 59
             .toDF(["tx_group", "index"])
60 60
 
61
-    def rewrite_cluster_parent(self, cluster_roots: Iterable[str], new_cluster_root: str) -> None:
61
+    def rewrite_cluster_id(self, cluster_roots: Iterable[str], new_cluster_root: str) -> None:
62 62
         cluster_rewrite = self.spark \
63 63
             .table(self.CLUSTERS_TABLE) \
64
-            .where(F.col('parent').isin(cluster_roots)) \
64
+            .where(F.col('id').isin(cluster_roots)) \
65 65
             .select('address') \
66 66
             .rdd \
67 67
             .map(lambda addr: (addr['address'], new_cluster_root)) \
68
-            .toDF(['address', 'parent']) \
68
+            .toDF(['address', 'id']) \
69 69
         
70 70
         if(debug):
71 71
             print("REWRITE JOB")
@@ -73,24 +73,22 @@ class Master:
73 73
             print()
74 74
 
75 75
         cluster_rewrite.writeTo(self.CLUSTERS_TABLE).append()
76
-        
77
-        
78 76
 # end class Master
79 77
 
80 78
 
81 79
 """
82 80
     tuple structure:
83
-        Row => Row(parent=addr, addresses=list[addr] | the cluster
81
+        Row => Row(id=addr, addresses=list[addr] | the cluster
84 82
         Iterable[str] => list[addr] | the transaction addresses
85 83
 """
86 84
 def find(data: tuple[Row, Iterable[str]]) -> str | None:
87 85
     cluster = data[0]
88 86
     tx = data[1]
89 87
 
90
-    clusteraddresses = cluster['addresses'] + [cluster['parent']]
88
+    clusteraddresses = cluster['addresses'] + [cluster['id']]
91 89
 
92 90
     if any(x in tx for x in clusteraddresses):
93
-        return cluster['parent']
91
+        return cluster['id']
94 92
     else:
95 93
         return None
96 94
 
@@ -149,7 +147,7 @@ for i in range(0, tx_addr_groups.count()):
149 147
     elif(len(matched_roots) == 1):
150 148
         master.insertNewCluster(tx_addrs, matched_roots[0])
151 149
     else:
152
-        master.rewrite_cluster_parent(matched_roots[1:], matched_roots[0])
150
+        master.rewrite_cluster_id(matched_roots[1:], matched_roots[0])
153 151
         master.insertNewCluster(tx_addrs, matched_roots[0])
154 152
 
155 153
     if(debug):

+ 24
- 21
src/spark/main_graphs.py View File

@@ -29,52 +29,55 @@ class Master:
29 29
             .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
30 30
             .getOrCreate()
31 31
 
32
-    def empty_dataframe(self, schema) -> DataFrame:
33
-        return self.spark.createDataFrame(self.spark.sparkContext.emptyRDD(), schema)
34
-
35 32
     def get_tx_dataframe(self) -> DataFrame:
36 33
         return self.spark.table(self.TX_TABLE)
37 34
 
38 35
     def get_cluster_dataframe(self) -> DataFrame:
39 36
         return self.spark.table(self.CLUSTERS_TABLE)
40 37
 
41
-# end class Master
38
+    def write_connected_components_as_clusters(self, conn_comp: DataFrame) -> None:
39
+        conn_comp \
40
+            .withColumnRenamed('id', 'address') \
41
+            .withColumnRenamed('component', 'id') \
42
+            .writeTo(self.CLUSTERS_TABLE) \
43
+            .append()
42 44
 
45
+# end class Master
43 46
 
44 47
 master = Master(config)
45
-master.spark.sparkContext.setCheckpointDir(
46
-    './checkpoints')  # spark is really adamant it needs this
48
+master.spark.sparkContext.setCheckpointDir('./checkpoints')  # spark is really adamant it needs this even if the algorithm is set to the non-checkpointed version
47 49
 
48
-# Vertex DataFrame
49
-transaction_as_vertices = master.get_tx_dataframe() \
50
+tx_df = master.get_tx_dataframe()
51
+
52
+transaction_as_vertices =  tx_df \
50 53
     .select('address') \
51 54
     .withColumnRenamed('address', 'id') \
52 55
     .distinct()
53 56
 
54 57
 def explode_row(row: Row) -> List[Row]:
55 58
     addresses = row['addresses']
56
-    if(len(addresses) == 1):
57
-        return []
58
-
59 59
     return list(map(lambda addr: (addr, addresses[0]), addresses[1:]))
60 60
 
61
-
62
-tx_groups = master.get_tx_dataframe() \
61
+transactions_as_edges = tx_df \
63 62
     .groupBy("tx_id") \
64
-    .agg(F.collect_set('address').alias('addresses'))
65
-
66
-transactions_as_edges = tx_groups \
63
+    .agg(F.collect_set('address').alias('addresses')) \
67 64
     .rdd \
68 65
     .flatMap(explode_row) \
69 66
     .toDF(['src', 'dst'])
70 67
 
71
-
72
-# Create a GraphFrame
73 68
 g = GraphFrame(transaction_as_vertices, transactions_as_edges)
74
-res = g.connectedComponents().groupBy('component').agg(F.collect_list('id')).collect()
69
+components = g.connectedComponents(algorithm='graphframes')
70
+
71
+master.write_connected_components_as_clusters(components)
72
+
73
+if(debug):
74
+    clusters = components \
75
+        .groupBy('component') \
76
+        .agg(F.collect_list('id')) \
77
+        .collect()
75 78
 
76
-for row in res:
77
-    print(sorted(row['collect_list(id)']))
79
+    for cluster in clusters:
80
+        print(sorted(cluster['collect_list(id)'])) 
78 81
 
79 82
 end = time.time()
80 83
 print("ELAPSED TIME:", end-start)

Loading…
Cancel
Save