|
@@ -1,9 +1,4 @@
|
1
|
|
-from gc import collect
|
2
|
1
|
import json
|
3
|
|
-from select import select
|
4
|
|
-
|
5
|
|
-from sqlite3 import Row
|
6
|
|
-from typing import Iterable, List
|
7
|
2
|
|
8
|
3
|
from pyspark.sql import SparkSession, DataFrame, Row
|
9
|
4
|
from pyspark.sql import functions as F
|
|
@@ -56,46 +51,18 @@ class Master:
|
56
|
51
|
.rdd \
|
57
|
52
|
.flatMap(lambda row: list(map(lambda elem: (elem,), row[column]))) \
|
58
|
53
|
.toDF([column])
|
59
|
|
-
|
60
|
|
- def array_col_to_elements(self, df: DataFrame, column: str, distinct:bool = False) -> DataFrame:
|
61
|
|
- exploded = master.explode_array_col(
|
62
|
|
- df,
|
63
|
|
- column
|
64
|
|
- )
|
65
|
|
-
|
66
|
|
- #this is likely redundant
|
67
|
|
- collected = master.collect_col_to_array(
|
68
|
|
- exploded,
|
69
|
|
- column,
|
70
|
|
- distinct
|
71
|
|
- )
|
72
|
|
-
|
73
|
|
- return self.explode_array_col(
|
74
|
|
- collected,
|
75
|
|
- column
|
76
|
|
- )
|
77
|
|
-
|
78
|
54
|
# end class Master
|
79
|
55
|
|
80
|
56
|
|
81
|
57
|
master = Master(config)
|
|
58
|
+master.spark.catalog.clearCache()
|
|
59
|
+master.spark.sparkContext.setCheckpointDir('./checkpoints')
|
82
|
60
|
tx_df = master.get_tx_dataframe()
|
83
|
61
|
|
84
|
|
-
|
85
|
62
|
#Turn transactions into a list of ('id', [addr, addr, ...])
|
86
|
63
|
tx_grouped = tx_df \
|
87
|
64
|
.groupBy('tx_id') \
|
88
|
|
- .agg(F.collect_set('address').alias('addresses')) \
|
89
|
|
- .rdd \
|
90
|
|
- .zipWithIndex() \
|
91
|
|
- .toDF(['tx', 'index']) \
|
92
|
|
- .select(
|
93
|
|
- F.col('tx.tx_id').alias('tx_id'),
|
94
|
|
- F.col('tx.addresses').alias('addresses'),
|
95
|
|
- 'index'
|
96
|
|
- ) \
|
97
|
|
- .cache()
|
98
|
|
-
|
|
65
|
+ .agg(F.collect_set('address').alias('addresses'))
|
99
|
66
|
|
100
|
67
|
# TODO: Load clusters from DB, check if any exist, if no make initial cluster, else proceed with loaded data
|
101
|
68
|
|
|
@@ -103,29 +70,39 @@ tx_grouped = tx_df \
|
103
|
70
|
|
104
|
71
|
# take the first tx
|
105
|
72
|
tx_zero = tx_grouped \
|
106
|
|
- .select(tx_grouped.tx_id, tx_grouped.addresses) \
|
107
|
|
- .where(tx_grouped.index == 0)
|
|
73
|
+ .select('*') \
|
|
74
|
+ .where('tx_id = 3') \
|
|
75
|
+ .limit(1)
|
108
|
76
|
|
109
|
77
|
# find txs with overlapping addresses
|
110
|
78
|
overlapping_txs = tx_grouped \
|
111
|
|
- .where((tx_grouped.index != 0)) \
|
112
|
|
- .join(tx_zero.withColumnRenamed('addresses', 'tx_addresses')) \
|
|
79
|
+ .join(
|
|
80
|
+ tx_zero \
|
|
81
|
+ .withColumnRenamed('addresses', 'tx_addresses') \
|
|
82
|
+ .withColumnRenamed('tx_id', 'overlap_id')
|
|
83
|
+ ) \
|
113
|
84
|
.select(
|
114
|
|
- tx_grouped.index,
|
|
85
|
+ tx_grouped.tx_id,
|
115
|
86
|
tx_grouped.addresses,
|
116
|
87
|
F.arrays_overlap(tx_grouped.addresses, 'tx_addresses').alias('overlap')
|
117
|
88
|
) \
|
118
|
89
|
.where(F.col('overlap') == True) \
|
|
90
|
+ .drop('overlap')
|
119
|
91
|
|
120
|
92
|
# overlapped txs must not be considered anymore, so remove them candidate dataframe
|
121
|
93
|
tx_grouped = tx_grouped \
|
122
|
|
- .join(overlapping_txs, 'index', 'leftanti') \
|
123
|
|
- .filter(tx_grouped.index != 0)
|
|
94
|
+ .join(
|
|
95
|
+ overlapping_txs.drop('addresses'),
|
|
96
|
+ 'tx_id',
|
|
97
|
+ 'leftanti'
|
|
98
|
+ )
|
124
|
99
|
|
125
|
100
|
# get the distinct addresses of all overlaps in a single array
|
126
|
101
|
distinct_addresses = master.reduce_concat_array_column(
|
127
|
102
|
master.union_single_col(
|
128
|
|
- overlapping_txs, tx_zero, column='addresses'
|
|
103
|
+ overlapping_txs,
|
|
104
|
+ tx_zero,
|
|
105
|
+ column='addresses'
|
129
|
106
|
),
|
130
|
107
|
column='addresses',
|
131
|
108
|
distinct=True,
|
|
@@ -144,7 +121,7 @@ clusters_grouped = cluster \
|
144
|
121
|
.groupBy('id') \
|
145
|
122
|
.agg(F.collect_list('address').alias('addresses'))
|
146
|
123
|
|
147
|
|
-def take_tx_and_cluster(txs: DataFrame, clusters: DataFrame):
|
|
124
|
+def take_tx_and_cluster(txs: DataFrame, clusters: DataFrame, n=0):
|
148
|
125
|
if (txs.count() == 0): # done!
|
149
|
126
|
return clusters
|
150
|
127
|
|
|
@@ -158,33 +135,43 @@ def take_tx_and_cluster(txs: DataFrame, clusters: DataFrame):
|
158
|
135
|
.select(
|
159
|
136
|
clusters.id,
|
160
|
137
|
clusters.addresses,
|
|
138
|
+ 'tx_addresses',
|
161
|
139
|
F.arrays_overlap(clusters.addresses,'tx_addresses').alias('overlap')
|
162
|
140
|
) \
|
163
|
141
|
.where(F.col('overlap') == True)
|
164
|
142
|
|
|
143
|
+ clusters_union_tx = master.union_single_col(tx, overlapping_clusters, 'addresses')
|
|
144
|
+
|
165
|
145
|
#collect all addresses into single array field
|
166
|
|
- new_cluster_arr = master.reduce_concat_array_column(
|
167
|
|
- master.union_single_col(tx, overlapping_clusters, 'addresses'),
|
|
146
|
+ new_cluster_arrays = master.reduce_concat_array_column(
|
|
147
|
+ clusters_union_tx,
|
168
|
148
|
column='addresses',
|
169
|
149
|
distinct=True
|
170
|
150
|
)
|
171
|
151
|
|
172
|
152
|
#declare cluster representative
|
173
|
|
- new_cluster = new_cluster_arr \
|
|
153
|
+ new_cluster = new_cluster_arrays \
|
174
|
154
|
.rdd \
|
175
|
155
|
.flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
|
176
|
156
|
.toDF(['address', 'id']) \
|
177
|
157
|
.groupBy('id') \
|
178
|
158
|
.agg(F.collect_list('address').alias('addresses'))
|
179
|
159
|
|
|
160
|
+ txs = txs.join(tx, 'tx_id', 'leftanti')
|
|
161
|
+ clusters = clusters.join(overlapping_clusters, 'id', 'leftanti').union(new_cluster)
|
|
162
|
+
|
|
163
|
+ #the RDD legacy (internal history tracker) gets too big as iterations continue, use checkpoint to prune it regularly
|
|
164
|
+ if(n % 3 == 0):
|
|
165
|
+ txs = txs.checkpoint()
|
|
166
|
+ clusters = clusters.checkpoint()
|
|
167
|
+
|
180
|
168
|
#start new round with txs minus the one just used, and updated clusters
|
181
|
|
- return take_tx_and_cluster(
|
182
|
|
- txs.join(tx, 'index', 'leftanti'),
|
183
|
|
- clusters.join(overlapping_clusters, 'id', 'leftanti').union(new_cluster)
|
184
|
|
- )
|
|
169
|
+ return take_tx_and_cluster(txs,clusters,n+1)
|
185
|
170
|
|
186
|
171
|
|
187
|
|
-take_tx_and_cluster(tx_grouped, clusters_grouped).show()
|
|
172
|
+result = take_tx_and_cluster(tx_grouped, clusters_grouped).collect()
|
|
173
|
+for row in result:
|
|
174
|
+ print(sorted(row['addresses']))
|
188
|
175
|
|
189
|
176
|
end = time.time()
|
190
|
177
|
print("ELAPSED TIME:", end-start)
|