WIP: autocorres: refactor add_defs

This commit is contained in:
Japheth Lim 2016-06-28 17:24:44 +10:00
parent 6c35cf176a
commit 2f53afd90b
5 changed files with 223 additions and 336 deletions

View File

@ -430,4 +430,117 @@ fun define_funcs
(results, accum, ctxt)
end
(* Utility for doing conversions in parallel.
* The conversion of each function f should depend only on the previous
* define phase for f (which necessarily also includes f's callees). *)
type convert_result = {
body: term, (* new body *)
proof: thm, (* corres thm *)
rec_callees: symset, (* minimal rec_callees after translation *)
callee_consts: term Symtab.table, (* assumed frees for other callees *)
arg_frees: (string * typ) list, (* fixed argument frees, including measure *)
traces: (string * AutoCorresData.Trace) list (* traces *)
}
fun par_convert
(* Worker: lthy -> function_infos for func and callees -> fn_name -> results *)
(convert: local_theory -> FunctionInfo2.function_info Symtab.table ->
string -> convert_result)
(prev_results: FunctionInfo2.phase_results)
(* Return converted functions in recursive groups.
* The groups are tagged with fn_infos from prev_results to identify them. *)
: (FunctionInfo2.function_info Symtab.table * convert_result Symtab.table) FSeq.fseq =
(* Knowing that prev_results is in topological order,
* we accumulate its function_infos, which will be a superset
* of the callee infos that each conversion requires. *)
FSeq.fold_map (fn fn_infos_accum => fn (lthy, fn_infos) => let
val fn_infos_accum = Symtab.merge (K false) (fn_infos_accum, fn_infos);
(* Convert fn_infos in parallel, but join the results right away.
* This is fine because we will define them together. *)
val conv_results =
Symtab.dest fn_infos
|> Par_List.map (fn (f, _) =>
(f, convert lthy fn_infos_accum f))
|> Symtab.make;
in ((fn_infos, conv_results), fn_infos_accum) end)
Symtab.empty prev_results;
(* Utility for defining functions.
*
* Definitions update the theory sequentially.
* Each definition step produces a lthy that contains the current function
* group, and can immediately be used in the next conversion phase for
* those functions. Hence we return the intermediate lthys as futures.
*
* The actual recursive function groups may be finer-grained than in
* converted_groups, as function calls can be removed by dead code
* elimination and other transformations. Hence we detect the actual
* function groups before defining them.
*
* FIXME: this currently discards traces *)
fun define_funcs_sequence
(lthy: local_theory)
(define_worker: local_theory ->
(* previous infos for functions *)
FunctionInfo2.function_info Symtab.table ->
(* new infos for callees *)
FunctionInfo2.function_info Symtab.table ->
(* data for functions *)
convert_result Symtab.table ->
(* new infos for functions *)
FunctionInfo2.function_info Symtab.table * local_theory)
(* accumulator, initially empty *)
(defined_so_far: FunctionInfo2.function_info Symtab.table)
(converted_groups: (FunctionInfo2.function_info Symtab.table *
convert_result Symtab.table) FSeq.fseq)
: FunctionInfo2.phase_results =
FSeq.mk (fn () =>
case FSeq.uncons converted_groups of
NONE => NONE
| SOME ((prev_infos, conv_group), remaining_groups) => SOME let
val (call_graph, f_convs) =
Symtab.dest conv_group
|> map (fn (f, result) =>
((f, #rec_callees result), (f, result)))
|> split_list;
val f_convs = Symtab.make f_convs;
(* Split this recursive group prior to defining the functions.
* Function calls can be removed by dead code elimination and other
* transformations, which could cause the group to split. *)
val f_callees = Symtab.make (map (apsnd Symset.dest) call_graph);
(* Add each function to its own callees to get a complete inverse *)
val f_callers = flip_symtab (Symtab.map cons f_callees);
val topo_sorted_functions =
Topo_Sort.topo_sort {
cmp = String.compare,
graph = Symtab.lookup f_callees #> the,
converse = Symtab.lookup f_callers #> the
} (Symtab.keys f_callees |> sort String.compare)
|> map Symset.make;
val f_convss =
topo_sorted_functions |> map (fn group =>
map (fn f => (f, the (Symtab.lookup f_convs f)))
(Symset.dest group)
|> Symtab.make);
(* Define each function group and append it to the result sequence. *)
fun add_subgroup_defs lthy defined_so_far [] =
define_funcs_sequence lthy define_worker defined_so_far remaining_groups
| add_subgroup_defs lthy defined_so_far (f_convs :: f_convss) =
FSeq.fcons (fn () => let
val (new_defs, lthy') =
define_worker lthy prev_infos defined_so_far f_convs;
(* Update (non-recursive) callees *)
val new_defs = FunctionInfo2.recalc_base_callees defined_so_far new_defs;
val defined_so_far' = Symtab.merge (K false) (defined_so_far, new_defs);
in ((lthy', new_defs),
add_subgroup_defs lthy' defined_so_far' f_convss)
end);
in (* a bit ugly -- may need to tweak the FSeq interface *)
the (FSeq.uncons (add_subgroup_defs lthy defined_so_far f_convss)) end);
end

View File

@ -676,7 +676,7 @@ let
val is_function_lifted = String_Memo.memo is_function_lifted;
(* Convert to new heap format. *)
fun convert lthy l2_infos f =
fun convert lthy l2_infos f: AutoCorresUtil2.convert_result =
let
val f_l2_info = the (Symtab.lookup l2_infos f);
@ -819,12 +819,14 @@ let
val callee_consts =
callee_terms |> map (fn (callee, (_, const, _)) => (callee, const)) |> Symtab.make;
in
(f_body,
Morphism.thm export_thm thm,
rec_callees,
callee_consts,
map dest_Free (measure_var :: arg_frees),
(if member (op =) trace_funcs f then [("HL", AutoCorresData.RuleTrace trace)] else []) @ opt_traces)
{ body = f_body,
proof = Morphism.thm export_thm thm,
rec_callees = rec_callees,
callee_consts = callee_consts,
arg_frees = map dest_Free (measure_var :: arg_frees),
traces = (if member (op =) trace_funcs f
then [("HL", AutoCorresData.RuleTrace trace)] else []) @ opt_traces
}
end
(* Define a previously-converted function (or recursive function group).
@ -833,13 +835,13 @@ let
(lthy: local_theory)
(l2_infos: FunctionInfo2.function_info Symtab.table)
(hl_callees: FunctionInfo2.function_info Symtab.table)
(* name, raw body, callee consts, corres thm, arg frees *)
(funcs: (string * (term * term Symtab.table * thm * (string * typ) list)) list)
(funcs: AutoCorresUtil2.convert_result Symtab.table)
: FunctionInfo2.function_info Symtab.table * local_theory = let
val funcs' = funcs |>
map (fn (name, def as (body, callee_consts, thm, frees)) =>
(* FIXME: abstract_fn_body should really be moved into define_funcs *)
(name, ((AutoCorresUtil2.abstract_fn_body l2_infos (name, def), thm, frees))));
val funcs' = Symtab.dest funcs |>
map (fn (name, {body, proof, callee_consts, arg_frees, ...}) =>
(name, ((AutoCorresUtil2.abstract_fn_body l2_infos
(name, (body, callee_consts, proof, arg_frees))),
proof, arg_frees)));
val (new_thms, (), lthy') =
AutoCorresUtil2.define_funcs FunctionInfo2.HL filename l2_infos
hl_function_name
@ -860,81 +862,18 @@ let
end) new_thms;
in (new_infos, lthy') end;
(* All conversions can run in parallel.
* Each conversion depends only on the corresponding L2 define phase
* (which necessarily also includes L2 callees). *)
val converted_groups =
FSeq.fold_map (fn l2_infos_accum => fn (lthy, l2_infos) => let
val l2_infos_accum' = Symtab.merge (K false) (l2_infos_accum, l2_infos);
(* Convert l2_infos in parallel, but join the results right away.
* This is fine because we will define them together. *)
val conv_results =
Symtab.dest l2_infos
|> Par_List.map (fn (f, _) => (f, convert lthy l2_infos_accum' f))
|> Symtab.make;
in ((l2_infos, conv_results), l2_infos_accum') end)
Symtab.empty l2_results;
(* Do conversions in parallel. *)
val converted_groups = AutoCorresUtil2.par_convert convert l2_results;
(* Definitions update lthy sequentially.
* Each definition step produces a lthy that has a prefix of the update sequence,
* and can be used in an L2 translation that depends only on that prefix.
* Hence we return the intermediate lthys as futures. *)
fun add_defs lthy defined_so_far converted_groups = FSeq.mk (fn () =>
case FSeq.uncons converted_groups of
NONE => NONE
| SOME ((l2_infos, conv_group), remaining_groups) => SOME let
val (call_graph, f_convs) =
Symtab.dest conv_group
|> map (fn (f, (hl_body, corres_thm, rec_callees, callee_consts, arg_frees, traces)) =>
(* FIXME: return traces *)
((f, rec_callees), (f, (hl_body, callee_consts, corres_thm, arg_frees))))
|> split_list;
val f_convs = Symtab.make f_convs;
(* Split this recursive group prior to defining the functions.
* Function calls can be removed by dead code elimination and other
* transformations, which could cause the group to split. *)
val f_callees = Symtab.make (map (apsnd Symset.dest) call_graph);
(* Add each function to its own callees to get a complete inverse *)
val f_callers = flip_symtab (Symtab.map cons f_callees);
val topo_sorted_functions =
Topo_Sort.topo_sort {
cmp = String.compare,
graph = Symtab.lookup f_callees #> the,
converse = Symtab.lookup f_callers #> the
} (Symtab.keys f_callees |> sort String.compare)
|> map Symset.make;
val f_convss =
topo_sorted_functions |> map (fn group =>
map (fn f => (f, the (Symtab.lookup f_convs f)))
(Symset.dest group)
|> Symtab.make);
(* Define each function group and append it to the result sequence. *)
fun add_subgroup_defs lthy defined_so_far [] =
add_defs lthy defined_so_far remaining_groups
| add_subgroup_defs lthy defined_so_far (f_convs :: f_convss) =
FSeq.fcons (fn () => let
val (new_defs, lthy') =
define lthy l2_infos defined_so_far (Symtab.dest f_convs);
(* Update (non-recursive) callees *)
val new_defs = FunctionInfo2.recalc_base_callees defined_so_far new_defs;
val defined_so_far' = Symtab.merge (K false) (defined_so_far, new_defs);
in ((lthy', new_defs),
add_subgroup_defs lthy' defined_so_far' f_convss)
end);
in (* a bit ugly -- may need to tweak the FSeq interface *)
the (FSeq.uncons (add_subgroup_defs lthy defined_so_far f_convss)) end);
(* Sequence of intermediate states: (lthy, new_defs) *)
(* Sequence of new function_infos and intermediate lthys *)
val def_results = FSeq.mk (fn _ =>
(* If there's nothing to translate, we won't have a lthy to use *)
if FSeq.null l2_results then NONE else
let (* Get initial lthy from end of L2 defs *)
val (l2_lthy, _) = FSeq.list_of l2_results |> List.last;
in FSeq.uncons (add_defs l2_lthy Symtab.empty converted_groups) end);
val results = AutoCorresUtil2.define_funcs_sequence
l2_lthy define Symtab.empty converted_groups;
in FSeq.uncons results end);
(* Produce a mapping from each function group to its L1 phase_infos and the
* earliest intermediate lthy where it is defined. *)

View File

@ -1510,9 +1510,7 @@ fun convert
(trace_opt: bool)
(l2_function_name: string -> string)
(f_name: string)
(* FIXME: name type *)
: term * thm * symset * term Symtab.table *
(string * typ) list * (string * AutoCorresData.Trace) list = let
: AutoCorresUtil2.convert_result = let
val (l1_call_info, l1_infos) = FunctionInfo2.calc_call_graph l1_infos;
val f_info = Utils.the' ("L2 conversion missing info for " ^ f_name)
@ -1563,12 +1561,13 @@ fun convert
val callee_consts =
callee_terms |> map (fn (callee, (_, const, _)) => (callee, const)) |> Symtab.make;
in
(f_body,
Morphism.thm export_thm thm, (* Expose callee assumptions *)
rec_callees,
callee_consts,
dest_Free measure_var :: arg_frees,
opt_traces)
{ body = f_body,
proof = Morphism.thm export_thm thm, (* Expose callee assumptions *)
rec_callees = rec_callees,
callee_consts = callee_consts,
arg_frees = dest_Free measure_var :: arg_frees,
traces = opt_traces
}
end
@ -1581,12 +1580,14 @@ fun define
(l1_infos: FunctionInfo2.function_info Symtab.table)
(l2_callees: FunctionInfo2.function_info Symtab.table)
(l2_function_name: string -> string)
(* name, raw body, callee consts, corres thm, arg frees *)
(funcs: (string * (term * term Symtab.table * thm * (string * typ) list)) list)
(funcs: AutoCorresUtil2.convert_result Symtab.table)
: FunctionInfo2.function_info Symtab.table * local_theory = let
val funcs' = funcs |>
map (fn (name, def as (body, callee_consts, thm, frees)) =>
(name, ((AutoCorresUtil2.abstract_fn_body l1_infos (name, def), thm, frees))));
(* FIXME: the abstract_fn_body step should be moved into define_funcs *)
val funcs' = Symtab.dest funcs |>
map (fn (name, {body, proof, callee_consts, arg_frees, ...}) =>
(name, ((AutoCorresUtil2.abstract_fn_body l1_infos
(name, (body, callee_consts, proof, arg_frees))),
proof, arg_frees)));
val (new_thms, (), lthy') =
AutoCorresUtil2.define_funcs FunctionInfo2.L2 filename l1_infos
l2_function_name
@ -1623,83 +1624,23 @@ fun translate
(l2_function_name: string -> string)
: FunctionInfo2.phase_results =
let
(* All conversions can run in parallel.
* Each conversion depends only on the corresponding L1 define phase
* (which necessarily also includes L1 callees). *)
(* Do conversions in parallel. *)
val converted_groups =
FSeq.fold_map (fn l1_infos_accum => fn (lthy, l1_infos) => let
val l1_infos_accum' = Symtab.merge (K false) (l1_infos_accum, l1_infos);
(* Convert l1_infos in parallel, but join the results right away.
* This is fine because we will define them together. *)
val conv_results =
Symtab.dest l1_infos
|> Par_List.map (fn (f, _) =>
(f, convert lthy prog_info l1_infos_accum' do_opt trace_opt l2_function_name f))
|> Symtab.make;
in ((l1_infos, conv_results), l1_infos_accum') end)
Symtab.empty l1_results;
AutoCorresUtil2.par_convert
(fn lthy => fn l1_infos => convert lthy prog_info l1_infos do_opt trace_opt l2_function_name)
l1_results;
(* Definitions update lthy sequentially.
* Each definition step produces a lthy that has a prefix of the update sequence,
* and can be used in an L2 translation that depends only on that prefix.
* Hence we return the intermediate lthys as futures. *)
fun add_defs lthy defined_so_far converted_groups = FSeq.mk (fn () =>
case FSeq.uncons converted_groups of
NONE => NONE
| SOME ((l1_infos, conv_group), remaining_groups) => SOME let
val (call_graph, f_convs) =
Symtab.dest conv_group
|> map (fn (f, (l2_body, corres_thm, rec_callees, callee_consts, arg_frees, traces)) =>
(* FIXME: return traces *)
((f, rec_callees), (f, (l2_body, callee_consts, corres_thm, arg_frees))))
|> split_list;
val f_convs = Symtab.make f_convs;
(* Split this recursive group prior to defining the functions.
* Function calls can be removed by dead code elimination and other
* transformations, which could cause the group to split. *)
val f_callees = Symtab.make (map (apsnd Symset.dest) call_graph);
(* Add each function to its own callees to get a complete inverse *)
val f_callers = flip_symtab (Symtab.map cons f_callees);
val topo_sorted_functions =
Topo_Sort.topo_sort {
cmp = String.compare,
graph = Symtab.lookup f_callees #> the,
converse = Symtab.lookup f_callers #> the
} (Symtab.keys f_callees |> sort String.compare)
|> map Symset.make;
val f_convss =
topo_sorted_functions |> map (fn group =>
map (fn f => (f, the (Symtab.lookup f_convs f)))
(Symset.dest group)
|> Symtab.make);
(* Define each function group and append it to the result sequence. *)
fun add_subgroup_defs lthy defined_so_far [] =
add_defs lthy defined_so_far remaining_groups
| add_subgroup_defs lthy defined_so_far (f_convs :: f_convss) =
FSeq.fcons (fn () => let
val (new_defs, lthy') =
define lthy filename prog_info l1_infos
defined_so_far l2_function_name (Symtab.dest f_convs);
(* Update (non-recursive) callees *)
val new_defs = FunctionInfo2.recalc_base_callees defined_so_far new_defs;
val defined_so_far' = Symtab.merge (K false) (defined_so_far, new_defs);
in ((lthy', new_defs),
add_subgroup_defs lthy' defined_so_far' f_convss)
end);
in (* a bit ugly -- may need to tweak the FSeq interface *)
the (FSeq.uncons (add_subgroup_defs lthy defined_so_far f_convss)) end);
(* Sequence of intermediate states: (lthy, new_defs) *)
(* Sequence of new function_infos and intermediate lthys *)
val def_results = FSeq.mk (fn _ =>
(* If there's nothing to translate, we won't have a lthy to use *)
if FSeq.null l1_results then NONE else
let (* Get initial lthy from end of L1 defs *)
val (l1_lthy, _) = FSeq.list_of l1_results |> List.last;
in FSeq.uncons (add_defs l1_lthy Symtab.empty converted_groups) end);
fun define_worker lthy l2_callee_infos l1_infos f_convs =
define lthy filename prog_info l2_callee_infos l1_infos l2_function_name f_convs;
val results = AutoCorresUtil2.define_funcs_sequence
l1_lthy define_worker Symtab.empty converted_groups;
in FSeq.uncons results end);
(* Produce a mapping from each function group to its L1 phase_infos and the
* earliest intermediate lthy where it is defined. *)

View File

@ -444,8 +444,7 @@ fun convert
(trace_opt: bool)
(l1_function_name: string -> string)
(f_name: string)
: term * thm * symset * term Symtab.table *
(string * typ) list * (string * AutoCorresData.Trace) list =
: AutoCorresUtil2.convert_result =
let
(* FIXME: refactor? *)
val (simpl_calls, simpl_defs) = FunctionInfo2.calc_call_graph simpl_defs;
@ -492,12 +491,14 @@ let
val callee_consts =
callee_terms |> map (fn (callee, (_, const, _)) => (callee, const)) |> Symtab.make;
in
(f_body,
Morphism.thm export_thm thm, (* Also generalizes callees *)
rec_callees,
callee_consts,
[dest_Free measure_var],
opt_traces)
{ body = f_body,
(* Expose callee assumptions and generalizes calle vars *)
proof = Morphism.thm export_thm thm,
rec_callees = rec_callees,
callee_consts = callee_consts,
arg_frees = [dest_Free measure_var],
traces = opt_traces
}
end
@ -505,15 +506,14 @@ let
* lthy must include all definitions from l1_callees.
* simpl_defs must include current function set and its immediate callees. *)
fun define
(lthy: local_theory)
(filename: string)
(prog_info: ProgramInfo.prog_info)
(simpl_defs: FunctionInfo2.function_info Symtab.table)
(check_termination: bool)
(l1_callees: FunctionInfo2.function_info Symtab.table)
(l1_function_name: string -> string)
(* name, raw body, callee consts, corres thm, arg frees *)
(funcs: (string * (term * term Symtab.table * thm * (string * typ) list)) list)
(lthy: local_theory)
(simpl_infos: FunctionInfo2.function_info Symtab.table)
(l1_callees: FunctionInfo2.function_info Symtab.table)
(funcs: AutoCorresUtil2.convert_result Symtab.table)
: FunctionInfo2.function_info Symtab.table * local_theory = let
(* All L1 functions have the same signature: measure \<Rightarrow> L1_monad *)
val l1_fn_type = AutoCorresUtil2.measureT --> mk_l1monadT (#state_type prog_info);
@ -521,19 +521,20 @@ fun define
(* L1corres for f's callees. *)
fun get_l1_fn_assumption ctxt fn_name free _ _ measure_var =
mk_L1corres_call_prop ctxt prog_info check_termination
(the (Symtab.lookup simpl_defs fn_name)) (betapply (free, measure_var));
(the (Symtab.lookup simpl_infos fn_name)) (betapply (free, measure_var));
val funcs' = funcs |>
map (fn (name, def as (body, callee_consts, thm, frees)) =>
(* FIXME: abstract_fn_body should really be moved into define_funcs *)
(name, ((AutoCorresUtil2.abstract_fn_body simpl_defs (name, def), thm, frees))));
val funcs' = Symtab.dest funcs |>
map (fn (name, {body, proof, callee_consts, arg_frees, ...}) =>
(name, ((AutoCorresUtil2.abstract_fn_body simpl_infos
(name, (body, callee_consts, proof, arg_frees))),
proof, arg_frees)));
val (new_thms, (), lthy') =
AutoCorresUtil2.define_funcs FunctionInfo2.L1 filename simpl_defs
AutoCorresUtil2.define_funcs FunctionInfo2.L1 filename simpl_infos
l1_function_name (K l1_fn_type) get_l1_fn_assumption (K []) @{thm L1corres_recguard_0}
lthy (Symtab.map (K #corres_thm) l1_callees) ()
funcs';
val new_defs = Symtab.map (fn f_name => fn (const, def, corres_thm) => let
val f_info = the (Symtab.lookup simpl_defs f_name);
val f_info = the (Symtab.lookup simpl_infos f_name);
in f_info
|> FunctionInfo2.function_info_upd_phase FunctionInfo2.L1
|> FunctionInfo2.function_info_upd_definition def
@ -568,78 +569,32 @@ fun translate
(lthy: local_theory)
: FunctionInfo2.phase_results =
let
val (l1_call_graph, simpl_infos) = FunctionInfo2.calc_call_graph simpl_infos;
(* Initial function groups, in topological order *)
val initial_results =
#topo_sorted_functions l1_call_graph
|> map (fn f_names => let
val f_infos =
Symset.dest f_names
|> List.mapPartial (fn f => Option.map (pair f) (Symtab.lookup simpl_infos f))
|> Symtab.make;
in (lthy, f_infos) end)
|> FSeq.of_list;
val funcs_to_translate = Symtab.keys simpl_infos;
(* All conversions can run in parallel. *)
val converted_funcs =
funcs_to_translate |> map (fn f =>
(f, Future.fork (fn _ =>
convert lthy prog_info simpl_infos check_termination do_opt trace_opt l1_function_name f)))
|> Symtab.make;
(* Do conversions in parallel. *)
val converted_groups =
AutoCorresUtil2.par_convert
(fn lthy => fn l1_infos =>
convert lthy prog_info simpl_infos check_termination
do_opt trace_opt l1_function_name)
initial_results;
(* Definitions update lthy sequentially.
* We use the arbitrary (but deterministic) ordering defined by get_topo_sorted_functions.
* Each definition step produces a lthy that has a prefix of the update sequence,
* and can be used in an L2 translation that depends only on that prefix.
* Hence we return the intermediate lthys as futures. *)
fun add_defs _ _ [] = FSeq.empty ()
| add_defs lthy defined_so_far (f_names :: next_names) =
FSeq.fcons (fn () => let
(* Wait for conversions to finish *)
val (call_graph, f_convs) =
map (fn f => let
val conv = the' ("didn't convert function: " ^ quote f ^ "??")
(Symtab.lookup converted_funcs f);
val (l1_body, corres_thm, rec_callees, callee_consts, arg_frees, traces) =
Future.join conv
(* FIXME: return traces *)
in ((f, rec_callees), (f, (l1_body, callee_consts, corres_thm, arg_frees))) end)
f_names
|> split_list;
val f_convs = Symtab.make f_convs;
(* Split this recursive group prior to defining the functions.
* Function calls can be removed by dead code elimination and other
* transformations, which could cause the group to split. *)
val f_callees = Symtab.make (map (apsnd Symset.dest) call_graph);
(* Add each function to its own callees to get a complete inverse *)
val f_callers = flip_symtab (Symtab.map cons f_callees);
val topo_sorted_functions =
Topo_Sort.topo_sort {
cmp = String.compare,
graph = Symtab.lookup f_callees #> the,
converse = Symtab.lookup f_callers #> the
} (Symtab.keys f_callees |> sort String.compare)
|> map Symset.make;
val f_convss =
topo_sorted_functions |> map (fn group =>
map (fn f => (f, the (Symtab.lookup f_convs f)))
(Symset.dest group)
|> Symtab.make);
(* Define each function group and append it to the result sequence. *)
fun add_subgroup_defs lthy defined_so_far [] =
add_defs lthy defined_so_far next_names
| add_subgroup_defs lthy defined_so_far (f_convs :: f_convss) =
FSeq.fcons (fn () => let
val (new_defs, lthy') =
define lthy filename prog_info simpl_infos check_termination
defined_so_far l1_function_name (Symtab.dest f_convs);
(* Update (non-recursive) callees *)
val new_defs = FunctionInfo2.recalc_base_callees defined_so_far new_defs;
val defined_so_far' = Symtab.merge (K false) (defined_so_far, new_defs);
in ((lthy', new_defs),
add_subgroup_defs lthy' defined_so_far' f_convss)
end);
in (* a bit ugly -- may need to tweak the FSeq interface *)
the (FSeq.uncons (add_subgroup_defs lthy defined_so_far f_convss)) end);
val (simpl_calls, _) = FunctionInfo2.calc_call_graph simpl_infos;
val function_groups = #topo_sorted_functions simpl_calls;
(* Sequence of intermediate results: (lthy, new_defs) *)
val def_results = add_defs lthy Symtab.empty (map Symset.dest function_groups);
(* Sequence of new function_infos and intermediate lthys *)
val def_results = AutoCorresUtil2.define_funcs_sequence
lthy (define filename prog_info check_termination l1_function_name)
Symtab.empty converted_groups;
(* Produce a mapping from each function group to its L1 phase_infos and the
* earliest intermediate lthy where it is defined. *)

View File

@ -175,7 +175,7 @@ let
(if Symset.contains no_signed_abs fn_name then [] else sword_abs)
(* Convert each function. *)
fun convert lthy l2_infos f =
fun convert lthy l2_infos f: AutoCorresUtil2.convert_result =
let
val old_fn_info = the (Symtab.lookup l2_infos f);
val wa_rules = rules_for f;
@ -400,12 +400,14 @@ let
val callee_consts =
callee_terms |> map (fn (callee, (_, const, _)) => (callee, const)) |> Symtab.make;
in
(f_body,
Morphism.thm export_thm thm,
rec_callees,
callee_consts,
map dest_Free (measure_var :: arg_frees),
(if member (op =) trace_funcs f then [("WA", AutoCorresData.RuleTrace trace)] else []) @ opt_traces)
{ body = f_body,
proof = Morphism.thm export_thm thm,
rec_callees = rec_callees,
callee_consts = callee_consts,
arg_frees = map dest_Free (measure_var :: arg_frees),
traces = (if member (op =) trace_funcs f
then [("WA", AutoCorresData.RuleTrace trace)] else []) @ opt_traces
}
end
(* Define a previously-converted function (or recursive function group).
@ -414,13 +416,13 @@ let
(lthy: local_theory)
(l2_infos: FunctionInfo2.function_info Symtab.table)
(wa_callees: FunctionInfo2.function_info Symtab.table)
(* name, raw body, callee consts, corres thm, arg frees *)
(funcs: (string * (term * term Symtab.table * thm * (string * typ) list)) list)
(funcs: AutoCorresUtil2.convert_result Symtab.table)
: FunctionInfo2.function_info Symtab.table * local_theory = let
val funcs' = funcs |>
map (fn (name, def as (body, callee_consts, thm, frees)) =>
(* FIXME: abstract_fn_body should really be moved into define_funcs *)
(name, ((AutoCorresUtil2.abstract_fn_body l2_infos (name, def), thm, frees))));
val funcs' = Symtab.dest funcs |>
map (fn (name, {body, proof, callee_consts, arg_frees, ...}) =>
(name, ((AutoCorresUtil2.abstract_fn_body l2_infos
(name, (body, callee_consts, proof, arg_frees))),
proof, arg_frees)));
val (new_thms, (), lthy') =
AutoCorresUtil2.define_funcs FunctionInfo2.WA filename l2_infos
wa_function_name
@ -445,73 +447,8 @@ let
end) new_thms;
in (new_infos, lthy') end;
(* All conversions can run in parallel.
* Each conversion depends only on the previous define phase
* (which necessarily also includes function callees). *)
val converted_groups =
FSeq.fold_map (fn l2_infos_accum => fn (lthy, l2_infos) => let
val l2_infos_accum' = Symtab.merge (K false) (l2_infos_accum, l2_infos);
(* Convert l2_infos in parallel, but join the results right away.
* This is fine because we will define them together. *)
val conv_results =
Symtab.dest l2_infos
|> Par_List.map (fn (f, _) => (f, convert lthy l2_infos_accum' f))
|> Symtab.make;
in ((l2_infos, conv_results), l2_infos_accum') end)
Symtab.empty l2_results;
(* Definitions update lthy sequentially.
* Each definition step produces a lthy that has a prefix of the update sequence,
* and can be used in an L2 translation that depends only on that prefix.
* Hence we return the intermediate lthys as futures. *)
fun add_defs lthy defined_so_far converted_groups = FSeq.mk (fn () =>
case FSeq.uncons converted_groups of
NONE => NONE
| SOME ((l2_infos, conv_group), remaining_groups) => SOME let
val (call_graph, f_convs) =
Symtab.dest conv_group
|> map (fn (f, (hl_body, corres_thm, rec_callees, callee_consts, arg_frees, traces)) =>
(* FIXME: return traces *)
((f, rec_callees), (f, (hl_body, callee_consts, corres_thm, arg_frees))))
|> split_list;
val f_convs = Symtab.make f_convs;
(* Split this recursive group prior to defining the functions.
* Function calls can be removed by dead code elimination and other
* transformations, which could cause the group to split. *)
val f_callees = Symtab.make (map (apsnd Symset.dest) call_graph);
(* Add each function to its own callees to get a complete inverse *)
val f_callers = flip_symtab (Symtab.map cons f_callees);
val topo_sorted_functions =
Topo_Sort.topo_sort {
cmp = String.compare,
graph = Symtab.lookup f_callees #> the,
converse = Symtab.lookup f_callers #> the
} (Symtab.keys f_callees |> sort String.compare)
|> map Symset.make;
val f_convss =
topo_sorted_functions |> map (fn group =>
map (fn f => (f, the (Symtab.lookup f_convs f)))
(Symset.dest group)
|> Symtab.make);
(* Define each function group and append it to the result sequence. *)
fun add_subgroup_defs lthy defined_so_far [] =
add_defs lthy defined_so_far remaining_groups
| add_subgroup_defs lthy defined_so_far (f_convs :: f_convss) =
FSeq.fcons (fn () => let
val (new_defs, lthy') =
define lthy l2_infos defined_so_far (Symtab.dest f_convs);
(* Update (non-recursive) callees *)
val new_defs = FunctionInfo2.recalc_base_callees defined_so_far new_defs;
val defined_so_far' = Symtab.merge (K false) (defined_so_far, new_defs);
in ((lthy', new_defs),
add_subgroup_defs lthy' defined_so_far' f_convss)
end);
in (* a bit ugly -- may need to tweak the FSeq interface *)
the (FSeq.uncons (add_subgroup_defs lthy defined_so_far f_convss)) end);
(* Do conversions in parallel. *)
val converted_groups = AutoCorresUtil2.par_convert convert l2_results;
(* Sequence of intermediate states: (lthy, new_defs) *)
val def_results = FSeq.mk (fn _ =>
@ -519,7 +456,9 @@ let
if FSeq.null l2_results then NONE else
let (* Get initial lthy from end of L2 defs *)
val (l2_lthy, _) = FSeq.list_of l2_results |> List.last;
in FSeq.uncons (add_defs l2_lthy Symtab.empty converted_groups) end);
val results = AutoCorresUtil2.define_funcs_sequence
l2_lthy define Symtab.empty converted_groups;
in FSeq.uncons results end);
(* Produce a mapping from each function group to its L1 phase_infos and the
* earliest intermediate lthy where it is defined. *)