Added MD5Table and made max simhash dist configurable.

This commit is contained in:
Michael Herzberg 2018-09-12 21:56:18 +01:00
parent 19bccc8be1
commit 81b2bbd21f
1 changed files with 116 additions and 72 deletions

View File

@ -18,8 +18,9 @@
import sys import sys
import os import os
import math
import getopt import getopt
from multiprocessing import Process, Queue from multiprocessing import Process, Queue, Manager
from itertools import islice, groupby from itertools import islice, groupby
from operator import itemgetter from operator import itemgetter
import heapq import heapq
@ -43,18 +44,44 @@ def grouper(it, chunksize):
def execute(q, args=None): def execute(q, args=None):
db = MySQLdb.connect(read_default_file=os.path.expanduser("~/.my.cnf"), cursorclass=cursors.SSCursor) db = MySQLdb.connect(read_default_file=os.path.expanduser("~/.my.cnf"), cursorclass=cursors.SSDictCursor)
cursor = db.cursor() cursor = db.cursor()
cursor.execute(q, args) cursor.execute(q, args)
return cursor return cursor
class MD5Table(Process):
def __init__(self, out_q, fp_q, query_q):
super().__init__()
self.out_q = out_q
self.fp_q = fp_q
self.query_q = query_q
self.table = {}
def run(self):
for fps in iter(self.fp_q.get, SimhashTable.STOP):
for fp_info, _, fp_md5 in fps:
if fp_md5 not in self.table:
self.table[fp_md5] = []
self.table[fp_md5] += [fp_info]
for queries in iter(self.query_q.get, SimhashTable.STOP):
for query_info, _, query_md5 in queries:
if query_md5 in self.table:
for fp_info in self.table[query_md5]:
self.out_q.put((query_info, fp_info, -1))
self.out_q.put(SimhashTable.STOP)
class SimhashTable(Process): class SimhashTable(Process):
STOP = "stop" STOP = "stop"
def __init__(self, splitter, outqueue, fp_q, query_q): def __init__(self, max_dist, splitter, out_q, fp_q, query_q):
super().__init__() super().__init__()
self.outqueue = outqueue
self.max_dist = max_dist
self.out_q = out_q
self.splitter = splitter self.splitter = splitter
self.table = {} self.table = {}
@ -74,115 +101,110 @@ class SimhashTable(Process):
sum += (n >> s) & (pow(2, c) - 1) sum += (n >> s) & (pow(2, c) - 1)
return sum return sum
def _add(self, fp): def _add(self, fp_info, fp_simhash):
fp_chunk = self.get_chunk(fp[1]) fp_chunk = self.get_chunk(fp_simhash)
if not fp_chunk in self.table: if not fp_chunk in self.table:
self.table[fp_chunk] = [] self.table[fp_chunk] = []
self.table[fp_chunk] += [fp] self.table[fp_chunk] += [(fp_info, fp_simhash)]
def _query(self, q): def _query(self, query_simhash):
query_chunk = self.get_chunk(q) query_chunk = self.get_chunk(query_simhash)
if query_chunk in self.table: if query_chunk in self.table:
for fp in self.table[query_chunk]: for fp_info, fp_simhash in self.table[query_chunk]:
diff = SimhashTable.bit_count(q ^ fp[1]) diff = SimhashTable.bit_count(query_simhash ^ fp_simhash)
if diff < 4: if diff <= self.max_dist:
yield (fp, diff) yield ((fp_info, fp_simhash), diff)
def run(self): def run(self):
for fps in iter(self.fp_q.get, SimhashTable.STOP): for fps in iter(self.fp_q.get, SimhashTable.STOP):
for fp in fps: for fp_info, fp_simhash, _ in fps:
self._add(fp) self._add(fp_info, fp_simhash)
for queries in iter(self.query_q.get, SimhashTable.STOP): for queries in iter(self.query_q.get, SimhashTable.STOP):
for (query_info, q) in queries: for query_info, query_simhash, _ in queries:
for ((fp_info, fp), diff) in self._query(q): for ((fp_info, fp_simhash), diff) in self._query(query_simhash):
self.outqueue.put((query_info, fp_info, diff)) self.out_q.put((query_info, fp_info, diff))
self.outqueue.put(SimhashTable.STOP) self.out_q.put(SimhashTable.STOP)
class SimhashBucket(Process): class SimhashBucket(Process):
"""Implementation of http://wwwconference.org/www2007/papers/paper215.pdf""" """Implementation of http://wwwconference.org/www2007/papers/paper215.pdf"""
def __init__(self, nr_of_tables, fp_it, query_it): def __init__(self, fp_it, query_it, max_dist=3, fp_size=64):
super().__init__() super().__init__()
# So far, we support the variants with 4 and 20 tables. Each element of splitters # Each element of splitters describes the key for one table. The first element of the tuple indicates the number
# describes the key for one table. The first element of the tuple indicates the number # of bits that we shift the simhash to the right; the second element indicates how many bits, from the right
# of bits that we shift the simhash to the right; the second element indicates how many # side, we end up taking. For example, with max_dist=5, we end up with [(0, 11)], [(11, 11)], [(22, 11)],
# bits, from the right side, we end up taking. # [(33, 11)], [(44, 11)], [(55, 9)]
if nr_of_tables == 4: if max_dist >= 0 :
splitters = [[(0, 16)], [(16, 16)], [(32, 16)], [(48, 16)]] chunksize = math.ceil(fp_size / (max_dist + 1))
elif nr_of_tables == 20: self.splitters = [[(i, min(chunksize, fp_size - i))] for i in range(0, fp_size, chunksize)]
block_sizes = [11, 11, 11, 11, 10, 10]
splitters = []
for i in range(0, len(block_sizes)):
for j in range(i + 1, len(block_sizes)):
for k in range(j + 1, len(block_sizes)):
splitters += [[
(sum(block_sizes[i+1:]), block_sizes[i]),
(sum(block_sizes[j+1:]), block_sizes[j]),
(sum(block_sizes[k+1:]), block_sizes[k]),
]]
else: else:
raise Exception(f"Unsupported number of tables: {nr_of_tables}") self.splitters = []
self.fp_it = fp_it self.fp_it = fp_it
self.query_it = query_it self.query_it = query_it
self.splitters = splitters
self.tables = [] self.tables = []
self.fp_qs = [Queue(100) for _ in range(len(self.splitters))] self.fp_qs = [Queue() for _ in range(len(self.splitters) + 1)]
self.query_qs = [Queue(100) for _ in range(len(self.splitters))] self.query_qs = [Queue() for _ in range(len(self.splitters) + 1)]
self.out_qs = [Queue(100) for _ in range(len(self.splitters))] self.out_qs = [Queue() for _ in range(len(self.splitters) + 1)]
self.max_dist = max_dist
self.fp_store = Manager().list()
self.query_store = Manager().list()
@staticmethod @staticmethod
def broadcast(it, qs, chunksize=1): def broadcast(it, qs, store, chunksize=1):
for x in grouper(it, chunksize): for x in grouper(it, chunksize):
store += [info for info, _, _ in x]
for q in qs: for q in qs:
q.put(x) q.put([(len(store) - len(x) + i, simhash, md5) for i, (_, simhash, md5) in enumerate(x)])
for q in qs: for q in qs:
q.put(SimhashTable.STOP) q.put(SimhashTable.STOP)
def run(self): def run(self):
self.tables = [SimhashTable(*args) for args in zip(self.splitters, self.out_qs, self.fp_qs, self.query_qs)] self.tables = [SimhashTable(self.max_dist, *args) for args in zip(self.splitters, self.out_qs, self.fp_qs, self.query_qs)] \
+ [MD5Table(self.out_qs[-1], self.fp_qs[-1], self.query_qs[-1])]
for tbl in self.tables: for tbl in self.tables:
tbl.start() tbl.start()
SimhashBucket.broadcast(self.fp_it, self.fp_qs, 100) SimhashBucket.broadcast(self.fp_it, self.fp_qs, self.fp_store, 100)
SimhashBucket.broadcast(self.query_it, self.query_qs, 100) SimhashBucket.broadcast(self.query_it, self.query_qs, self.query_store, 100)
for tbl in self.tables: for i, tbl in enumerate(self.tables):
tbl.join() tbl.join()
def __iter__(self): def __iter__(self):
return unique_justseen(heapq.merge(*[iter(q.get, SimhashTable.STOP) for q in self.out_qs])) for query_i, fp_i, diff in unique_justseen(heapq.merge(*[iter(q.get, SimhashTable.STOP) for q in self.out_qs])):
yield self.query_store[query_i], self.fp_store[fp_i], diff
def get_cdnjs_simhashes(limit=None): def get_cdnjs_simhashes(limit=None):
for (simhash, path, typ, library, version) in execute(( for row in execute((
"select simhash, path, typ, library, version from " "select simhash, path, typ, library, version, md5, add_date from "
"cdnjs where " "cdnjs where "
"simhash IS NOT NULL AND path like '%.js' and " "simhash IS NOT NULL AND path like '%.js' and "
"HEX(md5) <> 'd41d8cd98f00b204e9800998ecf8427e' " "HEX(md5) <> 'd41d8cd98f00b204e9800998ecf8427e' and typ = 'NORMALIZED' and add_date is not null "
"order by path, typ {limit}") "{limit}")
.format( .format(
limit=f"limit {limit}" if limit else "" limit=f"limit {limit}" if limit else ""
)): )):
# Whatever gets returned here must be sorted yield row, int(row['simhash']), row['md5']
yield ((path, typ, library, version), int(simhash))
def get_crxfile_simhashes(extension_limit=None, crxfile_limit=None): def get_crxfile_simhashes(extension_limit=None, crxfile_limit=None):
for (extid, date, crx_etag, path, typ, simhash) in execute(( for row in execute((
"select extid, date, crx_etag, path, typ, simhash from " "select extid, crx_etag, path, typ, simhash, md5, lastupdated, name from "
"(select * from extension_most_recent order by extid, date {extension_limit}) extension join " "(select * from extension_most_recent where downloads >= 100000 {extension_limit}) extension join "
"(select * from crxfile order by crx_etag, path, typ {crxfile_limit}) crxfile using (crx_etag) " "(select * from crxfile {crxfile_limit}) crxfile using (crx_etag) "
"join libdet using (md5, typ) " "join libdet using (md5, typ) "
"where simhash is not null and path like '%.js' and size >= 1024") "where simhash is not null and path like '%.js' and typ = 'NORMALIZED'")
.format( .format(
extension_limit=f"limit {extension_limit}" if extension_limit else "", extension_limit=f"limit {extension_limit}" if extension_limit else "",
crxfile_limit=f"limit {crxfile_limit}" if crxfile_limit else "" crxfile_limit=f"limit {crxfile_limit}" if crxfile_limit else ""
)): )):
# Whatever gets returned here must be sorted yield row, int(row['simhash']), row['md5']
yield ((extid, date, crx_etag, path, typ), int(simhash))
def print_help(): def print_help():
@ -191,17 +213,15 @@ def print_help():
print(""" --limit-cdnjs <N> only retrieve N rows, default: all""") print(""" --limit-cdnjs <N> only retrieve N rows, default: all""")
print(""" --limit-extension <N> only retrieve N rows, default: all""") print(""" --limit-extension <N> only retrieve N rows, default: all""")
print(""" --limit-crxfile <N> only retrieve N rows, default: all""") print(""" --limit-crxfile <N> only retrieve N rows, default: all""")
print(""" --tables <N> number of tables to use for the bucket (4 or 20 so far), default: 4""")
def parse_args(argv): def parse_args(argv):
limit_cdnjs = None limit_cdnjs = None
limit_extension = None limit_extension = None
limit_crxfile = None limit_crxfile = None
tables = 4
try: try:
opts, args = getopt.getopt(argv, "h", [ opts, args = getopt.getopt(argv, "h", [
"limit-cdnjs=", "limit-extension=", "limit-crxfile=", "help", "tables="]) "limit-cdnjs=", "limit-extension=", "limit-crxfile=", "help"])
except getopt.GetoptError: except getopt.GetoptError:
print_help() print_help()
sys.exit(2) sys.exit(2)
@ -213,8 +233,6 @@ def parse_args(argv):
limit_extension = int(arg) limit_extension = int(arg)
elif opt == "--limit-crxfile": elif opt == "--limit-crxfile":
limit_crxfile = int(arg) limit_crxfile = int(arg)
elif opt == "--tables":
tables = int(arg)
elif opt in ["-h", "--help"]: elif opt in ["-h", "--help"]:
print_help() print_help()
sys.exit(0) sys.exit(0)
@ -227,19 +245,45 @@ def parse_args(argv):
print_help() print_help()
sys.exit(2) sys.exit(2)
return limit_cdnjs, limit_extension, limit_crxfile, tables return limit_cdnjs, limit_extension, limit_crxfile
def main(args): def main(args):
limit_cdnjs, limit_extension, limit_crxfile, tables = parse_args(args) limit_cdnjs, limit_extension, limit_crxfile = parse_args(args)
fp_it = get_cdnjs_simhashes(limit_cdnjs) fp_it = get_cdnjs_simhashes(limit_cdnjs)
query_it = get_crxfile_simhashes(limit_extension, limit_crxfile) query_it = get_crxfile_simhashes(limit_extension, limit_crxfile)
bucket = SimhashBucket(tables, fp_it, query_it) bucket = SimhashBucket(fp_it, query_it, max_dist=-1)
bucket.start() bucket.start()
for tup in bucket: libraries = {}
sys.stdout.write("|".join([str(x) for x in tup]) + "\n") for query_info, fp_info, diff in bucket:
if diff == -1:
lib = fp_info["library"]
t = (fp_info["add_date"], fp_info["version"])
if lib not in libraries:
libraries[lib] = {}
if t not in libraries[lib]:
libraries[lib][t] = []
libraries[lib][t] += [(query_info, fp_info)]
#if fp_info["library"] == "jquery" and fp_info["version"] == "2.1.1":
# print(f"{query_info['extid']} ({query_info['crx_etag']}): {query_info['name']} ({query_info['lastupdated']})")
res = []
for lib in libraries:
assigned_MD5s = set()
for add_date, version in sorted(libraries[lib], key=lambda tup: tup[0], reverse=True):
md5s = set()
exts = set()
for query_info, fp_info in libraries[lib][(add_date, version)]:
if fp_info["md5"] not in assigned_MD5s:
exts.add((query_info["extid"], query_info["crx_etag"]))
md5s.add(fp_info["md5"])
for md5 in md5s:
assigned_MD5s.add(md5)
res += [(len(exts), lib, version)]
for N, lib, version in sorted(res):
print(f"{lib} (v{version}): {N}")
if __name__ == "__main__": if __name__ == "__main__":