#!/usr/bin/env python3.6 # # Copyright (C) 2018 The University of Sheffield, UK # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # import sys import getopt import sqlite3 from multiprocessing import Process, Queue, set_start_method from itertools import islice, groupby from operator import itemgetter import heapq def unique_justseen(iterable, key=None): "List unique elements, preserving order. Remember only the element just seen." # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B # unique_justseen('ABBCcAD', str.lower) --> A B C A D return map(next, map(itemgetter(1), groupby(iterable, key))) class SimhashTable(Process): STOP = "stop" def __init__(self, splitter, outqueue, fp_it, q_it): super().__init__() self.outqueue = outqueue self.splitter = splitter self.table = {} self.fp_it = fp_it self.q_it = q_it @staticmethod def bit_count(n): return bin(n).count("1") def get_chunk(self, n): """Reduces the simhash to a small chunk, given by self.splitters. The chunk will then be compared exactly in order to increase performance.""" sum = 0 for (s, c) in self.splitter: sum <<= c sum += (n >> s) & (pow(2, c) - 1) return sum def _add(self, fp): fp_chunk = self.get_chunk(fp[1]) if not fp_chunk in self.table: self.table[fp_chunk] = [] self.table[fp_chunk] += [fp] def _query(self, q): q_chunk = self.get_chunk(q) if q_chunk in self.table: for fp in self.table[q_chunk]: diff = SimhashTable.bit_count(q ^ fp[1]) if diff < 4: yield (fp, diff) def run(self): for fp in self.fp_it: self._add(fp) for (q_info, q) in self.q_it: for ((fp_info, fp), diff) in self._query(q): self.outqueue.put((q_info, fp_info, diff)) self.outqueue.put(SimhashTable.STOP) class SimhashBucket(Process): """Implementation of http://wwwconference.org/www2007/papers/paper215.pdf""" def __init__(self, nr_of_tables, fp_it, q_it): super().__init__() # So far, we support the variants with 4 and 20 tables. Each element of splitters # describes the key for one table. The first element of the tuple indicates the number # of bits that we shift the simhash to the right; the second element indicates how many # bits, from the right side, we end up taking. if nr_of_tables == 4: splitters = [[(0, 16)], [(16, 16)], [(32, 16)], [(48, 16)]] elif nr_of_tables == 20: block_sizes = [11, 11, 11, 11, 10, 10] splitters = [] for i in range(0, len(block_sizes)): for j in range(i + 1, len(block_sizes)): for k in range(j + 1, len(block_sizes)): splitters += [[ (sum(block_sizes[i+1:]), block_sizes[i]), (sum(block_sizes[j+1:]), block_sizes[j]), (sum(block_sizes[k+1:]), block_sizes[k]), ]] else: raise Exception(f"Unsupported number of tables: {nr_of_tables}") self.fp_it = fp_it self.q_it = q_it self.splitters = splitters self.tables = [] self.outqueues = [Queue(100) for _ in range(len(self.splitters))] def run(self): self.tables = [SimhashTable(splitter, outqueue, self.fp_it, self.q_it) for (outqueue, splitter) in zip(self.outqueues, self.splitters)] for tbl in self.tables: tbl.start() for tbl in self.tables: tbl.join() def __iter__(self): return unique_justseen(heapq.merge(*[iter(q.get, SimhashTable.STOP) for q in self.outqueues])) def get_cdnjs_simhashes(db_path, limit=None): with sqlite3.connect(db_path) as db: for (simhash, library, path, size, typ, md5) in db.execute("select simhash, library, path, size, typ, md5 from cdnjs where " "simhash IS NOT NULL AND path like '%.js' and " "HEX(md5) <> 'd41d8cd98f00b204e9800998ecf8427e' order by path, size, typ, md5" + (f" LIMIT {int(limit)}" if limit is not None else "")): yield ((path, size, typ, md5.hex()), int(simhash)) def get_crxfile_simhashes(db_path, extension_limit=None, crxfile_limit=None): with sqlite3.connect(db_path) as db: for (extid, date) in islice(db.execute("select extid, max(date) as date from extension group by extid order by extid"), extension_limit): for (crx_etag,) in db.execute("select crx_etag from extension where extid=? and date=? order by crx_etag", (extid, date)): for (path, md5, typ, simhash) in db.execute("select path, md5, typ, simhash from crxfile where crx_etag=? and simhash is not null and path like '%.js' order by path, md5, typ", (crx_etag,)): for (size,) in db.execute("select size from libdet where md5=? and typ=? and size >= 1024 order by size", (md5, typ)): yield ((extid, date, crx_etag, path, md5.hex(), typ, size), int(simhash)) def print_help(): print("""simhashbucket [OPTION] """) print(""" -h, --help print this help text""") print(""" --limit-cdnjs only retrieve N rows""") print(""" --limit-extension only retrieve N rows""") print(""" --limit-crxfile only retrieve N rows""") print(""" --tables number of tables to use for the bucket (4 or 20 so far)""") def parse_args(argv): limit_cdnjs = None limit_extension = None limit_crxfile = None tables = 20 try: opts, args = getopt.getopt(argv, "h", [ "limit-cdnjs=", "limit-extension=", "limit-crxfile=", "help", "tables="]) except getopt.GetoptError: print_help() sys.exit(2) for opt, arg in opts: if opt == "--limit-cdnjs": limit_cdnjs = int(arg) elif opt == "--limit-extension": limit_extension = int(arg) elif opt == "--limit-crxfile": limit_crxfile = int(arg) elif opt == "--tables": tables = int(arg) if len(args) != 1: print_help() sys.exit(2) db_path = args[0] return limit_cdnjs, limit_extension, limit_crxfile, tables, db_path def main(args): limit_cdnjs, limit_extension, limit_crxfile, tables, db_path = parse_args(args) fp_it = get_cdnjs_simhashes(db_path, limit_cdnjs) q_it = get_crxfile_simhashes(db_path, limit_extension, limit_crxfile) bucket = SimhashBucket(tables, fp_it, q_it) bucket.start() for tup in bucket: sys.stdout.write("|".join([str(x) for x in tup]) + "\n") if __name__ == "__main__": main(sys.argv[1:])