Browse Source

UPDATE TABLE is not supported temporarily XD

master
nitowa 1 year ago
parent
commit
8c4b598043
4 changed files with 167 additions and 49 deletions
  1. 5
    1
      settings.json
  2. 158
    47
      src/spark/main.py
  3. 3
    0
      start_services.sh
  4. 1
    1
      submit.sh

+ 5
- 1
settings.json View File

@@ -2,12 +2,16 @@
2 2
     "cassandra_addresses": ["127.0.0.1"],
3 3
     "cassandra_port": 9042,
4 4
     "cassandra_keyspace": "distributedunionfind",
5
+    "cassandra_catalog": "DUFCatalog",
5 6
 
6 7
     "setup_db_dir": "config/db",
7 8
     "setup_tables_dir": "config/db/tables",
8 9
     "setup_keyspace_dir": "config/db/keyspace",
9 10
 
10 11
     "tx_table_name": "transactions",
11
-    "clusters_table_name": "clusters"
12
+    "clusters_table_name": "clusters",
12 13
 
14
+    "spark_master": "spark://osboxes:7077",
15
+
16
+    "debug": true
13 17
 }

+ 158
- 47
src/spark/main.py View File

@@ -1,44 +1,89 @@
1 1
 from gc import collect
2
-from sqlite3 import Row
3
-from typing import Iterable
4
-from operator import add
5
-
6
-from pyspark.sql import SparkSession
7
-from pyspark.sql import functions as F
8
-
2
+import json
9 3
 
4
+from sqlite3 import Row
5
+from typing import Iterable, List
10 6
 
11
-spark = SparkSession.builder \
12
-    .appName('SparkCassandraApp') \
13
-    .config('spark.cassandra.connection.host', 'localhost') \
14
-    .config('spark.cassandra.connection.port', '9042') \
15
-    .config('spark.cassandra.output.consistency.level', 'ONE') \
16
-    .config("spark.sql.extensions",  "com.datastax.spark.connector.CassandraSparkExtensions") \
17
-    .config('directJoinSetting', 'on') \
18
-    .master('spark://osboxes:7077') \
19
-    .getOrCreate()
20
-
21
-spark.conf.set("spark.sql.catalog.myCatalog",
22
-               "com.datastax.spark.connector.datasource.CassandraCatalog")
23
-
7
+from pyspark import RDD
24 8
 
25
-tx_addr_groups = spark.read.table("myCatalog.distributedunionfind.transactions") \
26
-    .groupBy("tx_id") \
27
-    .agg(F.collect_set('address').alias('addresses')) \
28
-    .toLocalIterator()
9
+from pyspark.sql import SparkSession, DataFrame, Row
10
+from pyspark.sql import functions as F
29 11
 
30
-def insertCluster (row):
31
-    addrs: Iterable[str] = row['addresses']
32
-    df = spark.createDataFrame(map(lambda addr: (addr, addrs[0]), addrs), schema=['address', 'parent'])
12
+config = json.load(open("./settings.json"))
13
+debug = config['debug']
14
+
15
+class Master:
16
+    spark: SparkSession
17
+    CLUSTERS_TABLE: str
18
+    TX_TABLE: str
19
+
20
+    def __init__(self, config):
21
+        self.spark = self.makeSparkContext(config)
22
+        self.config = config
23
+        self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
24
+        self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
25
+
26
+    def makeSparkContext(self,config) -> SparkSession:
27
+        return SparkSession.builder \
28
+        .appName('SparkCassandraApp') \
29
+        .config('spark.cassandra.connection.host', ','.join(config['cassandra_addresses'])) \
30
+        .config('spark.cassandra.connection.port', config["cassandra_port"]) \
31
+        .config('spark.cassandra.output.consistency.level', 'ONE') \
32
+        .config("spark.sql.extensions",  "com.datastax.spark.connector.CassandraSparkExtensions") \
33
+        .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
34
+        .config('directJoinSetting', 'on') \
35
+        .master(config['spark_master']) \
36
+        .getOrCreate()
37
+
38
+    def group_tx_addrs(self) -> DataFrame:
39
+        return self.spark \
40
+            .read \
41
+            .table(self.TX_TABLE) \
42
+            .groupBy("tx_id") \
43
+            .agg(F.collect_set('address').alias('addresses'))
44
+
45
+    def group_cluster_addrs(self) -> DataFrame:
46
+        return self.spark \
47
+            .read \
48
+            .table(self.CLUSTERS_TABLE) \
49
+            .groupBy("parent") \
50
+            .agg(F.collect_set('address').alias('addresses'))
51
+
52
+    def insertNewCluster (self, addrs: Iterable[str], root: str | None = None) -> str:
53
+        if(root == None):
54
+            root = addrs[0]
55
+        df = self.spark.createDataFrame(map(lambda addr: (addr, root), addrs), schema=['address', 'parent'])
56
+        df.writeTo(self.CLUSTERS_TABLE).append()
57
+        return root
58
+
59
+    def enumerate(self, data: DataFrame) -> DataFrame:
60
+        return data \
61
+            .rdd \
62
+            .zipWithIndex() \
63
+            .toDF(["tx_group", "index"])
64
+
65
+    def rewrite_cluster_parent(self, cluster_roots: Iterable[str], new_cluster_root: str) -> None:
66
+        sqlstr = f"""
67
+            UPDATE {self.CLUSTERS_TABLE} 
68
+            SET parent='{new_cluster_root}' 
69
+            WHERE parent IN ({','.join(map(lambda r: f"'{r}'", cluster_roots))})"""
70
+        
71
+        if(debug):
72
+            print("UPDATE SQL")
73
+            print(sqlstr)
74
+            print()
75
+
76
+        self.spark.sql(sqlstr)
77
+
78
+# end class Master
33 79
 
34
-    df.writeTo("myCatalog.distributedunionfind.clusters").overwrite()
35 80
 
36 81
 """
37 82
     tuple structure:
38
-        Row => Row(parent=addr, addresses=list[addr]
39
-        Iterable[str] => list[addr]
83
+        Row => Row(parent=addr, addresses=list[addr] | the cluster
84
+        Iterable[str] => list[addr] | the transaction addresses
40 85
 """
41
-def find(data: tuple[Row, Iterable[str]]):
86
+def find(data: tuple[Row, Iterable[str]]) -> str | None:
42 87
     cluster = data[0]
43 88
     tx = data[1]
44 89
 
@@ -49,28 +94,94 @@ def find(data: tuple[Row, Iterable[str]]):
49 94
     else:
50 95
         return None
51 96
 
52
-for addr_group in tx_addr_groups:
53
-    clusters_df = spark.read.table("myCatalog.distributedunionfind.clusters")
97
+def handleTx(tx_addr_group: Row):
54 98
 
55
-    clusters = clusters_df \
56
-        .groupBy("parent") \
57
-        .agg(F.collect_set('address').alias('addresses'))
58 99
 
59
-    if (clusters.count() == 0):
60
-        insertCluster(addr_group)
100
+    found_clusters: "RDD[str]" = clusters.rdd \
101
+        .map(lambda cluster: (cluster, tx_addr_group['addresses'])) \
102
+        .map(find) \
103
+        .filter(lambda x: x != None)
104
+        
105
+
106
+    if(found_clusters.count() == 0):
107
+        insertNewCluster(tx_addr_group)
108
+        return
109
+
110
+    cluster_roots = found_clusters.collect()
111
+
112
+    cl = clusters \
113
+        .select('addresses') \
114
+        .where(
115
+            F.col('parent').isin(cluster_roots)
116
+        ) \
117
+        .agg(F.collect_set('addresses').alias('agg')) \
118
+        .select(F.flatten('agg').alias('addresses')) \
119
+        .select(F.explode('addresses')) \
120
+        .rdd \
121
+        .map(lambda addr: (addr, cluster_roots[0])) \
122
+        .toDF(['address', 'parent']) \
123
+        .show()
124
+        #.writeTo(CLUSTERS_TABLE) \
125
+        #.append()
126
+
127
+
128
+master = Master(config)
129
+
130
+tx_addr_groups = master.group_tx_addrs()
131
+tx_groups_indexed = master.enumerate(tx_addr_groups)
132
+
133
+for i in range(0, tx_addr_groups.count()):
134
+    cluster_addr_groups = master.group_cluster_addrs()
135
+
136
+    if(debug):
137
+        print("KNOWN CLUSTERS")
138
+        cluster_addr_groups.show(truncate=False)
139
+        print()
140
+
141
+    tx_addrs: Iterable[str] = tx_groups_indexed \
142
+        .where(tx_groups_indexed.index == i) \
143
+        .select('tx_group') \
144
+        .collect()[0]['tx_group']['addresses']
145
+
146
+    if(debug):
147
+        print("CURRENT TX")
148
+        print(tx_addrs)
149
+        print()
150
+
151
+    if (cluster_addr_groups.count() == 0):
152
+        master.insertNewCluster(tx_addrs)
61 153
         continue
62 154
 
63
-    df = clusters.rdd \
64
-        .map(lambda cluster: (cluster, addr_group['addresses'])) \
155
+    cluster_tx_mapping = cluster_addr_groups \
156
+        .rdd \
157
+        .map(lambda cluster: (cluster, tx_addrs)) 
158
+
159
+    if(debug):
160
+        print("cluster_tx_mapping")
161
+        cluster_tx_mapping \
162
+            .toDF(['cluster', 'tx']) \
163
+            .show(truncate=False)
164
+        print()
165
+
166
+
167
+    matched_roots: "List[str]" = cluster_tx_mapping \
65 168
         .map(find) \
66
-        .filter(lambda x: x != None) \
169
+        .filter(lambda root: root != None) \
67 170
         .collect()
68 171
 
69
-    if(len(df) == 0):
70
-        insertCluster(addr_group)
71
-        continue
172
+    if(debug):
173
+        print("FOUND ROOTS")
174
+        print(matched_roots)
175
+        print()
176
+
72 177
 
73
-    print(addr_group)
74
-    print(df)
178
+    if(len(matched_roots) == 0):
179
+        new_root = master.insertNewCluster(tx_addrs)
180
+    elif(len(matched_roots) == 1):
181
+        master.insertNewCluster(tx_addrs, matched_roots[0])
182
+    else:
183
+        master.rewrite_cluster_parent(matched_roots[1:], matched_roots[0])
184
+        master.insertNewCluster(tx_addrs, matched_roots[0])
75 185
 
76
-    break
186
+    if(debug):
187
+        print("==============")

+ 3
- 0
start_services.sh View File

@@ -1,6 +1,9 @@
1 1
 SPARK_HOME="/home/osboxes/Downloads/spark-3.2.2-bin-hadoop3.2"
2 2
 SPARK_MASTER="spark://osboxes:7077"
3 3
 
4
+echo "Starting spark master..."
4 5
 "$SPARK_HOME"/sbin/start-master.sh
6
+echo "Starting spark workers..."
5 7
 SPARK_WORKER_INSTANCES=5 "$SPARK_HOME"/sbin/start-worker.sh "$SPARK_MASTER"
8
+echo "Starting cassandra container..."
6 9
 docker run -d -p 9042:9042 cassandra

+ 1
- 1
submit.sh View File

@@ -1,5 +1,5 @@
1 1
 SPARK_HOME="/home/osboxes/Downloads/spark-3.2.2-bin-hadoop3.2"
2
-MEMORY="4g"
2
+MEMORY="1g"
3 3
 SPARK_MASTER="spark://osboxes:7077"
4 4
 CASSANDRA_HOST="localhost"
5 5
 

Loading…
Cancel
Save