Browse Source

add db write to graph impl

master
nitowa 1 year ago
parent
commit
d42f70d33c
5 changed files with 45 additions and 39 deletions
  1. 2
    1
      .gitignore
  2. 9
    5
      README.md
  3. 1
    1
      config/db/tables/clusters/CREATE.sql
  4. 9
    11
      src/spark/main.py
  5. 24
    21
      src/spark/main_graphs.py

+ 2
- 1
.gitignore View File

1
 __pycache__
1
 __pycache__
2
 .vscode
2
 .vscode
3
 checkpoints
3
 checkpoints
4
-spark-warehouse
4
+spark-warehouse
5
+scratchpad.py

+ 9
- 5
README.md View File

8
 
8
 
9
 - Python3
9
 - Python3
10
 - Apache spark 3.2 (https://spark.apache.org/downloads.html)
10
 - Apache spark 3.2 (https://spark.apache.org/downloads.html)
11
-- Cassandra DB (https://cassandra.apache.org/_/index.html, locally the docker build is recommended: https://hub.docker.com/_/cassandra)
11
+- Cassandra DB (https://cassandra.apache.org/\_/index.html, locally the docker build is recommended: https://hub.docker.com/\_/cassandra)
12
 
12
 
13
-For the graph implementation specifically you need to install `graphframes` manually since the official release is incompatible with `spark 3.x` (pull request pending). A prebuilt copy is supplied in the `spark-packages` directory. 
13
+For the graph implementation specifically you need to install `graphframes` manually from a third party since the official release is incompatible with `spark 3.x` ([pull request pending](https://github.com/graphframes/graphframes/pull/415)). A prebuilt copy is supplied in the `spark-packages` directory. 
14
 - graphframes (https://github.com/eejbyfeldt/graphframes/tree/spark-3.3)
14
 - graphframes (https://github.com/eejbyfeldt/graphframes/tree/spark-3.3)
15
 
15
 
16
 ## Setting up
16
 ## Setting up
17
 
17
 
18
-- Modify `settings.json` to reflect your setup. If you are running everything locally you can use `start_services.sh` to turn everything on in one swoop.
19
-- Load the development database by running `python3 setup.py` from the project root.
20
-- Start the spark workload by either running `submit.sh` (slow) or `submit_graph.sh` (faster)
18
+- Modify `settings.json` to reflect your setup. If you are running everything locally you can use `start_services.sh` to turn everything on in one swoop. It might take a few minutes for Cassandra to become available.
19
+- Load the development database by running `python3 setup.py` from the project root. Per default this will move `small_test_data.csv` into the transactions table.
20
+
21
+# Deploying:
22
+
23
+- Start the spark workload by either running `submit.sh` (slow) or `submit_graph.sh` (faster)
24
+- If you need to clean out the Database you can run `python3 clean.py`. Be wary that this wipes all data.

+ 1
- 1
config/db/tables/clusters/CREATE.sql View File

1
 CREATE TABLE clusters(
1
 CREATE TABLE clusters(
2
     address TEXT,
2
     address TEXT,
3
-    parent TEXT,
3
+    id TEXT,
4
     PRIMARY KEY (address)
4
     PRIMARY KEY (address)
5
 );
5
 );

+ 9
- 11
src/spark/main.py View File

42
         return self.spark \
42
         return self.spark \
43
             .read \
43
             .read \
44
             .table(self.CLUSTERS_TABLE) \
44
             .table(self.CLUSTERS_TABLE) \
45
-            .groupBy("parent") \
45
+            .groupBy("id") \
46
             .agg(F.collect_set('address').alias('addresses'))
46
             .agg(F.collect_set('address').alias('addresses'))
47
 
47
 
48
     def insertNewCluster (self, addrs: Iterable[str], root: str | None = None) -> str:
48
     def insertNewCluster (self, addrs: Iterable[str], root: str | None = None) -> str:
49
         if(root == None):
49
         if(root == None):
50
             root = addrs[0]
50
             root = addrs[0]
51
-        df = self.spark.createDataFrame(map(lambda addr: (addr, root), addrs), schema=['address', 'parent'])
51
+        df = self.spark.createDataFrame(map(lambda addr: (addr, root), addrs), schema=['address', 'id'])
52
         df.writeTo(self.CLUSTERS_TABLE).append()
52
         df.writeTo(self.CLUSTERS_TABLE).append()
53
         return root
53
         return root
54
 
54
 
58
             .zipWithIndex() \
58
             .zipWithIndex() \
59
             .toDF(["tx_group", "index"])
59
             .toDF(["tx_group", "index"])
60
 
60
 
61
-    def rewrite_cluster_parent(self, cluster_roots: Iterable[str], new_cluster_root: str) -> None:
61
+    def rewrite_cluster_id(self, cluster_roots: Iterable[str], new_cluster_root: str) -> None:
62
         cluster_rewrite = self.spark \
62
         cluster_rewrite = self.spark \
63
             .table(self.CLUSTERS_TABLE) \
63
             .table(self.CLUSTERS_TABLE) \
64
-            .where(F.col('parent').isin(cluster_roots)) \
64
+            .where(F.col('id').isin(cluster_roots)) \
65
             .select('address') \
65
             .select('address') \
66
             .rdd \
66
             .rdd \
67
             .map(lambda addr: (addr['address'], new_cluster_root)) \
67
             .map(lambda addr: (addr['address'], new_cluster_root)) \
68
-            .toDF(['address', 'parent']) \
68
+            .toDF(['address', 'id']) \
69
         
69
         
70
         if(debug):
70
         if(debug):
71
             print("REWRITE JOB")
71
             print("REWRITE JOB")
73
             print()
73
             print()
74
 
74
 
75
         cluster_rewrite.writeTo(self.CLUSTERS_TABLE).append()
75
         cluster_rewrite.writeTo(self.CLUSTERS_TABLE).append()
76
-        
77
-        
78
 # end class Master
76
 # end class Master
79
 
77
 
80
 
78
 
81
 """
79
 """
82
     tuple structure:
80
     tuple structure:
83
-        Row => Row(parent=addr, addresses=list[addr] | the cluster
81
+        Row => Row(id=addr, addresses=list[addr] | the cluster
84
         Iterable[str] => list[addr] | the transaction addresses
82
         Iterable[str] => list[addr] | the transaction addresses
85
 """
83
 """
86
 def find(data: tuple[Row, Iterable[str]]) -> str | None:
84
 def find(data: tuple[Row, Iterable[str]]) -> str | None:
87
     cluster = data[0]
85
     cluster = data[0]
88
     tx = data[1]
86
     tx = data[1]
89
 
87
 
90
-    clusteraddresses = cluster['addresses'] + [cluster['parent']]
88
+    clusteraddresses = cluster['addresses'] + [cluster['id']]
91
 
89
 
92
     if any(x in tx for x in clusteraddresses):
90
     if any(x in tx for x in clusteraddresses):
93
-        return cluster['parent']
91
+        return cluster['id']
94
     else:
92
     else:
95
         return None
93
         return None
96
 
94
 
149
     elif(len(matched_roots) == 1):
147
     elif(len(matched_roots) == 1):
150
         master.insertNewCluster(tx_addrs, matched_roots[0])
148
         master.insertNewCluster(tx_addrs, matched_roots[0])
151
     else:
149
     else:
152
-        master.rewrite_cluster_parent(matched_roots[1:], matched_roots[0])
150
+        master.rewrite_cluster_id(matched_roots[1:], matched_roots[0])
153
         master.insertNewCluster(tx_addrs, matched_roots[0])
151
         master.insertNewCluster(tx_addrs, matched_roots[0])
154
 
152
 
155
     if(debug):
153
     if(debug):

+ 24
- 21
src/spark/main_graphs.py View File

29
             .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
29
             .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
30
             .getOrCreate()
30
             .getOrCreate()
31
 
31
 
32
-    def empty_dataframe(self, schema) -> DataFrame:
33
-        return self.spark.createDataFrame(self.spark.sparkContext.emptyRDD(), schema)
34
-
35
     def get_tx_dataframe(self) -> DataFrame:
32
     def get_tx_dataframe(self) -> DataFrame:
36
         return self.spark.table(self.TX_TABLE)
33
         return self.spark.table(self.TX_TABLE)
37
 
34
 
38
     def get_cluster_dataframe(self) -> DataFrame:
35
     def get_cluster_dataframe(self) -> DataFrame:
39
         return self.spark.table(self.CLUSTERS_TABLE)
36
         return self.spark.table(self.CLUSTERS_TABLE)
40
 
37
 
41
-# end class Master
38
+    def write_connected_components_as_clusters(self, conn_comp: DataFrame) -> None:
39
+        conn_comp \
40
+            .withColumnRenamed('id', 'address') \
41
+            .withColumnRenamed('component', 'id') \
42
+            .writeTo(self.CLUSTERS_TABLE) \
43
+            .append()
42
 
44
 
45
+# end class Master
43
 
46
 
44
 master = Master(config)
47
 master = Master(config)
45
-master.spark.sparkContext.setCheckpointDir(
46
-    './checkpoints')  # spark is really adamant it needs this
48
+master.spark.sparkContext.setCheckpointDir('./checkpoints')  # spark is really adamant it needs this even if the algorithm is set to the non-checkpointed version
47
 
49
 
48
-# Vertex DataFrame
49
-transaction_as_vertices = master.get_tx_dataframe() \
50
+tx_df = master.get_tx_dataframe()
51
+
52
+transaction_as_vertices =  tx_df \
50
     .select('address') \
53
     .select('address') \
51
     .withColumnRenamed('address', 'id') \
54
     .withColumnRenamed('address', 'id') \
52
     .distinct()
55
     .distinct()
53
 
56
 
54
 def explode_row(row: Row) -> List[Row]:
57
 def explode_row(row: Row) -> List[Row]:
55
     addresses = row['addresses']
58
     addresses = row['addresses']
56
-    if(len(addresses) == 1):
57
-        return []
58
-
59
     return list(map(lambda addr: (addr, addresses[0]), addresses[1:]))
59
     return list(map(lambda addr: (addr, addresses[0]), addresses[1:]))
60
 
60
 
61
-
62
-tx_groups = master.get_tx_dataframe() \
61
+transactions_as_edges = tx_df \
63
     .groupBy("tx_id") \
62
     .groupBy("tx_id") \
64
-    .agg(F.collect_set('address').alias('addresses'))
65
-
66
-transactions_as_edges = tx_groups \
63
+    .agg(F.collect_set('address').alias('addresses')) \
67
     .rdd \
64
     .rdd \
68
     .flatMap(explode_row) \
65
     .flatMap(explode_row) \
69
     .toDF(['src', 'dst'])
66
     .toDF(['src', 'dst'])
70
 
67
 
71
-
72
-# Create a GraphFrame
73
 g = GraphFrame(transaction_as_vertices, transactions_as_edges)
68
 g = GraphFrame(transaction_as_vertices, transactions_as_edges)
74
-res = g.connectedComponents().groupBy('component').agg(F.collect_list('id')).collect()
69
+components = g.connectedComponents(algorithm='graphframes')
70
+
71
+master.write_connected_components_as_clusters(components)
72
+
73
+if(debug):
74
+    clusters = components \
75
+        .groupBy('component') \
76
+        .agg(F.collect_list('id')) \
77
+        .collect()
75
 
78
 
76
-for row in res:
77
-    print(sorted(row['collect_list(id)']))
79
+    for cluster in clusters:
80
+        print(sorted(cluster['collect_list(id)'])) 
78
 
81
 
79
 end = time.time()
82
 end = time.time()
80
 print("ELAPSED TIME:", end-start)
83
 print("ELAPSED TIME:", end-start)

Loading…
Cancel
Save