WIP: autocorres: draft of more modular dependencies for L1, L2
Prototype for Jira VER-517.
This commit is contained in:
parent
3400debdc2
commit
2caf6520e5
|
@ -145,6 +145,11 @@ ML_file "monad_convert.ML"
|
|||
ML_file "type_strengthen.ML"
|
||||
ML_file "autocorres.ML"
|
||||
|
||||
declare [[ML_print_depth=42]]
|
||||
ML_file "autocorres_util2.ML"
|
||||
ML_file "simpl_conv2.ML"
|
||||
ML_file "local_var_extract2.ML"
|
||||
|
||||
(* Setup "autocorres" keyword. *)
|
||||
ML {*
|
||||
Outer_Syntax.command @{command_keyword "autocorres"}
|
||||
|
@ -153,4 +158,82 @@ ML {*
|
|||
(Toplevel.theory o (fn (opt, filename) => AutoCorres.do_autocorres opt filename)))
|
||||
*}
|
||||
|
||||
ML \<open>
|
||||
|
||||
fun assume_called_functions_corres ctxt fn_info callees
|
||||
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var =
|
||||
let
|
||||
(* Assume the existence of a function, along with a theorem about its
|
||||
* behaviour. *)
|
||||
fun assume_func ctxt fn_name is_recursive_call =
|
||||
let
|
||||
val fn_args = get_fn_args fn_name
|
||||
|
||||
(* Fix a variable for the function. *)
|
||||
val ([fixed_fn_name], ctxt') = Variable.variant_fixes [get_const_name fn_name] ctxt
|
||||
val fn_free = Free (fixed_fn_name, get_fn_type fn_name)
|
||||
|
||||
(* Fix a variable for the measure and function arguments. *)
|
||||
val (measure_var_name :: arg_names, ctxt'')
|
||||
= Variable.variant_fixes ("rec_measure'" :: (map fst fn_args)) ctxt'
|
||||
val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args))
|
||||
val my_measure_var = Free (measure_var_name, @{typ nat})
|
||||
|
||||
(*
|
||||
* A measure variable is needed to handle recursion: for recursive calls,
|
||||
* we need to decrement the caller's input measure value (and our
|
||||
* assumption will need to assume this to). This is so we can later prove
|
||||
* termination of our function definition: the measure always reaches zero.
|
||||
*
|
||||
* Non-recursive calls can have a fresh value.
|
||||
*)
|
||||
val measure_var =
|
||||
if is_recursive_call then
|
||||
@{const "recguard_dec"} $ callers_measure_var
|
||||
else
|
||||
my_measure_var
|
||||
|
||||
(* Create our assumption. *)
|
||||
val assumption =
|
||||
get_fn_assumption ctxt'' fn_name fn_free fn_arg_terms
|
||||
is_recursive_call measure_var
|
||||
|> fold Logic.all (rev ((if is_recursive_call then [] else [my_measure_var]) @ fn_arg_terms))
|
||||
|> Sign.no_vars ctxt'
|
||||
|> Thm.cterm_of ctxt'
|
||||
val ([thm], ctxt''') = Assumption.add_assumes [assumption] ctxt'
|
||||
|
||||
(* Generate a morphism for escaping this context. *)
|
||||
val m = (Assumption.export_morphism ctxt''' ctxt')
|
||||
$> (Variable.export_morphism ctxt' ctxt)
|
||||
in
|
||||
(fn_free, thm, ctxt''', m)
|
||||
end
|
||||
|
||||
(* Apply each assumption. *)
|
||||
val (res, (ctxt', m)) = fold_map (
|
||||
fn (fn_name, is_recursive_call) =>
|
||||
fn (ctxt, m) =>
|
||||
let
|
||||
val (free, thm, ctxt', m') =
|
||||
assume_func ctxt fn_name is_recursive_call
|
||||
in
|
||||
((fn_name, (is_recursive_call, free, thm)), (ctxt', m' $> m))
|
||||
end)
|
||||
callees (ctxt, Morphism.identity)
|
||||
in
|
||||
(ctxt', m, res)
|
||||
end;
|
||||
|
||||
(*
|
||||
fun assume_called_functions_corres ctxt fn_info callees
|
||||
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var
|
||||
*)
|
||||
assume_called_functions_corres @{context} () [("a", false), ("r", true)]
|
||||
(K @{typ "nat \<Rightarrow> nat \<Rightarrow> string \<Rightarrow> string"})
|
||||
(fn ctxt => fn name => fn term => fn args => fn is_rec => fn meas =>
|
||||
HOLogic.mk_Trueprop (@{term "my_corres :: (string \<Rightarrow> string) \<Rightarrow> bool"} $ betapplys (term, meas :: args)))
|
||||
(fn f => if f = "a" then [("arg_a", @{typ nat})] else [("arg_r", @{typ nat})])
|
||||
I @{term "rec_measure :: nat"}
|
||||
\<close>
|
||||
|
||||
end
|
||||
|
|
|
@ -57,10 +57,36 @@ sig
|
|||
(* result *)
|
||||
-> (Proof.context * FunctionInfo.fn_info)
|
||||
|
||||
val define_funcs:
|
||||
FunctionInfo.phase ->
|
||||
string ->
|
||||
FunctionInfo.fn_info ->
|
||||
(string -> string) ->
|
||||
(string -> typ) ->
|
||||
(Proof.context -> string -> term -> term list -> bool -> term -> term) ->
|
||||
(string -> (string * typ) list) ->
|
||||
thm ->
|
||||
Proof.context ->
|
||||
thm Symtab.table ->
|
||||
'a ->
|
||||
(string * (thm * term * (string * AutoCorresData.Trace) list)) list ->
|
||||
thm list * 'a * Proof.context
|
||||
|
||||
val map_all : Proof.context -> FunctionInfo.fn_info -> (string -> FunctionInfo.function_info -> 'a) -> 'a list
|
||||
val concurrent : bool Unsynchronized.ref
|
||||
val has_simpl_body_def : local_theory -> string -> bool
|
||||
val max_run_time : Time.time option Unsynchronized.ref
|
||||
|
||||
val assume_called_functions_corres :
|
||||
Proof.context ->
|
||||
(string * bool) list ->
|
||||
(string -> typ) ->
|
||||
(Proof.context -> string -> term -> term list -> bool -> term -> term) ->
|
||||
(string -> (string * typ) list) ->
|
||||
(string -> string) ->
|
||||
term ->
|
||||
Proof.context * morphism * (string * (bool * term * thm)) list
|
||||
val get_callees : FunctionInfo.fn_info -> string -> string list * string list (* (nonrecs, recs) *)
|
||||
end;
|
||||
|
||||
structure AutoCorresUtil : AUTOCORRES_UTIL =
|
||||
|
@ -262,14 +288,32 @@ fun is_Trueprop (Const (@{const_name "Trueprop"}, _) $ _) = true
|
|||
| is_Trueprop _ = false
|
||||
|
||||
(*
|
||||
* Assume the existence of the given list of functions.
|
||||
* Assume theorems for called functions.
|
||||
*
|
||||
* A new context is returned with the assumptions in it, along with a morphism
|
||||
* used for exporting the theorems out, and a list of the functions assumed:
|
||||
*
|
||||
* (<function name>, (<is_mutually_recursive>, <function free>, <function thm>))
|
||||
* (<function name>, (<is_mutually_recursive>, <function free>, <arg frees>, <function thm>))
|
||||
*
|
||||
* In this context, the theorems refer to functions by fixed free variables.
|
||||
*
|
||||
* get_fn_args may return user-friendly argument names that clash with other names.
|
||||
* We will process these names to avoid conflicts.
|
||||
*
|
||||
* get_fn_assumption should produce the desired theorems to assume. Its arguments:
|
||||
* context (with fixed vars), callee name, callee term, arg terms, is recursive, measure term
|
||||
* (all terms are fixed free vars).
|
||||
*
|
||||
* get_const_name generates names for the free function placeholders.
|
||||
* FIXME: probably unnecessary and/or broken.
|
||||
*
|
||||
* We return two morphisms:
|
||||
* - the first one makes the assumptions visible again
|
||||
* - the second one automatically generalizes the assumed constants
|
||||
* (this exists for backwards compat; all new code should explicitly use the
|
||||
* returned free variable set)
|
||||
*)
|
||||
fun assume_called_functions_corres ctxt fn_info callees
|
||||
fun assume_called_functions_corres ctxt callees
|
||||
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var =
|
||||
let
|
||||
(* Assume the existence of a function, along with a theorem about its
|
||||
|
@ -354,7 +398,7 @@ fun gen_corres_for_function
|
|||
(ctxt : Proof.context)
|
||||
(fn_name : string) =
|
||||
let
|
||||
val _ = writeln ("Converting (" ^ FunctionInfo.string_of_phase phase ^") " ^ fn_name)
|
||||
val _ = writeln ("Converting (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name)
|
||||
val start_time = Timer.startRealTimer ()
|
||||
|
||||
(* Get a list of functions we call. *)
|
||||
|
@ -364,21 +408,30 @@ let
|
|||
(map (fn x => (x, false)) normal_calls)
|
||||
@ (map (fn x => (x, true)) recursive_calls)
|
||||
|
||||
(* Make sure the desired function name is available. *)
|
||||
val fn_target_name = get_const_name fn_name
|
||||
val ([fn_free], ctxt') = Variable.variant_fixes [fn_target_name] ctxt
|
||||
val _ = if fn_free = fn_target_name then () else
|
||||
warning ("Variable clobbered: " ^ fn_target_name ^ " -> " ^ fn_free ^ ". Translating " ^
|
||||
fn_name ^ " may fail.")
|
||||
val fn_var_morph = Variable.export_morphism ctxt' ctxt
|
||||
|
||||
(* Fix a measure variable that will be used to track recursion progress. *)
|
||||
val ([measure_var_name], ctxt') = Variable.variant_fixes ["rec_measure'"] ctxt
|
||||
val ([measure_var_name], ctxt'') = Variable.variant_fixes ["rec_measure'"] ctxt'
|
||||
val measure_var = Free (measure_var_name, @{typ nat})
|
||||
val measure_var_morph = Variable.export_morphism ctxt' ctxt
|
||||
val measure_var_morph = Variable.export_morphism ctxt'' ctxt'
|
||||
|
||||
(* Fix variables for function arguments. *)
|
||||
val fn_args = get_fn_args fn_name
|
||||
val (arg_names, ctxt'')
|
||||
= Variable.variant_fixes (map (fn (a, _) => a ^ "'arg") fn_args) ctxt'
|
||||
val (arg_names, ctxt''')
|
||||
= Variable.variant_fixes (map (fn (a, _) => a ^ "'arg") fn_args) ctxt''
|
||||
val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args))
|
||||
val fn_args_morph = Variable.export_morphism ctxt'' ctxt'
|
||||
val fn_args_morph = Variable.export_morphism ctxt''' ctxt''
|
||||
|
||||
val _ = @{trace} ("Vars", fn_free, measure_var_name, arg_names)
|
||||
(* Enter a context where we assume our callees exist. *)
|
||||
val (ctxt''', m, callee_info_and_proofs)
|
||||
= assume_called_functions_corres ctxt'' fn_info callees
|
||||
val (ctxt'''', m, callee_info_and_proofs)
|
||||
= assume_called_functions_corres ctxt''' callees
|
||||
get_fn_type get_fn_assumption get_fn_args get_const_name
|
||||
measure_var
|
||||
|
||||
|
@ -387,7 +440,7 @@ let
|
|||
* term and a tactic for proving correspondence.
|
||||
*)
|
||||
val callee_tab = Symtab.make callee_info_and_proofs
|
||||
val (body, thm, trace) = convert ctxt''' fn_name callee_tab measure_var fn_arg_terms
|
||||
val (body, thm, trace) = convert ctxt'''' fn_name callee_tab measure_var fn_arg_terms
|
||||
|
||||
(*
|
||||
* The returned body will have free variables as placeholders for the function's
|
||||
|
@ -408,7 +461,7 @@ let
|
|||
|> fold lambda (rev (map (fn x => Symtab.lookup callee_tab x |> the |> #2) normal_calls))
|
||||
|
||||
(* Export the theorem out of our context. *)
|
||||
val exported_thm = Morphism.thm (m $> fn_args_morph $> measure_var_morph) thm
|
||||
val exported_thm = Morphism.thm (m $> fn_args_morph $> measure_var_morph $> fn_var_morph) thm
|
||||
|
||||
(* TODO: allow this message to be configured *)
|
||||
val _ = @{trace} ("Converted (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name ^ " in " ^
|
||||
|
@ -423,15 +476,15 @@ end
|
|||
* "callee_thms" contains a table mapping function names to complete
|
||||
* corres proofs for those functions.
|
||||
*
|
||||
* "functions" contains a list of (fn_name, (proof, callees)). We
|
||||
* assume that all functions in this list are mutually recursive. (If
|
||||
* not, you should call "define_funcs" multiple times, each
|
||||
* "functions" contains a list of (fn_name, (proof, body, proof traces)).
|
||||
* The body should be of the form generated by gen_corres_for_function,
|
||||
* with lambda abstractions for all callees and arguments.
|
||||
*
|
||||
* We assume that all functions in this list are mutually recursive.
|
||||
* (If not, you should call "define_funcs" multiple times, each
|
||||
* time with a single function.)
|
||||
*
|
||||
* This code is quite complex in order to support mutual recursion,
|
||||
* where function definitions and proofs must simultaneously take place
|
||||
* for several functions: if we were only supporting non-recursive
|
||||
* functions, life would be easier.
|
||||
* The proof traces are stored into the theory. (This should probably be moved.)
|
||||
*)
|
||||
fun define_funcs
|
||||
(phase : FunctionInfo.phase)
|
||||
|
@ -444,7 +497,7 @@ fun define_funcs
|
|||
(rec_base_case : thm)
|
||||
(ctxt : Proof.context)
|
||||
(callee_thms : thm Symtab.table)
|
||||
_
|
||||
accum (* translate allows an accumulator, but we don't use it here *)
|
||||
(functions : (string * (thm * term * (string * AutoCorresData.Trace) list)) list)
|
||||
=
|
||||
let
|
||||
|
@ -455,6 +508,8 @@ fun define_funcs
|
|||
|
||||
val _ = writeln ("Defining (" ^ FunctionInfo.string_of_phase phase ^ ") " ^
|
||||
(Utils.commas (map get_const_name fn_names)))
|
||||
val _ = @{trace} ("define_funcs function(s)", functions)
|
||||
val _ = @{trace} ("define_funcs callee(s)", Symtab.dest callee_thms)
|
||||
|
||||
(*
|
||||
* Determine if we are in a recursive case by checking to see if the
|
||||
|
@ -495,10 +550,14 @@ fun define_funcs
|
|||
* Mutually recursive calls should be of the form "Free (fn_name, fn_type)".
|
||||
*)
|
||||
val defs = map (
|
||||
fn (fn_name, fn_body) =>
|
||||
(get_const_name fn_name,
|
||||
("rec_measure'", @{typ nat}) :: get_fn_args fn_name,
|
||||
fill_body fn_name fn_body))
|
||||
fn (fn_name, fn_body) => let
|
||||
val fn_args = get_fn_args fn_name
|
||||
(* FIXME: this retraces assume_called_functions_corres *)
|
||||
val (fn_free :: measure_free :: arg_frees, _) = Variable.variant_fixes
|
||||
(get_const_name fn_name :: "rec_measure'" :: map fst fn_args) ctxt
|
||||
in (get_const_name fn_name, (* be inflexible when it comes to fn_name *)
|
||||
(measure_free, @{typ nat}) :: (arg_frees ~~ map snd fn_args), (* changing arg names is ok *)
|
||||
fill_body fn_name fn_body) end)
|
||||
(fn_names ~~ fn_bodies)
|
||||
val (fn_def_thms, ctxt) = Utils.define_functions defs true is_recursive ctxt
|
||||
|
||||
|
@ -560,12 +619,12 @@ fun define_funcs
|
|||
(* Prove each of the predicates above, leaving any assumptions about called
|
||||
* functions unsolved. *)
|
||||
val pred_thms = map (
|
||||
fn (pred, thm, body_def) =>
|
||||
fn (pred, thm, body_def) => (@{trace} ("define_funcs applying rule", thm);
|
||||
Thm.trivial (Thm.cterm_of ctxt' pred)
|
||||
|> Utils.apply_tac "unfold body" (Hypsubst.stac ctxt' body_def 1)
|
||||
|> Utils.apply_tac "apply rule" (resolve_tac ctxt' [thm] 1)
|
||||
|> Goal.norm_result ctxt
|
||||
|> singleton (Variable.export ctxt' ctxt)
|
||||
|> singleton (Variable.export ctxt' ctxt))
|
||||
)
|
||||
(Utils.zip3 preds fn_thms fn_def_thms)
|
||||
|
||||
|
@ -630,7 +689,7 @@ fun define_funcs
|
|||
(fold (fn (phase, fn_name, trace) =>
|
||||
AutoCorresData.add_trace filename phase fn_name trace) fn_traces) ctxt
|
||||
in
|
||||
(new_thms, (), ctxt)
|
||||
(new_thms, accum, ctxt)
|
||||
end
|
||||
|
||||
(*
|
||||
|
|
|
@ -0,0 +1,664 @@
|
|||
(*
|
||||
* Copyright 2014, NICTA
|
||||
*
|
||||
* This software may be distributed and modified according to the terms of
|
||||
* the BSD 2-Clause license. Note that NO WARRANTY is provided.
|
||||
* See "LICENSE_BSD2.txt" for details.
|
||||
*
|
||||
* @TAG(NICTA_BSD)
|
||||
*)
|
||||
|
||||
(*
|
||||
* Common code for all translation phases: defining funcs, calculating dependencies, variable fixes, etc.
|
||||
*)
|
||||
|
||||
structure AutoCorresUtil2 =
|
||||
struct
|
||||
|
||||
(*
|
||||
* Maximum time to let an individual function translation phase to run for.
|
||||
*
|
||||
* Note that this is wall time, and not CPU time, so it is a very rough
|
||||
* tool.
|
||||
* FIXME: convert to proper option
|
||||
*)
|
||||
val max_run_time = Unsynchronized.ref NONE
|
||||
(*
|
||||
val max_run_time = Unsynchronized.ref (SOME (seconds 900.0))
|
||||
*)
|
||||
|
||||
exception AutocorresTimeout of string list
|
||||
|
||||
fun time_limit f v =
|
||||
case !max_run_time of
|
||||
SOME t =>
|
||||
(TimeLimit.timeLimit t f ()
|
||||
handle TimeLimit.TimeOut =>
|
||||
raise AutocorresTimeout v)
|
||||
| NONE =>
|
||||
f ()
|
||||
|
||||
(* Should we use concurrency? *)
|
||||
val concurrent = Unsynchronized.ref true;
|
||||
|
||||
(*
|
||||
* Conditionally fork a group of tasks, depending on the value
|
||||
* of "concurrent".
|
||||
*)
|
||||
datatype 'a maybe_fork = Future of 'a future | Boring of 'a
|
||||
|
||||
(* Fork a group of tasks. *)
|
||||
fun maybe_fork ctxt vals =
|
||||
if ((!concurrent) andalso not (Config.get ctxt ML_Options.exception_trace)) then
|
||||
map Future (Future.forks {
|
||||
name = "", group = NONE, deps = [], pri = ~1, interrupts = true}
|
||||
vals)
|
||||
else
|
||||
map (fn x => Boring (x ())) vals
|
||||
|
||||
(* Ensure a forked task has completed. *)
|
||||
fun maybe_join v =
|
||||
case v of
|
||||
Boring x => x
|
||||
| Future x => Future.join x
|
||||
|
||||
(* Functional map. *)
|
||||
fun par_map ctxt f a =
|
||||
if ((!concurrent) andalso not (Config.get ctxt ML_Options.exception_trace)) then
|
||||
Par_List.map f a
|
||||
else
|
||||
map f a
|
||||
|
||||
(* Does a SIMPL body exist for the given function name? *)
|
||||
fun has_simpl_body_def lthy name =
|
||||
try (fn name => Proof_Context.get_thm lthy (name ^ "_body_def")) name
|
||||
|> is_some
|
||||
|
||||
|
||||
(*
|
||||
* A translation step transforms the program from one form to another;
|
||||
* such as from SIMPL to a monadic type, or from one type of monad to
|
||||
* another.
|
||||
*
|
||||
* "filename" is the name of the file we are translating: this is required as a
|
||||
* key to fetch data stashed away by the C parser.
|
||||
*
|
||||
* "lthy" is the local theory.
|
||||
*
|
||||
* "convert" performs any proof work required for this translation step. All
|
||||
* conversions are performed in parallel, so must be able to be completed
|
||||
* without results from previous steps.
|
||||
*
|
||||
* "define" actually sets up any definitions required by the translation;
|
||||
* definition steps occur serially, but may be in parallel with conversion
|
||||
* steps whose results are not yet required.
|
||||
*
|
||||
* "prove_mono" should prove the monad_mono property for recursive functions.
|
||||
* (* It is run in parallel over all recursive groups. *)
|
||||
*
|
||||
* Because functions handed to us by the C parser may be mutually recursive and
|
||||
* such mutually recursive functions must typically be defined simultaneously,
|
||||
* "define" is handed a list of functions which must all be defined in one
|
||||
* step.
|
||||
*)
|
||||
fun translate lthy phase fn_info initial_callees convert define gen_new_info prove_mono v =
|
||||
let
|
||||
(* Get list of functions we need to translate.
|
||||
* This is a bit complicated because we need to skip over functions that
|
||||
* have already been translated, and hence we need to recalculate the
|
||||
* function call graph. *)
|
||||
val functions_to_translate =
|
||||
Symtab.dest (FunctionInfo.get_all_functions fn_info)
|
||||
|> map_filter (fn (name, info) =>
|
||||
case FunctionInfo.Phasetab.lookup (#phases info) phase of
|
||||
NONE => SOME name
|
||||
| SOME _ => NONE)
|
||||
|> Symset.make
|
||||
val fn_info_restricted = FunctionInfo.map_fn_info (fn info =>
|
||||
if Symset.contains functions_to_translate (#name info) then SOME info else NONE) fn_info
|
||||
val function_groups = FunctionInfo.get_topo_sorted_functions fn_info_restricted
|
||||
val all_functions = List.concat function_groups
|
||||
|
||||
(*
|
||||
* Convert every function.
|
||||
*
|
||||
* We perform the conversions using futures, which are run in parallel.
|
||||
* This allows us to perform conversions while we start defining functions,
|
||||
* hopefully speeding everything up on multicore systems.
|
||||
*)
|
||||
val converted_body_thms =
|
||||
map (fn name => fn _ =>
|
||||
time_limit (fn _ => convert lthy name) [name]) all_functions
|
||||
|> maybe_fork lthy
|
||||
val converted_bodies = Symtab.make (all_functions ~~ converted_body_thms)
|
||||
|
||||
(* In sorted order, define constants and proofs for the functions. *)
|
||||
fun translate fn_names (callee_thms, new_phase_infos, v, lthy) =
|
||||
let
|
||||
val defs = map (fn fn_name =>
|
||||
Symtab.lookup converted_bodies fn_name |> the |> maybe_join) fn_names
|
||||
val (proofs, v, lthy)
|
||||
= time_limit (fn _ =>
|
||||
define lthy callee_thms v (fn_names ~~ defs)) fn_names
|
||||
val new_callee_thms = fold Symtab.update_new
|
||||
(fn_names ~~ proofs) callee_thms
|
||||
val new_phase_infos = fold (fn n =>
|
||||
Symtab.update_new (n, gen_new_info lthy v (FunctionInfo.get_function_info fn_info n)))
|
||||
fn_names new_phase_infos
|
||||
in
|
||||
(new_callee_thms, new_phase_infos, v, lthy)
|
||||
end
|
||||
|
||||
val (proof_table, new_phase_infos, v, lthy)
|
||||
= fold translate function_groups (initial_callees, Symtab.empty, v, lthy)
|
||||
|
||||
val mono_thms =
|
||||
function_groups
|
||||
|> map (fn funcs => if not (FunctionInfo.is_function_recursive fn_info (hd funcs))
|
||||
then K Symtab.empty else (fn _ => time_limit (fn _ =>
|
||||
(List.mapPartial (fn f =>
|
||||
case Symtab.lookup new_phase_infos f of
|
||||
SOME phase_info =>
|
||||
SOME (FunctionInfo.function_info_add_phase phase_info
|
||||
(FunctionInfo.get_function_info fn_info f))
|
||||
| _ => NONE) funcs
|
||||
|> prove_mono lthy)) funcs))
|
||||
|> maybe_fork lthy |> map maybe_join
|
||||
|> maps Symtab.dest |> Symtab.make
|
||||
|
||||
val new_phase_infos = new_phase_infos |>
|
||||
Symtab.map (fn func => FunctionInfo.phase_info_upd_mono_thm (Symtab.lookup mono_thms func))
|
||||
in
|
||||
(lthy, FunctionInfo.add_phases (fn name => K (Symtab.lookup new_phase_infos name)) fn_info, v)
|
||||
end
|
||||
|
||||
(*
|
||||
* A translation step that maps over every function in the program.
|
||||
*
|
||||
* "convert" performs any proof work required for this translation step. All
|
||||
* conversions are performed in parallel, so must be able to be completed
|
||||
* without results from previous steps.
|
||||
*
|
||||
* We return a list of all results.
|
||||
*)
|
||||
fun map_all ctxt fn_info convert =
|
||||
par_map ctxt (uncurry convert) (FunctionInfo.get_all_functions fn_info |> Symtab.dest)
|
||||
|
||||
(*
|
||||
* Get functions called by a particular function.
|
||||
*
|
||||
* We split the result into standard calls and recursive calls (i.e., calls
|
||||
* which may recursively call back into us).
|
||||
*)
|
||||
fun get_callees fn_info fn_name =
|
||||
let
|
||||
(* Get a list of functions we call. *)
|
||||
val all_callees = FunctionInfo.get_function_callees fn_info fn_name
|
||||
|
||||
(* Fetch calls that may recursively call back to us. *)
|
||||
val recursive_calls = FunctionInfo.get_recursive_group fn_info fn_name
|
||||
|
||||
(* Remove "recursive_calls" from the standard callee set. *)
|
||||
val callees =
|
||||
Symset.make all_callees
|
||||
|> Symset.subtract (Symset.make recursive_calls)
|
||||
|> Symset.dest
|
||||
in
|
||||
(callees, recursive_calls)
|
||||
end
|
||||
|
||||
(* Is the given term a Trueprop? *)
|
||||
fun is_Trueprop (Const (@{const_name "Trueprop"}, _) $ _) = true
|
||||
| is_Trueprop _ = false
|
||||
|
||||
(*
|
||||
* Assume theorems for called functions.
|
||||
*
|
||||
* A new context is returned with the assumptions in it, along with a morphism
|
||||
* used for exporting the theorems out, and a list of the functions assumed:
|
||||
*
|
||||
* (<function name>, (<is_mutually_recursive>, <function free>, <arg frees>, <function thm>))
|
||||
*
|
||||
* In this context, the theorems refer to functions by fixed free variables.
|
||||
*
|
||||
* get_fn_args may return user-friendly argument names that clash with other names.
|
||||
* We will process these names to avoid conflicts.
|
||||
*
|
||||
* get_fn_assumption should produce the desired theorems to assume. Its arguments:
|
||||
* context (with fixed vars), callee name, callee term, arg terms, is recursive, measure term
|
||||
* (all terms are fixed free vars).
|
||||
*
|
||||
* get_const_name generates names for the free function placeholders.
|
||||
* FIXME: probably unnecessary and/or broken.
|
||||
*
|
||||
* We return two morphisms:
|
||||
* - the first one makes the assumptions visible again
|
||||
* - the second one automatically generalizes the assumed constants
|
||||
* (this exists for backwards compat; all new code should explicitly use the
|
||||
* returned free variable set)
|
||||
*)
|
||||
fun assume_called_functions_corres ctxt callees
|
||||
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var =
|
||||
let
|
||||
(* Assume the existence of a function, along with a theorem about its
|
||||
* behaviour. *)
|
||||
fun assume_func ctxt fn_name is_recursive_call =
|
||||
let
|
||||
val fn_args = get_fn_args fn_name
|
||||
|
||||
(* Fix a variable for the function. *)
|
||||
val ([fixed_fn_name], ctxt') = Variable.variant_fixes [get_const_name fn_name] ctxt
|
||||
val fn_free = Free (fixed_fn_name, get_fn_type fn_name)
|
||||
|
||||
(* Fix a variable for the measure and function arguments. *)
|
||||
val (measure_var_name :: arg_names, ctxt'')
|
||||
= Variable.variant_fixes ("rec_measure'" :: (map fst fn_args)) ctxt'
|
||||
val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args))
|
||||
val my_measure_var = Free (measure_var_name, @{typ nat})
|
||||
|
||||
(*
|
||||
* A measure variable is needed to handle recursion: for recursive calls,
|
||||
* we need to decrement the caller's input measure value (and our
|
||||
* assumption will need to assume this to). This is so we can later prove
|
||||
* termination of our function definition: the measure always reaches zero.
|
||||
*
|
||||
* Non-recursive calls can have a fresh value.
|
||||
*)
|
||||
val measure_var =
|
||||
if is_recursive_call then
|
||||
@{const "recguard_dec"} $ callers_measure_var
|
||||
else
|
||||
my_measure_var
|
||||
|
||||
(* Create our assumption. *)
|
||||
val assumption =
|
||||
get_fn_assumption ctxt'' fn_name fn_free fn_arg_terms
|
||||
is_recursive_call measure_var
|
||||
|> fold Logic.all (rev ((if is_recursive_call then [] else [my_measure_var]) @ fn_arg_terms))
|
||||
|> Sign.no_vars ctxt'
|
||||
|> Thm.cterm_of ctxt'
|
||||
val ([thm], ctxt''') = Assumption.add_assumes [assumption] ctxt'
|
||||
|
||||
(* Generate a morphism for escaping this context. *)
|
||||
val m = (Assumption.export_morphism ctxt''' ctxt')
|
||||
$> (Variable.export_morphism ctxt' ctxt)
|
||||
in
|
||||
(fn_free, thm, ctxt''', m)
|
||||
end
|
||||
|
||||
(* Apply each assumption. *)
|
||||
val (res, (ctxt', m)) = fold_map (
|
||||
fn (fn_name, is_recursive_call) =>
|
||||
fn (ctxt, m) =>
|
||||
let
|
||||
val (free, thm, ctxt', m') =
|
||||
assume_func ctxt fn_name is_recursive_call
|
||||
in
|
||||
((fn_name, (is_recursive_call, free, thm)), (ctxt', m' $> m))
|
||||
end)
|
||||
callees (ctxt, Morphism.identity)
|
||||
in
|
||||
(ctxt', m, res)
|
||||
end
|
||||
|
||||
(*
|
||||
* Convert a single function.
|
||||
*
|
||||
* Given a single concrete function, abstract that function and
|
||||
* return a theorem that shows the correspondence.
|
||||
*
|
||||
* A theorem is returned which has assumptions that called functions
|
||||
* correspond, giving a goal that this given function corresponds.
|
||||
*)
|
||||
fun gen_corres_for_function
|
||||
(phase : FunctionInfo.phase)
|
||||
(fn_info : FunctionInfo.fn_info)
|
||||
(get_fn_type : string -> typ)
|
||||
(get_fn_assumption : Proof.context -> string -> term -> term list -> bool -> term -> term)
|
||||
(get_fn_args : string -> (string * typ) list)
|
||||
(get_const_name : string -> string)
|
||||
(convert : Proof.context -> string -> ((bool * term * thm) Symtab.table) ->
|
||||
term -> term list -> (term * thm * (string * AutoCorresData.Trace) list))
|
||||
(ctxt : Proof.context)
|
||||
(fn_name : string) =
|
||||
let
|
||||
val _ = writeln ("Converting (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name)
|
||||
val start_time = Timer.startRealTimer ()
|
||||
|
||||
(* Get a list of functions we call. *)
|
||||
val (normal_calls, recursive_calls)
|
||||
= get_callees fn_info fn_name
|
||||
val callees =
|
||||
(map (fn x => (x, false)) normal_calls)
|
||||
@ (map (fn x => (x, true)) recursive_calls)
|
||||
|
||||
(* Make sure the desired function name is available. *)
|
||||
val fn_target_name = get_const_name fn_name
|
||||
val ([fn_free], ctxt') = Variable.variant_fixes [fn_target_name] ctxt
|
||||
val _ = if fn_free = fn_target_name then () else
|
||||
warning ("Variable clobbered: " ^ fn_target_name ^ " -> " ^ fn_free ^ ". Translating " ^
|
||||
fn_name ^ " may fail.")
|
||||
val fn_var_morph = Variable.export_morphism ctxt' ctxt
|
||||
|
||||
(* Fix a measure variable that will be used to track recursion progress. *)
|
||||
val ([measure_var_name], ctxt'') = Variable.variant_fixes ["rec_measure'"] ctxt'
|
||||
val measure_var = Free (measure_var_name, @{typ nat})
|
||||
val measure_var_morph = Variable.export_morphism ctxt'' ctxt'
|
||||
|
||||
(* Fix variables for function arguments. *)
|
||||
val fn_args = get_fn_args fn_name
|
||||
val (arg_names, ctxt''')
|
||||
= Variable.variant_fixes (map (fn (a, _) => a ^ "'arg") fn_args) ctxt''
|
||||
val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args))
|
||||
val fn_args_morph = Variable.export_morphism ctxt''' ctxt''
|
||||
|
||||
val _ = @{trace} ("Vars", fn_free, measure_var_name, arg_names)
|
||||
(* Enter a context where we assume our callees exist. *)
|
||||
val (ctxt'''', m, callee_info_and_proofs)
|
||||
= assume_called_functions_corres ctxt''' callees
|
||||
get_fn_type get_fn_assumption get_fn_args get_const_name
|
||||
measure_var
|
||||
|
||||
(*
|
||||
* Do the conversion. We receive a new monadic version of the SIMPL
|
||||
* term and a tactic for proving correspondence.
|
||||
*)
|
||||
val callee_tab = Symtab.make callee_info_and_proofs
|
||||
val (body, thm, trace) = convert ctxt'''' fn_name callee_tab measure_var fn_arg_terms
|
||||
|
||||
(*
|
||||
* The returned body will have free variables as placeholders for the function's
|
||||
* input parameters, for the functions it calls, and for its measure variable.
|
||||
*
|
||||
* We modify the body to be of the form:
|
||||
*
|
||||
* %fun1 fun2 rec1 rec2 measure arg1 arg2. f <...>
|
||||
*
|
||||
* That is, all non-recursive calls are abstracted out the front, followed by
|
||||
* recursive calls, followed by the measure variable, followed by function
|
||||
* arguments.
|
||||
*)
|
||||
val body =
|
||||
fold lambda (rev fn_arg_terms) body
|
||||
|> lambda measure_var
|
||||
|> fold lambda (rev (map (fn x => Symtab.lookup callee_tab x |> the |> #2) recursive_calls))
|
||||
|> fold lambda (rev (map (fn x => Symtab.lookup callee_tab x |> the |> #2) normal_calls))
|
||||
|
||||
(* Export the theorem out of our context. *)
|
||||
val exported_thm = Morphism.thm (m $> fn_args_morph $> measure_var_morph $> fn_var_morph) thm
|
||||
|
||||
(* TODO: allow this message to be configured *)
|
||||
val _ = @{trace} ("Converted (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name ^ " in " ^
|
||||
Time.toString (Timer.checkRealTimer start_time) ^ " s")
|
||||
in
|
||||
(exported_thm, body, trace)
|
||||
end
|
||||
|
||||
(*
|
||||
* Given a SIMPL function, define a constant and a proof for it.
|
||||
*
|
||||
* "callee_thms" contains a table mapping function names to complete
|
||||
* corres proofs for those functions.
|
||||
*
|
||||
* "functions" contains a list of (fn_name, (proof, body, proof traces)).
|
||||
* The body should be of the form generated by gen_corres_for_function,
|
||||
* with lambda abstractions for all callees and arguments.
|
||||
*
|
||||
* We assume that all functions in this list are mutually recursive.
|
||||
* (If not, you should call "define_funcs" multiple times, each
|
||||
* time with a single function.)
|
||||
*
|
||||
* The proof traces are stored into the theory. (This should probably be moved.)
|
||||
*)
|
||||
fun define_funcs
|
||||
(phase : FunctionInfo.phase)
|
||||
(filename : string)
|
||||
(fn_info : FunctionInfo.fn_info)
|
||||
(get_const_name : string -> string)
|
||||
(get_fn_type : string -> typ)
|
||||
(get_fn_assumption : Proof.context -> string -> term -> term list -> bool -> term -> term)
|
||||
(get_fn_args : string -> (string * typ) list)
|
||||
(rec_base_case : thm)
|
||||
(ctxt : Proof.context)
|
||||
(callee_thms : thm Symtab.table)
|
||||
accum (* translate allows an accumulator, but we don't use it here *)
|
||||
(functions : (string * (thm * term * (string * AutoCorresData.Trace) list)) list)
|
||||
=
|
||||
let
|
||||
val fn_names = map fst functions
|
||||
val fn_thms = map (snd #> #1) functions
|
||||
val fn_bodies = map (snd #> #2) functions
|
||||
val fn_traces = map (fn (fn_name, (_, _, traces)) => map (fn (module, trace) => (module, fn_name, trace)) traces) functions |> List.concat
|
||||
|
||||
val _ = writeln ("Defining (" ^ FunctionInfo.string_of_phase phase ^ ") " ^
|
||||
(Utils.commas (map get_const_name fn_names)))
|
||||
val _ = @{trace} ("function(s)", functions)
|
||||
val _ = @{trace} ("callee(s)", Symtab.dest callee_thms)
|
||||
|
||||
(*
|
||||
* Determine if we are in a recursive case by checking to see if the
|
||||
* first function in our list makes recursive calls to any other
|
||||
* function. (This "other function" will be itself if it is simple
|
||||
* recursion, but may be a different function if we are mutually
|
||||
* recursive.)
|
||||
*)
|
||||
val is_recursive = FunctionInfo.is_function_recursive fn_info (hd fn_names)
|
||||
val _ = assert (length fn_names = 1 orelse is_recursive)
|
||||
"define_funcs passed multiple functions, but they don't appear to be recursive."
|
||||
|
||||
(*
|
||||
* Patch in functions into our function body in the following order:
|
||||
*
|
||||
* * Non-recursive calls;
|
||||
* * Recursive calls
|
||||
*)
|
||||
fun fill_body fn_name body =
|
||||
let
|
||||
val (normal_calls, recursive_calls)
|
||||
= get_callees fn_info fn_name
|
||||
val non_rec_calls = map (fn x => Utils.get_term ctxt (get_const_name x)) normal_calls
|
||||
val rec_calls = map (fn x => Free (get_const_name x, get_fn_type x)) recursive_calls
|
||||
in
|
||||
body
|
||||
|> (fn t => betapplys (t, non_rec_calls))
|
||||
|> (fn t => betapplys (t, rec_calls))
|
||||
end
|
||||
|
||||
(*
|
||||
* Define our functions.
|
||||
*
|
||||
* Definitions should be of the form:
|
||||
*
|
||||
* %arg1 arg2 arg3. (arg1 + arg2 + arg3)
|
||||
*
|
||||
* Mutually recursive calls should be of the form "Free (fn_name, fn_type)".
|
||||
*)
|
||||
val defs = map (
|
||||
fn (fn_name, fn_body) => let
|
||||
val fn_args = get_fn_args fn_name
|
||||
(* FIXME: this retraces assume_called_functions_corres *)
|
||||
val (fn_free :: measure_free :: arg_frees, _) = Variable.variant_fixes
|
||||
(get_const_name fn_name :: "rec_measure'" :: map fst fn_args) ctxt
|
||||
in (get_const_name fn_name, (* be inflexible when it comes to fn_name *)
|
||||
(measure_free, @{typ nat}) :: (arg_frees ~~ map snd fn_args), (* changing arg names is ok *)
|
||||
fill_body fn_name fn_body) end)
|
||||
(fn_names ~~ fn_bodies)
|
||||
val (fn_def_thms, ctxt) = Utils.define_functions defs true is_recursive ctxt
|
||||
|
||||
(* Record the constant in our theory data. *)
|
||||
val ctxt = fold (fn (fn_name, def) =>
|
||||
Local_Theory.background_theory (
|
||||
AutoCorresData.add_def filename (FunctionInfo.string_of_phase phase ^ "def") fn_name def))
|
||||
(Utils.zip fn_names fn_def_thms) ctxt
|
||||
|
||||
(*
|
||||
* Instantiate schematic function calls in our theorems with their
|
||||
* concrete definitions.
|
||||
*)
|
||||
val combined_callees = map (get_callees fn_info) (map fst functions)
|
||||
val combined_normal_calls =
|
||||
map fst combined_callees |> flat |> sort_distinct fast_string_ord
|
||||
val combined_recursive_calls =
|
||||
map snd combined_callees |> flat |> sort_distinct fast_string_ord
|
||||
val callee_terms =
|
||||
(combined_recursive_calls @ combined_normal_calls)
|
||||
|> map (fn x => (get_const_name x, Utils.get_term ctxt (get_const_name x)))
|
||||
|> Symtab.make
|
||||
fun fill_proofs thm =
|
||||
Utils.instantiate_thm_vars ctxt
|
||||
(fn ((name, _), _) =>
|
||||
Symtab.lookup callee_terms name
|
||||
|> Option.map (Thm.cterm_of ctxt)) thm
|
||||
val fn_thms = map fill_proofs fn_thms
|
||||
|
||||
(* Fix free variable for the measure. *)
|
||||
val ([measure_var_name], ctxt') = Variable.variant_fixes ["m"] ctxt
|
||||
val measure_var = Free (measure_var_name, @{typ nat})
|
||||
|
||||
(* Generate corres predicates for each function. *)
|
||||
val preds = map (
|
||||
fn fn_name =>
|
||||
let
|
||||
fun mk_forall v t = HOLogic.all_const (Term.fastype_of v) $ lambda v t
|
||||
val fn_const = Utils.get_term ctxt' (get_const_name fn_name)
|
||||
|
||||
(* Fetch parameters to this function. *)
|
||||
val free_params =
|
||||
get_fn_args fn_name
|
||||
|> Variable.variant_frees ctxt' [measure_var]
|
||||
|> map Free
|
||||
in
|
||||
(* Generate the prop. *)
|
||||
get_fn_assumption ctxt' fn_name fn_const
|
||||
free_params is_recursive measure_var
|
||||
|> fold Logic.all (rev free_params)
|
||||
end) fn_names
|
||||
|
||||
(* We generate a goal which solves all the mutually recursive calls simultaneously. *)
|
||||
val goal = map (Object_Logic.atomize_term ctxt') preds
|
||||
|> Utils.mk_conj_list
|
||||
|> HOLogic.mk_Trueprop
|
||||
|> Thm.cterm_of ctxt'
|
||||
|
||||
(* Prove each of the predicates above, leaving any assumptions about called
|
||||
* functions unsolved. *)
|
||||
val pred_thms = map (
|
||||
fn (pred, thm, body_def) => (@{trace} thm;
|
||||
Thm.trivial (Thm.cterm_of ctxt' pred)
|
||||
|> Utils.apply_tac "unfold body" (Hypsubst.stac ctxt' body_def 1)
|
||||
|> Utils.apply_tac "apply rule" (resolve_tac ctxt' [thm] 1)
|
||||
|> Goal.norm_result ctxt
|
||||
|> singleton (Variable.export ctxt' ctxt))
|
||||
)
|
||||
(Utils.zip3 preds fn_thms fn_def_thms)
|
||||
|
||||
(* Create a set of "helper theorems", which should be sufficient to discharge
|
||||
* all assumptions that our callees refine. *)
|
||||
val helper_thms =
|
||||
(map (Symtab.lookup callee_thms #> the) combined_normal_calls) @ pred_thms
|
||||
|> map (Drule.forall_intr_vars)
|
||||
|> map (Conv.fconv_rule (Object_Logic.atomize ctxt))
|
||||
|
||||
(* Generate a proof term of equivalence using the folded definitions. *)
|
||||
val new_thm =
|
||||
Goal.init goal
|
||||
|> (fn thm =>
|
||||
if is_recursive then (
|
||||
Utils.apply_tac "start induction"
|
||||
(resolve_tac ctxt'
|
||||
[Utils.named_cterm_instantiate ctxt'
|
||||
[("n", Thm.cterm_of ctxt' measure_var)] @{thm recguard_induct}]
|
||||
1) thm
|
||||
|> Utils.apply_tac "unfold bodies"
|
||||
(EVERY (map (fn x => (EqSubst.eqsubst_tac ctxt' [1] [x] 1)) (rev fn_def_thms)))
|
||||
|> Utils.apply_tac "solve induction base cases"
|
||||
(SOLVES ((simp_tac (put_simpset HOL_ss ctxt' addsimps [rec_base_case]) 1)))
|
||||
|> Utils.apply_tac "solve remaing goals"
|
||||
(Utils.metis_insert_tac ctxt helper_thms 1)
|
||||
) else (
|
||||
Utils.apply_tac "solve remaing goals"
|
||||
(Utils.metis_insert_tac ctxt helper_thms 1) thm
|
||||
))
|
||||
|> Goal.finish ctxt'
|
||||
|
||||
(*
|
||||
* The proof above is of the form (L1corres a & L1corres b & ...).
|
||||
* Split it up into several proofs.
|
||||
*)
|
||||
fun prove_partial_l1_corres thm pred =
|
||||
Thm.cterm_of ctxt' pred
|
||||
|> Goal.init
|
||||
|> Utils.apply_tac "solving using metis" (Utils.metis_tac ctxt [thm] 1)
|
||||
|> Goal.finish ctxt'
|
||||
|
||||
(* Generate the final theorems. *)
|
||||
val new_thms =
|
||||
map (prove_partial_l1_corres new_thm) preds
|
||||
|> (Variable.export ctxt' ctxt)
|
||||
|> map (Goal.norm_result ctxt)
|
||||
|
||||
(* Record the theorems in our theory data. *)
|
||||
val ctxt = fold (fn (fn_name, thm) =>
|
||||
Local_Theory.background_theory
|
||||
(AutoCorresData.add_thm filename (FunctionInfo.string_of_phase phase ^ "corres") fn_name thm))
|
||||
(fn_names ~~ new_thms) ctxt
|
||||
|
||||
(* Add the theorems to the context. *)
|
||||
val ctxt = fold (fn (fn_name, thm) =>
|
||||
Utils.define_lemma (fn_name ^ "_" ^ FunctionInfo.string_of_phase phase ^ "corres") thm #> snd)
|
||||
(fn_names ~~ new_thms) ctxt
|
||||
|
||||
(* Add the traces to the context. *)
|
||||
val ctxt = Local_Theory.background_theory
|
||||
(fold (fn (phase, fn_name, trace) =>
|
||||
AutoCorresData.add_trace filename phase fn_name trace) fn_traces) ctxt
|
||||
in
|
||||
(new_thms, accum, ctxt)
|
||||
end
|
||||
|
||||
(*
|
||||
* Do a translation phase, converting every function from one form to another.
|
||||
*)
|
||||
fun do_translation_phase
|
||||
(phase : FunctionInfo.phase)
|
||||
(filename : string)
|
||||
(prog_info : ProgramInfo.prog_info)
|
||||
(fn_info : FunctionInfo.fn_info)
|
||||
(get_fn_type : string -> typ)
|
||||
(get_fn_assumption : local_theory -> string -> term -> term list -> bool -> term -> term)
|
||||
(get_fn_args : string -> (string * typ) list)
|
||||
(get_const_name : string -> string)
|
||||
(convert : local_theory -> string -> ((bool * term * thm) Symtab.table) ->
|
||||
term -> term list -> (term * thm * (string * AutoCorresData.Trace) list))
|
||||
(gen_new_info : local_theory -> FunctionInfo.function_info -> FunctionInfo.phase_info)
|
||||
(prove_mono : local_theory -> FunctionInfo.function_info list -> thm Symtab.table)
|
||||
(rec_base_case : thm)
|
||||
(ctxt : Proof.context) =
|
||||
let
|
||||
val do_gen_corres =
|
||||
gen_corres_for_function phase fn_info get_fn_type get_fn_assumption
|
||||
get_fn_args get_const_name convert;
|
||||
val do_define_funcs =
|
||||
define_funcs phase filename fn_info get_const_name get_fn_type
|
||||
get_fn_assumption get_fn_args rec_base_case
|
||||
(* Lookup functions that have already been translated (i.e. phase exists) *)
|
||||
val initial_callees = Symtab.dest (FunctionInfo.get_all_functions fn_info)
|
||||
|> List.mapPartial (fn (fn_name, info) =>
|
||||
FunctionInfo.Phasetab.lookup (#phases info) phase
|
||||
|> Option.mapPartial (fn phase_info =>
|
||||
AutoCorresData.get_thm (Proof_Context.theory_of ctxt) filename
|
||||
(FunctionInfo.string_of_phase phase ^ "corres") fn_name)
|
||||
|> Option.map (fn thm => (fn_name, thm)))
|
||||
|> Symtab.make
|
||||
|
||||
(* Do the translation. *)
|
||||
val (ctxt', new_fn_info, _) =
|
||||
translate ctxt phase fn_info initial_callees do_gen_corres do_define_funcs
|
||||
(fn lthy => K (gen_new_info lthy)) prove_mono ()
|
||||
|
||||
(* Map function information. *)
|
||||
in
|
||||
(ctxt', new_fn_info)
|
||||
end
|
||||
|
||||
end
|
|
@ -25,21 +25,29 @@ sig
|
|||
|
||||
(* Function info for a single phase. *)
|
||||
type phase_info = {
|
||||
phase : phase,
|
||||
args : (string * typ) list,
|
||||
return_type : typ,
|
||||
const : term,
|
||||
raw_const : term,
|
||||
(*
|
||||
callees : string list,
|
||||
rec_callees : string list,
|
||||
*)
|
||||
definition : thm,
|
||||
mono_thm : thm option,
|
||||
phase : phase
|
||||
mono_thm : thm option
|
||||
};
|
||||
val phase_info_upd_phase : phase -> phase_info -> phase_info;
|
||||
val phase_info_upd_args : (string * typ) list -> phase_info -> phase_info;
|
||||
val phase_info_upd_return_type : typ -> phase_info -> phase_info;
|
||||
(* also updates raw_const *)
|
||||
val phase_info_upd_const : term -> phase_info -> phase_info;
|
||||
(*
|
||||
val phase_info_upd_callees : string list -> phase_info -> phase_info;
|
||||
val phase_info_upd_rec_callees : string list -> phase_info -> phase_info;
|
||||
*)
|
||||
val phase_info_upd_definition : thm -> phase_info -> phase_info;
|
||||
val phase_info_upd_mono_thm : thm option -> phase_info -> phase_info;
|
||||
val phase_info_upd_phase : phase -> phase_info -> phase_info;
|
||||
|
||||
(* Function info for a single function. *)
|
||||
type function_info = {
|
||||
|
@ -101,6 +109,9 @@ structure Phasetab = Table(
|
|||
val ord = phase_ord);
|
||||
|
||||
type phase_info = {
|
||||
(* The translation phase for this definition. *)
|
||||
phase : phase,
|
||||
|
||||
(* Arguments of the function, in order, excluding measure variables. *)
|
||||
args : (string * typ) list,
|
||||
|
||||
|
@ -120,10 +131,7 @@ type phase_info = {
|
|||
definition : thm,
|
||||
|
||||
(* monad_mono theorem for the function, if it is recursive. *)
|
||||
mono_thm : thm option,
|
||||
|
||||
(* The translation phase for this definition. *)
|
||||
phase : phase
|
||||
mono_thm : thm option
|
||||
};
|
||||
|
||||
type function_info = {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,630 @@
|
|||
(*
|
||||
* Copyright 2014, NICTA
|
||||
*
|
||||
* This software may be distributed and modified according to the terms of
|
||||
* the BSD 2-Clause license. Note that NO WARRANTY is provided.
|
||||
* See "LICENSE_BSD2.txt" for details.
|
||||
*
|
||||
* @TAG(NICTA_BSD)
|
||||
*)
|
||||
|
||||
(*
|
||||
* Automatically convert SIMPL code fragments into a monadic form, with proofs
|
||||
* of correspondence between the two.
|
||||
*)
|
||||
structure SimplConv2 =
|
||||
struct
|
||||
|
||||
(* Convenience shortcuts. *)
|
||||
val warning = Utils.ac_warning
|
||||
val apply_tac = Utils.apply_tac
|
||||
val the' = Utils.the'
|
||||
|
||||
exception FunctionNotFound of string
|
||||
|
||||
val simpl_conv_ss = AUTOCORRES_SIMPSET
|
||||
|
||||
(*
|
||||
* Given a function constant name such as "Blah.foo_'proc", guess the underlying
|
||||
* function name "foo".
|
||||
*)
|
||||
fun guess_function_name const_name =
|
||||
const_name |> unsuffix "_'proc" |> Long_Name.base_name
|
||||
|
||||
(* Generate a L1 monad type. *)
|
||||
fun mk_l1monadT stateT =
|
||||
Utils.gen_typ @{typ "'a L1_monad"} [stateT]
|
||||
|
||||
(*
|
||||
* Extract the L1 monadic term out of a L1corres constant.
|
||||
*)
|
||||
fun get_L1corres_monad @{term_pat "L1corres _ _ ?l1_monad _"} = l1_monad
|
||||
| get_L1corres_monad t = raise TERM ("get_L1corres_monad", [t])
|
||||
|
||||
(*
|
||||
* Generate a SIMPL term that calls the given function.
|
||||
*
|
||||
* For instance, we might return:
|
||||
*
|
||||
* "Call foo_'proc"
|
||||
*)
|
||||
fun mk_SIMPL_call_term ctxt prog_info fn_info target_fn =
|
||||
@{mk_term "Call ?proc :: (?'s, int, strictc_errortype) com" (proc, 's)}
|
||||
(FunctionInfo.get_phase_info fn_info FunctionInfo.CP target_fn |> #const, #state_type prog_info)
|
||||
|
||||
(*
|
||||
* Construct a correspondence lemma between a given monadic term and a SIMPL fragment.
|
||||
*
|
||||
* The term is of the form:
|
||||
*
|
||||
* L1corres check_termination \<Gamma> monad simpl
|
||||
*)
|
||||
fun mk_L1corres_prop prog_info check_termination monad_term simpl_term =
|
||||
@{mk_term "L1corres ?ct ?gamma ?monad ?simpl" (ct, gamma, monad, simpl)}
|
||||
(Utils.mk_bool check_termination, #gamma prog_info, monad_term, simpl_term)
|
||||
|
||||
(*
|
||||
* Construct a prop claiming that the given term is equivalent to
|
||||
* a call to the given SIMPL function:
|
||||
*
|
||||
* L1corres ct \<Gamma> <term> (Call foo_'proc)
|
||||
*
|
||||
*)
|
||||
fun mk_L1corres_call_prop ctxt prog_info fn_info check_termination target_fn_name term =
|
||||
mk_L1corres_prop prog_info check_termination term
|
||||
(mk_SIMPL_call_term ctxt prog_info fn_info target_fn_name)
|
||||
|> HOLogic.mk_Trueprop
|
||||
|
||||
(*
|
||||
* Convert a SIMPL fragment into a monadic term.
|
||||
*
|
||||
* We return the monadic version of the input fragment and a tactic
|
||||
* to prove correspondence.
|
||||
*)
|
||||
fun simpl_conv'
|
||||
(prog_info : ProgramInfo.prog_info)
|
||||
(fn_info : FunctionInfo.fn_info)
|
||||
(ctxt : Proof.context)
|
||||
(callee_terms : (bool * term * thm) Symtab.table)
|
||||
(measure_var : term)
|
||||
(simpl_term : term) =
|
||||
let
|
||||
fun prove_term subterms base_thm result_term =
|
||||
let
|
||||
val subterms' = map (simpl_conv' prog_info fn_info ctxt
|
||||
callee_terms measure_var) subterms;
|
||||
val converted_terms = map fst subterms';
|
||||
val subproofs = map snd subterms';
|
||||
val new_term = (result_term converted_terms);
|
||||
in
|
||||
(new_term, (resolve_tac ctxt [base_thm] 1) THEN (EVERY subproofs))
|
||||
end
|
||||
|
||||
(* Construct a "L1 monad" term with the given arguments applied to it. *)
|
||||
fun mk_l1 (Const (a, _)) args =
|
||||
Term.betapplys (Const (a, map fastype_of args
|
||||
---> mk_l1monadT (#state_type prog_info)), args)
|
||||
|
||||
(* Convert a set construct into a predicate construct. *)
|
||||
fun set_to_pred t =
|
||||
(Const (@{const_name L1_set_to_pred},
|
||||
fastype_of t --> (HOLogic.dest_setT (fastype_of t) --> @{typ bool})) $ t)
|
||||
in
|
||||
(case simpl_term of
|
||||
(*
|
||||
* Various easy cases of SIMPL to monadic conversion.
|
||||
*)
|
||||
|
||||
(Const (@{const_name Skip}, _)) =>
|
||||
prove_term [] @{thm L1corres_skip}
|
||||
(fn _ => mk_l1 @{term "L1_skip"} [])
|
||||
|
||||
| (Const (@{const_name Seq}, _) $ left $ right) =>
|
||||
prove_term [left, right] @{thm L1corres_seq}
|
||||
(fn [l, r] => mk_l1 @{term "L1_seq"} [l, r])
|
||||
|
||||
| (Const (@{const_name Basic}, _) $ m) =>
|
||||
prove_term [] @{thm L1corres_modify}
|
||||
(fn _ => mk_l1 @{term "L1_modify"} [m])
|
||||
|
||||
| (Const (@{const_name Cond}, _) $ c $ left $ right) =>
|
||||
prove_term [left, right] @{thm L1corres_condition}
|
||||
(fn [l, r] => mk_l1 @{term "L1_condition"} [set_to_pred c, l, r])
|
||||
|
||||
| (Const (@{const_name Catch}, _) $ left $ right) =>
|
||||
prove_term [left, right] @{thm L1corres_catch}
|
||||
(fn [l, r] => mk_l1 @{term "L1_catch"} [l, r])
|
||||
|
||||
| (Const (@{const_name While}, _) $ c $ body) =>
|
||||
prove_term [body] @{thm L1corres_while}
|
||||
(fn [body] => mk_l1 @{term "L1_while"} [set_to_pred c, body])
|
||||
|
||||
| (Const (@{const_name Throw}, _)) =>
|
||||
prove_term [] @{thm L1corres_throw}
|
||||
(fn _ => mk_l1 @{term "L1_throw"} [])
|
||||
|
||||
| (Const (@{const_name Guard}, _) $ _ $ c $ body) =>
|
||||
prove_term [body] @{thm L1corres_guard}
|
||||
(fn [body] => mk_l1 @{term "L1_seq"} [mk_l1 @{term "L1_guard"} [set_to_pred c], body])
|
||||
|
||||
| @{term_pat "lvar_nondet_init _ ?upd"} =>
|
||||
prove_term [] @{thm L1corres_init}
|
||||
(fn _ => mk_l1 @{term "L1_init"} [upd])
|
||||
|
||||
| (Const (@{const_name Spec}, _) $ s) =>
|
||||
prove_term [] @{thm L1corres_spec}
|
||||
(fn _ => mk_l1 @{term "L1_spec"} [s])
|
||||
|
||||
| (Const (@{const_name guarded_spec_body}, _) $ _ $ s) =>
|
||||
prove_term [] @{thm L1corres_guarded_spec}
|
||||
(fn _ => mk_l1 @{term "L1_spec"} [s])
|
||||
|
||||
(*
|
||||
* "call": This is primarily what is output by the C parser. We
|
||||
* accept input terms of the form:
|
||||
*
|
||||
* "call <argument_setup> <proc_to_call> <locals_reset> (%_ s. Basic (<store return value> s))".
|
||||
*
|
||||
* In particular, the last argument needs to be of precisely the
|
||||
* form above. SIMPL, in theory, supports complex expressions in
|
||||
* the last argument. In practice, the C parser only outputs
|
||||
* the form above, and supporting more would be a pain.
|
||||
*)
|
||||
| (Const (@{const_name call}, _) $ a $ (fn_const as Const (b, _)) $ c $ (Abs (_, _, Abs (_, _, (Const (@{const_name Basic}, _) $ d))))) =>
|
||||
let
|
||||
val state_type = #state_type prog_info
|
||||
val target_fn_name = FunctionInfo.get_function_from_const fn_info fn_const |> Option.map #name
|
||||
in
|
||||
case Option.mapPartial (Symtab.lookup callee_terms) target_fn_name of
|
||||
NONE =>
|
||||
(* If no proof of our callee could be found, we emit a call to
|
||||
* "fail". This may happen for functions without bodies. *)
|
||||
let
|
||||
val _ = warning ("Function '" ^ guess_function_name b ^ "' contains no body. "
|
||||
^ "Replacing the function call with a \"fail\" command.")
|
||||
in
|
||||
prove_term [] @{thm L1corres_fail} (fn _ => mk_l1 @{term "L1_fail"} [])
|
||||
end
|
||||
| SOME (is_rec, term, thm) =>
|
||||
let
|
||||
(*
|
||||
* If this is an internal recursive call, decrement the measure.
|
||||
* Or if this is calling a recursive function, use measure_call.
|
||||
* If the callee isn't recursive, it doesn't use the measure var
|
||||
* and we can just give an arbitrary value.
|
||||
*)
|
||||
val target_fn_name = (the target_fn_name)
|
||||
val target_rec = FunctionInfo.is_function_recursive fn_info target_fn_name
|
||||
val term' =
|
||||
if is_rec then
|
||||
term $ (@{term "recguard_dec"} $ measure_var)
|
||||
else if target_rec then
|
||||
@{mk_term "measure_call ?f" f} term
|
||||
else
|
||||
term $ @{term "undefined :: nat"}
|
||||
in
|
||||
(* Generate the term. *)
|
||||
(mk_l1 @{term "L1_call"}
|
||||
[a, term', c, absdummy state_type d],
|
||||
resolve_tac ctxt [if is_rec orelse not target_rec then
|
||||
@{thm L1corres_reccall} else @{thm L1corres_call}] 1
|
||||
THEN resolve_tac ctxt [thm] 1)
|
||||
end
|
||||
end
|
||||
|
||||
(* TODO : Don't currently support DynCom *)
|
||||
| other => Utils.invalid_term "a SIMPL term" other)
|
||||
end
|
||||
|
||||
(* Perform post-processing on a theorem. *)
|
||||
fun cleanup_thm ctxt do_opt trace_opt prog_info fn_name thm =
|
||||
let
|
||||
(* Measure the term. *)
|
||||
fun gather_stats phase thm =
|
||||
Statistics.gather ctxt phase fn_name
|
||||
(Thm.concl_of thm |> HOLogic.dest_Trueprop |> get_L1corres_monad)
|
||||
val _ = gather_stats "L1" thm
|
||||
|
||||
(* For each function, we want to prepend a statement that sets its return
|
||||
* value undefined. It is actually always defined, but our analysis isn't
|
||||
* sophisticated enough to realise. *)
|
||||
fun prepend_undef thm fn_name =
|
||||
let
|
||||
val ret_var_name =
|
||||
Symtab.lookup (ProgramAnalysis.get_fninfo (#csenv prog_info)) fn_name
|
||||
|> the
|
||||
|> (fn (ctype, _, _) => NameGeneration.return_var_name ctype |> MString.dest)
|
||||
val ret_var_setter = Symtab.lookup (#var_setters prog_info) ret_var_name
|
||||
val ret_var_getter = Symtab.lookup (#var_getters prog_info) ret_var_name
|
||||
fun try_unify (x::xs) =
|
||||
((x ()) handle THM _ => try_unify xs)
|
||||
in
|
||||
case ret_var_setter of
|
||||
SOME _ =>
|
||||
(* Prepend the L1_init code. *)
|
||||
Utils.named_cterm_instantiate ctxt
|
||||
[("X", Thm.cterm_of ctxt (the ret_var_setter)),
|
||||
("X'", Thm.cterm_of ctxt (the ret_var_getter))]
|
||||
(try_unify [
|
||||
(fn _ => @{thm L1corres_prepend_unknown_var_recguard} OF [thm]),
|
||||
(fn _ => @{thm L1corres_prepend_unknown_var} OF [thm]),
|
||||
(fn _ => @{thm L1corres_prepend_unknown_var'} OF [thm])])
|
||||
|
||||
(* Discharge the given proof obligation. *)
|
||||
|> simp_tac (put_simpset simpl_conv_ss ctxt) 1 |> Seq.hd
|
||||
| NONE => thm
|
||||
end
|
||||
val thm = prepend_undef thm fn_name
|
||||
|
||||
(* Conversion combinator to apply a conversion only to the L1 subterm of a
|
||||
* L1corres term. *)
|
||||
fun l1conv conv = (Conv.arg_conv (Utils.nth_arg_conv 3 conv))
|
||||
|
||||
(* Conversion to simplify guards. *)
|
||||
fun guard_conv' c =
|
||||
case (Thm.term_of c) of
|
||||
(Const (@{const_name "L1_guard"}, _) $ _) =>
|
||||
Simplifier.asm_full_rewrite (put_simpset simpl_conv_ss ctxt) c
|
||||
| _ =>
|
||||
Conv.all_conv c
|
||||
val guard_conv = Conv.top_conv (K guard_conv') ctxt
|
||||
|
||||
(* Apply all the conversions on the generated term. *)
|
||||
val (thm, guard_opt_trace) = AutoCorresTrace.fconv_rule_maybe_traced ctxt (l1conv guard_conv) thm trace_opt
|
||||
val (thm, peephole_opt_trace) =
|
||||
AutoCorresTrace.fconv_rule_maybe_traced ctxt
|
||||
(l1conv (Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps
|
||||
(if do_opt then Utils.get_rules ctxt @{named_theorems L1opt} else []))))
|
||||
thm trace_opt
|
||||
val _ = gather_stats "L1peep" thm
|
||||
|
||||
(* Rewrite exceptions. *)
|
||||
val (thm, exn_opt_trace) = AutoCorresTrace.fconv_rule_maybe_traced ctxt
|
||||
(l1conv (ExceptionRewrite.except_rewrite_conv ctxt do_opt)) thm trace_opt
|
||||
val _ = gather_stats "L1except" thm
|
||||
in
|
||||
(thm,
|
||||
[("L1 guard opt", guard_opt_trace), ("L1 peephole opt", peephole_opt_trace), ("L1 exception opt", exn_opt_trace)]
|
||||
|> List.mapPartial (fn (n, tr) => case tr of NONE => NONE | SOME x => SOME (n, AutoCorresData.SimpTrace x))
|
||||
)
|
||||
end
|
||||
|
||||
(*
|
||||
* Get theorems about a SIMPL body in a format convenient to reason about.
|
||||
*
|
||||
* In particular, we unfold parts of SIMPL where we would prefer to reason
|
||||
* about raw definitions instead of more abstract constructs generated
|
||||
* by the C parser.
|
||||
*)
|
||||
fun get_simpl_body ctxt fn_info fn_name =
|
||||
let
|
||||
(* Find the definition of the given function. *)
|
||||
val simpl_thm = #definition (FunctionInfo.get_phase_info fn_info FunctionInfo.CP fn_name)
|
||||
handle ERROR _ => raise FunctionNotFound fn_name;
|
||||
|
||||
(* Unfold terms in the body which we don't want to deal with. *)
|
||||
val unfolded_simpl_thm =
|
||||
Conv.fconv_rule (Utils.rhs_conv
|
||||
(Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps
|
||||
(Utils.get_rules ctxt @{named_theorems L1unfold}))))
|
||||
simpl_thm
|
||||
val unfolded_simpl_term = Thm.concl_of unfolded_simpl_thm |> Utils.rhs_of;
|
||||
|
||||
(*
|
||||
* Get the implementation definition for this function. These rules are of
|
||||
* the form "Gamma foo_'proc = Some foo_body".
|
||||
*)
|
||||
val impl_thm =
|
||||
Proof_Context.get_thm ctxt (fn_name ^ "_impl")
|
||||
|> Local_Defs.unfold ctxt [unfolded_simpl_thm]
|
||||
|> SOME
|
||||
handle (ERROR _) => NONE
|
||||
in
|
||||
(unfolded_simpl_term, unfolded_simpl_thm, impl_thm)
|
||||
end
|
||||
|
||||
fun get_l1corres_thm prog_info fn_info check_termination ctxt do_opt trace_opt fn_name
|
||||
callee_terms measure_var = let
|
||||
val thy = Proof_Context.theory_of ctxt
|
||||
val (simpl_term, simpl_thm, impl_thm) = get_simpl_body ctxt fn_info fn_name
|
||||
|
||||
(* Fetch stats on pre-converted term. *)
|
||||
val _ = Statistics.gather ctxt "CParser" fn_name simpl_term
|
||||
|
||||
(*
|
||||
* Do the conversion. We receive a new monadic version of the SIMPL
|
||||
* term and a tactic for proving correspondence.
|
||||
*)
|
||||
val (monad, tactic) = simpl_conv' prog_info fn_info ctxt
|
||||
callee_terms measure_var simpl_term
|
||||
|
||||
(*
|
||||
* Wrap the monad in a "L1_recguard" statement, which triggers
|
||||
* failure when the measure reaches zero. This lets us automatically
|
||||
* prove termination of the recursive function.
|
||||
*)
|
||||
val is_recursive = FunctionInfo.is_function_recursive fn_info fn_name
|
||||
val (monad, tactic) =
|
||||
if is_recursive then
|
||||
(Utils.mk_term thy @{term "L1_recguard"} [measure_var, monad],
|
||||
(resolve_tac ctxt @{thms L1corres_recguard} 1 THEN tactic))
|
||||
else
|
||||
(monad, tactic)
|
||||
|
||||
(*
|
||||
* Return a new theorem of correspondence between the original
|
||||
* SIMPL body (with folded constants) and the output monad term.
|
||||
*)
|
||||
in
|
||||
mk_L1corres_call_prop ctxt prog_info fn_info check_termination fn_name monad
|
||||
|> Thm.cterm_of ctxt
|
||||
|> Goal.init
|
||||
|> (case impl_thm of
|
||||
NONE => apply_tac "unfold SIMPL body" (resolve_tac ctxt @{thms L1corres_undefined_call} 1)
|
||||
| SOME def => apply_tac "unfold SIMPL body" (resolve_tac ctxt @{thms L1corres_Call} 1 THEN
|
||||
resolve_tac ctxt [def] 1)
|
||||
#> apply_tac "solve L1corres" tactic)
|
||||
|> Goal.finish ctxt
|
||||
(* Apply simplifications to the L1 term. *)
|
||||
|> cleanup_thm ctxt do_opt trace_opt prog_info fn_name
|
||||
end
|
||||
|
||||
fun get_body_of_l1corres_thm thm =
|
||||
(* Extract the monad from the thm. *)
|
||||
Thm.concl_of thm
|
||||
|> HOLogic.dest_Trueprop
|
||||
|> get_L1corres_monad
|
||||
|
||||
fun split_conj thm =
|
||||
(thm RS @{thm conjunct1}) :: split_conj (thm RS @{thm conjunct2})
|
||||
handle THM _ => [thm]
|
||||
|
||||
(* Prove monad_mono for recursive functions. *)
|
||||
fun l1_monad_mono lthy (l1_defs : FunctionInfo.phase_info Symtab.table) =
|
||||
let
|
||||
val l1_defs' = Symtab.dest l1_defs;
|
||||
fun mk_stmt [func] = @{mk_term "monad_mono ?f" f} func
|
||||
| mk_stmt (func :: funcs) = @{mk_term "monad_mono ?f \<and> ?g" (f, g)} (func, mk_stmt funcs);
|
||||
val mono_thm = @{term "Trueprop"} $ mk_stmt (map (#const o snd) l1_defs');
|
||||
val func_expand = map (fn (_, l1_def) =>
|
||||
EqSubst.eqsubst_tac lthy [0] [Utils.abs_def lthy (#definition l1_def)]) l1_defs';
|
||||
val tac =
|
||||
REPEAT (EqSubst.eqsubst_tac lthy [0]
|
||||
[@{thm monad_mono_alt_def}, @{thm all_conj_distrib} RS @{thm sym}] 1)
|
||||
THEN resolve_tac lthy @{thms allI} 1 THEN resolve_tac lthy @{thms nat.induct} 1
|
||||
THEN EVERY (map (fn expand =>
|
||||
TRY (resolve_tac lthy @{thms conjI} 1)
|
||||
THEN expand 1
|
||||
THEN resolve_tac lthy @{thms monad_mono_step_L1_recguard_0} 1) func_expand)
|
||||
THEN REPEAT (eresolve_tac lthy @{thms conjE} 1)
|
||||
THEN EVERY (map (fn expand =>
|
||||
TRY (resolve_tac lthy @{thms conjI} 1)
|
||||
THEN expand 1
|
||||
THEN REPEAT (FIRST [assume_tac lthy 1,
|
||||
resolve_tac lthy @{thms L1_monad_mono_step_rules} 1]))
|
||||
func_expand);
|
||||
in
|
||||
Goal.prove lthy [] [] mono_thm (K tac)
|
||||
|> split_conj
|
||||
|> (fn thms => map fst l1_defs' ~~ thms)
|
||||
|> Symtab.make
|
||||
end
|
||||
|
||||
|
||||
(* For functions that are not translated, just generate a trivial wrapper. *)
|
||||
fun mk_l1corres_call_simpl_thm fn_info check_termination ctxt fn_name = let
|
||||
val info = FunctionInfo.get_phase_info fn_info FunctionInfo.CP fn_name
|
||||
val const = #const info
|
||||
val impl_thm = Proof_Context.get_thm ctxt (fn_name ^ "_impl")
|
||||
val gamma = safe_mk_meta_eq impl_thm |> Thm.concl_of |> Logic.dest_equals
|
||||
|> fst |> (fn (f $ _) => f | t => raise TERM ("gamma", [t]))
|
||||
val thm = Utils.named_cterm_instantiate ctxt
|
||||
[("ct", Thm.cterm_of ctxt (Utils.mk_bool check_termination)),
|
||||
("proc", Thm.cterm_of ctxt const),
|
||||
("Gamma", Thm.cterm_of ctxt gamma)]
|
||||
@{thm L1corres_call_simpl}
|
||||
in thm end
|
||||
|
||||
|
||||
(*
|
||||
* Convert a single function. Returns a thm that looks like
|
||||
* \<lbrakk> L1corres ?callee1 (Call callee1_'proc); ... \<rbrakk> \<Longrightarrow>
|
||||
* L1corres (conversion result...) (Call f_'proc)
|
||||
* i.e. with assumptions for called functions, which are parameterised as Vars.
|
||||
*)
|
||||
fun convert
|
||||
(lthy: local_theory)
|
||||
(prog_info: ProgramInfo.prog_info)
|
||||
(fn_info: FunctionInfo.fn_info) (* needs CP phase_info for each callee *)
|
||||
(check_termination: bool)
|
||||
(do_opt: bool)
|
||||
(trace_opt: bool)
|
||||
(l1_function_name: string -> string)
|
||||
(f_name: string)
|
||||
: thm * (string * typ) list = let
|
||||
val callee_names = FunctionInfo.get_function_callees fn_info f_name;
|
||||
(* TODO sanity check:
|
||||
- all SIMPL defs exist *)
|
||||
|
||||
val measureT = @{typ nat};
|
||||
|
||||
(* All L1 functions have the same signature: measure \<Rightarrow> L1_monad *)
|
||||
val l1_fn_type = measureT --> mk_l1monadT (#state_type prog_info);
|
||||
|
||||
(* L1corres for f's callees. *)
|
||||
fun get_l1_fn_assumption ctxt fn_name free _ _ measure_var =
|
||||
mk_L1corres_call_prop ctxt prog_info fn_info check_termination fn_name (betapply (free, measure_var));
|
||||
|
||||
(* Fix measure variable. *)
|
||||
val ([measure_var_name], lthy') = Variable.variant_fixes ["rec_measure'"] lthy;
|
||||
val measure_var = Free (measure_var_name, measureT);
|
||||
|
||||
(* Add callee assumptions. Note that our define code has to use the same assumption order. *)
|
||||
val group = Symset.make (FunctionInfo.get_recursive_group fn_info f_name);
|
||||
val callee_terms =
|
||||
map (fn callee => (callee, Symset.contains group callee)) callee_names;
|
||||
val (lthy'', export_thm, callee_terms) =
|
||||
AutoCorresUtil.assume_called_functions_corres lthy'
|
||||
callee_terms
|
||||
(K l1_fn_type)
|
||||
get_l1_fn_assumption
|
||||
(K [])
|
||||
l1_function_name
|
||||
measure_var;
|
||||
|
||||
val (thm, opt_traces) =
|
||||
if #is_simpl_wrapper (FunctionInfo.get_function_info fn_info f_name)
|
||||
then (mk_l1corres_call_simpl_thm fn_info check_termination lthy'' f_name, [])
|
||||
else get_l1corres_thm prog_info fn_info check_termination lthy'' do_opt trace_opt f_name
|
||||
(Symtab.make callee_terms) measure_var;
|
||||
in (Morphism.thm export_thm thm,
|
||||
(* Provide the fixed vars so the user can generalize/instantiate them *)
|
||||
[dest_Free measure_var])
|
||||
end
|
||||
|
||||
|
||||
(* Define a previously-converted function (or recursive function group).
|
||||
* lthy must include all definitions from l1_callees *)
|
||||
fun define
|
||||
(lthy: local_theory)
|
||||
(filename: string)
|
||||
(prog_info: ProgramInfo.prog_info)
|
||||
(fn_info: FunctionInfo.fn_info) (* doesn't need to have L1 info *)
|
||||
(check_termination: bool)
|
||||
(l1_callees: (FunctionInfo.phase_info * thm) Symtab.table) (* L1 callees & corres thms *)
|
||||
(l1_function_name: string -> string)
|
||||
(funcs: (string * thm * (string * typ) list) list) (* name, corres, arg frees *)
|
||||
: (FunctionInfo.phase_info * thm) Symtab.table * local_theory = let
|
||||
(* FIXME: dedup with convert *)
|
||||
|
||||
(* All L1 functions have the same signature: measure \<Rightarrow> L1_monad *)
|
||||
val measureT = @{typ nat};
|
||||
val l1_fn_type = measureT --> mk_l1monadT (#state_type prog_info);
|
||||
|
||||
(* L1corres for f's callees. *)
|
||||
fun get_l1_fn_assumption ctxt fn_name free _ _ measure_var =
|
||||
mk_L1corres_call_prop ctxt prog_info fn_info check_termination fn_name (betapply (free, measure_var));
|
||||
|
||||
(* FIXME: pass this from assume_called_functions_corres, etc. *)
|
||||
fun guess_callee_var thm callee = let
|
||||
val base_name = l1_function_name callee;
|
||||
val mentioned_vars = Term.add_vars (Thm.prop_of thm) [];
|
||||
in hd (filter (fn ((v, _), _) => v = base_name) mentioned_vars) end;
|
||||
|
||||
fun prepare_fn_body (fn_name, corres_thm, measure_free) = let
|
||||
val @{term_pat "Trueprop (L1corres _ _ ?body _)"} = Thm.concl_of corres_thm;
|
||||
val (callees, recursive_callees) = AutoCorresUtil.get_callees fn_info fn_name;
|
||||
val calls = map (fn c => Var (guess_callee_var corres_thm c)) callees;
|
||||
val recursive_calls = map (fn c => Var (guess_callee_var corres_thm c)) recursive_callees;
|
||||
(*
|
||||
* The returned body will have free variables as placeholders for the function's
|
||||
* measure parameter and other arguments, and schematic variables for the functions it calls.
|
||||
*
|
||||
* We modify the body to be of the form:
|
||||
*
|
||||
* %fun1 fun2 rec1 rec2 measure arg1 arg2. f <...>
|
||||
*
|
||||
* That is, all non-recursive calls are abstracted out the front, followed by
|
||||
* recursive calls, followed by the measure variable, followed by function
|
||||
* arguments (none for L1). This is the format expected by define_funcs.
|
||||
*)
|
||||
val abs_body = body
|
||||
|> fold lambda (rev (map Free measure_free))
|
||||
|> fold lambda (rev recursive_calls)
|
||||
|> fold lambda (rev calls);
|
||||
in abs_body end;
|
||||
|
||||
val fn_info' = FunctionInfo.add_phases (fn f => K (Option.map fst (Symtab.lookup l1_callees f))) fn_info;
|
||||
val funcs' = funcs |>
|
||||
map (fn (name, thm, frees) =>
|
||||
(name, (* FIXME: define_funcs needs this currently *)
|
||||
(Thm.generalize ([], map fst frees) (Thm.maxidx_of thm + 1) thm,
|
||||
prepare_fn_body (name, thm, frees), [])));
|
||||
val _ = @{trace} ("L1.define", map (fn (name, (thm, body, _)) => (name, thm, Thm.cterm_of lthy body)) funcs');
|
||||
val (corres_thms, (), lthy') =
|
||||
AutoCorresUtil.define_funcs FunctionInfo.L1 filename fn_info'
|
||||
l1_function_name (K l1_fn_type) get_l1_fn_assumption (K []) @{thm L1corres_recguard_0}
|
||||
lthy (Symtab.map (K snd) l1_callees) ()
|
||||
funcs';
|
||||
val f_names = map (fn (name, _, _) => name) funcs;
|
||||
val new_phases = map2 (fn f_name => fn corres => let
|
||||
val old_phase = FunctionInfo.get_phase_info fn_info FunctionInfo.CP f_name;
|
||||
val def = the (AutoCorresData.get_def (Proof_Context.theory_of lthy') filename "L1def" f_name);
|
||||
val f_const = Utils.get_term lthy' (l1_function_name f_name);
|
||||
in old_phase
|
||||
|> FunctionInfo.phase_info_upd_phase FunctionInfo.L1
|
||||
|> FunctionInfo.phase_info_upd_definition def
|
||||
|> FunctionInfo.phase_info_upd_const f_const
|
||||
|> FunctionInfo.phase_info_upd_mono_thm NONE (* done in translate *)
|
||||
end) f_names corres_thms;
|
||||
(* FIXME: return traces *)
|
||||
in (Symtab.make (f_names ~~ (new_phases ~~ corres_thms)), lthy') end;
|
||||
|
||||
|
||||
(*
|
||||
* Top level translation from SIMPL to a monadic spec.
|
||||
*
|
||||
* We accept a filename (the same filename passed to the C parser; the
|
||||
* parser stashes away important information using this filename as the
|
||||
* key) and a local theory.
|
||||
*
|
||||
* We define a number of new functions (the converted monadic
|
||||
* specifications of the SIMPL functions) and theorems (proving
|
||||
* correspondence between our generated specs and the original SIMPL
|
||||
* code).
|
||||
*)
|
||||
(* FIXME: use AutoCorresUtil.Future instead of default *)
|
||||
fun translate filename prog_info fn_info check_termination do_opt trace_opt l1_function_name lthy =
|
||||
let
|
||||
val funcs_to_translate = Symtab.keys (FunctionInfo.get_all_functions fn_info);
|
||||
|
||||
(* All conversions can run in parallel. *)
|
||||
val converted_funcs =
|
||||
funcs_to_translate |> map (fn f =>
|
||||
(f, Future.fork (fn _ =>
|
||||
convert lthy prog_info fn_info check_termination do_opt trace_opt l1_function_name f)))
|
||||
|> Symtab.make;
|
||||
|
||||
(* Definitions update lthy sequentially.
|
||||
* We use the arbitrary (but deterministic) ordering defined by get_topo_sorted_functions.
|
||||
* Each definition step produces a lthy that has a prefix of the update sequence,
|
||||
* and can be used in an L2 translation that depends only on that prefix.
|
||||
* Hence we return the intermediate lthys as futures. *)
|
||||
fun add_def f_names accum = Future.fork (fn _ => let
|
||||
(* Get output from previous definition.
|
||||
* Technically we don't need all of defined_so_far, but we're guaranteed
|
||||
* to have them at this point already. *)
|
||||
val (lthy, _, defined_so_far) = Future.join accum;
|
||||
(* Wait for conversions to finish *)
|
||||
val f_convs = map (fn f => let
|
||||
val conv = the' ("didn't convert function: " ^ quote f ^ "??") (Symtab.lookup converted_funcs f);
|
||||
val (corres_thm, arg_frees) = Future.join conv;
|
||||
in (f, corres_thm, arg_frees) end) f_names;
|
||||
val (new_defs, lthy') = define lthy filename prog_info fn_info check_termination
|
||||
(defined_so_far: (FunctionInfo.phase_info * thm) Symtab.table)
|
||||
l1_function_name f_convs;
|
||||
in (lthy', new_defs, Symtab.merge (K false) (defined_so_far, new_defs)) end);
|
||||
|
||||
val function_groups = FunctionInfo.get_topo_sorted_functions fn_info;
|
||||
(* Chain of intermediate states: (lthy, new_defs, accumulator) *)
|
||||
val (def_results, _) = Utils.accumulate add_def (Future.value (lthy, Symtab.empty, Symtab.empty)) function_groups;
|
||||
|
||||
(* Produce a mapping from each function group to its L1 phase_infos and the
|
||||
* earliest intermediate lthy where it is defined. *)
|
||||
val results =
|
||||
def_results
|
||||
|> map (Future.map (fn (lthy, f_defs, _ (* discard accum here *)) => let
|
||||
(* Add monad_mono proofs. These are done in parallel as well
|
||||
* (though in practice, they already run pretty quickly). *)
|
||||
val mono_thms = if FunctionInfo.is_function_recursive fn_info (hd (Symtab.keys f_defs))
|
||||
then l1_monad_mono lthy (Symtab.map (K fst) f_defs)
|
||||
else Symtab.empty;
|
||||
val f_defs' = f_defs |> Symtab.map (fn f =>
|
||||
apfst (FunctionInfo.phase_info_upd_mono_thm (Symtab.lookup mono_thms f)));
|
||||
in (lthy, f_defs') end));
|
||||
in
|
||||
(* would be ready to become a symtab, but a list also preserves order *)
|
||||
function_groups ~~ results
|
||||
end
|
||||
|
||||
end
|
|
@ -149,6 +149,10 @@ unsigned opt_a(unsigned m, unsigned n) {
|
|||
return opt_a(m - 1, opt_a(m, n - 1));
|
||||
}
|
||||
|
||||
/* Test for measure_call */
|
||||
unsigned opt_a2(unsigned n) {
|
||||
return opt_a(n, n);
|
||||
}
|
||||
|
||||
|
||||
/*********************
|
||||
|
|
|
@ -22,6 +22,59 @@ install_C_file "type_strengthen.c"
|
|||
For example, suppose that we do not want to lift loops to the option monad: *)
|
||||
declare gets_theE_L2_while [ts_rule option del]
|
||||
|
||||
context type_strengthen begin
|
||||
ML \<open>
|
||||
let val fn_info = FunctionInfo.init_fn_info @{context} "type_strengthen.c"
|
||||
val prog_info = ProgramInfo.get_prog_info @{context} "type_strengthen.c"
|
||||
val (corres1, frees1) = SimplConv2.convert @{context} prog_info fn_info true true false (fn f => "l1_" ^ f) "opt_j";
|
||||
val (corres2, frees2) = SimplConv2.convert @{context} prog_info fn_info true true false (fn f => "l1_" ^ f) "st_i";
|
||||
(*val thm' = Thm.generalize ([], map fst frees) (Thm.maxidx_of thm + 1) thm*)
|
||||
val lthy0 = @{context};
|
||||
val (l1_infos1, lthy1) =
|
||||
SimplConv2.define lthy0 "type_strengthen.c" prog_info fn_info true
|
||||
Symtab.empty
|
||||
(fn f => "l1_" ^ f)
|
||||
[("opt_j", corres1, frees1)]
|
||||
val (l1_infos2, lthy2) =
|
||||
SimplConv2.define lthy1 "type_strengthen.c" prog_info fn_info true
|
||||
l1_infos1
|
||||
(fn f => "l1_" ^ f)
|
||||
[("st_i", corres2, frees2)]
|
||||
in (frees1, corres1, Symtab.dest l1_infos2) end
|
||||
\<close>
|
||||
|
||||
ML \<open>
|
||||
let val filename = "type_strengthen.c";
|
||||
val fn_info = FunctionInfo.init_fn_info @{context} filename;
|
||||
val prog_info = ProgramInfo.get_prog_info @{context} filename;
|
||||
val l1_results =
|
||||
SimplConv2.translate filename prog_info fn_info
|
||||
true true false (fn f => "l1_" ^ f ^ "'") @{context};
|
||||
(*
|
||||
val l2_results =
|
||||
LocalVarExtract2.translate filename prog_info fn_info l1_results
|
||||
true false (fn f => "l2_" ^ f ^ "'");
|
||||
*)
|
||||
in l1_results |> map (snd #> Future.join) |> map (snd #> Symtab.dest) end
|
||||
\<close>
|
||||
|
||||
ML \<open>
|
||||
FunctionInfo.is_function_recursive (FunctionInfo.init_fn_info @{context} "type_strengthen.c") "opt_a"
|
||||
\<close>
|
||||
|
||||
ML \<open>
|
||||
FunctionInfo.get_topo_sorted_functions (FunctionInfo.init_fn_info @{context} "type_strengthen.c")
|
||||
\<close>
|
||||
end
|
||||
|
||||
declare [[ML_print_depth=99]]
|
||||
autocorres [
|
||||
ts_rules = nondet,
|
||||
scope = st_i,
|
||||
skip_heap_abs, skip_word_abs
|
||||
] "type_strengthen.c"
|
||||
|
||||
|
||||
(* We can also specify which monads are used for type strengthening.
|
||||
Here, we exclude the read-only monad completely, and specify
|
||||
rules for some individual functions. *)
|
||||
|
@ -32,6 +85,7 @@ autocorres [
|
|||
] "type_strengthen.c"
|
||||
|
||||
context type_strengthen begin
|
||||
|
||||
(* pure_f (and indirectly, pure_f2) are now lifted to the option monad. *)
|
||||
thm pure_f'_def pure_f2'_def
|
||||
thm pure_g'_def pure_h'_def
|
||||
|
|
|
@ -85,6 +85,13 @@ fun enumerate xs = let
|
|||
fun nubBy _ [] = []
|
||||
| nubBy f (x::xs) = x :: filter (fn y => f x <> f y) (nubBy f xs)
|
||||
|
||||
fun accumulate f acc xs = let
|
||||
fun walk results acc [] = (results [], acc)
|
||||
| walk results acc (x::xs) = let
|
||||
val acc' = f x acc;
|
||||
in walk (results o cons acc') acc' xs end;
|
||||
in walk I acc xs end;
|
||||
|
||||
(* Define a constant "name" of type "term" into the local theory "lthy". *)
|
||||
fun define_const name term lthy =
|
||||
let
|
||||
|
|
Loading…
Reference in New Issue