WIP: autocorres: crude VER-517 prototypes for WA and TS phases
HL is still pending; the new code also needs to be refactored itself.
This commit is contained in:
parent
84cb9deaf8
commit
08c3475a09
|
@ -145,11 +145,13 @@ ML_file "monad_convert.ML"
|
|||
ML_file "type_strengthen.ML"
|
||||
ML_file "autocorres.ML"
|
||||
|
||||
declare [[ML_print_depth=42]]
|
||||
declare [[ML_print_depth=27]]
|
||||
ML_file "function_info2.ML"
|
||||
ML_file "autocorres_util2.ML"
|
||||
ML_file "simpl_conv2.ML"
|
||||
ML_file "local_var_extract2.ML"
|
||||
ML_file "word_abstract2.ML"
|
||||
ML_file "type_strengthen2.ML"
|
||||
|
||||
(* Setup "autocorres" keyword. *)
|
||||
ML {*
|
||||
|
@ -159,82 +161,4 @@ ML {*
|
|||
(Toplevel.theory o (fn (opt, filename) => AutoCorres.do_autocorres opt filename)))
|
||||
*}
|
||||
|
||||
ML \<open>
|
||||
|
||||
fun assume_called_functions_corres ctxt fn_info callees
|
||||
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var =
|
||||
let
|
||||
(* Assume the existence of a function, along with a theorem about its
|
||||
* behaviour. *)
|
||||
fun assume_func ctxt fn_name is_recursive_call =
|
||||
let
|
||||
val fn_args = get_fn_args fn_name
|
||||
|
||||
(* Fix a variable for the function. *)
|
||||
val ([fixed_fn_name], ctxt') = Variable.variant_fixes [get_const_name fn_name] ctxt
|
||||
val fn_free = Free (fixed_fn_name, get_fn_type fn_name)
|
||||
|
||||
(* Fix a variable for the measure and function arguments. *)
|
||||
val (measure_var_name :: arg_names, ctxt'')
|
||||
= Variable.variant_fixes ("rec_measure'" :: (map fst fn_args)) ctxt'
|
||||
val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args))
|
||||
val my_measure_var = Free (measure_var_name, @{typ nat})
|
||||
|
||||
(*
|
||||
* A measure variable is needed to handle recursion: for recursive calls,
|
||||
* we need to decrement the caller's input measure value (and our
|
||||
* assumption will need to assume this to). This is so we can later prove
|
||||
* termination of our function definition: the measure always reaches zero.
|
||||
*
|
||||
* Non-recursive calls can have a fresh value.
|
||||
*)
|
||||
val measure_var =
|
||||
if is_recursive_call then
|
||||
@{const "recguard_dec"} $ callers_measure_var
|
||||
else
|
||||
my_measure_var
|
||||
|
||||
(* Create our assumption. *)
|
||||
val assumption =
|
||||
get_fn_assumption ctxt'' fn_name fn_free fn_arg_terms
|
||||
is_recursive_call measure_var
|
||||
|> fold Logic.all (rev ((if is_recursive_call then [] else [my_measure_var]) @ fn_arg_terms))
|
||||
|> Sign.no_vars ctxt'
|
||||
|> Thm.cterm_of ctxt'
|
||||
val ([thm], ctxt''') = Assumption.add_assumes [assumption] ctxt'
|
||||
|
||||
(* Generate a morphism for escaping this context. *)
|
||||
val m = (Assumption.export_morphism ctxt''' ctxt')
|
||||
$> (Variable.export_morphism ctxt' ctxt)
|
||||
in
|
||||
(fn_free, thm, ctxt''', m)
|
||||
end
|
||||
|
||||
(* Apply each assumption. *)
|
||||
val (res, (ctxt', m)) = fold_map (
|
||||
fn (fn_name, is_recursive_call) =>
|
||||
fn (ctxt, m) =>
|
||||
let
|
||||
val (free, thm, ctxt', m') =
|
||||
assume_func ctxt fn_name is_recursive_call
|
||||
in
|
||||
((fn_name, (is_recursive_call, free, thm)), (ctxt', m' $> m))
|
||||
end)
|
||||
callees (ctxt, Morphism.identity)
|
||||
in
|
||||
(ctxt', m, res)
|
||||
end;
|
||||
|
||||
(*
|
||||
fun assume_called_functions_corres ctxt fn_info callees
|
||||
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var
|
||||
*)
|
||||
assume_called_functions_corres @{context} () [("a", false), ("r", true)]
|
||||
(K @{typ "nat \<Rightarrow> nat \<Rightarrow> string \<Rightarrow> string"})
|
||||
(fn ctxt => fn name => fn term => fn args => fn is_rec => fn meas =>
|
||||
HOLogic.mk_Trueprop (@{term "my_corres :: (string \<Rightarrow> string) \<Rightarrow> bool"} $ betapplys (term, meas :: args)))
|
||||
(fn f => if f = "a" then [("arg_a", @{typ nat})] else [("arg_r", @{typ nat})])
|
||||
I @{term "rec_measure :: nat"}
|
||||
\<close>
|
||||
|
||||
end
|
||||
|
|
|
@ -1584,6 +1584,7 @@ fun define
|
|||
|
||||
fun prepare_fn_body (fn_name, corres_thm, arg_frees) = let
|
||||
val _ = @{trace} ("prepare_fn_body", fn_name, corres_thm);
|
||||
(* FIXME: move this to convert *)
|
||||
val @{term_pat "Trueprop (L2corres _ _ _ _ ?body _)"} = Thm.concl_of corres_thm;
|
||||
val (callees, recursive_callees) = AutoCorresUtil2.get_callees l1_infos fn_name;
|
||||
val calls = map (fn c => Var (guess_callee_var corres_thm c)) callees;
|
||||
|
@ -1622,7 +1623,6 @@ fun define
|
|||
@{thm L2corres_recguard_0}
|
||||
lthy (Symtab.map (K #corres_thm) l2_callees) ()
|
||||
funcs';
|
||||
val f_names = map (fn (name, _, _) => name) funcs;
|
||||
val new_infos = Symtab.map (fn f_name => fn (const, def, corres_thm) => let
|
||||
val old_info = the (Symtab.lookup l1_infos f_name);
|
||||
in old_info
|
||||
|
@ -1631,11 +1631,15 @@ fun define
|
|||
|> FunctionInfo2.function_info_upd_definition def
|
||||
|> FunctionInfo2.function_info_upd_corres_thm corres_thm
|
||||
|> FunctionInfo2.function_info_upd_mono_thm NONE (* added later *)
|
||||
(* Update arg names to match our newly converted functions *)
|
||||
|> FunctionInfo2.function_info_upd_args
|
||||
(map (apfst (to_free_var_name lthy' o ProgramInfo.demangle_name)) (#args old_info))
|
||||
end) new_thms;
|
||||
(* FIXME: return traces *)
|
||||
in (new_infos, lthy') end;
|
||||
|
||||
|
||||
(* FIXME: move *)
|
||||
fun symtab_merge allow_dups tabs =
|
||||
maps Symtab.dest tabs
|
||||
|> (if allow_dups then sort_distinct (fast_string_ord o apply2 fst) else I)
|
||||
|
@ -1648,7 +1652,7 @@ fun translate filename prog_info
|
|||
(* lazy results from L1 *)
|
||||
(l1_results: (symset * (local_theory * FunctionInfo2.function_info Symtab.table) future) list)
|
||||
do_opt trace_opt l2_function_name =
|
||||
(* if there's nothing to translate, we won't have a lthy *)
|
||||
(* if there's nothing to translate, we won't have a lthy to use *)
|
||||
if null l1_results then [] else
|
||||
let
|
||||
(* TODO: we should recalculate this from l1_results to take dead-code elim
|
||||
|
|
|
@ -20,6 +20,10 @@ begin
|
|||
(* Parse the input file. *)
|
||||
install_C_file "factorial.c"
|
||||
|
||||
(*
|
||||
autocorres [scope_depth=0, scope=factorial] "factorial.c"
|
||||
autocorres [scope_depth=0, scope=call_factorial] "factorial.c"
|
||||
*)
|
||||
autocorres "factorial.c"
|
||||
|
||||
context factorial begin
|
||||
|
|
|
@ -19,6 +19,10 @@ install_C_file "simple.c"
|
|||
|
||||
(* Abstract the input file. *)
|
||||
autocorres [ ts_force pure = max, ts_force nondet = gcd, unsigned_word_abs = gcd ] "simple.c"
|
||||
(*
|
||||
autocorres [ ts_force nondet = gcd, unsigned_word_abs = gcd, scope_depth=0, scope=gcd ] "simple.c"
|
||||
autocorres [ ts_force pure = max, scope_depth=0, scope=max ] "simple.c"
|
||||
*)
|
||||
|
||||
(* Generated theorems and proofs. *)
|
||||
thm simple.max'_def simple.max'_ac_corres
|
||||
|
|
|
@ -23,4 +23,3 @@ unsigned gcd(unsigned a, unsigned b) {
|
|||
}
|
||||
return b;
|
||||
}
|
||||
|
||||
|
|
|
@ -49,36 +49,43 @@ let val simpl_infos = FunctionInfo2.init_function_info @{context} "type_strength
|
|||
in (frees1, corres1, Symtab.dest l1_infos2) end
|
||||
\<close>
|
||||
|
||||
ML \<open>
|
||||
local_setup \<open>
|
||||
fn lthy =>
|
||||
let val filename = "type_strengthen.c";
|
||||
val simpl_info = FunctionInfo2.init_function_info @{context} filename;
|
||||
val prog_info = ProgramInfo.get_prog_info @{context} filename;
|
||||
val simpl_info = FunctionInfo2.init_function_info lthy filename;
|
||||
val prog_info = ProgramInfo.get_prog_info lthy filename;
|
||||
val l1_results =
|
||||
SimplConv2.translate filename prog_info simpl_info
|
||||
true true false (fn f => "l1_" ^ f ^ "'") @{context};
|
||||
true true false (fn f => "l1_" ^ f ^ "'") lthy;
|
||||
val l2_results =
|
||||
LocalVarExtract2.translate filename prog_info l1_results
|
||||
true false (fn f => "l2_" ^ f ^ "'");
|
||||
in l2_results |> map (snd #> Future.join) |> map (snd #> Symtab.dest) end
|
||||
|
||||
val wa_results =
|
||||
WordAbstract2.translate filename prog_info l2_results Symset.empty Symset.empty []
|
||||
true false (fn f => "wa_" ^ f ^ "'");
|
||||
|
||||
val ts_rules = Monad_Types.get_ordered_rules [] (Context.Proof lthy);
|
||||
val ts_results =
|
||||
TypeStrengthen2.translate ts_rules Symtab.empty filename prog_info
|
||||
wa_results (fn f => f ^ "'") true true;
|
||||
in ts_results |> rev |> hd |> fst end
|
||||
\<close>
|
||||
|
||||
ML \<open>
|
||||
FunctionInfo2.init_function_info @{context} "type_strengthen.c"
|
||||
|> Symtab.dest
|
||||
\<close>
|
||||
|
||||
end
|
||||
|
||||
declare [[ML_print_depth=99]]
|
||||
autocorres [
|
||||
ts_rules = nondet,
|
||||
scope = st_i,
|
||||
skip_heap_abs, skip_word_abs
|
||||
] "type_strengthen.c"
|
||||
|
||||
|
||||
(* We can also specify which monads are used for type strengthening.
|
||||
Here, we exclude the read-only monad completely, and specify
|
||||
rules for some individual functions. *)
|
||||
autocorres [
|
||||
ts_rules = pure option nondet,
|
||||
ts_force option = pure_f,
|
||||
statistics
|
||||
skip_heap_abs
|
||||
] "type_strengthen.c"
|
||||
|
||||
context type_strengthen begin
|
||||
|
|
|
@ -386,6 +386,7 @@ let
|
|||
#L2_simp_rules rule_set' @
|
||||
@{thms gets_bind_ign L2_call_fail HOL.simp_thms}
|
||||
|
||||
val _ = @{trace} ("ts proof rules", map fst functions, defs ~~ fn_defs)
|
||||
val rewrite_thm =
|
||||
Goal.init goal
|
||||
|> (fn goal_state => if is_recursive then
|
||||
|
|
|
@ -0,0 +1,640 @@
|
|||
(*
|
||||
* Copyright 2014, NICTA
|
||||
*
|
||||
* This software may be distributed and modified according to the terms of
|
||||
* the BSD 2-Clause license. Note that NO WARRANTY is provided.
|
||||
* See "LICENSE_BSD2.txt" for details.
|
||||
*
|
||||
* @TAG(NICTA_BSD)
|
||||
*)
|
||||
|
||||
(*
|
||||
* Lift monadic structures into lighter-weight monads.
|
||||
*)
|
||||
structure TypeStrengthen2 =
|
||||
struct
|
||||
|
||||
exception AllLiftingFailed of (string * thm) list
|
||||
exception LiftingFailed of unit
|
||||
|
||||
(* FIXME: use AUTOCORRES_SIMPSET (need to fix unknown deps of the corres prover) *)
|
||||
val ts_simpset = simpset_of @{context}
|
||||
|
||||
(* Misc util functions. *)
|
||||
val the' = Utils.the'
|
||||
val apply_tac = Utils.apply_tac
|
||||
|
||||
fun get_l2_state_typ ctxt prog_info l2_infos fn_name =
|
||||
let
|
||||
val term = #const (the (Symtab.lookup l2_infos fn_name));
|
||||
in
|
||||
LocalVarExtract.dest_l2monad_T (fastype_of term) |> snd |> #1
|
||||
end;
|
||||
|
||||
fun get_typ_from_L2 (rule_set : Monad_Types.monad_type) L2_typ =
|
||||
LocalVarExtract.dest_l2monad_T L2_typ |> snd |> #typ_from_L2 rule_set
|
||||
|
||||
(*
|
||||
* Make an equality prop of the form "L2_call <foo> = <liftE> $ <bar>".
|
||||
*
|
||||
* L2_call and <liftE> will typically be desired to be polymorphic in their
|
||||
* exception type. We fix it to "unit"; the caller will need to introduce
|
||||
* polymorphism as necessary.
|
||||
*
|
||||
* If "measure" is non-NONE, then that term will be used instead of a free
|
||||
* variable.
|
||||
* If "state_typ" is non-NONE, then "measure" is assumed to also take a
|
||||
* state parameter of the given type.
|
||||
*)
|
||||
fun make_lift_equality ctxt prog_info l2_infos fn_name
|
||||
(rule_set : Monad_Types.monad_type) state_typ measure rhs_term =
|
||||
let
|
||||
val thy = Proof_Context.theory_of ctxt
|
||||
|
||||
(* Fetch function variables. *)
|
||||
val fn_def = the (Symtab.lookup l2_infos fn_name)
|
||||
val inputs = #args fn_def
|
||||
val input_vars = map (fn (x, y) => Var ((x, 0), y)) inputs
|
||||
|
||||
(*
|
||||
* Measure var.
|
||||
*
|
||||
* The L2 left-hand-side will always use this measure, while the
|
||||
* right-hand-side will only include a measure if the function is actually
|
||||
* recursive.
|
||||
*)
|
||||
val is_recursive = FunctionInfo2.is_function_recursive fn_def
|
||||
val default_measure_var = @{term "rec_measure' :: nat"}
|
||||
val measure_term = Option.getOpt (measure, default_measure_var)
|
||||
|
||||
(*
|
||||
* Construct the equality.
|
||||
*
|
||||
* This is a little delicate: in particular, we need to ensure that
|
||||
* the type of the resulting term is strictly correct. In particular,
|
||||
* our "lift_fn" will have type variables that need to be modified.
|
||||
* "Utils.mk_term" will fill in type variables of the base term based
|
||||
* on what is applied to it. So, we need to ensure that the lift
|
||||
* function is in our base term, and that its type variables have the
|
||||
* same names ("'s" for state, "'a" for return type) that we use
|
||||
* below.
|
||||
*)
|
||||
(* FIXME: use @{mk_term} *)
|
||||
val base_term = @{term "%L a b. (L2_call :: ('s, 'a, unit) L2_monad => ('s, 'a, unit) L2_monad) a = L b"}
|
||||
val a = betapplys (#const fn_def, measure_term :: input_vars)
|
||||
val b = if not is_recursive then betapplys (rhs_term, input_vars) else
|
||||
case state_typ of
|
||||
NONE => betapplys (rhs_term, measure_term :: input_vars)
|
||||
| SOME s => betapplys (rhs_term, [measure_term] @ input_vars @ [Free ("s'", s)]) (* FIXME: Free *)
|
||||
|> lambda (Free ("s'", s))
|
||||
val term = Utils.mk_term thy (betapply (base_term, #L2_call_lift rule_set)) [a, b]
|
||||
|> HOLogic.mk_Trueprop
|
||||
|
||||
in
|
||||
(* Convert it into a trueprop with meta-foralls. *)
|
||||
Utils.vars_to_metaforall term
|
||||
end
|
||||
|
||||
(*
|
||||
* Assume recursively called functions correctly map into the given type.
|
||||
*
|
||||
* We return:
|
||||
*
|
||||
* (<newly generated context>,
|
||||
* <the measure variable used>,
|
||||
* <generated assumptions>,
|
||||
* <table mapping free term names back to their function names>,
|
||||
* <morphism to escape the context>)
|
||||
*)
|
||||
fun assume_rec_lifted ctxt prog_info l2_infos rule_set fn_name =
|
||||
let
|
||||
val thy = Proof_Context.theory_of ctxt
|
||||
val fn_def = the (Symtab.lookup l2_infos fn_name)
|
||||
|
||||
(* Find recursive calls. *)
|
||||
val recursive_calls = Symset.dest (#rec_callees fn_def)
|
||||
|
||||
(* Fix a variable for each such call, plus another for our measure variable. *)
|
||||
val (measure_fix :: dest_fn_fixes, ctxt')
|
||||
= Variable.add_fixes ("rec_measure'" :: map (fn x => "rec'" ^ x) recursive_calls) ctxt
|
||||
val rec_fun_names = Symtab.make (dest_fn_fixes ~~ recursive_calls)
|
||||
|
||||
(* For recursive calls, we need a term representing our measure variable and
|
||||
* another representing our decremented measure variable. *)
|
||||
val measure_var = Free (measure_fix, @{typ nat})
|
||||
val dec_measure_var = @{const "recguard_dec"} $ measure_var
|
||||
|
||||
(* For each recursive call, generate a theorem assuming that it lifts into
|
||||
* the type/monad of "rule_set". *)
|
||||
val dest_fn_thms = map (fn (callee, var) =>
|
||||
let
|
||||
val fn_def' = the (Symtab.lookup l2_infos callee)
|
||||
val inputs = #args fn_def'
|
||||
val T = (@{typ nat} :: map snd inputs)
|
||||
---> (fastype_of (#const fn_def') |> get_typ_from_L2 rule_set)
|
||||
|
||||
(* NB: pure functions would not use state, but recursive functions cannot
|
||||
* be lifted to pure (because we trigger failure when the measure hits
|
||||
* 0). So we can always assume there is state. *)
|
||||
val state_typ = get_l2_state_typ ctxt prog_info l2_infos fn_name
|
||||
val x = make_lift_equality ctxt prog_info l2_infos callee rule_set
|
||||
(SOME state_typ) (SOME dec_measure_var) (Free (var, T))
|
||||
in
|
||||
Thm.cterm_of ctxt' x
|
||||
end) (recursive_calls ~~ dest_fn_fixes)
|
||||
(* Our measure does not type-check for pure functions,
|
||||
causing a TERM exception from mk_term. *)
|
||||
handle TERM _ => raise LiftingFailed ()
|
||||
|
||||
(* Assume the theorems we just generated. *)
|
||||
val (thms, ctxt'') = Assumption.add_assumes dest_fn_thms ctxt'
|
||||
val thms = map (fn t => (#polymorphic_thm rule_set) OF [t]) thms
|
||||
in
|
||||
(ctxt'',
|
||||
measure_var,
|
||||
thms,
|
||||
rec_fun_names,
|
||||
Assumption.export_morphism ctxt'' ctxt'
|
||||
$> Variable.export_morphism ctxt' ctxt)
|
||||
end
|
||||
|
||||
|
||||
(*
|
||||
* Given a function definition, attempt to lift it into a different
|
||||
* monadic structure by applying a set of rewrite rules.
|
||||
*
|
||||
* For example, given:
|
||||
*
|
||||
* foo x y = doE
|
||||
* a <- returnOk 3;
|
||||
* b <- returnOk 5;
|
||||
* returnOk (a + b)
|
||||
* odE
|
||||
*
|
||||
* we may be able to lift to:
|
||||
*
|
||||
* foo x y = returnOk (let
|
||||
* a = 3;
|
||||
* b = 5;
|
||||
* in
|
||||
* a + b)
|
||||
*
|
||||
* This second function has the form "lift $ term" for some lifting function
|
||||
* "lift" and some new term "term". (These would be "returnOk" and "let a = ...
|
||||
* in a + b" in the example above, respectively.)
|
||||
*
|
||||
* We return a theorem of the form "foo x y == <lift> $ <term>", along with the
|
||||
* new term "<term>". If the lift was unsuccessful, we return "NONE".
|
||||
*)
|
||||
fun perform_lift ctxt prog_info l2_infos rule_set fn_name =
|
||||
let
|
||||
(* Assume recursive calls can be successfully lifted into this type. *)
|
||||
val (ctxt', measure_var, thms, dict, m)
|
||||
= assume_rec_lifted ctxt prog_info l2_infos rule_set fn_name
|
||||
|
||||
val fn_def = #definition (the (Symtab.lookup l2_infos fn_name))
|
||||
|
||||
(* Extract the term from our function definition. *)
|
||||
val fn_thm' = Utils.named_cterm_instantiate ctxt
|
||||
[("rec_measure'", Thm.cterm_of ctxt' measure_var)] fn_def
|
||||
handle THM _ => fn_def
|
||||
val ct = Thm.prop_of fn_thm' |> Utils.rhs_of |> Thm.cterm_of ctxt'
|
||||
|
||||
(* Rewrite the term using the given rewrite rules. *)
|
||||
val t = Thm.term_of ct
|
||||
val thm = case Monad_Convert.monad_rewrite ctxt rule_set thms true t of
|
||||
SOME t => t
|
||||
| NONE => Utils.named_cterm_instantiate ctxt [("x", ct)] @{thm reflexive}
|
||||
|
||||
(* Convert "def == old_body" and "old_body == new_body" into "def == new_body". *)
|
||||
val thm_ml = Thm.transitive fn_thm' thm
|
||||
val thm_ml = Morphism.thm m thm_ml
|
||||
|
||||
(* Get the newly rewritten term. *)
|
||||
val new_term = Thm.concl_of thm_ml |> Utils.rhs_of
|
||||
|
||||
(*val _ = @{trace} (fn_name, #name rule_set, cterm_of (Proof_Context.theory_of ctxt) new_term)*)
|
||||
|
||||
(* Determine if the conversion was successful. *)
|
||||
val success = #valid_term rule_set ctxt new_term
|
||||
in
|
||||
(* Determine if we were a success. *)
|
||||
if success then
|
||||
SOME (thm_ml, dict)
|
||||
else
|
||||
NONE
|
||||
end
|
||||
|
||||
(* Like perform_lift, but also applies the polishing rules, hopefully yielding
|
||||
* an even nicer definition. *)
|
||||
fun perform_lift_and_polish ctxt prog_info fn_info rule_set do_opt fn_name =
|
||||
case (perform_lift ctxt prog_info fn_info rule_set fn_name)
|
||||
of NONE => NONE
|
||||
| SOME (thm, dict) => SOME let
|
||||
|
||||
(* Measure the size of the new theorem. *)
|
||||
val _ = Statistics.gather ctxt "TS" fn_name
|
||||
(Thm.concl_of thm |> Utils.rhs_of)
|
||||
|
||||
(* Apply any polishing rules. *)
|
||||
val polish_thm = Monad_Convert.polish ctxt rule_set do_opt thm
|
||||
|
||||
(* Measure the term. *)
|
||||
val _ = Statistics.gather ctxt "polish" fn_name
|
||||
(Thm.concl_of polish_thm |> Utils.rhs_of)
|
||||
|
||||
in (polish_thm, dict) end
|
||||
|
||||
|
||||
(*
|
||||
* Attempt to lift a function (or recursive function group) into the given monad.
|
||||
*
|
||||
* If successful, we define the new function (vis. group) to the theory.
|
||||
* We then return a theorem of the form:
|
||||
*
|
||||
* "L2_call foo x y z == <lift> $ new_foo x y z"
|
||||
*
|
||||
* where "lift" is a lifting function, such as "returnOk" or "gets", etc.
|
||||
*
|
||||
* If the lift does not succeed, the function returns NONE.
|
||||
*
|
||||
* The callees of this function need to be already translated in ts_infos
|
||||
* and also defined in lthy.
|
||||
*)
|
||||
fun lift_function_rewrite rule_set filename prog_info l2_infos ts_infos
|
||||
fn_names make_function_name keep_going do_opt lthy =
|
||||
let
|
||||
val these_l2_infos = map (the o Symtab.lookup l2_infos) fn_names
|
||||
|
||||
(* Determine if this function is recursive. *)
|
||||
val is_recursive = FunctionInfo2.is_function_recursive (hd these_l2_infos)
|
||||
|
||||
(* Fetch relevant callees. *)
|
||||
val callees =
|
||||
map #callees these_l2_infos
|
||||
|> Symset.union_sets
|
||||
|> Symset.dest
|
||||
val callee_infos = map (Symtab.lookup ts_infos #> the) callees
|
||||
|
||||
(* Add monad_mono theorems. *)
|
||||
val mono_thms = List.mapPartial #mono_thm callee_infos
|
||||
|
||||
(* When lifting, also lift our callees. *)
|
||||
val rule_set' = Monad_Types.update_mt_lift_rules
|
||||
(fn thms => Monad_Types.thmset_adds thms (map #corres_thm callee_infos @ mono_thms))
|
||||
rule_set
|
||||
|
||||
(*
|
||||
* Attempt to lift all functions into this type.
|
||||
*
|
||||
* For mutually recursive functions, every function in the group needs to be
|
||||
* lifted in the same type.
|
||||
*
|
||||
* Eliminate the "SOME", raising an exception if any function in the group
|
||||
* couldn't be lifted to this type.
|
||||
*)
|
||||
val lifted_functions =
|
||||
map (perform_lift_and_polish lthy prog_info l2_infos rule_set' do_opt) fn_names
|
||||
|
||||
val lifted_functions = map (fn x =>
|
||||
case x of
|
||||
SOME a => a
|
||||
| NONE => raise LiftingFailed ())
|
||||
lifted_functions
|
||||
val thms = map fst lifted_functions
|
||||
val dicts = map snd lifted_functions
|
||||
|
||||
(*
|
||||
* Generate terms necessary for defining the function, and define the
|
||||
* functions.
|
||||
*)
|
||||
fun gen_fun_def_term (fn_name, dict, thm) =
|
||||
let
|
||||
(* If this function is recursive, it has a measure parameter. *)
|
||||
val measure_param = if is_recursive then [("rec_measure'", @{typ nat})] else []
|
||||
|
||||
(* Fetch function parameters. *)
|
||||
val fn_info = the (Symtab.lookup l2_infos fn_name)
|
||||
val fn_params = #args fn_info
|
||||
|
||||
(* Extract the term from the function theorem. *)
|
||||
fun tail_of (_ $ x) = x
|
||||
val term = Thm.concl_of thm
|
||||
|> Utils.rhs_of
|
||||
|> tail_of
|
||||
|> Utils.unsafe_unvarify
|
||||
|> (fn term =>
|
||||
foldr (fn (v, t) => Utils.abs_over "" v t)
|
||||
term (map Free (measure_param @ fn_params)))
|
||||
val fn_type = fastype_of term
|
||||
|
||||
(* Replace place-holder function names with our generated constant name. *)
|
||||
val term = map_aterms (fn t => case t of
|
||||
Free (n, T) =>
|
||||
(case Symtab.lookup dict n of
|
||||
NONE => t
|
||||
| SOME a => Free (make_function_name a, T))
|
||||
| x => x) term
|
||||
in
|
||||
(make_function_name fn_name, measure_param @ fn_params, term)
|
||||
end
|
||||
val input_defs = map gen_fun_def_term (Utils.zip3 fn_names dicts thms)
|
||||
val (ts_defs, lthy) = Utils.define_functions input_defs false is_recursive lthy
|
||||
|
||||
(*
|
||||
(* Record the definitions in our theory data. *)
|
||||
val lthy = fold (fn (fn_name, def) =>
|
||||
Local_Theory.background_theory (
|
||||
AutoCorresData.add_def filename ("TS" ^ "def") fn_name def))
|
||||
(Utils.zip fn_names defs) lthy
|
||||
*)
|
||||
|
||||
(* Instantiate variables in our equivalence theorem to their newly-defined values. *)
|
||||
fun do_inst_thm thm =
|
||||
Utils.instantiate_thm_vars lthy (
|
||||
fn ((name, _), t) =>
|
||||
try (unprefix "rec'") name
|
||||
|> Option.map make_function_name
|
||||
|> Option.map (Utils.get_term lthy)
|
||||
|> Option.map (Thm.cterm_of lthy)) thm
|
||||
val inst_thms = map do_inst_thm thms
|
||||
|
||||
(* HACK: If an L2 function takes no parameters, its measure gets eta-contracted
|
||||
away, preventing eqsubst_tac from unifying the rec_measure's schematic
|
||||
variable. So we get rid of the schematic var pre-emptively. *)
|
||||
val inst_thms =
|
||||
map (fn inst_thm => Drule.abs_def inst_thm handle TERM _ => inst_thm) inst_thms
|
||||
|
||||
(* Generate a theorem converting "L2_call <func>" into its new form,
|
||||
* such as L2_call <func> = liftE $ <new_func_def> *)
|
||||
val final_props = map (fn fn_name =>
|
||||
make_lift_equality lthy prog_info l2_infos fn_name rule_set'
|
||||
NONE NONE (Utils.get_term lthy (make_function_name fn_name))) fn_names
|
||||
|
||||
(* Convert meta-logic into HOL statements, conjunct them together and setup
|
||||
* our goal statement. *)
|
||||
val int_props = map (Object_Logic.atomize_term lthy) final_props
|
||||
val goal = Utils.mk_conj_list int_props
|
||||
|> HOLogic.mk_Trueprop
|
||||
|> Thm.cterm_of lthy
|
||||
|
||||
val simps =
|
||||
#L2_simp_rules rule_set' @
|
||||
@{thms gets_bind_ign L2_call_fail HOL.simp_thms}
|
||||
|
||||
val rewrite_thm =
|
||||
Goal.init goal
|
||||
|> (fn goal_state => if is_recursive then
|
||||
goal_state
|
||||
|> apply_tac "start induction"
|
||||
(resolve_tac lthy @{thms recguard_induct} 1)
|
||||
|> apply_tac "(base case) split subgoals"
|
||||
(TRY (REPEAT_ALL_NEW (Tactic.match_tac lthy [@{thm conjI}]) 1))
|
||||
|> apply_tac "(base case) solve base cases"
|
||||
(EVERY (map (fn (ts_def, l2_info) =>
|
||||
SOLVES (
|
||||
(EqSubst.eqsubst_tac lthy [0] [#definition l2_info] 1)
|
||||
THEN (EqSubst.eqsubst_tac lthy [0] [ts_def] 1)
|
||||
THEN (simp_tac (put_simpset ts_simpset lthy) 1))) (ts_defs ~~ these_l2_infos)))
|
||||
|> apply_tac "(rec case) spliting induct case prems"
|
||||
(TRY (REPEAT_ALL_NEW (Tactic.ematch_tac lthy [@{thm conjE}]) 1))
|
||||
|> apply_tac "(rec case) split inductive case subgoals"
|
||||
(TRY (REPEAT_ALL_NEW (Tactic.match_tac lthy [@{thm conjI}]) 1))
|
||||
|> apply_tac "(rec case) unfolding strengthened function definition"
|
||||
(EVERY (map (fn (def, inst_thm) =>
|
||||
((EqSubst.eqsubst_tac lthy [0] [def] 1)
|
||||
THEN (EqSubst.eqsubst_tac lthy [0] [inst_thm] 1)
|
||||
THEN (REPEAT (CHANGED (asm_full_simp_tac (put_simpset ts_simpset lthy) 1)))))
|
||||
(ts_defs ~~ inst_thms)))
|
||||
else
|
||||
goal_state
|
||||
|> apply_tac "unfolding strengthen function definition"
|
||||
(EqSubst.eqsubst_tac lthy [0] [hd ts_defs] 1)
|
||||
|> apply_tac "unfolding L2 rewritten theorem"
|
||||
(EqSubst.eqsubst_tac lthy [0] [hd inst_thms] 1)
|
||||
|> apply_tac "simplifying remainder"
|
||||
(TRY (simp_tac (put_simpset HOL_ss (Utils.set_hidden_ctxt lthy) addsimps simps) 1))
|
||||
)
|
||||
|> Goal.finish lthy
|
||||
|
||||
(* Now, using this combined theorem, generate a theorem for each individual
|
||||
* function. *)
|
||||
fun prove_partial_pred thm pred =
|
||||
Thm.cterm_of lthy pred
|
||||
|> Goal.init
|
||||
|> apply_tac "inserting hypothesis"
|
||||
(cut_tac thm 1)
|
||||
|> apply_tac "normalising into rule format"
|
||||
((REPEAT (resolve_tac lthy @{thms allI} 1))
|
||||
THEN (REPEAT (eresolve_tac lthy @{thms conjE} 1))
|
||||
THEN (REPEAT (eresolve_tac lthy @{thms allE} 1)))
|
||||
|> apply_tac "solving goal" (assume_tac lthy 1)
|
||||
|> Goal.finish lthy
|
||||
|> Object_Logic.rulify lthy
|
||||
val new_thms = map (prove_partial_pred rewrite_thm) final_props
|
||||
|
||||
(*
|
||||
* Make the theorems polymorphic in their exception type.
|
||||
*
|
||||
* That is, these theories may all be applied regardless of what the type of
|
||||
* the exception part of the monad is, but are currently specialised to
|
||||
* when the exception part of the monad is unit. We apply a "polymorphism theorem" to change
|
||||
* the type of the rule from:
|
||||
*
|
||||
* ('s, unit + 'a) nondet_monad
|
||||
*
|
||||
* to
|
||||
*
|
||||
* ('s, 'e + 'a) nondet_monad
|
||||
*)
|
||||
val new_thms = map (fn t => #polymorphic_thm rule_set' OF [t]) new_thms
|
||||
|> map (fn t => Drule.generalize ([], ["rec_measure'"]) t)
|
||||
|
||||
(*
|
||||
(* Record correctness theorem(s) for what we just did. *)
|
||||
val lthy = fold (fn (fn_name, thm) =>
|
||||
Local_Theory.background_theory
|
||||
(AutoCorresData.add_thm filename "TScorres" fn_name thm))
|
||||
(fn_names ~~ new_thms) lthy
|
||||
*)
|
||||
|
||||
(* FIXME: resurrect in a better place
|
||||
(* Output final corres rule. *)
|
||||
fun output_rule fn_name lthy =
|
||||
let
|
||||
val thy = Proof_Context.theory_of lthy
|
||||
val l1_thm = the (AutoCorresData.get_thm thy filename "L1corres" fn_name)
|
||||
handle Option => raise SimplConv.FunctionNotFound fn_name
|
||||
val l2_thm = the (AutoCorresData.get_thm thy filename "L2corres" fn_name)
|
||||
handle Option => raise SimplConv.FunctionNotFound fn_name
|
||||
val hl_thm = the (AutoCorresData.get_thm thy filename "HLcorres" fn_name)
|
||||
handle Option => (@{thm L2Tcorres_trivial_from_local_var_extract} OF [l2_thm]
|
||||
handle THM _ => l2_thm) (* catch this failure below *)
|
||||
val wa_thm = the (AutoCorresData.get_thm thy filename "WAcorres" fn_name)
|
||||
(* If there is no WAcorres thm, it is probably because word_abstract
|
||||
* was disabled using no_word_abs.
|
||||
* In that case we just carry the hl_thm through. *)
|
||||
handle Option => (@{thm corresTA_trivial_from_heap_lift} OF [hl_thm]
|
||||
handle THM _ => hl_thm) (* catch this failure below *)
|
||||
val ts_thm = the (AutoCorresData.get_thm thy filename "TScorres" fn_name)
|
||||
handle Option => raise SimplConv.FunctionNotFound fn_name
|
||||
val lthy = let
|
||||
val final_thm = @{thm ac_corres_chain} OF [l1_thm, l2_thm, hl_thm, wa_thm, ts_thm]
|
||||
val final_thm' = Simplifier.simplify lthy final_thm
|
||||
val (_, lthy) = Utils.define_lemma (make_function_name fn_name ^ "_ac_corres") final_thm' lthy
|
||||
in
|
||||
lthy
|
||||
end
|
||||
handle THM _ =>
|
||||
(Utils.THM_non_critical keep_going
|
||||
("autocorres: failed to prove ac_corres theorem for " ^ fn_name)
|
||||
0 [l1_thm, l2_thm, hl_thm, wa_thm, ts_thm];
|
||||
lthy)
|
||||
in
|
||||
lthy
|
||||
end
|
||||
|
||||
val lthy = fold output_rule fn_names lthy
|
||||
*)
|
||||
|
||||
(* Final mono rules for recursive functions. *)
|
||||
fun try_mono_thm (fn_name, rewrite_thm) =
|
||||
case #mono_thm (the (Symtab.lookup l2_infos fn_name)) of
|
||||
NONE => NONE
|
||||
| SOME l2_mono_thm =>
|
||||
case try (#prove_mono rule_set' rewrite_thm) l2_mono_thm of
|
||||
NONE => (Utils.ac_warning ("Failed to prove monad_mono for function: " ^ fn_name);
|
||||
NONE)
|
||||
| SOME ts_mono_thm => SOME (fn_name, ts_mono_thm);
|
||||
val mono_thms = (if is_recursive then List.mapPartial try_mono_thm (fn_names ~~ new_thms) else [])
|
||||
|> Symtab.make;
|
||||
|
||||
val ts_infos =
|
||||
map (fn ((f, l2_info), (ts_def, ts_corres)) => let
|
||||
val mono_thm = Symtab.lookup mono_thms f;
|
||||
val const = Utils.get_term lthy (make_function_name f);
|
||||
val ts_info = l2_info
|
||||
|> FunctionInfo2.function_info_upd_phase FunctionInfo2.TS
|
||||
|> FunctionInfo2.function_info_upd_definition ts_def
|
||||
|> FunctionInfo2.function_info_upd_corres_thm ts_corres
|
||||
|> FunctionInfo2.function_info_upd_const const;
|
||||
in (f, ts_info) end)
|
||||
((fn_names ~~ these_l2_infos) ~~ (ts_defs ~~ new_thms))
|
||||
in
|
||||
(lthy, Symtab.make ts_infos, #name rule_set')
|
||||
end
|
||||
|
||||
|
||||
(* Return the lifting rule(s) to try for a function set.
|
||||
This is moved out of lift_function so that it can be used to
|
||||
provide argument checking in the AutoCorres.abstract wrapper. *)
|
||||
fun compute_lift_rules rules force_lift fn_names =
|
||||
let
|
||||
fun all_list f xs = fold (fn x => (fn b => b andalso f x)) xs true
|
||||
|
||||
val forced = fn_names
|
||||
|> map (fn func => case Symtab.lookup force_lift func of
|
||||
SOME rule => [(func, rule)]
|
||||
| NONE => [])
|
||||
|> List.concat
|
||||
in
|
||||
case forced of
|
||||
[] => rules (* No restrictions *)
|
||||
| ((func, rule) :: rest) =>
|
||||
(* Functions in the same set must all use the same lifting rule. *)
|
||||
if map snd rest |> all_list (fn rule' => #name rule = #name rule')
|
||||
then [rule] (* Try the specified rule *)
|
||||
else error ("autocorres: this set of mutually recursive functions " ^
|
||||
"cannot be lifted to different monads: " ^
|
||||
commas_quote (map fst forced))
|
||||
end
|
||||
|
||||
|
||||
(* Lift the given function set, trying each rule until one succeeds. *)
|
||||
fun lift_function rules force_lift filename prog_info l2_infos ts_infos
|
||||
fn_names make_function_name keep_going do_opt lthy =
|
||||
let
|
||||
val rules' = compute_lift_rules rules force_lift fn_names
|
||||
|
||||
(* Find the first lift that works. *)
|
||||
fun first (rule::xs) =
|
||||
(lift_function_rewrite rule filename prog_info l2_infos ts_infos
|
||||
fn_names make_function_name keep_going do_opt lthy
|
||||
handle LiftingFailed _ => first xs)
|
||||
| first [] = raise AllLiftingFailed (map (fn f =>
|
||||
(f, #definition (the (Symtab.lookup l2_infos f)))) fn_names)
|
||||
in
|
||||
first rules'
|
||||
end
|
||||
|
||||
(* Show how many functions were lifted to each monad. *)
|
||||
fun print_statistics results =
|
||||
let
|
||||
fun count_dups x [] = [x]
|
||||
| count_dups (head, count) (next::rest) =
|
||||
if head = next then
|
||||
count_dups (head, count + 1) rest
|
||||
else
|
||||
(head, count) :: (count_dups (next, 1) rest)
|
||||
val tabulated = count_dups ("__fake__", 0) (sort_strings results) |> tl
|
||||
val data = map (fn (a,b) =>
|
||||
(" " ^ a ^ ": " ^ (PolyML.makestring b) ^ "\n")
|
||||
) tabulated
|
||||
|> String.concat
|
||||
in
|
||||
writeln ("Type Strengthening Statistics: \n" ^ data)
|
||||
end
|
||||
|
||||
(* Run through every function, attempting to strengthen its type. *)
|
||||
fun translate
|
||||
(rules : Monad_Types.monad_type list)
|
||||
(force_lift : Monad_Types.monad_type Symtab.table)
|
||||
(filename : string)
|
||||
(prog_info : ProgramInfo.prog_info)
|
||||
(* lazy results from L2 (we currently pre-force them, though) *)
|
||||
(l2_results: (symset * (local_theory * FunctionInfo2.function_info Symtab.table) future) list)
|
||||
(make_function_name : string -> string)
|
||||
(keep_going : bool)
|
||||
(do_opt : bool) =
|
||||
if null l2_results then [] else
|
||||
let
|
||||
(* Wait for previous stage to finish first. *)
|
||||
val lthy = fst (Future.join (snd (hd (rev l2_results))));
|
||||
val l2_results = map (snd #> Future.join #> snd) l2_results;
|
||||
val l2_infos = LocalVarExtract2.symtab_merge false l2_results;
|
||||
|
||||
(* Prettify bound variable names in definitions. *)
|
||||
val l2_infos = Symtab.map (fn f_name => fn info => let
|
||||
val def = #definition (the (Symtab.lookup l2_infos f_name));
|
||||
val pretty_def = PrettyBoundVarNames.pretty_bound_vars_thm
|
||||
lthy (Utils.crhs_of (Thm.cprop_of def)) keep_going
|
||||
|> Thm.transitive def;
|
||||
in FunctionInfo2.function_info_upd_definition pretty_def info end)
|
||||
l2_infos;
|
||||
|
||||
(* For now, just works sequentially like the old TypeStrengthen.
|
||||
TODO: parallelise conversion and (especially) polishing *)
|
||||
fun translate_group fn_names (lthy, _, ts_infos) =
|
||||
let
|
||||
val _ = writeln ("Translating (type strengthen) " ^ Utils.commas fn_names);
|
||||
val start_time = Timer.startRealTimer ();
|
||||
|
||||
val (lthy, new_ts_infos, monad_name) =
|
||||
lift_function rules force_lift filename prog_info l2_infos ts_infos
|
||||
fn_names make_function_name keep_going do_opt lthy;
|
||||
|
||||
val _ = writeln (" --> " ^ monad_name);
|
||||
val _ = @{trace} ("Converted (TS) " ^ Utils.commas fn_names ^ " in " ^
|
||||
Time.toString (Timer.checkRealTimer start_time) ^ " s");
|
||||
in (lthy, new_ts_infos, Symtab.merge (K false) (ts_infos, new_ts_infos)) end;
|
||||
|
||||
val ts_results =
|
||||
Utils.accumulate translate_group (lthy, Symtab.empty, Symtab.empty)
|
||||
(map (map fst o Symtab.dest) l2_results)
|
||||
|> fst
|
||||
|> map (fn (lthy, ts_infos, _) => (lthy, ts_infos));
|
||||
in
|
||||
ts_results
|
||||
end
|
||||
|
||||
end
|
|
@ -397,8 +397,7 @@ fun dest_conj_list (Const (@{const_name "HOL.conj"}, _) $ l $ r)
|
|||
*)
|
||||
fun apply_tac (step : string) tac (thm : thm) =
|
||||
(tac thm |> Seq.hd) handle Option =>
|
||||
error ("Failed to apply tactic during '" ^ step ^ "': " ^ (
|
||||
(PolyML.makestring (cprems_of thm))))
|
||||
raise TERM ("Failed to apply tactic during " ^ quote step, Thm.prems_of thm)
|
||||
|
||||
(*
|
||||
* A "the" operator that explains what is going wrong.
|
||||
|
|
|
@ -0,0 +1,551 @@
|
|||
(*
|
||||
* Copyright 2014, NICTA
|
||||
*
|
||||
* This software may be distributed and modified according to the terms of
|
||||
* the BSD 2-Clause license. Note that NO WARRANTY is provided.
|
||||
* See "LICENSE_BSD2.txt" for details.
|
||||
*
|
||||
* @TAG(NICTA_BSD)
|
||||
*)
|
||||
|
||||
(*
|
||||
* Rewrite L2 specifications to use "nat" and "int" data-types instead of
|
||||
* "word" data types. The former tend to be easier to reason about.
|
||||
*)
|
||||
|
||||
structure WordAbstract2 =
|
||||
struct
|
||||
|
||||
(* Maximum depth that we will go before assuming that we are diverging. *)
|
||||
val WORD_ABS_MAX_DEPTH = 200
|
||||
|
||||
(* Convenience shortcuts. *)
|
||||
val warning = Utils.ac_warning
|
||||
val apply_tac = Utils.apply_tac
|
||||
val the' = Utils.the'
|
||||
|
||||
type WARules = {
|
||||
ctype : typ, atype : typ,
|
||||
abs_fn : term, inv_fn : term,
|
||||
rules : thm list
|
||||
}
|
||||
|
||||
fun mk_word_abs_rule T =
|
||||
{
|
||||
ctype = fastype_of (@{mk_term "x :: (?'W::len) word" ('W)} T),
|
||||
atype = @{typ nat},
|
||||
abs_fn = @{mk_term "unat :: (?'W::len) word => nat" ('W)} T,
|
||||
inv_fn = @{mk_term "of_nat :: nat => (?'W::len) word" ('W)} T,
|
||||
rules = @{thms word_abs_word32}
|
||||
}
|
||||
|
||||
val word_abs : WARules list =
|
||||
map mk_word_abs_rule [@{typ 32}, @{typ 16}, @{typ 8}]
|
||||
|
||||
fun mk_sword_abs_rule T =
|
||||
{
|
||||
ctype = fastype_of (@{mk_term "x :: (?'W::len) signed word" ('W)} T),
|
||||
atype = @{typ int},
|
||||
abs_fn = @{mk_term "sint :: (?'W::len) signed word => int" ('W)} T,
|
||||
inv_fn = @{mk_term "of_int :: int => (?'W::len) signed word" ('W)} T,
|
||||
rules = @{thms word_abs_sword32}
|
||||
}
|
||||
|
||||
val sword_abs : WARules list =
|
||||
map mk_sword_abs_rule [@{typ 32}, @{typ 16}, @{typ 8}]
|
||||
|
||||
(* Get abstract version of a HOL type. *)
|
||||
fun get_abs_type (rules : WARules list) T =
|
||||
Option.getOpt
|
||||
(List.find (fn r => #ctype r = T) rules
|
||||
|> Option.map (fn r => #atype r),
|
||||
T)
|
||||
|
||||
(* Get abstraction function for a HOL type. *)
|
||||
fun get_abs_fn (rules : WARules list) T =
|
||||
Option.getOpt
|
||||
(List.find (fn r => #ctype r = T) rules
|
||||
|> Option.map (fn r => #abs_fn r),
|
||||
@{mk_term "id :: ?'a => ?'a" ('a)} T)
|
||||
|
||||
fun get_abs_inv_fn (rules : WARules list) t =
|
||||
Option.getOpt
|
||||
(List.find (fn r => #ctype r = fastype_of t) rules
|
||||
|> Option.map (fn r => #inv_fn r $ t),
|
||||
t)
|
||||
|
||||
(*
|
||||
* From a list of abstract arguments to a function, derive a list of
|
||||
* concrete arguments and types and a precondition that links the two.
|
||||
*)
|
||||
fun get_wa_conc_args rules l2_infos fn_name fn_args =
|
||||
let
|
||||
(* Construct arguments for the concrete body. We use the abstract names
|
||||
* with a prime ('), but with the concrete types. *)
|
||||
val conc_types = the (Symtab.lookup l2_infos fn_name) |> #args |> map snd
|
||||
val conc_args = map (fn (Free (x, Tc), Ta) => Free (x ^ "'", Ta)) (* FIXME: Free *)
|
||||
(fn_args ~~ conc_types)
|
||||
val arg_pairs = (conc_args ~~ fn_args)
|
||||
|
||||
(* Create preconditions that link the new types to the old types. *)
|
||||
val precond =
|
||||
map (fn (old, new) => @{mk_term "abs_var ?n ?f ?o" (o, f, n)}
|
||||
(old, get_abs_fn rules (fastype_of old), new))
|
||||
arg_pairs
|
||||
|> Utils.mk_conj_list
|
||||
in
|
||||
(conc_types, conc_args, precond, arg_pairs)
|
||||
end
|
||||
|
||||
(* Get the expected type of a function from its name. *)
|
||||
fun get_expected_fn_type rules l2_infos fn_name =
|
||||
let
|
||||
val fn_info = the (Symtab.lookup l2_infos fn_name)
|
||||
val fn_params_typ = map ((get_abs_type rules) o snd) (#args fn_info)
|
||||
val fn_ret_typ = get_abs_type rules (#return_type fn_info)
|
||||
val globals_typ = LocalVarExtract.dest_l2monad_T (fastype_of (#const fn_info)) |> snd |> #1
|
||||
val measure_typ = @{typ "nat"}
|
||||
in
|
||||
(measure_typ :: fn_params_typ)
|
||||
---> LocalVarExtract.mk_l2monadT globals_typ fn_ret_typ @{typ unit}
|
||||
end
|
||||
|
||||
(* Get the expected theorem that will be generated about a function. *)
|
||||
fun get_expected_fn_thm rules l2_infos ctxt fn_name
|
||||
function_free fn_args _ measure_var =
|
||||
let
|
||||
val old_def = the (Symtab.lookup l2_infos fn_name)
|
||||
val (old_arg_types, old_args, precond, arg_pairs)
|
||||
= get_wa_conc_args rules l2_infos fn_name fn_args
|
||||
|
||||
val old_term = betapplys (#const old_def, measure_var :: old_args)
|
||||
val new_term = betapplys (function_free, measure_var :: fn_args)
|
||||
in
|
||||
@{mk_term "Trueprop (corresTA (%x. ?P) ?rt id ?A ?C)" (rt, A, C, P)}
|
||||
(get_abs_fn rules (#return_type old_def), new_term, old_term, precond)
|
||||
|> fold (fn t => fn v => Logic.all t v) (rev (map fst arg_pairs))
|
||||
end
|
||||
|
||||
(* Get arguments passed into the function. *)
|
||||
fun get_expected_fn_args rules l2_infos fn_name =
|
||||
map (apsnd (get_abs_type rules)) (#args (the (Symtab.lookup l2_infos fn_name)))
|
||||
|
||||
(*
|
||||
* Convert a theorem of the form:
|
||||
*
|
||||
* corresTA (%_. abs_var True a f a' \<and> abs_var True b f b' \<and> ...) ...
|
||||
*
|
||||
* into
|
||||
*
|
||||
* [| abstract_val A a f a'; abstract_val B b (f b') |] ==> corresTA (%_. A \<and> B \<and> ...) ...
|
||||
*
|
||||
* the latter of which better suits our resolution approach of proof
|
||||
* construction.
|
||||
*)
|
||||
fun extract_preconds_of_call thm =
|
||||
let
|
||||
fun r thm =
|
||||
r (thm RS @{thm corresTA_extract_preconds_of_call_step})
|
||||
handle THM _ => (thm RS @{thm corresTA_extract_preconds_of_call_final}
|
||||
handle THM _ => thm RS @{thm corresTA_extract_preconds_of_call_final'});
|
||||
in
|
||||
r (thm RS @{thm corresTA_extract_preconds_of_call_init})
|
||||
end
|
||||
|
||||
(* Convert a program by abstracting words. *)
|
||||
fun translate
|
||||
(filename: string)
|
||||
(prog_info: ProgramInfo.prog_info)
|
||||
(* lazy results *)
|
||||
(l2_results: (symset * (local_theory * FunctionInfo2.function_info Symtab.table) future) list)
|
||||
(unsigned_abs: Symset.key Symset.set)
|
||||
(no_signed_abs: Symset.key Symset.set)
|
||||
(trace_funcs: string list)
|
||||
(do_opt: bool)
|
||||
(trace_opt: bool)
|
||||
(wa_function_name: string -> string) =
|
||||
if null l2_results then [] else
|
||||
let
|
||||
(*
|
||||
* Select the rules to translate each function.
|
||||
*)
|
||||
fun rules_for fn_name =
|
||||
(if Symset.contains unsigned_abs fn_name then word_abs else []) @
|
||||
(if Symset.contains no_signed_abs fn_name then [] else sword_abs)
|
||||
|
||||
(* Results for individual functions *)
|
||||
val l2_results' = maps (fn (f_names, r) =>
|
||||
Symset.dest f_names ~~ replicate (Symset.card f_names) r) l2_results;
|
||||
|
||||
val get_l2_result = let
|
||||
val table = Symtab.make l2_results';
|
||||
in fn f => the' ("WA: missing lazy result for function: " ^ f) (Symtab.lookup table f) end;
|
||||
|
||||
(* Abstract functions. *)
|
||||
fun convert f =
|
||||
let
|
||||
(* Info for f and its callees *)
|
||||
val (lthy, f_l2_info) = Future.join (get_l2_result f);
|
||||
val callee_l1_infos = Symset.dest (FunctionInfo2.all_callees (the (Symtab.lookup f_l2_info f)))
|
||||
|> map (fn callee => snd (Future.join (get_l2_result callee)));
|
||||
val l2_infos = LocalVarExtract2.symtab_merge true (f_l2_info :: callee_l1_infos);
|
||||
val old_fn_info = the (Symtab.lookup l2_infos f);
|
||||
|
||||
val wa_rules = rules_for f
|
||||
|
||||
(* Fix measure variable. *)
|
||||
val ([measure_var_name], lthy) = Variable.variant_fixes ["rec_measure'"] lthy;
|
||||
val measure_var = Free (measure_var_name, LocalVarExtract2.measureT);
|
||||
|
||||
(* Fix argument variables. *)
|
||||
val new_fn_args = get_expected_fn_args wa_rules l2_infos f;
|
||||
val (arg_names, lthy) = Variable.variant_fixes (map fst new_fn_args) lthy;
|
||||
val arg_frees = map Free (arg_names ~~ map snd new_fn_args);
|
||||
|
||||
val (lthy, export_thm, callee_terms) =
|
||||
AutoCorresUtil2.assume_called_functions_corres lthy
|
||||
(#callees old_fn_info) (#rec_callees old_fn_info)
|
||||
(fn f => get_expected_fn_type (rules_for f) l2_infos f)
|
||||
(fn lthy => fn f => get_expected_fn_thm (rules_for f) l2_infos lthy f)
|
||||
(fn f => get_expected_fn_args (rules_for f) l2_infos f)
|
||||
wa_function_name
|
||||
measure_var;
|
||||
|
||||
(* Construct free variables to represent our concrete arguments. *)
|
||||
val (conc_types, conc_args, precond, arg_pairs)
|
||||
= get_wa_conc_args wa_rules l2_infos f arg_frees
|
||||
|
||||
(* Fetch the function definition, and instantiate its arguments. *)
|
||||
val old_body_def =
|
||||
#definition old_fn_info
|
||||
(* Instantiate the arguments. *)
|
||||
|> Utils.inst_args lthy (map (Thm.cterm_of lthy) (measure_var :: conc_args))
|
||||
|
||||
(* Get old body definition with function arguments. *)
|
||||
val old_term = betapplys (#const old_fn_info, measure_var :: conc_args)
|
||||
|
||||
(* Get a schematic variable accepting new arguments. *)
|
||||
val new_var = betapplys (
|
||||
Var (("A", 0), get_expected_fn_type wa_rules l2_infos f), measure_var :: arg_frees)
|
||||
|
||||
(* Fetch monotonicity theorems of callees. *)
|
||||
val callee_mono_thms =
|
||||
Symset.dest (FunctionInfo2.all_callees old_fn_info)
|
||||
|> List.mapPartial (fn callee =>
|
||||
if FunctionInfo2.is_function_recursive (the (Symtab.lookup l2_infos callee))
|
||||
then the (Symtab.lookup l2_infos callee) |> #mono_thm
|
||||
else NONE)
|
||||
|
||||
(*
|
||||
* Generate a schematic goal.
|
||||
*
|
||||
* We only want ?A to depend on abstracted variables and ?C to depend on
|
||||
* concrete variables. We force this by applying bound variables to each
|
||||
* of the schematics, giving us something like:
|
||||
*
|
||||
* !!a a' b b'. corresTA ... (?A a b) (?C a' b')
|
||||
*
|
||||
* The abstract side will hence be prevented from capturing (i.e., using)
|
||||
* concrete variables, and vice-versa.
|
||||
*)
|
||||
val goal = @{mk_term "Trueprop (corresTA (%x. ?precond) ?ra id ?A ?C)" (ra, A, C, precond)}
|
||||
(get_abs_fn wa_rules (#return_type old_fn_info), new_var, old_term, precond)
|
||||
|> fold (fn t => fn v => Logic.all t v) (rev (arg_frees @ map fst arg_pairs))
|
||||
|> Thm.cterm_of lthy
|
||||
|> Goal.init
|
||||
|> Utils.apply_tac "move precond to assumption" (resolve_tac lthy @{thms corresTA_precond_to_asm} 1)
|
||||
|> Utils.apply_tac "split precond" (REPEAT (CHANGED (eresolve_tac lthy @{thms conjE} 1)))
|
||||
|> Utils.apply_tac "create schematic precond" (resolve_tac lthy @{thms corresTA_precond_to_guard} 1)
|
||||
|> Utils.apply_tac "unfold RHS" (CHANGED (Utils.unfold_once_tac lthy (Utils.abs_def lthy old_body_def) 1))
|
||||
|
||||
(*
|
||||
* Fetch rules from the theory.
|
||||
*)
|
||||
val rules = Utils.get_rules lthy @{named_theorems word_abs}
|
||||
@ List.concat (map #rules wa_rules)
|
||||
@ @{thms word_abs_default}
|
||||
val fo_rules = [@{thm abstract_val_fun_app}]
|
||||
|
||||
|
||||
val rules = rules @ (map (snd #> #3 #> extract_preconds_of_call) callee_terms)
|
||||
@ callee_mono_thms
|
||||
|
||||
(* Standard tactics. *)
|
||||
fun my_rtac ctxt thm n =
|
||||
Utils.trace_if_success ctxt thm (
|
||||
DETERM (EVERY' (resolve_tac ctxt [thm] :: replicate (Rule_Cases.get_consumes thm) (assume_tac ctxt)) n))
|
||||
|
||||
(* Apply a conversion to the abstract/concrete side of the given "abstract_val" term. *)
|
||||
fun wa_conc_body_conv conv =
|
||||
Conv.params_conv (~1) (K (Conv.concl_conv (~1) ((Conv.arg_conv (Utils.nth_arg_conv 4 conv)))))
|
||||
|
||||
(* Tactics and conversions for converting goals into first-order format. *)
|
||||
fun to_fo_tac ctxt =
|
||||
CONVERSION (Drule.beta_eta_conversion then_conv wa_conc_body_conv (HeapLift.mk_first_order ctxt) ctxt)
|
||||
fun from_fo_tac ctxt =
|
||||
CONVERSION (wa_conc_body_conv (HeapLift.dest_first_order ctxt then_conv Drule.beta_eta_conversion) ctxt)
|
||||
fun make_fo_tac tac ctxt = ((to_fo_tac ctxt THEN' tac) THEN_ALL_NEW from_fo_tac ctxt)
|
||||
|
||||
|
||||
(*
|
||||
* Recursively solve subgoals.
|
||||
*
|
||||
* We allow backtracking in order to solve a particular subgoal, but once a
|
||||
* subgoal is completed we don't ever try to solve it in a different way.
|
||||
*
|
||||
* This allows us to try different approaches to solving subgoals,
|
||||
* hopefully reducing exponential explosion (of many different combinations
|
||||
* of "good solutions") once we hit an unsolvable subgoal.
|
||||
*)
|
||||
fun SOLVE_ALL _ _ 0 thm =
|
||||
raise THM ("Word abstraction diverging", 0, [thm])
|
||||
| SOLVE_ALL ctxt tacs depth thm =
|
||||
let
|
||||
fun TRY_ALL [] = no_tac
|
||||
| TRY_ALL (x::xs) =
|
||||
(x ctxt THEN REPEAT (SELECT_GOAL (SOLVE_ALL ctxt tacs (depth - 1)) 1))
|
||||
APPEND (TRY_ALL xs)
|
||||
in
|
||||
if Thm.nprems_of thm > 0 then
|
||||
DETERM (SOLVES (TRY_ALL tacs)) thm
|
||||
else
|
||||
all_tac thm
|
||||
end
|
||||
|
||||
(*
|
||||
* Eliminate a lambda term in the concrete state, but only if the
|
||||
* lambda is "real".
|
||||
*
|
||||
* That is, we don't attempt to eta-expand in order to apply the theorem
|
||||
* "abstract_val_lambda", because that may lead to an infinite loop with
|
||||
* "abstract_val_fun_app".
|
||||
*)
|
||||
fun lambda_tac n thm =
|
||||
case Logic.concl_of_goal (Thm.prop_of thm) n of
|
||||
(Const (@{const_name "Trueprop"}, _) $
|
||||
(Const (@{const_name "abstract_val"}, _) $ _ $ _ $ _ $ (
|
||||
Abs (_, _, _)))) =>
|
||||
resolve_tac lthy @{thms abstract_val_lambda} n thm
|
||||
| _ => no_tac thm
|
||||
|
||||
(* All tactics we try, in the order we should try them. *)
|
||||
val step_tacs =
|
||||
[(@{thm imp_refl}, assume_tac lthy 1)]
|
||||
@ (map (fn thm => (thm, my_rtac lthy thm 1)) rules)
|
||||
@ (map (fn thm => (thm, make_fo_tac (my_rtac lthy thm) lthy 1)) fo_rules)
|
||||
@ [(@{thm abstract_val_lambda}, lambda_tac 1)]
|
||||
@ [(@{thm reflexive},
|
||||
fn thm =>
|
||||
(if Config.get lthy ML_Options.exception_trace then
|
||||
warning ("Could not solve subgoal: " ^
|
||||
(Goal_Display.string_of_goal lthy thm))
|
||||
else (); no_tac thm))]
|
||||
|
||||
(* Solve the goal. *)
|
||||
val replay_failure_start = 1
|
||||
val replay_failures = Unsynchronized.ref replay_failure_start
|
||||
val (thm, trace) =
|
||||
case AutoCorresTrace.maybe_trace_solve_tac lthy (member (op =) trace_funcs f) true false
|
||||
(K step_tacs) goal (SOME WORD_ABS_MAX_DEPTH) replay_failures of
|
||||
NONE => (* intentionally generate a TRACE_SOLVE_TAC_FAIL *)
|
||||
(AutoCorresTrace.trace_solve_tac lthy false false (K step_tacs) goal NONE (Unsynchronized.ref 0);
|
||||
(* never reached *) error "word_abstract fail tac: impossible")
|
||||
| SOME (thm, [trace]) => (Goal.finish lthy thm, trace)
|
||||
val _ = if !replay_failures < replay_failure_start then
|
||||
@{trace} (f ^ " WA: reverted to slow replay " ^
|
||||
Int.toString(replay_failure_start - !replay_failures) ^ " time(s)") else ()
|
||||
|
||||
(* Clean out any final function application ($) constants or "id" constants
|
||||
* generated by some rules. *)
|
||||
fun corresTA_abs_conv conv =
|
||||
Utils.remove_meta_conv (fn ctxt => Utils.nth_arg_conv 4 (conv ctxt)) lthy
|
||||
|
||||
val thm =
|
||||
Conv.fconv_rule (
|
||||
corresTA_abs_conv (fn ctxt =>
|
||||
(HeapLift.dest_first_order ctxt)
|
||||
then_conv (Simplifier.rewrite (
|
||||
put_simpset HOL_basic_ss ctxt addsimps [@{thm id_def}]))
|
||||
then_conv Drule.beta_eta_conversion
|
||||
)
|
||||
) thm
|
||||
|
||||
(* Ensure no schematics remain in the goal. *)
|
||||
val _ = Sign.no_vars lthy (Thm.prop_of thm)
|
||||
|
||||
(* Gather statistics. *)
|
||||
val _ = Statistics.gather lthy "WA" f
|
||||
(Variable.gen_all lthy thm
|
||||
|> Thm.prop_of
|
||||
|> HOLogic.dest_Trueprop
|
||||
|> (fn t => Utils.term_nth_arg t 3))
|
||||
|
||||
(*
|
||||
* Instantiate abstract function's meta-forall variables with their actual values.
|
||||
*
|
||||
* That is, we go from:
|
||||
*
|
||||
* !!a b c a' b' c'. corresTA (P a b c) ...
|
||||
*
|
||||
* to
|
||||
*
|
||||
* !!a' b' c'. corresTA (P a b c) ...
|
||||
*)
|
||||
val thm = Drule.forall_elim_list (map (Thm.cterm_of lthy) arg_frees) thm
|
||||
|
||||
(* Apply peephole optimisations to the theorem. *)
|
||||
val _ = writeln ("Simpifying (WA) " ^ f)
|
||||
val (thm, opt_traces) = L2Opt.cleanup_thm_tagged lthy thm (if do_opt then 0 else 2) 4 trace_opt "WA"
|
||||
|
||||
(* We end up with an unwanted L2_guard outside the L2_recguard.
|
||||
* L2Opt should simplify the condition to (%_. True) even if (not do_opt),
|
||||
* so we match the guard and get rid of it here. *)
|
||||
val thm = Simplifier.rewrite_rule lthy @{thms corresTA_simp_trivial_guard} thm
|
||||
|
||||
(* Gather post-optimisation statistics. *)
|
||||
val _ = Statistics.gather lthy "WAsimp" f
|
||||
(Variable.gen_all lthy thm
|
||||
|> Thm.prop_of
|
||||
|> HOLogic.dest_Trueprop
|
||||
|> (fn t => Utils.term_nth_arg t 3))
|
||||
|
||||
(* Extract the abstract term out of a corresTA thm. *)
|
||||
fun dest_corresWA_term_abs @{term_pat "corresTA _ _ _ ?t _"} = t
|
||||
fun get_body_of_thm thm =
|
||||
Thm.concl_of (Variable.gen_all lthy thm)
|
||||
|> HOLogic.dest_Trueprop
|
||||
|> dest_corresWA_term_abs
|
||||
|
||||
val _ = @{trace} ("WA conv", f, thm)
|
||||
val thm' = Morphism.thm export_thm (Variable.gen_all lthy thm)
|
||||
val _ = @{trace} ("WA conv final", f, thm')
|
||||
in
|
||||
(get_body_of_thm thm, thm', dest_Free measure_var :: new_fn_args,
|
||||
(if member (op =) trace_funcs f then [("WA", AutoCorresData.RuleTrace trace)] else []) @ opt_traces)
|
||||
end
|
||||
|
||||
(* Define a previously-converted function (or recursive function group).
|
||||
* lthy must include all definitions from l2_callees. *)
|
||||
fun define
|
||||
(lthy: local_theory)
|
||||
(l2_infos: FunctionInfo2.function_info Symtab.table)
|
||||
(wa_callees: FunctionInfo2.function_info Symtab.table)
|
||||
(funcs: (string * thm * (string * typ) list) list) (* name, corres, arg frees *)
|
||||
: FunctionInfo2.function_info Symtab.table * local_theory = let
|
||||
(* FIXME: dedup with convert, LocalVarExtract *)
|
||||
|
||||
(* FIXME: pass this from assume_called_functions_corres, etc. *)
|
||||
fun guess_callee_var thm callee = let
|
||||
val base_name = wa_function_name callee;
|
||||
val mentioned_vars = Term.add_vars (Thm.prop_of thm) [];
|
||||
in hd (filter (fn ((v, _), _) => v = base_name) mentioned_vars) end;
|
||||
|
||||
fun prepare_fn_body (fn_name, corres_thm, arg_frees) = let
|
||||
val _ = @{trace} ("WA prepare_fn_body", fn_name, corres_thm);
|
||||
val @{term_pat "Trueprop (corresTA _ _ _ ?body _)"} = Thm.concl_of corres_thm;
|
||||
val (callees, recursive_callees) = AutoCorresUtil2.get_callees l2_infos fn_name;
|
||||
val calls = map (fn c => Var (guess_callee_var corres_thm c)) callees;
|
||||
val recursive_calls = map (fn c => Var (guess_callee_var corres_thm c)) recursive_callees;
|
||||
|
||||
(*
|
||||
* The returned body will have free variables as placeholders for the function's
|
||||
* measure parameter and other arguments, and schematic variables for the functions it calls.
|
||||
*
|
||||
* We modify the body to be of the form:
|
||||
*
|
||||
* %fun1 fun2 rec1 rec2 measure arg1 arg2. f <...>
|
||||
*
|
||||
* That is, all non-recursive calls are abstracted out the front, followed by
|
||||
* recursive calls, followed by the measure variable, followed by function
|
||||
* arguments. This is the format expected by define_funcs.
|
||||
*)
|
||||
val abs_body = body
|
||||
|> fold lambda (rev (map Free arg_frees))
|
||||
|> fold lambda (rev recursive_calls)
|
||||
|> fold lambda (rev calls);
|
||||
in abs_body end;
|
||||
|
||||
val funcs' = funcs |>
|
||||
map (fn (name, thm, frees) =>
|
||||
(name, (* FIXME: define_funcs needs this currently *)
|
||||
(Thm.generalize ([], map fst frees) (Thm.maxidx_of thm + 1) thm,
|
||||
prepare_fn_body (name, thm, frees))));
|
||||
val _ = @{trace} ("WA.define", map (fn (name, (thm, body)) => (name, thm, Thm.cterm_of lthy body)) funcs');
|
||||
val (new_thms, (), lthy') =
|
||||
AutoCorresUtil2.define_funcs FunctionInfo2.WA filename l2_infos
|
||||
wa_function_name
|
||||
(fn f => get_expected_fn_type (rules_for f) l2_infos f)
|
||||
(fn lthy => fn f => get_expected_fn_thm (rules_for f) l2_infos lthy f)
|
||||
(fn f => get_expected_fn_args (rules_for f) l2_infos f)
|
||||
@{thm corresTA_recguard_0}
|
||||
lthy (Symtab.map (K #corres_thm) wa_callees) ()
|
||||
funcs';
|
||||
val new_infos = Symtab.map (fn f_name => fn (const, def, corres_thm) => let
|
||||
val old_info = the (Symtab.lookup l2_infos f_name);
|
||||
in old_info
|
||||
|> FunctionInfo2.function_info_upd_phase FunctionInfo2.WA
|
||||
|> FunctionInfo2.function_info_upd_const const
|
||||
|> FunctionInfo2.function_info_upd_definition def
|
||||
|> FunctionInfo2.function_info_upd_corres_thm corres_thm
|
||||
|> FunctionInfo2.function_info_upd_return_type
|
||||
(get_abs_type (rules_for f_name) (#return_type old_info))
|
||||
|> FunctionInfo2.function_info_upd_args
|
||||
(map (apsnd (get_abs_type (rules_for f_name))) (#args old_info))
|
||||
|> FunctionInfo2.function_info_upd_mono_thm NONE (* added later *)
|
||||
end) new_thms;
|
||||
in (new_infos, lthy') end;
|
||||
|
||||
|
||||
val function_groups = map fst l2_results;
|
||||
|
||||
(* All conversions can run in parallel.
|
||||
* Each conversion depends only on the previous define phase
|
||||
* (which necessarily also includes definitions of callees). *)
|
||||
val converted_funcs =
|
||||
maps Symset.dest function_groups
|
||||
|> map (fn f => (f, Future.fork (fn _ => convert f)))
|
||||
|> Symtab.make;
|
||||
|
||||
val _ = Symtab.dest converted_funcs |> map snd |> map Future.join
|
||||
|
||||
(* Definitions update lthy sequentially.
|
||||
* We use the arbitrary (but deterministic) ordering defined by get_topo_sorted_functions.
|
||||
* Each definition step produces a lthy that has a prefix of the update sequence,
|
||||
* and can be used subsequently to convert a function that depends only on that prefix.
|
||||
* Hence we produce the intermediate lthys lazily for maximum parallelism. *)
|
||||
fun add_def f_names accum = Future.fork (fn _ => let
|
||||
(* Wait for previous definition to finish *)
|
||||
val (lthy, _, defined_so_far) = Future.join accum;
|
||||
(* Wait for conversions to finish *)
|
||||
val f_convs = map (fn f => let
|
||||
val conv = the' ("didn't convert function: " ^ quote f ^ "??") (Symtab.lookup converted_funcs f);
|
||||
val (wa_body, corres_thm, arg_frees, traces) = Future.join conv;
|
||||
in (f, corres_thm, arg_frees) end) f_names;
|
||||
(* Get L2 phase results (should be a no-op at this point) *)
|
||||
val (_, l2_infos) = Future.join (get_l2_result (hd f_names));
|
||||
val (new_defs, lthy') = define lthy l2_infos defined_so_far f_convs;
|
||||
in (lthy', new_defs, Symtab.merge (K false) (defined_so_far, new_defs)) end);
|
||||
|
||||
(* Get initial lthy from end of L2 defs *)
|
||||
val l2_def_context = Future.map (fn (lthy, _) => (lthy, Symtab.empty, Symtab.empty))
|
||||
(snd (hd (rev l2_results)));
|
||||
(* Chain of intermediate states: (lthy, new_defs, accumulator) *)
|
||||
val (def_results, _) = Utils.accumulate add_def l2_def_context
|
||||
(map Symset.dest function_groups);
|
||||
|
||||
(* Produce a mapping from each function group to its L1 phase_infos and the
|
||||
* earliest intermediate lthy where it is defined. *)
|
||||
val results =
|
||||
def_results
|
||||
|> map (Future.map (fn (lthy, f_defs, _ (* discard accum here *)) => let
|
||||
(* Add monad_mono proofs. These are done in parallel as well
|
||||
* (though in practice, they already run pretty quickly). *)
|
||||
val mono_thms = if FunctionInfo2.is_function_recursive (snd (hd (Symtab.dest f_defs)))
|
||||
then LocalVarExtract2.l2_monad_mono lthy f_defs
|
||||
else Symtab.empty;
|
||||
val f_defs' = f_defs |> Symtab.map (fn f =>
|
||||
FunctionInfo2.function_info_upd_mono_thm (Symtab.lookup mono_thms f));
|
||||
in (lthy, f_defs') end));
|
||||
in function_groups ~~ results end
|
||||
|
||||
end
|
Loading…
Reference in New Issue