Browse Source

more merging algorithms, add bench script

master
nitowa 1 year ago
parent
commit
681b09e5e0

+ 2
- 1
.gitignore View File

@@ -2,4 +2,5 @@ __pycache__
2 2
 .vscode
3 3
 checkpoints
4 4
 spark-warehouse
5
-scratchpad.py
5
+scratchpad.py
6
+benchmarks

+ 48
- 0
bench.py View File

@@ -0,0 +1,48 @@
1
+import sys
2
+import json
3
+from cassandra.cluster import Cluster
4
+
5
+sys.path.append("config/db")
6
+
7
+from db_read_csv_txs import db_insert_csv_txs
8
+import os
9
+
10
+
11
+config = json.load(open("./settings.json"))
12
+
13
+cluster = Cluster(config['cassandra_addresses'],
14
+                    port=config['cassandra_port'])
15
+session = cluster.connect(config['cassandra_keyspace'])
16
+print(f"Connection OK")
17
+
18
+file = "/home/osboxes/Downloads/zec_tx_inputs.csv"
19
+num_rows = 128
20
+
21
+db_insert_csv_txs(config, file, skip=0, limit=num_rows)
22
+
23
+algorithms = [
24
+    'rik_merge',
25
+    'sve_merge',
26
+    'hoc_merge',
27
+    'nik_merge',
28
+    'rob_merge',
29
+    'agf_merge',
30
+    'agf_opt_merge',
31
+    'che_merge',
32
+    'ale_merge',
33
+    'nik_rew_merge_skip'
34
+]
35
+
36
+for algo in algorithms:
37
+    os.system(f"mkdir -p benchmarks/partition/{algo}")
38
+
39
+for i in range(16):
40
+    for algo in algorithms:
41
+        os.system(f"ALGO={algo} ./submit_partition.sh | sed '1d' | sed '2d' > benchmarks/partition/{algo}/{num_rows}.txt")
42
+        os.system(f"rm -rf ./checkpoints")
43
+    #os.system(f"./submit_graph.sh | sed '1d' | sed '2d' > benchmarks/graph/{num_rows}.txt")
44
+    #os.system(f"rm -rf ./checkpoints")
45
+
46
+    db_insert_csv_txs(config, file, skip=num_rows, limit=num_rows*2)
47
+    num_rows = num_rows*2
48
+    print(num_rows)

BIN
config/db/__pycache__/db_read_csv_txs.cpython-310.pyc View File


+ 20
- 6
config/db/db_read_csv_txs.py View File

@@ -3,7 +3,7 @@ from cassandra.query import BoundStatement, BatchStatement
3 3
 import csv
4 4
 
5 5
 
6
-def db_insert_csv_txs(config, tx_file):
6
+def db_insert_csv_txs(config, tx_file, skip=0, limit=-1):
7 7
     print(" == DB TX INSERTION SCRIPT == ")
8 8
 
9 9
     print(
@@ -13,18 +13,32 @@ def db_insert_csv_txs(config, tx_file):
13 13
     session = cluster.connect(config['cassandra_keyspace'])
14 14
     print(f"Connection OK")
15 15
 
16
-    with open(tx_file, newline='') as tx_csv:
16
+    statement = session.prepare(
17
+            f"INSERT INTO {config['tx_table_name']} (tx_id,address,value,tx_hash,block_id,timestamp) VALUES(?,?,?,?,?,?);")
18
+    boundStatement = BoundStatement(statement)
19
+
20
+    with open(tx_file, newline='') as (tx_csv):
17 21
         rowreader = csv.reader(tx_csv, dialect="excel")
18 22
         next(rowreader)  # skip header
19 23
 
20
-        statement = session.prepare(
21
-            f"INSERT INTO {config['tx_table_name']} (tx_id,address,value,tx_hash,block_id,timestamp) VALUES(?,?,?,?,?,?);")
22
-        boundStatement = BoundStatement(statement)
23 24
         batchStatement = BatchStatement()
24 25
 
25
-        for row in rowreader:
26
+        batch_count = 0
27
+
28
+        for i, row in enumerate(rowreader):
29
+            if i < skip:
30
+                continue
31
+            if i == limit:
32
+                break
33
+
26 34
             batchStatement.add(boundStatement.bind(
27 35
                 [int(row[0]), str(row[1]), int(row[2]), str(row[3]), int(row[4]), int(row[5])]))
36
+            batch_count += 1
37
+            if batch_count > 256:
38
+                session.execute(batchStatement)
39
+                batchStatement = BatchStatement()
40
+                batch_count = 0
41
+
28 42
         session.execute(batchStatement)
29 43
 
30 44
     print("Done!")

+ 3
- 3
config/db/tables/transactions/CREATE.sql View File

@@ -1,9 +1,9 @@
1 1
 CREATE TABLE transactions(
2
-    tx_id INT,
2
+    tx_id bigint,
3 3
     address TEXT,
4
-    value INT,
4
+    value bigint,
5 5
     tx_hash TEXT,
6
-    block_id INT,
6
+    block_id bigint,
7 7
     timestamp TIMESTAMP,
8 8
     PRIMARY KEY (tx_id, address)
9 9
 ) WITH CLUSTERING ORDER BY (address DESC);

+ 11
- 1
setup.py View File

@@ -1,5 +1,6 @@
1 1
 import sys
2 2
 import json
3
+from cassandra.cluster import Cluster
3 4
 
4 5
 sys.path.append("config/db")
5 6
 
@@ -9,4 +10,13 @@ from db_read_csv_txs import db_insert_csv_txs
9 10
 config = json.load(open("./settings.json"))
10 11
 
11 12
 db_setup(config)
12
-db_insert_csv_txs(config, "./small_test_data.csv")
13
+
14
+
15
+cluster = Cluster(config['cassandra_addresses'],
16
+                    port=config['cassandra_port'])
17
+session = cluster.connect(config['cassandra_keyspace'])
18
+print(f"Connection OK")
19
+
20
+#db_insert_csv_txs(config, "./small_test_data.csv", skip=0, limit=1500)
21
+#res = session.execute('SELECT COUNT(*) FROM transactions')
22
+#print(res.one()[0])

+ 12
- 9
src/spark/main_graphs.py View File

@@ -68,16 +68,19 @@ transactions_as_edges = tx_df \
68 68
 g = GraphFrame(addresses_as_vertices, transactions_as_edges)
69 69
 components = g.connectedComponents(algorithm='graphframes')
70 70
 
71
-master.write_connected_components_as_clusters(components)
71
+#master.write_connected_components_as_clusters(components)
72 72
 
73
-if(debug):
74
-    clusters = components \
75
-        .groupBy('component') \
76
-        .agg(F.collect_list('id')) \
77
-        .collect()
78 73
 
79
-    for cluster in clusters:
80
-        print(sorted(cluster['collect_list(id)'])) 
74
+
75
+clusters = components \
76
+    .groupBy('component') \
77
+    .agg(F.collect_list('id')) \
78
+    .collect()
79
+
80
+#print(len(clusters))
81
+
82
+#for cluster in clusters:
83
+#    print(sorted(cluster['collect_list(id)'])) 
81 84
 
82 85
 end = time.time()
83
-print("ELAPSED TIME:", end-start)
86
+print(end-start, end='')

+ 252
- 28
src/spark/main_partition.py View File

@@ -1,16 +1,20 @@
1 1
 import json
2 2
 from typing import Iterable, List, Set
3
-
3
+import networkx
4
+import heapq
5
+from itertools import chain
6
+from collections import deque
7
+import os
4 8
 from pyspark.sql import SparkSession, DataFrame, Row
5 9
 from pyspark.sql import functions as F
6 10
 
7 11
 import time
8 12
 start = time.time()
9 13
 
10
-
11 14
 config = json.load(open("./settings.json"))
12 15
 debug = config['debug']
13 16
 
17
+union_find_algo_name = os.environ['ALGO']
14 18
 
15 19
 class Master:
16 20
     spark: SparkSession
@@ -31,60 +35,280 @@ class Master:
31 35
 
32 36
     def get_tx_dataframe(self) -> DataFrame:
33 37
         return self.spark.table(self.TX_TABLE)
34
-
35 38
 # end class Master
36 39
 
37
-def merge_lists_distinct(*lists: "Iterable[List[str]]") -> List[str]:
38
-    accum = set()
39
-    for lst in lists:
40
-        accum = accum.union(set(lst))
41
-    return list(accum)
42 40
 
43
-def check_lists_overlap(list1, list2):
44
-    return any(x in list1 for x in list2)
41
+def rik_merge(lsts):
42
+    """Rik. Poggi"""
43
+    sets = (set(e) for e in lsts if e)
44
+    results = [next(sets)]
45
+    for e_set in sets:
46
+        to_update = []
47
+        for i,res in enumerate(results):
48
+            if not e_set.isdisjoint(res):
49
+                to_update.insert(0,i)
50
+
51
+        if not to_update:
52
+            results.append(e_set)
53
+        else:
54
+            last = results[to_update.pop(-1)]
55
+            for i in to_update:
56
+                last |= results[i]
57
+                del results[i]
58
+            last |= e_set
45 59
 
46
-def cluster_step(clusters: "List[List[str]]", addresses: "List[List[str]]"):
47
-    #if there are no more sets of addresses to consider, we are done
48
-    if(len(addresses) == 0):
49
-        return clusters
60
+    return results
50 61
 
51
-    tx = addresses[0]
52
-    matching_clusters = []
53
-    new_clusters = []
54 62
 
55
-    for cluster in clusters:
56
-        if(check_lists_overlap(tx, cluster)):
57
-            matching_clusters.append(cluster)
63
+def sve_merge(lsts):
64
+    """Sven Marnach"""
65
+    sets = {}
66
+    for lst in lsts:
67
+        s = set(lst)
68
+        t = set()
69
+        for x in s:
70
+            if x in sets:
71
+                t.update(sets[x])
72
+            else:
73
+                sets[x] = s
74
+        for y in t:
75
+            sets[y] = s
76
+        s.update(t)
77
+    ids = set()
78
+    result = []
79
+    for s in sets.values():
80
+        if id(s) not in ids:
81
+            ids.add(id(s))
82
+            result.append(s)
83
+    return result
84
+
85
+
86
+def hoc_merge(lsts):    # modified a bit to make it return sets
87
+    """hochl"""
88
+    s = [set(lst) for lst in lsts if lst]
89
+    i,n = 0,len(s)
90
+    while i < n-1:
91
+        for j in range(i+1, n):
92
+            if s[i].intersection(s[j]):
93
+                s[i].update(s[j])
94
+                del s[j]
95
+                n -= 1
96
+                break
58 97
         else:
59
-            new_clusters.append(cluster)
98
+            i += 1
99
+    return [set(i) for i in s]
100
+
101
+
102
+def nik_merge(lsts):
103
+    """Niklas B."""
104
+    sets = [set(lst) for lst in lsts if lst]
105
+    merged = 1
106
+    while merged:
107
+        merged = 0
108
+        results = []
109
+        while sets:
110
+            common, rest = sets[0], sets[1:]
111
+            sets = []
112
+            for x in rest:
113
+                if x.isdisjoint(common):
114
+                    sets.append(x)
115
+                else:
116
+                    merged = 1
117
+                    common |= x
118
+            results.append(common)
119
+        sets = results
120
+    return sets
121
+
122
+
123
+
124
+def rob_merge(lsts):
125
+    """robert king"""
126
+    lsts = [sorted(l) for l in lsts]   # I changed this line
127
+    one_list = heapq.merge(*[zip(l,[i]*len(l)) for i,l in enumerate(lsts)])
128
+    previous = next(one_list)
129
+
130
+    d = {i:i for i in range(len(lsts))}
131
+    for current in one_list:
132
+        if current[0]==previous[0]:
133
+            d[current[1]] = d[previous[1]]
134
+        previous=current
135
+
136
+    groups=[[] for i in range(len(lsts))]
137
+    for k in d:
138
+        groups[d[k]].append(lsts[k])
139
+
140
+    return [set(chain(*g)) for g in groups if g]
141
+
142
+
143
+def agf_merge(lsts):
144
+    """agf"""
145
+    newsets, sets = [set(lst) for lst in lsts if lst], []
146
+    while len(sets) != len(newsets):
147
+        sets, newsets = newsets, []
148
+        for aset in sets:
149
+            for eachset in newsets:
150
+                if not aset.isdisjoint(eachset):
151
+                    eachset.update(aset)
152
+                    break
153
+            else:
154
+                newsets.append(aset)
155
+    return newsets
60 156
 
61
-    new_clusters.append(merge_lists_distinct(tx, *matching_clusters))
62 157
 
63
-    return cluster_step(new_clusters,addresses[1:])
158
+def agf_opt_merge(lists):
159
+    """agf (optimized)"""
160
+    sets = deque(set(lst) for lst in lists if lst)
161
+    results = []
162
+    disjoint = 0
163
+    current = sets.pop()
164
+    while True:
165
+        merged = False
166
+        newsets = deque()
167
+        for _ in range(disjoint, len(sets)):
168
+            this = sets.pop()
169
+            if not current.isdisjoint(this):
170
+                current.update(this)
171
+                merged = True
172
+                disjoint = 0
173
+            else:
174
+                newsets.append(this)
175
+                disjoint += 1
176
+        if sets:
177
+            newsets.extendleft(sets)
178
+        if not merged:
179
+            results.append(current)
180
+            try:
181
+                current = newsets.pop()
182
+            except IndexError:
183
+                break
184
+            disjoint = 0
185
+        sets = newsets
186
+    return results
187
+
188
+
189
+def che_merge(lsts):
190
+    """ChessMaster"""
191
+    results, sets = [], [set(lst) for lst in lsts if lst]
192
+    upd, isd, pop = set.update, set.isdisjoint, sets.pop
193
+    while sets:
194
+        if not [upd(sets[0],pop(i)) for i in range(len(sets)-1,0,-1) if not isd(sets[0],sets[i])]:
195
+            results.append(pop(0))
196
+    return results
197
+
198
+
199
+def locatebin(bins, n):
200
+    """Find the bin where list n has ended up: Follow bin references until
201
+    we find a bin that has not moved.
202
+    
203
+    """
204
+    while bins[n] != n:
205
+        n = bins[n]
206
+    return n
207
+
208
+
209
+def ale_merge(data):
210
+    """alexis"""
211
+    bins = list(range(len(data)))  # Initialize each bin[n] == n
212
+    nums = dict()
213
+
214
+    data = [set(m) for m in data ]  # Convert to sets    
215
+    for r, row in enumerate(data):
216
+        for num in row:
217
+            if num not in nums:
218
+                # New number: tag it with a pointer to this row's bin
219
+                nums[num] = r
220
+                continue
221
+            else:
222
+                dest = locatebin(bins, nums[num])
223
+                if dest == r:
224
+                    continue # already in the same bin
225
+
226
+                if dest > r:
227
+                    dest, r = r, dest   # always merge into the smallest bin
228
+
229
+                data[dest].update(data[r]) 
230
+                data[r] = None
231
+                # Update our indices to reflect the move
232
+                bins[r] = dest
233
+                r = dest 
234
+
235
+    # Filter out the empty bins
236
+    have = [ m for m in data if m ]
237
+    #print len(have), "groups in result"  #removed this line
238
+    return have
239
+
240
+
241
+def nik_rew_merge_skip(lsts):
242
+    """Nik's rewrite"""
243
+    sets = list(map(set,lsts))
244
+    results = []
245
+    while sets:
246
+        first, rest = sets[0], sets[1:]
247
+        merged = False
248
+        sets = []
249
+        for s in rest:
250
+            if s and s.isdisjoint(first):
251
+                sets.append(s)
252
+            else:
253
+                first |= s
254
+                merged = True
255
+        if merged:
256
+            sets.append(first)
257
+        else:
258
+            results.append(first)
259
+    return results
260
+
261
+def union_find(clusters: "List[List[str]]", addresses: "List[List[str]]"):
262
+    data = clusters + addresses
263
+    match union_find_algo_name:
264
+        case 'rik_merge':
265
+            return rik_merge(data)
266
+        case 'sve_merge':
267
+            return sve_merge(data)
268
+        case 'hoc_merge':
269
+            return hoc_merge(data)
270
+        case 'nik_merge':
271
+            return nik_merge(data)
272
+        case 'rob_merge':
273
+            return rob_merge(data)
274
+        case 'agf_merge':
275
+            return agf_merge(data)
276
+        case 'agf_opt_merge':
277
+            return agf_opt_merge(data)
278
+        case 'che_merge':
279
+            return che_merge(data)
280
+        case 'ale_merge':
281
+            return ale_merge(data)
282
+        case 'nik_rew_merge_skip':
283
+            return nik_rew_merge_skip(data)
284
+        case _:
285
+            raise NameError("Unset or unknown algorithm")
64 286
 
65 287
 def cluster_partition(iter: "Iterable[Row]") -> Iterable:
66
-    yield cluster_step([], list(map(lambda row: row['addresses'], iter)))
288
+    yield union_find([], list(map(lambda row: row['addresses'], iter)))
67 289
     
68 290
 master = Master(config)
69 291
 master.spark.catalog.clearCache()
70 292
 master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir'])
71 293
 tx_df = master.get_tx_dataframe()
72 294
 
295
+
73 296
 #Turn transactions into a list of ('id', [addr, addr, ...])
74 297
 tx_grouped = tx_df \
75 298
     .groupBy('tx_id') \
76
-    .agg(F.collect_set('address').alias('addresses')) \
77
-    .orderBy('tx_id') \
299
+    .agg(F.collect_set('address').alias('addresses'))
78 300
 
79 301
 res = tx_grouped \
80 302
     .repartition(5) \
81 303
     .rdd \
82 304
     .mapPartitions(cluster_partition) \
83
-    .fold([], cluster_step)
305
+    .fold([], union_find)
84 306
 
307
+"""
85 308
 for cluster in res:
86 309
     print()
87 310
     print(sorted(cluster))
311
+"""
88 312
 
89 313
 end = time.time()
90
-print("ELAPSED TIME:", end-start)
314
+print(end-start, end='')

+ 68
- 0
src/spark/naive_implementation.py View File

@@ -0,0 +1,68 @@
1
+from typing import Iterable, List, Set
2
+
3
+def merge_lists_distinct(*lists: "Iterable[List[str]]") -> List[str]:
4
+    accum = set()
5
+    for lst in lists:
6
+        accum = accum.union(set(lst))
7
+    return list(accum)
8
+
9
+def check_lists_overlap(list1, list2):
10
+    return any(x in list1 for x in list2)
11
+
12
+def cluster_step(clusters: "List[List[str]]", addresses: "List[List[str]]"):
13
+    #if there are no more sets of addresses to consider, we are done
14
+    if(len(addresses) == 0):
15
+        return clusters
16
+
17
+    tx = addresses[0]
18
+    matching_clusters = []
19
+    new_clusters = []
20
+
21
+    for cluster in clusters:
22
+        if(check_lists_overlap(tx, cluster)):
23
+            matching_clusters.append(cluster)
24
+        else:
25
+            new_clusters.append(cluster)
26
+
27
+    new_clusters.append(merge_lists_distinct(tx, *matching_clusters))
28
+
29
+    return cluster_step(new_clusters,addresses[1:])
30
+
31
+def cluster_step_iter(clusters: "List[List[str]]", addresses: "List[List[str]]"):
32
+
33
+    clstr = clusters
34
+    addrs = addresses
35
+
36
+    while True:
37
+        if(len(addrs) == 0):
38
+            break
39
+
40
+        tx = addrs[0]
41
+        matching_clusters = []
42
+        new_clusters = []
43
+
44
+        for cluster in clstr:
45
+            if(check_lists_overlap(tx, cluster)):
46
+                matching_clusters.append(cluster)
47
+            else:
48
+                new_clusters.append(cluster)
49
+
50
+        new_clusters.append(merge_lists_distinct(tx, *matching_clusters))
51
+        clstr = new_clusters
52
+        addrs = addrs[1:]
53
+
54
+    return clstr
55
+
56
+def cluster_n(clusters: "List[List[str]]", addresses: "List[List[str]]"):
57
+    tx_sets = map(set, clusters+addresses)
58
+    unions = []
59
+    for tx in tx_sets:
60
+        temp = []
61
+        for s in unions:
62
+            if not s.isdisjoint(tx):
63
+                tx = s.union(tx)
64
+            else:
65
+                temp.append(s)
66
+        temp.append(tx)
67
+        unions = temp
68
+    return unions

Loading…
Cancel
Save