浏览代码

non graph solution but extremely memory intensive

master
nitowa 2 年前
父节点
当前提交
183723e46f
共有 2 个文件被更改,包括 308 次插入118 次删除
  1. 151
    118
      src/spark/main.py
  2. 157
    0
      src/spark/main_bak.py

+ 151
- 118
src/spark/main.py 查看文件

@@ -1,5 +1,6 @@
1 1
 from gc import collect
2 2
 import json
3
+from select import select
3 4
 
4 5
 from sqlite3 import Row
5 6
 from typing import Iterable, List
@@ -14,6 +15,7 @@ start = time.time()
14 15
 config = json.load(open("./settings.json"))
15 16
 debug = config['debug']
16 17
 
18
+
17 19
 class Master:
18 20
     spark: SparkSession
19 21
     CLUSTERS_TABLE: str
@@ -25,133 +27,164 @@ class Master:
25 27
         self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
26 28
         self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
27 29
 
28
-    def makeSparkContext(self,config) -> SparkSession:
30
+    def makeSparkContext(self, config) -> SparkSession:
29 31
         return SparkSession.builder \
30
-        .appName('SparkCassandraApp') \
31
-        .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
32
-        .getOrCreate()
33
-
34
-    def group_tx_addrs(self) -> DataFrame:
35
-        return self.spark \
36
-            .read \
37
-            .table(self.TX_TABLE) \
38
-            .groupBy("tx_id") \
39
-            .agg(F.collect_set('address').alias('addresses'))
40
-
41
-    def group_cluster_addrs(self) -> DataFrame:
42
-        return self.spark \
43
-            .read \
44
-            .table(self.CLUSTERS_TABLE) \
45
-            .groupBy("id") \
46
-            .agg(F.collect_set('address').alias('addresses'))
47
-
48
-    def insertNewCluster (self, addrs: Iterable[str], root: str | None = None) -> str:
49
-        if(root == None):
50
-            root = addrs[0]
51
-        df = self.spark.createDataFrame(map(lambda addr: (addr, root), addrs), schema=['address', 'id'])
52
-        df.writeTo(self.CLUSTERS_TABLE).append()
53
-        return root
54
-
55
-    def enumerate(self, data: DataFrame) -> DataFrame:
56
-        return data \
57
-            .rdd \
58
-            .zipWithIndex() \
59
-            .toDF(["tx_group", "index"])
60
-
61
-    def rewrite_cluster_id(self, cluster_roots: Iterable[str], new_cluster_root: str) -> None:
62
-        cluster_rewrite = self.spark \
63
-            .table(self.CLUSTERS_TABLE) \
64
-            .where(F.col('id').isin(cluster_roots)) \
65
-            .select('address') \
32
+            .appName('SparkCassandraApp') \
33
+            .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
34
+            .getOrCreate()
35
+
36
+    def get_tx_dataframe(self) -> DataFrame:
37
+        return self.spark.table(self.TX_TABLE)
38
+
39
+    def union_single_col(self, df1: DataFrame, df2: DataFrame, column: str) -> DataFrame:
40
+        return df1 \
41
+            .select(column) \
42
+            .union(df2.select(column))
43
+    
44
+    def reduce_concat_array_column(self, df: DataFrame, column: str, distinct:bool = False) -> DataFrame:
45
+        df = self.explode_array_col(df.select(column), column)
46
+        return self.collect_col_to_array(df, column, distinct)
47
+
48
+    def collect_col_to_array(self, df: DataFrame, column: str, distinct: bool = False) -> DataFrame:
49
+        if(distinct):    
50
+            return df.select(F.collect_set(column).alias(column))
51
+        else:
52
+            return df.select(F.collect_list(column).alias(column))
53
+    
54
+    def explode_array_col(self, df: DataFrame, column: str) -> DataFrame:
55
+        return df \
66 56
             .rdd \
67
-            .map(lambda addr: (addr['address'], new_cluster_root)) \
68
-            .toDF(['address', 'id']) \
69
-        
70
-        if(debug):
71
-            print("REWRITE JOB")
72
-            cluster_rewrite.show(truncate=False, vertical=True)
73
-            print()
74
-
75
-        cluster_rewrite.writeTo(self.CLUSTERS_TABLE).append()
76
-# end class Master
77
-
78
-
79
-"""
80
-    tuple structure:
81
-        Row => Row(id=addr, addresses=list[addr] | the cluster
82
-        Iterable[str] => list[addr] | the transaction addresses
83
-"""
84
-def find(data: tuple[Row, Iterable[str]]) -> str | None:
85
-    cluster = data[0]
86
-    tx = data[1]
57
+            .flatMap(lambda row: list(map(lambda elem: (elem,), row[column]))) \
58
+            .toDF([column])
59
+
60
+    def array_col_to_elements(self, df: DataFrame, column: str, distinct:bool = False) -> DataFrame:
61
+        exploded = master.explode_array_col(
62
+            df,
63
+            column
64
+        )
65
+
66
+        #this is likely redundant
67
+        collected = master.collect_col_to_array(
68
+            exploded, 
69
+            column, 
70
+            distinct
71
+        )
72
+
73
+        return self.explode_array_col(
74
+            collected,
75
+            column
76
+        )
87 77
 
88
-    clusteraddresses = cluster['addresses'] + [cluster['id']]
78
+# end class Master
89 79
 
90
-    if any(x in tx for x in clusteraddresses):
91
-        return cluster['id']
92
-    else:
93
-        return None
94 80
 
95 81
 master = Master(config)
96
-
97
-tx_addr_groups = master.group_tx_addrs()
98
-tx_groups_indexed = master.enumerate(tx_addr_groups).cache()
99
-
100
-for i in range(0, tx_addr_groups.count()):
101
-    cluster_addr_groups = master.group_cluster_addrs()
102
-
103
-    if(debug):
104
-        print("KNOWN CLUSTERS")
105
-        cluster_addr_groups.show(truncate=True)
106
-        print()
107
-
108
-    tx_addrs: Iterable[str] = tx_groups_indexed \
109
-        .where(tx_groups_indexed.index == i) \
110
-        .select('tx_group') \
111
-        .collect()[0]['tx_group']['addresses']
112
-
113
-    if(debug):
114
-        print("CURRENT TX")
115
-        print(tx_addrs)
116
-        print()
117
-
118
-    if (cluster_addr_groups.count() == 0):
119
-        master.insertNewCluster(tx_addrs)
120
-        continue
121
-
122
-    cluster_tx_mapping = cluster_addr_groups \
82
+tx_df = master.get_tx_dataframe()
83
+
84
+
85
+#Turn transactions into a list of ('id', [addr, addr, ...])
86
+tx_grouped = tx_df \
87
+    .groupBy('tx_id') \
88
+    .agg(F.collect_set('address').alias('addresses')) \
89
+    .rdd \
90
+    .zipWithIndex() \
91
+    .toDF(['tx', 'index']) \
92
+    .select(
93
+        F.col('tx.tx_id').alias('tx_id'),
94
+        F.col('tx.addresses').alias('addresses'),
95
+        'index'
96
+    ) \
97
+    .cache()
98
+
99
+
100
+# TODO: Load clusters from DB, check if any exist, if no make initial cluster, else proceed with loaded data
101
+
102
+# find initial cluster
103
+
104
+# take the first tx
105
+tx_zero = tx_grouped \
106
+    .select(tx_grouped.tx_id, tx_grouped.addresses) \
107
+    .where(tx_grouped.index == 0)
108
+
109
+# find txs with overlapping addresses
110
+overlapping_txs = tx_grouped \
111
+    .where((tx_grouped.index != 0)) \
112
+    .join(tx_zero.withColumnRenamed('addresses', 'tx_addresses')) \
113
+    .select(
114
+        tx_grouped.index,
115
+        tx_grouped.addresses,
116
+        F.arrays_overlap(tx_grouped.addresses, 'tx_addresses').alias('overlap')
117
+    ) \
118
+    .where(F.col('overlap') == True) \
119
+
120
+# overlapped txs must not be considered anymore, so remove them candidate dataframe
121
+tx_grouped = tx_grouped \
122
+    .join(overlapping_txs, 'index', 'leftanti') \
123
+    .filter(tx_grouped.index != 0)
124
+
125
+# get the distinct addresses of all overlaps in a single array
126
+distinct_addresses = master.reduce_concat_array_column(
127
+    master.union_single_col(
128
+        overlapping_txs, tx_zero, column='addresses'
129
+    ), 
130
+    column='addresses',
131
+    distinct=True,
132
+)
133
+
134
+#pick out a random representative for this cluster and add it to every address 
135
+cluster = distinct_addresses \
136
+    .rdd \
137
+    .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
138
+    .toDF(['address', 'id'])
139
+
140
+# done finding initial cluster
141
+
142
+#group cluster by representative and transform the result into a list of shape ('id', ['addr', 'addr', ...])
143
+clusters_grouped = cluster \
144
+    .groupBy('id') \
145
+    .agg(F.collect_list('address').alias('addresses'))
146
+
147
+def take_tx_and_cluster(txs: DataFrame, clusters: DataFrame):
148
+    if (txs.count() == 0):  # done!
149
+        return clusters
150
+
151
+    # take a random tx
152
+    tx = txs \
153
+        .select('*').limit(1)
154
+
155
+    # find clusters with overlapping addresses from tx
156
+    overlapping_clusters = clusters \
157
+        .join(tx.withColumnRenamed('addresses', 'tx_addresses')) \
158
+        .select(
159
+            clusters.id,
160
+            clusters.addresses,
161
+            F.arrays_overlap(clusters.addresses,'tx_addresses').alias('overlap')
162
+        ) \
163
+        .where(F.col('overlap') == True)
164
+
165
+    #collect all addresses into single array field
166
+    new_cluster_arr = master.reduce_concat_array_column(
167
+        master.union_single_col(tx, overlapping_clusters, 'addresses'),
168
+        column='addresses',
169
+        distinct=True
170
+    )
171
+
172
+    #declare cluster representative
173
+    new_cluster = new_cluster_arr \
123 174
         .rdd \
124
-        .map(lambda cluster: (cluster, tx_addrs))
125
-
126
-    if(debug):
127
-        print("cluster_tx_mapping")
128
-        cluster_tx_mapping \
129
-            .toDF(['cluster', 'tx']) \
130
-            .show(truncate=True)
131
-        print()
132
-
133
-
134
-    matched_roots: "List[str]" = cluster_tx_mapping \
135
-        .map(find) \
136
-        .filter(lambda root: root != None) \
137
-        .collect()
138
-
139
-    if(debug):
140
-        print("FOUND ROOTS")
141
-        print(matched_roots)
142
-        print()
175
+        .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
176
+        .toDF(['address', 'id']) \
177
+        .groupBy('id') \
178
+        .agg(F.collect_list('address').alias('addresses'))
143 179
 
180
+    #start new round with txs minus the one just used, and updated clusters
181
+    return take_tx_and_cluster(
182
+        txs.join(tx, 'index', 'leftanti'), 
183
+        clusters.join(overlapping_clusters, 'id', 'leftanti').union(new_cluster)
184
+    )
144 185
 
145
-    if(len(matched_roots) == 0):
146
-        master.insertNewCluster(tx_addrs)
147
-    elif(len(matched_roots) == 1):
148
-        master.insertNewCluster(tx_addrs, matched_roots[0])
149
-    else:
150
-        master.rewrite_cluster_id(matched_roots[1:], matched_roots[0])
151
-        master.insertNewCluster(tx_addrs, matched_roots[0])
152 186
 
153
-    if(debug):
154
-        print("======================================================================")
187
+take_tx_and_cluster(tx_grouped, clusters_grouped).show()
155 188
 
156 189
 end = time.time()
157
-print("ELAPSED TIME:", end-start)
190
+print("ELAPSED TIME:", end-start)

+ 157
- 0
src/spark/main_bak.py 查看文件

@@ -0,0 +1,157 @@
1
+from gc import collect
2
+import json
3
+
4
+from sqlite3 import Row
5
+from typing import Iterable, List
6
+
7
+from pyspark.sql import SparkSession, DataFrame, Row
8
+from pyspark.sql import functions as F
9
+
10
+import time
11
+start = time.time()
12
+
13
+
14
+config = json.load(open("./settings.json"))
15
+debug = config['debug']
16
+
17
+class Master:
18
+    spark: SparkSession
19
+    CLUSTERS_TABLE: str
20
+    TX_TABLE: str
21
+
22
+    def __init__(self, config):
23
+        self.spark = self.makeSparkContext(config)
24
+        self.config = config
25
+        self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
26
+        self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
27
+
28
+    def makeSparkContext(self,config) -> SparkSession:
29
+        return SparkSession.builder \
30
+        .appName('SparkCassandraApp') \
31
+        .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
32
+        .getOrCreate()
33
+
34
+    def group_tx_addrs(self) -> DataFrame:
35
+        return self.spark \
36
+            .read \
37
+            .table(self.TX_TABLE) \
38
+            .groupBy("tx_id") \
39
+            .agg(F.collect_set('address').alias('addresses'))
40
+
41
+    def group_cluster_addrs(self) -> DataFrame:
42
+        return self.spark \
43
+            .read \
44
+            .table(self.CLUSTERS_TABLE) \
45
+            .groupBy("id") \
46
+            .agg(F.collect_set('address').alias('addresses'))
47
+
48
+    def insertNewCluster (self, addrs: Iterable[str], root: str | None = None) -> str:
49
+        if(root == None):
50
+            root = addrs[0]
51
+        df = self.spark.createDataFrame(map(lambda addr: (addr, root), addrs), schema=['address', 'id'])
52
+        df.writeTo(self.CLUSTERS_TABLE).append()
53
+        return root
54
+
55
+    def enumerate(self, data: DataFrame) -> DataFrame:
56
+        return data \
57
+            .rdd \
58
+            .zipWithIndex() \
59
+            .toDF(["tx_group", "index"])
60
+
61
+    def rewrite_cluster_id(self, cluster_roots: Iterable[str], new_cluster_root: str) -> None:
62
+        cluster_rewrite = self.spark \
63
+            .table(self.CLUSTERS_TABLE) \
64
+            .where(F.col('id').isin(cluster_roots)) \
65
+            .select('address') \
66
+            .rdd \
67
+            .map(lambda addr: (addr['address'], new_cluster_root)) \
68
+            .toDF(['address', 'id']) \
69
+        
70
+        if(debug):
71
+            print("REWRITE JOB")
72
+            cluster_rewrite.show(truncate=False, vertical=True)
73
+            print()
74
+
75
+        cluster_rewrite.writeTo(self.CLUSTERS_TABLE).append()
76
+# end class Master
77
+
78
+
79
+"""
80
+    tuple structure:
81
+        Row => Row(id=addr, addresses=list[addr] | the cluster
82
+        Iterable[str] => list[addr] | the transaction addresses
83
+"""
84
+def find(data: tuple[Row, Iterable[str]]) -> str | None:
85
+    cluster = data[0]
86
+    tx = data[1]
87
+
88
+    clusteraddresses = cluster['addresses'] + [cluster['id']]
89
+
90
+    if any(x in tx for x in clusteraddresses):
91
+        return cluster['id']
92
+    else:
93
+        return None
94
+
95
+master = Master(config)
96
+
97
+tx_addr_groups = master.group_tx_addrs()
98
+tx_groups_indexed = master.enumerate(tx_addr_groups).cache()
99
+
100
+for i in range(0, tx_addr_groups.count()):
101
+    cluster_addr_groups = master.group_cluster_addrs()
102
+
103
+    if(debug):
104
+        print("KNOWN CLUSTERS")
105
+        cluster_addr_groups.show(truncate=True)
106
+        print()
107
+
108
+    tx_addrs: Iterable[str] = tx_groups_indexed \
109
+        .where(tx_groups_indexed.index == i) \
110
+        .select('tx_group') \
111
+        .collect()[0]['tx_group']['addresses']
112
+
113
+    if(debug):
114
+        print("CURRENT TX")
115
+        print(tx_addrs)
116
+        print()
117
+
118
+    if (cluster_addr_groups.count() == 0):
119
+        master.insertNewCluster(tx_addrs)
120
+        continue
121
+
122
+    cluster_tx_mapping = cluster_addr_groups \
123
+        .rdd \
124
+        .map(lambda cluster: (cluster, tx_addrs))
125
+
126
+    if(debug):
127
+        print("cluster_tx_mapping")
128
+        cluster_tx_mapping \
129
+            .toDF(['cluster', 'tx']) \
130
+            .show(truncate=True)
131
+        print()
132
+
133
+
134
+    matched_roots: "List[str]" = cluster_tx_mapping \
135
+        .map(find) \
136
+        .filter(lambda root: root != None) \
137
+        .collect()
138
+
139
+    if(debug):
140
+        print("FOUND ROOTS")
141
+        print(matched_roots)
142
+        print()
143
+
144
+
145
+    if(len(matched_roots) == 0):
146
+        master.insertNewCluster(tx_addrs)
147
+    elif(len(matched_roots) == 1):
148
+        master.insertNewCluster(tx_addrs, matched_roots[0])
149
+    else:
150
+        master.rewrite_cluster_id(matched_roots[1:], matched_roots[0])
151
+        master.insertNewCluster(tx_addrs, matched_roots[0])
152
+
153
+    if(debug):
154
+        print("======================================================================")
155
+
156
+end = time.time()
157
+print("ELAPSED TIME:", end-start)

正在加载...
取消
保存