2014-07-14 19:32:44 +00:00
|
|
|
(*
|
2020-03-09 06:18:30 +00:00
|
|
|
* Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
|
2014-07-14 19:32:44 +00:00
|
|
|
*
|
2020-03-09 06:18:30 +00:00
|
|
|
* SPDX-License-Identifier: BSD-2-Clause
|
2014-07-14 19:32:44 +00:00
|
|
|
*)
|
|
|
|
|
|
|
|
(*
|
|
|
|
* Lift monadic structures into lighter-weight monads.
|
|
|
|
*)
|
|
|
|
structure TypeStrengthen =
|
|
|
|
struct
|
|
|
|
|
|
|
|
exception AllLiftingFailed of (string * thm) list
|
|
|
|
exception LiftingFailed of unit
|
|
|
|
|
2016-06-29 15:35:37 +00:00
|
|
|
(* FIXME: use AUTOCORRES_SIMPSET (need to fix unknown deps of the corres prover) *)
|
2014-09-17 22:54:45 +00:00
|
|
|
val ts_simpset = simpset_of @{context}
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
(* Misc util functions. *)
|
|
|
|
val the' = Utils.the'
|
|
|
|
val apply_tac = Utils.apply_tac
|
|
|
|
|
2016-06-29 15:35:37 +00:00
|
|
|
fun get_l2_state_typ ctxt prog_info l2_infos fn_name =
|
2014-07-14 19:32:44 +00:00
|
|
|
let
|
2016-06-29 15:35:37 +00:00
|
|
|
val term = #const (the (Symtab.lookup l2_infos fn_name));
|
2014-07-14 19:32:44 +00:00
|
|
|
in
|
|
|
|
LocalVarExtract.dest_l2monad_T (fastype_of term) |> snd |> #1
|
2016-06-29 15:35:37 +00:00
|
|
|
end;
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
fun get_typ_from_L2 (rule_set : Monad_Types.monad_type) L2_typ =
|
|
|
|
LocalVarExtract.dest_l2monad_T L2_typ |> snd |> #typ_from_L2 rule_set
|
|
|
|
|
|
|
|
(*
|
|
|
|
* Make an equality prop of the form "L2_call <foo> = <liftE> $ <bar>".
|
|
|
|
*
|
|
|
|
* L2_call and <liftE> will typically be desired to be polymorphic in their
|
|
|
|
* exception type. We fix it to "unit"; the caller will need to introduce
|
|
|
|
* polymorphism as necessary.
|
|
|
|
*
|
|
|
|
* If "measure" is non-NONE, then that term will be used instead of a free
|
|
|
|
* variable.
|
|
|
|
* If "state_typ" is non-NONE, then "measure" is assumed to also take a
|
|
|
|
* state parameter of the given type.
|
|
|
|
*)
|
2016-06-29 15:35:37 +00:00
|
|
|
fun make_lift_equality ctxt prog_info l2_infos fn_name
|
2014-07-14 19:32:44 +00:00
|
|
|
(rule_set : Monad_Types.monad_type) state_typ measure rhs_term =
|
|
|
|
let
|
|
|
|
val thy = Proof_Context.theory_of ctxt
|
|
|
|
|
|
|
|
(* Fetch function variables. *)
|
2016-06-29 15:35:37 +00:00
|
|
|
val fn_def = the (Symtab.lookup l2_infos fn_name)
|
2014-07-14 19:32:44 +00:00
|
|
|
val inputs = #args fn_def
|
|
|
|
val input_vars = map (fn (x, y) => Var ((x, 0), y)) inputs
|
|
|
|
|
|
|
|
(*
|
|
|
|
* Measure var.
|
|
|
|
*
|
|
|
|
* The L2 left-hand-side will always use this measure, while the
|
|
|
|
* right-hand-side will only include a measure if the function is actually
|
|
|
|
* recursive.
|
|
|
|
*)
|
2016-06-29 15:35:37 +00:00
|
|
|
val is_recursive = FunctionInfo.is_function_recursive fn_def
|
|
|
|
val default_measure_var = @{term "rec_measure' :: nat"} (* FIXME: Free *)
|
2014-07-14 19:32:44 +00:00
|
|
|
val measure_term = Option.getOpt (measure, default_measure_var)
|
|
|
|
|
|
|
|
(*
|
|
|
|
* Construct the equality.
|
|
|
|
*
|
|
|
|
* This is a little delicate: in particular, we need to ensure that
|
|
|
|
* the type of the resulting term is strictly correct. In particular,
|
|
|
|
* our "lift_fn" will have type variables that need to be modified.
|
|
|
|
* "Utils.mk_term" will fill in type variables of the base term based
|
|
|
|
* on what is applied to it. So, we need to ensure that the lift
|
|
|
|
* function is in our base term, and that its type variables have the
|
|
|
|
* same names ("'s" for state, "'a" for return type) that we use
|
|
|
|
* below.
|
|
|
|
*)
|
2016-06-29 15:35:37 +00:00
|
|
|
(* FIXME: use @{mk_term} *)
|
2014-07-14 19:32:44 +00:00
|
|
|
val base_term = @{term "%L a b. (L2_call :: ('s, 'a, unit) L2_monad => ('s, 'a, unit) L2_monad) a = L b"}
|
|
|
|
val a = betapplys (#const fn_def, measure_term :: input_vars)
|
|
|
|
val b = if not is_recursive then betapplys (rhs_term, input_vars) else
|
|
|
|
case state_typ of
|
|
|
|
NONE => betapplys (rhs_term, measure_term :: input_vars)
|
2016-06-29 15:35:37 +00:00
|
|
|
| SOME s => betapplys (rhs_term, [measure_term] @ input_vars @ [Free ("s'", s)]) (* FIXME: Free *)
|
2014-07-14 19:32:44 +00:00
|
|
|
|> lambda (Free ("s'", s))
|
|
|
|
val term = Utils.mk_term thy (betapply (base_term, #L2_call_lift rule_set)) [a, b]
|
|
|
|
|> HOLogic.mk_Trueprop
|
|
|
|
|
|
|
|
in
|
|
|
|
(* Convert it into a trueprop with meta-foralls. *)
|
|
|
|
Utils.vars_to_metaforall term
|
|
|
|
end
|
|
|
|
|
|
|
|
(*
|
|
|
|
* Assume recursively called functions correctly map into the given type.
|
|
|
|
*
|
|
|
|
* We return:
|
|
|
|
*
|
|
|
|
* (<newly generated context>,
|
|
|
|
* <the measure variable used>,
|
|
|
|
* <generated assumptions>,
|
|
|
|
* <table mapping free term names back to their function names>,
|
|
|
|
* <morphism to escape the context>)
|
2016-06-29 15:35:37 +00:00
|
|
|
*
|
|
|
|
* FIXME: refactor with AutoCorresUtil.assume_called_functions_corres
|
2014-07-14 19:32:44 +00:00
|
|
|
*)
|
2016-06-29 15:35:37 +00:00
|
|
|
fun assume_rec_lifted ctxt prog_info l2_infos rule_set fn_name =
|
2014-07-14 19:32:44 +00:00
|
|
|
let
|
2016-06-29 15:35:37 +00:00
|
|
|
val fn_def = the (Symtab.lookup l2_infos fn_name)
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
(* Find recursive calls. *)
|
2016-06-29 15:35:37 +00:00
|
|
|
val recursive_calls = Symset.dest (#rec_callees fn_def)
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
(* Fix a variable for each such call, plus another for our measure variable. *)
|
|
|
|
val (measure_fix :: dest_fn_fixes, ctxt')
|
|
|
|
= Variable.add_fixes ("rec_measure'" :: map (fn x => "rec'" ^ x) recursive_calls) ctxt
|
|
|
|
val rec_fun_names = Symtab.make (dest_fn_fixes ~~ recursive_calls)
|
|
|
|
|
|
|
|
(* For recursive calls, we need a term representing our measure variable and
|
|
|
|
* another representing our decremented measure variable. *)
|
|
|
|
val measure_var = Free (measure_fix, @{typ nat})
|
|
|
|
val dec_measure_var = @{const "recguard_dec"} $ measure_var
|
|
|
|
|
|
|
|
(* For each recursive call, generate a theorem assuming that it lifts into
|
|
|
|
* the type/monad of "rule_set". *)
|
2016-06-29 15:35:37 +00:00
|
|
|
val dest_fn_thms = map (fn (callee, var) =>
|
2014-07-14 19:32:44 +00:00
|
|
|
let
|
2016-06-29 15:35:37 +00:00
|
|
|
val fn_def' = the (Symtab.lookup l2_infos callee)
|
2014-07-14 19:32:44 +00:00
|
|
|
val inputs = #args fn_def'
|
|
|
|
val T = (@{typ nat} :: map snd inputs)
|
|
|
|
---> (fastype_of (#const fn_def') |> get_typ_from_L2 rule_set)
|
|
|
|
|
|
|
|
(* NB: pure functions would not use state, but recursive functions cannot
|
|
|
|
* be lifted to pure (because we trigger failure when the measure hits
|
|
|
|
* 0). So we can always assume there is state. *)
|
2016-06-29 15:35:37 +00:00
|
|
|
val state_typ = get_l2_state_typ ctxt prog_info l2_infos fn_name
|
|
|
|
val x = make_lift_equality ctxt prog_info l2_infos callee rule_set
|
2014-07-14 19:32:44 +00:00
|
|
|
(SOME state_typ) (SOME dec_measure_var) (Free (var, T))
|
|
|
|
in
|
2015-05-21 06:44:02 +00:00
|
|
|
Thm.cterm_of ctxt' x
|
2014-07-14 19:32:44 +00:00
|
|
|
end) (recursive_calls ~~ dest_fn_fixes)
|
|
|
|
(* Our measure does not type-check for pure functions,
|
|
|
|
causing a TERM exception from mk_term. *)
|
|
|
|
handle TERM _ => raise LiftingFailed ()
|
|
|
|
|
|
|
|
(* Assume the theorems we just generated. *)
|
|
|
|
val (thms, ctxt'') = Assumption.add_assumes dest_fn_thms ctxt'
|
|
|
|
val thms = map (fn t => (#polymorphic_thm rule_set) OF [t]) thms
|
|
|
|
in
|
|
|
|
(ctxt'',
|
|
|
|
measure_var,
|
|
|
|
thms,
|
|
|
|
rec_fun_names,
|
|
|
|
Assumption.export_morphism ctxt'' ctxt'
|
|
|
|
$> Variable.export_morphism ctxt' ctxt)
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
(*
|
|
|
|
* Given a function definition, attempt to lift it into a different
|
|
|
|
* monadic structure by applying a set of rewrite rules.
|
|
|
|
*
|
|
|
|
* For example, given:
|
|
|
|
*
|
|
|
|
* foo x y = doE
|
|
|
|
* a <- returnOk 3;
|
|
|
|
* b <- returnOk 5;
|
|
|
|
* returnOk (a + b)
|
|
|
|
* odE
|
|
|
|
*
|
|
|
|
* we may be able to lift to:
|
|
|
|
*
|
|
|
|
* foo x y = returnOk (let
|
|
|
|
* a = 3;
|
|
|
|
* b = 5;
|
|
|
|
* in
|
|
|
|
* a + b)
|
|
|
|
*
|
|
|
|
* This second function has the form "lift $ term" for some lifting function
|
|
|
|
* "lift" and some new term "term". (These would be "returnOk" and "let a = ...
|
|
|
|
* in a + b" in the example above, respectively.)
|
|
|
|
*
|
|
|
|
* We return a theorem of the form "foo x y == <lift> $ <term>", along with the
|
|
|
|
* new term "<term>". If the lift was unsuccessful, we return "NONE".
|
|
|
|
*)
|
2016-06-29 15:35:37 +00:00
|
|
|
fun perform_lift ctxt prog_info l2_infos rule_set fn_name =
|
2014-07-14 19:32:44 +00:00
|
|
|
let
|
|
|
|
(* Assume recursive calls can be successfully lifted into this type. *)
|
|
|
|
val (ctxt', measure_var, thms, dict, m)
|
2016-06-29 15:35:37 +00:00
|
|
|
= assume_rec_lifted ctxt prog_info l2_infos rule_set fn_name
|
|
|
|
|
|
|
|
val fn_def = #definition (the (Symtab.lookup l2_infos fn_name))
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
(* Extract the term from our function definition. *)
|
2016-02-05 07:51:15 +00:00
|
|
|
val fn_thm' = Utils.named_cterm_instantiate ctxt
|
2016-06-29 15:35:37 +00:00
|
|
|
[("rec_measure'" (* FIXME *), Thm.cterm_of ctxt' measure_var)] fn_def
|
|
|
|
handle THM _ => fn_def
|
2015-05-21 06:44:02 +00:00
|
|
|
val ct = Thm.prop_of fn_thm' |> Utils.rhs_of |> Thm.cterm_of ctxt'
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
(* Rewrite the term using the given rewrite rules. *)
|
2015-05-21 06:44:02 +00:00
|
|
|
val t = Thm.term_of ct
|
2014-07-14 19:32:44 +00:00
|
|
|
val thm = case Monad_Convert.monad_rewrite ctxt rule_set thms true t of
|
|
|
|
SOME t => t
|
2016-02-05 07:51:15 +00:00
|
|
|
| NONE => Utils.named_cterm_instantiate ctxt [("x", ct)] @{thm reflexive}
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
(* Convert "def == old_body" and "old_body == new_body" into "def == new_body". *)
|
|
|
|
val thm_ml = Thm.transitive fn_thm' thm
|
|
|
|
|
|
|
|
(* Get the newly rewritten term. *)
|
|
|
|
val new_term = Thm.concl_of thm_ml |> Utils.rhs_of
|
|
|
|
|
2016-06-29 15:35:37 +00:00
|
|
|
(* Export assumptions, converting callees to schematics. *)
|
|
|
|
val thm_ml = Morphism.thm m thm_ml
|
|
|
|
|
2014-07-14 19:32:44 +00:00
|
|
|
(*val _ = @{trace} (fn_name, #name rule_set, cterm_of (Proof_Context.theory_of ctxt) new_term)*)
|
|
|
|
|
|
|
|
(* Determine if the conversion was successful. *)
|
|
|
|
val success = #valid_term rule_set ctxt new_term
|
|
|
|
in
|
2016-06-29 15:35:37 +00:00
|
|
|
(* Determine if the lifting succeeded. *)
|
2014-07-14 19:32:44 +00:00
|
|
|
if success then
|
|
|
|
SOME (thm_ml, dict)
|
|
|
|
else
|
|
|
|
NONE
|
|
|
|
end
|
|
|
|
|
|
|
|
(* Like perform_lift, but also applies the polishing rules, hopefully yielding
|
|
|
|
* an even nicer definition. *)
|
2016-06-29 15:35:37 +00:00
|
|
|
fun perform_lift_and_polish ctxt prog_info fn_info rule_set do_opt fn_name =
|
|
|
|
case (perform_lift ctxt prog_info fn_info rule_set fn_name)
|
2014-07-14 19:32:44 +00:00
|
|
|
of NONE => NONE
|
|
|
|
| SOME (thm, dict) => SOME let
|
|
|
|
|
|
|
|
(* Apply any polishing rules. *)
|
2014-11-14 03:10:35 +00:00
|
|
|
val polish_thm = Monad_Convert.polish ctxt rule_set do_opt thm
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
in (polish_thm, dict) end
|
|
|
|
|
|
|
|
|
|
|
|
(*
|
2016-04-20 07:33:52 +00:00
|
|
|
* Attempt to lift a function (or recursive function group) into the given monad.
|
2014-07-14 19:32:44 +00:00
|
|
|
*
|
2016-04-20 07:33:52 +00:00
|
|
|
* If successful, we define the new function (vis. group) to the theory.
|
|
|
|
* We then return a theorem of the form:
|
2014-07-14 19:32:44 +00:00
|
|
|
*
|
|
|
|
* "L2_call foo x y z == <lift> $ new_foo x y z"
|
|
|
|
*
|
2016-04-20 07:33:52 +00:00
|
|
|
* where "lift" is a lifting function, such as "returnOk" or "gets", etc.
|
2014-07-14 19:32:44 +00:00
|
|
|
*
|
|
|
|
* If the lift does not succeed, the function returns NONE.
|
2016-06-29 15:35:37 +00:00
|
|
|
*
|
|
|
|
* The callees of this function need to be already translated in ts_infos
|
|
|
|
* and also defined in lthy.
|
2014-07-14 19:32:44 +00:00
|
|
|
*)
|
2016-06-29 15:35:37 +00:00
|
|
|
fun lift_function_rewrite rule_set filename prog_info l2_infos ts_infos
|
|
|
|
fn_names make_function_name keep_going do_opt lthy =
|
2014-07-14 19:32:44 +00:00
|
|
|
let
|
2016-06-29 15:35:37 +00:00
|
|
|
val these_l2_infos = map (the o Symtab.lookup l2_infos) fn_names
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
(* Determine if this function is recursive. *)
|
2016-06-29 15:35:37 +00:00
|
|
|
val is_recursive = FunctionInfo.is_function_recursive (hd these_l2_infos)
|
|
|
|
|
|
|
|
(* Fetch relevant callees. *)
|
|
|
|
val callees =
|
|
|
|
map #callees these_l2_infos
|
|
|
|
|> Symset.union_sets
|
2014-07-14 19:32:44 +00:00
|
|
|
|> Symset.dest
|
2016-06-29 15:35:37 +00:00
|
|
|
val callee_infos = map (Symtab.lookup ts_infos #> the) callees
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
(* Add monad_mono theorems. *)
|
2016-06-29 15:35:37 +00:00
|
|
|
val callee_l2_infos = map (Symtab.lookup l2_infos #> the) callees
|
|
|
|
val mono_thms = List.mapPartial #mono_thm callee_l2_infos
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
(* When lifting, also lift our callees. *)
|
|
|
|
val rule_set' = Monad_Types.update_mt_lift_rules
|
2016-06-29 15:35:37 +00:00
|
|
|
(fn thms => Monad_Types.thmset_adds thms (map #corres_thm callee_infos @ mono_thms))
|
2014-07-14 19:32:44 +00:00
|
|
|
rule_set
|
|
|
|
|
|
|
|
(*
|
|
|
|
* Attempt to lift all functions into this type.
|
|
|
|
*
|
|
|
|
* For mutually recursive functions, every function in the group needs to be
|
|
|
|
* lifted in the same type.
|
|
|
|
*
|
|
|
|
* Eliminate the "SOME", raising an exception if any function in the group
|
|
|
|
* couldn't be lifted to this type.
|
|
|
|
*)
|
|
|
|
val lifted_functions =
|
2016-06-29 15:35:37 +00:00
|
|
|
map (perform_lift_and_polish lthy prog_info l2_infos rule_set' do_opt) fn_names
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
val lifted_functions = map (fn x =>
|
|
|
|
case x of
|
|
|
|
SOME a => a
|
|
|
|
| NONE => raise LiftingFailed ())
|
|
|
|
lifted_functions
|
|
|
|
val thms = map fst lifted_functions
|
|
|
|
val dicts = map snd lifted_functions
|
|
|
|
|
|
|
|
(*
|
|
|
|
* Generate terms necessary for defining the function, and define the
|
|
|
|
* functions.
|
|
|
|
*)
|
2016-06-29 15:35:37 +00:00
|
|
|
fun gen_fun_def_term (fn_name, dict, thm) lthy =
|
2014-07-14 19:32:44 +00:00
|
|
|
let
|
2016-06-29 15:35:37 +00:00
|
|
|
(* Fix function parameters. *)
|
|
|
|
val fn_info = the (Symtab.lookup l2_infos fn_name)
|
2016-06-30 04:41:55 +00:00
|
|
|
(* Only produce measure parameters for recursive functions. *)
|
2016-06-29 15:35:37 +00:00
|
|
|
val all_arg_names = (if is_recursive then ["rec_measure'"] else []) @
|
|
|
|
map fst (#args fn_info)
|
|
|
|
val all_arg_types = (if is_recursive then [AutoCorresUtil.measureT] else []) @
|
|
|
|
map snd (#args fn_info)
|
|
|
|
val (all_arg_names, lthy') = Variable.variant_fixes all_arg_names lthy
|
|
|
|
val (measure_param, fn_args) =
|
|
|
|
if is_recursive
|
|
|
|
then chop 1 (all_arg_names ~~ all_arg_types)
|
|
|
|
else ([], all_arg_names ~~ all_arg_types)
|
|
|
|
|
|
|
|
(* Extract function body and abstract over the function arguments.
|
|
|
|
* FIXME: this is an ugly way to do it. It's probably better if perform_lift
|
|
|
|
* doesn't convert to schematic vars to begin with *)
|
|
|
|
|
|
|
|
(* First, convert schematic argument variables (of unknown names) to known free variables. *)
|
|
|
|
val inst_args = map (Free #> Thm.cterm_of lthy' #> SOME) fn_args
|
|
|
|
(* L2 function always takes a measure variable, so reserve a slot *)
|
|
|
|
val inst_measure = if is_recursive
|
|
|
|
then map (Free #> Thm.cterm_of lthy' #> SOME) measure_param else [NONE]
|
|
|
|
(* thm also has a schematic const for each recursive assumption,
|
|
|
|
* which we need to skip. (The measure var appears first, because it
|
|
|
|
* is also used in each recursive call.) *)
|
|
|
|
val skip_callees = replicate (Symset.card (#rec_callees fn_info)) NONE
|
|
|
|
val inst_thm = Drule.infer_instantiate' lthy' (inst_measure @ skip_callees @ inst_args) thm
|
|
|
|
|
|
|
|
(* Extract the body from the conversion theorem.
|
|
|
|
* E.g. for "L2_call foo = liftE body" we extract "body". *)
|
2014-07-14 19:32:44 +00:00
|
|
|
fun tail_of (_ $ x) = x
|
2016-06-29 15:35:37 +00:00
|
|
|
val body =
|
|
|
|
Thm.concl_of inst_thm
|
|
|
|
|> Utils.rhs_of_eq
|
|
|
|
|> tail_of
|
|
|
|
(* This converts callee Vars in assumptions to Frees. Won't be
|
|
|
|
* necessary if we generate Frees (see previous FIXME). *)
|
|
|
|
|> Utils.unsafe_unvarify
|
|
|
|
(* Abstract over args, which are now known arg frees *)
|
|
|
|
val term = foldr (fn (v, t) => Utils.abs_over "" v t)
|
|
|
|
body (map Free (measure_param @ fn_args))
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
(* Replace place-holder function names with our generated constant name. *)
|
|
|
|
val term = map_aterms (fn t => case t of
|
|
|
|
Free (n, T) =>
|
|
|
|
(case Symtab.lookup dict n of
|
|
|
|
NONE => t
|
2015-09-16 11:18:16 +00:00
|
|
|
| SOME a => Free (make_function_name a, T))
|
2014-07-14 19:32:44 +00:00
|
|
|
| x => x) term
|
|
|
|
in
|
2016-06-29 15:35:37 +00:00
|
|
|
((make_function_name fn_name, measure_param @ fn_args, term), lthy')
|
2014-07-14 19:32:44 +00:00
|
|
|
end
|
2016-06-29 15:35:37 +00:00
|
|
|
(* NB: discard lthy because we don't want the placeholders to be fixed during define_functions *)
|
|
|
|
val (input_defs, _) = fold_map gen_fun_def_term (Utils.zip3 fn_names dicts thms) lthy
|
|
|
|
val (ts_defs, lthy) = Utils.define_functions input_defs false is_recursive lthy
|
2016-06-30 04:41:55 +00:00
|
|
|
(* TODO: we may want to cleanup callees and rec_callees here, like we do
|
|
|
|
* in other phases. It's not crucial, however, since this is the
|
|
|
|
* final phase. *)
|
2016-04-20 07:33:52 +00:00
|
|
|
|
2014-07-14 19:32:44 +00:00
|
|
|
(* Instantiate variables in our equivalence theorem to their newly-defined values. *)
|
|
|
|
fun do_inst_thm thm =
|
2015-05-21 06:44:02 +00:00
|
|
|
Utils.instantiate_thm_vars lthy (
|
2014-07-14 19:32:44 +00:00
|
|
|
fn ((name, _), t) =>
|
|
|
|
try (unprefix "rec'") name
|
2016-06-29 15:35:37 +00:00
|
|
|
(* FIXME: lookup from ts_infos *)
|
2015-09-16 11:18:16 +00:00
|
|
|
|> Option.map make_function_name
|
2014-07-14 19:32:44 +00:00
|
|
|
|> Option.map (Utils.get_term lthy)
|
2015-05-21 06:44:02 +00:00
|
|
|
|> Option.map (Thm.cterm_of lthy)) thm
|
2014-07-14 19:32:44 +00:00
|
|
|
val inst_thms = map do_inst_thm thms
|
|
|
|
|
2016-06-29 15:35:37 +00:00
|
|
|
(* HACK: If an L2 function takes no parameters, its measure gets eta-contracted
|
|
|
|
away, preventing eqsubst_tac from unifying the rec_measure's schematic
|
|
|
|
variable. So we get rid of the schematic var pre-emptively. *)
|
2014-07-14 19:32:44 +00:00
|
|
|
val inst_thms =
|
|
|
|
map (fn inst_thm => Drule.abs_def inst_thm handle TERM _ => inst_thm) inst_thms
|
|
|
|
|
|
|
|
(* Generate a theorem converting "L2_call <func>" into its new form,
|
|
|
|
* such as L2_call <func> = liftE $ <new_func_def> *)
|
|
|
|
val final_props = map (fn fn_name =>
|
2016-06-29 15:35:37 +00:00
|
|
|
make_lift_equality lthy prog_info l2_infos fn_name rule_set'
|
2015-09-16 11:18:16 +00:00
|
|
|
NONE NONE (Utils.get_term lthy (make_function_name fn_name))) fn_names
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
(* Convert meta-logic into HOL statements, conjunct them together and setup
|
|
|
|
* our goal statement. *)
|
2015-05-21 06:44:02 +00:00
|
|
|
val int_props = map (Object_Logic.atomize_term lthy) final_props
|
2014-07-14 19:32:44 +00:00
|
|
|
val goal = Utils.mk_conj_list int_props
|
|
|
|
|> HOLogic.mk_Trueprop
|
2015-05-21 06:44:02 +00:00
|
|
|
|> Thm.cterm_of lthy
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
val simps =
|
|
|
|
#L2_simp_rules rule_set' @
|
|
|
|
@{thms gets_bind_ign L2_call_fail HOL.simp_thms}
|
|
|
|
|
|
|
|
val rewrite_thm =
|
|
|
|
Goal.init goal
|
|
|
|
|> (fn goal_state => if is_recursive then
|
|
|
|
goal_state
|
|
|
|
|> apply_tac "start induction"
|
2016-02-04 08:24:50 +00:00
|
|
|
(resolve_tac lthy @{thms recguard_induct} 1)
|
2014-07-14 19:32:44 +00:00
|
|
|
|> apply_tac "(base case) split subgoals"
|
2015-05-21 06:44:02 +00:00
|
|
|
(TRY (REPEAT_ALL_NEW (Tactic.match_tac lthy [@{thm conjI}]) 1))
|
2014-07-14 19:32:44 +00:00
|
|
|
|> apply_tac "(base case) solve base cases"
|
2016-06-29 15:35:37 +00:00
|
|
|
(EVERY (map (fn (ts_def, l2_info) =>
|
2014-07-14 19:32:44 +00:00
|
|
|
SOLVES (
|
2016-06-29 15:35:37 +00:00
|
|
|
(EqSubst.eqsubst_tac lthy [0] [#definition l2_info] 1)
|
|
|
|
THEN (EqSubst.eqsubst_tac lthy [0] [ts_def] 1)
|
|
|
|
THEN (simp_tac (put_simpset ts_simpset lthy) 1))) (ts_defs ~~ these_l2_infos)))
|
2014-07-14 19:32:44 +00:00
|
|
|
|> apply_tac "(rec case) spliting induct case prems"
|
2015-05-21 06:44:02 +00:00
|
|
|
(TRY (REPEAT_ALL_NEW (Tactic.ematch_tac lthy [@{thm conjE}]) 1))
|
2014-07-14 19:32:44 +00:00
|
|
|
|> apply_tac "(rec case) split inductive case subgoals"
|
2015-05-21 06:44:02 +00:00
|
|
|
(TRY (REPEAT_ALL_NEW (Tactic.match_tac lthy [@{thm conjI}]) 1))
|
2014-07-14 19:32:44 +00:00
|
|
|
|> apply_tac "(rec case) unfolding strengthened function definition"
|
|
|
|
(EVERY (map (fn (def, inst_thm) =>
|
|
|
|
((EqSubst.eqsubst_tac lthy [0] [def] 1)
|
|
|
|
THEN (EqSubst.eqsubst_tac lthy [0] [inst_thm] 1)
|
2016-06-29 15:35:37 +00:00
|
|
|
THEN (REPEAT (CHANGED (asm_full_simp_tac (put_simpset ts_simpset lthy) 1)))))
|
|
|
|
(ts_defs ~~ inst_thms)))
|
2014-07-14 19:32:44 +00:00
|
|
|
else
|
|
|
|
goal_state
|
|
|
|
|> apply_tac "unfolding strengthen function definition"
|
2016-06-29 15:35:37 +00:00
|
|
|
(EqSubst.eqsubst_tac lthy [0] [hd ts_defs] 1)
|
2014-07-14 19:32:44 +00:00
|
|
|
|> apply_tac "unfolding L2 rewritten theorem"
|
|
|
|
(EqSubst.eqsubst_tac lthy [0] [hd inst_thms] 1)
|
|
|
|
|> apply_tac "simplifying remainder"
|
|
|
|
(TRY (simp_tac (put_simpset HOL_ss (Utils.set_hidden_ctxt lthy) addsimps simps) 1))
|
|
|
|
)
|
|
|
|
|> Goal.finish lthy
|
|
|
|
|
|
|
|
(* Now, using this combined theorem, generate a theorem for each individual
|
|
|
|
* function. *)
|
|
|
|
fun prove_partial_pred thm pred =
|
2015-05-21 06:44:02 +00:00
|
|
|
Thm.cterm_of lthy pred
|
2014-07-14 19:32:44 +00:00
|
|
|
|> Goal.init
|
|
|
|
|> apply_tac "inserting hypothesis"
|
|
|
|
(cut_tac thm 1)
|
|
|
|
|> apply_tac "normalising into rule format"
|
2016-02-04 08:24:50 +00:00
|
|
|
((REPEAT (resolve_tac lthy @{thms allI} 1))
|
|
|
|
THEN (REPEAT (eresolve_tac lthy @{thms conjE} 1))
|
|
|
|
THEN (REPEAT (eresolve_tac lthy @{thms allE} 1)))
|
|
|
|
|> apply_tac "solving goal" (assume_tac lthy 1)
|
2014-07-14 19:32:44 +00:00
|
|
|
|> Goal.finish lthy
|
2014-08-08 07:29:54 +00:00
|
|
|
|> Object_Logic.rulify lthy
|
2014-07-14 19:32:44 +00:00
|
|
|
val new_thms = map (prove_partial_pred rewrite_thm) final_props
|
|
|
|
|
|
|
|
(*
|
|
|
|
* Make the theorems polymorphic in their exception type.
|
|
|
|
*
|
|
|
|
* That is, these theories may all be applied regardless of what the type of
|
|
|
|
* the exception part of the monad is, but are currently specialised to
|
|
|
|
* when the exception part of the monad is unit. We apply a "polymorphism theorem" to change
|
|
|
|
* the type of the rule from:
|
|
|
|
*
|
|
|
|
* ('s, unit + 'a) nondet_monad
|
|
|
|
*
|
|
|
|
* to
|
|
|
|
*
|
|
|
|
* ('s, 'e + 'a) nondet_monad
|
|
|
|
*)
|
|
|
|
val new_thms = map (fn t => #polymorphic_thm rule_set' OF [t]) new_thms
|
2021-11-15 22:02:03 +00:00
|
|
|
|> map (fn t => Drule.generalize (Names.empty, Names.make_set ["rec_measure'"]) t)
|
2016-06-29 15:35:37 +00:00
|
|
|
|
|
|
|
(* Final mono rules for recursive functions. *)
|
|
|
|
fun try_mono_thm (fn_name, rewrite_thm) =
|
|
|
|
case #mono_thm (the (Symtab.lookup l2_infos fn_name)) of
|
|
|
|
NONE => NONE
|
|
|
|
| SOME l2_mono_thm =>
|
|
|
|
case try (#prove_mono rule_set' rewrite_thm) l2_mono_thm of
|
2016-05-25 05:32:52 +00:00
|
|
|
NONE => (Utils.ac_warning ("Failed to prove monad_mono for function: " ^ fn_name);
|
2016-06-29 15:35:37 +00:00
|
|
|
NONE)
|
|
|
|
| SOME ts_mono_thm => SOME (fn_name, ts_mono_thm);
|
|
|
|
val mono_thms = (if is_recursive then List.mapPartial try_mono_thm (fn_names ~~ new_thms) else [])
|
|
|
|
|> Symtab.make;
|
|
|
|
|
|
|
|
val ts_infos =
|
|
|
|
map (fn ((f, l2_info), (ts_def, ts_corres)) => let
|
|
|
|
val mono_thm = Symtab.lookup mono_thms f;
|
|
|
|
val const = Utils.get_term lthy (make_function_name f);
|
|
|
|
val ts_info = l2_info
|
|
|
|
|> FunctionInfo.function_info_upd_phase FunctionInfo.TS
|
|
|
|
|> FunctionInfo.function_info_upd_definition ts_def
|
|
|
|
|> FunctionInfo.function_info_upd_corres_thm ts_corres
|
|
|
|
|> FunctionInfo.function_info_upd_const const
|
|
|
|
|> FunctionInfo.function_info_upd_mono_thm mono_thm;
|
|
|
|
in (f, ts_info) end)
|
|
|
|
((fn_names ~~ these_l2_infos) ~~ (ts_defs ~~ new_thms))
|
2014-07-14 19:32:44 +00:00
|
|
|
in
|
2016-06-29 15:35:37 +00:00
|
|
|
(lthy, Symtab.make ts_infos, #name rule_set')
|
2014-07-14 19:32:44 +00:00
|
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
(* Return the lifting rule(s) to try for a function set.
|
|
|
|
This is moved out of lift_function so that it can be used to
|
|
|
|
provide argument checking in the AutoCorres.abstract wrapper. *)
|
|
|
|
fun compute_lift_rules rules force_lift fn_names =
|
|
|
|
let
|
|
|
|
fun all_list f xs = fold (fn x => (fn b => b andalso f x)) xs true
|
|
|
|
|
|
|
|
val forced = fn_names
|
|
|
|
|> map (fn func => case Symtab.lookup force_lift func of
|
|
|
|
SOME rule => [(func, rule)]
|
|
|
|
| NONE => [])
|
|
|
|
|> List.concat
|
|
|
|
in
|
|
|
|
case forced of
|
|
|
|
[] => rules (* No restrictions *)
|
|
|
|
| ((func, rule) :: rest) =>
|
|
|
|
(* Functions in the same set must all use the same lifting rule. *)
|
|
|
|
if map snd rest |> all_list (fn rule' => #name rule = #name rule')
|
|
|
|
then [rule] (* Try the specified rule *)
|
|
|
|
else error ("autocorres: this set of mutually recursive functions " ^
|
|
|
|
"cannot be lifted to different monads: " ^
|
|
|
|
commas_quote (map fst forced))
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
(* Lift the given function set, trying each rule until one succeeds. *)
|
2016-06-29 15:35:37 +00:00
|
|
|
fun lift_function rules force_lift filename prog_info l2_infos ts_infos
|
|
|
|
fn_names make_function_name keep_going do_opt lthy =
|
2014-07-14 19:32:44 +00:00
|
|
|
let
|
2016-06-29 15:35:37 +00:00
|
|
|
val rules' = compute_lift_rules rules force_lift fn_names
|
2014-07-14 19:32:44 +00:00
|
|
|
|
|
|
|
(* Find the first lift that works. *)
|
|
|
|
fun first (rule::xs) =
|
2016-06-29 15:35:37 +00:00
|
|
|
(lift_function_rewrite rule filename prog_info l2_infos ts_infos
|
|
|
|
fn_names make_function_name keep_going do_opt lthy
|
2014-07-14 19:32:44 +00:00
|
|
|
handle LiftingFailed _ => first xs)
|
2016-06-29 15:35:37 +00:00
|
|
|
| first [] = raise AllLiftingFailed (map (fn f =>
|
|
|
|
(f, #definition (the (Symtab.lookup l2_infos f)))) fn_names)
|
2014-07-14 19:32:44 +00:00
|
|
|
in
|
|
|
|
first rules'
|
|
|
|
end
|
|
|
|
|
|
|
|
(* Show how many functions were lifted to each monad. *)
|
|
|
|
fun print_statistics results =
|
|
|
|
let
|
|
|
|
fun count_dups x [] = [x]
|
|
|
|
| count_dups (head, count) (next::rest) =
|
|
|
|
if head = next then
|
|
|
|
count_dups (head, count + 1) rest
|
|
|
|
else
|
|
|
|
(head, count) :: (count_dups (next, 1) rest)
|
|
|
|
val tabulated = count_dups ("__fake__", 0) (sort_strings results) |> tl
|
|
|
|
val data = map (fn (a,b) =>
|
2016-11-10 21:45:41 +00:00
|
|
|
(" " ^ a ^ ": " ^ (@{make_string} b) ^ "\n")
|
2014-07-14 19:32:44 +00:00
|
|
|
) tabulated
|
|
|
|
|> String.concat
|
|
|
|
in
|
|
|
|
writeln ("Type Strengthening Statistics: \n" ^ data)
|
|
|
|
end
|
|
|
|
|
2016-06-29 15:35:37 +00:00
|
|
|
(* Run through every function, attempting to strengthen its type.
|
|
|
|
* FIXME: this stage is currently completely sequential. Conversions
|
|
|
|
* that do not depend on each other should be in parallel;
|
|
|
|
* this requires splitting the convert and define stages as usual. *)
|
|
|
|
fun translate
|
|
|
|
(rules : Monad_Types.monad_type list)
|
|
|
|
(force_lift : Monad_Types.monad_type Symtab.table)
|
|
|
|
(filename : string)
|
|
|
|
(prog_info : ProgramInfo.prog_info)
|
|
|
|
(l2_results : FunctionInfo.phase_results)
|
|
|
|
(existing_l2_infos : FunctionInfo.function_info Symtab.table)
|
|
|
|
(existing_ts_infos : FunctionInfo.function_info Symtab.table)
|
|
|
|
(make_function_name : string -> string)
|
|
|
|
(keep_going : bool)
|
|
|
|
(do_opt : bool)
|
|
|
|
(add_trace: string -> string -> AutoCorresData.Trace -> unit)
|
|
|
|
: FunctionInfo.phase_results =
|
|
|
|
if FSeq.null l2_results then FSeq.empty () else
|
2014-07-14 19:32:44 +00:00
|
|
|
let
|
2016-06-29 15:35:37 +00:00
|
|
|
(* Wait for previous stage to finish first. *)
|
|
|
|
val l2_results = FSeq.list_of l2_results;
|
|
|
|
val lthy = List.last l2_results |> fst;
|
|
|
|
val l2_infos = Utils.symtab_merge false (map snd l2_results);
|
|
|
|
|
|
|
|
(* Prettify bound variable names in definitions. *)
|
|
|
|
val l2_infos = Symtab.map (fn f_name => fn info => let
|
|
|
|
val def = #definition (the (Symtab.lookup l2_infos f_name));
|
|
|
|
val pretty_def = PrettyBoundVarNames.pretty_bound_vars_thm
|
|
|
|
lthy (Utils.crhs_of (Thm.cprop_of def)) keep_going
|
|
|
|
|> Thm.transitive def;
|
|
|
|
in FunctionInfo.function_info_upd_definition pretty_def info end)
|
|
|
|
l2_infos;
|
|
|
|
|
|
|
|
(* Add prior L2 translations to round out L2 callees. *)
|
|
|
|
val l2_infos = AutoCorresUtil.add_background_callees existing_l2_infos l2_infos;
|
|
|
|
val l2_infos = Symtab.merge (K false) (l2_infos, existing_l2_infos);
|
|
|
|
|
|
|
|
(* For now, just works sequentially like the old TypeStrengthen. *)
|
|
|
|
fun translate_group fn_names (lthy, _, ts_infos) =
|
2014-07-14 19:32:44 +00:00
|
|
|
let
|
2016-06-29 15:35:37 +00:00
|
|
|
val _ = writeln ("Translating (type strengthen) " ^ Utils.commas fn_names);
|
|
|
|
val start_time = Timer.startRealTimer ();
|
|
|
|
|
|
|
|
val (lthy, new_ts_infos, monad_name) =
|
|
|
|
lift_function rules force_lift filename prog_info l2_infos ts_infos
|
|
|
|
fn_names make_function_name keep_going do_opt lthy;
|
|
|
|
|
|
|
|
val _ = writeln (" --> " ^ monad_name);
|
|
|
|
val _ = tracing ("Converted (TS) " ^ Utils.commas fn_names ^ " in " ^
|
|
|
|
Time.toString (Timer.checkRealTimer start_time) ^ " s");
|
|
|
|
in (lthy, new_ts_infos, Symtab.merge (K false) (ts_infos, new_ts_infos)) end;
|
|
|
|
|
|
|
|
val ts_results =
|
|
|
|
Utils.accumulate translate_group (lthy, Symtab.empty, existing_ts_infos)
|
|
|
|
(map (map fst o Symtab.dest o snd) l2_results)
|
|
|
|
|> fst
|
|
|
|
|> map (fn (lthy, ts_infos, _) => (lthy, ts_infos));
|
2014-07-14 19:32:44 +00:00
|
|
|
in
|
2016-06-29 15:35:37 +00:00
|
|
|
FSeq.of_list ts_results
|
2014-07-14 19:32:44 +00:00
|
|
|
end
|
|
|
|
|
2016-04-20 07:33:52 +00:00
|
|
|
end
|