Merge branch 'master' of logicalhacking.com:BrowserSecurity/ExtensionCrawler

This commit is contained in:
Achim D. Brucker 2018-09-01 21:52:02 +01:00
commit c7419a2d9f
2 changed files with 108 additions and 63 deletions

View File

@ -172,12 +172,6 @@ def parse_args(argv):
def main(argv): def main(argv):
"""Main function of the extension crawler.""" """Main function of the extension crawler."""
# Use a separate process which forks new worker processes. This should make sure
# that processes which got created after running for some time also require only
# little memory. Details:
# https://docs.python.org/3.6/library/multiprocessing.html#contexts-and-start-methods
multiprocessing.set_start_method("forkserver")
today = datetime.datetime.now(datetime.timezone.utc).isoformat() today = datetime.datetime.now(datetime.timezone.utc).isoformat()
basedir, parallel, verbose, discover, max_discover, ext_timeout, start_pystuck = parse_args(argv) basedir, parallel, verbose, discover, max_discover, ext_timeout, start_pystuck = parse_args(argv)

View File

@ -17,12 +17,14 @@
# #
import sys import sys
import os
import getopt import getopt
import sqlite3 from multiprocessing import Process, Queue
from multiprocessing import Process, Queue, set_start_method
from itertools import islice, groupby from itertools import islice, groupby
from operator import itemgetter from operator import itemgetter
import heapq import heapq
import MySQLdb
from MySQLdb import cursors
def unique_justseen(iterable, key=None): def unique_justseen(iterable, key=None):
@ -32,17 +34,32 @@ def unique_justseen(iterable, key=None):
return map(next, map(itemgetter(1), groupby(iterable, key))) return map(next, map(itemgetter(1), groupby(iterable, key)))
def grouper(it, chunksize):
while True:
chunk = list(islice(it, chunksize))
yield chunk
if len(chunk) < chunksize:
break
def execute(q, args=None):
db = MySQLdb.connect(read_default_file=os.path.expanduser("~/.my.cnf"), cursorclass=cursors.SSCursor)
cursor = db.cursor()
cursor.execute(q, args)
return cursor
class SimhashTable(Process): class SimhashTable(Process):
STOP = "stop" STOP = "stop"
def __init__(self, splitter, outqueue, fp_it, q_it): def __init__(self, splitter, outqueue, fp_q, query_q):
super().__init__() super().__init__()
self.outqueue = outqueue self.outqueue = outqueue
self.splitter = splitter self.splitter = splitter
self.table = {} self.table = {}
self.fp_it = fp_it self.fp_q = fp_q
self.q_it = q_it self.query_q = query_q
@staticmethod @staticmethod
def bit_count(n): def bit_count(n):
@ -64,25 +81,28 @@ class SimhashTable(Process):
self.table[fp_chunk] += [fp] self.table[fp_chunk] += [fp]
def _query(self, q): def _query(self, q):
q_chunk = self.get_chunk(q) query_chunk = self.get_chunk(q)
if q_chunk in self.table: if query_chunk in self.table:
for fp in self.table[q_chunk]: for fp in self.table[query_chunk]:
diff = SimhashTable.bit_count(q ^ fp[1]) diff = SimhashTable.bit_count(q ^ fp[1])
if diff < 4: if diff < 4:
yield (fp, diff) yield (fp, diff)
def run(self): def run(self):
for fp in self.fp_it: for fps in iter(self.fp_q.get, SimhashTable.STOP):
for fp in fps:
self._add(fp) self._add(fp)
for (q_info, q) in self.q_it: for queries in iter(self.query_q.get, SimhashTable.STOP):
for (query_info, q) in queries:
for ((fp_info, fp), diff) in self._query(q): for ((fp_info, fp), diff) in self._query(q):
self.outqueue.put((q_info, fp_info, diff)) self.outqueue.put((query_info, fp_info, diff))
self.outqueue.put(SimhashTable.STOP) self.outqueue.put(SimhashTable.STOP)
class SimhashBucket(Process): class SimhashBucket(Process):
"""Implementation of http://wwwconference.org/www2007/papers/paper215.pdf""" """Implementation of http://wwwconference.org/www2007/papers/paper215.pdf"""
def __init__(self, nr_of_tables, fp_it, q_it): def __init__(self, nr_of_tables, fp_it, query_it):
super().__init__() super().__init__()
# So far, we support the variants with 4 and 20 tables. Each element of splitters # So far, we support the variants with 4 and 20 tables. Each element of splitters
# describes the key for one table. The first element of the tuple indicates the number # describes the key for one table. The first element of the tuple indicates the number
@ -105,55 +125,79 @@ class SimhashBucket(Process):
raise Exception(f"Unsupported number of tables: {nr_of_tables}") raise Exception(f"Unsupported number of tables: {nr_of_tables}")
self.fp_it = fp_it self.fp_it = fp_it
self.q_it = q_it self.query_it = query_it
self.splitters = splitters self.splitters = splitters
self.tables = [] self.tables = []
self.outqueues = [Queue(100) for _ in range(len(self.splitters))] self.fp_qs = [Queue(100) for _ in range(len(self.splitters))]
self.query_qs = [Queue(100) for _ in range(len(self.splitters))]
self.out_qs = [Queue(100) for _ in range(len(self.splitters))]
@staticmethod
def broadcast(it, qs, chunksize=1):
for x in grouper(it, chunksize):
for q in qs:
q.put(x)
for q in qs:
q.put(SimhashTable.STOP)
def run(self): def run(self):
self.tables = [SimhashTable(splitter, outqueue, self.fp_it, self.q_it) for (outqueue, splitter) in zip(self.outqueues, self.splitters)] self.tables = [SimhashTable(*args) for args in zip(self.splitters, self.out_qs, self.fp_qs, self.query_qs)]
for tbl in self.tables: for tbl in self.tables:
tbl.start() tbl.start()
SimhashBucket.broadcast(self.fp_it, self.fp_qs, 100)
SimhashBucket.broadcast(self.query_it, self.query_qs, 100)
for tbl in self.tables: for tbl in self.tables:
tbl.join() tbl.join()
def __iter__(self): def __iter__(self):
return unique_justseen(heapq.merge(*[iter(q.get, SimhashTable.STOP) for q in self.outqueues])) return unique_justseen(heapq.merge(*[iter(q.get, SimhashTable.STOP) for q in self.out_qs]))
def get_cdnjs_simhashes(db_path, limit=None): def get_cdnjs_simhashes(limit=None):
with sqlite3.connect(db_path) as db: for (simhash, path, typ, library, version) in execute((
for (simhash, library, path, size, typ, md5) in db.execute("select simhash, library, path, size, typ, md5 from cdnjs where " "select simhash, path, typ, library, version from "
"cdnjs where "
"simhash IS NOT NULL AND path like '%.js' and " "simhash IS NOT NULL AND path like '%.js' and "
"HEX(md5) <> 'd41d8cd98f00b204e9800998ecf8427e' order by path, size, typ, md5" + "HEX(md5) <> 'd41d8cd98f00b204e9800998ecf8427e' "
(f" LIMIT {int(limit)}" if limit is not None else "")): "order by path, typ {limit}")
yield ((path, size, typ, md5.hex()), int(simhash)) .format(
limit=f"limit {limit}" if limit else ""
)):
# Whatever gets returned here must be sorted
yield ((path, typ, library, version), int(simhash))
def get_crxfile_simhashes(db_path, extension_limit=None, crxfile_limit=None): def get_crxfile_simhashes(extension_limit=None, crxfile_limit=None):
with sqlite3.connect(db_path) as db: for (extid, date, crx_etag, path, typ, simhash) in execute((
for (extid, date) in islice(db.execute("select extid, max(date) as date from extension group by extid order by extid"), extension_limit): "select extid, date, crx_etag, path, typ, simhash from "
for (crx_etag,) in db.execute("select crx_etag from extension where extid=? and date=? order by crx_etag", (extid, date)): "(select * from extension_most_recent order by extid, date {extension_limit}) extension join "
for (path, md5, typ, simhash) in db.execute("select path, md5, typ, simhash from crxfile where crx_etag=? and simhash is not null and path like '%.js' order by path, md5, typ", (crx_etag,)): "(select * from crxfile order by crx_etag, path, typ {crxfile_limit}) crxfile using (crx_etag) "
for (size,) in db.execute("select size from libdet where md5=? and typ=? and size >= 1024 order by size", (md5, typ)): "join libdet using (md5, typ) "
yield ((extid, date, crx_etag, path, md5.hex(), typ, size), int(simhash)) "where simhash is not null and path like '%.js' and size >= 1024")
.format(
extension_limit=f"limit {extension_limit}" if extension_limit else "",
crxfile_limit=f"limit {crxfile_limit}" if crxfile_limit else ""
)):
# Whatever gets returned here must be sorted
yield ((extid, date, crx_etag, path, typ), int(simhash))
def print_help(): def print_help():
print("""simhashbucket [OPTION] <DB_PATH>""") print("""simhashbucket [OPTIONS]""")
print(""" -h, --help print this help text""") print(""" -h, --help print this help text""")
print(""" --limit-cdnjs <N> only retrieve N rows""") print(""" --limit-cdnjs <N> only retrieve N rows, default: all""")
print(""" --limit-extension <N> only retrieve N rows""") print(""" --limit-extension <N> only retrieve N rows, default: all""")
print(""" --limit-crxfile <N> only retrieve N rows""") print(""" --limit-crxfile <N> only retrieve N rows, default: all""")
print(""" --tables <N> number of tables to use for the bucket (4 or 20 so far)""") print(""" --tables <N> number of tables to use for the bucket (4 or 20 so far), default: 4""")
def parse_args(argv): def parse_args(argv):
limit_cdnjs = None limit_cdnjs = None
limit_extension = None limit_extension = None
limit_crxfile = None limit_crxfile = None
tables = 20 tables = 4
try: try:
opts, args = getopt.getopt(argv, "h", [ opts, args = getopt.getopt(argv, "h", [
@ -161,6 +205,7 @@ def parse_args(argv):
except getopt.GetoptError: except getopt.GetoptError:
print_help() print_help()
sys.exit(2) sys.exit(2)
try:
for opt, arg in opts: for opt, arg in opts:
if opt == "--limit-cdnjs": if opt == "--limit-cdnjs":
limit_cdnjs = int(arg) limit_cdnjs = int(arg)
@ -170,22 +215,28 @@ def parse_args(argv):
limit_crxfile = int(arg) limit_crxfile = int(arg)
elif opt == "--tables": elif opt == "--tables":
tables = int(arg) tables = int(arg)
elif opt in ["-h", "--help"]:
if len(args) != 1: print_help()
sys.exit(0)
except ValueError:
print("Arguments to int options must be an int!", file=sys.stderr)
print_help() print_help()
sys.exit(2) sys.exit(2)
db_path = args[0]
return limit_cdnjs, limit_extension, limit_crxfile, tables, db_path if len(args) != 0:
print_help()
sys.exit(2)
return limit_cdnjs, limit_extension, limit_crxfile, tables
def main(args): def main(args):
limit_cdnjs, limit_extension, limit_crxfile, tables, db_path = parse_args(args) limit_cdnjs, limit_extension, limit_crxfile, tables = parse_args(args)
fp_it = get_cdnjs_simhashes(db_path, limit_cdnjs) fp_it = get_cdnjs_simhashes(limit_cdnjs)
q_it = get_crxfile_simhashes(db_path, limit_extension, limit_crxfile) query_it = get_crxfile_simhashes(limit_extension, limit_crxfile)
bucket = SimhashBucket(tables, fp_it, q_it) bucket = SimhashBucket(tables, fp_it, query_it)
bucket.start() bucket.start()
for tup in bucket: for tup in bucket:
sys.stdout.write("|".join([str(x) for x in tup]) + "\n") sys.stdout.write("|".join([str(x) for x in tup]) + "\n")