123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314 |
- import json
- from typing import Iterable, List, Set
- import networkx
- import heapq
- from itertools import chain
- from collections import deque
- import os
- from pyspark.sql import SparkSession, DataFrame, Row
- from pyspark.sql import functions as F
-
- import time
- start = time.time()
-
- config = json.load(open("./settings.json"))
- debug = config['debug']
-
- union_find_algo_name = os.environ['ALGO']
-
- class Master:
- spark: SparkSession
- CLUSTERS_TABLE: str
- TX_TABLE: str
-
- def __init__(self, config):
- self.spark = self.makeSparkContext(config)
- self.config = config
- self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
- self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
-
- def makeSparkContext(self, config) -> SparkSession:
- return SparkSession.builder \
- .appName('SparkCassandraApp') \
- .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
- .getOrCreate()
-
- def get_tx_dataframe(self) -> DataFrame:
- return self.spark.table(self.TX_TABLE)
- # end class Master
-
-
- def rik_merge(lsts):
- """Rik. Poggi"""
- sets = (set(e) for e in lsts if e)
- results = [next(sets)]
- for e_set in sets:
- to_update = []
- for i,res in enumerate(results):
- if not e_set.isdisjoint(res):
- to_update.insert(0,i)
-
- if not to_update:
- results.append(e_set)
- else:
- last = results[to_update.pop(-1)]
- for i in to_update:
- last |= results[i]
- del results[i]
- last |= e_set
-
- return results
-
-
- def sve_merge(lsts):
- """Sven Marnach"""
- sets = {}
- for lst in lsts:
- s = set(lst)
- t = set()
- for x in s:
- if x in sets:
- t.update(sets[x])
- else:
- sets[x] = s
- for y in t:
- sets[y] = s
- s.update(t)
- ids = set()
- result = []
- for s in sets.values():
- if id(s) not in ids:
- ids.add(id(s))
- result.append(s)
- return result
-
-
- def hoc_merge(lsts): # modified a bit to make it return sets
- """hochl"""
- s = [set(lst) for lst in lsts if lst]
- i,n = 0,len(s)
- while i < n-1:
- for j in range(i+1, n):
- if s[i].intersection(s[j]):
- s[i].update(s[j])
- del s[j]
- n -= 1
- break
- else:
- i += 1
- return [set(i) for i in s]
-
-
- def nik_merge(lsts):
- """Niklas B."""
- sets = [set(lst) for lst in lsts if lst]
- merged = 1
- while merged:
- merged = 0
- results = []
- while sets:
- common, rest = sets[0], sets[1:]
- sets = []
- for x in rest:
- if x.isdisjoint(common):
- sets.append(x)
- else:
- merged = 1
- common |= x
- results.append(common)
- sets = results
- return sets
-
-
-
- def rob_merge(lsts):
- """robert king"""
- lsts = [sorted(l) for l in lsts] # I changed this line
- one_list = heapq.merge(*[zip(l,[i]*len(l)) for i,l in enumerate(lsts)])
- previous = next(one_list)
-
- d = {i:i for i in range(len(lsts))}
- for current in one_list:
- if current[0]==previous[0]:
- d[current[1]] = d[previous[1]]
- previous=current
-
- groups=[[] for i in range(len(lsts))]
- for k in d:
- groups[d[k]].append(lsts[k])
-
- return [set(chain(*g)) for g in groups if g]
-
-
- def agf_merge(lsts):
- """agf"""
- newsets, sets = [set(lst) for lst in lsts if lst], []
- while len(sets) != len(newsets):
- sets, newsets = newsets, []
- for aset in sets:
- for eachset in newsets:
- if not aset.isdisjoint(eachset):
- eachset.update(aset)
- break
- else:
- newsets.append(aset)
- return newsets
-
-
- def agf_opt_merge(lists):
- """agf (optimized)"""
- sets = deque(set(lst) for lst in lists if lst)
- results = []
- disjoint = 0
- current = sets.pop()
- while True:
- merged = False
- newsets = deque()
- for _ in range(disjoint, len(sets)):
- this = sets.pop()
- if not current.isdisjoint(this):
- current.update(this)
- merged = True
- disjoint = 0
- else:
- newsets.append(this)
- disjoint += 1
- if sets:
- newsets.extendleft(sets)
- if not merged:
- results.append(current)
- try:
- current = newsets.pop()
- except IndexError:
- break
- disjoint = 0
- sets = newsets
- return results
-
-
- def che_merge(lsts):
- """ChessMaster"""
- results, sets = [], [set(lst) for lst in lsts if lst]
- upd, isd, pop = set.update, set.isdisjoint, sets.pop
- while sets:
- if not [upd(sets[0],pop(i)) for i in range(len(sets)-1,0,-1) if not isd(sets[0],sets[i])]:
- results.append(pop(0))
- return results
-
-
- def locatebin(bins, n):
- """Find the bin where list n has ended up: Follow bin references until
- we find a bin that has not moved.
-
- """
- while bins[n] != n:
- n = bins[n]
- return n
-
-
- def ale_merge(data):
- """alexis"""
- bins = list(range(len(data))) # Initialize each bin[n] == n
- nums = dict()
-
- data = [set(m) for m in data ] # Convert to sets
- for r, row in enumerate(data):
- for num in row:
- if num not in nums:
- # New number: tag it with a pointer to this row's bin
- nums[num] = r
- continue
- else:
- dest = locatebin(bins, nums[num])
- if dest == r:
- continue # already in the same bin
-
- if dest > r:
- dest, r = r, dest # always merge into the smallest bin
-
- data[dest].update(data[r])
- data[r] = None
- # Update our indices to reflect the move
- bins[r] = dest
- r = dest
-
- # Filter out the empty bins
- have = [ m for m in data if m ]
- #print len(have), "groups in result" #removed this line
- return have
-
-
- def nik_rew_merge_skip(lsts):
- """Nik's rewrite"""
- sets = list(map(set,lsts))
- results = []
- while sets:
- first, rest = sets[0], sets[1:]
- merged = False
- sets = []
- for s in rest:
- if s and s.isdisjoint(first):
- sets.append(s)
- else:
- first |= s
- merged = True
- if merged:
- sets.append(first)
- else:
- results.append(first)
- return results
-
- def union_find(clusters: "List[List[str]]", addresses: "List[List[str]]"):
- data = clusters + addresses
- match union_find_algo_name:
- case 'rik_merge':
- return rik_merge(data)
- case 'sve_merge':
- return sve_merge(data)
- case 'hoc_merge':
- return hoc_merge(data)
- case 'nik_merge':
- return nik_merge(data)
- case 'rob_merge':
- return rob_merge(data)
- case 'agf_merge':
- return agf_merge(data)
- case 'agf_opt_merge':
- return agf_opt_merge(data)
- case 'che_merge':
- return che_merge(data)
- case 'ale_merge':
- return ale_merge(data)
- case 'nik_rew_merge_skip':
- return nik_rew_merge_skip(data)
- case _:
- raise NameError("Unset or unknown algorithm")
-
- def cluster_partition(iter: "Iterable[Row]") -> Iterable:
- yield union_find([], list(map(lambda row: row['addresses'], iter)))
-
- master = Master(config)
- master.spark.catalog.clearCache()
- master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir'])
- tx_df = master.get_tx_dataframe()
-
-
- #Turn transactions into a list of ('id', [addr, addr, ...])
- tx_grouped = tx_df \
- .groupBy('tx_id') \
- .agg(F.collect_set('address').alias('addresses'))
-
- res = tx_grouped \
- .repartition(5) \
- .rdd \
- .mapPartitions(cluster_partition) \
- .fold([], union_find)
-
- """
- for cluster in res:
- print()
- print(sorted(cluster))
- """
-
- end = time.time()
- print(end-start, end='')
|