1655 lines
67 KiB
Standard ML
1655 lines
67 KiB
Standard ML
(*
|
|
* Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
|
|
*
|
|
* SPDX-License-Identifier: BSD-2-Clause
|
|
*)
|
|
|
|
(*
|
|
* Extract local variables out of converted L1 fragments.
|
|
*
|
|
* The main interface to this module is translate (and helper functions
|
|
* convert and define). See AutoCorresUtil for a conceptual overview.
|
|
*)
|
|
structure LocalVarExtract =
|
|
struct
|
|
|
|
open Prog
|
|
|
|
(* Convenience abbreviations for set manipulation. *)
|
|
infix 1 INTER MINUS UNION
|
|
val empty_set = Varset.empty
|
|
val make_set = Varset.make
|
|
val union_sets = Varset.union_sets
|
|
val dest_set = Varset.dest
|
|
fun (a INTER b) = Varset.inter a b
|
|
fun (a MINUS b) = Varset.subtract b a
|
|
fun (a UNION b) = Varset.union a b
|
|
|
|
(* Convenience shortcuts. *)
|
|
val warning = Utils.ac_warning
|
|
val apply_tac = Utils.apply_tac
|
|
val the' = Utils.the'
|
|
|
|
(* Simpset we use for automated tactics. *)
|
|
fun setup_l2_ss ctxt = put_simpset AUTOCORRES_SIMPSET ctxt
|
|
addsimps @{thms pred_conj_def}
|
|
|
|
(* Convert a set of variable names into an Isabelle list of strings. *)
|
|
fun var_set_to_isa_list prog_info s =
|
|
dest_set s
|
|
|> map fst
|
|
|> map (ProgramInfo.demangle_name prog_info)
|
|
|> map Utils.encode_isa_string
|
|
|> Utils.encode_isa_list @{typ string}
|
|
|
|
(*
|
|
* Remove references to local variables in "term", replacing them with free
|
|
* variables.
|
|
*
|
|
* We return a list of variables that were successfully extracted, along with
|
|
* the modified term itself.
|
|
*
|
|
* For instance:
|
|
*
|
|
* convert_local_vars @{term "a_' s + b + c"}
|
|
* [("x", @{term "a_' s"}), ("y", @{term "b_' s"})]
|
|
*
|
|
* would return ("x", @{term "x + b + c"}).
|
|
*)
|
|
fun convert_local_vars name_map term [] = ([], term) | convert_local_vars name_map term ((var_name, var_term) :: vars) =
|
|
if Utils.contains_subterm var_term term then
|
|
let
|
|
val free_var = name_map (var_name, fastype_of var_term)
|
|
|
|
(* Pull out "term" from "var_term". *)
|
|
val abstracted = betapply (Utils.abs_over var_name var_term term, free_var)
|
|
|
|
(* Pull out the other variables. *)
|
|
val (other_vars, other_term) = convert_local_vars name_map abstracted vars
|
|
in
|
|
(other_vars @ [(var_name, fastype_of var_term)], other_term)
|
|
end
|
|
else
|
|
convert_local_vars name_map term vars
|
|
|
|
(* Get the set of variables a function accepts and returns. *)
|
|
fun get_fn_input_output_vars prog_info l1_infos fn_name =
|
|
let
|
|
val fn_info = the (Symtab.lookup l1_infos fn_name);
|
|
val inputs = #args fn_info |> Varset.make;
|
|
|
|
(* Get the return type of a function. *)
|
|
val return_ctype =
|
|
ProgramAnalysis.get_rettype fn_name (#csenv prog_info)
|
|
|> Utils.the' ("Function missing from C-parser's csenv: " ^ quote fn_name)
|
|
|
|
val outputs =
|
|
if return_ctype = Absyn.Void then
|
|
empty_set
|
|
else
|
|
make_set [(NameGeneration.return_var_name return_ctype |> MString.dest,
|
|
#return_type fn_info)]
|
|
in
|
|
(inputs, outputs)
|
|
end
|
|
|
|
(* Get the return variable of a particular function. *)
|
|
fun get_ret_var prog_info l1_infos fn_name =
|
|
let
|
|
val (_, outputs) = get_fn_input_output_vars prog_info l1_infos fn_name
|
|
in
|
|
hd ((Varset.dest outputs) @ [("void", @{typ unit})])
|
|
end
|
|
|
|
(*
|
|
* Determine the state, return and exception type of a monad.
|
|
*
|
|
* Monads have the form:
|
|
*
|
|
* 'a => 'b => ... => 's => ('x, 'y, 's) L2_monad.
|
|
*
|
|
* We return:
|
|
*
|
|
* (['a, 'b, ...], ('x, 'y, 's))
|
|
*)
|
|
fun dest_l2monad_T t =
|
|
let
|
|
val (Type ("Product_Type.prod", [Type ("Set.set", [Type ("Product_Type.prod", [
|
|
Type ("Sum_Type.sum", [ex, ret]) ,state])]), _]))
|
|
= body_type t
|
|
val args = binder_types t
|
|
val inputs = List.take (args, length args - 1)
|
|
in
|
|
(inputs, (state, ret, ex))
|
|
end
|
|
fun l2monad_type monad =
|
|
dest_l2monad_T (fastype_of monad) |> snd
|
|
fun l2monad_state_type monad = #1 (l2monad_type monad)
|
|
fun l2monad_ret_type monad = #2 (l2monad_type monad)
|
|
fun l2monad_ex_type monad = #3 (l2monad_type monad)
|
|
|
|
(* Get the abstract/concrete term from a "L2corres" predicate. *)
|
|
fun dest_L2corres_term_abs @{term_pat "L2corres _ _ _ _ ?t _"} = t
|
|
fun dest_L2corres_term_conc @{term_pat "L2corres _ _ _ _ _ ?t"} = t
|
|
|
|
(* Make an L2 monad. *)
|
|
fun mk_l2monadT stateT retT exT =
|
|
Utils.gen_typ @{typ "('a, 'b, 'c) L2_monad"} [stateT, retT, exT]
|
|
|
|
(*
|
|
* "Spec" expressions are of the form:
|
|
*
|
|
* {(s, t). f s t}
|
|
*
|
|
* where "s" and "t" are input/output states. We want to parse the expression,
|
|
* and convert it to an L2 expression dealing only with globals in "s" and "t".
|
|
*
|
|
* If the original SIMPL spec attempts to read or write to local variables, we
|
|
* just fail.
|
|
*)
|
|
fun parse_spec ctxt prog_info term =
|
|
let
|
|
(*
|
|
* If simplification was turned off in L1, the spec may still contain
|
|
* unions and intersections, i.e. be of the form
|
|
* {(s, t). f s t} \<union> {(s, t). g s t} ...
|
|
* We blithely rewrite them here.
|
|
*)
|
|
val term = Raw_Simplifier.rewrite_term (Proof_Context.theory_of ctxt)
|
|
(map safe_mk_meta_eq @{thms Collect_prod_inter Collect_prod_union}) [] term
|
|
|
|
(* Apply a dummy old and new state variable to the term. *)
|
|
val dummy_s = Free ("_dummy_state1", #state_type prog_info)
|
|
val dummy_t = Free ("_dummy_state2", #state_type prog_info)
|
|
val dummy_tuple = HOLogic.mk_tuple [dummy_s, dummy_t]
|
|
val t = Envir.beta_eta_contract (
|
|
(Const (@{const_name "Set.member"},
|
|
fastype_of dummy_tuple --> fastype_of term --> @{typ bool})
|
|
$ dummy_tuple $ term))
|
|
|
|
(* Pull apart the "split" at the beginning of the term, then apply
|
|
* to our dummy variables *)
|
|
val t = Raw_Simplifier.rewrite_term (Proof_Context.theory_of ctxt)
|
|
(map mk_meta_eq @{thms split_def fst_conv snd_conv mem_Collect_eq}) [] t
|
|
|
|
(*
|
|
* Pull out any references to any other variables into a lambda
|
|
* function.
|
|
*
|
|
* We pull out the globals variable first, because we want it to end
|
|
* up inner-most compared to all the other lambdas we generate.
|
|
*)
|
|
val globals_getter = #globals_getter prog_info
|
|
val t = Utils.abs_over "t" (globals_getter $ dummy_t) t
|
|
|> Utils.abs_over "s" (globals_getter $ dummy_s)
|
|
|> HOLogic.mk_case_prod
|
|
val t_collect = @{mk_term "Collect :: (?'s \<Rightarrow> bool) \<Rightarrow> ?'s set" ('s)}
|
|
(domain_type (fastype_of t))
|
|
in
|
|
(* Determine if there are any references left to the dummy state
|
|
* variable. If so, give up on the translation. *)
|
|
if Utils.contains_subterm dummy_s t
|
|
orelse Utils.contains_subterm dummy_t t then
|
|
(warning ("Can't parse spec term: "
|
|
^ (Utils.term_to_string ctxt term)); NONE)
|
|
else
|
|
SOME (t_collect $ t)
|
|
end
|
|
|
|
|
|
(*
|
|
* Parse an L1 expression containing references to the global state.
|
|
*
|
|
* We assume that the input term is in the "abstracted" form "%s. f s" where
|
|
* "s" is the global state variable.
|
|
*
|
|
* Our return value is a list of variables abstracted, whether the global
|
|
* variable was used, and the abstracted term itself.
|
|
*
|
|
* The function will fail (and return NONE) if the input expression performs
|
|
* arbitrary transformations on the state. For example:
|
|
*
|
|
* "%s. a_' s" => ([a], False, SOME @{term "%a s. a"})
|
|
* "%s. globals s" => ([], True, SOME @{term "%s. s"})
|
|
* "%s. a_' s + b_' s" => ([a, b], False, SOME @{term "%a b s. a + b"})
|
|
* "%s. False" => ([], False, SOME @{term "%s. False"})
|
|
* "%s. bot s" => ([], False, NONE)
|
|
*)
|
|
fun parse_expr ctxt prog_info name_map term =
|
|
let
|
|
val dummy_state = Free ("_dummy_state", #state_type prog_info)
|
|
|
|
(* Apply a dummy state variable to the term. This makes our later analysis
|
|
* easier. *)
|
|
val term = Envir.beta_eta_contract (term $ dummy_state)
|
|
|
|
(*
|
|
* Pull out any references to any other variables into a lambda
|
|
* function.
|
|
*
|
|
* We pull out the globals variable first, because we want it to end
|
|
* up inner-most compared to all the other lambdas we generate.
|
|
*)
|
|
val globals_getter = #globals_getter prog_info $ dummy_state
|
|
val globals_used = Utils.contains_subterm globals_getter term
|
|
val t = Utils.abs_over "s" globals_getter term
|
|
|
|
(* Pull out local variables. *)
|
|
val all_getters = #var_getters prog_info |> Symtab.dest |> map (fn (a,b) => (a, b $ dummy_state))
|
|
val (v1, t) = convert_local_vars name_map t all_getters
|
|
|
|
(*
|
|
* Determine if there are any references left to the dummy state
|
|
* variable.
|
|
*
|
|
* If so, we are stuck: we aren't pulling out a part of the state
|
|
* record, but instead performing an arbitrary transformation on it.
|
|
* The most likely reason for this is the C parser's dummy function
|
|
* "lvar_init", which attempts to set an uninitialised local
|
|
* variable to an invalid state. Other possibilities include "bot",
|
|
* the always-false guard.
|
|
*)
|
|
val t = if Utils.contains_subterm dummy_state t then
|
|
(warning ("Can't parse expression: "
|
|
^ (Utils.term_to_string ctxt term)); NONE)
|
|
else
|
|
SOME t;
|
|
in
|
|
(v1, globals_used, t)
|
|
end
|
|
|
|
(*
|
|
* Parse an "L1_modify" expression.
|
|
*)
|
|
local
|
|
fun parse_modify' ctxt prog_info name_map term =
|
|
let
|
|
val dummy_state = Free ("_dummy_state", #state_type prog_info)
|
|
|
|
(*
|
|
* We expect modify clauses in two forms: both "%x. (foo x) x" and just
|
|
* "foo". We apply a state variable to the function and beta/eta contract
|
|
* to normalise our output for the next steps.
|
|
*)
|
|
val modify_clause = Envir.beta_eta_contract (term $ dummy_state)
|
|
|
|
(*
|
|
* Extract "xxx" from "foo_'_update xxx".
|
|
*
|
|
* If the user has written custom "modifies" clauses (presumably
|
|
* using "AUXUPD" directives), this may fail.
|
|
*)
|
|
val (setter, modify_val, s) = case modify_clause of
|
|
(Const var $ value $ s) => (Const var, value, s)
|
|
| other => Utils.invalid_term "Const (x,y) $ z" other;
|
|
val (var_name, var_type) = ProgramInfo.guess_var_name_type_from_setter_term setter
|
|
|
|
(*
|
|
* At this stage we have assume we have an update function "f" of
|
|
* type "'a => 'a" which expects the old value of the variable
|
|
* being updated, and returns a new value.
|
|
*
|
|
* We now want to convert this into a value of type "'a", returning
|
|
* the new value. We do this by applying "(field_' s)" to the
|
|
* function f, followed by normalisation.
|
|
*)
|
|
val getter =
|
|
case (Symtab.lookup (#var_getters prog_info) var_name) of
|
|
SOME v => v
|
|
| NONE => Utils.invalid_input "valid local variable getter" var_name
|
|
val modify_val = betapply (modify_val, getter $ dummy_state)
|
|
|> Envir.beta_eta_contract
|
|
|
|
(*
|
|
* We are now in the form of "foo dummy_state". Pull out
|
|
* our dummy state variable, and parse the expression.
|
|
*)
|
|
fun remove_dummy_state_var t = Utils.abs_over "s" dummy_state t
|
|
val (vars, globals_used, modify_val) = parse_expr ctxt prog_info name_map (remove_dummy_state_var modify_val)
|
|
in
|
|
((var_name, var_type), vars, globals_used, modify_val, remove_dummy_state_var s)
|
|
end
|
|
in
|
|
fun parse_modify ctxt prog_info name_map term =
|
|
let
|
|
val dummy_state = Free ("_dummy_state", #state_type prog_info)
|
|
in
|
|
if Envir.beta_eta_contract (term $ dummy_state) = dummy_state then
|
|
[]
|
|
else
|
|
let
|
|
val (updated_var, read_vars, reads_globals, term, residual) = parse_modify' ctxt prog_info name_map term
|
|
in
|
|
(updated_var, read_vars, reads_globals, term) :: parse_modify ctxt prog_info name_map residual
|
|
end
|
|
end
|
|
end
|
|
|
|
(*
|
|
* Construct precondition from variable set.
|
|
*
|
|
* These preconditions are of the form:
|
|
*
|
|
* "(%s. n_' s = n) and (%s. i_' s = i) and ..."
|
|
*)
|
|
fun mk_precond prog_info name_map vars =
|
|
let
|
|
val myvarsT = #state_type prog_info
|
|
val dummy_state = Free ("_dummy_state", myvarsT)
|
|
|
|
(* Fetch a variable getter, such as "n_'" from a variable's name. *)
|
|
fun var_getter var =
|
|
case Symtab.lookup (#var_getters prog_info) var of
|
|
SOME x => (x $ dummy_state)
|
|
| NONE => Utils.invalid_input "valid local variable name" var
|
|
in
|
|
Utils.chain_preds myvarsT
|
|
(map (fn (var_name, var_type) => Utils.abs_over "s" dummy_state
|
|
(HOLogic.mk_eq (var_getter var_name, name_map (var_name, var_type))))
|
|
(dest_set vars))
|
|
end
|
|
|
|
(*
|
|
* Construct extraction functions, of the form:
|
|
*
|
|
* "%s. (a_' s, b_' s, c_' s)"
|
|
*)
|
|
fun mk_xf (prog_info : ProgramInfo.prog_info) vars =
|
|
let
|
|
val dummy_state = Free ("_dummy_state", #state_type prog_info)
|
|
fun var_getter var =
|
|
((Symtab.lookup (#var_getters prog_info) var |> the) $ dummy_state)
|
|
handle Option => (Utils.invalid_input "valid local variable name" var)
|
|
in
|
|
Utils.abs_over "s" dummy_state
|
|
(HOLogic.mk_tuple (dest_set vars |> map fst |> map var_getter))
|
|
end
|
|
|
|
(*
|
|
* Construct a correspondence lemma between a given L2 and L1 terms.
|
|
*)
|
|
fun mk_corresXF_prop thy prog_info name_map return_vars except_vars precond_vars l2_term l1_term =
|
|
let
|
|
(* Construct precondition and extraction functions. *)
|
|
val precond = mk_precond prog_info name_map precond_vars
|
|
val return_xf = mk_xf prog_info return_vars
|
|
val except_xf = mk_xf prog_info except_vars
|
|
in
|
|
Utils.mk_term thy @{term L2corres} [#globals_getter prog_info, return_xf,
|
|
except_xf, precond, l2_term, l1_term]
|
|
|> HOLogic.mk_Trueprop
|
|
end
|
|
|
|
(*
|
|
* Prove correspondence between L1 and L2.
|
|
*
|
|
* ctxt: Local theory context
|
|
*
|
|
* return_vars: Variables that are returned by the abstract spec's monad.
|
|
*
|
|
* except_vars: Variables that are thrown by the abstract spec's monad.
|
|
*
|
|
* precond_vars: Variables that must match between abstract and concrete.
|
|
*
|
|
* l2_term / l1_term: Abstract and concrete specs.
|
|
*)
|
|
fun mk_corresXF_thm ctxt prog_info name_map return_vars except_vars precond_vars l2_term l1_term tac =
|
|
let
|
|
val free_vars = precond_vars |> dest_set |> map name_map
|
|
val free_names = map (dest_Free #> fst) free_vars
|
|
in
|
|
mk_corresXF_prop (Proof_Context.theory_of ctxt) prog_info name_map
|
|
return_vars except_vars precond_vars l2_term l1_term
|
|
|> (fn x => Goal.prove ctxt free_names [] x (fn _ => tac))
|
|
end
|
|
|
|
fun mk_corresXF_thm' ctxt prog_info name_map return_vars except_vars precond_vars l2_term l1_term thm =
|
|
mk_corresXF_thm ctxt prog_info name_map return_vars except_vars precond_vars l2_term l1_term (
|
|
(rewrite_goal_tac ctxt [mk_meta_eq @{thm split_def}] 1)
|
|
THEN
|
|
(resolve_tac ctxt [rewrite_rule ctxt [mk_meta_eq @{thm split_def}] thm] 1)
|
|
THEN
|
|
(REPEAT (CHANGED (asm_full_simp_tac (setup_l2_ss ctxt) 1)))
|
|
)
|
|
|
|
fun l1call_function_const t = case strip_comb t |> apsnd rev of
|
|
(Const c, (Const c' :: _)) => if String.isSuffix "_'proc" (fst c')
|
|
then Const c' else Const c
|
|
| (Const c, _) => Const c
|
|
| (Abs (_, _, t), []) => l1call_function_const t
|
|
| _ => raise TERM ("l1call_function_const", [t])
|
|
|
|
(*
|
|
* Parse an L1 term.
|
|
*
|
|
* In particular, we break down the structure of the program and parse the
|
|
* usage of local variables in all expressions and modifies clauses.
|
|
*)
|
|
fun parse_l1 ctxt prog_info l1_infos l1_call_info name_map term =
|
|
case term of
|
|
(Const (@{const_name "L1_skip"}, _)) =>
|
|
Modify (term,
|
|
(SOME (Abs ("s", #globals_type prog_info, @{term "()"})), empty_set, false), NONE)
|
|
|
|
| (Const (@{const_name "L1_modify"}, _) $ m) =>
|
|
let
|
|
val parsed_clause = parse_modify ctxt prog_info name_map m
|
|
val (updated_var, read_vars, is_globals_reader, parsed_expr) =
|
|
case parsed_clause of
|
|
[x] => x
|
|
| _ => Utils.invalid_term "Modifies clause too complex." m
|
|
in
|
|
Modify (term, (parsed_expr, make_set read_vars, is_globals_reader), SOME updated_var)
|
|
end
|
|
|
|
| (Const (@{const_name "L1_seq"}, _) $ lhs $ rhs) =>
|
|
Seq (term, parse_l1 ctxt prog_info l1_infos l1_call_info name_map lhs,
|
|
parse_l1 ctxt prog_info l1_infos l1_call_info name_map rhs)
|
|
|
|
| (Const (@{const_name "L1_catch"}, _) $ lhs $ rhs) =>
|
|
Catch (term, parse_l1 ctxt prog_info l1_infos l1_call_info name_map lhs,
|
|
parse_l1 ctxt prog_info l1_infos l1_call_info name_map rhs)
|
|
|
|
| (Const (@{const_name "L1_guard"}, _) $ c) =>
|
|
let
|
|
val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map c
|
|
in
|
|
Guard (term, (parsed_expr, make_set read_vars, is_globals_reader))
|
|
end
|
|
|
|
| (Const (@{const_name "L1_throw"}, _)) =>
|
|
Throw term
|
|
|
|
| (Const (@{const_name "L1_condition"}, _) $ cond $ lhs $ rhs) =>
|
|
let
|
|
(* Parse the conditional. *)
|
|
val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map cond
|
|
in
|
|
Condition (term, (parsed_expr, make_set read_vars, is_globals_reader),
|
|
parse_l1 ctxt prog_info l1_infos l1_call_info name_map lhs,
|
|
parse_l1 ctxt prog_info l1_infos l1_call_info name_map rhs)
|
|
end
|
|
|
|
| (Const (@{const_name "L1_call"}, L1_call_type)
|
|
$ arg_setup $ dest_fn_term $ globals_extract $ ret_extract) =>
|
|
let
|
|
(* Parse arg setup. We treat this not as a modify, but as several
|
|
* expressions, as the modified variables are only in the scope of
|
|
* this L1_call command. *)
|
|
val arg_setup_exprs = parse_modify ctxt prog_info name_map arg_setup
|
|
|> map (fn (_, read_vars, is_globals_reader, term) =>
|
|
(term, make_set read_vars, is_globals_reader))
|
|
|
|
val dest_fn_term = case dest_fn_term of
|
|
Const (@{const_name "measure_call"}, _) $ f => f
|
|
| _ => dest_fn_term
|
|
|
|
(* Get the name of the variable the return value of the function will
|
|
* be placed into. *)
|
|
val dest_fn = Termtab.lookup (#const_to_function l1_call_info) (l1call_function_const dest_fn_term)
|
|
|> Utils.the' ("Unknown function " ^ quote (@{make_string} dest_fn_term))
|
|
|> Symtab.lookup l1_infos |> the
|
|
|
|
(* Parse the return arguments. *)
|
|
val ret_var = get_ret_var prog_info l1_infos (#name dest_fn)
|
|
val parsed_clause =
|
|
parse_modify ctxt prog_info name_map (betapply (ret_extract, Free ("_dummy_state", #state_type prog_info)))
|
|
|> map (fn (target_var, read_vars, globals_read, expr) =>
|
|
(target_var, (make_set read_vars) MINUS (make_set [ret_var]),
|
|
globals_read, Option.map (Utils.abs_over "ret" (name_map ret_var)) expr))
|
|
|
|
val (ret_expr, updated_var) =
|
|
case parsed_clause of
|
|
[(target_var, read_vars, globals_read, expr)] =>
|
|
((expr, read_vars, globals_read), SOME target_var)
|
|
| [] => ((NONE, empty_set, false), NONE)
|
|
| x => Utils.invalid_input "single return param" (@{make_string} x)
|
|
in
|
|
Call (term, arg_setup_exprs, ret_expr, updated_var, ())
|
|
end
|
|
|
|
| (Const (@{const_name "L1_while"}, _) $ cond $ body) =>
|
|
let
|
|
(* Parse conditional. *)
|
|
val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map cond;
|
|
in
|
|
While (term, (parsed_expr, make_set read_vars, is_globals_reader),
|
|
parse_l1 ctxt prog_info l1_infos l1_call_info name_map body)
|
|
end
|
|
|
|
| (Const (@{const_name "L1_init"}, _) $ setter) =>
|
|
let
|
|
val updated_var = ProgramInfo.guess_var_name_type_from_setter_term setter
|
|
in
|
|
Init (term, SOME updated_var)
|
|
end
|
|
|
|
| (Const (@{const_name "L1_spec"}, _) $ c) =>
|
|
(case parse_spec ctxt prog_info c of
|
|
SOME x =>
|
|
Spec (term, (SOME x, empty_set, true))
|
|
| NONE =>
|
|
Spec (term, (NONE, empty_set, true)))
|
|
|
|
| (Const (@{const_name "L1_fail"}, _)) =>
|
|
Fail term
|
|
|
|
| (Const (@{const_name "L1_recguard"}, _) $ var $ body) =>
|
|
RecGuard (term, parse_l1 ctxt prog_info l1_infos l1_call_info name_map body)
|
|
|
|
| other => Utils.invalid_term "a L1 term" other
|
|
|
|
(*
|
|
* Generate a proof showing that a particular variables "var" is not modified
|
|
* over the given input L1 term.
|
|
*)
|
|
fun mk_preservation_proof ctxt prog_info name_map var term =
|
|
let
|
|
val thy = Proof_Context.theory_of ctxt
|
|
|
|
(* Apply a tactic then simplify all remaining subgoals. *)
|
|
fun s tac =
|
|
tac THEN (TRY (REPEAT (CHANGED (asm_full_simp_tac (setup_l2_ss ctxt) 1))))
|
|
|
|
(* Apply a rule then simplify all remaining subgoals. *)
|
|
fun r thm = s (resolve_tac ctxt [thm] 1)
|
|
|
|
(* Generate the predicate. *)
|
|
val var_set = make_set [var]
|
|
val precond = mk_precond prog_info name_map var_set
|
|
val postcond_ret = absdummy @{typ unit} (mk_precond prog_info name_map var_set)
|
|
val postcond_ex = absdummy @{typ unit} (mk_precond prog_info name_map var_set)
|
|
val goal =
|
|
Utils.mk_term thy @{term validE} [precond, term, postcond_ret, postcond_ex]
|
|
|> HOLogic.mk_Trueprop
|
|
|
|
(* Construct a tactic that solves the problem. *)
|
|
val tac =
|
|
(case term of
|
|
(Const (@{const_name "L1_skip"}, _)) =>
|
|
r @{thm L1_skip_lp}
|
|
| (Const (@{const_name "L1_init"}, _) $ _) =>
|
|
r @{thm L1_init_lp}
|
|
| (Const (@{const_name "L1_modify"}, _) $ _) =>
|
|
r @{thm L1_modify_lp}
|
|
| (Const (@{const_name "L1_call"}, _) $ _ $ _ $ _ $ _) =>
|
|
r @{thm L1_call_lp}
|
|
| (Const (@{const_name "L1_guard"}, _) $ _) =>
|
|
r @{thm L1_guard_lp}
|
|
| (Const (@{const_name "L1_throw"}, _)) =>
|
|
r @{thm L1_throw_lp}
|
|
| (Const (@{const_name "L1_spec"}, _) $ _) =>
|
|
r @{thm hoareE_TrueI}
|
|
| (Const (@{const_name "L1_fail"}, _)) =>
|
|
r @{thm L1_fail_lp}
|
|
| (Const (@{const_name "L1_while"}, _) $ _ $ body) =>
|
|
let
|
|
val body' = mk_preservation_proof ctxt prog_info name_map var body
|
|
in
|
|
s (resolve_tac ctxt @{thms L1_while_lp} 1 THEN resolve_tac ctxt [body'] 1)
|
|
end
|
|
| (Const (@{const_name "L1_condition"}, _) $ _ $ lhs $ rhs) =>
|
|
let
|
|
val lhs' = mk_preservation_proof ctxt prog_info name_map var lhs
|
|
val rhs' = mk_preservation_proof ctxt prog_info name_map var rhs
|
|
in
|
|
s (resolve_tac ctxt @{thms L1_condition_lp} 1 THEN resolve_tac ctxt [lhs'] 1 THEN resolve_tac ctxt [rhs'] 1)
|
|
end
|
|
| (Const (@{const_name "L1_seq"}, _) $ lhs $ rhs) =>
|
|
let
|
|
val lhs' = mk_preservation_proof ctxt prog_info name_map var lhs
|
|
val rhs' = mk_preservation_proof ctxt prog_info name_map var rhs
|
|
in
|
|
s (resolve_tac ctxt @{thms L1_seq_lp} 1 THEN resolve_tac ctxt [lhs'] 1 THEN resolve_tac ctxt [rhs'] 1)
|
|
end
|
|
| (Const (@{const_name "L1_catch"}, _) $ lhs $ rhs) =>
|
|
let
|
|
val lhs' = mk_preservation_proof ctxt prog_info name_map var lhs
|
|
val rhs' = mk_preservation_proof ctxt prog_info name_map var rhs
|
|
in
|
|
s (resolve_tac ctxt @{thms L1_catch_lp} 1 THEN resolve_tac ctxt [lhs'] 1 THEN resolve_tac ctxt [rhs'] 1)
|
|
end
|
|
| (Const (@{const_name "L1_recguard"}, _) $ _ $ body) =>
|
|
let
|
|
val body' = mk_preservation_proof ctxt prog_info name_map var body
|
|
in
|
|
s (resolve_tac ctxt @{thms L1_recguard_lp} 1 THEN resolve_tac ctxt [body'] 1)
|
|
end
|
|
| other => Utils.invalid_term "a L1 term" other)
|
|
in
|
|
(* Generate proof. *)
|
|
Thm.cterm_of ctxt goal
|
|
|> Goal.init
|
|
|> Utils.apply_tac ("proving variable preservation for var '" ^ (fst var) ^ "'") tac
|
|
|> Goal.finish ctxt
|
|
end
|
|
|
|
(* Generate a preservation proof for multiple variables. *)
|
|
fun mk_multivar_preservation_proof ctxt prog_info name_map term var_set =
|
|
let
|
|
val proofs = map (fn x =>
|
|
mk_preservation_proof ctxt prog_info name_map x term)
|
|
(dest_set var_set)
|
|
val result = fold (fn x => fn y => @{thm combine_validE} OF [x,y])
|
|
proofs @{thm hoareE_TrueI}
|
|
in
|
|
result
|
|
end
|
|
handle Option => error ("Preservation proof failed for " ^ quote (@{make_string} var_set))
|
|
|
|
(*
|
|
* Generate a well-typed L2 monad expression.
|
|
*
|
|
* "const" is the name of the monadic function (e.g., @{const_name "L2_gets"})
|
|
*
|
|
* "ret"/"throw" are the variables being returned or thrown by this monadic
|
|
* expression. This is used only for determining the type of the output
|
|
* monad.
|
|
*
|
|
* "params" are the expressions to be beta applied to the monad.
|
|
*)
|
|
fun mk_l2monad (prog_info : ProgramInfo.prog_info) const ret throw params =
|
|
let
|
|
val retT = HOLogic.mk_tupleT (dest_set ret |> map snd)
|
|
val exT = HOLogic.mk_tupleT (dest_set throw |> map snd)
|
|
val monadT = mk_l2monadT (#globals_type prog_info) retT exT
|
|
in
|
|
betapplys ((Const (const, (map fastype_of params) ---> monadT)), params)
|
|
end
|
|
|
|
(* Abstract over a tuple using the given name map. *)
|
|
fun abs_over_tuple_vars (name_map : (string * typ) -> term) (vars : varset) =
|
|
Utils.abs_over_tuple (map (fn (a, b) => (a, name_map (a, b))) (dest_set vars))
|
|
|
|
(*
|
|
* Take an L2corres theorem of the form:
|
|
*
|
|
* L2corres st ret_xf ex_xf P (foo a b c) X
|
|
*
|
|
* and convert it into the form:
|
|
*
|
|
* L2corres st ret_xf ex_xf P ((%(a, b, c). foo a b c) (a, b, c)) X
|
|
*
|
|
* This is used to ease unification in proofs where the abstract monad is
|
|
* expected to be of the form "A x", where "x" is the return value of another
|
|
* monad.
|
|
*)
|
|
fun abs_over_thm ctxt (name_map : (string * typ) -> term) (thm : thm) (vars : varset) =
|
|
let
|
|
fun convert_var_to_free x =
|
|
case x of
|
|
Var ((a, _), t) => Free (a, t)
|
|
| x => x
|
|
fun convert_free_to_var x =
|
|
case x of
|
|
Free (a, t) => Var ((a, 0), t)
|
|
| x => x
|
|
val @{term_pat "?head \<comment> \<open>L2corres\<close> ?st ?ret_xf ?ex_xf ?precond ?l2_term ?l1_term"} =
|
|
map_aterms convert_var_to_free (Thm.concl_of thm) |> HOLogic.dest_Trueprop
|
|
val new_l2_term = (abs_over_tuple_vars name_map vars l2_term
|
|
$ Free ("r'", HOLogic.mk_tupleT (dest_set vars |> map snd)))
|
|
val new_concl =
|
|
head $ st $ ret_xf $ ex_xf $ precond $ new_l2_term $ l1_term
|
|
|> map_aterms convert_free_to_var
|
|
|> HOLogic.mk_Trueprop
|
|
val new_thm = list_implies (cprems_of thm, Thm.cterm_of ctxt new_concl)
|
|
in
|
|
Goal.init new_thm
|
|
|> asm_full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps [mk_meta_eq @{thm split_def}]) 1 |> Seq.hd
|
|
|> resolve_tac ctxt [rewrite_rule ctxt [mk_meta_eq @{thm split_def}] thm] 1 |> Seq.hd
|
|
|> REPEAT (assume_tac ctxt 1) |> Seq.hd
|
|
|> Goal.finish ctxt
|
|
end
|
|
|
|
(*
|
|
* Given a L2 monad that returns the variables "vars_returned", convert it into
|
|
* an L2 monad that returns "needed_returns".
|
|
*
|
|
* This is frequently needed when a particular monad is only capable of returning
|
|
* a particular variable (or set of variables), but needs to return a different set
|
|
* of these variables. For example, both branches in an "condition" block need
|
|
* to return the same set of variables.
|
|
*
|
|
* The injection is done by (if necessary) appending an additional "L2_seq" to
|
|
* the input monad, returning the desired set of variables.
|
|
*
|
|
* "allow_excess" is the output monad is allowed to return a superset of
|
|
* "needed_returns". By allowing such excess variables to be returned, the
|
|
* generated output can be neater than if we were more strict.
|
|
*)
|
|
fun inject_return_vals ctxt prog_info name_map needed_returns allow_excess throw_vars fn_vars
|
|
term (vars_read, vars_returned, output_monad, thm) =
|
|
if needed_returns = vars_returned then
|
|
(* We already have precisely what is needed --- no more to do. *)
|
|
(vars_read, vars_returned, output_monad, thm)
|
|
else if (allow_excess andalso Varset.subset (needed_returns, vars_returned)) then
|
|
(* We already provide a superset of what is needed, and this is allowed. *)
|
|
(vars_read, vars_returned, output_monad, thm)
|
|
else
|
|
let
|
|
val (l1_term, _, _) = get_node_data term
|
|
|
|
(* Generate the return statement. *)
|
|
val injected_return =
|
|
mk_l2monad prog_info @{const_name L2_gets} needed_returns throw_vars
|
|
[absdummy (#globals_type prog_info) (HOLogic.mk_tuple (dest_set needed_returns |> map name_map)),
|
|
var_set_to_isa_list prog_info needed_returns]
|
|
|> abs_over_tuple_vars name_map vars_returned
|
|
|
|
(* Append the return statement to the input term. *)
|
|
val generated_term = mk_l2monad prog_info @{const_name L2_seq}
|
|
needed_returns throw_vars [output_monad, injected_return]
|
|
val preserved_vals = needed_returns MINUS vars_returned
|
|
|
|
(* Generate a proof of correctness. *)
|
|
val generated_thm =
|
|
let
|
|
val preserve_proof = mk_multivar_preservation_proof ctxt prog_info name_map l1_term preserved_vals
|
|
in
|
|
mk_corresXF_thm' ctxt prog_info name_map needed_returns throw_vars (vars_read UNION preserved_vals)
|
|
generated_term l1_term
|
|
(@{thm L2corres_inject_return} OF [thm, @{thm validE_weaken} OF [preserve_proof]])
|
|
end
|
|
in
|
|
(vars_read UNION preserved_vals, needed_returns, generated_term, generated_thm)
|
|
end
|
|
|
|
(*
|
|
* Convert an L1 function into an L2 function.
|
|
*
|
|
* We assume that our input term has come out of the L1 conversion functions.
|
|
*
|
|
* We have inputs of the following:
|
|
*
|
|
* ctxt: Isabelle context
|
|
*
|
|
* needed_vars:
|
|
*
|
|
* Variables that are read in later executions.
|
|
*
|
|
* These are passed into the conversion so that we know what variables
|
|
* we need to track for later execution, and what variables we can
|
|
* just discard on the spot.
|
|
*
|
|
* If we didn't know what we actually needed to track, then the
|
|
* converted code would be significantly bloated due to returning
|
|
* variables that aren't actually used.
|
|
*
|
|
* allow_excess:
|
|
*
|
|
* Are we allowed to return _more_ variables than otherwise needed
|
|
* according to needed_vars? By setting this to true, more efficient
|
|
* code can be generated.
|
|
*
|
|
* throw_vars:
|
|
*
|
|
* Variables that must be thrown in the event we decide to emit an
|
|
* "L2_throw" call. These are calculated as we enter a try/catch block
|
|
* to ensure that all sites are consistent in the values they throw.
|
|
*
|
|
* term: The L1 term to convert.
|
|
*
|
|
* The return value of this function is a tuple:
|
|
*
|
|
* (<vars read by block>, <vars returned>, <term>, <proof>)
|
|
*
|
|
* The "vars returned" is the variables that are returned through the "bind"
|
|
* combinator.
|
|
*)
|
|
fun do_conv
|
|
(ctxt : Proof.context)
|
|
prog_info
|
|
(l1_infos : FunctionInfo.function_info Symtab.table)
|
|
(l1_call_info : FunctionInfo.call_graph_info)
|
|
name_map
|
|
(fn_vars : varset)
|
|
(callee_proofs : (bool * term * thm) Symtab.table)
|
|
(needed_vars : varset)
|
|
(allow_excess : bool)
|
|
(throw_vars : varset)
|
|
(term : (term * varset * varset, term option * varset * bool, (string * typ) option, unit) prog)
|
|
: (varset * varset * term * thm) =
|
|
let
|
|
val l1_term = get_node_data term |> #1
|
|
val live_vars = get_node_data term |> #2
|
|
val modified_vars = get_node_data term |> #3
|
|
val inject =
|
|
inject_return_vals ctxt prog_info name_map needed_vars allow_excess throw_vars fn_vars term
|
|
fun mkthm read_vars ret_vars generated_term thm =
|
|
mk_corresXF_thm' ctxt prog_info name_map ret_vars throw_vars read_vars generated_term l1_term thm
|
|
val mk_monad = mk_l2monad prog_info
|
|
in
|
|
case term of
|
|
Init (_, SOME output_var) =>
|
|
let
|
|
val out_vars = make_set [output_var]
|
|
val generated_term = mk_monad @{const_name L2_unknown} out_vars throw_vars
|
|
[Utils.ml_str_list_to_isa [fst output_var]]
|
|
val thm = mkthm empty_set out_vars generated_term @{thm L2corres_spec_unknown}
|
|
in
|
|
inject (empty_set, out_vars, generated_term, thm)
|
|
end
|
|
|
|
(* L1_skip. *)
|
|
| Modify (_, (SOME expr, _, _), NONE) =>
|
|
let
|
|
val generated_term = mk_monad @{const_name L2_gets}
|
|
empty_set throw_vars [expr, var_set_to_isa_list prog_info empty_set]
|
|
val thm = mkthm empty_set empty_set generated_term @{thm L2corres_gets_skip}
|
|
in
|
|
inject (empty_set, empty_set, generated_term, thm)
|
|
end
|
|
|
|
(* L1_modify with unparsable expression. *)
|
|
| Modify (_, (NONE, _, _), SOME output_var) =>
|
|
let
|
|
val out_vars = make_set [output_var]
|
|
val generated_term = mk_monad @{const_name L2_unknown} out_vars throw_vars []
|
|
val thm = mkthm empty_set out_vars generated_term @{thm L2corres_modify_unknown}
|
|
in
|
|
inject (empty_set, out_vars, generated_term, thm)
|
|
end
|
|
|
|
(* L1_modify that only modifies globals. *)
|
|
| Modify (_, (SOME expr, read_vars, _), SOME ("globals'", _)) =>
|
|
let
|
|
val generated_term = mk_monad @{const_name L2_modify} empty_set throw_vars [expr]
|
|
val thm = mkthm read_vars empty_set generated_term @{thm L2corres_modify_global}
|
|
in
|
|
inject (read_vars, empty_set, generated_term, thm)
|
|
end
|
|
|
|
(* L1_modify that only modifies a local and also reads globals. *)
|
|
| Modify (_, (SOME expr, read_vars, _), SOME output_var) =>
|
|
let
|
|
val generated_term = mk_monad @{const_name L2_gets}
|
|
(make_set [output_var]) throw_vars [expr, var_set_to_isa_list prog_info (make_set [output_var])]
|
|
val thm = mkthm read_vars (make_set [output_var]) generated_term @{thm L2corres_modify_gets}
|
|
in
|
|
inject (read_vars, make_set [output_var], generated_term, thm)
|
|
end
|
|
|
|
| Throw _ =>
|
|
let
|
|
val generated_term = mk_monad @{const_name L2_throw} needed_vars throw_vars
|
|
[HOLogic.mk_tuple (dest_set throw_vars |> map name_map),
|
|
var_set_to_isa_list prog_info throw_vars]
|
|
val thm = mkthm throw_vars needed_vars generated_term @{thm L2corres_throw}
|
|
in
|
|
(throw_vars, needed_vars, generated_term, thm)
|
|
end
|
|
|
|
| Spec (_, (SOME expr, read_vars, _)) =>
|
|
let
|
|
val generated_term = mk_monad @{const_name "L2_spec"} needed_vars throw_vars [expr]
|
|
val thm = mkthm read_vars needed_vars generated_term @{thm L2corres_spec}
|
|
in
|
|
inject (read_vars, needed_vars, generated_term, thm)
|
|
end
|
|
|
|
| Spec (_, (NONE, _, _)) =>
|
|
let
|
|
val generated_term = mk_monad @{const_name "L2_fail"} needed_vars throw_vars []
|
|
val thm = mkthm empty_set needed_vars generated_term @{thm L2corres_fail}
|
|
in
|
|
inject (empty_set, needed_vars, generated_term, thm)
|
|
end
|
|
|
|
| Guard (_, (SOME expr, read_vars, _)) =>
|
|
let
|
|
val generated_term = mk_monad @{const_name "L2_guard"} empty_set throw_vars [expr]
|
|
val thm = mkthm read_vars empty_set generated_term @{thm L2corres_guard}
|
|
in
|
|
inject (read_vars, empty_set, generated_term, thm)
|
|
end
|
|
|
|
| Guard (_, (NONE, _, _)) =>
|
|
let
|
|
val generated_term = mk_monad @{const_name "L2_fail"} needed_vars throw_vars []
|
|
val thm = mkthm empty_set needed_vars generated_term @{thm L2corres_fail}
|
|
in
|
|
(empty_set, needed_vars, generated_term, thm)
|
|
end
|
|
|
|
| Fail _ =>
|
|
let
|
|
val generated_term = mk_monad @{const_name "L2_fail"} needed_vars throw_vars []
|
|
val thm = mkthm empty_set needed_vars generated_term @{thm L2corres_fail}
|
|
in
|
|
(empty_set, needed_vars, generated_term, thm)
|
|
end
|
|
|
|
| Seq (_, lhs, rhs) =>
|
|
let
|
|
val (_, rhs_live, rhs_modified) = get_node_data rhs
|
|
val (lhs_term, _, lhs_modified) = get_node_data lhs
|
|
|
|
(* Convert LHS and RHS. *)
|
|
val ret_vars = rhs_live INTER lhs_modified
|
|
val (lhs_reads, lhs_rets, new_lhs, lhs_thm)
|
|
= do_conv ctxt prog_info l1_infos l1_call_info name_map
|
|
fn_vars callee_proofs ret_vars true throw_vars lhs
|
|
val (rhs_reads, rhs_rets, new_rhs, rhs_thm)
|
|
= do_conv ctxt prog_info l1_infos l1_call_info name_map
|
|
fn_vars callee_proofs needed_vars allow_excess throw_vars rhs
|
|
val block_reads = lhs_reads UNION (rhs_reads MINUS lhs_modified)
|
|
|
|
(* Reconstruct body to support our input tuple. *)
|
|
val rhs_thm = abs_over_thm ctxt name_map rhs_thm lhs_rets
|
|
val new_rhs = abs_over_tuple_vars name_map lhs_rets new_rhs
|
|
|
|
(* Generate the final term. *)
|
|
val generated_term = mk_monad @{const_name L2_seq} rhs_rets throw_vars [new_lhs, new_rhs]
|
|
|
|
(* Generate a proof. *)
|
|
val thm =
|
|
let
|
|
(* Show that certain variables are preserved by the LHS. *)
|
|
val needed_preserves = (rhs_reads MINUS lhs_modified)
|
|
val preserve_proof = mk_multivar_preservation_proof ctxt prog_info name_map lhs_term needed_preserves
|
|
in
|
|
mkthm block_reads rhs_rets generated_term
|
|
(@{thm L2corres_seq} OF [lhs_thm, rhs_thm,
|
|
@{thm validE_weaken} OF [preserve_proof]])
|
|
end
|
|
in
|
|
inject (block_reads, rhs_rets, generated_term, thm)
|
|
end
|
|
|
|
| Catch (_, lhs, rhs) =>
|
|
let
|
|
val (lhs_term, _, lhs_modified) = get_node_data lhs
|
|
val (_, rhs_live, _) = get_node_data rhs
|
|
|
|
(* Convert LHS and RHS. *)
|
|
val lhs_throws = rhs_live INTER lhs_modified
|
|
val (lhs_reads, lhs_rets, new_lhs, lhs_thm)
|
|
= do_conv ctxt prog_info l1_infos l1_call_info name_map
|
|
fn_vars callee_proofs (needed_vars) false lhs_throws lhs
|
|
val (rhs_reads, _, new_rhs, rhs_thm)
|
|
= do_conv ctxt prog_info l1_infos l1_call_info name_map
|
|
fn_vars callee_proofs (needed_vars) false throw_vars rhs
|
|
val block_reads = lhs_reads UNION (rhs_reads MINUS lhs_throws)
|
|
|
|
(* Reconstruct body to support our input tuple. *)
|
|
val rhs_thm = abs_over_thm ctxt name_map rhs_thm lhs_throws
|
|
val new_rhs = abs_over_tuple_vars name_map lhs_throws new_rhs
|
|
|
|
(* Generate the final term. *)
|
|
val generated_term = mk_monad @{const_name L2_catch} needed_vars throw_vars [new_lhs, new_rhs]
|
|
|
|
(* Generate a proof. *)
|
|
val thm =
|
|
let
|
|
(* Show that certain variables are preserved by the LHS. *)
|
|
val needed_preserves = (rhs_reads MINUS lhs_modified)
|
|
val preserve_proof = mk_multivar_preservation_proof ctxt prog_info name_map lhs_term needed_preserves
|
|
in
|
|
mkthm block_reads needed_vars generated_term
|
|
(@{thm L2corres_catch} OF [lhs_thm, rhs_thm, @{thm validE_weaken} OF [preserve_proof]])
|
|
end
|
|
in
|
|
inject (block_reads, needed_vars, generated_term, thm)
|
|
end
|
|
|
|
| RecGuard (_, body) =>
|
|
let
|
|
(* Convert body. *)
|
|
val (body_reads, vars_returned, new_body, body_thm) =
|
|
do_conv ctxt prog_info l1_infos l1_call_info name_map
|
|
fn_vars callee_proofs needed_vars false throw_vars body
|
|
|
|
(* Get recguard variable. *)
|
|
val (_ $ var $ _) = l1_term
|
|
|
|
(* Generate the final term. *)
|
|
val generated_term =
|
|
mk_monad @{const_name "L2_recguard"} vars_returned throw_vars [
|
|
var, new_body]
|
|
val thm = mkthm body_reads vars_returned generated_term
|
|
(@{thm L2corres_recguard} OF [body_thm])
|
|
in
|
|
inject (body_reads, vars_returned, generated_term, thm)
|
|
end
|
|
|
|
| Condition (_, (SOME expr, read_vars, _), lhs, rhs) =>
|
|
let
|
|
(* Convert LHS and RHS. *)
|
|
val requested_vars = needed_vars INTER modified_vars
|
|
val (lhs_reads, _, new_lhs, lhs_thm)
|
|
= do_conv ctxt prog_info l1_infos l1_call_info name_map
|
|
fn_vars callee_proofs requested_vars false throw_vars lhs
|
|
val (rhs_reads, _, new_rhs, rhs_thm)
|
|
= do_conv ctxt prog_info l1_infos l1_call_info name_map
|
|
fn_vars callee_proofs requested_vars false throw_vars rhs
|
|
val block_reads = lhs_reads UNION rhs_reads UNION read_vars
|
|
|
|
(* Generate the final term. *)
|
|
val generated_term = mk_monad @{const_name "L2_condition"}
|
|
requested_vars throw_vars [expr, new_lhs, new_rhs]
|
|
val thm = mkthm block_reads requested_vars generated_term
|
|
(@{thm L2corres_cond} OF [lhs_thm, rhs_thm])
|
|
in
|
|
inject (block_reads, requested_vars, generated_term, thm)
|
|
end
|
|
|
|
| While (_, (SOME expr, read_vars, _), body) =>
|
|
let
|
|
(* Convert body. *)
|
|
val loop_iterators = (needed_vars UNION live_vars) INTER modified_vars
|
|
val (body_reads, _, new_body, body_thm) =
|
|
do_conv ctxt prog_info l1_infos l1_call_info name_map
|
|
fn_vars callee_proofs loop_iterators false throw_vars body
|
|
val (body_term, _, body_modifies) = get_node_data body
|
|
|
|
(* Reconstruct body to support our input tuple. *)
|
|
val new_body = abs_over_tuple_vars name_map loop_iterators new_body
|
|
val body_thm = abs_over_thm ctxt name_map body_thm loop_iterators
|
|
|
|
(* Generate the final term. *)
|
|
val generated_term =
|
|
mk_monad @{const_name "L2_while"} loop_iterators throw_vars [
|
|
abs_over_tuple_vars name_map loop_iterators expr,
|
|
new_body,
|
|
HOLogic.mk_tuple (dest_set loop_iterators |> map name_map),
|
|
var_set_to_isa_list prog_info loop_iterators]
|
|
|
|
(* Generate a proof. *)
|
|
val thm =
|
|
let
|
|
(* Show that certain variables are preserved by the LHS. *)
|
|
val needed_preserves = ((body_reads UNION read_vars) MINUS body_modifies)
|
|
val preserve_proof = mk_multivar_preservation_proof ctxt prog_info name_map body_term needed_preserves
|
|
|
|
(* Instantiate while loop rule to avoid ambiguous unification. *)
|
|
val tracked_vars = (body_reads UNION read_vars UNION loop_iterators)
|
|
val invariant_precond = abs_over_tuple_vars name_map loop_iterators
|
|
(mk_precond prog_info name_map tracked_vars)
|
|
val base_thm = Utils.named_cterm_instantiate ctxt [
|
|
("P", Thm.cterm_of ctxt invariant_precond),
|
|
("A", Thm.cterm_of ctxt new_body)
|
|
] @{thm L2corres_while}
|
|
in
|
|
mkthm (body_reads UNION read_vars UNION loop_iterators) loop_iterators generated_term
|
|
(base_thm OF [body_thm, @{thm validE_weaken} OF [preserve_proof]])
|
|
end
|
|
in
|
|
inject (body_reads UNION read_vars UNION loop_iterators, loop_iterators, generated_term, thm)
|
|
end
|
|
|
|
| Call (_, expr_list, (ret_expr, ret_read_vars, _), ret_var, measure_term) =>
|
|
let
|
|
val @{term_pat "_ ?arg_setup ?dest_fn ?globals_extract ?ret_extract"} = l1_term
|
|
|
|
val (measure_term, dest_fn) = case dest_fn of
|
|
(c as Const (@{const_name "measure_call"}, _)) $ f => (c, f)
|
|
| f $ (c as Const (@{const_name "undefined"}, _)) => (c, f)
|
|
| f $ (c as Const (@{const_name "recguard_dec"}, _) $ _) => (c, f)
|
|
| _ => raise TERM ("local_var_extract: strange function call", [dest_fn])
|
|
|
|
(* Get destination function. *)
|
|
val dest_fn = Termtab.lookup (#const_to_function l1_call_info) (l1call_function_const dest_fn)
|
|
|> Option.mapPartial (Symtab.lookup l1_infos)
|
|
|
|
(* Lookup the callee proof, if it exists. *)
|
|
val callee_proof = Option.mapPartial
|
|
(Symtab.lookup callee_proofs) (Option.map #name dest_fn)
|
|
in
|
|
(* Determine if we have a proof for the callee. *)
|
|
case callee_proof of
|
|
NONE =>
|
|
(let
|
|
val generated_term = mk_monad @{const_name "L2_fail"} needed_vars throw_vars []
|
|
val thm = mkthm empty_set needed_vars generated_term @{thm L2corres_fail}
|
|
in
|
|
(empty_set, needed_vars, generated_term, thm)
|
|
end)
|
|
|
|
| SOME (is_recursive, callee_free, callee_thm) =>
|
|
(let
|
|
(* Get information about the function. *)
|
|
val dest_fn = the (dest_fn)
|
|
val args = #args dest_fn
|
|
|
|
(* Parse argument setup. *)
|
|
val arg_setup_vals = parse_modify ctxt prog_info name_map arg_setup |> List.rev
|
|
|
|
(* Ensure that we can parse everything. *)
|
|
val arg_setup_vals =
|
|
map (fn (a, b, c, parsed_expr) =>
|
|
case parsed_expr of
|
|
NONE =>
|
|
raise Utils.InvalidInput ("Could not parse function parameter '" ^ (fst a) ^ "'")
|
|
| SOME x =>
|
|
(a, b, c, x)
|
|
) arg_setup_vals
|
|
|
|
(* Sanity check: ensure that we have the correct number of arguments. *)
|
|
val _ = if length arg_setup_vals <> length args then
|
|
raise TERM ("Argument list length does not match function definition.", [arg_setup])
|
|
else
|
|
()
|
|
|
|
(* Rename input parameter names. *)
|
|
val arg_setup_vals = map (fn ((a,t),b,c,d) => ((a ^ "'param", t), b, c, d)) arg_setup_vals
|
|
|
|
(* Generate the call. *)
|
|
(* The measure is the first arg, so we need to skip it when applying the others. *)
|
|
val args = map (fn (a,_,_,_) => name_map a) arg_setup_vals
|
|
val call_args = let
|
|
val var = Free ("rec_measure'", @{typ "nat"})
|
|
in
|
|
lambda var (betapplys (callee_free, var :: args))
|
|
end
|
|
|
|
val call_measure = case measure_term of
|
|
Const (@{const_name "measure_call"}, _) => @{mk_term "measure_call ?f" f} call_args
|
|
| _ => betapply (call_args, measure_term)
|
|
|
|
val (call, ret_vars) =
|
|
case (ret_var, ret_expr) of
|
|
(SOME ("globals'", _), SOME e) =>
|
|
(mk_monad @{const_name L2_modifycall} empty_set throw_vars
|
|
[call_measure, e], empty_set)
|
|
| (SOME x, SOME e) =>
|
|
(mk_monad @{const_name L2_returncall} (make_set [x]) throw_vars
|
|
[call_measure, e], make_set [x])
|
|
| (NONE, _) =>
|
|
(mk_monad @{const_name L2_voidcall} empty_set throw_vars
|
|
[call_measure], empty_set)
|
|
|
|
(*
|
|
* We have a list of arguments; some may be expressions that refer to
|
|
* global variables, while others will be purely local variables. We
|
|
* just emit them all as "L2_gets" calls, and will clean them up
|
|
* later.
|
|
*)
|
|
val extractors = foldr (
|
|
fn ((updated_var, read_vars, is_globals_reader, expr), rest) =>
|
|
let
|
|
val ret_type = (make_set [("x'", fastype_of expr |> body_type)])
|
|
val rest_type = (make_set [("x'", l2monad_ret_type rest)])
|
|
val getter = mk_monad @{const_name L2_folded_gets} ret_type throw_vars
|
|
[expr, Utils.ml_str_list_to_isa [fst updated_var]]
|
|
in
|
|
mk_monad @{const_name "L2_seq"} rest_type throw_vars [
|
|
getter,
|
|
Utils.abs_over (fst updated_var) (name_map updated_var) rest]
|
|
end
|
|
)
|
|
call
|
|
arg_setup_vals
|
|
val read_vars = union_sets (map #2 expr_list) UNION ret_read_vars
|
|
|
|
(* Generate a proof. *)
|
|
val my_debug_tac = if false then print_tac ctxt else fn _ => all_tac
|
|
val L2_call_thms = @{thms L2corres_returncall L2corres_voidcall L2corres_modifycall}
|
|
val L2_reccall_thms = @{thms L2corres_recursive_returncall L2corres_recursive_voidcall L2corres_recursive_modifycall}
|
|
val thm =
|
|
mk_corresXF_thm ctxt prog_info name_map ret_vars throw_vars read_vars extractors l1_term (
|
|
(my_debug_tac "unfold folded_gets"
|
|
THEN (REPEAT (resolve_tac ctxt @{thms L2corres_folded_gets} 1)))
|
|
THEN (my_debug_tac "apply callee proof"
|
|
THEN FIRST (map (fn thm => resolve_tac ctxt [thm] 1 THEN
|
|
resolve_tac ctxt (List.mapPartial I [#mono_thm dest_fn]) 1 THEN
|
|
resolve_tac ctxt [callee_thm] 1)
|
|
L2_call_thms @
|
|
map (fn thm => resolve_tac ctxt [thm] 1 THEN
|
|
resolve_tac ctxt [callee_thm] 1)
|
|
L2_reccall_thms))
|
|
THEN (my_debug_tac "final simp"
|
|
THEN (REPEAT (CHANGED (asm_full_simp_tac (setup_l2_ss ctxt) 1))))
|
|
)
|
|
in
|
|
inject (read_vars, ret_vars, extractors, thm)
|
|
end)
|
|
end
|
|
| _ => Utils.invalid_input "a parsed L1 term"
|
|
(l1_term |> head_of |> @{make_string})
|
|
end
|
|
|
|
(* Get the expected type of a function from its name. *)
|
|
fun get_expected_l2_fn_type prog_info l1_infos fn_name =
|
|
let
|
|
val fn_info = the (Symtab.lookup l1_infos fn_name)
|
|
val fn_params_typ = AutoCorresUtil.measureT :: map snd (#args fn_info)
|
|
in
|
|
fn_params_typ ---> mk_l2monadT (#globals_type prog_info) (#return_type fn_info) @{typ unit}
|
|
end
|
|
|
|
(* Get arguments passed into the function. *)
|
|
fun get_expected_l2_fn_args lthy prog_info l1_infos fn_name =
|
|
let
|
|
val fn_def = the (Symtab.lookup l1_infos fn_name)
|
|
in
|
|
map (apfst (ProgramInfo.demangle_name prog_info)) (#args fn_def)
|
|
end
|
|
|
|
fun get_expected_l2_fn_thm prog_info l1_infos ctxt fn_name fn_free fn_args _ measure_var =
|
|
let
|
|
(* Fetch input/output params for monad type. *)
|
|
val (input_params, output_params) = get_fn_input_output_vars prog_info l1_infos fn_name
|
|
|
|
(* Get mapping from internal variable names that we use to the names passed
|
|
* in "fn_args". *)
|
|
val fn_info = the (Symtab.lookup l1_infos fn_name)
|
|
val args = map fst (#args fn_info)
|
|
val m = Symtab.make (args ~~ fn_args)
|
|
fun name_map (n, _) = Symtab.lookup m n |> the
|
|
in
|
|
mk_corresXF_prop (Proof_Context.theory_of ctxt) prog_info name_map
|
|
output_params empty_set input_params
|
|
(betapplys (fn_free, measure_var :: fn_args))
|
|
(betapply (#const fn_info, measure_var))
|
|
end
|
|
|
|
(* Extract the abstract body of a L2corres theorem. *)
|
|
fun get_body_of_thm ctxt thm =
|
|
Thm.concl_of (Variable.gen_all ctxt thm)
|
|
|> HOLogic.dest_Trueprop
|
|
|> dest_L2corres_term_abs
|
|
|
|
fun get_l2corres_thm ctxt prog_info l1_infos l1_call_info do_opt trace_opt fn_name
|
|
callee_terms fn_args l1_term init_unfold = let
|
|
|
|
(* Get information about the return variable. *)
|
|
val fn_info = the (Symtab.lookup l1_infos fn_name)
|
|
|
|
(* Get return variables. *)
|
|
val (fn_input_vars, fn_ret_vars) = get_fn_input_output_vars prog_info l1_infos fn_name
|
|
|
|
(* Get mapping from internal variable names to external arguments. *)
|
|
val m = Symtab.make (map fst (#args fn_info) ~~ fn_args)
|
|
fun name_map_ext (n, T) = Symtab.lookup m n |> the
|
|
fun name_map_internal (n, T) = Free ("lvar'" ^ n, T)
|
|
|
|
(*
|
|
* Many constructs from SIMPL (and also L1) are in set form, but we really
|
|
* need them to be in functional form to be able to effectively parse them.
|
|
* In particular we can parse:
|
|
*
|
|
* (%s. n_' s)
|
|
*
|
|
* but not:
|
|
*
|
|
* {s. n_' s}
|
|
*
|
|
* We do some basic conversions here to convert common sets into lambda
|
|
* functions.
|
|
*)
|
|
val init_rule = Thm.cterm_of ctxt l1_term
|
|
|> Conv.rewr_conv (safe_mk_meta_eq init_unfold)
|
|
|
|
(* Extract the term we will be working with. *)
|
|
val source_term = Thm.concl_of init_rule |> Utils.rhs_of_eq
|
|
|
|
(* Do basic parsing. *)
|
|
val parsed_term = parse_l1 ctxt prog_info l1_infos l1_call_info name_map_internal source_term
|
|
|
|
(* Get a list of all variables either read from or written to. *)
|
|
val all_vars = Prog.fold_prog
|
|
(K I)
|
|
(fn (_, vars, _) => fn old_vars => vars UNION old_vars)
|
|
(fn mod_var => fn old_vars =>
|
|
case mod_var of SOME x => (Varset.insert x old_vars) | NONE => old_vars)
|
|
(K I)
|
|
parsed_term empty_set
|
|
|
|
(* Perform liveness analysis of the function. *)
|
|
val liveness_data = calc_live_vars parsed_term fn_ret_vars
|
|
|
|
(*
|
|
* Get information about modified variables.
|
|
*
|
|
* "NONE" represents "modifies potentially all variables"; we modify
|
|
* the results to fit this.
|
|
*)
|
|
val modification_data =
|
|
get_modified_vars parsed_term
|
|
|> map_prog (fn x => Option.getOpt (x, all_vars)) I I I
|
|
|
|
(* Combine collected data. *)
|
|
fun zip_node_data a b c =
|
|
zip_progs a (zip_progs b c)
|
|
|> map_prog (fn (a, (b, c)) => (a, b, c)) fst fst fst
|
|
val input_term = zip_node_data parsed_term liveness_data modification_data
|
|
|
|
(* Ensure that the only live variables at the beginning of the function are
|
|
* those that are function inputs. *)
|
|
val fn_inputs = get_node_data liveness_data
|
|
val fn_params = #args fn_info
|
|
val excess_inputs = fn_inputs MINUS (make_set fn_params)
|
|
val _ =
|
|
if excess_inputs <> empty_set then
|
|
warning
|
|
("Input function '" ^ fn_name ^ "' has unresolved variables: "
|
|
^ @{make_string} (dest_set excess_inputs))
|
|
else
|
|
()
|
|
|
|
(* Do the conversion. *)
|
|
val (vars_read, _, term, thm) =
|
|
do_conv ctxt prog_info l1_infos l1_call_info name_map_internal fn_input_vars
|
|
callee_terms fn_ret_vars false empty_set input_term
|
|
|
|
(* Replace our internal terms with external terms. *)
|
|
val replacements = (map name_map_internal fn_params) ~~ (map name_map_ext fn_params)
|
|
val term = Raw_Simplifier.rewrite_term (Proof_Context.theory_of ctxt)
|
|
[] [Termtab.lookup (Termtab.make replacements)] term
|
|
|
|
(*
|
|
* Generate a theorem with a folded RHS, with the LHS unfolded.
|
|
*
|
|
* The idea here is that we must generate a theorem of the form we
|
|
* committed to in "get_expected_l2_fn_thm", but with schematic variables.
|
|
*)
|
|
val new_thm =
|
|
mk_corresXF_prop (Proof_Context.theory_of ctxt) prog_info name_map_ext
|
|
fn_ret_vars empty_set fn_input_vars
|
|
term l1_term
|
|
|> Thm.cterm_of ctxt
|
|
|> Goal.init
|
|
|> apply_tac "unfold RHS" (EqSubst.eqsubst_tac ctxt [0] [init_rule] 1)
|
|
|> apply_tac "generalise guard" (resolve_tac ctxt @{thms L2corres_guard_imp} 1)
|
|
|> apply_tac "solve main goal" (resolve_tac ctxt [thm] 1)
|
|
|> apply_tac "solve guard_imp" (REPEAT (FIRST [
|
|
resolve_tac ctxt @{thms HOL.refl} 1,
|
|
resolve_tac ctxt @{thms pred_conjI} 1,
|
|
resolve_tac ctxt @{thms conjI} 1,
|
|
CHANGED (asm_full_simp_tac (setup_l2_ss ctxt) 1)]))
|
|
|> Goal.finish ctxt
|
|
|
|
(* Remove intermediate scaffolding. *)
|
|
val new_thm = Conv.fconv_rule (
|
|
Utils.remove_meta_conv (fn ctxt =>
|
|
Utils.nth_arg_conv 5 (
|
|
Raw_Simplifier.rewrite ctxt false @{thms L2_remove_scaffolding_1}
|
|
then_conv
|
|
Raw_Simplifier.rewrite ctxt false @{thms L2_remove_scaffolding_2})) ctxt) new_thm
|
|
|
|
(* Cleanup. *)
|
|
val _ = writeln ("Simplifying (L2) " ^ fn_name)
|
|
val new_thm = Simplifier.simplify (put_simpset HOL_basic_ss ctxt addsimps
|
|
(* this rule is expensive *)
|
|
(if do_opt then @{thms L2_unknown_bind} else []))
|
|
new_thm
|
|
|
|
val _ = writeln ("Simplifying (L2opt) " ^ fn_name)
|
|
(* HACK: we need to avoid these simps until heap_lift *)
|
|
val cleanup_del = @{thms ptr_coerce.simps ptr_add_0_id}
|
|
val (new_thm, traces) = L2Opt.cleanup_thm_tagged (ctxt delsimps cleanup_del) new_thm
|
|
(if do_opt then 0 else 2) 5 trace_opt "L2"
|
|
in
|
|
(new_thm, traces)
|
|
end
|
|
|
|
(*
|
|
* Prove monad_mono property for recursive functions.
|
|
* Note that this is also used for subsequent L2-based phases.
|
|
*)
|
|
|
|
fun l2_monad_mono lthy (l2_infos: FunctionInfo.function_info Symtab.table) =
|
|
let
|
|
(*
|
|
* For the induction, we need to have the form
|
|
* "\<And> m. (ALL a b... f m a b...) /\ (ALL a b... g m a b...) /\ ..."
|
|
* and this gets annoying pretty quickly. But it is probably unavoidable.
|
|
*)
|
|
val (fn_names, fn_defs) = split_list (Symtab.dest l2_infos);
|
|
val ([measure_var_name], lthy) = Variable.variant_fixes ["rec_measure'"] lthy;
|
|
val measure = Free (measure_var_name, AutoCorresUtil.measureT)
|
|
fun make_mono_step_stmt current_def =
|
|
let
|
|
(* def should be of the form "func ?locale_args... ?measure ?args... = ..." *)
|
|
val (_, locale_args) = strip_comb (#const current_def)
|
|
val (_, all_args) = Utils.lhs_of_eq (term_of_thm (#definition current_def)) |> strip_comb
|
|
val _ :: args = drop (length locale_args) all_args
|
|
val args = args |> map (fn Var ((name, _), typ) => Free (name, typ))
|
|
in
|
|
fold (fn arg => fn t => @{mk_term "All ?P" P} (lambda arg t)) args
|
|
(@{mk_term "monad_mono_step ?f ?m" (f, m)}
|
|
(lambda measure (betapplys (#const current_def, measure :: args)), measure))
|
|
end
|
|
val mk_conj_list = foldr1 (fn (a, b) => @{term "conj"} $ a $ b)
|
|
|
|
val func_expand = map (fn fn_def => EqSubst.eqsubst_tac lthy [0]
|
|
[Utils.abs_def lthy (#definition fn_def)]) fn_defs
|
|
val tac =
|
|
resolve_tac lthy @{thms nat.induct} 1
|
|
THEN EVERY (map (fn expand =>
|
|
TRY (resolve_tac lthy @{thms conjI} 1)
|
|
THEN REPEAT (resolve_tac lthy @{thms allI} 1)
|
|
THEN expand 1
|
|
THEN resolve_tac lthy @{thms monad_mono_step_L2_recguard_0} 1) func_expand)
|
|
THEN REPEAT (eresolve_tac lthy @{thms conjE} 1)
|
|
THEN EVERY (map (fn expand =>
|
|
TRY (resolve_tac lthy @{thms conjI} 1)
|
|
THEN REPEAT (resolve_tac lthy @{thms allI} 1)
|
|
THEN expand 1
|
|
THEN REPEAT (FIRST (
|
|
map (fn t => resolve_tac lthy [t] 1) @{thms L2_monad_mono_step_rules}
|
|
(* We use simp to solve assumptions (assume_tac doesn't work
|
|
* because the assumptions are ALL-quantified) and to
|
|
* split tuple cases. *)
|
|
@ [CHANGED (asm_full_simp_tac (clear_simpset lthy
|
|
addsimps @{thms split_conv split_tupled_all}) 1)])))
|
|
func_expand)
|
|
|
|
val mono_thm = map make_mono_step_stmt fn_defs
|
|
|> mk_conj_list
|
|
|> (fn t => Logic.all measure (@{term "Trueprop"} $ t))
|
|
|> (fn t => Goal.prove lthy [] [] t (K tac))
|
|
|
|
(* We have finished the induction, now we extract the individual results. *)
|
|
fun make_mono_stmt L2_def =
|
|
let
|
|
val (_, locale_args) = strip_comb (#const L2_def)
|
|
val (_, all_args) = Utils.lhs_of_eq (term_of_thm (#definition L2_def)) |> strip_comb
|
|
val _ :: args = drop (length locale_args) all_args
|
|
val args = args |> map (fn Var ((name, _), typ) => Free (name, typ))
|
|
in
|
|
fold Logic.all args
|
|
(@{mk_term "Trueprop (monad_mono ?f)" f}
|
|
(lambda measure (betapplys (#const L2_def, measure :: args))))
|
|
end
|
|
val final_thms = fn_defs
|
|
|> map (fn fn_def => Goal.prove lthy [] [] (make_mono_stmt fn_def)
|
|
(K (asm_full_simp_tac (lthy addsimps [@{thm monad_mono_alt_def}, mono_thm]) 1)))
|
|
in
|
|
final_thms
|
|
|> (fn thms => fn_names ~~ thms)
|
|
|> Symtab.make
|
|
end
|
|
|
|
(* For functions that are not translated, just generate a trivial wrapper. *)
|
|
fun mk_l2corres_call_simpl_thm prog_info l1_infos ctxt fn_name fn_args = let
|
|
val fn_def = the (Symtab.lookup l1_infos fn_name)
|
|
val const = #const fn_def
|
|
val args = #args fn_def
|
|
|
|
(* Get return variables. *)
|
|
val (fn_input_vars, fn_ret_vars) = get_fn_input_output_vars prog_info l1_infos fn_name
|
|
(* Get mapping from internal variable names to external arguments. *)
|
|
val m = Symtab.make (map fst args ~~ fn_args)
|
|
fun name_map_ext (n, T) = Symtab.lookup m n |> the
|
|
|
|
val arg_xf = mk_precond prog_info name_map_ext fn_input_vars
|
|
val ret_xf = mk_xf prog_info fn_ret_vars
|
|
val ex_xf = Abs ("s", #state_type prog_info, HOLogic.unit)
|
|
|
|
val thm = Utils.named_cterm_instantiate ctxt
|
|
(map (apsnd (Thm.cterm_of ctxt))
|
|
[("l1_f", betapply (const, Free ("rec_measure'", @{typ "nat"}))),
|
|
("ex_xf", ex_xf), ("gs", #globals_getter prog_info),
|
|
("ret_xf", ret_xf), ("arg_xf", arg_xf)])
|
|
@{thm L2corres_L2_call_simpl}
|
|
OF [#definition fn_def]
|
|
in thm end
|
|
|
|
(*
|
|
* Convert a single function. Returns a thm that looks like
|
|
* \<lbrakk> L2corres ?callee1 l1_callee1; ... \<rbrakk> \<Longrightarrow>
|
|
* L2corres (conversion result...) l1_f
|
|
* i.e. with assumptions for called functions, which are parameterised as Vars.
|
|
*)
|
|
fun convert
|
|
(lthy: local_theory) (* must contain at least L1 callee defs, but no other requirements *)
|
|
(prog_info: ProgramInfo.prog_info)
|
|
(l1_infos: FunctionInfo.function_info Symtab.table)
|
|
(do_opt: bool)
|
|
(trace_opt: bool)
|
|
(l2_function_name: string -> string)
|
|
(f_name: string)
|
|
: AutoCorresUtil.convert_result = let
|
|
val (l1_call_info, l1_infos) = FunctionInfo.calc_call_graph l1_infos;
|
|
|
|
val f_info = Utils.the' ("L2 conversion missing info for " ^ f_name)
|
|
(Symtab.lookup l1_infos f_name);
|
|
val callee_names = FunctionInfo.all_callees f_info;
|
|
val _ = filter (fn f => not (isSome (Symtab.lookup l1_infos f))) (Symset.dest callee_names)
|
|
|> (fn bad => if null bad then () else
|
|
error ("L2 conversion missing callees for " ^ f_name ^ ": " ^ commas bad));
|
|
|
|
(* Fix measure variable. *)
|
|
val ([measure_var_name], lthy') = Variable.variant_fixes ["rec_measure'"] lthy;
|
|
val measure_var = Free (measure_var_name, AutoCorresUtil.measureT);
|
|
|
|
(* Add callee assumptions. Note that our define code has to use the same assumption order. *)
|
|
val (lthy'', export_thm, callee_terms) =
|
|
AutoCorresUtil.assume_called_functions_corres lthy'
|
|
(#callees f_info) (#rec_callees f_info)
|
|
(get_expected_l2_fn_type prog_info l1_infos)
|
|
(get_expected_l2_fn_thm prog_info l1_infos)
|
|
(get_expected_l2_fn_args lthy prog_info l1_infos)
|
|
l2_function_name
|
|
measure_var;
|
|
|
|
(* Fix argument variables.
|
|
* We do this after fixing the callees, because there is still some broken code
|
|
* (e.g. in define_funcs) that requires callee var to exactly match the
|
|
* names generated by l2_function_name. *)
|
|
val f_args = map (apfst (ProgramInfo.demangle_name prog_info)) (#args f_info);
|
|
val (arg_names, lthy''') = Variable.variant_fixes (map fst f_args) lthy'';
|
|
val arg_frees = arg_names ~~ map snd f_args;
|
|
|
|
val f_l1_def = Utils.named_cterm_instantiate lthy'''
|
|
[("rec_measure'" (* FIXME *), Thm.cterm_of lthy''' measure_var)]
|
|
(#definition f_info)
|
|
val (thm, opt_traces) =
|
|
if #is_simpl_wrapper f_info
|
|
then (mk_l2corres_call_simpl_thm prog_info l1_infos lthy''' f_name (map Free arg_frees), [])
|
|
else get_l2corres_thm lthy''' prog_info l1_infos l1_call_info do_opt trace_opt f_name
|
|
(Symtab.make callee_terms) (map Free arg_frees)
|
|
(betapply (#const f_info, measure_var))
|
|
f_l1_def;
|
|
|
|
val f_body = dest_L2corres_term_abs (HOLogic.dest_Trueprop (Thm.concl_of thm));
|
|
(* Get actual recursive callees *)
|
|
val rec_callees = AutoCorresUtil.get_rec_callees callee_terms f_body;
|
|
|
|
(* Return the constants that we fixed. This will be used to process the returned body. *)
|
|
val callee_consts =
|
|
callee_terms |> map (fn (callee, (_, const, _)) => (callee, const)) |> Symtab.make;
|
|
in
|
|
{ body = f_body,
|
|
proof = Morphism.thm export_thm thm, (* Expose callee assumptions *)
|
|
rec_callees = rec_callees,
|
|
callee_consts = callee_consts,
|
|
arg_frees = dest_Free measure_var :: arg_frees,
|
|
traces = opt_traces
|
|
}
|
|
end
|
|
|
|
|
|
(* Define a previously-converted function (or recursive function group).
|
|
* lthy must include all definitions from l2_callees. *)
|
|
fun define
|
|
(lthy: local_theory)
|
|
(filename: string)
|
|
(prog_info: ProgramInfo.prog_info)
|
|
(l1_infos: FunctionInfo.function_info Symtab.table)
|
|
(l2_callees: FunctionInfo.function_info Symtab.table)
|
|
(l2_function_name: string -> string)
|
|
(funcs: AutoCorresUtil.convert_result Symtab.table)
|
|
: FunctionInfo.function_info Symtab.table * local_theory = let
|
|
(* FIXME: the abstract_fn_body step should be moved into define_funcs *)
|
|
val funcs' = Symtab.dest funcs |>
|
|
map (fn result as (name, {proof, arg_frees, ...}) =>
|
|
(name, (AutoCorresUtil.abstract_fn_body l1_infos result,
|
|
proof, arg_frees)));
|
|
val (new_thms, lthy') =
|
|
AutoCorresUtil.define_funcs
|
|
FunctionInfo.L2 filename l1_infos l2_function_name
|
|
(get_expected_l2_fn_type prog_info l1_infos)
|
|
(get_expected_l2_fn_thm prog_info l1_infos)
|
|
(get_expected_l2_fn_args lthy prog_info l1_infos)
|
|
@{thm L2corres_recguard_0}
|
|
lthy (Symtab.map (K #corres_thm) l2_callees)
|
|
funcs';
|
|
val new_infos = Symtab.map (fn f_name => fn (const, def, corres_thm) => let
|
|
val old_info = the (Symtab.lookup l1_infos f_name);
|
|
in old_info
|
|
|> FunctionInfo.function_info_upd_phase FunctionInfo.L2
|
|
|> FunctionInfo.function_info_upd_const const
|
|
|> FunctionInfo.function_info_upd_definition def
|
|
|> FunctionInfo.function_info_upd_corres_thm corres_thm
|
|
|> FunctionInfo.function_info_upd_mono_thm NONE (* added later *)
|
|
(* Update arg names to match our newly converted functions *)
|
|
|> FunctionInfo.function_info_upd_args
|
|
(map (apfst (ProgramInfo.demangle_name prog_info)) (#args old_info))
|
|
end) new_thms;
|
|
in (new_infos, lthy') end;
|
|
|
|
|
|
(*
|
|
* Translate all functions from L1 to L2 format.
|
|
*)
|
|
fun translate
|
|
(filename: string)
|
|
(prog_info: ProgramInfo.prog_info)
|
|
(l1_results: FunctionInfo.phase_results)
|
|
(existing_l1_infos: FunctionInfo.function_info Symtab.table)
|
|
(existing_l2_infos: FunctionInfo.function_info Symtab.table)
|
|
(do_opt: bool)
|
|
(trace_opt: bool)
|
|
(add_trace: string -> string -> AutoCorresData.Trace -> unit)
|
|
(l2_function_name: string -> string)
|
|
: FunctionInfo.phase_results =
|
|
let;
|
|
(* Do conversions in parallel. *)
|
|
val converted_groups =
|
|
AutoCorresUtil.par_convert
|
|
(fn lthy => fn l1_infos => convert lthy prog_info l1_infos do_opt trace_opt l2_function_name)
|
|
existing_l1_infos l1_results add_trace;
|
|
|
|
(* Sequence of new function_infos and intermediate lthys *)
|
|
val def_results = FSeq.mk (fn _ =>
|
|
(* If there's nothing to translate, we won't have a lthy to use *)
|
|
if FSeq.null l1_results then NONE else
|
|
let (* Get initial lthy from end of L1 defs *)
|
|
val (l1_lthy, _) = FSeq.list_of l1_results |> List.last;
|
|
fun define_worker lthy l2_callee_infos l1_infos f_convs =
|
|
define lthy filename prog_info l2_callee_infos l1_infos l2_function_name f_convs;
|
|
val results = AutoCorresUtil.define_funcs_sequence
|
|
l1_lthy define_worker existing_l1_infos existing_l2_infos converted_groups;
|
|
in FSeq.uncons results end);
|
|
|
|
(* Produce a mapping from each function group to its L1 phase_infos and the
|
|
* earliest intermediate lthy where it is defined. *)
|
|
val results =
|
|
def_results
|
|
|> FSeq.map (fn (lthy, l2_defs) => let
|
|
(* Add monad_mono proofs. These are done in parallel as well
|
|
* (though in practice, they already run pretty quickly). *)
|
|
val mono_thms = if FunctionInfo.is_function_recursive (snd (hd (Symtab.dest l2_defs)))
|
|
then l2_monad_mono lthy l2_defs
|
|
else Symtab.empty;
|
|
val l2_defs' = l2_defs |> Symtab.map (fn f =>
|
|
FunctionInfo.function_info_upd_mono_thm (Symtab.lookup mono_thms f));
|
|
in (lthy, l2_defs') end);
|
|
in
|
|
results
|
|
end
|
|
|
|
end
|