Browse Source

working non-graph implementation

master
nitowa 1 year ago
parent
commit
9cb6827c5e
1 changed files with 40 additions and 53 deletions
  1. 40
    53
      src/spark/main.py

+ 40
- 53
src/spark/main.py View File

@@ -1,9 +1,4 @@
1
-from gc import collect
2 1
 import json
3
-from select import select
4
-
5
-from sqlite3 import Row
6
-from typing import Iterable, List
7 2
 
8 3
 from pyspark.sql import SparkSession, DataFrame, Row
9 4
 from pyspark.sql import functions as F
@@ -56,46 +51,18 @@ class Master:
56 51
             .rdd \
57 52
             .flatMap(lambda row: list(map(lambda elem: (elem,), row[column]))) \
58 53
             .toDF([column])
59
-
60
-    def array_col_to_elements(self, df: DataFrame, column: str, distinct:bool = False) -> DataFrame:
61
-        exploded = master.explode_array_col(
62
-            df,
63
-            column
64
-        )
65
-
66
-        #this is likely redundant
67
-        collected = master.collect_col_to_array(
68
-            exploded, 
69
-            column, 
70
-            distinct
71
-        )
72
-
73
-        return self.explode_array_col(
74
-            collected,
75
-            column
76
-        )
77
-
78 54
 # end class Master
79 55
 
80 56
 
81 57
 master = Master(config)
58
+master.spark.catalog.clearCache()
59
+master.spark.sparkContext.setCheckpointDir('./checkpoints')
82 60
 tx_df = master.get_tx_dataframe()
83 61
 
84
-
85 62
 #Turn transactions into a list of ('id', [addr, addr, ...])
86 63
 tx_grouped = tx_df \
87 64
     .groupBy('tx_id') \
88
-    .agg(F.collect_set('address').alias('addresses')) \
89
-    .rdd \
90
-    .zipWithIndex() \
91
-    .toDF(['tx', 'index']) \
92
-    .select(
93
-        F.col('tx.tx_id').alias('tx_id'),
94
-        F.col('tx.addresses').alias('addresses'),
95
-        'index'
96
-    ) \
97
-    .cache()
98
-
65
+    .agg(F.collect_set('address').alias('addresses'))
99 66
 
100 67
 # TODO: Load clusters from DB, check if any exist, if no make initial cluster, else proceed with loaded data
101 68
 
@@ -103,29 +70,39 @@ tx_grouped = tx_df \
103 70
 
104 71
 # take the first tx
105 72
 tx_zero = tx_grouped \
106
-    .select(tx_grouped.tx_id, tx_grouped.addresses) \
107
-    .where(tx_grouped.index == 0)
73
+    .select('*') \
74
+    .where('tx_id = 3') \
75
+    .limit(1)
108 76
 
109 77
 # find txs with overlapping addresses
110 78
 overlapping_txs = tx_grouped \
111
-    .where((tx_grouped.index != 0)) \
112
-    .join(tx_zero.withColumnRenamed('addresses', 'tx_addresses')) \
79
+    .join(
80
+        tx_zero \
81
+            .withColumnRenamed('addresses', 'tx_addresses') \
82
+            .withColumnRenamed('tx_id', 'overlap_id')
83
+    ) \
113 84
     .select(
114
-        tx_grouped.index,
85
+        tx_grouped.tx_id,
115 86
         tx_grouped.addresses,
116 87
         F.arrays_overlap(tx_grouped.addresses, 'tx_addresses').alias('overlap')
117 88
     ) \
118 89
     .where(F.col('overlap') == True) \
90
+    .drop('overlap')
119 91
 
120 92
 # overlapped txs must not be considered anymore, so remove them candidate dataframe
121 93
 tx_grouped = tx_grouped \
122
-    .join(overlapping_txs, 'index', 'leftanti') \
123
-    .filter(tx_grouped.index != 0)
94
+    .join(
95
+        overlapping_txs.drop('addresses'), 
96
+        'tx_id', 
97
+        'leftanti'
98
+    )
124 99
 
125 100
 # get the distinct addresses of all overlaps in a single array
126 101
 distinct_addresses = master.reduce_concat_array_column(
127 102
     master.union_single_col(
128
-        overlapping_txs, tx_zero, column='addresses'
103
+        overlapping_txs, 
104
+        tx_zero, 
105
+        column='addresses'
129 106
     ), 
130 107
     column='addresses',
131 108
     distinct=True,
@@ -144,7 +121,7 @@ clusters_grouped = cluster \
144 121
     .groupBy('id') \
145 122
     .agg(F.collect_list('address').alias('addresses'))
146 123
 
147
-def take_tx_and_cluster(txs: DataFrame, clusters: DataFrame):
124
+def take_tx_and_cluster(txs: DataFrame, clusters: DataFrame, n=0):
148 125
     if (txs.count() == 0):  # done!
149 126
         return clusters
150 127
 
@@ -158,33 +135,43 @@ def take_tx_and_cluster(txs: DataFrame, clusters: DataFrame):
158 135
         .select(
159 136
             clusters.id,
160 137
             clusters.addresses,
138
+            'tx_addresses',
161 139
             F.arrays_overlap(clusters.addresses,'tx_addresses').alias('overlap')
162 140
         ) \
163 141
         .where(F.col('overlap') == True)
164 142
 
143
+    clusters_union_tx = master.union_single_col(tx, overlapping_clusters, 'addresses')
144
+
165 145
     #collect all addresses into single array field
166
-    new_cluster_arr = master.reduce_concat_array_column(
167
-        master.union_single_col(tx, overlapping_clusters, 'addresses'),
146
+    new_cluster_arrays = master.reduce_concat_array_column(
147
+        clusters_union_tx,
168 148
         column='addresses',
169 149
         distinct=True
170 150
     )
171 151
 
172 152
     #declare cluster representative
173
-    new_cluster = new_cluster_arr \
153
+    new_cluster = new_cluster_arrays \
174 154
         .rdd \
175 155
         .flatMap(lambda row: list(map(lambda addr: (addr, row['addresses'][0]), row['addresses']))) \
176 156
         .toDF(['address', 'id']) \
177 157
         .groupBy('id') \
178 158
         .agg(F.collect_list('address').alias('addresses'))
179 159
 
160
+    txs = txs.join(tx, 'tx_id', 'leftanti')
161
+    clusters = clusters.join(overlapping_clusters, 'id', 'leftanti').union(new_cluster)
162
+
163
+    #the RDD legacy (internal history tracker) gets too big as iterations continue, use checkpoint to prune it regularly
164
+    if(n % 3 == 0):
165
+        txs = txs.checkpoint()
166
+        clusters = clusters.checkpoint()
167
+
180 168
     #start new round with txs minus the one just used, and updated clusters
181
-    return take_tx_and_cluster(
182
-        txs.join(tx, 'index', 'leftanti'), 
183
-        clusters.join(overlapping_clusters, 'id', 'leftanti').union(new_cluster)
184
-    )
169
+    return take_tx_and_cluster(txs,clusters,n+1)
185 170
 
186 171
 
187
-take_tx_and_cluster(tx_grouped, clusters_grouped).show()
172
+result = take_tx_and_cluster(tx_grouped, clusters_grouped).collect()
173
+for row in result:
174
+    print(sorted(row['addresses']))
188 175
 
189 176
 end = time.time()
190 177
 print("ELAPSED TIME:", end-start)

Loading…
Cancel
Save