You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_partition.py 8.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import json
  2. from typing import Iterable, List, Set
  3. import networkx
  4. import heapq
  5. from itertools import chain
  6. from collections import deque
  7. import os
  8. from pyspark.sql import SparkSession, DataFrame, Row
  9. from pyspark.sql import functions as F
  10. import time
  11. start = time.time()
  12. config = json.load(open("./settings.json"))
  13. debug = config['debug']
  14. union_find_algo_name = os.environ['ALGO']
  15. class Master:
  16. spark: SparkSession
  17. CLUSTERS_TABLE: str
  18. TX_TABLE: str
  19. def __init__(self, config):
  20. self.spark = self.makeSparkContext(config)
  21. self.config = config
  22. self.CLUSTERS_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['clusters_table_name']}"
  23. self.TX_TABLE = f"{config['cassandra_catalog']}.{config['cassandra_keyspace']}.{config['tx_table_name']}"
  24. def makeSparkContext(self, config) -> SparkSession:
  25. return SparkSession.builder \
  26. .appName('SparkCassandraApp') \
  27. .config(f"spark.sql.catalog.{config['cassandra_catalog']}", "com.datastax.spark.connector.datasource.CassandraCatalog") \
  28. .getOrCreate()
  29. def get_tx_dataframe(self) -> DataFrame:
  30. return self.spark.table(self.TX_TABLE)
  31. # end class Master
  32. def rik_merge(lsts):
  33. """Rik. Poggi"""
  34. sets = (set(e) for e in lsts if e)
  35. results = [next(sets)]
  36. for e_set in sets:
  37. to_update = []
  38. for i,res in enumerate(results):
  39. if not e_set.isdisjoint(res):
  40. to_update.insert(0,i)
  41. if not to_update:
  42. results.append(e_set)
  43. else:
  44. last = results[to_update.pop(-1)]
  45. for i in to_update:
  46. last |= results[i]
  47. del results[i]
  48. last |= e_set
  49. return results
  50. def sve_merge(lsts):
  51. """Sven Marnach"""
  52. sets = {}
  53. for lst in lsts:
  54. s = set(lst)
  55. t = set()
  56. for x in s:
  57. if x in sets:
  58. t.update(sets[x])
  59. else:
  60. sets[x] = s
  61. for y in t:
  62. sets[y] = s
  63. s.update(t)
  64. ids = set()
  65. result = []
  66. for s in sets.values():
  67. if id(s) not in ids:
  68. ids.add(id(s))
  69. result.append(s)
  70. return result
  71. def hoc_merge(lsts): # modified a bit to make it return sets
  72. """hochl"""
  73. s = [set(lst) for lst in lsts if lst]
  74. i,n = 0,len(s)
  75. while i < n-1:
  76. for j in range(i+1, n):
  77. if s[i].intersection(s[j]):
  78. s[i].update(s[j])
  79. del s[j]
  80. n -= 1
  81. break
  82. else:
  83. i += 1
  84. return [set(i) for i in s]
  85. def nik_merge(lsts):
  86. """Niklas B."""
  87. sets = [set(lst) for lst in lsts if lst]
  88. merged = 1
  89. while merged:
  90. merged = 0
  91. results = []
  92. while sets:
  93. common, rest = sets[0], sets[1:]
  94. sets = []
  95. for x in rest:
  96. if x.isdisjoint(common):
  97. sets.append(x)
  98. else:
  99. merged = 1
  100. common |= x
  101. results.append(common)
  102. sets = results
  103. return sets
  104. def rob_merge(lsts):
  105. """robert king"""
  106. lsts = [sorted(l) for l in lsts] # I changed this line
  107. one_list = heapq.merge(*[zip(l,[i]*len(l)) for i,l in enumerate(lsts)])
  108. previous = next(one_list)
  109. d = {i:i for i in range(len(lsts))}
  110. for current in one_list:
  111. if current[0]==previous[0]:
  112. d[current[1]] = d[previous[1]]
  113. previous=current
  114. groups=[[] for i in range(len(lsts))]
  115. for k in d:
  116. groups[d[k]].append(lsts[k])
  117. return [set(chain(*g)) for g in groups if g]
  118. def agf_merge(lsts):
  119. """agf"""
  120. newsets, sets = [set(lst) for lst in lsts if lst], []
  121. while len(sets) != len(newsets):
  122. sets, newsets = newsets, []
  123. for aset in sets:
  124. for eachset in newsets:
  125. if not aset.isdisjoint(eachset):
  126. eachset.update(aset)
  127. break
  128. else:
  129. newsets.append(aset)
  130. return newsets
  131. def agf_opt_merge(lists):
  132. """agf (optimized)"""
  133. sets = deque(set(lst) for lst in lists if lst)
  134. results = []
  135. disjoint = 0
  136. current = sets.pop()
  137. while True:
  138. merged = False
  139. newsets = deque()
  140. for _ in range(disjoint, len(sets)):
  141. this = sets.pop()
  142. if not current.isdisjoint(this):
  143. current.update(this)
  144. merged = True
  145. disjoint = 0
  146. else:
  147. newsets.append(this)
  148. disjoint += 1
  149. if sets:
  150. newsets.extendleft(sets)
  151. if not merged:
  152. results.append(current)
  153. try:
  154. current = newsets.pop()
  155. except IndexError:
  156. break
  157. disjoint = 0
  158. sets = newsets
  159. return results
  160. def che_merge(lsts):
  161. """ChessMaster"""
  162. results, sets = [], [set(lst) for lst in lsts if lst]
  163. upd, isd, pop = set.update, set.isdisjoint, sets.pop
  164. while sets:
  165. if not [upd(sets[0],pop(i)) for i in range(len(sets)-1,0,-1) if not isd(sets[0],sets[i])]:
  166. results.append(pop(0))
  167. return results
  168. def locatebin(bins, n):
  169. """Find the bin where list n has ended up: Follow bin references until
  170. we find a bin that has not moved.
  171. """
  172. while bins[n] != n:
  173. n = bins[n]
  174. return n
  175. def ale_merge(data):
  176. """alexis"""
  177. bins = list(range(len(data))) # Initialize each bin[n] == n
  178. nums = dict()
  179. data = [set(m) for m in data ] # Convert to sets
  180. for r, row in enumerate(data):
  181. for num in row:
  182. if num not in nums:
  183. # New number: tag it with a pointer to this row's bin
  184. nums[num] = r
  185. continue
  186. else:
  187. dest = locatebin(bins, nums[num])
  188. if dest == r:
  189. continue # already in the same bin
  190. if dest > r:
  191. dest, r = r, dest # always merge into the smallest bin
  192. data[dest].update(data[r])
  193. data[r] = None
  194. # Update our indices to reflect the move
  195. bins[r] = dest
  196. r = dest
  197. # Filter out the empty bins
  198. have = [ m for m in data if m ]
  199. #print len(have), "groups in result" #removed this line
  200. return have
  201. def nik_rew_merge_skip(lsts):
  202. """Nik's rewrite"""
  203. sets = list(map(set,lsts))
  204. results = []
  205. while sets:
  206. first, rest = sets[0], sets[1:]
  207. merged = False
  208. sets = []
  209. for s in rest:
  210. if s and s.isdisjoint(first):
  211. sets.append(s)
  212. else:
  213. first |= s
  214. merged = True
  215. if merged:
  216. sets.append(first)
  217. else:
  218. results.append(first)
  219. return results
  220. def union_find(clusters: "List[List[str]]", addresses: "List[List[str]]"):
  221. data = clusters + addresses
  222. match union_find_algo_name:
  223. case 'rik_merge':
  224. return rik_merge(data)
  225. case 'sve_merge':
  226. return sve_merge(data)
  227. case 'hoc_merge':
  228. return hoc_merge(data)
  229. case 'nik_merge':
  230. return nik_merge(data)
  231. case 'rob_merge':
  232. return rob_merge(data)
  233. case 'agf_merge':
  234. return agf_merge(data)
  235. case 'agf_opt_merge':
  236. return agf_opt_merge(data)
  237. case 'che_merge':
  238. return che_merge(data)
  239. case 'ale_merge':
  240. return ale_merge(data)
  241. case 'nik_rew_merge_skip':
  242. return nik_rew_merge_skip(data)
  243. case _:
  244. raise NameError("Unset or unknown algorithm")
  245. def cluster_partition(iter: "Iterable[Row]") -> Iterable:
  246. yield union_find([], list(map(lambda row: row['addresses'], iter)))
  247. master = Master(config)
  248. master.spark.catalog.clearCache()
  249. master.spark.sparkContext.setCheckpointDir(config['spark_checkpoint_dir'])
  250. tx_df = master.get_tx_dataframe()
  251. #Turn transactions into a list of ('id', [addr, addr, ...])
  252. tx_grouped = tx_df \
  253. .groupBy('tx_id') \
  254. .agg(F.collect_set('address').alias('addresses'))
  255. res = tx_grouped \
  256. .repartition(5) \
  257. .rdd \
  258. .mapPartitions(cluster_partition) \
  259. .fold([], union_find)
  260. """
  261. for cluster in res:
  262. print()
  263. print(sorted(cluster))
  264. """
  265. end = time.time()
  266. print(end-start, end='')