ExtensionCrawler/simhashbucket

196 lines
7.4 KiB
Python
Executable File

#!/usr/bin/env python3.6
#
# Copyright (C) 2018 The University of Sheffield, UK
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
import sys
import getopt
import sqlite3
from multiprocessing import Process, Queue, set_start_method
from itertools import islice, groupby
from operator import itemgetter
import heapq
def unique_justseen(iterable, key=None):
"List unique elements, preserving order. Remember only the element just seen."
# unique_justseen('AAAABBBCCDAABBB') --> A B C D A B
# unique_justseen('ABBCcAD', str.lower) --> A B C A D
return map(next, map(itemgetter(1), groupby(iterable, key)))
class SimhashTable(Process):
STOP = "stop"
def __init__(self, splitter, outqueue, fp_it, q_it):
super().__init__()
self.outqueue = outqueue
self.splitter = splitter
self.table = {}
self.fp_it = fp_it
self.q_it = q_it
@staticmethod
def bit_count(n):
return bin(n).count("1")
def get_chunk(self, n):
"""Reduces the simhash to a small chunk, given by self.splitters. The chunk will
then be compared exactly in order to increase performance."""
sum = 0
for (s, c) in self.splitter:
sum <<= c
sum += (n >> s) & (pow(2, c) - 1)
return sum
def _add(self, fp):
fp_chunk = self.get_chunk(fp[1])
if not fp_chunk in self.table:
self.table[fp_chunk] = []
self.table[fp_chunk] += [fp]
def _query(self, q):
q_chunk = self.get_chunk(q)
if q_chunk in self.table:
for fp in self.table[q_chunk]:
diff = SimhashTable.bit_count(q ^ fp[1])
if diff < 4:
yield (fp, diff)
def run(self):
for fp in self.fp_it:
self._add(fp)
for (q_info, q) in self.q_it:
for ((fp_info, fp), diff) in self._query(q):
self.outqueue.put((q_info, fp_info, diff))
self.outqueue.put(SimhashTable.STOP)
class SimhashBucket(Process):
"""Implementation of http://wwwconference.org/www2007/papers/paper215.pdf"""
def __init__(self, nr_of_tables, fp_it, q_it):
super().__init__()
# So far, we support the variants with 4 and 20 tables. Each element of splitters
# describes the key for one table. The first element of the tuple indicates the number
# of bits that we shift the simhash to the right; the second element indicates how many
# bits, from the right side, we end up taking.
if nr_of_tables == 4:
splitters = [[(0, 16)], [(16, 16)], [(32, 16)], [(48, 16)]]
elif nr_of_tables == 20:
block_sizes = [11, 11, 11, 11, 10, 10]
splitters = []
for i in range(0, len(block_sizes)):
for j in range(i + 1, len(block_sizes)):
for k in range(j + 1, len(block_sizes)):
splitters += [[
(sum(block_sizes[i+1:]), block_sizes[i]),
(sum(block_sizes[j+1:]), block_sizes[j]),
(sum(block_sizes[k+1:]), block_sizes[k]),
]]
else:
raise Exception(f"Unsupported number of tables: {nr_of_tables}")
self.fp_it = fp_it
self.q_it = q_it
self.splitters = splitters
self.tables = []
self.outqueues = [Queue(100) for _ in range(len(self.splitters))]
def run(self):
self.tables = [SimhashTable(splitter, outqueue, self.fp_it, self.q_it) for (outqueue, splitter) in zip(self.outqueues, self.splitters)]
for tbl in self.tables:
tbl.start()
for tbl in self.tables:
tbl.join()
def __iter__(self):
return unique_justseen(heapq.merge(*[iter(q.get, SimhashTable.STOP) for q in self.outqueues]))
def get_cdnjs_simhashes(db_path, limit=None):
with sqlite3.connect(db_path) as db:
for (simhash, library, path, size, typ, md5) in db.execute("select simhash, library, path, size, typ, md5 from cdnjs where "
"simhash IS NOT NULL AND path like '%.js' and "
"HEX(md5) <> 'd41d8cd98f00b204e9800998ecf8427e' order by path, size, typ, md5" +
(f" LIMIT {int(limit)}" if limit is not None else "")):
yield ((path, size, typ, md5.hex()), int(simhash))
def get_crxfile_simhashes(db_path, extension_limit=None, crxfile_limit=None):
with sqlite3.connect(db_path) as db:
for (extid, date) in islice(db.execute("select extid, max(date) as date from extension group by extid order by extid"), extension_limit):
for (crx_etag,) in db.execute("select crx_etag from extension where extid=? and date=? order by crx_etag", (extid, date)):
for (path, md5, typ, simhash) in db.execute("select path, md5, typ, simhash from crxfile where crx_etag=? and simhash is not null and path like '%.js' order by path, md5, typ", (crx_etag,)):
for (size,) in db.execute("select size from libdet where md5=? and typ=? and size >= 1024 order by size", (md5, typ)):
yield ((extid, date, crx_etag, path, md5.hex(), typ, size), int(simhash))
def print_help():
print("""simhashbucket [OPTION] <DB_PATH>""")
print(""" -h, --help print this help text""")
print(""" --limit-cdnjs <N> only retrieve N rows""")
print(""" --limit-extension <N> only retrieve N rows""")
print(""" --limit-crxfile <N> only retrieve N rows""")
print(""" --tables <N> number of tables to use for the bucket (4 or 20 so far)""")
def parse_args(argv):
limit_cdnjs = None
limit_extension = None
limit_crxfile = None
tables = 20
try:
opts, args = getopt.getopt(argv, "h", [
"limit-cdnjs=", "limit-extension=", "limit-crxfile=", "help", "tables="])
except getopt.GetoptError:
print_help()
sys.exit(2)
for opt, arg in opts:
if opt == "--limit-cdnjs":
limit_cdnjs = int(arg)
elif opt == "--limit-extension":
limit_extension = int(arg)
elif opt == "--limit-crxfile":
limit_crxfile = int(arg)
elif opt == "--tables":
tables = int(arg)
if len(args) != 1:
print_help()
sys.exit(2)
db_path = args[0]
return limit_cdnjs, limit_extension, limit_crxfile, tables, db_path
def main(args):
limit_cdnjs, limit_extension, limit_crxfile, tables, db_path = parse_args(args)
fp_it = get_cdnjs_simhashes(db_path, limit_cdnjs)
q_it = get_crxfile_simhashes(db_path, limit_extension, limit_crxfile)
bucket = SimhashBucket(tables, fp_it, q_it)
bucket.start()
for tup in bucket:
sys.stdout.write("|".join([str(x) for x in tup]) + "\n")
if __name__ == "__main__":
main(sys.argv[1:])