diff --git a/simhashbucket b/simhashbucket index 382a8ba..7fd86c8 100755 --- a/simhashbucket +++ b/simhashbucket @@ -17,12 +17,14 @@ # import sys +import os import getopt -import sqlite3 -from multiprocessing import Process, Queue, set_start_method +from multiprocessing import Process, Queue from itertools import islice, groupby from operator import itemgetter import heapq +import MySQLdb +from MySQLdb import cursors def unique_justseen(iterable, key=None): @@ -32,17 +34,32 @@ def unique_justseen(iterable, key=None): return map(next, map(itemgetter(1), groupby(iterable, key))) +def grouper(it, chunksize): + while True: + chunk = list(islice(it, chunksize)) + yield chunk + if len(chunk) < chunksize: + break + + +def execute(q, args=None): + db = MySQLdb.connect(read_default_file=os.path.expanduser("~/.my.cnf"), cursorclass=cursors.SSCursor) + cursor = db.cursor() + cursor.execute(q, args) + return cursor + + class SimhashTable(Process): STOP = "stop" - def __init__(self, splitter, outqueue, fp_it, q_it): + def __init__(self, splitter, outqueue, fp_q, query_q): super().__init__() self.outqueue = outqueue self.splitter = splitter self.table = {} - self.fp_it = fp_it - self.q_it = q_it + self.fp_q = fp_q + self.query_q = query_q @staticmethod def bit_count(n): @@ -64,25 +81,28 @@ class SimhashTable(Process): self.table[fp_chunk] += [fp] def _query(self, q): - q_chunk = self.get_chunk(q) - if q_chunk in self.table: - for fp in self.table[q_chunk]: + query_chunk = self.get_chunk(q) + if query_chunk in self.table: + for fp in self.table[query_chunk]: diff = SimhashTable.bit_count(q ^ fp[1]) if diff < 4: yield (fp, diff) def run(self): - for fp in self.fp_it: - self._add(fp) - for (q_info, q) in self.q_it: - for ((fp_info, fp), diff) in self._query(q): - self.outqueue.put((q_info, fp_info, diff)) + for fps in iter(self.fp_q.get, SimhashTable.STOP): + for fp in fps: + self._add(fp) + for queries in iter(self.query_q.get, SimhashTable.STOP): + for (query_info, q) in queries: + for ((fp_info, fp), diff) in self._query(q): + self.outqueue.put((query_info, fp_info, diff)) self.outqueue.put(SimhashTable.STOP) + class SimhashBucket(Process): """Implementation of http://wwwconference.org/www2007/papers/paper215.pdf""" - def __init__(self, nr_of_tables, fp_it, q_it): + def __init__(self, nr_of_tables, fp_it, query_it): super().__init__() # So far, we support the variants with 4 and 20 tables. Each element of splitters # describes the key for one table. The first element of the tuple indicates the number @@ -105,55 +125,79 @@ class SimhashBucket(Process): raise Exception(f"Unsupported number of tables: {nr_of_tables}") self.fp_it = fp_it - self.q_it = q_it + self.query_it = query_it self.splitters = splitters self.tables = [] - self.outqueues = [Queue(100) for _ in range(len(self.splitters))] + self.fp_qs = [Queue(100) for _ in range(len(self.splitters))] + self.query_qs = [Queue(100) for _ in range(len(self.splitters))] + self.out_qs = [Queue(100) for _ in range(len(self.splitters))] + + @staticmethod + def broadcast(it, qs, chunksize=1): + for x in grouper(it, chunksize): + for q in qs: + q.put(x) + for q in qs: + q.put(SimhashTable.STOP) def run(self): - self.tables = [SimhashTable(splitter, outqueue, self.fp_it, self.q_it) for (outqueue, splitter) in zip(self.outqueues, self.splitters)] + self.tables = [SimhashTable(*args) for args in zip(self.splitters, self.out_qs, self.fp_qs, self.query_qs)] for tbl in self.tables: tbl.start() + SimhashBucket.broadcast(self.fp_it, self.fp_qs, 100) + SimhashBucket.broadcast(self.query_it, self.query_qs, 100) + for tbl in self.tables: tbl.join() - def __iter__(self): - return unique_justseen(heapq.merge(*[iter(q.get, SimhashTable.STOP) for q in self.outqueues])) + return unique_justseen(heapq.merge(*[iter(q.get, SimhashTable.STOP) for q in self.out_qs])) -def get_cdnjs_simhashes(db_path, limit=None): - with sqlite3.connect(db_path) as db: - for (simhash, library, path, size, typ, md5) in db.execute("select simhash, library, path, size, typ, md5 from cdnjs where " - "simhash IS NOT NULL AND path like '%.js' and " - "HEX(md5) <> 'd41d8cd98f00b204e9800998ecf8427e' order by path, size, typ, md5" + - (f" LIMIT {int(limit)}" if limit is not None else "")): - yield ((path, size, typ, md5.hex()), int(simhash)) +def get_cdnjs_simhashes(limit=None): + for (simhash, path, typ, library, version) in execute(( + "select simhash, path, typ, library, version from " + "cdnjs where " + "simhash IS NOT NULL AND path like '%.js' and " + "HEX(md5) <> 'd41d8cd98f00b204e9800998ecf8427e' " + "order by path, typ {limit}") + .format( + limit=f"limit {limit}" if limit else "" + )): + # Whatever gets returned here must be sorted + yield ((path, typ, library, version), int(simhash)) -def get_crxfile_simhashes(db_path, extension_limit=None, crxfile_limit=None): - with sqlite3.connect(db_path) as db: - for (extid, date) in islice(db.execute("select extid, max(date) as date from extension group by extid order by extid"), extension_limit): - for (crx_etag,) in db.execute("select crx_etag from extension where extid=? and date=? order by crx_etag", (extid, date)): - for (path, md5, typ, simhash) in db.execute("select path, md5, typ, simhash from crxfile where crx_etag=? and simhash is not null and path like '%.js' order by path, md5, typ", (crx_etag,)): - for (size,) in db.execute("select size from libdet where md5=? and typ=? and size >= 1024 order by size", (md5, typ)): - yield ((extid, date, crx_etag, path, md5.hex(), typ, size), int(simhash)) +def get_crxfile_simhashes(extension_limit=None, crxfile_limit=None): + for (extid, date, crx_etag, path, typ, simhash) in execute(( + "select extid, date, crx_etag, path, typ, simhash from " + "(select * from extension_most_recent order by extid, date {extension_limit}) extension join " + "(select * from crxfile order by crx_etag, path, typ {crxfile_limit}) crxfile using (crx_etag) " + "join libdet using (md5, typ) " + "where simhash is not null and path like '%.js' and size >= 1024") + .format( + extension_limit=f"limit {extension_limit}" if extension_limit else "", + crxfile_limit=f"limit {crxfile_limit}" if crxfile_limit else "" + )): + # Whatever gets returned here must be sorted + yield ((extid, date, crx_etag, path, typ), int(simhash)) + def print_help(): - print("""simhashbucket [OPTION] """) - print(""" -h, --help print this help text""") - print(""" --limit-cdnjs only retrieve N rows""") - print(""" --limit-extension only retrieve N rows""") - print(""" --limit-crxfile only retrieve N rows""") - print(""" --tables number of tables to use for the bucket (4 or 20 so far)""") + print("""simhashbucket [OPTIONS]""") + print(""" -h, --help print this help text""") + print(""" --limit-cdnjs only retrieve N rows, default: all""") + print(""" --limit-extension only retrieve N rows, default: all""") + print(""" --limit-crxfile only retrieve N rows, default: all""") + print(""" --tables number of tables to use for the bucket (4 or 20 so far), default: 4""") def parse_args(argv): limit_cdnjs = None limit_extension = None limit_crxfile = None - tables = 20 + tables = 4 try: opts, args = getopt.getopt(argv, "h", [ @@ -161,31 +205,38 @@ def parse_args(argv): except getopt.GetoptError: print_help() sys.exit(2) - for opt, arg in opts: - if opt == "--limit-cdnjs": - limit_cdnjs = int(arg) - elif opt == "--limit-extension": - limit_extension = int(arg) - elif opt == "--limit-crxfile": - limit_crxfile = int(arg) - elif opt == "--tables": - tables = int(arg) - - if len(args) != 1: + try: + for opt, arg in opts: + if opt == "--limit-cdnjs": + limit_cdnjs = int(arg) + elif opt == "--limit-extension": + limit_extension = int(arg) + elif opt == "--limit-crxfile": + limit_crxfile = int(arg) + elif opt == "--tables": + tables = int(arg) + elif opt in ["-h", "--help"]: + print_help() + sys.exit(0) + except ValueError: + print("Arguments to int options must be an int!", file=sys.stderr) print_help() sys.exit(2) - db_path = args[0] - return limit_cdnjs, limit_extension, limit_crxfile, tables, db_path + if len(args) != 0: + print_help() + sys.exit(2) + + return limit_cdnjs, limit_extension, limit_crxfile, tables def main(args): - limit_cdnjs, limit_extension, limit_crxfile, tables, db_path = parse_args(args) + limit_cdnjs, limit_extension, limit_crxfile, tables = parse_args(args) - fp_it = get_cdnjs_simhashes(db_path, limit_cdnjs) - q_it = get_crxfile_simhashes(db_path, limit_extension, limit_crxfile) + fp_it = get_cdnjs_simhashes(limit_cdnjs) + query_it = get_crxfile_simhashes(limit_extension, limit_crxfile) - bucket = SimhashBucket(tables, fp_it, q_it) + bucket = SimhashBucket(tables, fp_it, query_it) bucket.start() for tup in bucket: sys.stdout.write("|".join([str(x) for x in tup]) + "\n")