tools: haskell-translator: use a Def class instead of a dict

This commit is contained in:
Corey Richardson 2016-01-29 15:36:15 +11:00
parent 307fa64568
commit a8aea960f3
No known key found for this signature in database
GPG Key ID: 990278AD76243314
2 changed files with 110 additions and 92 deletions

View File

@ -33,6 +33,22 @@ class Call(object):
self.body = False
class Def(object):
def __init__(self):
self.type = None
self.defined = None
self.body = None
self.line = None
self.sig = None
self.instance_proofs = []
self.instance_extras = []
self.comments = []
self.primrec = None
self.deriving = []
self.instance_defs = {}
def parse(call):
"""Parses a file."""
set_global(call)
@ -131,7 +147,7 @@ def top_transform(input):
# Forget about the comments for now
# defs_plus_comments = [(d['line'], d) for d in defs] + comments
# defs_plus_comments = [d.line, d) for d in defs] + comments
# defs_plus_comments.sort()
# defs = []
# prev_comments = []
@ -140,7 +156,7 @@ def top_transform(input):
# prev_comments.append(term[2])
# else:
# d = term[1]
# d['comments'] = prev_comments
# d.comments = prev_comments
# defs.append(d)
# prev_comments = []
@ -158,7 +174,7 @@ def get_lines(defs, call):
all the potential output generated at parse time."""
if call.restr:
defs = [d for d in defs if d['type'] == 'comments'
defs = [d for d in defs if d.type == 'comments'
or call.restr(d)]
output = []
@ -230,11 +246,11 @@ def group_defs(defs):
defgroups = []
defined = ''
for d in defs:
this_defines = d['defined']
if d['type'] != 'definitions':
this_defines = d.defined
if d.type != 'definitions':
this_defines = ''
if this_defines == defined and this_defines:
defgroups[-1]['body'].extend(d['body'])
defgroups[-1].body.extend(d.body)
else:
defgroups.append(d)
defined = this_defines
@ -251,7 +267,9 @@ def create_def(elt):
def create_def_2(line, children, n):
d = {'body': [(line, children)], 'line': n}
d = Def()
d.body = [(line, children)]
d.line = n
lead = line.split(None, 3)
if lead[0] in ['import', 'module', 'class']:
return
@ -265,8 +283,8 @@ def create_def_2(line, children, n):
type = 'definitions'
defined = lead[0]
d['type'] = type
d['defined'] = defined
d.type = type
d.defined = defined
return d
@ -284,28 +302,28 @@ def defs_transform(d):
may include its type signature, and may include the special
case of multiple definitions."""
# the first tokens of the first line in the first definition
if d['type'] == 'newtype':
if d.type == 'newtype':
return newtype_transform(d)
elif d['type'] == 'instance':
elif d.type == 'instance':
return instance_transform(d)
lead = d['body'][0][0].split(None, 2)
lead = d.body[0][0].split(None, 2)
if lead[1] == '::':
d['sig'] = type_sig_transform(d['body'][0])
d['body'] = d['body'][1:]
d.sig = type_sig_transform(d.body[0])
d.body.pop(0)
if d['defined'] in primrecs:
if d.defined in primrecs:
return primrec_transform(d)
if len(d['body']) > 1:
d['body'] = pattern_match_transform(d['body'])
if len(d.body) > 1:
d.body = pattern_match_transform(d.body)
if len(d['body']) == 0:
if len(d.body) == 0:
print()
print(d)
assert 0
d['body'] = body_transform(d['body'], d['defined'], d.get('sig', None))
d.body = body_transform(d.body, d.defined, d.sig)
return d
@ -313,33 +331,33 @@ def def_lines(d, call):
"""Produces the set of lines associated with a definition."""
if call.all_bits:
L = []
if 'comments' in d:
L.extend(flatten_tree(d['comments']))
if d.comments:
L.extend(flatten_tree(d.comments))
L.append('')
if d['type'] == 'definitions':
if d.type == 'definitions':
L.append('definition')
if 'sig' in d:
L.extend(flatten_tree([d['sig']]))
if d.sig:
L.extend(flatten_tree([d.sig]))
L.append('where')
L.extend(flatten_tree(d['body']))
elif d['type'] == 'newtype':
L.extend(flatten_tree(d['body']))
if 'instance_proofs' in d:
L.extend(flatten_tree(d['instance_proofs']))
L.extend(flatten_tree(d.body))
elif d.type == 'newtype':
L.extend(flatten_tree(d.body))
if d.instance_proofs:
L.extend(flatten_tree(d.instance_proofs))
L.append('')
if 'instance_extras' in d:
L.extend(flatten_tree(d['instance_extras']))
if d.instance_extras:
L.extend(flatten_tree(d.instance_extras))
L.append('')
return L
if call.instanceproofs:
if not call.bodies_only:
instance_proofs = flatten_tree(d.get('instance_proofs', []))
instance_proofs = flatten_tree(d.instance_proofs)
else:
instance_proofs = []
if not call.decls_only:
instance_extras = flatten_tree(d.get('instance_extras', []))
instance_extras = flatten_tree(d.instance_extras)
else:
instance_extras = []
@ -350,13 +368,13 @@ def def_lines(d, call):
if call.body:
return get_lambda_body_lines(d)
comments = d.get('comments', [])
comments = d.comments
try:
typesig = flatten_tree([d['sig']])
typesig = flatten_tree([d.sig])
except:
typesig = []
body = flatten_tree(d['body'])
type = d['type']
body = flatten_tree(d.body)
type = d.type
if type == 'definitions':
if call.decls_only:
@ -365,9 +383,9 @@ def def_lines(d, call):
else:
return []
elif call.bodies_only:
if 'sig' in d:
defname = '%s_def' % d['defined']
if 'primrec' in d:
if d.sig:
defname = '%s_def' % d.defined
if d.primrec:
print('warning body-only primrec:')
print(body[0])
return comments + ['primrec'] + body
@ -375,7 +393,7 @@ def def_lines(d, call):
else:
return comments + ['definition'] + body
else:
if 'primrec' in d:
if d.primrec:
return comments + ['primrec'] + typesig \
+ ['where'] + body
if typesig:
@ -579,10 +597,10 @@ def newtype_transform(d):
options are divided with | and each of whose options has named
elements, and forms a datatype statement and definitions for
the named extractors and their update functions."""
if len(d['body']) != 1:
if len(d.body) != 1:
print('--- newtype long body ---')
print(d)
[(line, children)] = d['body']
[(line, children)] = d.body
if children and children[-1][0].lstrip().startswith('deriving'):
l = reduce_to_single_line(children[-1])
@ -590,7 +608,7 @@ def newtype_transform(d):
r = re.compile(r"[,\s\(\)]+")
bits = r.split(l)
bits = [bit for bit in bits if bit and bit != 'deriving']
d['deriving'] = bits
d.deriving = bits
line = reduce_to_single_line((line, children))
@ -599,11 +617,11 @@ def newtype_transform(d):
line = bits[1]
bits = line.split('=', 1)
header = type_conv(bits[0].strip())
d['typename'] = header
d['typedeps'] = set()
d.typename = header
d.typedeps = set()
if len(bits) == 1:
# line of form 'data Blah' introduces unknown type?
d['body'] = [('typedecl %s' % header, [])]
d.body = [('typedecl %s' % header, [])]
all_type_arities[header] = [] # HACK
return d
line = bits[1]
@ -622,7 +640,7 @@ def typename_transform(line, header, d):
try:
[oldtype] = line.split()
except:
sys.stderr.write('Warning: type assignment with parameters not supported %s\n' % d['body'])
sys.stderr.write('Warning: type assignment with parameters not supported %s\n' % d.body)
call.bad_type_assignment = True
return
if oldtype.startswith('Data.Word.Word'):
@ -631,13 +649,13 @@ def typename_transform(line, header, d):
oldtype = type_conv(oldtype)
bits = oldtype.split()
for bit in bits:
d['typedeps'].add(bit)
d.typedeps.add(bit)
lines = [
'type_synonym %s = "%s"' % (header, oldtype),
# translations (* TYPE 1 *)',
# "%s" <=(type) "%s"' % (oldtype, header)
]
d['body'] = [(line, []) for line in lines]
d.body = [(line, []) for line in lines]
return d
@ -668,7 +686,7 @@ def simple_newtype_transform(line, header, d):
if ' ' in typename:
typename = '"%s"' % typename
l = l + ' ' + typename
d['typedeps'].add(typename)
d.typedeps.add(typename)
lines.append(l)
arities.append((str(bits[0]), len(bits[1:])))
@ -676,7 +694,7 @@ def simple_newtype_transform(line, header, d):
if list((dict(arities)).values()) == [1] and header not in dontwrap:
return type_wrapper_type(header, last_lhs, last_rhs, d)
d['body'] = [('datatype %s =' % header, [(line, []) for line in lines])]
d.body = [('datatype %s =' % header, [(line, []) for line in lines])]
set_instance_proofs(header, arities, d)
@ -704,7 +722,7 @@ def named_newtype_transform(line, header, d):
else:
l = l + ' "' + type + '"'
for bit in type.split():
d['typedeps'].add(bit)
d.typedeps.add(bit)
lines.append(l)
names = {}
@ -752,7 +770,7 @@ def named_newtype_transform(line, header, d):
set_instance_proofs(header, arities, d)
d['body'] = [('datatype %s =' % header, [(line, []) for line in lines])]
d.body = [('datatype %s =' % header, [(line, []) for line in lines])]
return d
@ -866,8 +884,8 @@ def named_constructor_check(name, map, header):
def type_wrapper_type(header, cons, rhs, d, decons=None):
if '\\<Rightarrow>' in d['typedeps']:
d['body'] = [('(* type declaration of %s omitted *)' % header, [])]
if '\\<Rightarrow>' in d.typedeps:
d.body = [('(* type declaration of %s omitted *)' % header, [])]
return d
lines = [
'type_synonym %s = "%s"' % (header, rhs),
@ -900,12 +918,12 @@ def type_wrapper_type(header, cons, rhs, d, decons=None):
lines.extend(named_constructor_translation(cons, [(decons, decons_type)
], header))
d['body'] = [(line, []) for line in lines]
d.body = [(line, []) for line in lines]
return d
def instance_transform(d):
[(line, children)] = d['body']
[(line, children)] = d.body
bits = line.split(None, 3)
assert bits[0] == 'instance'
classname = bits[1]
@ -933,11 +951,11 @@ def instance_transform(d):
defs_dict = {}
for d2 in defs:
if d2 is not None:
defs_dict[d2['defined']] = d2
d['instance_defs'] = defs_dict
d['deriving'] = [classname]
defs_dict[d2.defined] = d2
d.instance_defs = defs_dict
d.deriving = [classname]
if typename not in all_type_arities:
sys.stderr.write('FAIL: attempting %s\n' % d['defined'])
sys.stderr.write('FAIL: attempting %s\n' % d.defined)
sys.stderr.write('(typename %r)\n' % typename)
sys.stderr.write('when reading %s\n' % filename)
sys.stderr.write('but class not defined yet\n')
@ -959,7 +977,7 @@ def set_instance_proofs(header, constructor_arities, d):
exs = []
canonical = list(enumerate(constructor_arities))
classes = d.get('deriving', [])
classes = d.deriving
instance_proof_fns = set(
sorted((instance_proof_table[classname] for classname in classes),
key=lambda x: x.order))
@ -968,17 +986,17 @@ def set_instance_proofs(header, constructor_arities, d):
pfs.extend(npfs)
exs.extend(nexs)
if d['type'] == 'newtype' and len(canonical) == 1 and False:
if d.type == 'newtype' and len(canonical) == 1 and False:
[(cons, n)] = constructor_arities
if n == 1:
pfs.extend(finite_instance_proofs(header, cons))
if pfs:
lead = '(* %s instance proofs *)' % header
d['instance_proofs'] = [(lead, [(line, []) for line in pfs])]
d.instance_proofs = [(lead, [(line, []) for line in pfs])]
if exs:
lead = '(* %s extra instance defs *)' % header
d['instance_extras'] = [(lead, [(line, []) for line in exs])]
d.instance_extras = [(lead, [(line, []) for line in exs])]
def finite_instance_proofs(header, cons):
@ -1015,11 +1033,11 @@ def storable_instance_proofs(header, canonical, d):
proofs.append('')
proofs.append('instance %s :: storable ..' % header)
defs = d.get('instance_defs', {})
defs = d.instance_defs
extradefs.append('')
if 'objBits' in defs:
extradefs.append('definition')
body = flatten_tree(defs['objBits']['body'])
body = flatten_tree(defs['objBits'].body)
bits = body[0].split('objBits')
assert bits[0].strip() == '"'
if bits[1].strip().startswith('_'):
@ -1032,7 +1050,7 @@ def storable_instance_proofs(header, canonical, d):
extradefs.append('')
if 'makeObject' in defs:
extradefs.append('definition')
body = flatten_tree(defs['makeObject']['body'])
body = flatten_tree(defs['makeObject'].body)
bits = body[0].split('makeObject')
assert bits[0].strip() == '"'
body[0] = ' makeObject_%s: "(makeObject :: %s) %s' \
@ -1042,7 +1060,7 @@ def storable_instance_proofs(header, canonical, d):
extradefs.extend(['', 'definition', ])
if 'loadObject' in defs:
extradefs.append(' loadObject_%s:' % header)
extradefs.extend(flatten_tree(defs['loadObject']['body']))
extradefs.extend(flatten_tree(defs['loadObject'].body))
else:
extradefs.extend([
' loadObject_%s[simp]:' % header,
@ -1053,7 +1071,7 @@ def storable_instance_proofs(header, canonical, d):
extradefs.extend(['', 'definition', ])
if 'updateObject' in defs:
extradefs.append(' updateObject_%s:' % header)
body = flatten_tree(defs['updateObject']['body'])
body = flatten_tree(defs['updateObject'].body)
bits = body[0].split('updateObject')
assert bits[0].strip() == '"'
bits = bits[1].split(None, 1)
@ -1083,11 +1101,11 @@ def pspace_storable_instance_proofs(header, canonical, d):
proofs.append(
' auto simp: projectKO_opts_defs split: kernel_object.splits arch_kernel_object.splits)')
defs = d.get('instance_defs', {})
defs = d.instance_defs
extradefs.append('')
if 'objBits' in defs:
extradefs.append('definition')
body = flatten_tree(defs['objBits']['body'])
body = flatten_tree(defs['objBits'].body)
bits = body[0].split('objBits')
assert bits[0].strip() == '"'
if bits[1].strip().startswith('_'):
@ -1100,7 +1118,7 @@ def pspace_storable_instance_proofs(header, canonical, d):
extradefs.append('')
if 'makeObject' in defs:
extradefs.append('definition')
body = flatten_tree(defs['makeObject']['body'])
body = flatten_tree(defs['makeObject'].body)
bits = body[0].split('makeObject')
assert bits[0].strip() == '"'
body[0] = ' makeObject_%s: "(makeObject :: %s) %s' \
@ -1110,7 +1128,7 @@ def pspace_storable_instance_proofs(header, canonical, d):
extradefs.extend(['', 'definition', ])
if 'loadObject' in defs:
extradefs.append(' loadObject_%s:' % header)
extradefs.extend(flatten_tree(defs['loadObject']['body']))
extradefs.extend(flatten_tree(defs['loadObject'].body))
else:
extradefs.extend([
' loadObject_%s[simp]:' % header,
@ -1121,7 +1139,7 @@ def pspace_storable_instance_proofs(header, canonical, d):
extradefs.extend(['', 'definition', ])
if 'updateObject' in defs:
extradefs.append(' updateObject_%s:' % header)
body = flatten_tree(defs['updateObject']['body'])
body = flatten_tree(defs['updateObject'].body)
bits = body[0].split('updateObject')
assert bits[0].strip() == '"'
bits = bits[1].split(None, 1)
@ -2478,11 +2496,11 @@ def detect_recursion(body):
def primrec_transform(d):
sig = d.get('sig', None)
defn = d['defined']
sig = d.sig
defn = d.defined
body = []
is_not_first = False
for (l, c) in d['body']:
for (l, c) in d.body:
[(l, c)] = body_transform([(l, c)], defn, sig, nopattern=True)
if is_not_first:
l = "| " + l
@ -2495,8 +2513,8 @@ def primrec_transform(d):
(l, c) = remove_trailing_string('"', (l, c))
(l, c) = add_trailing_string(')"', (l, c))
body.append((l, c))
d['primrec'] = True
d['body'] = body
d.primrec = True
d.body = body
return d
@ -2601,9 +2619,9 @@ def pattern_match_transform(body):
def get_lambda_body_lines(d):
"""Returns lines equivalent to the body of the function as
a lambda expression."""
fn = d['defined']
fn = d.defined
[(line, children)] = d['body']
[(line, children)] = d.body
line = line[1:]
# find \<equiv> in first or 2nd line
@ -2671,27 +2689,27 @@ def has_trailing_string(s, xxx_todo_changeme11):
def ensure_type_ordering(defs):
typedefs = [d for d in defs if d['type'] == 'newtype']
other = [d for d in defs if d['type'] != 'newtype']
typedefs = [d for d in defs if d.type == 'newtype']
other = [d for d in defs if d.type != 'newtype']
final_typedefs = []
while typedefs:
try:
i = 0
deps = typedefs[i]['typedeps']
deps = typedefs[i].typedeps
while 1:
for j, term in enumerate(typedefs):
if term['typename'] in deps:
if term.typename in deps:
break
else:
break
i = j
deps = typedefs[i]['typedeps']
deps = typedefs[i].typedeps
final_typedefs.append(typedefs.pop(i))
except Exception as e:
print('Exception hit ordering types:')
for td in typedefs:
print(' - %s' % td['typename'])
print(' - %s' % td.typename)
raise e
return final_typedefs + other

View File

@ -63,16 +63,16 @@ for line in instructions:
if 'ONLY' in bits:
n = bits.index('ONLY')
m = set(bits[n + 1:])
call.restr = lambda x: x['defined'] in m
call.restr = lambda x: x.defined in m
elif 'NOT' in bits:
n = bits.index('NOT')
m = set(bits[n + 1:])
call.restr = lambda x: not x['defined'] in m
call.restr = lambda x: not x.defined in m
elif 'BODY' in bits:
call.body = True
assert bits[-2] == 'BODY'
fn = bits[-1]
call.restr = lambda x: x['defined'] == fn
call.restr = lambda x: x.defined == fn
try:
parsed = lhs_pars.parse(call)