247 lines
8.6 KiB
Python
Executable File
247 lines
8.6 KiB
Python
Executable File
#!/usr/bin/env python3.6
|
|
#
|
|
# Copyright (C) 2018 The University of Sheffield, UK
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
#
|
|
|
|
import sys
|
|
import os
|
|
import getopt
|
|
from multiprocessing import Process, Queue
|
|
from itertools import islice, groupby
|
|
from operator import itemgetter
|
|
import heapq
|
|
import MySQLdb
|
|
from MySQLdb import cursors
|
|
|
|
|
|
def unique_justseen(iterable, key=None):
|
|
"List unique elements, preserving order. Remember only the element just seen."
|
|
# unique_justseen('AAAABBBCCDAABBB') --> A B C D A B
|
|
# unique_justseen('ABBCcAD', str.lower) --> A B C A D
|
|
return map(next, map(itemgetter(1), groupby(iterable, key)))
|
|
|
|
|
|
def grouper(it, chunksize):
|
|
while True:
|
|
chunk = list(islice(it, chunksize))
|
|
yield chunk
|
|
if len(chunk) < chunksize:
|
|
break
|
|
|
|
|
|
def execute(q, args=None):
|
|
db = MySQLdb.connect(read_default_file=os.path.expanduser("~/.my.cnf"), cursorclass=cursors.SSCursor)
|
|
cursor = db.cursor()
|
|
cursor.execute(q, args)
|
|
return cursor
|
|
|
|
|
|
class SimhashTable(Process):
|
|
STOP = "stop"
|
|
|
|
def __init__(self, splitter, outqueue, fp_q, query_q):
|
|
super().__init__()
|
|
self.outqueue = outqueue
|
|
|
|
self.splitter = splitter
|
|
self.table = {}
|
|
self.fp_q = fp_q
|
|
self.query_q = query_q
|
|
|
|
@staticmethod
|
|
def bit_count(n):
|
|
return bin(n).count("1")
|
|
|
|
def get_chunk(self, n):
|
|
"""Reduces the simhash to a small chunk, given by self.splitters. The chunk will
|
|
then be compared exactly in order to increase performance."""
|
|
sum = 0
|
|
for (s, c) in self.splitter:
|
|
sum <<= c
|
|
sum += (n >> s) & (pow(2, c) - 1)
|
|
return sum
|
|
|
|
def _add(self, fp):
|
|
fp_chunk = self.get_chunk(fp[1])
|
|
if not fp_chunk in self.table:
|
|
self.table[fp_chunk] = []
|
|
self.table[fp_chunk] += [fp]
|
|
|
|
def _query(self, q):
|
|
query_chunk = self.get_chunk(q)
|
|
if query_chunk in self.table:
|
|
for fp in self.table[query_chunk]:
|
|
diff = SimhashTable.bit_count(q ^ fp[1])
|
|
if diff < 4:
|
|
yield (fp, diff)
|
|
|
|
def run(self):
|
|
for fps in iter(self.fp_q.get, SimhashTable.STOP):
|
|
for fp in fps:
|
|
self._add(fp)
|
|
for queries in iter(self.query_q.get, SimhashTable.STOP):
|
|
for (query_info, q) in queries:
|
|
for ((fp_info, fp), diff) in self._query(q):
|
|
self.outqueue.put((query_info, fp_info, diff))
|
|
self.outqueue.put(SimhashTable.STOP)
|
|
|
|
|
|
class SimhashBucket(Process):
|
|
"""Implementation of http://wwwconference.org/www2007/papers/paper215.pdf"""
|
|
|
|
def __init__(self, nr_of_tables, fp_it, query_it):
|
|
super().__init__()
|
|
# So far, we support the variants with 4 and 20 tables. Each element of splitters
|
|
# describes the key for one table. The first element of the tuple indicates the number
|
|
# of bits that we shift the simhash to the right; the second element indicates how many
|
|
# bits, from the right side, we end up taking.
|
|
if nr_of_tables == 4:
|
|
splitters = [[(0, 16)], [(16, 16)], [(32, 16)], [(48, 16)]]
|
|
elif nr_of_tables == 20:
|
|
block_sizes = [11, 11, 11, 11, 10, 10]
|
|
splitters = []
|
|
for i in range(0, len(block_sizes)):
|
|
for j in range(i + 1, len(block_sizes)):
|
|
for k in range(j + 1, len(block_sizes)):
|
|
splitters += [[
|
|
(sum(block_sizes[i+1:]), block_sizes[i]),
|
|
(sum(block_sizes[j+1:]), block_sizes[j]),
|
|
(sum(block_sizes[k+1:]), block_sizes[k]),
|
|
]]
|
|
else:
|
|
raise Exception(f"Unsupported number of tables: {nr_of_tables}")
|
|
|
|
self.fp_it = fp_it
|
|
self.query_it = query_it
|
|
|
|
self.splitters = splitters
|
|
self.tables = []
|
|
self.fp_qs = [Queue(100) for _ in range(len(self.splitters))]
|
|
self.query_qs = [Queue(100) for _ in range(len(self.splitters))]
|
|
self.out_qs = [Queue(100) for _ in range(len(self.splitters))]
|
|
|
|
@staticmethod
|
|
def broadcast(it, qs, chunksize=1):
|
|
for x in grouper(it, chunksize):
|
|
for q in qs:
|
|
q.put(x)
|
|
for q in qs:
|
|
q.put(SimhashTable.STOP)
|
|
|
|
def run(self):
|
|
self.tables = [SimhashTable(*args) for args in zip(self.splitters, self.out_qs, self.fp_qs, self.query_qs)]
|
|
for tbl in self.tables:
|
|
tbl.start()
|
|
|
|
SimhashBucket.broadcast(self.fp_it, self.fp_qs, 100)
|
|
SimhashBucket.broadcast(self.query_it, self.query_qs, 100)
|
|
|
|
for tbl in self.tables:
|
|
tbl.join()
|
|
|
|
def __iter__(self):
|
|
return unique_justseen(heapq.merge(*[iter(q.get, SimhashTable.STOP) for q in self.out_qs]))
|
|
|
|
|
|
def get_cdnjs_simhashes(limit=None):
|
|
for (simhash, path, typ, library, version) in execute((
|
|
"select simhash, path, typ, library, version from "
|
|
"cdnjs where "
|
|
"simhash IS NOT NULL AND path like '%.js' and "
|
|
"HEX(md5) <> 'd41d8cd98f00b204e9800998ecf8427e' "
|
|
"order by path, typ {limit}")
|
|
.format(
|
|
limit=f"limit {limit}" if limit else ""
|
|
)):
|
|
# Whatever gets returned here must be sorted
|
|
yield ((path, typ, library, version), int(simhash))
|
|
|
|
|
|
def get_crxfile_simhashes(extension_limit=None, crxfile_limit=None):
|
|
for (extid, date, crx_etag, path, typ, simhash) in execute((
|
|
"select extid, date, crx_etag, path, typ, simhash from "
|
|
"(select * from extension_most_recent order by extid, date {extension_limit}) extension join "
|
|
"(select * from crxfile order by crx_etag, path, typ {crxfile_limit}) crxfile using (crx_etag) "
|
|
"join libdet using (md5, typ) "
|
|
"where simhash is not null and path like '%.js' and size >= 1024")
|
|
.format(
|
|
extension_limit=f"limit {extension_limit}" if extension_limit else "",
|
|
crxfile_limit=f"limit {crxfile_limit}" if crxfile_limit else ""
|
|
)):
|
|
# Whatever gets returned here must be sorted
|
|
yield ((extid, date, crx_etag, path, typ), int(simhash))
|
|
|
|
|
|
def print_help():
|
|
print("""simhashbucket [OPTIONS]""")
|
|
print(""" -h, --help print this help text""")
|
|
print(""" --limit-cdnjs <N> only retrieve N rows, default: all""")
|
|
print(""" --limit-extension <N> only retrieve N rows, default: all""")
|
|
print(""" --limit-crxfile <N> only retrieve N rows, default: all""")
|
|
print(""" --tables <N> number of tables to use for the bucket (4 or 20 so far), default: 4""")
|
|
|
|
def parse_args(argv):
|
|
limit_cdnjs = None
|
|
limit_extension = None
|
|
limit_crxfile = None
|
|
tables = 4
|
|
|
|
try:
|
|
opts, args = getopt.getopt(argv, "h", [
|
|
"limit-cdnjs=", "limit-extension=", "limit-crxfile=", "help", "tables="])
|
|
except getopt.GetoptError:
|
|
print_help()
|
|
sys.exit(2)
|
|
try:
|
|
for opt, arg in opts:
|
|
if opt == "--limit-cdnjs":
|
|
limit_cdnjs = int(arg)
|
|
elif opt == "--limit-extension":
|
|
limit_extension = int(arg)
|
|
elif opt == "--limit-crxfile":
|
|
limit_crxfile = int(arg)
|
|
elif opt == "--tables":
|
|
tables = int(arg)
|
|
elif opt in ["-h", "--help"]:
|
|
print_help()
|
|
sys.exit(0)
|
|
except ValueError:
|
|
print("Arguments to int options must be an int!", file=sys.stderr)
|
|
print_help()
|
|
sys.exit(2)
|
|
|
|
if len(args) != 0:
|
|
print_help()
|
|
sys.exit(2)
|
|
|
|
return limit_cdnjs, limit_extension, limit_crxfile, tables
|
|
|
|
|
|
def main(args):
|
|
limit_cdnjs, limit_extension, limit_crxfile, tables = parse_args(args)
|
|
|
|
fp_it = get_cdnjs_simhashes(limit_cdnjs)
|
|
query_it = get_crxfile_simhashes(limit_extension, limit_crxfile)
|
|
|
|
bucket = SimhashBucket(tables, fp_it, query_it)
|
|
bucket.start()
|
|
for tup in bucket:
|
|
sys.stdout.write("|".join([str(x) for x in tup]) + "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(sys.argv[1:])
|