WIP: autocorres: crude VER-517 prototypes for WA and TS phases

HL is still pending; the new code also needs to be refactored itself.
This commit is contained in:
Japheth Lim 2016-06-16 13:04:20 +10:00
parent 84cb9deaf8
commit 08c3475a09
10 changed files with 1231 additions and 98 deletions

View File

@ -145,11 +145,13 @@ ML_file "monad_convert.ML"
ML_file "type_strengthen.ML"
ML_file "autocorres.ML"
declare [[ML_print_depth=42]]
declare [[ML_print_depth=27]]
ML_file "function_info2.ML"
ML_file "autocorres_util2.ML"
ML_file "simpl_conv2.ML"
ML_file "local_var_extract2.ML"
ML_file "word_abstract2.ML"
ML_file "type_strengthen2.ML"
(* Setup "autocorres" keyword. *)
ML {*
@ -159,82 +161,4 @@ ML {*
(Toplevel.theory o (fn (opt, filename) => AutoCorres.do_autocorres opt filename)))
*}
ML \<open>
fun assume_called_functions_corres ctxt fn_info callees
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var =
let
(* Assume the existence of a function, along with a theorem about its
* behaviour. *)
fun assume_func ctxt fn_name is_recursive_call =
let
val fn_args = get_fn_args fn_name
(* Fix a variable for the function. *)
val ([fixed_fn_name], ctxt') = Variable.variant_fixes [get_const_name fn_name] ctxt
val fn_free = Free (fixed_fn_name, get_fn_type fn_name)
(* Fix a variable for the measure and function arguments. *)
val (measure_var_name :: arg_names, ctxt'')
= Variable.variant_fixes ("rec_measure'" :: (map fst fn_args)) ctxt'
val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args))
val my_measure_var = Free (measure_var_name, @{typ nat})
(*
* A measure variable is needed to handle recursion: for recursive calls,
* we need to decrement the caller's input measure value (and our
* assumption will need to assume this to). This is so we can later prove
* termination of our function definition: the measure always reaches zero.
*
* Non-recursive calls can have a fresh value.
*)
val measure_var =
if is_recursive_call then
@{const "recguard_dec"} $ callers_measure_var
else
my_measure_var
(* Create our assumption. *)
val assumption =
get_fn_assumption ctxt'' fn_name fn_free fn_arg_terms
is_recursive_call measure_var
|> fold Logic.all (rev ((if is_recursive_call then [] else [my_measure_var]) @ fn_arg_terms))
|> Sign.no_vars ctxt'
|> Thm.cterm_of ctxt'
val ([thm], ctxt''') = Assumption.add_assumes [assumption] ctxt'
(* Generate a morphism for escaping this context. *)
val m = (Assumption.export_morphism ctxt''' ctxt')
$> (Variable.export_morphism ctxt' ctxt)
in
(fn_free, thm, ctxt''', m)
end
(* Apply each assumption. *)
val (res, (ctxt', m)) = fold_map (
fn (fn_name, is_recursive_call) =>
fn (ctxt, m) =>
let
val (free, thm, ctxt', m') =
assume_func ctxt fn_name is_recursive_call
in
((fn_name, (is_recursive_call, free, thm)), (ctxt', m' $> m))
end)
callees (ctxt, Morphism.identity)
in
(ctxt', m, res)
end;
(*
fun assume_called_functions_corres ctxt fn_info callees
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var
*)
assume_called_functions_corres @{context} () [("a", false), ("r", true)]
(K @{typ "nat \<Rightarrow> nat \<Rightarrow> string \<Rightarrow> string"})
(fn ctxt => fn name => fn term => fn args => fn is_rec => fn meas =>
HOLogic.mk_Trueprop (@{term "my_corres :: (string \<Rightarrow> string) \<Rightarrow> bool"} $ betapplys (term, meas :: args)))
(fn f => if f = "a" then [("arg_a", @{typ nat})] else [("arg_r", @{typ nat})])
I @{term "rec_measure :: nat"}
\<close>
end

View File

@ -1584,6 +1584,7 @@ fun define
fun prepare_fn_body (fn_name, corres_thm, arg_frees) = let
val _ = @{trace} ("prepare_fn_body", fn_name, corres_thm);
(* FIXME: move this to convert *)
val @{term_pat "Trueprop (L2corres _ _ _ _ ?body _)"} = Thm.concl_of corres_thm;
val (callees, recursive_callees) = AutoCorresUtil2.get_callees l1_infos fn_name;
val calls = map (fn c => Var (guess_callee_var corres_thm c)) callees;
@ -1622,7 +1623,6 @@ fun define
@{thm L2corres_recguard_0}
lthy (Symtab.map (K #corres_thm) l2_callees) ()
funcs';
val f_names = map (fn (name, _, _) => name) funcs;
val new_infos = Symtab.map (fn f_name => fn (const, def, corres_thm) => let
val old_info = the (Symtab.lookup l1_infos f_name);
in old_info
@ -1631,11 +1631,15 @@ fun define
|> FunctionInfo2.function_info_upd_definition def
|> FunctionInfo2.function_info_upd_corres_thm corres_thm
|> FunctionInfo2.function_info_upd_mono_thm NONE (* added later *)
(* Update arg names to match our newly converted functions *)
|> FunctionInfo2.function_info_upd_args
(map (apfst (to_free_var_name lthy' o ProgramInfo.demangle_name)) (#args old_info))
end) new_thms;
(* FIXME: return traces *)
in (new_infos, lthy') end;
(* FIXME: move *)
fun symtab_merge allow_dups tabs =
maps Symtab.dest tabs
|> (if allow_dups then sort_distinct (fast_string_ord o apply2 fst) else I)
@ -1648,7 +1652,7 @@ fun translate filename prog_info
(* lazy results from L1 *)
(l1_results: (symset * (local_theory * FunctionInfo2.function_info Symtab.table) future) list)
do_opt trace_opt l2_function_name =
(* if there's nothing to translate, we won't have a lthy *)
(* if there's nothing to translate, we won't have a lthy to use *)
if null l1_results then [] else
let
(* TODO: we should recalculate this from l1_results to take dead-code elim

View File

@ -20,6 +20,10 @@ begin
(* Parse the input file. *)
install_C_file "factorial.c"
(*
autocorres [scope_depth=0, scope=factorial] "factorial.c"
autocorres [scope_depth=0, scope=call_factorial] "factorial.c"
*)
autocorres "factorial.c"
context factorial begin

View File

@ -19,6 +19,10 @@ install_C_file "simple.c"
(* Abstract the input file. *)
autocorres [ ts_force pure = max, ts_force nondet = gcd, unsigned_word_abs = gcd ] "simple.c"
(*
autocorres [ ts_force nondet = gcd, unsigned_word_abs = gcd, scope_depth=0, scope=gcd ] "simple.c"
autocorres [ ts_force pure = max, scope_depth=0, scope=max ] "simple.c"
*)
(* Generated theorems and proofs. *)
thm simple.max'_def simple.max'_ac_corres

View File

@ -23,4 +23,3 @@ unsigned gcd(unsigned a, unsigned b) {
}
return b;
}

View File

@ -49,36 +49,43 @@ let val simpl_infos = FunctionInfo2.init_function_info @{context} "type_strength
in (frees1, corres1, Symtab.dest l1_infos2) end
\<close>
ML \<open>
local_setup \<open>
fn lthy =>
let val filename = "type_strengthen.c";
val simpl_info = FunctionInfo2.init_function_info @{context} filename;
val prog_info = ProgramInfo.get_prog_info @{context} filename;
val simpl_info = FunctionInfo2.init_function_info lthy filename;
val prog_info = ProgramInfo.get_prog_info lthy filename;
val l1_results =
SimplConv2.translate filename prog_info simpl_info
true true false (fn f => "l1_" ^ f ^ "'") @{context};
true true false (fn f => "l1_" ^ f ^ "'") lthy;
val l2_results =
LocalVarExtract2.translate filename prog_info l1_results
true false (fn f => "l2_" ^ f ^ "'");
in l2_results |> map (snd #> Future.join) |> map (snd #> Symtab.dest) end
val wa_results =
WordAbstract2.translate filename prog_info l2_results Symset.empty Symset.empty []
true false (fn f => "wa_" ^ f ^ "'");
val ts_rules = Monad_Types.get_ordered_rules [] (Context.Proof lthy);
val ts_results =
TypeStrengthen2.translate ts_rules Symtab.empty filename prog_info
wa_results (fn f => f ^ "'") true true;
in ts_results |> rev |> hd |> fst end
\<close>
ML \<open>
FunctionInfo2.init_function_info @{context} "type_strengthen.c"
|> Symtab.dest
\<close>
end
declare [[ML_print_depth=99]]
autocorres [
ts_rules = nondet,
scope = st_i,
skip_heap_abs, skip_word_abs
] "type_strengthen.c"
(* We can also specify which monads are used for type strengthening.
Here, we exclude the read-only monad completely, and specify
rules for some individual functions. *)
autocorres [
ts_rules = pure option nondet,
ts_force option = pure_f,
statistics
skip_heap_abs
] "type_strengthen.c"
context type_strengthen begin

View File

@ -386,6 +386,7 @@ let
#L2_simp_rules rule_set' @
@{thms gets_bind_ign L2_call_fail HOL.simp_thms}
val _ = @{trace} ("ts proof rules", map fst functions, defs ~~ fn_defs)
val rewrite_thm =
Goal.init goal
|> (fn goal_state => if is_recursive then

View File

@ -0,0 +1,640 @@
(*
* Copyright 2014, NICTA
*
* This software may be distributed and modified according to the terms of
* the BSD 2-Clause license. Note that NO WARRANTY is provided.
* See "LICENSE_BSD2.txt" for details.
*
* @TAG(NICTA_BSD)
*)
(*
* Lift monadic structures into lighter-weight monads.
*)
structure TypeStrengthen2 =
struct
exception AllLiftingFailed of (string * thm) list
exception LiftingFailed of unit
(* FIXME: use AUTOCORRES_SIMPSET (need to fix unknown deps of the corres prover) *)
val ts_simpset = simpset_of @{context}
(* Misc util functions. *)
val the' = Utils.the'
val apply_tac = Utils.apply_tac
fun get_l2_state_typ ctxt prog_info l2_infos fn_name =
let
val term = #const (the (Symtab.lookup l2_infos fn_name));
in
LocalVarExtract.dest_l2monad_T (fastype_of term) |> snd |> #1
end;
fun get_typ_from_L2 (rule_set : Monad_Types.monad_type) L2_typ =
LocalVarExtract.dest_l2monad_T L2_typ |> snd |> #typ_from_L2 rule_set
(*
* Make an equality prop of the form "L2_call <foo> = <liftE> $ <bar>".
*
* L2_call and <liftE> will typically be desired to be polymorphic in their
* exception type. We fix it to "unit"; the caller will need to introduce
* polymorphism as necessary.
*
* If "measure" is non-NONE, then that term will be used instead of a free
* variable.
* If "state_typ" is non-NONE, then "measure" is assumed to also take a
* state parameter of the given type.
*)
fun make_lift_equality ctxt prog_info l2_infos fn_name
(rule_set : Monad_Types.monad_type) state_typ measure rhs_term =
let
val thy = Proof_Context.theory_of ctxt
(* Fetch function variables. *)
val fn_def = the (Symtab.lookup l2_infos fn_name)
val inputs = #args fn_def
val input_vars = map (fn (x, y) => Var ((x, 0), y)) inputs
(*
* Measure var.
*
* The L2 left-hand-side will always use this measure, while the
* right-hand-side will only include a measure if the function is actually
* recursive.
*)
val is_recursive = FunctionInfo2.is_function_recursive fn_def
val default_measure_var = @{term "rec_measure' :: nat"}
val measure_term = Option.getOpt (measure, default_measure_var)
(*
* Construct the equality.
*
* This is a little delicate: in particular, we need to ensure that
* the type of the resulting term is strictly correct. In particular,
* our "lift_fn" will have type variables that need to be modified.
* "Utils.mk_term" will fill in type variables of the base term based
* on what is applied to it. So, we need to ensure that the lift
* function is in our base term, and that its type variables have the
* same names ("'s" for state, "'a" for return type) that we use
* below.
*)
(* FIXME: use @{mk_term} *)
val base_term = @{term "%L a b. (L2_call :: ('s, 'a, unit) L2_monad => ('s, 'a, unit) L2_monad) a = L b"}
val a = betapplys (#const fn_def, measure_term :: input_vars)
val b = if not is_recursive then betapplys (rhs_term, input_vars) else
case state_typ of
NONE => betapplys (rhs_term, measure_term :: input_vars)
| SOME s => betapplys (rhs_term, [measure_term] @ input_vars @ [Free ("s'", s)]) (* FIXME: Free *)
|> lambda (Free ("s'", s))
val term = Utils.mk_term thy (betapply (base_term, #L2_call_lift rule_set)) [a, b]
|> HOLogic.mk_Trueprop
in
(* Convert it into a trueprop with meta-foralls. *)
Utils.vars_to_metaforall term
end
(*
* Assume recursively called functions correctly map into the given type.
*
* We return:
*
* (<newly generated context>,
* <the measure variable used>,
* <generated assumptions>,
* <table mapping free term names back to their function names>,
* <morphism to escape the context>)
*)
fun assume_rec_lifted ctxt prog_info l2_infos rule_set fn_name =
let
val thy = Proof_Context.theory_of ctxt
val fn_def = the (Symtab.lookup l2_infos fn_name)
(* Find recursive calls. *)
val recursive_calls = Symset.dest (#rec_callees fn_def)
(* Fix a variable for each such call, plus another for our measure variable. *)
val (measure_fix :: dest_fn_fixes, ctxt')
= Variable.add_fixes ("rec_measure'" :: map (fn x => "rec'" ^ x) recursive_calls) ctxt
val rec_fun_names = Symtab.make (dest_fn_fixes ~~ recursive_calls)
(* For recursive calls, we need a term representing our measure variable and
* another representing our decremented measure variable. *)
val measure_var = Free (measure_fix, @{typ nat})
val dec_measure_var = @{const "recguard_dec"} $ measure_var
(* For each recursive call, generate a theorem assuming that it lifts into
* the type/monad of "rule_set". *)
val dest_fn_thms = map (fn (callee, var) =>
let
val fn_def' = the (Symtab.lookup l2_infos callee)
val inputs = #args fn_def'
val T = (@{typ nat} :: map snd inputs)
---> (fastype_of (#const fn_def') |> get_typ_from_L2 rule_set)
(* NB: pure functions would not use state, but recursive functions cannot
* be lifted to pure (because we trigger failure when the measure hits
* 0). So we can always assume there is state. *)
val state_typ = get_l2_state_typ ctxt prog_info l2_infos fn_name
val x = make_lift_equality ctxt prog_info l2_infos callee rule_set
(SOME state_typ) (SOME dec_measure_var) (Free (var, T))
in
Thm.cterm_of ctxt' x
end) (recursive_calls ~~ dest_fn_fixes)
(* Our measure does not type-check for pure functions,
causing a TERM exception from mk_term. *)
handle TERM _ => raise LiftingFailed ()
(* Assume the theorems we just generated. *)
val (thms, ctxt'') = Assumption.add_assumes dest_fn_thms ctxt'
val thms = map (fn t => (#polymorphic_thm rule_set) OF [t]) thms
in
(ctxt'',
measure_var,
thms,
rec_fun_names,
Assumption.export_morphism ctxt'' ctxt'
$> Variable.export_morphism ctxt' ctxt)
end
(*
* Given a function definition, attempt to lift it into a different
* monadic structure by applying a set of rewrite rules.
*
* For example, given:
*
* foo x y = doE
* a <- returnOk 3;
* b <- returnOk 5;
* returnOk (a + b)
* odE
*
* we may be able to lift to:
*
* foo x y = returnOk (let
* a = 3;
* b = 5;
* in
* a + b)
*
* This second function has the form "lift $ term" for some lifting function
* "lift" and some new term "term". (These would be "returnOk" and "let a = ...
* in a + b" in the example above, respectively.)
*
* We return a theorem of the form "foo x y == <lift> $ <term>", along with the
* new term "<term>". If the lift was unsuccessful, we return "NONE".
*)
fun perform_lift ctxt prog_info l2_infos rule_set fn_name =
let
(* Assume recursive calls can be successfully lifted into this type. *)
val (ctxt', measure_var, thms, dict, m)
= assume_rec_lifted ctxt prog_info l2_infos rule_set fn_name
val fn_def = #definition (the (Symtab.lookup l2_infos fn_name))
(* Extract the term from our function definition. *)
val fn_thm' = Utils.named_cterm_instantiate ctxt
[("rec_measure'", Thm.cterm_of ctxt' measure_var)] fn_def
handle THM _ => fn_def
val ct = Thm.prop_of fn_thm' |> Utils.rhs_of |> Thm.cterm_of ctxt'
(* Rewrite the term using the given rewrite rules. *)
val t = Thm.term_of ct
val thm = case Monad_Convert.monad_rewrite ctxt rule_set thms true t of
SOME t => t
| NONE => Utils.named_cterm_instantiate ctxt [("x", ct)] @{thm reflexive}
(* Convert "def == old_body" and "old_body == new_body" into "def == new_body". *)
val thm_ml = Thm.transitive fn_thm' thm
val thm_ml = Morphism.thm m thm_ml
(* Get the newly rewritten term. *)
val new_term = Thm.concl_of thm_ml |> Utils.rhs_of
(*val _ = @{trace} (fn_name, #name rule_set, cterm_of (Proof_Context.theory_of ctxt) new_term)*)
(* Determine if the conversion was successful. *)
val success = #valid_term rule_set ctxt new_term
in
(* Determine if we were a success. *)
if success then
SOME (thm_ml, dict)
else
NONE
end
(* Like perform_lift, but also applies the polishing rules, hopefully yielding
* an even nicer definition. *)
fun perform_lift_and_polish ctxt prog_info fn_info rule_set do_opt fn_name =
case (perform_lift ctxt prog_info fn_info rule_set fn_name)
of NONE => NONE
| SOME (thm, dict) => SOME let
(* Measure the size of the new theorem. *)
val _ = Statistics.gather ctxt "TS" fn_name
(Thm.concl_of thm |> Utils.rhs_of)
(* Apply any polishing rules. *)
val polish_thm = Monad_Convert.polish ctxt rule_set do_opt thm
(* Measure the term. *)
val _ = Statistics.gather ctxt "polish" fn_name
(Thm.concl_of polish_thm |> Utils.rhs_of)
in (polish_thm, dict) end
(*
* Attempt to lift a function (or recursive function group) into the given monad.
*
* If successful, we define the new function (vis. group) to the theory.
* We then return a theorem of the form:
*
* "L2_call foo x y z == <lift> $ new_foo x y z"
*
* where "lift" is a lifting function, such as "returnOk" or "gets", etc.
*
* If the lift does not succeed, the function returns NONE.
*
* The callees of this function need to be already translated in ts_infos
* and also defined in lthy.
*)
fun lift_function_rewrite rule_set filename prog_info l2_infos ts_infos
fn_names make_function_name keep_going do_opt lthy =
let
val these_l2_infos = map (the o Symtab.lookup l2_infos) fn_names
(* Determine if this function is recursive. *)
val is_recursive = FunctionInfo2.is_function_recursive (hd these_l2_infos)
(* Fetch relevant callees. *)
val callees =
map #callees these_l2_infos
|> Symset.union_sets
|> Symset.dest
val callee_infos = map (Symtab.lookup ts_infos #> the) callees
(* Add monad_mono theorems. *)
val mono_thms = List.mapPartial #mono_thm callee_infos
(* When lifting, also lift our callees. *)
val rule_set' = Monad_Types.update_mt_lift_rules
(fn thms => Monad_Types.thmset_adds thms (map #corres_thm callee_infos @ mono_thms))
rule_set
(*
* Attempt to lift all functions into this type.
*
* For mutually recursive functions, every function in the group needs to be
* lifted in the same type.
*
* Eliminate the "SOME", raising an exception if any function in the group
* couldn't be lifted to this type.
*)
val lifted_functions =
map (perform_lift_and_polish lthy prog_info l2_infos rule_set' do_opt) fn_names
val lifted_functions = map (fn x =>
case x of
SOME a => a
| NONE => raise LiftingFailed ())
lifted_functions
val thms = map fst lifted_functions
val dicts = map snd lifted_functions
(*
* Generate terms necessary for defining the function, and define the
* functions.
*)
fun gen_fun_def_term (fn_name, dict, thm) =
let
(* If this function is recursive, it has a measure parameter. *)
val measure_param = if is_recursive then [("rec_measure'", @{typ nat})] else []
(* Fetch function parameters. *)
val fn_info = the (Symtab.lookup l2_infos fn_name)
val fn_params = #args fn_info
(* Extract the term from the function theorem. *)
fun tail_of (_ $ x) = x
val term = Thm.concl_of thm
|> Utils.rhs_of
|> tail_of
|> Utils.unsafe_unvarify
|> (fn term =>
foldr (fn (v, t) => Utils.abs_over "" v t)
term (map Free (measure_param @ fn_params)))
val fn_type = fastype_of term
(* Replace place-holder function names with our generated constant name. *)
val term = map_aterms (fn t => case t of
Free (n, T) =>
(case Symtab.lookup dict n of
NONE => t
| SOME a => Free (make_function_name a, T))
| x => x) term
in
(make_function_name fn_name, measure_param @ fn_params, term)
end
val input_defs = map gen_fun_def_term (Utils.zip3 fn_names dicts thms)
val (ts_defs, lthy) = Utils.define_functions input_defs false is_recursive lthy
(*
(* Record the definitions in our theory data. *)
val lthy = fold (fn (fn_name, def) =>
Local_Theory.background_theory (
AutoCorresData.add_def filename ("TS" ^ "def") fn_name def))
(Utils.zip fn_names defs) lthy
*)
(* Instantiate variables in our equivalence theorem to their newly-defined values. *)
fun do_inst_thm thm =
Utils.instantiate_thm_vars lthy (
fn ((name, _), t) =>
try (unprefix "rec'") name
|> Option.map make_function_name
|> Option.map (Utils.get_term lthy)
|> Option.map (Thm.cterm_of lthy)) thm
val inst_thms = map do_inst_thm thms
(* HACK: If an L2 function takes no parameters, its measure gets eta-contracted
away, preventing eqsubst_tac from unifying the rec_measure's schematic
variable. So we get rid of the schematic var pre-emptively. *)
val inst_thms =
map (fn inst_thm => Drule.abs_def inst_thm handle TERM _ => inst_thm) inst_thms
(* Generate a theorem converting "L2_call <func>" into its new form,
* such as L2_call <func> = liftE $ <new_func_def> *)
val final_props = map (fn fn_name =>
make_lift_equality lthy prog_info l2_infos fn_name rule_set'
NONE NONE (Utils.get_term lthy (make_function_name fn_name))) fn_names
(* Convert meta-logic into HOL statements, conjunct them together and setup
* our goal statement. *)
val int_props = map (Object_Logic.atomize_term lthy) final_props
val goal = Utils.mk_conj_list int_props
|> HOLogic.mk_Trueprop
|> Thm.cterm_of lthy
val simps =
#L2_simp_rules rule_set' @
@{thms gets_bind_ign L2_call_fail HOL.simp_thms}
val rewrite_thm =
Goal.init goal
|> (fn goal_state => if is_recursive then
goal_state
|> apply_tac "start induction"
(resolve_tac lthy @{thms recguard_induct} 1)
|> apply_tac "(base case) split subgoals"
(TRY (REPEAT_ALL_NEW (Tactic.match_tac lthy [@{thm conjI}]) 1))
|> apply_tac "(base case) solve base cases"
(EVERY (map (fn (ts_def, l2_info) =>
SOLVES (
(EqSubst.eqsubst_tac lthy [0] [#definition l2_info] 1)
THEN (EqSubst.eqsubst_tac lthy [0] [ts_def] 1)
THEN (simp_tac (put_simpset ts_simpset lthy) 1))) (ts_defs ~~ these_l2_infos)))
|> apply_tac "(rec case) spliting induct case prems"
(TRY (REPEAT_ALL_NEW (Tactic.ematch_tac lthy [@{thm conjE}]) 1))
|> apply_tac "(rec case) split inductive case subgoals"
(TRY (REPEAT_ALL_NEW (Tactic.match_tac lthy [@{thm conjI}]) 1))
|> apply_tac "(rec case) unfolding strengthened function definition"
(EVERY (map (fn (def, inst_thm) =>
((EqSubst.eqsubst_tac lthy [0] [def] 1)
THEN (EqSubst.eqsubst_tac lthy [0] [inst_thm] 1)
THEN (REPEAT (CHANGED (asm_full_simp_tac (put_simpset ts_simpset lthy) 1)))))
(ts_defs ~~ inst_thms)))
else
goal_state
|> apply_tac "unfolding strengthen function definition"
(EqSubst.eqsubst_tac lthy [0] [hd ts_defs] 1)
|> apply_tac "unfolding L2 rewritten theorem"
(EqSubst.eqsubst_tac lthy [0] [hd inst_thms] 1)
|> apply_tac "simplifying remainder"
(TRY (simp_tac (put_simpset HOL_ss (Utils.set_hidden_ctxt lthy) addsimps simps) 1))
)
|> Goal.finish lthy
(* Now, using this combined theorem, generate a theorem for each individual
* function. *)
fun prove_partial_pred thm pred =
Thm.cterm_of lthy pred
|> Goal.init
|> apply_tac "inserting hypothesis"
(cut_tac thm 1)
|> apply_tac "normalising into rule format"
((REPEAT (resolve_tac lthy @{thms allI} 1))
THEN (REPEAT (eresolve_tac lthy @{thms conjE} 1))
THEN (REPEAT (eresolve_tac lthy @{thms allE} 1)))
|> apply_tac "solving goal" (assume_tac lthy 1)
|> Goal.finish lthy
|> Object_Logic.rulify lthy
val new_thms = map (prove_partial_pred rewrite_thm) final_props
(*
* Make the theorems polymorphic in their exception type.
*
* That is, these theories may all be applied regardless of what the type of
* the exception part of the monad is, but are currently specialised to
* when the exception part of the monad is unit. We apply a "polymorphism theorem" to change
* the type of the rule from:
*
* ('s, unit + 'a) nondet_monad
*
* to
*
* ('s, 'e + 'a) nondet_monad
*)
val new_thms = map (fn t => #polymorphic_thm rule_set' OF [t]) new_thms
|> map (fn t => Drule.generalize ([], ["rec_measure'"]) t)
(*
(* Record correctness theorem(s) for what we just did. *)
val lthy = fold (fn (fn_name, thm) =>
Local_Theory.background_theory
(AutoCorresData.add_thm filename "TScorres" fn_name thm))
(fn_names ~~ new_thms) lthy
*)
(* FIXME: resurrect in a better place
(* Output final corres rule. *)
fun output_rule fn_name lthy =
let
val thy = Proof_Context.theory_of lthy
val l1_thm = the (AutoCorresData.get_thm thy filename "L1corres" fn_name)
handle Option => raise SimplConv.FunctionNotFound fn_name
val l2_thm = the (AutoCorresData.get_thm thy filename "L2corres" fn_name)
handle Option => raise SimplConv.FunctionNotFound fn_name
val hl_thm = the (AutoCorresData.get_thm thy filename "HLcorres" fn_name)
handle Option => (@{thm L2Tcorres_trivial_from_local_var_extract} OF [l2_thm]
handle THM _ => l2_thm) (* catch this failure below *)
val wa_thm = the (AutoCorresData.get_thm thy filename "WAcorres" fn_name)
(* If there is no WAcorres thm, it is probably because word_abstract
* was disabled using no_word_abs.
* In that case we just carry the hl_thm through. *)
handle Option => (@{thm corresTA_trivial_from_heap_lift} OF [hl_thm]
handle THM _ => hl_thm) (* catch this failure below *)
val ts_thm = the (AutoCorresData.get_thm thy filename "TScorres" fn_name)
handle Option => raise SimplConv.FunctionNotFound fn_name
val lthy = let
val final_thm = @{thm ac_corres_chain} OF [l1_thm, l2_thm, hl_thm, wa_thm, ts_thm]
val final_thm' = Simplifier.simplify lthy final_thm
val (_, lthy) = Utils.define_lemma (make_function_name fn_name ^ "_ac_corres") final_thm' lthy
in
lthy
end
handle THM _ =>
(Utils.THM_non_critical keep_going
("autocorres: failed to prove ac_corres theorem for " ^ fn_name)
0 [l1_thm, l2_thm, hl_thm, wa_thm, ts_thm];
lthy)
in
lthy
end
val lthy = fold output_rule fn_names lthy
*)
(* Final mono rules for recursive functions. *)
fun try_mono_thm (fn_name, rewrite_thm) =
case #mono_thm (the (Symtab.lookup l2_infos fn_name)) of
NONE => NONE
| SOME l2_mono_thm =>
case try (#prove_mono rule_set' rewrite_thm) l2_mono_thm of
NONE => (Utils.ac_warning ("Failed to prove monad_mono for function: " ^ fn_name);
NONE)
| SOME ts_mono_thm => SOME (fn_name, ts_mono_thm);
val mono_thms = (if is_recursive then List.mapPartial try_mono_thm (fn_names ~~ new_thms) else [])
|> Symtab.make;
val ts_infos =
map (fn ((f, l2_info), (ts_def, ts_corres)) => let
val mono_thm = Symtab.lookup mono_thms f;
val const = Utils.get_term lthy (make_function_name f);
val ts_info = l2_info
|> FunctionInfo2.function_info_upd_phase FunctionInfo2.TS
|> FunctionInfo2.function_info_upd_definition ts_def
|> FunctionInfo2.function_info_upd_corres_thm ts_corres
|> FunctionInfo2.function_info_upd_const const;
in (f, ts_info) end)
((fn_names ~~ these_l2_infos) ~~ (ts_defs ~~ new_thms))
in
(lthy, Symtab.make ts_infos, #name rule_set')
end
(* Return the lifting rule(s) to try for a function set.
This is moved out of lift_function so that it can be used to
provide argument checking in the AutoCorres.abstract wrapper. *)
fun compute_lift_rules rules force_lift fn_names =
let
fun all_list f xs = fold (fn x => (fn b => b andalso f x)) xs true
val forced = fn_names
|> map (fn func => case Symtab.lookup force_lift func of
SOME rule => [(func, rule)]
| NONE => [])
|> List.concat
in
case forced of
[] => rules (* No restrictions *)
| ((func, rule) :: rest) =>
(* Functions in the same set must all use the same lifting rule. *)
if map snd rest |> all_list (fn rule' => #name rule = #name rule')
then [rule] (* Try the specified rule *)
else error ("autocorres: this set of mutually recursive functions " ^
"cannot be lifted to different monads: " ^
commas_quote (map fst forced))
end
(* Lift the given function set, trying each rule until one succeeds. *)
fun lift_function rules force_lift filename prog_info l2_infos ts_infos
fn_names make_function_name keep_going do_opt lthy =
let
val rules' = compute_lift_rules rules force_lift fn_names
(* Find the first lift that works. *)
fun first (rule::xs) =
(lift_function_rewrite rule filename prog_info l2_infos ts_infos
fn_names make_function_name keep_going do_opt lthy
handle LiftingFailed _ => first xs)
| first [] = raise AllLiftingFailed (map (fn f =>
(f, #definition (the (Symtab.lookup l2_infos f)))) fn_names)
in
first rules'
end
(* Show how many functions were lifted to each monad. *)
fun print_statistics results =
let
fun count_dups x [] = [x]
| count_dups (head, count) (next::rest) =
if head = next then
count_dups (head, count + 1) rest
else
(head, count) :: (count_dups (next, 1) rest)
val tabulated = count_dups ("__fake__", 0) (sort_strings results) |> tl
val data = map (fn (a,b) =>
(" " ^ a ^ ": " ^ (PolyML.makestring b) ^ "\n")
) tabulated
|> String.concat
in
writeln ("Type Strengthening Statistics: \n" ^ data)
end
(* Run through every function, attempting to strengthen its type. *)
fun translate
(rules : Monad_Types.monad_type list)
(force_lift : Monad_Types.monad_type Symtab.table)
(filename : string)
(prog_info : ProgramInfo.prog_info)
(* lazy results from L2 (we currently pre-force them, though) *)
(l2_results: (symset * (local_theory * FunctionInfo2.function_info Symtab.table) future) list)
(make_function_name : string -> string)
(keep_going : bool)
(do_opt : bool) =
if null l2_results then [] else
let
(* Wait for previous stage to finish first. *)
val lthy = fst (Future.join (snd (hd (rev l2_results))));
val l2_results = map (snd #> Future.join #> snd) l2_results;
val l2_infos = LocalVarExtract2.symtab_merge false l2_results;
(* Prettify bound variable names in definitions. *)
val l2_infos = Symtab.map (fn f_name => fn info => let
val def = #definition (the (Symtab.lookup l2_infos f_name));
val pretty_def = PrettyBoundVarNames.pretty_bound_vars_thm
lthy (Utils.crhs_of (Thm.cprop_of def)) keep_going
|> Thm.transitive def;
in FunctionInfo2.function_info_upd_definition pretty_def info end)
l2_infos;
(* For now, just works sequentially like the old TypeStrengthen.
TODO: parallelise conversion and (especially) polishing *)
fun translate_group fn_names (lthy, _, ts_infos) =
let
val _ = writeln ("Translating (type strengthen) " ^ Utils.commas fn_names);
val start_time = Timer.startRealTimer ();
val (lthy, new_ts_infos, monad_name) =
lift_function rules force_lift filename prog_info l2_infos ts_infos
fn_names make_function_name keep_going do_opt lthy;
val _ = writeln (" --> " ^ monad_name);
val _ = @{trace} ("Converted (TS) " ^ Utils.commas fn_names ^ " in " ^
Time.toString (Timer.checkRealTimer start_time) ^ " s");
in (lthy, new_ts_infos, Symtab.merge (K false) (ts_infos, new_ts_infos)) end;
val ts_results =
Utils.accumulate translate_group (lthy, Symtab.empty, Symtab.empty)
(map (map fst o Symtab.dest) l2_results)
|> fst
|> map (fn (lthy, ts_infos, _) => (lthy, ts_infos));
in
ts_results
end
end

View File

@ -397,8 +397,7 @@ fun dest_conj_list (Const (@{const_name "HOL.conj"}, _) $ l $ r)
*)
fun apply_tac (step : string) tac (thm : thm) =
(tac thm |> Seq.hd) handle Option =>
error ("Failed to apply tactic during '" ^ step ^ "': " ^ (
(PolyML.makestring (cprems_of thm))))
raise TERM ("Failed to apply tactic during " ^ quote step, Thm.prems_of thm)
(*
* A "the" operator that explains what is going wrong.

View File

@ -0,0 +1,551 @@
(*
* Copyright 2014, NICTA
*
* This software may be distributed and modified according to the terms of
* the BSD 2-Clause license. Note that NO WARRANTY is provided.
* See "LICENSE_BSD2.txt" for details.
*
* @TAG(NICTA_BSD)
*)
(*
* Rewrite L2 specifications to use "nat" and "int" data-types instead of
* "word" data types. The former tend to be easier to reason about.
*)
structure WordAbstract2 =
struct
(* Maximum depth that we will go before assuming that we are diverging. *)
val WORD_ABS_MAX_DEPTH = 200
(* Convenience shortcuts. *)
val warning = Utils.ac_warning
val apply_tac = Utils.apply_tac
val the' = Utils.the'
type WARules = {
ctype : typ, atype : typ,
abs_fn : term, inv_fn : term,
rules : thm list
}
fun mk_word_abs_rule T =
{
ctype = fastype_of (@{mk_term "x :: (?'W::len) word" ('W)} T),
atype = @{typ nat},
abs_fn = @{mk_term "unat :: (?'W::len) word => nat" ('W)} T,
inv_fn = @{mk_term "of_nat :: nat => (?'W::len) word" ('W)} T,
rules = @{thms word_abs_word32}
}
val word_abs : WARules list =
map mk_word_abs_rule [@{typ 32}, @{typ 16}, @{typ 8}]
fun mk_sword_abs_rule T =
{
ctype = fastype_of (@{mk_term "x :: (?'W::len) signed word" ('W)} T),
atype = @{typ int},
abs_fn = @{mk_term "sint :: (?'W::len) signed word => int" ('W)} T,
inv_fn = @{mk_term "of_int :: int => (?'W::len) signed word" ('W)} T,
rules = @{thms word_abs_sword32}
}
val sword_abs : WARules list =
map mk_sword_abs_rule [@{typ 32}, @{typ 16}, @{typ 8}]
(* Get abstract version of a HOL type. *)
fun get_abs_type (rules : WARules list) T =
Option.getOpt
(List.find (fn r => #ctype r = T) rules
|> Option.map (fn r => #atype r),
T)
(* Get abstraction function for a HOL type. *)
fun get_abs_fn (rules : WARules list) T =
Option.getOpt
(List.find (fn r => #ctype r = T) rules
|> Option.map (fn r => #abs_fn r),
@{mk_term "id :: ?'a => ?'a" ('a)} T)
fun get_abs_inv_fn (rules : WARules list) t =
Option.getOpt
(List.find (fn r => #ctype r = fastype_of t) rules
|> Option.map (fn r => #inv_fn r $ t),
t)
(*
* From a list of abstract arguments to a function, derive a list of
* concrete arguments and types and a precondition that links the two.
*)
fun get_wa_conc_args rules l2_infos fn_name fn_args =
let
(* Construct arguments for the concrete body. We use the abstract names
* with a prime ('), but with the concrete types. *)
val conc_types = the (Symtab.lookup l2_infos fn_name) |> #args |> map snd
val conc_args = map (fn (Free (x, Tc), Ta) => Free (x ^ "'", Ta)) (* FIXME: Free *)
(fn_args ~~ conc_types)
val arg_pairs = (conc_args ~~ fn_args)
(* Create preconditions that link the new types to the old types. *)
val precond =
map (fn (old, new) => @{mk_term "abs_var ?n ?f ?o" (o, f, n)}
(old, get_abs_fn rules (fastype_of old), new))
arg_pairs
|> Utils.mk_conj_list
in
(conc_types, conc_args, precond, arg_pairs)
end
(* Get the expected type of a function from its name. *)
fun get_expected_fn_type rules l2_infos fn_name =
let
val fn_info = the (Symtab.lookup l2_infos fn_name)
val fn_params_typ = map ((get_abs_type rules) o snd) (#args fn_info)
val fn_ret_typ = get_abs_type rules (#return_type fn_info)
val globals_typ = LocalVarExtract.dest_l2monad_T (fastype_of (#const fn_info)) |> snd |> #1
val measure_typ = @{typ "nat"}
in
(measure_typ :: fn_params_typ)
---> LocalVarExtract.mk_l2monadT globals_typ fn_ret_typ @{typ unit}
end
(* Get the expected theorem that will be generated about a function. *)
fun get_expected_fn_thm rules l2_infos ctxt fn_name
function_free fn_args _ measure_var =
let
val old_def = the (Symtab.lookup l2_infos fn_name)
val (old_arg_types, old_args, precond, arg_pairs)
= get_wa_conc_args rules l2_infos fn_name fn_args
val old_term = betapplys (#const old_def, measure_var :: old_args)
val new_term = betapplys (function_free, measure_var :: fn_args)
in
@{mk_term "Trueprop (corresTA (%x. ?P) ?rt id ?A ?C)" (rt, A, C, P)}
(get_abs_fn rules (#return_type old_def), new_term, old_term, precond)
|> fold (fn t => fn v => Logic.all t v) (rev (map fst arg_pairs))
end
(* Get arguments passed into the function. *)
fun get_expected_fn_args rules l2_infos fn_name =
map (apsnd (get_abs_type rules)) (#args (the (Symtab.lookup l2_infos fn_name)))
(*
* Convert a theorem of the form:
*
* corresTA (%_. abs_var True a f a' \<and> abs_var True b f b' \<and> ...) ...
*
* into
*
* [| abstract_val A a f a'; abstract_val B b (f b') |] ==> corresTA (%_. A \<and> B \<and> ...) ...
*
* the latter of which better suits our resolution approach of proof
* construction.
*)
fun extract_preconds_of_call thm =
let
fun r thm =
r (thm RS @{thm corresTA_extract_preconds_of_call_step})
handle THM _ => (thm RS @{thm corresTA_extract_preconds_of_call_final}
handle THM _ => thm RS @{thm corresTA_extract_preconds_of_call_final'});
in
r (thm RS @{thm corresTA_extract_preconds_of_call_init})
end
(* Convert a program by abstracting words. *)
fun translate
(filename: string)
(prog_info: ProgramInfo.prog_info)
(* lazy results *)
(l2_results: (symset * (local_theory * FunctionInfo2.function_info Symtab.table) future) list)
(unsigned_abs: Symset.key Symset.set)
(no_signed_abs: Symset.key Symset.set)
(trace_funcs: string list)
(do_opt: bool)
(trace_opt: bool)
(wa_function_name: string -> string) =
if null l2_results then [] else
let
(*
* Select the rules to translate each function.
*)
fun rules_for fn_name =
(if Symset.contains unsigned_abs fn_name then word_abs else []) @
(if Symset.contains no_signed_abs fn_name then [] else sword_abs)
(* Results for individual functions *)
val l2_results' = maps (fn (f_names, r) =>
Symset.dest f_names ~~ replicate (Symset.card f_names) r) l2_results;
val get_l2_result = let
val table = Symtab.make l2_results';
in fn f => the' ("WA: missing lazy result for function: " ^ f) (Symtab.lookup table f) end;
(* Abstract functions. *)
fun convert f =
let
(* Info for f and its callees *)
val (lthy, f_l2_info) = Future.join (get_l2_result f);
val callee_l1_infos = Symset.dest (FunctionInfo2.all_callees (the (Symtab.lookup f_l2_info f)))
|> map (fn callee => snd (Future.join (get_l2_result callee)));
val l2_infos = LocalVarExtract2.symtab_merge true (f_l2_info :: callee_l1_infos);
val old_fn_info = the (Symtab.lookup l2_infos f);
val wa_rules = rules_for f
(* Fix measure variable. *)
val ([measure_var_name], lthy) = Variable.variant_fixes ["rec_measure'"] lthy;
val measure_var = Free (measure_var_name, LocalVarExtract2.measureT);
(* Fix argument variables. *)
val new_fn_args = get_expected_fn_args wa_rules l2_infos f;
val (arg_names, lthy) = Variable.variant_fixes (map fst new_fn_args) lthy;
val arg_frees = map Free (arg_names ~~ map snd new_fn_args);
val (lthy, export_thm, callee_terms) =
AutoCorresUtil2.assume_called_functions_corres lthy
(#callees old_fn_info) (#rec_callees old_fn_info)
(fn f => get_expected_fn_type (rules_for f) l2_infos f)
(fn lthy => fn f => get_expected_fn_thm (rules_for f) l2_infos lthy f)
(fn f => get_expected_fn_args (rules_for f) l2_infos f)
wa_function_name
measure_var;
(* Construct free variables to represent our concrete arguments. *)
val (conc_types, conc_args, precond, arg_pairs)
= get_wa_conc_args wa_rules l2_infos f arg_frees
(* Fetch the function definition, and instantiate its arguments. *)
val old_body_def =
#definition old_fn_info
(* Instantiate the arguments. *)
|> Utils.inst_args lthy (map (Thm.cterm_of lthy) (measure_var :: conc_args))
(* Get old body definition with function arguments. *)
val old_term = betapplys (#const old_fn_info, measure_var :: conc_args)
(* Get a schematic variable accepting new arguments. *)
val new_var = betapplys (
Var (("A", 0), get_expected_fn_type wa_rules l2_infos f), measure_var :: arg_frees)
(* Fetch monotonicity theorems of callees. *)
val callee_mono_thms =
Symset.dest (FunctionInfo2.all_callees old_fn_info)
|> List.mapPartial (fn callee =>
if FunctionInfo2.is_function_recursive (the (Symtab.lookup l2_infos callee))
then the (Symtab.lookup l2_infos callee) |> #mono_thm
else NONE)
(*
* Generate a schematic goal.
*
* We only want ?A to depend on abstracted variables and ?C to depend on
* concrete variables. We force this by applying bound variables to each
* of the schematics, giving us something like:
*
* !!a a' b b'. corresTA ... (?A a b) (?C a' b')
*
* The abstract side will hence be prevented from capturing (i.e., using)
* concrete variables, and vice-versa.
*)
val goal = @{mk_term "Trueprop (corresTA (%x. ?precond) ?ra id ?A ?C)" (ra, A, C, precond)}
(get_abs_fn wa_rules (#return_type old_fn_info), new_var, old_term, precond)
|> fold (fn t => fn v => Logic.all t v) (rev (arg_frees @ map fst arg_pairs))
|> Thm.cterm_of lthy
|> Goal.init
|> Utils.apply_tac "move precond to assumption" (resolve_tac lthy @{thms corresTA_precond_to_asm} 1)
|> Utils.apply_tac "split precond" (REPEAT (CHANGED (eresolve_tac lthy @{thms conjE} 1)))
|> Utils.apply_tac "create schematic precond" (resolve_tac lthy @{thms corresTA_precond_to_guard} 1)
|> Utils.apply_tac "unfold RHS" (CHANGED (Utils.unfold_once_tac lthy (Utils.abs_def lthy old_body_def) 1))
(*
* Fetch rules from the theory.
*)
val rules = Utils.get_rules lthy @{named_theorems word_abs}
@ List.concat (map #rules wa_rules)
@ @{thms word_abs_default}
val fo_rules = [@{thm abstract_val_fun_app}]
val rules = rules @ (map (snd #> #3 #> extract_preconds_of_call) callee_terms)
@ callee_mono_thms
(* Standard tactics. *)
fun my_rtac ctxt thm n =
Utils.trace_if_success ctxt thm (
DETERM (EVERY' (resolve_tac ctxt [thm] :: replicate (Rule_Cases.get_consumes thm) (assume_tac ctxt)) n))
(* Apply a conversion to the abstract/concrete side of the given "abstract_val" term. *)
fun wa_conc_body_conv conv =
Conv.params_conv (~1) (K (Conv.concl_conv (~1) ((Conv.arg_conv (Utils.nth_arg_conv 4 conv)))))
(* Tactics and conversions for converting goals into first-order format. *)
fun to_fo_tac ctxt =
CONVERSION (Drule.beta_eta_conversion then_conv wa_conc_body_conv (HeapLift.mk_first_order ctxt) ctxt)
fun from_fo_tac ctxt =
CONVERSION (wa_conc_body_conv (HeapLift.dest_first_order ctxt then_conv Drule.beta_eta_conversion) ctxt)
fun make_fo_tac tac ctxt = ((to_fo_tac ctxt THEN' tac) THEN_ALL_NEW from_fo_tac ctxt)
(*
* Recursively solve subgoals.
*
* We allow backtracking in order to solve a particular subgoal, but once a
* subgoal is completed we don't ever try to solve it in a different way.
*
* This allows us to try different approaches to solving subgoals,
* hopefully reducing exponential explosion (of many different combinations
* of "good solutions") once we hit an unsolvable subgoal.
*)
fun SOLVE_ALL _ _ 0 thm =
raise THM ("Word abstraction diverging", 0, [thm])
| SOLVE_ALL ctxt tacs depth thm =
let
fun TRY_ALL [] = no_tac
| TRY_ALL (x::xs) =
(x ctxt THEN REPEAT (SELECT_GOAL (SOLVE_ALL ctxt tacs (depth - 1)) 1))
APPEND (TRY_ALL xs)
in
if Thm.nprems_of thm > 0 then
DETERM (SOLVES (TRY_ALL tacs)) thm
else
all_tac thm
end
(*
* Eliminate a lambda term in the concrete state, but only if the
* lambda is "real".
*
* That is, we don't attempt to eta-expand in order to apply the theorem
* "abstract_val_lambda", because that may lead to an infinite loop with
* "abstract_val_fun_app".
*)
fun lambda_tac n thm =
case Logic.concl_of_goal (Thm.prop_of thm) n of
(Const (@{const_name "Trueprop"}, _) $
(Const (@{const_name "abstract_val"}, _) $ _ $ _ $ _ $ (
Abs (_, _, _)))) =>
resolve_tac lthy @{thms abstract_val_lambda} n thm
| _ => no_tac thm
(* All tactics we try, in the order we should try them. *)
val step_tacs =
[(@{thm imp_refl}, assume_tac lthy 1)]
@ (map (fn thm => (thm, my_rtac lthy thm 1)) rules)
@ (map (fn thm => (thm, make_fo_tac (my_rtac lthy thm) lthy 1)) fo_rules)
@ [(@{thm abstract_val_lambda}, lambda_tac 1)]
@ [(@{thm reflexive},
fn thm =>
(if Config.get lthy ML_Options.exception_trace then
warning ("Could not solve subgoal: " ^
(Goal_Display.string_of_goal lthy thm))
else (); no_tac thm))]
(* Solve the goal. *)
val replay_failure_start = 1
val replay_failures = Unsynchronized.ref replay_failure_start
val (thm, trace) =
case AutoCorresTrace.maybe_trace_solve_tac lthy (member (op =) trace_funcs f) true false
(K step_tacs) goal (SOME WORD_ABS_MAX_DEPTH) replay_failures of
NONE => (* intentionally generate a TRACE_SOLVE_TAC_FAIL *)
(AutoCorresTrace.trace_solve_tac lthy false false (K step_tacs) goal NONE (Unsynchronized.ref 0);
(* never reached *) error "word_abstract fail tac: impossible")
| SOME (thm, [trace]) => (Goal.finish lthy thm, trace)
val _ = if !replay_failures < replay_failure_start then
@{trace} (f ^ " WA: reverted to slow replay " ^
Int.toString(replay_failure_start - !replay_failures) ^ " time(s)") else ()
(* Clean out any final function application ($) constants or "id" constants
* generated by some rules. *)
fun corresTA_abs_conv conv =
Utils.remove_meta_conv (fn ctxt => Utils.nth_arg_conv 4 (conv ctxt)) lthy
val thm =
Conv.fconv_rule (
corresTA_abs_conv (fn ctxt =>
(HeapLift.dest_first_order ctxt)
then_conv (Simplifier.rewrite (
put_simpset HOL_basic_ss ctxt addsimps [@{thm id_def}]))
then_conv Drule.beta_eta_conversion
)
) thm
(* Ensure no schematics remain in the goal. *)
val _ = Sign.no_vars lthy (Thm.prop_of thm)
(* Gather statistics. *)
val _ = Statistics.gather lthy "WA" f
(Variable.gen_all lthy thm
|> Thm.prop_of
|> HOLogic.dest_Trueprop
|> (fn t => Utils.term_nth_arg t 3))
(*
* Instantiate abstract function's meta-forall variables with their actual values.
*
* That is, we go from:
*
* !!a b c a' b' c'. corresTA (P a b c) ...
*
* to
*
* !!a' b' c'. corresTA (P a b c) ...
*)
val thm = Drule.forall_elim_list (map (Thm.cterm_of lthy) arg_frees) thm
(* Apply peephole optimisations to the theorem. *)
val _ = writeln ("Simpifying (WA) " ^ f)
val (thm, opt_traces) = L2Opt.cleanup_thm_tagged lthy thm (if do_opt then 0 else 2) 4 trace_opt "WA"
(* We end up with an unwanted L2_guard outside the L2_recguard.
* L2Opt should simplify the condition to (%_. True) even if (not do_opt),
* so we match the guard and get rid of it here. *)
val thm = Simplifier.rewrite_rule lthy @{thms corresTA_simp_trivial_guard} thm
(* Gather post-optimisation statistics. *)
val _ = Statistics.gather lthy "WAsimp" f
(Variable.gen_all lthy thm
|> Thm.prop_of
|> HOLogic.dest_Trueprop
|> (fn t => Utils.term_nth_arg t 3))
(* Extract the abstract term out of a corresTA thm. *)
fun dest_corresWA_term_abs @{term_pat "corresTA _ _ _ ?t _"} = t
fun get_body_of_thm thm =
Thm.concl_of (Variable.gen_all lthy thm)
|> HOLogic.dest_Trueprop
|> dest_corresWA_term_abs
val _ = @{trace} ("WA conv", f, thm)
val thm' = Morphism.thm export_thm (Variable.gen_all lthy thm)
val _ = @{trace} ("WA conv final", f, thm')
in
(get_body_of_thm thm, thm', dest_Free measure_var :: new_fn_args,
(if member (op =) trace_funcs f then [("WA", AutoCorresData.RuleTrace trace)] else []) @ opt_traces)
end
(* Define a previously-converted function (or recursive function group).
* lthy must include all definitions from l2_callees. *)
fun define
(lthy: local_theory)
(l2_infos: FunctionInfo2.function_info Symtab.table)
(wa_callees: FunctionInfo2.function_info Symtab.table)
(funcs: (string * thm * (string * typ) list) list) (* name, corres, arg frees *)
: FunctionInfo2.function_info Symtab.table * local_theory = let
(* FIXME: dedup with convert, LocalVarExtract *)
(* FIXME: pass this from assume_called_functions_corres, etc. *)
fun guess_callee_var thm callee = let
val base_name = wa_function_name callee;
val mentioned_vars = Term.add_vars (Thm.prop_of thm) [];
in hd (filter (fn ((v, _), _) => v = base_name) mentioned_vars) end;
fun prepare_fn_body (fn_name, corres_thm, arg_frees) = let
val _ = @{trace} ("WA prepare_fn_body", fn_name, corres_thm);
val @{term_pat "Trueprop (corresTA _ _ _ ?body _)"} = Thm.concl_of corres_thm;
val (callees, recursive_callees) = AutoCorresUtil2.get_callees l2_infos fn_name;
val calls = map (fn c => Var (guess_callee_var corres_thm c)) callees;
val recursive_calls = map (fn c => Var (guess_callee_var corres_thm c)) recursive_callees;
(*
* The returned body will have free variables as placeholders for the function's
* measure parameter and other arguments, and schematic variables for the functions it calls.
*
* We modify the body to be of the form:
*
* %fun1 fun2 rec1 rec2 measure arg1 arg2. f <...>
*
* That is, all non-recursive calls are abstracted out the front, followed by
* recursive calls, followed by the measure variable, followed by function
* arguments. This is the format expected by define_funcs.
*)
val abs_body = body
|> fold lambda (rev (map Free arg_frees))
|> fold lambda (rev recursive_calls)
|> fold lambda (rev calls);
in abs_body end;
val funcs' = funcs |>
map (fn (name, thm, frees) =>
(name, (* FIXME: define_funcs needs this currently *)
(Thm.generalize ([], map fst frees) (Thm.maxidx_of thm + 1) thm,
prepare_fn_body (name, thm, frees))));
val _ = @{trace} ("WA.define", map (fn (name, (thm, body)) => (name, thm, Thm.cterm_of lthy body)) funcs');
val (new_thms, (), lthy') =
AutoCorresUtil2.define_funcs FunctionInfo2.WA filename l2_infos
wa_function_name
(fn f => get_expected_fn_type (rules_for f) l2_infos f)
(fn lthy => fn f => get_expected_fn_thm (rules_for f) l2_infos lthy f)
(fn f => get_expected_fn_args (rules_for f) l2_infos f)
@{thm corresTA_recguard_0}
lthy (Symtab.map (K #corres_thm) wa_callees) ()
funcs';
val new_infos = Symtab.map (fn f_name => fn (const, def, corres_thm) => let
val old_info = the (Symtab.lookup l2_infos f_name);
in old_info
|> FunctionInfo2.function_info_upd_phase FunctionInfo2.WA
|> FunctionInfo2.function_info_upd_const const
|> FunctionInfo2.function_info_upd_definition def
|> FunctionInfo2.function_info_upd_corres_thm corres_thm
|> FunctionInfo2.function_info_upd_return_type
(get_abs_type (rules_for f_name) (#return_type old_info))
|> FunctionInfo2.function_info_upd_args
(map (apsnd (get_abs_type (rules_for f_name))) (#args old_info))
|> FunctionInfo2.function_info_upd_mono_thm NONE (* added later *)
end) new_thms;
in (new_infos, lthy') end;
val function_groups = map fst l2_results;
(* All conversions can run in parallel.
* Each conversion depends only on the previous define phase
* (which necessarily also includes definitions of callees). *)
val converted_funcs =
maps Symset.dest function_groups
|> map (fn f => (f, Future.fork (fn _ => convert f)))
|> Symtab.make;
val _ = Symtab.dest converted_funcs |> map snd |> map Future.join
(* Definitions update lthy sequentially.
* We use the arbitrary (but deterministic) ordering defined by get_topo_sorted_functions.
* Each definition step produces a lthy that has a prefix of the update sequence,
* and can be used subsequently to convert a function that depends only on that prefix.
* Hence we produce the intermediate lthys lazily for maximum parallelism. *)
fun add_def f_names accum = Future.fork (fn _ => let
(* Wait for previous definition to finish *)
val (lthy, _, defined_so_far) = Future.join accum;
(* Wait for conversions to finish *)
val f_convs = map (fn f => let
val conv = the' ("didn't convert function: " ^ quote f ^ "??") (Symtab.lookup converted_funcs f);
val (wa_body, corres_thm, arg_frees, traces) = Future.join conv;
in (f, corres_thm, arg_frees) end) f_names;
(* Get L2 phase results (should be a no-op at this point) *)
val (_, l2_infos) = Future.join (get_l2_result (hd f_names));
val (new_defs, lthy') = define lthy l2_infos defined_so_far f_convs;
in (lthy', new_defs, Symtab.merge (K false) (defined_so_far, new_defs)) end);
(* Get initial lthy from end of L2 defs *)
val l2_def_context = Future.map (fn (lthy, _) => (lthy, Symtab.empty, Symtab.empty))
(snd (hd (rev l2_results)));
(* Chain of intermediate states: (lthy, new_defs, accumulator) *)
val (def_results, _) = Utils.accumulate add_def l2_def_context
(map Symset.dest function_groups);
(* Produce a mapping from each function group to its L1 phase_infos and the
* earliest intermediate lthy where it is defined. *)
val results =
def_results
|> map (Future.map (fn (lthy, f_defs, _ (* discard accum here *)) => let
(* Add monad_mono proofs. These are done in parallel as well
* (though in practice, they already run pretty quickly). *)
val mono_thms = if FunctionInfo2.is_function_recursive (snd (hd (Symtab.dest f_defs)))
then LocalVarExtract2.l2_monad_mono lthy f_defs
else Symtab.empty;
val f_defs' = f_defs |> Symtab.map (fn f =>
FunctionInfo2.function_info_upd_mono_thm (Symtab.lookup mono_thms f));
in (lthy, f_defs') end));
in function_groups ~~ results end
end