import json from typing import Iterable, List, Set import networkx import heapq from itertools import chain from collections import deque import os from pyspark.sql import SparkSession, DataFrame, Row from pyspark.sql import functions as F import time start = time.time() config = json.load(open("./settings.json")) debug = config['debug'] union_find_algo_name = os.environ['ALGO'] class Master: spark: SparkSession CLUSTERS_TABLE: str TX_TABLE: str def __init__(self, config): self.spark = self.makeSparkContext(config) self.config = config self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}" self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}" def makeSparkContext(self, config) -> SparkSession: return SparkSession.builder \ .appName('SparkCassandraApp') \ .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \ .getOrCreate() def get_tx_dataframe(self) -> DataFrame: return self.spark.table(self.TX_TABLE) # end class Master def rik_merge(lsts): """Rik. Poggi""" sets = (set(e) for e in lsts if e) results = [next(sets)] for e_set in sets: to_update = [] for i,res in enumerate(results): if not e_set.isdisjoint(res): to_update.insert(0,i) if not to_update: results.append(e_set) else: last = results[to_update.pop(-1)] for i in to_update: last |= results[i] del results[i] last |= e_set return results def sve_merge(lsts): """Sven Marnach""" sets = {} for lst in lsts: s = set(lst) t = set() for x in s: if x in sets: t.update(sets[x]) else: sets[x] = s for y in t: sets[y] = s s.update(t) ids = set() result = [] for s in sets.values(): if id(s) not in ids: ids.add(id(s)) result.append(s) return result def hoc_merge(lsts): # modified a bit to make it return sets """hochl""" s = [set(lst) for lst in lsts if lst] i,n = 0,len(s) while i < n-1: for j in range(i+1, n): if s[i].intersection(s[j]): s[i].update(s[j]) del s[j] n -= 1 break else: i += 1 return [set(i) for i in s] def nik_merge(lsts): """Niklas B.""" sets = [set(lst) for lst in lsts if lst] merged = 1 while merged: merged = 0 results = [] while sets: common, rest = sets[0], sets[1:] sets = [] for x in rest: if x.isdisjoint(common): sets.append(x) else: merged = 1 common |= x results.append(common) sets = results return sets def rob_merge(lsts): """robert king""" lsts = [sorted(l) for l in lsts] # I changed this line one_list = heapq.merge(*[zip(l,[i]*len(l)) for i,l in enumerate(lsts)]) previous = next(one_list) d = {i:i for i in range(len(lsts))} for current in one_list: if current[0]==previous[0]: d[current[1]] = d[previous[1]] previous=current groups=[[] for i in range(len(lsts))] for k in d: groups[d[k]].append(lsts[k]) return [set(chain(*g)) for g in groups if g] def agf_merge(lsts): """agf""" newsets, sets = [set(lst) for lst in lsts if lst], [] while len(sets) != len(newsets): sets, newsets = newsets, [] for aset in sets: for eachset in newsets: if not aset.isdisjoint(eachset): eachset.update(aset) break else: newsets.append(aset) return newsets def agf_opt_merge(lists): """agf (optimized)""" sets = deque(set(lst) for lst in lists if lst) results = [] disjoint = 0 current = sets.pop() while True: merged = False newsets = deque() for _ in range(disjoint, len(sets)): this = sets.pop() if not current.isdisjoint(this): current.update(this) merged = True disjoint = 0 else: newsets.append(this) disjoint += 1 if sets: newsets.extendleft(sets) if not merged: results.append(current) try: current = newsets.pop() except IndexError: break disjoint = 0 sets = newsets return results def che_merge(lsts): """ChessMaster""" results, sets = [], [set(lst) for lst in lsts if lst] upd, isd, pop = set.update, set.isdisjoint, sets.pop while sets: if not [upd(sets[0],pop(i)) for i in range(len(sets)-1,0,-1) if not isd(sets[0],sets[i])]: results.append(pop(0)) return results def locatebin(bins, n): """Find the bin where list n has ended up: Follow bin references until we find a bin that has not moved. """ while bins[n] != n: n = bins[n] return n def ale_merge(data): """alexis""" bins = list(range(len(data))) # Initialize each bin[n] == n nums = dict() data = [set(m) for m in data ] # Convert to sets for r, row in enumerate(data): for num in row: if num not in nums: # New number: tag it with a pointer to this row's bin nums[num] = r continue else: dest = locatebin(bins, nums[num]) if dest == r: continue # already in the same bin if dest > r: dest, r = r, dest # always merge into the smallest bin data[dest].update(data[r]) data[r] = None # Update our indices to reflect the move bins[r] = dest r = dest # Filter out the empty bins have = [ m for m in data if m ] #print len(have), "groups in result" #removed this line return have def nik_rew_merge_skip(lsts): """Nik's rewrite""" sets = list(map(set,lsts)) results = [] while sets: first, rest = sets[0], sets[1:] merged = False sets = [] for s in rest: if s and s.isdisjoint(first): sets.append(s) else: first |= s merged = True if merged: sets.append(first) else: results.append(first) return results def union_find(clusters: "List[List[str]]", addresses: "List[List[str]]"): data = clusters + addresses match union_find_algo_name: case 'rik_merge': return rik_merge(data) case 'sve_merge': return sve_merge(data) case 'hoc_merge': return hoc_merge(data) case 'nik_merge': return nik_merge(data) case 'rob_merge': return rob_merge(data) case 'agf_merge': return agf_merge(data) case 'agf_opt_merge': return agf_opt_merge(data) case 'che_merge': return che_merge(data) case 'ale_merge': return ale_merge(data) case 'nik_rew_merge_skip': return nik_rew_merge_skip(data) case _: raise NameError("Unset or unknown algorithm") def cluster_partition(iter: "Iterable[Row]") -> Iterable: yield union_find([], list(map(lambda row: row['addresses'], iter))) master = Master(config) master.spark.catalog.clearCache() master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir']) tx_df = master.get_tx_dataframe() #Turn transactions into a list of ('id', [addr, addr, ...]) tx_grouped = tx_df \ .groupBy('tx_id') \ .agg(F.collect_set('address').alias('addresses')) res = tx_grouped \ .repartition(5) \ .rdd \ .mapPartitions(cluster_partition) \ .fold([], union_find) """ for cluster in res: print() print(sorted(cluster)) """ end = time.time() print(end-start, end='')