Changed way of parallelism for simhashbucket and added comparemd5.
This commit is contained in:
parent
da5133b2b1
commit
5d4f321dca
|
@ -0,0 +1,44 @@
|
|||
#!/usr/bin/env python3.6
|
||||
#
|
||||
# Copyright (C) 2018 The University of Sheffield, UK
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
#
|
||||
|
||||
import sqlite3
|
||||
import sys
|
||||
|
||||
db_path = sys.argv[1]
|
||||
|
||||
with sqlite3.connect(db_path) as db:
|
||||
hit = 0
|
||||
miss = 0
|
||||
s = {}
|
||||
for (md5, library, path, typ) in db.execute("select md5, library, path, typ from cdnjs"):
|
||||
s[md5] = (library, path, typ)
|
||||
|
||||
for (extid, date) in db.execute("select extid, max(date) as date from extension group by extid order by extid"):
|
||||
for (crx_etag,) in db.execute("select crx_etag from extension where extid=? and date=? order by crx_etag", (extid, date)):
|
||||
for (path, md5, typ, simhash) in db.execute("select path, md5, typ, simhash from crxfile where crx_etag=? and simhash is not null and path like '%.js' order by path, md5, typ", (crx_etag,)):
|
||||
for (size,) in db.execute("select size from libdet where md5=? and typ=? and size >= 1024 order by size", (md5, typ)):
|
||||
if md5 in s:
|
||||
hit += 1
|
||||
# library, path, typ = s[md5]
|
||||
# print("|".join((library, path, typ, extid, date, path, typ)))
|
||||
else:
|
||||
miss += 1
|
||||
print("|".join((extid, date, path, typ)))
|
||||
|
||||
print(f"Hit: {hit}")
|
||||
print(f"Miss: {miss}")
|
191
simhashbucket
191
simhashbucket
|
@ -18,32 +18,85 @@
|
|||
|
||||
import sys
|
||||
import getopt
|
||||
import os
|
||||
import sqlite3
|
||||
from itertools import groupby
|
||||
import time
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from multiprocessing import Process, Queue, set_start_method
|
||||
from itertools import islice, groupby
|
||||
from operator import itemgetter
|
||||
import heapq
|
||||
|
||||
|
||||
class SimhashBucket:
|
||||
def unique_justseen(iterable, key=None):
|
||||
"List unique elements, preserving order. Remember only the element just seen."
|
||||
# unique_justseen('AAAABBBCCDAABBB') --> A B C D A B
|
||||
# unique_justseen('ABBCcAD', str.lower) --> A B C A D
|
||||
return map(next, map(itemgetter(1), groupby(iterable, key)))
|
||||
|
||||
|
||||
class SimhashTable(Process):
|
||||
STOP = "stop"
|
||||
|
||||
def __init__(self, splitter, outqueue, fp_it, q_it):
|
||||
super().__init__()
|
||||
self.outqueue = outqueue
|
||||
|
||||
self.splitter = splitter
|
||||
self.table = {}
|
||||
self.fp_it = fp_it
|
||||
self.q_it = q_it
|
||||
|
||||
@staticmethod
|
||||
def bit_count(n):
|
||||
return bin(n).count("1")
|
||||
|
||||
def get_chunk(self, n):
|
||||
"""Reduces the simhash to a small chunk, given by self.splitters. The chunk will
|
||||
then be compared exactly in order to increase performance."""
|
||||
sum = 0
|
||||
for (s, c) in self.splitter:
|
||||
sum <<= c
|
||||
sum += (n >> s) & (pow(2, c) - 1)
|
||||
return sum
|
||||
|
||||
def _add(self, fp):
|
||||
fp_chunk = self.get_chunk(fp[1])
|
||||
if not fp_chunk in self.table:
|
||||
self.table[fp_chunk] = []
|
||||
self.table[fp_chunk] += [fp]
|
||||
|
||||
def _query(self, q):
|
||||
q_chunk = self.get_chunk(q)
|
||||
if q_chunk in self.table:
|
||||
for fp in self.table[q_chunk]:
|
||||
diff = SimhashTable.bit_count(q ^ fp[1])
|
||||
if diff < 4:
|
||||
yield (fp, diff)
|
||||
|
||||
def run(self):
|
||||
for fp in self.fp_it:
|
||||
self._add(fp)
|
||||
for (q_info, q) in self.q_it:
|
||||
for ((fp_info, fp), diff) in self._query(q):
|
||||
self.outqueue.put((q_info, fp_info, diff))
|
||||
self.outqueue.put(SimhashTable.STOP)
|
||||
|
||||
class SimhashBucket(Process):
|
||||
"""Implementation of http://wwwconference.org/www2007/papers/paper215.pdf"""
|
||||
|
||||
def __init__(self, nr_of_tables):
|
||||
self.tables = nr_of_tables * [{}]
|
||||
|
||||
def __init__(self, nr_of_tables, fp_it, q_it):
|
||||
super().__init__()
|
||||
# So far, we support the variants with 4 and 20 tables. Each element of splitters
|
||||
# describes the key for one table. The first element of the tuple indicates the number
|
||||
# of bits that we shift the simhash to the right; the second element indicates how many
|
||||
# bits, from the right side, we end up taking.
|
||||
if nr_of_tables == 4:
|
||||
self.splitters = [[(0, 16)], [(16, 16)], [(32, 16)], [(48, 16)]]
|
||||
splitters = [[(0, 16)], [(16, 16)], [(32, 16)], [(48, 16)]]
|
||||
elif nr_of_tables == 20:
|
||||
block_sizes = [11, 11, 11, 11, 10, 10]
|
||||
self.splitters = []
|
||||
splitters = []
|
||||
for i in range(0, len(block_sizes)):
|
||||
for j in range(i + 1, len(block_sizes)):
|
||||
for k in range(j + 1, len(block_sizes)):
|
||||
self.splitters += [[
|
||||
splitters += [[
|
||||
(sum(block_sizes[i+1:]), block_sizes[i]),
|
||||
(sum(block_sizes[j+1:]), block_sizes[j]),
|
||||
(sum(block_sizes[k+1:]), block_sizes[k]),
|
||||
|
@ -51,82 +104,42 @@ class SimhashBucket:
|
|||
else:
|
||||
raise Exception(f"Unsupported number of tables: {nr_of_tables}")
|
||||
|
||||
self.fp_it = fp_it
|
||||
self.q_it = q_it
|
||||
|
||||
def bit_count(self, n):
|
||||
return bin(n).count("1")
|
||||
self.splitters = splitters
|
||||
self.tables = []
|
||||
self.outqueues = [Queue(100) for _ in range(len(self.splitters))]
|
||||
|
||||
def get_chunk(self, n, i):
|
||||
"""Reduces the simhash to a small chunk, given by self.splitters. The chunk will
|
||||
then be compared exactly in order to increase performance."""
|
||||
sum = 0
|
||||
for (s, c) in self.splitters[i]:
|
||||
sum <<= c
|
||||
sum += (n >> s) & (pow(2, c) - 1)
|
||||
return sum
|
||||
def run(self):
|
||||
self.tables = [SimhashTable(splitter, outqueue, self.fp_it, self.q_it) for (outqueue, splitter) in zip(self.outqueues, self.splitters)]
|
||||
for tbl in self.tables:
|
||||
tbl.start()
|
||||
|
||||
def add(self, fp):
|
||||
for i, tbl in enumerate(self.tables):
|
||||
fp_chunk = self.get_chunk(fp[0], i)
|
||||
if not fp_chunk in tbl:
|
||||
tbl[fp_chunk] = []
|
||||
tbl[fp_chunk] += [fp]
|
||||
|
||||
def query(self, q):
|
||||
for i, tbl in enumerate(self.tables):
|
||||
q_chunk = self.get_chunk(q, i)
|
||||
if q_chunk in tbl:
|
||||
for fp in tbl[q_chunk]:
|
||||
diff = self.bit_count(q ^ fp[0])
|
||||
if diff < 4:
|
||||
yield (fp, diff)
|
||||
|
||||
def addMany(self, fps):
|
||||
for fp in fps:
|
||||
self.add(fp)
|
||||
|
||||
def queryMany(self, qs):
|
||||
for q in qs:
|
||||
for (fp, diff) in self.query(q):
|
||||
yield (fp, diff)
|
||||
for tbl in self.tables:
|
||||
tbl.join()
|
||||
|
||||
|
||||
def groupby_first(xs, n):
|
||||
return ((x, [y[n:] for y in y]) for x, y in groupby(xs, key=lambda x: x[:n]))
|
||||
def __iter__(self):
|
||||
return unique_justseen(heapq.merge(*[iter(q.get, SimhashTable.STOP) for q in self.outqueues]))
|
||||
|
||||
|
||||
def get_cdnjs_simhashes(db_path, limit=None):
|
||||
with sqlite3.connect(db_path) as db:
|
||||
db.row_factory = sqlite3.Row
|
||||
for row in db.execute("select simhash, library, path, size, typ from cdnjs where "
|
||||
for (simhash, library, path, size, typ, md5) in db.execute("select simhash, library, path, size, typ, md5 from cdnjs where "
|
||||
"simhash IS NOT NULL AND path like '%.js' and "
|
||||
"HEX(md5) <> 'd41d8cd98f00b204e9800998ecf8427e'" +
|
||||
"HEX(md5) <> 'd41d8cd98f00b204e9800998ecf8427e' order by path, size, typ, md5" +
|
||||
(f" LIMIT {int(limit)}" if limit is not None else "")):
|
||||
yield row
|
||||
yield ((path, size, typ, md5.hex()), int(simhash))
|
||||
|
||||
|
||||
def get_crxfile_simhashes(db_path, extension_limit=None, crxfile_limit=None):
|
||||
with sqlite3.connect(db_path) as db:
|
||||
db.row_factory = sqlite3.Row
|
||||
for row in db.execute(("select extid, date, crx_etag, path, size, typ, simhash from "
|
||||
# "((select * from extension e1 where date=(select max(date) from extension e2 where e1.extid=e2.extid) order by extid {}) as d1 join "
|
||||
"((select * from extension join (select extid, max(date) as date from extension group by extid) using (extid, date) order by extid {}) as d1 join "
|
||||
"(select * from crxfile where simhash is not null and path like '%.js' order by crx_etag, path) as d2 using (crx_etag)) join "
|
||||
"(select * from libdet where size >= 1024) as d3 using (md5,typ) {}").format(
|
||||
"LIMIT " + str(int(extension_limit)) if extension_limit is not None else "",
|
||||
"LIMIT " + str(int(crxfile_limit)) if crxfile_limit is not None else "",
|
||||
)):
|
||||
yield row
|
||||
|
||||
def process(tup):
|
||||
(extid, date, crx_etag), rest = tup
|
||||
|
||||
result = []
|
||||
for (path,), rest2 in groupby_first(rest, 1):
|
||||
for size, typ, simhash in rest2:
|
||||
for ((_, (lib_path, lib_size, lib_typ)), diff) in gbl_bucket.query(simhash):
|
||||
result += [(extid, date, crx_etag, path, size, typ, lib_path, lib_size, lib_typ, diff)]
|
||||
return result
|
||||
|
||||
for (extid, date) in islice(db.execute("select extid, max(date) as date from extension group by extid order by extid"), extension_limit):
|
||||
for (crx_etag,) in db.execute("select crx_etag from extension where extid=? and date=? order by crx_etag", (extid, date)):
|
||||
for (path, md5, typ, simhash) in db.execute("select path, md5, typ, simhash from crxfile where crx_etag=? and simhash is not null and path like '%.js' order by path, md5, typ", (crx_etag,)):
|
||||
for (size,) in db.execute("select size from libdet where md5=? and typ=? and size >= 1024 order by size", (md5, typ)):
|
||||
yield ((extid, date, crx_etag, path, md5.hex(), typ, size), int(simhash))
|
||||
|
||||
def print_help():
|
||||
print("""simhashbucket [OPTION] <DB_PATH>""")
|
||||
|
@ -134,19 +147,17 @@ def print_help():
|
|||
print(""" --limit-cdnjs <N> only retrieve N rows""")
|
||||
print(""" --limit-extension <N> only retrieve N rows""")
|
||||
print(""" --limit-crxfile <N> only retrieve N rows""")
|
||||
print(""" -t <THREADS> number of parallel threads""")
|
||||
print(""" --tables <N> number of tables to use for the bucket (4 or 20 so far)""")
|
||||
|
||||
def parse_args(argv):
|
||||
limit_cdnjs = None
|
||||
limit_extension = None
|
||||
limit_crxfile = None
|
||||
parallel = cpu_count()
|
||||
tables = 20
|
||||
|
||||
try:
|
||||
opts, args = getopt.getopt(argv, "ht:", [
|
||||
"limit-cdnjs=", "limit-extension=", "limit-crxfile=", "help", "parallel=", "tables="])
|
||||
opts, args = getopt.getopt(argv, "h", [
|
||||
"limit-cdnjs=", "limit-extension=", "limit-crxfile=", "help", "tables="])
|
||||
except getopt.GetoptError:
|
||||
print_help()
|
||||
sys.exit(2)
|
||||
|
@ -157,8 +168,6 @@ def parse_args(argv):
|
|||
limit_extension = int(arg)
|
||||
elif opt == "--limit-crxfile":
|
||||
limit_crxfile = int(arg)
|
||||
elif opt in ("-t", "--parallel"):
|
||||
parallel = int(arg)
|
||||
elif opt == "--tables":
|
||||
tables = int(arg)
|
||||
|
||||
|
@ -167,31 +176,19 @@ def parse_args(argv):
|
|||
sys.exit(2)
|
||||
db_path = args[0]
|
||||
|
||||
return limit_cdnjs, limit_extension, limit_crxfile, parallel, tables, db_path
|
||||
|
||||
|
||||
def init(bucket):
|
||||
global gbl_bucket
|
||||
gbl_bucket = bucket
|
||||
return limit_cdnjs, limit_extension, limit_crxfile, tables, db_path
|
||||
|
||||
|
||||
def main(args):
|
||||
limit_cdnjs, limit_extension, limit_crxfile, parallel, tables, db_path = parse_args(args)
|
||||
bucket = SimhashBucket(tables)
|
||||
limit_cdnjs, limit_extension, limit_crxfile, tables, db_path = parse_args(args)
|
||||
|
||||
start_build = time.time()
|
||||
bucket.addMany(((int(row["simhash"]), (row["path"], row["size"], row["typ"])) for row in get_cdnjs_simhashes(db_path, limit_cdnjs)))
|
||||
sys.stderr.write(f"Building the bucket took {format(time.time() - start_build, '.2f')} seconds\n")
|
||||
fp_it = get_cdnjs_simhashes(db_path, limit_cdnjs)
|
||||
q_it = get_crxfile_simhashes(db_path, limit_extension, limit_crxfile)
|
||||
|
||||
start_query = time.time()
|
||||
|
||||
data = ((row["extid"], row["date"], row["crx_etag"], row["path"], row["size"], row["typ"], int(row["simhash"])) for row in get_crxfile_simhashes(db_path, limit_extension, limit_crxfile))
|
||||
with Pool(parallel, initializer=init, initargs=(bucket,)) as p:
|
||||
for tups in p.imap_unordered(process, groupby_first(data, 3), 100):
|
||||
for tup in tups:
|
||||
bucket = SimhashBucket(tables, fp_it, q_it)
|
||||
bucket.start()
|
||||
for tup in bucket:
|
||||
sys.stdout.write("|".join([str(x) for x in tup]) + "\n")
|
||||
sys.stdout.flush()
|
||||
sys.stderr.write(f"The query took {format(time.time() - start_query, '.2f')} seconds\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue