From 5d4f321dca1b9b1b8ba690ca5d6f9b23143d8418 Mon Sep 17 00:00:00 2001 From: Michael Herzberg Date: Tue, 10 Jul 2018 20:47:07 +0100 Subject: [PATCH] Changed way of parallelism for simhashbucket and added comparemd5. --- comparemd5 | 44 ++++++++++++ simhashbucket | 193 +++++++++++++++++++++++++------------------------- 2 files changed, 139 insertions(+), 98 deletions(-) create mode 100755 comparemd5 diff --git a/comparemd5 b/comparemd5 new file mode 100755 index 0000000..7bc42df --- /dev/null +++ b/comparemd5 @@ -0,0 +1,44 @@ +#!/usr/bin/env python3.6 +# +# Copyright (C) 2018 The University of Sheffield, UK +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +import sqlite3 +import sys + +db_path = sys.argv[1] + +with sqlite3.connect(db_path) as db: + hit = 0 + miss = 0 + s = {} + for (md5, library, path, typ) in db.execute("select md5, library, path, typ from cdnjs"): + s[md5] = (library, path, typ) + + for (extid, date) in db.execute("select extid, max(date) as date from extension group by extid order by extid"): + for (crx_etag,) in db.execute("select crx_etag from extension where extid=? and date=? order by crx_etag", (extid, date)): + for (path, md5, typ, simhash) in db.execute("select path, md5, typ, simhash from crxfile where crx_etag=? and simhash is not null and path like '%.js' order by path, md5, typ", (crx_etag,)): + for (size,) in db.execute("select size from libdet where md5=? and typ=? and size >= 1024 order by size", (md5, typ)): + if md5 in s: + hit += 1 + # library, path, typ = s[md5] + # print("|".join((library, path, typ, extid, date, path, typ))) + else: + miss += 1 + print("|".join((extid, date, path, typ))) + + print(f"Hit: {hit}") + print(f"Miss: {miss}") diff --git a/simhashbucket b/simhashbucket index 6f89fae..382a8ba 100755 --- a/simhashbucket +++ b/simhashbucket @@ -18,32 +18,85 @@ import sys import getopt -import os import sqlite3 -from itertools import groupby -import time -from multiprocessing import Pool, cpu_count +from multiprocessing import Process, Queue, set_start_method +from itertools import islice, groupby +from operator import itemgetter +import heapq -class SimhashBucket: +def unique_justseen(iterable, key=None): + "List unique elements, preserving order. Remember only the element just seen." + # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B + # unique_justseen('ABBCcAD', str.lower) --> A B C A D + return map(next, map(itemgetter(1), groupby(iterable, key))) + + +class SimhashTable(Process): + STOP = "stop" + + def __init__(self, splitter, outqueue, fp_it, q_it): + super().__init__() + self.outqueue = outqueue + + self.splitter = splitter + self.table = {} + self.fp_it = fp_it + self.q_it = q_it + + @staticmethod + def bit_count(n): + return bin(n).count("1") + + def get_chunk(self, n): + """Reduces the simhash to a small chunk, given by self.splitters. The chunk will + then be compared exactly in order to increase performance.""" + sum = 0 + for (s, c) in self.splitter: + sum <<= c + sum += (n >> s) & (pow(2, c) - 1) + return sum + + def _add(self, fp): + fp_chunk = self.get_chunk(fp[1]) + if not fp_chunk in self.table: + self.table[fp_chunk] = [] + self.table[fp_chunk] += [fp] + + def _query(self, q): + q_chunk = self.get_chunk(q) + if q_chunk in self.table: + for fp in self.table[q_chunk]: + diff = SimhashTable.bit_count(q ^ fp[1]) + if diff < 4: + yield (fp, diff) + + def run(self): + for fp in self.fp_it: + self._add(fp) + for (q_info, q) in self.q_it: + for ((fp_info, fp), diff) in self._query(q): + self.outqueue.put((q_info, fp_info, diff)) + self.outqueue.put(SimhashTable.STOP) + +class SimhashBucket(Process): """Implementation of http://wwwconference.org/www2007/papers/paper215.pdf""" - def __init__(self, nr_of_tables): - self.tables = nr_of_tables * [{}] - + def __init__(self, nr_of_tables, fp_it, q_it): + super().__init__() # So far, we support the variants with 4 and 20 tables. Each element of splitters # describes the key for one table. The first element of the tuple indicates the number # of bits that we shift the simhash to the right; the second element indicates how many # bits, from the right side, we end up taking. if nr_of_tables == 4: - self.splitters = [[(0, 16)], [(16, 16)], [(32, 16)], [(48, 16)]] + splitters = [[(0, 16)], [(16, 16)], [(32, 16)], [(48, 16)]] elif nr_of_tables == 20: block_sizes = [11, 11, 11, 11, 10, 10] - self.splitters = [] + splitters = [] for i in range(0, len(block_sizes)): for j in range(i + 1, len(block_sizes)): for k in range(j + 1, len(block_sizes)): - self.splitters += [[ + splitters += [[ (sum(block_sizes[i+1:]), block_sizes[i]), (sum(block_sizes[j+1:]), block_sizes[j]), (sum(block_sizes[k+1:]), block_sizes[k]), @@ -51,82 +104,42 @@ class SimhashBucket: else: raise Exception(f"Unsupported number of tables: {nr_of_tables}") + self.fp_it = fp_it + self.q_it = q_it - def bit_count(self, n): - return bin(n).count("1") + self.splitters = splitters + self.tables = [] + self.outqueues = [Queue(100) for _ in range(len(self.splitters))] - def get_chunk(self, n, i): - """Reduces the simhash to a small chunk, given by self.splitters. The chunk will - then be compared exactly in order to increase performance.""" - sum = 0 - for (s, c) in self.splitters[i]: - sum <<= c - sum += (n >> s) & (pow(2, c) - 1) - return sum + def run(self): + self.tables = [SimhashTable(splitter, outqueue, self.fp_it, self.q_it) for (outqueue, splitter) in zip(self.outqueues, self.splitters)] + for tbl in self.tables: + tbl.start() - def add(self, fp): - for i, tbl in enumerate(self.tables): - fp_chunk = self.get_chunk(fp[0], i) - if not fp_chunk in tbl: - tbl[fp_chunk] = [] - tbl[fp_chunk] += [fp] - - def query(self, q): - for i, tbl in enumerate(self.tables): - q_chunk = self.get_chunk(q, i) - if q_chunk in tbl: - for fp in tbl[q_chunk]: - diff = self.bit_count(q ^ fp[0]) - if diff < 4: - yield (fp, diff) - - def addMany(self, fps): - for fp in fps: - self.add(fp) - - def queryMany(self, qs): - for q in qs: - for (fp, diff) in self.query(q): - yield (fp, diff) + for tbl in self.tables: + tbl.join() -def groupby_first(xs, n): - return ((x, [y[n:] for y in y]) for x, y in groupby(xs, key=lambda x: x[:n])) + def __iter__(self): + return unique_justseen(heapq.merge(*[iter(q.get, SimhashTable.STOP) for q in self.outqueues])) def get_cdnjs_simhashes(db_path, limit=None): with sqlite3.connect(db_path) as db: - db.row_factory = sqlite3.Row - for row in db.execute("select simhash, library, path, size, typ from cdnjs where " + for (simhash, library, path, size, typ, md5) in db.execute("select simhash, library, path, size, typ, md5 from cdnjs where " "simhash IS NOT NULL AND path like '%.js' and " - "HEX(md5) <> 'd41d8cd98f00b204e9800998ecf8427e'" + + "HEX(md5) <> 'd41d8cd98f00b204e9800998ecf8427e' order by path, size, typ, md5" + (f" LIMIT {int(limit)}" if limit is not None else "")): - yield row + yield ((path, size, typ, md5.hex()), int(simhash)) def get_crxfile_simhashes(db_path, extension_limit=None, crxfile_limit=None): with sqlite3.connect(db_path) as db: - db.row_factory = sqlite3.Row - for row in db.execute(("select extid, date, crx_etag, path, size, typ, simhash from " - # "((select * from extension e1 where date=(select max(date) from extension e2 where e1.extid=e2.extid) order by extid {}) as d1 join " - "((select * from extension join (select extid, max(date) as date from extension group by extid) using (extid, date) order by extid {}) as d1 join " - "(select * from crxfile where simhash is not null and path like '%.js' order by crx_etag, path) as d2 using (crx_etag)) join " - "(select * from libdet where size >= 1024) as d3 using (md5,typ) {}").format( - "LIMIT " + str(int(extension_limit)) if extension_limit is not None else "", - "LIMIT " + str(int(crxfile_limit)) if crxfile_limit is not None else "", - )): - yield row - -def process(tup): - (extid, date, crx_etag), rest = tup - - result = [] - for (path,), rest2 in groupby_first(rest, 1): - for size, typ, simhash in rest2: - for ((_, (lib_path, lib_size, lib_typ)), diff) in gbl_bucket.query(simhash): - result += [(extid, date, crx_etag, path, size, typ, lib_path, lib_size, lib_typ, diff)] - return result - + for (extid, date) in islice(db.execute("select extid, max(date) as date from extension group by extid order by extid"), extension_limit): + for (crx_etag,) in db.execute("select crx_etag from extension where extid=? and date=? order by crx_etag", (extid, date)): + for (path, md5, typ, simhash) in db.execute("select path, md5, typ, simhash from crxfile where crx_etag=? and simhash is not null and path like '%.js' order by path, md5, typ", (crx_etag,)): + for (size,) in db.execute("select size from libdet where md5=? and typ=? and size >= 1024 order by size", (md5, typ)): + yield ((extid, date, crx_etag, path, md5.hex(), typ, size), int(simhash)) def print_help(): print("""simhashbucket [OPTION] """) @@ -134,19 +147,17 @@ def print_help(): print(""" --limit-cdnjs only retrieve N rows""") print(""" --limit-extension only retrieve N rows""") print(""" --limit-crxfile only retrieve N rows""") - print(""" -t number of parallel threads""") print(""" --tables number of tables to use for the bucket (4 or 20 so far)""") def parse_args(argv): limit_cdnjs = None limit_extension = None limit_crxfile = None - parallel = cpu_count() tables = 20 try: - opts, args = getopt.getopt(argv, "ht:", [ - "limit-cdnjs=", "limit-extension=", "limit-crxfile=", "help", "parallel=", "tables="]) + opts, args = getopt.getopt(argv, "h", [ + "limit-cdnjs=", "limit-extension=", "limit-crxfile=", "help", "tables="]) except getopt.GetoptError: print_help() sys.exit(2) @@ -157,8 +168,6 @@ def parse_args(argv): limit_extension = int(arg) elif opt == "--limit-crxfile": limit_crxfile = int(arg) - elif opt in ("-t", "--parallel"): - parallel = int(arg) elif opt == "--tables": tables = int(arg) @@ -167,31 +176,19 @@ def parse_args(argv): sys.exit(2) db_path = args[0] - return limit_cdnjs, limit_extension, limit_crxfile, parallel, tables, db_path - - -def init(bucket): - global gbl_bucket - gbl_bucket = bucket + return limit_cdnjs, limit_extension, limit_crxfile, tables, db_path def main(args): - limit_cdnjs, limit_extension, limit_crxfile, parallel, tables, db_path = parse_args(args) - bucket = SimhashBucket(tables) + limit_cdnjs, limit_extension, limit_crxfile, tables, db_path = parse_args(args) - start_build = time.time() - bucket.addMany(((int(row["simhash"]), (row["path"], row["size"], row["typ"])) for row in get_cdnjs_simhashes(db_path, limit_cdnjs))) - sys.stderr.write(f"Building the bucket took {format(time.time() - start_build, '.2f')} seconds\n") + fp_it = get_cdnjs_simhashes(db_path, limit_cdnjs) + q_it = get_crxfile_simhashes(db_path, limit_extension, limit_crxfile) - start_query = time.time() - - data = ((row["extid"], row["date"], row["crx_etag"], row["path"], row["size"], row["typ"], int(row["simhash"])) for row in get_crxfile_simhashes(db_path, limit_extension, limit_crxfile)) - with Pool(parallel, initializer=init, initargs=(bucket,)) as p: - for tups in p.imap_unordered(process, groupby_first(data, 3), 100): - for tup in tups: - sys.stdout.write("|".join([str(x) for x in tup]) + "\n") - sys.stdout.flush() - sys.stderr.write(f"The query took {format(time.time() - start_query, '.2f')} seconds\n") + bucket = SimhashBucket(tables, fp_it, q_it) + bucket.start() + for tup in bucket: + sys.stdout.write("|".join([str(x) for x in tup]) + "\n") if __name__ == "__main__":