|
@@ -1,5 +1,6 @@
|
1
|
1
|
from gc import collect
|
2
|
2
|
import json
|
|
3
|
+from select import select
|
3
|
4
|
|
4
|
5
|
from sqlite3 import Row
|
5
|
6
|
from typing import Iterable, List
|
|
@@ -14,6 +15,7 @@ start = time.time()
|
14
|
15
|
config = json.load(open("./settings.json"))
|
15
|
16
|
debug = config['debug']
|
16
|
17
|
|
|
18
|
+
|
17
|
19
|
class Master:
|
18
|
20
|
spark: SparkSession
|
19
|
21
|
CLUSTERS_TABLE: str
|
|
@@ -25,133 +27,164 @@ class Master:
|
25
|
27
|
self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
|
26
|
28
|
self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
|
27
|
29
|
|
28
|
|
- def makeSparkContext(self,config) -> SparkSession:
|
|
30
|
+ def makeSparkContext(self, config) -> SparkSession:
|
29
|
31
|
return SparkSession.builder \
|
30
|
|
- .appName('SparkCassandraApp') \
|
31
|
|
- .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
|
32
|
|
- .getOrCreate()
|
33
|
|
-
|
34
|
|
- def group_tx_addrs(self) -> DataFrame:
|
35
|
|
- return self.spark \
|
36
|
|
- .read \
|
37
|
|
- .table(self.TX_TABLE) \
|
38
|
|
- .groupBy("tx_id") \
|
39
|
|
- .agg(F.collect_set('address').alias('addresses'))
|
40
|
|
-
|
41
|
|
- def group_cluster_addrs(self) -> DataFrame:
|
42
|
|
- return self.spark \
|
43
|
|
- .read \
|
44
|
|
- .table(self.CLUSTERS_TABLE) \
|
45
|
|
- .groupBy("id") \
|
46
|
|
- .agg(F.collect_set('address').alias('addresses'))
|
47
|
|
-
|
48
|
|
- def insertNewCluster (self, addrs: Iterable[str], root: str | None = None) -> str:
|
49
|
|
- if(root == None):
|
50
|
|
- root = addrs[0]
|
51
|
|
- df = self.spark.createDataFrame(map(lambda addr: (addr, root), addrs), schema=['address', 'id'])
|
52
|
|
- df.writeTo(self.CLUSTERS_TABLE).append()
|
53
|
|
- return root
|
54
|
|
-
|
55
|
|
- def enumerate(self, data: DataFrame) -> DataFrame:
|
56
|
|
- return data \
|
57
|
|
- .rdd \
|
58
|
|
- .zipWithIndex() \
|
59
|
|
- .toDF(["tx_group", "index"])
|
60
|
|
-
|
61
|
|
- def rewrite_cluster_id(self, cluster_roots: Iterable[str], new_cluster_root: str) -> None:
|
62
|
|
- cluster_rewrite = self.spark \
|
63
|
|
- .table(self.CLUSTERS_TABLE) \
|
64
|
|
- .where(F.col('id').isin(cluster_roots)) \
|
65
|
|
- .select('address') \
|
|
32
|
+ .appName('SparkCassandraApp') \
|
|
33
|
+ .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
|
|
34
|
+ .getOrCreate()
|
|
35
|
+
|
|
36
|
+ def get_tx_dataframe(self) -> DataFrame:
|
|
37
|
+ return self.spark.table(self.TX_TABLE)
|
|
38
|
+
|
|
39
|
+ def union_single_col(self, df1: DataFrame, df2: DataFrame, column: str) -> DataFrame:
|
|
40
|
+ return df1 \
|
|
41
|
+ .select(column) \
|
|
42
|
+ .union(df2.select(column))
|
|
43
|
+
|
|
44
|
+ def reduce_concat_array_column(self, df: DataFrame, column: str, distinct:bool = False) -> DataFrame:
|
|
45
|
+ df = self.explode_array_col(df.select(column), column)
|
|
46
|
+ return self.collect_col_to_array(df, column, distinct)
|
|
47
|
+
|
|
48
|
+ def collect_col_to_array(self, df: DataFrame, column: str, distinct: bool = False) -> DataFrame:
|
|
49
|
+ if(distinct):
|
|
50
|
+ return df.select(F.collect_set(column).alias(column))
|
|
51
|
+ else:
|
|
52
|
+ return df.select(F.collect_list(column).alias(column))
|
|
53
|
+
|
|
54
|
+ def explode_array_col(self, df: DataFrame, column: str) -> DataFrame:
|
|
55
|
+ return df \
|
66
|
56
|
.rdd \
|
67
|
|
- .map(lambda addr: (addr['address'], new_cluster_root)) \
|
68
|
|
- .toDF(['address', 'id']) \
|
69
|
|
-
|
70
|
|
- if(debug):
|
71
|
|
- print("REWRITE JOB")
|
72
|
|
- cluster_rewrite.show(truncate=False, vertical=True)
|
73
|
|
- print()
|
74
|
|
-
|
75
|
|
- cluster_rewrite.writeTo(self.CLUSTERS_TABLE).append()
|
76
|
|
-# end class Master
|
77
|
|
-
|
78
|
|
-
|
79
|
|
-"""
|
80
|
|
- tuple structure:
|
81
|
|
- Row => Row(id=addr, addresses=list[addr] | the cluster
|
82
|
|
- Iterable[str] => list[addr] | the transaction addresses
|
83
|
|
-"""
|
84
|
|
-def find(data: tuple[Row, Iterable[str]]) -> str | None:
|
85
|
|
- cluster = data[0]
|
86
|
|
- tx = data[1]
|
|
57
|
+ .flatMap(lambda row: list(map(lambda elem: (elem,), row[column]))) \
|
|
58
|
+ .toDF([column])
|
|
59
|
+
|
|
60
|
+ def array_col_to_elements(self, df: DataFrame, column: str, distinct:bool = False) -> DataFrame:
|
|
61
|
+ exploded = master.explode_array_col(
|
|
62
|
+ df,
|
|
63
|
+ column
|
|
64
|
+ )
|
|
65
|
+
|
|
66
|
+ #this is likely redundant
|
|
67
|
+ collected = master.collect_col_to_array(
|
|
68
|
+ exploded,
|
|
69
|
+ column,
|
|
70
|
+ distinct
|
|
71
|
+ )
|
|
72
|
+
|
|
73
|
+ return self.explode_array_col(
|
|
74
|
+ collected,
|
|
75
|
+ column
|
|
76
|
+ )
|
87
|
77
|
|
88
|
|
- clusteraddresses = cluster['addresses'] + [cluster['id']]
|
|
78
|
+# end class Master
|
89
|
79
|
|
90
|
|
- if any(x in tx for x in clusteraddresses):
|
91
|
|
- return cluster['id']
|
92
|
|
- else:
|
93
|
|
- return None
|
94
|
80
|
|
95
|
81
|
master = Master(config)
|
96
|
|
-
|
97
|
|
-tx_addr_groups = master.group_tx_addrs()
|
98
|
|
-tx_groups_indexed = master.enumerate(tx_addr_groups).cache()
|
99
|
|
-
|
100
|
|
-for i in range(0, tx_addr_groups.count()):
|
101
|
|
- cluster_addr_groups = master.group_cluster_addrs()
|
102
|
|
-
|
103
|
|
- if(debug):
|
104
|
|
- print("KNOWN CLUSTERS")
|
105
|
|
- cluster_addr_groups.show(truncate=True)
|
106
|
|
- print()
|
107
|
|
-
|
108
|
|
- tx_addrs: Iterable[str] = tx_groups_indexed \
|
109
|
|
- .where(tx_groups_indexed.index == i) \
|
110
|
|
- .select('tx_group') \
|
111
|
|
- .collect()[0]['tx_group']['addresses']
|
112
|
|
-
|
113
|
|
- if(debug):
|
114
|
|
- print("CURRENT TX")
|
115
|
|
- print(tx_addrs)
|
116
|
|
- print()
|
117
|
|
-
|
118
|
|
- if (cluster_addr_groups.count() == 0):
|
119
|
|
- master.insertNewCluster(tx_addrs)
|
120
|
|
- continue
|
121
|
|
-
|
122
|
|
- cluster_tx_mapping = cluster_addr_groups \
|
|
82
|
+tx_df = master.get_tx_dataframe()
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+#Turn transactions into a list of ('id', [addr, addr, ...])
|
|
86
|
+tx_grouped = tx_df \
|
|
87
|
+ .groupBy('tx_id') \
|
|
88
|
+ .agg(F.collect_set('address').alias('addresses')) \
|
|
89
|
+ .rdd \
|
|
90
|
+ .zipWithIndex() \
|
|
91
|
+ .toDF(['tx', 'index']) \
|
|
92
|
+ .select(
|
|
93
|
+ F.col('tx.tx_id').alias('tx_id'),
|
|
94
|
+ F.col('tx.addresses').alias('addresses'),
|
|
95
|
+ 'index'
|
|
96
|
+ ) \
|
|
97
|
+ .cache()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+# TODO: Load clusters from DB, check if any exist, if no make initial cluster, else proceed with loaded data
|
|
101
|
+
|
|
102
|
+# find initial cluster
|
|
103
|
+
|
|
104
|
+# take the first tx
|
|
105
|
+tx_zero = tx_grouped \
|
|
106
|
+ .select(tx_grouped.tx_id, tx_grouped.addresses) \
|
|
107
|
+ .where(tx_grouped.index == 0)
|
|
108
|
+
|
|
109
|
+# find txs with overlapping addresses
|
|
110
|
+overlapping_txs = tx_grouped \
|
|
111
|
+ .where((tx_grouped.index != 0)) \
|
|
112
|
+ .join(tx_zero.withColumnRenamed('addresses', 'tx_addresses')) \
|
|
113
|
+ .select(
|
|
114
|
+ tx_grouped.index,
|
|
115
|
+ tx_grouped.addresses,
|
|
116
|
+ F.arrays_overlap(tx_grouped.addresses, 'tx_addresses').alias('overlap')
|
|
117
|
+ ) \
|
|
118
|
+ .where(F.col('overlap') == True) \
|
|
119
|
+
|
|
120
|
+# overlapped txs must not be considered anymore, so remove them candidate dataframe
|
|
121
|
+tx_grouped = tx_grouped \
|
|
122
|
+ .join(overlapping_txs, 'index', 'leftanti') \
|
|
123
|
+ .filter(tx_grouped.index != 0)
|
|
124
|
+
|
|
125
|
+# get the distinct addresses of all overlaps in a single array
|
|
126
|
+distinct_addresses = master.reduce_concat_array_column(
|
|
127
|
+ master.union_single_col(
|
|
128
|
+ overlapping_txs, tx_zero, column='addresses'
|
|
129
|
+ ),
|
|
130
|
+ column='addresses',
|
|
131
|
+ distinct=True,
|
|
132
|
+)
|
|
133
|
+
|
|
134
|
+#pick out a random representative for this cluster and add it to every address
|
|
135
|
+cluster = distinct_addresses \
|
|
136
|
+ .rdd \
|
|
137
|
+ .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
|
|
138
|
+ .toDF(['address', 'id'])
|
|
139
|
+
|
|
140
|
+# done finding initial cluster
|
|
141
|
+
|
|
142
|
+#group cluster by representative and transform the result into a list of shape ('id', ['addr', 'addr', ...])
|
|
143
|
+clusters_grouped = cluster \
|
|
144
|
+ .groupBy('id') \
|
|
145
|
+ .agg(F.collect_list('address').alias('addresses'))
|
|
146
|
+
|
|
147
|
+def take_tx_and_cluster(txs: DataFrame, clusters: DataFrame):
|
|
148
|
+ if (txs.count() == 0): # done!
|
|
149
|
+ return clusters
|
|
150
|
+
|
|
151
|
+ # take a random tx
|
|
152
|
+ tx = txs \
|
|
153
|
+ .select('*').limit(1)
|
|
154
|
+
|
|
155
|
+ # find clusters with overlapping addresses from tx
|
|
156
|
+ overlapping_clusters = clusters \
|
|
157
|
+ .join(tx.withColumnRenamed('addresses', 'tx_addresses')) \
|
|
158
|
+ .select(
|
|
159
|
+ clusters.id,
|
|
160
|
+ clusters.addresses,
|
|
161
|
+ F.arrays_overlap(clusters.addresses,'tx_addresses').alias('overlap')
|
|
162
|
+ ) \
|
|
163
|
+ .where(F.col('overlap') == True)
|
|
164
|
+
|
|
165
|
+ #collect all addresses into single array field
|
|
166
|
+ new_cluster_arr = master.reduce_concat_array_column(
|
|
167
|
+ master.union_single_col(tx, overlapping_clusters, 'addresses'),
|
|
168
|
+ column='addresses',
|
|
169
|
+ distinct=True
|
|
170
|
+ )
|
|
171
|
+
|
|
172
|
+ #declare cluster representative
|
|
173
|
+ new_cluster = new_cluster_arr \
|
123
|
174
|
.rdd \
|
124
|
|
- .map(lambda cluster: (cluster, tx_addrs))
|
125
|
|
-
|
126
|
|
- if(debug):
|
127
|
|
- print("cluster_tx_mapping")
|
128
|
|
- cluster_tx_mapping \
|
129
|
|
- .toDF(['cluster', 'tx']) \
|
130
|
|
- .show(truncate=True)
|
131
|
|
- print()
|
132
|
|
-
|
133
|
|
-
|
134
|
|
- matched_roots: "List[str]" = cluster_tx_mapping \
|
135
|
|
- .map(find) \
|
136
|
|
- .filter(lambda root: root != None) \
|
137
|
|
- .collect()
|
138
|
|
-
|
139
|
|
- if(debug):
|
140
|
|
- print("FOUND ROOTS")
|
141
|
|
- print(matched_roots)
|
142
|
|
- print()
|
|
175
|
+ .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
|
|
176
|
+ .toDF(['address', 'id']) \
|
|
177
|
+ .groupBy('id') \
|
|
178
|
+ .agg(F.collect_list('address').alias('addresses'))
|
143
|
179
|
|
|
180
|
+ #start new round with txs minus the one just used, and updated clusters
|
|
181
|
+ return take_tx_and_cluster(
|
|
182
|
+ txs.join(tx, 'index', 'leftanti'),
|
|
183
|
+ clusters.join(overlapping_clusters, 'id', 'leftanti').union(new_cluster)
|
|
184
|
+ )
|
144
|
185
|
|
145
|
|
- if(len(matched_roots) == 0):
|
146
|
|
- master.insertNewCluster(tx_addrs)
|
147
|
|
- elif(len(matched_roots) == 1):
|
148
|
|
- master.insertNewCluster(tx_addrs, matched_roots[0])
|
149
|
|
- else:
|
150
|
|
- master.rewrite_cluster_id(matched_roots[1:], matched_roots[0])
|
151
|
|
- master.insertNewCluster(tx_addrs, matched_roots[0])
|
152
|
186
|
|
153
|
|
- if(debug):
|
154
|
|
- print("======================================================================")
|
|
187
|
+take_tx_and_cluster(tx_grouped, clusters_grouped).show()
|
155
|
188
|
|
156
|
189
|
end = time.time()
|
157
|
|
-print("ELAPSED TIME:", end-start)
|
|
190
|
+print("ELAPSED TIME:", end-start)
|