WIP: autocorres: draft of more modular dependencies for L1, L2

Prototype for Jira VER-517.
This commit is contained in:
Japheth Lim 2016-06-09 17:22:23 +10:00
parent 3400debdc2
commit 2caf6520e5
9 changed files with 3261 additions and 35 deletions

View File

@ -145,6 +145,11 @@ ML_file "monad_convert.ML"
ML_file "type_strengthen.ML"
ML_file "autocorres.ML"
declare [[ML_print_depth=42]]
ML_file "autocorres_util2.ML"
ML_file "simpl_conv2.ML"
ML_file "local_var_extract2.ML"
(* Setup "autocorres" keyword. *)
ML {*
Outer_Syntax.command @{command_keyword "autocorres"}
@ -153,4 +158,82 @@ ML {*
(Toplevel.theory o (fn (opt, filename) => AutoCorres.do_autocorres opt filename)))
*}
ML \<open>
fun assume_called_functions_corres ctxt fn_info callees
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var =
let
(* Assume the existence of a function, along with a theorem about its
* behaviour. *)
fun assume_func ctxt fn_name is_recursive_call =
let
val fn_args = get_fn_args fn_name
(* Fix a variable for the function. *)
val ([fixed_fn_name], ctxt') = Variable.variant_fixes [get_const_name fn_name] ctxt
val fn_free = Free (fixed_fn_name, get_fn_type fn_name)
(* Fix a variable for the measure and function arguments. *)
val (measure_var_name :: arg_names, ctxt'')
= Variable.variant_fixes ("rec_measure'" :: (map fst fn_args)) ctxt'
val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args))
val my_measure_var = Free (measure_var_name, @{typ nat})
(*
* A measure variable is needed to handle recursion: for recursive calls,
* we need to decrement the caller's input measure value (and our
* assumption will need to assume this to). This is so we can later prove
* termination of our function definition: the measure always reaches zero.
*
* Non-recursive calls can have a fresh value.
*)
val measure_var =
if is_recursive_call then
@{const "recguard_dec"} $ callers_measure_var
else
my_measure_var
(* Create our assumption. *)
val assumption =
get_fn_assumption ctxt'' fn_name fn_free fn_arg_terms
is_recursive_call measure_var
|> fold Logic.all (rev ((if is_recursive_call then [] else [my_measure_var]) @ fn_arg_terms))
|> Sign.no_vars ctxt'
|> Thm.cterm_of ctxt'
val ([thm], ctxt''') = Assumption.add_assumes [assumption] ctxt'
(* Generate a morphism for escaping this context. *)
val m = (Assumption.export_morphism ctxt''' ctxt')
$> (Variable.export_morphism ctxt' ctxt)
in
(fn_free, thm, ctxt''', m)
end
(* Apply each assumption. *)
val (res, (ctxt', m)) = fold_map (
fn (fn_name, is_recursive_call) =>
fn (ctxt, m) =>
let
val (free, thm, ctxt', m') =
assume_func ctxt fn_name is_recursive_call
in
((fn_name, (is_recursive_call, free, thm)), (ctxt', m' $> m))
end)
callees (ctxt, Morphism.identity)
in
(ctxt', m, res)
end;
(*
fun assume_called_functions_corres ctxt fn_info callees
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var
*)
assume_called_functions_corres @{context} () [("a", false), ("r", true)]
(K @{typ "nat \<Rightarrow> nat \<Rightarrow> string \<Rightarrow> string"})
(fn ctxt => fn name => fn term => fn args => fn is_rec => fn meas =>
HOLogic.mk_Trueprop (@{term "my_corres :: (string \<Rightarrow> string) \<Rightarrow> bool"} $ betapplys (term, meas :: args)))
(fn f => if f = "a" then [("arg_a", @{typ nat})] else [("arg_r", @{typ nat})])
I @{term "rec_measure :: nat"}
\<close>
end

View File

@ -57,10 +57,36 @@ sig
(* result *)
-> (Proof.context * FunctionInfo.fn_info)
val define_funcs:
FunctionInfo.phase ->
string ->
FunctionInfo.fn_info ->
(string -> string) ->
(string -> typ) ->
(Proof.context -> string -> term -> term list -> bool -> term -> term) ->
(string -> (string * typ) list) ->
thm ->
Proof.context ->
thm Symtab.table ->
'a ->
(string * (thm * term * (string * AutoCorresData.Trace) list)) list ->
thm list * 'a * Proof.context
val map_all : Proof.context -> FunctionInfo.fn_info -> (string -> FunctionInfo.function_info -> 'a) -> 'a list
val concurrent : bool Unsynchronized.ref
val has_simpl_body_def : local_theory -> string -> bool
val max_run_time : Time.time option Unsynchronized.ref
val assume_called_functions_corres :
Proof.context ->
(string * bool) list ->
(string -> typ) ->
(Proof.context -> string -> term -> term list -> bool -> term -> term) ->
(string -> (string * typ) list) ->
(string -> string) ->
term ->
Proof.context * morphism * (string * (bool * term * thm)) list
val get_callees : FunctionInfo.fn_info -> string -> string list * string list (* (nonrecs, recs) *)
end;
structure AutoCorresUtil : AUTOCORRES_UTIL =
@ -262,14 +288,32 @@ fun is_Trueprop (Const (@{const_name "Trueprop"}, _) $ _) = true
| is_Trueprop _ = false
(*
* Assume the existence of the given list of functions.
* Assume theorems for called functions.
*
* A new context is returned with the assumptions in it, along with a morphism
* used for exporting the theorems out, and a list of the functions assumed:
*
* (<function name>, (<is_mutually_recursive>, <function free>, <function thm>))
* (<function name>, (<is_mutually_recursive>, <function free>, <arg frees>, <function thm>))
*
* In this context, the theorems refer to functions by fixed free variables.
*
* get_fn_args may return user-friendly argument names that clash with other names.
* We will process these names to avoid conflicts.
*
* get_fn_assumption should produce the desired theorems to assume. Its arguments:
* context (with fixed vars), callee name, callee term, arg terms, is recursive, measure term
* (all terms are fixed free vars).
*
* get_const_name generates names for the free function placeholders.
* FIXME: probably unnecessary and/or broken.
*
* We return two morphisms:
* - the first one makes the assumptions visible again
* - the second one automatically generalizes the assumed constants
* (this exists for backwards compat; all new code should explicitly use the
* returned free variable set)
*)
fun assume_called_functions_corres ctxt fn_info callees
fun assume_called_functions_corres ctxt callees
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var =
let
(* Assume the existence of a function, along with a theorem about its
@ -354,7 +398,7 @@ fun gen_corres_for_function
(ctxt : Proof.context)
(fn_name : string) =
let
val _ = writeln ("Converting (" ^ FunctionInfo.string_of_phase phase ^") " ^ fn_name)
val _ = writeln ("Converting (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name)
val start_time = Timer.startRealTimer ()
(* Get a list of functions we call. *)
@ -364,21 +408,30 @@ let
(map (fn x => (x, false)) normal_calls)
@ (map (fn x => (x, true)) recursive_calls)
(* Make sure the desired function name is available. *)
val fn_target_name = get_const_name fn_name
val ([fn_free], ctxt') = Variable.variant_fixes [fn_target_name] ctxt
val _ = if fn_free = fn_target_name then () else
warning ("Variable clobbered: " ^ fn_target_name ^ " -> " ^ fn_free ^ ". Translating " ^
fn_name ^ " may fail.")
val fn_var_morph = Variable.export_morphism ctxt' ctxt
(* Fix a measure variable that will be used to track recursion progress. *)
val ([measure_var_name], ctxt') = Variable.variant_fixes ["rec_measure'"] ctxt
val ([measure_var_name], ctxt'') = Variable.variant_fixes ["rec_measure'"] ctxt'
val measure_var = Free (measure_var_name, @{typ nat})
val measure_var_morph = Variable.export_morphism ctxt' ctxt
val measure_var_morph = Variable.export_morphism ctxt'' ctxt'
(* Fix variables for function arguments. *)
val fn_args = get_fn_args fn_name
val (arg_names, ctxt'')
= Variable.variant_fixes (map (fn (a, _) => a ^ "'arg") fn_args) ctxt'
val (arg_names, ctxt''')
= Variable.variant_fixes (map (fn (a, _) => a ^ "'arg") fn_args) ctxt''
val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args))
val fn_args_morph = Variable.export_morphism ctxt'' ctxt'
val fn_args_morph = Variable.export_morphism ctxt''' ctxt''
val _ = @{trace} ("Vars", fn_free, measure_var_name, arg_names)
(* Enter a context where we assume our callees exist. *)
val (ctxt''', m, callee_info_and_proofs)
= assume_called_functions_corres ctxt'' fn_info callees
val (ctxt'''', m, callee_info_and_proofs)
= assume_called_functions_corres ctxt''' callees
get_fn_type get_fn_assumption get_fn_args get_const_name
measure_var
@ -387,7 +440,7 @@ let
* term and a tactic for proving correspondence.
*)
val callee_tab = Symtab.make callee_info_and_proofs
val (body, thm, trace) = convert ctxt''' fn_name callee_tab measure_var fn_arg_terms
val (body, thm, trace) = convert ctxt'''' fn_name callee_tab measure_var fn_arg_terms
(*
* The returned body will have free variables as placeholders for the function's
@ -408,7 +461,7 @@ let
|> fold lambda (rev (map (fn x => Symtab.lookup callee_tab x |> the |> #2) normal_calls))
(* Export the theorem out of our context. *)
val exported_thm = Morphism.thm (m $> fn_args_morph $> measure_var_morph) thm
val exported_thm = Morphism.thm (m $> fn_args_morph $> measure_var_morph $> fn_var_morph) thm
(* TODO: allow this message to be configured *)
val _ = @{trace} ("Converted (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name ^ " in " ^
@ -423,15 +476,15 @@ end
* "callee_thms" contains a table mapping function names to complete
* corres proofs for those functions.
*
* "functions" contains a list of (fn_name, (proof, callees)). We
* assume that all functions in this list are mutually recursive. (If
* not, you should call "define_funcs" multiple times, each
* "functions" contains a list of (fn_name, (proof, body, proof traces)).
* The body should be of the form generated by gen_corres_for_function,
* with lambda abstractions for all callees and arguments.
*
* We assume that all functions in this list are mutually recursive.
* (If not, you should call "define_funcs" multiple times, each
* time with a single function.)
*
* This code is quite complex in order to support mutual recursion,
* where function definitions and proofs must simultaneously take place
* for several functions: if we were only supporting non-recursive
* functions, life would be easier.
* The proof traces are stored into the theory. (This should probably be moved.)
*)
fun define_funcs
(phase : FunctionInfo.phase)
@ -444,7 +497,7 @@ fun define_funcs
(rec_base_case : thm)
(ctxt : Proof.context)
(callee_thms : thm Symtab.table)
_
accum (* translate allows an accumulator, but we don't use it here *)
(functions : (string * (thm * term * (string * AutoCorresData.Trace) list)) list)
=
let
@ -455,6 +508,8 @@ fun define_funcs
val _ = writeln ("Defining (" ^ FunctionInfo.string_of_phase phase ^ ") " ^
(Utils.commas (map get_const_name fn_names)))
val _ = @{trace} ("define_funcs function(s)", functions)
val _ = @{trace} ("define_funcs callee(s)", Symtab.dest callee_thms)
(*
* Determine if we are in a recursive case by checking to see if the
@ -495,10 +550,14 @@ fun define_funcs
* Mutually recursive calls should be of the form "Free (fn_name, fn_type)".
*)
val defs = map (
fn (fn_name, fn_body) =>
(get_const_name fn_name,
("rec_measure'", @{typ nat}) :: get_fn_args fn_name,
fill_body fn_name fn_body))
fn (fn_name, fn_body) => let
val fn_args = get_fn_args fn_name
(* FIXME: this retraces assume_called_functions_corres *)
val (fn_free :: measure_free :: arg_frees, _) = Variable.variant_fixes
(get_const_name fn_name :: "rec_measure'" :: map fst fn_args) ctxt
in (get_const_name fn_name, (* be inflexible when it comes to fn_name *)
(measure_free, @{typ nat}) :: (arg_frees ~~ map snd fn_args), (* changing arg names is ok *)
fill_body fn_name fn_body) end)
(fn_names ~~ fn_bodies)
val (fn_def_thms, ctxt) = Utils.define_functions defs true is_recursive ctxt
@ -560,12 +619,12 @@ fun define_funcs
(* Prove each of the predicates above, leaving any assumptions about called
* functions unsolved. *)
val pred_thms = map (
fn (pred, thm, body_def) =>
fn (pred, thm, body_def) => (@{trace} ("define_funcs applying rule", thm);
Thm.trivial (Thm.cterm_of ctxt' pred)
|> Utils.apply_tac "unfold body" (Hypsubst.stac ctxt' body_def 1)
|> Utils.apply_tac "apply rule" (resolve_tac ctxt' [thm] 1)
|> Goal.norm_result ctxt
|> singleton (Variable.export ctxt' ctxt)
|> singleton (Variable.export ctxt' ctxt))
)
(Utils.zip3 preds fn_thms fn_def_thms)
@ -630,7 +689,7 @@ fun define_funcs
(fold (fn (phase, fn_name, trace) =>
AutoCorresData.add_trace filename phase fn_name trace) fn_traces) ctxt
in
(new_thms, (), ctxt)
(new_thms, accum, ctxt)
end
(*

View File

@ -0,0 +1,664 @@
(*
* Copyright 2014, NICTA
*
* This software may be distributed and modified according to the terms of
* the BSD 2-Clause license. Note that NO WARRANTY is provided.
* See "LICENSE_BSD2.txt" for details.
*
* @TAG(NICTA_BSD)
*)
(*
* Common code for all translation phases: defining funcs, calculating dependencies, variable fixes, etc.
*)
structure AutoCorresUtil2 =
struct
(*
* Maximum time to let an individual function translation phase to run for.
*
* Note that this is wall time, and not CPU time, so it is a very rough
* tool.
* FIXME: convert to proper option
*)
val max_run_time = Unsynchronized.ref NONE
(*
val max_run_time = Unsynchronized.ref (SOME (seconds 900.0))
*)
exception AutocorresTimeout of string list
fun time_limit f v =
case !max_run_time of
SOME t =>
(TimeLimit.timeLimit t f ()
handle TimeLimit.TimeOut =>
raise AutocorresTimeout v)
| NONE =>
f ()
(* Should we use concurrency? *)
val concurrent = Unsynchronized.ref true;
(*
* Conditionally fork a group of tasks, depending on the value
* of "concurrent".
*)
datatype 'a maybe_fork = Future of 'a future | Boring of 'a
(* Fork a group of tasks. *)
fun maybe_fork ctxt vals =
if ((!concurrent) andalso not (Config.get ctxt ML_Options.exception_trace)) then
map Future (Future.forks {
name = "", group = NONE, deps = [], pri = ~1, interrupts = true}
vals)
else
map (fn x => Boring (x ())) vals
(* Ensure a forked task has completed. *)
fun maybe_join v =
case v of
Boring x => x
| Future x => Future.join x
(* Functional map. *)
fun par_map ctxt f a =
if ((!concurrent) andalso not (Config.get ctxt ML_Options.exception_trace)) then
Par_List.map f a
else
map f a
(* Does a SIMPL body exist for the given function name? *)
fun has_simpl_body_def lthy name =
try (fn name => Proof_Context.get_thm lthy (name ^ "_body_def")) name
|> is_some
(*
* A translation step transforms the program from one form to another;
* such as from SIMPL to a monadic type, or from one type of monad to
* another.
*
* "filename" is the name of the file we are translating: this is required as a
* key to fetch data stashed away by the C parser.
*
* "lthy" is the local theory.
*
* "convert" performs any proof work required for this translation step. All
* conversions are performed in parallel, so must be able to be completed
* without results from previous steps.
*
* "define" actually sets up any definitions required by the translation;
* definition steps occur serially, but may be in parallel with conversion
* steps whose results are not yet required.
*
* "prove_mono" should prove the monad_mono property for recursive functions.
* (* It is run in parallel over all recursive groups. *)
*
* Because functions handed to us by the C parser may be mutually recursive and
* such mutually recursive functions must typically be defined simultaneously,
* "define" is handed a list of functions which must all be defined in one
* step.
*)
fun translate lthy phase fn_info initial_callees convert define gen_new_info prove_mono v =
let
(* Get list of functions we need to translate.
* This is a bit complicated because we need to skip over functions that
* have already been translated, and hence we need to recalculate the
* function call graph. *)
val functions_to_translate =
Symtab.dest (FunctionInfo.get_all_functions fn_info)
|> map_filter (fn (name, info) =>
case FunctionInfo.Phasetab.lookup (#phases info) phase of
NONE => SOME name
| SOME _ => NONE)
|> Symset.make
val fn_info_restricted = FunctionInfo.map_fn_info (fn info =>
if Symset.contains functions_to_translate (#name info) then SOME info else NONE) fn_info
val function_groups = FunctionInfo.get_topo_sorted_functions fn_info_restricted
val all_functions = List.concat function_groups
(*
* Convert every function.
*
* We perform the conversions using futures, which are run in parallel.
* This allows us to perform conversions while we start defining functions,
* hopefully speeding everything up on multicore systems.
*)
val converted_body_thms =
map (fn name => fn _ =>
time_limit (fn _ => convert lthy name) [name]) all_functions
|> maybe_fork lthy
val converted_bodies = Symtab.make (all_functions ~~ converted_body_thms)
(* In sorted order, define constants and proofs for the functions. *)
fun translate fn_names (callee_thms, new_phase_infos, v, lthy) =
let
val defs = map (fn fn_name =>
Symtab.lookup converted_bodies fn_name |> the |> maybe_join) fn_names
val (proofs, v, lthy)
= time_limit (fn _ =>
define lthy callee_thms v (fn_names ~~ defs)) fn_names
val new_callee_thms = fold Symtab.update_new
(fn_names ~~ proofs) callee_thms
val new_phase_infos = fold (fn n =>
Symtab.update_new (n, gen_new_info lthy v (FunctionInfo.get_function_info fn_info n)))
fn_names new_phase_infos
in
(new_callee_thms, new_phase_infos, v, lthy)
end
val (proof_table, new_phase_infos, v, lthy)
= fold translate function_groups (initial_callees, Symtab.empty, v, lthy)
val mono_thms =
function_groups
|> map (fn funcs => if not (FunctionInfo.is_function_recursive fn_info (hd funcs))
then K Symtab.empty else (fn _ => time_limit (fn _ =>
(List.mapPartial (fn f =>
case Symtab.lookup new_phase_infos f of
SOME phase_info =>
SOME (FunctionInfo.function_info_add_phase phase_info
(FunctionInfo.get_function_info fn_info f))
| _ => NONE) funcs
|> prove_mono lthy)) funcs))
|> maybe_fork lthy |> map maybe_join
|> maps Symtab.dest |> Symtab.make
val new_phase_infos = new_phase_infos |>
Symtab.map (fn func => FunctionInfo.phase_info_upd_mono_thm (Symtab.lookup mono_thms func))
in
(lthy, FunctionInfo.add_phases (fn name => K (Symtab.lookup new_phase_infos name)) fn_info, v)
end
(*
* A translation step that maps over every function in the program.
*
* "convert" performs any proof work required for this translation step. All
* conversions are performed in parallel, so must be able to be completed
* without results from previous steps.
*
* We return a list of all results.
*)
fun map_all ctxt fn_info convert =
par_map ctxt (uncurry convert) (FunctionInfo.get_all_functions fn_info |> Symtab.dest)
(*
* Get functions called by a particular function.
*
* We split the result into standard calls and recursive calls (i.e., calls
* which may recursively call back into us).
*)
fun get_callees fn_info fn_name =
let
(* Get a list of functions we call. *)
val all_callees = FunctionInfo.get_function_callees fn_info fn_name
(* Fetch calls that may recursively call back to us. *)
val recursive_calls = FunctionInfo.get_recursive_group fn_info fn_name
(* Remove "recursive_calls" from the standard callee set. *)
val callees =
Symset.make all_callees
|> Symset.subtract (Symset.make recursive_calls)
|> Symset.dest
in
(callees, recursive_calls)
end
(* Is the given term a Trueprop? *)
fun is_Trueprop (Const (@{const_name "Trueprop"}, _) $ _) = true
| is_Trueprop _ = false
(*
* Assume theorems for called functions.
*
* A new context is returned with the assumptions in it, along with a morphism
* used for exporting the theorems out, and a list of the functions assumed:
*
* (<function name>, (<is_mutually_recursive>, <function free>, <arg frees>, <function thm>))
*
* In this context, the theorems refer to functions by fixed free variables.
*
* get_fn_args may return user-friendly argument names that clash with other names.
* We will process these names to avoid conflicts.
*
* get_fn_assumption should produce the desired theorems to assume. Its arguments:
* context (with fixed vars), callee name, callee term, arg terms, is recursive, measure term
* (all terms are fixed free vars).
*
* get_const_name generates names for the free function placeholders.
* FIXME: probably unnecessary and/or broken.
*
* We return two morphisms:
* - the first one makes the assumptions visible again
* - the second one automatically generalizes the assumed constants
* (this exists for backwards compat; all new code should explicitly use the
* returned free variable set)
*)
fun assume_called_functions_corres ctxt callees
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var =
let
(* Assume the existence of a function, along with a theorem about its
* behaviour. *)
fun assume_func ctxt fn_name is_recursive_call =
let
val fn_args = get_fn_args fn_name
(* Fix a variable for the function. *)
val ([fixed_fn_name], ctxt') = Variable.variant_fixes [get_const_name fn_name] ctxt
val fn_free = Free (fixed_fn_name, get_fn_type fn_name)
(* Fix a variable for the measure and function arguments. *)
val (measure_var_name :: arg_names, ctxt'')
= Variable.variant_fixes ("rec_measure'" :: (map fst fn_args)) ctxt'
val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args))
val my_measure_var = Free (measure_var_name, @{typ nat})
(*
* A measure variable is needed to handle recursion: for recursive calls,
* we need to decrement the caller's input measure value (and our
* assumption will need to assume this to). This is so we can later prove
* termination of our function definition: the measure always reaches zero.
*
* Non-recursive calls can have a fresh value.
*)
val measure_var =
if is_recursive_call then
@{const "recguard_dec"} $ callers_measure_var
else
my_measure_var
(* Create our assumption. *)
val assumption =
get_fn_assumption ctxt'' fn_name fn_free fn_arg_terms
is_recursive_call measure_var
|> fold Logic.all (rev ((if is_recursive_call then [] else [my_measure_var]) @ fn_arg_terms))
|> Sign.no_vars ctxt'
|> Thm.cterm_of ctxt'
val ([thm], ctxt''') = Assumption.add_assumes [assumption] ctxt'
(* Generate a morphism for escaping this context. *)
val m = (Assumption.export_morphism ctxt''' ctxt')
$> (Variable.export_morphism ctxt' ctxt)
in
(fn_free, thm, ctxt''', m)
end
(* Apply each assumption. *)
val (res, (ctxt', m)) = fold_map (
fn (fn_name, is_recursive_call) =>
fn (ctxt, m) =>
let
val (free, thm, ctxt', m') =
assume_func ctxt fn_name is_recursive_call
in
((fn_name, (is_recursive_call, free, thm)), (ctxt', m' $> m))
end)
callees (ctxt, Morphism.identity)
in
(ctxt', m, res)
end
(*
* Convert a single function.
*
* Given a single concrete function, abstract that function and
* return a theorem that shows the correspondence.
*
* A theorem is returned which has assumptions that called functions
* correspond, giving a goal that this given function corresponds.
*)
fun gen_corres_for_function
(phase : FunctionInfo.phase)
(fn_info : FunctionInfo.fn_info)
(get_fn_type : string -> typ)
(get_fn_assumption : Proof.context -> string -> term -> term list -> bool -> term -> term)
(get_fn_args : string -> (string * typ) list)
(get_const_name : string -> string)
(convert : Proof.context -> string -> ((bool * term * thm) Symtab.table) ->
term -> term list -> (term * thm * (string * AutoCorresData.Trace) list))
(ctxt : Proof.context)
(fn_name : string) =
let
val _ = writeln ("Converting (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name)
val start_time = Timer.startRealTimer ()
(* Get a list of functions we call. *)
val (normal_calls, recursive_calls)
= get_callees fn_info fn_name
val callees =
(map (fn x => (x, false)) normal_calls)
@ (map (fn x => (x, true)) recursive_calls)
(* Make sure the desired function name is available. *)
val fn_target_name = get_const_name fn_name
val ([fn_free], ctxt') = Variable.variant_fixes [fn_target_name] ctxt
val _ = if fn_free = fn_target_name then () else
warning ("Variable clobbered: " ^ fn_target_name ^ " -> " ^ fn_free ^ ". Translating " ^
fn_name ^ " may fail.")
val fn_var_morph = Variable.export_morphism ctxt' ctxt
(* Fix a measure variable that will be used to track recursion progress. *)
val ([measure_var_name], ctxt'') = Variable.variant_fixes ["rec_measure'"] ctxt'
val measure_var = Free (measure_var_name, @{typ nat})
val measure_var_morph = Variable.export_morphism ctxt'' ctxt'
(* Fix variables for function arguments. *)
val fn_args = get_fn_args fn_name
val (arg_names, ctxt''')
= Variable.variant_fixes (map (fn (a, _) => a ^ "'arg") fn_args) ctxt''
val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args))
val fn_args_morph = Variable.export_morphism ctxt''' ctxt''
val _ = @{trace} ("Vars", fn_free, measure_var_name, arg_names)
(* Enter a context where we assume our callees exist. *)
val (ctxt'''', m, callee_info_and_proofs)
= assume_called_functions_corres ctxt''' callees
get_fn_type get_fn_assumption get_fn_args get_const_name
measure_var
(*
* Do the conversion. We receive a new monadic version of the SIMPL
* term and a tactic for proving correspondence.
*)
val callee_tab = Symtab.make callee_info_and_proofs
val (body, thm, trace) = convert ctxt'''' fn_name callee_tab measure_var fn_arg_terms
(*
* The returned body will have free variables as placeholders for the function's
* input parameters, for the functions it calls, and for its measure variable.
*
* We modify the body to be of the form:
*
* %fun1 fun2 rec1 rec2 measure arg1 arg2. f <...>
*
* That is, all non-recursive calls are abstracted out the front, followed by
* recursive calls, followed by the measure variable, followed by function
* arguments.
*)
val body =
fold lambda (rev fn_arg_terms) body
|> lambda measure_var
|> fold lambda (rev (map (fn x => Symtab.lookup callee_tab x |> the |> #2) recursive_calls))
|> fold lambda (rev (map (fn x => Symtab.lookup callee_tab x |> the |> #2) normal_calls))
(* Export the theorem out of our context. *)
val exported_thm = Morphism.thm (m $> fn_args_morph $> measure_var_morph $> fn_var_morph) thm
(* TODO: allow this message to be configured *)
val _ = @{trace} ("Converted (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name ^ " in " ^
Time.toString (Timer.checkRealTimer start_time) ^ " s")
in
(exported_thm, body, trace)
end
(*
* Given a SIMPL function, define a constant and a proof for it.
*
* "callee_thms" contains a table mapping function names to complete
* corres proofs for those functions.
*
* "functions" contains a list of (fn_name, (proof, body, proof traces)).
* The body should be of the form generated by gen_corres_for_function,
* with lambda abstractions for all callees and arguments.
*
* We assume that all functions in this list are mutually recursive.
* (If not, you should call "define_funcs" multiple times, each
* time with a single function.)
*
* The proof traces are stored into the theory. (This should probably be moved.)
*)
fun define_funcs
(phase : FunctionInfo.phase)
(filename : string)
(fn_info : FunctionInfo.fn_info)
(get_const_name : string -> string)
(get_fn_type : string -> typ)
(get_fn_assumption : Proof.context -> string -> term -> term list -> bool -> term -> term)
(get_fn_args : string -> (string * typ) list)
(rec_base_case : thm)
(ctxt : Proof.context)
(callee_thms : thm Symtab.table)
accum (* translate allows an accumulator, but we don't use it here *)
(functions : (string * (thm * term * (string * AutoCorresData.Trace) list)) list)
=
let
val fn_names = map fst functions
val fn_thms = map (snd #> #1) functions
val fn_bodies = map (snd #> #2) functions
val fn_traces = map (fn (fn_name, (_, _, traces)) => map (fn (module, trace) => (module, fn_name, trace)) traces) functions |> List.concat
val _ = writeln ("Defining (" ^ FunctionInfo.string_of_phase phase ^ ") " ^
(Utils.commas (map get_const_name fn_names)))
val _ = @{trace} ("function(s)", functions)
val _ = @{trace} ("callee(s)", Symtab.dest callee_thms)
(*
* Determine if we are in a recursive case by checking to see if the
* first function in our list makes recursive calls to any other
* function. (This "other function" will be itself if it is simple
* recursion, but may be a different function if we are mutually
* recursive.)
*)
val is_recursive = FunctionInfo.is_function_recursive fn_info (hd fn_names)
val _ = assert (length fn_names = 1 orelse is_recursive)
"define_funcs passed multiple functions, but they don't appear to be recursive."
(*
* Patch in functions into our function body in the following order:
*
* * Non-recursive calls;
* * Recursive calls
*)
fun fill_body fn_name body =
let
val (normal_calls, recursive_calls)
= get_callees fn_info fn_name
val non_rec_calls = map (fn x => Utils.get_term ctxt (get_const_name x)) normal_calls
val rec_calls = map (fn x => Free (get_const_name x, get_fn_type x)) recursive_calls
in
body
|> (fn t => betapplys (t, non_rec_calls))
|> (fn t => betapplys (t, rec_calls))
end
(*
* Define our functions.
*
* Definitions should be of the form:
*
* %arg1 arg2 arg3. (arg1 + arg2 + arg3)
*
* Mutually recursive calls should be of the form "Free (fn_name, fn_type)".
*)
val defs = map (
fn (fn_name, fn_body) => let
val fn_args = get_fn_args fn_name
(* FIXME: this retraces assume_called_functions_corres *)
val (fn_free :: measure_free :: arg_frees, _) = Variable.variant_fixes
(get_const_name fn_name :: "rec_measure'" :: map fst fn_args) ctxt
in (get_const_name fn_name, (* be inflexible when it comes to fn_name *)
(measure_free, @{typ nat}) :: (arg_frees ~~ map snd fn_args), (* changing arg names is ok *)
fill_body fn_name fn_body) end)
(fn_names ~~ fn_bodies)
val (fn_def_thms, ctxt) = Utils.define_functions defs true is_recursive ctxt
(* Record the constant in our theory data. *)
val ctxt = fold (fn (fn_name, def) =>
Local_Theory.background_theory (
AutoCorresData.add_def filename (FunctionInfo.string_of_phase phase ^ "def") fn_name def))
(Utils.zip fn_names fn_def_thms) ctxt
(*
* Instantiate schematic function calls in our theorems with their
* concrete definitions.
*)
val combined_callees = map (get_callees fn_info) (map fst functions)
val combined_normal_calls =
map fst combined_callees |> flat |> sort_distinct fast_string_ord
val combined_recursive_calls =
map snd combined_callees |> flat |> sort_distinct fast_string_ord
val callee_terms =
(combined_recursive_calls @ combined_normal_calls)
|> map (fn x => (get_const_name x, Utils.get_term ctxt (get_const_name x)))
|> Symtab.make
fun fill_proofs thm =
Utils.instantiate_thm_vars ctxt
(fn ((name, _), _) =>
Symtab.lookup callee_terms name
|> Option.map (Thm.cterm_of ctxt)) thm
val fn_thms = map fill_proofs fn_thms
(* Fix free variable for the measure. *)
val ([measure_var_name], ctxt') = Variable.variant_fixes ["m"] ctxt
val measure_var = Free (measure_var_name, @{typ nat})
(* Generate corres predicates for each function. *)
val preds = map (
fn fn_name =>
let
fun mk_forall v t = HOLogic.all_const (Term.fastype_of v) $ lambda v t
val fn_const = Utils.get_term ctxt' (get_const_name fn_name)
(* Fetch parameters to this function. *)
val free_params =
get_fn_args fn_name
|> Variable.variant_frees ctxt' [measure_var]
|> map Free
in
(* Generate the prop. *)
get_fn_assumption ctxt' fn_name fn_const
free_params is_recursive measure_var
|> fold Logic.all (rev free_params)
end) fn_names
(* We generate a goal which solves all the mutually recursive calls simultaneously. *)
val goal = map (Object_Logic.atomize_term ctxt') preds
|> Utils.mk_conj_list
|> HOLogic.mk_Trueprop
|> Thm.cterm_of ctxt'
(* Prove each of the predicates above, leaving any assumptions about called
* functions unsolved. *)
val pred_thms = map (
fn (pred, thm, body_def) => (@{trace} thm;
Thm.trivial (Thm.cterm_of ctxt' pred)
|> Utils.apply_tac "unfold body" (Hypsubst.stac ctxt' body_def 1)
|> Utils.apply_tac "apply rule" (resolve_tac ctxt' [thm] 1)
|> Goal.norm_result ctxt
|> singleton (Variable.export ctxt' ctxt))
)
(Utils.zip3 preds fn_thms fn_def_thms)
(* Create a set of "helper theorems", which should be sufficient to discharge
* all assumptions that our callees refine. *)
val helper_thms =
(map (Symtab.lookup callee_thms #> the) combined_normal_calls) @ pred_thms
|> map (Drule.forall_intr_vars)
|> map (Conv.fconv_rule (Object_Logic.atomize ctxt))
(* Generate a proof term of equivalence using the folded definitions. *)
val new_thm =
Goal.init goal
|> (fn thm =>
if is_recursive then (
Utils.apply_tac "start induction"
(resolve_tac ctxt'
[Utils.named_cterm_instantiate ctxt'
[("n", Thm.cterm_of ctxt' measure_var)] @{thm recguard_induct}]
1) thm
|> Utils.apply_tac "unfold bodies"
(EVERY (map (fn x => (EqSubst.eqsubst_tac ctxt' [1] [x] 1)) (rev fn_def_thms)))
|> Utils.apply_tac "solve induction base cases"
(SOLVES ((simp_tac (put_simpset HOL_ss ctxt' addsimps [rec_base_case]) 1)))
|> Utils.apply_tac "solve remaing goals"
(Utils.metis_insert_tac ctxt helper_thms 1)
) else (
Utils.apply_tac "solve remaing goals"
(Utils.metis_insert_tac ctxt helper_thms 1) thm
))
|> Goal.finish ctxt'
(*
* The proof above is of the form (L1corres a & L1corres b & ...).
* Split it up into several proofs.
*)
fun prove_partial_l1_corres thm pred =
Thm.cterm_of ctxt' pred
|> Goal.init
|> Utils.apply_tac "solving using metis" (Utils.metis_tac ctxt [thm] 1)
|> Goal.finish ctxt'
(* Generate the final theorems. *)
val new_thms =
map (prove_partial_l1_corres new_thm) preds
|> (Variable.export ctxt' ctxt)
|> map (Goal.norm_result ctxt)
(* Record the theorems in our theory data. *)
val ctxt = fold (fn (fn_name, thm) =>
Local_Theory.background_theory
(AutoCorresData.add_thm filename (FunctionInfo.string_of_phase phase ^ "corres") fn_name thm))
(fn_names ~~ new_thms) ctxt
(* Add the theorems to the context. *)
val ctxt = fold (fn (fn_name, thm) =>
Utils.define_lemma (fn_name ^ "_" ^ FunctionInfo.string_of_phase phase ^ "corres") thm #> snd)
(fn_names ~~ new_thms) ctxt
(* Add the traces to the context. *)
val ctxt = Local_Theory.background_theory
(fold (fn (phase, fn_name, trace) =>
AutoCorresData.add_trace filename phase fn_name trace) fn_traces) ctxt
in
(new_thms, accum, ctxt)
end
(*
* Do a translation phase, converting every function from one form to another.
*)
fun do_translation_phase
(phase : FunctionInfo.phase)
(filename : string)
(prog_info : ProgramInfo.prog_info)
(fn_info : FunctionInfo.fn_info)
(get_fn_type : string -> typ)
(get_fn_assumption : local_theory -> string -> term -> term list -> bool -> term -> term)
(get_fn_args : string -> (string * typ) list)
(get_const_name : string -> string)
(convert : local_theory -> string -> ((bool * term * thm) Symtab.table) ->
term -> term list -> (term * thm * (string * AutoCorresData.Trace) list))
(gen_new_info : local_theory -> FunctionInfo.function_info -> FunctionInfo.phase_info)
(prove_mono : local_theory -> FunctionInfo.function_info list -> thm Symtab.table)
(rec_base_case : thm)
(ctxt : Proof.context) =
let
val do_gen_corres =
gen_corres_for_function phase fn_info get_fn_type get_fn_assumption
get_fn_args get_const_name convert;
val do_define_funcs =
define_funcs phase filename fn_info get_const_name get_fn_type
get_fn_assumption get_fn_args rec_base_case
(* Lookup functions that have already been translated (i.e. phase exists) *)
val initial_callees = Symtab.dest (FunctionInfo.get_all_functions fn_info)
|> List.mapPartial (fn (fn_name, info) =>
FunctionInfo.Phasetab.lookup (#phases info) phase
|> Option.mapPartial (fn phase_info =>
AutoCorresData.get_thm (Proof_Context.theory_of ctxt) filename
(FunctionInfo.string_of_phase phase ^ "corres") fn_name)
|> Option.map (fn thm => (fn_name, thm)))
|> Symtab.make
(* Do the translation. *)
val (ctxt', new_fn_info, _) =
translate ctxt phase fn_info initial_callees do_gen_corres do_define_funcs
(fn lthy => K (gen_new_info lthy)) prove_mono ()
(* Map function information. *)
in
(ctxt', new_fn_info)
end
end

View File

@ -25,21 +25,29 @@ sig
(* Function info for a single phase. *)
type phase_info = {
phase : phase,
args : (string * typ) list,
return_type : typ,
const : term,
raw_const : term,
(*
callees : string list,
rec_callees : string list,
*)
definition : thm,
mono_thm : thm option,
phase : phase
mono_thm : thm option
};
val phase_info_upd_phase : phase -> phase_info -> phase_info;
val phase_info_upd_args : (string * typ) list -> phase_info -> phase_info;
val phase_info_upd_return_type : typ -> phase_info -> phase_info;
(* also updates raw_const *)
val phase_info_upd_const : term -> phase_info -> phase_info;
(*
val phase_info_upd_callees : string list -> phase_info -> phase_info;
val phase_info_upd_rec_callees : string list -> phase_info -> phase_info;
*)
val phase_info_upd_definition : thm -> phase_info -> phase_info;
val phase_info_upd_mono_thm : thm option -> phase_info -> phase_info;
val phase_info_upd_phase : phase -> phase_info -> phase_info;
(* Function info for a single function. *)
type function_info = {
@ -101,6 +109,9 @@ structure Phasetab = Table(
val ord = phase_ord);
type phase_info = {
(* The translation phase for this definition. *)
phase : phase,
(* Arguments of the function, in order, excluding measure variables. *)
args : (string * typ) list,
@ -120,10 +131,7 @@ type phase_info = {
definition : thm,
(* monad_mono theorem for the function, if it is recursive. *)
mono_thm : thm option,
(* The translation phase for this definition. *)
phase : phase
mono_thm : thm option
};
type function_info = {

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,630 @@
(*
* Copyright 2014, NICTA
*
* This software may be distributed and modified according to the terms of
* the BSD 2-Clause license. Note that NO WARRANTY is provided.
* See "LICENSE_BSD2.txt" for details.
*
* @TAG(NICTA_BSD)
*)
(*
* Automatically convert SIMPL code fragments into a monadic form, with proofs
* of correspondence between the two.
*)
structure SimplConv2 =
struct
(* Convenience shortcuts. *)
val warning = Utils.ac_warning
val apply_tac = Utils.apply_tac
val the' = Utils.the'
exception FunctionNotFound of string
val simpl_conv_ss = AUTOCORRES_SIMPSET
(*
* Given a function constant name such as "Blah.foo_'proc", guess the underlying
* function name "foo".
*)
fun guess_function_name const_name =
const_name |> unsuffix "_'proc" |> Long_Name.base_name
(* Generate a L1 monad type. *)
fun mk_l1monadT stateT =
Utils.gen_typ @{typ "'a L1_monad"} [stateT]
(*
* Extract the L1 monadic term out of a L1corres constant.
*)
fun get_L1corres_monad @{term_pat "L1corres _ _ ?l1_monad _"} = l1_monad
| get_L1corres_monad t = raise TERM ("get_L1corres_monad", [t])
(*
* Generate a SIMPL term that calls the given function.
*
* For instance, we might return:
*
* "Call foo_'proc"
*)
fun mk_SIMPL_call_term ctxt prog_info fn_info target_fn =
@{mk_term "Call ?proc :: (?'s, int, strictc_errortype) com" (proc, 's)}
(FunctionInfo.get_phase_info fn_info FunctionInfo.CP target_fn |> #const, #state_type prog_info)
(*
* Construct a correspondence lemma between a given monadic term and a SIMPL fragment.
*
* The term is of the form:
*
* L1corres check_termination \<Gamma> monad simpl
*)
fun mk_L1corres_prop prog_info check_termination monad_term simpl_term =
@{mk_term "L1corres ?ct ?gamma ?monad ?simpl" (ct, gamma, monad, simpl)}
(Utils.mk_bool check_termination, #gamma prog_info, monad_term, simpl_term)
(*
* Construct a prop claiming that the given term is equivalent to
* a call to the given SIMPL function:
*
* L1corres ct \<Gamma> <term> (Call foo_'proc)
*
*)
fun mk_L1corres_call_prop ctxt prog_info fn_info check_termination target_fn_name term =
mk_L1corres_prop prog_info check_termination term
(mk_SIMPL_call_term ctxt prog_info fn_info target_fn_name)
|> HOLogic.mk_Trueprop
(*
* Convert a SIMPL fragment into a monadic term.
*
* We return the monadic version of the input fragment and a tactic
* to prove correspondence.
*)
fun simpl_conv'
(prog_info : ProgramInfo.prog_info)
(fn_info : FunctionInfo.fn_info)
(ctxt : Proof.context)
(callee_terms : (bool * term * thm) Symtab.table)
(measure_var : term)
(simpl_term : term) =
let
fun prove_term subterms base_thm result_term =
let
val subterms' = map (simpl_conv' prog_info fn_info ctxt
callee_terms measure_var) subterms;
val converted_terms = map fst subterms';
val subproofs = map snd subterms';
val new_term = (result_term converted_terms);
in
(new_term, (resolve_tac ctxt [base_thm] 1) THEN (EVERY subproofs))
end
(* Construct a "L1 monad" term with the given arguments applied to it. *)
fun mk_l1 (Const (a, _)) args =
Term.betapplys (Const (a, map fastype_of args
---> mk_l1monadT (#state_type prog_info)), args)
(* Convert a set construct into a predicate construct. *)
fun set_to_pred t =
(Const (@{const_name L1_set_to_pred},
fastype_of t --> (HOLogic.dest_setT (fastype_of t) --> @{typ bool})) $ t)
in
(case simpl_term of
(*
* Various easy cases of SIMPL to monadic conversion.
*)
(Const (@{const_name Skip}, _)) =>
prove_term [] @{thm L1corres_skip}
(fn _ => mk_l1 @{term "L1_skip"} [])
| (Const (@{const_name Seq}, _) $ left $ right) =>
prove_term [left, right] @{thm L1corres_seq}
(fn [l, r] => mk_l1 @{term "L1_seq"} [l, r])
| (Const (@{const_name Basic}, _) $ m) =>
prove_term [] @{thm L1corres_modify}
(fn _ => mk_l1 @{term "L1_modify"} [m])
| (Const (@{const_name Cond}, _) $ c $ left $ right) =>
prove_term [left, right] @{thm L1corres_condition}
(fn [l, r] => mk_l1 @{term "L1_condition"} [set_to_pred c, l, r])
| (Const (@{const_name Catch}, _) $ left $ right) =>
prove_term [left, right] @{thm L1corres_catch}
(fn [l, r] => mk_l1 @{term "L1_catch"} [l, r])
| (Const (@{const_name While}, _) $ c $ body) =>
prove_term [body] @{thm L1corres_while}
(fn [body] => mk_l1 @{term "L1_while"} [set_to_pred c, body])
| (Const (@{const_name Throw}, _)) =>
prove_term [] @{thm L1corres_throw}
(fn _ => mk_l1 @{term "L1_throw"} [])
| (Const (@{const_name Guard}, _) $ _ $ c $ body) =>
prove_term [body] @{thm L1corres_guard}
(fn [body] => mk_l1 @{term "L1_seq"} [mk_l1 @{term "L1_guard"} [set_to_pred c], body])
| @{term_pat "lvar_nondet_init _ ?upd"} =>
prove_term [] @{thm L1corres_init}
(fn _ => mk_l1 @{term "L1_init"} [upd])
| (Const (@{const_name Spec}, _) $ s) =>
prove_term [] @{thm L1corres_spec}
(fn _ => mk_l1 @{term "L1_spec"} [s])
| (Const (@{const_name guarded_spec_body}, _) $ _ $ s) =>
prove_term [] @{thm L1corres_guarded_spec}
(fn _ => mk_l1 @{term "L1_spec"} [s])
(*
* "call": This is primarily what is output by the C parser. We
* accept input terms of the form:
*
* "call <argument_setup> <proc_to_call> <locals_reset> (%_ s. Basic (<store return value> s))".
*
* In particular, the last argument needs to be of precisely the
* form above. SIMPL, in theory, supports complex expressions in
* the last argument. In practice, the C parser only outputs
* the form above, and supporting more would be a pain.
*)
| (Const (@{const_name call}, _) $ a $ (fn_const as Const (b, _)) $ c $ (Abs (_, _, Abs (_, _, (Const (@{const_name Basic}, _) $ d))))) =>
let
val state_type = #state_type prog_info
val target_fn_name = FunctionInfo.get_function_from_const fn_info fn_const |> Option.map #name
in
case Option.mapPartial (Symtab.lookup callee_terms) target_fn_name of
NONE =>
(* If no proof of our callee could be found, we emit a call to
* "fail". This may happen for functions without bodies. *)
let
val _ = warning ("Function '" ^ guess_function_name b ^ "' contains no body. "
^ "Replacing the function call with a \"fail\" command.")
in
prove_term [] @{thm L1corres_fail} (fn _ => mk_l1 @{term "L1_fail"} [])
end
| SOME (is_rec, term, thm) =>
let
(*
* If this is an internal recursive call, decrement the measure.
* Or if this is calling a recursive function, use measure_call.
* If the callee isn't recursive, it doesn't use the measure var
* and we can just give an arbitrary value.
*)
val target_fn_name = (the target_fn_name)
val target_rec = FunctionInfo.is_function_recursive fn_info target_fn_name
val term' =
if is_rec then
term $ (@{term "recguard_dec"} $ measure_var)
else if target_rec then
@{mk_term "measure_call ?f" f} term
else
term $ @{term "undefined :: nat"}
in
(* Generate the term. *)
(mk_l1 @{term "L1_call"}
[a, term', c, absdummy state_type d],
resolve_tac ctxt [if is_rec orelse not target_rec then
@{thm L1corres_reccall} else @{thm L1corres_call}] 1
THEN resolve_tac ctxt [thm] 1)
end
end
(* TODO : Don't currently support DynCom *)
| other => Utils.invalid_term "a SIMPL term" other)
end
(* Perform post-processing on a theorem. *)
fun cleanup_thm ctxt do_opt trace_opt prog_info fn_name thm =
let
(* Measure the term. *)
fun gather_stats phase thm =
Statistics.gather ctxt phase fn_name
(Thm.concl_of thm |> HOLogic.dest_Trueprop |> get_L1corres_monad)
val _ = gather_stats "L1" thm
(* For each function, we want to prepend a statement that sets its return
* value undefined. It is actually always defined, but our analysis isn't
* sophisticated enough to realise. *)
fun prepend_undef thm fn_name =
let
val ret_var_name =
Symtab.lookup (ProgramAnalysis.get_fninfo (#csenv prog_info)) fn_name
|> the
|> (fn (ctype, _, _) => NameGeneration.return_var_name ctype |> MString.dest)
val ret_var_setter = Symtab.lookup (#var_setters prog_info) ret_var_name
val ret_var_getter = Symtab.lookup (#var_getters prog_info) ret_var_name
fun try_unify (x::xs) =
((x ()) handle THM _ => try_unify xs)
in
case ret_var_setter of
SOME _ =>
(* Prepend the L1_init code. *)
Utils.named_cterm_instantiate ctxt
[("X", Thm.cterm_of ctxt (the ret_var_setter)),
("X'", Thm.cterm_of ctxt (the ret_var_getter))]
(try_unify [
(fn _ => @{thm L1corres_prepend_unknown_var_recguard} OF [thm]),
(fn _ => @{thm L1corres_prepend_unknown_var} OF [thm]),
(fn _ => @{thm L1corres_prepend_unknown_var'} OF [thm])])
(* Discharge the given proof obligation. *)
|> simp_tac (put_simpset simpl_conv_ss ctxt) 1 |> Seq.hd
| NONE => thm
end
val thm = prepend_undef thm fn_name
(* Conversion combinator to apply a conversion only to the L1 subterm of a
* L1corres term. *)
fun l1conv conv = (Conv.arg_conv (Utils.nth_arg_conv 3 conv))
(* Conversion to simplify guards. *)
fun guard_conv' c =
case (Thm.term_of c) of
(Const (@{const_name "L1_guard"}, _) $ _) =>
Simplifier.asm_full_rewrite (put_simpset simpl_conv_ss ctxt) c
| _ =>
Conv.all_conv c
val guard_conv = Conv.top_conv (K guard_conv') ctxt
(* Apply all the conversions on the generated term. *)
val (thm, guard_opt_trace) = AutoCorresTrace.fconv_rule_maybe_traced ctxt (l1conv guard_conv) thm trace_opt
val (thm, peephole_opt_trace) =
AutoCorresTrace.fconv_rule_maybe_traced ctxt
(l1conv (Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps
(if do_opt then Utils.get_rules ctxt @{named_theorems L1opt} else []))))
thm trace_opt
val _ = gather_stats "L1peep" thm
(* Rewrite exceptions. *)
val (thm, exn_opt_trace) = AutoCorresTrace.fconv_rule_maybe_traced ctxt
(l1conv (ExceptionRewrite.except_rewrite_conv ctxt do_opt)) thm trace_opt
val _ = gather_stats "L1except" thm
in
(thm,
[("L1 guard opt", guard_opt_trace), ("L1 peephole opt", peephole_opt_trace), ("L1 exception opt", exn_opt_trace)]
|> List.mapPartial (fn (n, tr) => case tr of NONE => NONE | SOME x => SOME (n, AutoCorresData.SimpTrace x))
)
end
(*
* Get theorems about a SIMPL body in a format convenient to reason about.
*
* In particular, we unfold parts of SIMPL where we would prefer to reason
* about raw definitions instead of more abstract constructs generated
* by the C parser.
*)
fun get_simpl_body ctxt fn_info fn_name =
let
(* Find the definition of the given function. *)
val simpl_thm = #definition (FunctionInfo.get_phase_info fn_info FunctionInfo.CP fn_name)
handle ERROR _ => raise FunctionNotFound fn_name;
(* Unfold terms in the body which we don't want to deal with. *)
val unfolded_simpl_thm =
Conv.fconv_rule (Utils.rhs_conv
(Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps
(Utils.get_rules ctxt @{named_theorems L1unfold}))))
simpl_thm
val unfolded_simpl_term = Thm.concl_of unfolded_simpl_thm |> Utils.rhs_of;
(*
* Get the implementation definition for this function. These rules are of
* the form "Gamma foo_'proc = Some foo_body".
*)
val impl_thm =
Proof_Context.get_thm ctxt (fn_name ^ "_impl")
|> Local_Defs.unfold ctxt [unfolded_simpl_thm]
|> SOME
handle (ERROR _) => NONE
in
(unfolded_simpl_term, unfolded_simpl_thm, impl_thm)
end
fun get_l1corres_thm prog_info fn_info check_termination ctxt do_opt trace_opt fn_name
callee_terms measure_var = let
val thy = Proof_Context.theory_of ctxt
val (simpl_term, simpl_thm, impl_thm) = get_simpl_body ctxt fn_info fn_name
(* Fetch stats on pre-converted term. *)
val _ = Statistics.gather ctxt "CParser" fn_name simpl_term
(*
* Do the conversion. We receive a new monadic version of the SIMPL
* term and a tactic for proving correspondence.
*)
val (monad, tactic) = simpl_conv' prog_info fn_info ctxt
callee_terms measure_var simpl_term
(*
* Wrap the monad in a "L1_recguard" statement, which triggers
* failure when the measure reaches zero. This lets us automatically
* prove termination of the recursive function.
*)
val is_recursive = FunctionInfo.is_function_recursive fn_info fn_name
val (monad, tactic) =
if is_recursive then
(Utils.mk_term thy @{term "L1_recguard"} [measure_var, monad],
(resolve_tac ctxt @{thms L1corres_recguard} 1 THEN tactic))
else
(monad, tactic)
(*
* Return a new theorem of correspondence between the original
* SIMPL body (with folded constants) and the output monad term.
*)
in
mk_L1corres_call_prop ctxt prog_info fn_info check_termination fn_name monad
|> Thm.cterm_of ctxt
|> Goal.init
|> (case impl_thm of
NONE => apply_tac "unfold SIMPL body" (resolve_tac ctxt @{thms L1corres_undefined_call} 1)
| SOME def => apply_tac "unfold SIMPL body" (resolve_tac ctxt @{thms L1corres_Call} 1 THEN
resolve_tac ctxt [def] 1)
#> apply_tac "solve L1corres" tactic)
|> Goal.finish ctxt
(* Apply simplifications to the L1 term. *)
|> cleanup_thm ctxt do_opt trace_opt prog_info fn_name
end
fun get_body_of_l1corres_thm thm =
(* Extract the monad from the thm. *)
Thm.concl_of thm
|> HOLogic.dest_Trueprop
|> get_L1corres_monad
fun split_conj thm =
(thm RS @{thm conjunct1}) :: split_conj (thm RS @{thm conjunct2})
handle THM _ => [thm]
(* Prove monad_mono for recursive functions. *)
fun l1_monad_mono lthy (l1_defs : FunctionInfo.phase_info Symtab.table) =
let
val l1_defs' = Symtab.dest l1_defs;
fun mk_stmt [func] = @{mk_term "monad_mono ?f" f} func
| mk_stmt (func :: funcs) = @{mk_term "monad_mono ?f \<and> ?g" (f, g)} (func, mk_stmt funcs);
val mono_thm = @{term "Trueprop"} $ mk_stmt (map (#const o snd) l1_defs');
val func_expand = map (fn (_, l1_def) =>
EqSubst.eqsubst_tac lthy [0] [Utils.abs_def lthy (#definition l1_def)]) l1_defs';
val tac =
REPEAT (EqSubst.eqsubst_tac lthy [0]
[@{thm monad_mono_alt_def}, @{thm all_conj_distrib} RS @{thm sym}] 1)
THEN resolve_tac lthy @{thms allI} 1 THEN resolve_tac lthy @{thms nat.induct} 1
THEN EVERY (map (fn expand =>
TRY (resolve_tac lthy @{thms conjI} 1)
THEN expand 1
THEN resolve_tac lthy @{thms monad_mono_step_L1_recguard_0} 1) func_expand)
THEN REPEAT (eresolve_tac lthy @{thms conjE} 1)
THEN EVERY (map (fn expand =>
TRY (resolve_tac lthy @{thms conjI} 1)
THEN expand 1
THEN REPEAT (FIRST [assume_tac lthy 1,
resolve_tac lthy @{thms L1_monad_mono_step_rules} 1]))
func_expand);
in
Goal.prove lthy [] [] mono_thm (K tac)
|> split_conj
|> (fn thms => map fst l1_defs' ~~ thms)
|> Symtab.make
end
(* For functions that are not translated, just generate a trivial wrapper. *)
fun mk_l1corres_call_simpl_thm fn_info check_termination ctxt fn_name = let
val info = FunctionInfo.get_phase_info fn_info FunctionInfo.CP fn_name
val const = #const info
val impl_thm = Proof_Context.get_thm ctxt (fn_name ^ "_impl")
val gamma = safe_mk_meta_eq impl_thm |> Thm.concl_of |> Logic.dest_equals
|> fst |> (fn (f $ _) => f | t => raise TERM ("gamma", [t]))
val thm = Utils.named_cterm_instantiate ctxt
[("ct", Thm.cterm_of ctxt (Utils.mk_bool check_termination)),
("proc", Thm.cterm_of ctxt const),
("Gamma", Thm.cterm_of ctxt gamma)]
@{thm L1corres_call_simpl}
in thm end
(*
* Convert a single function. Returns a thm that looks like
* \<lbrakk> L1corres ?callee1 (Call callee1_'proc); ... \<rbrakk> \<Longrightarrow>
* L1corres (conversion result...) (Call f_'proc)
* i.e. with assumptions for called functions, which are parameterised as Vars.
*)
fun convert
(lthy: local_theory)
(prog_info: ProgramInfo.prog_info)
(fn_info: FunctionInfo.fn_info) (* needs CP phase_info for each callee *)
(check_termination: bool)
(do_opt: bool)
(trace_opt: bool)
(l1_function_name: string -> string)
(f_name: string)
: thm * (string * typ) list = let
val callee_names = FunctionInfo.get_function_callees fn_info f_name;
(* TODO sanity check:
- all SIMPL defs exist *)
val measureT = @{typ nat};
(* All L1 functions have the same signature: measure \<Rightarrow> L1_monad *)
val l1_fn_type = measureT --> mk_l1monadT (#state_type prog_info);
(* L1corres for f's callees. *)
fun get_l1_fn_assumption ctxt fn_name free _ _ measure_var =
mk_L1corres_call_prop ctxt prog_info fn_info check_termination fn_name (betapply (free, measure_var));
(* Fix measure variable. *)
val ([measure_var_name], lthy') = Variable.variant_fixes ["rec_measure'"] lthy;
val measure_var = Free (measure_var_name, measureT);
(* Add callee assumptions. Note that our define code has to use the same assumption order. *)
val group = Symset.make (FunctionInfo.get_recursive_group fn_info f_name);
val callee_terms =
map (fn callee => (callee, Symset.contains group callee)) callee_names;
val (lthy'', export_thm, callee_terms) =
AutoCorresUtil.assume_called_functions_corres lthy'
callee_terms
(K l1_fn_type)
get_l1_fn_assumption
(K [])
l1_function_name
measure_var;
val (thm, opt_traces) =
if #is_simpl_wrapper (FunctionInfo.get_function_info fn_info f_name)
then (mk_l1corres_call_simpl_thm fn_info check_termination lthy'' f_name, [])
else get_l1corres_thm prog_info fn_info check_termination lthy'' do_opt trace_opt f_name
(Symtab.make callee_terms) measure_var;
in (Morphism.thm export_thm thm,
(* Provide the fixed vars so the user can generalize/instantiate them *)
[dest_Free measure_var])
end
(* Define a previously-converted function (or recursive function group).
* lthy must include all definitions from l1_callees *)
fun define
(lthy: local_theory)
(filename: string)
(prog_info: ProgramInfo.prog_info)
(fn_info: FunctionInfo.fn_info) (* doesn't need to have L1 info *)
(check_termination: bool)
(l1_callees: (FunctionInfo.phase_info * thm) Symtab.table) (* L1 callees & corres thms *)
(l1_function_name: string -> string)
(funcs: (string * thm * (string * typ) list) list) (* name, corres, arg frees *)
: (FunctionInfo.phase_info * thm) Symtab.table * local_theory = let
(* FIXME: dedup with convert *)
(* All L1 functions have the same signature: measure \<Rightarrow> L1_monad *)
val measureT = @{typ nat};
val l1_fn_type = measureT --> mk_l1monadT (#state_type prog_info);
(* L1corres for f's callees. *)
fun get_l1_fn_assumption ctxt fn_name free _ _ measure_var =
mk_L1corres_call_prop ctxt prog_info fn_info check_termination fn_name (betapply (free, measure_var));
(* FIXME: pass this from assume_called_functions_corres, etc. *)
fun guess_callee_var thm callee = let
val base_name = l1_function_name callee;
val mentioned_vars = Term.add_vars (Thm.prop_of thm) [];
in hd (filter (fn ((v, _), _) => v = base_name) mentioned_vars) end;
fun prepare_fn_body (fn_name, corres_thm, measure_free) = let
val @{term_pat "Trueprop (L1corres _ _ ?body _)"} = Thm.concl_of corres_thm;
val (callees, recursive_callees) = AutoCorresUtil.get_callees fn_info fn_name;
val calls = map (fn c => Var (guess_callee_var corres_thm c)) callees;
val recursive_calls = map (fn c => Var (guess_callee_var corres_thm c)) recursive_callees;
(*
* The returned body will have free variables as placeholders for the function's
* measure parameter and other arguments, and schematic variables for the functions it calls.
*
* We modify the body to be of the form:
*
* %fun1 fun2 rec1 rec2 measure arg1 arg2. f <...>
*
* That is, all non-recursive calls are abstracted out the front, followed by
* recursive calls, followed by the measure variable, followed by function
* arguments (none for L1). This is the format expected by define_funcs.
*)
val abs_body = body
|> fold lambda (rev (map Free measure_free))
|> fold lambda (rev recursive_calls)
|> fold lambda (rev calls);
in abs_body end;
val fn_info' = FunctionInfo.add_phases (fn f => K (Option.map fst (Symtab.lookup l1_callees f))) fn_info;
val funcs' = funcs |>
map (fn (name, thm, frees) =>
(name, (* FIXME: define_funcs needs this currently *)
(Thm.generalize ([], map fst frees) (Thm.maxidx_of thm + 1) thm,
prepare_fn_body (name, thm, frees), [])));
val _ = @{trace} ("L1.define", map (fn (name, (thm, body, _)) => (name, thm, Thm.cterm_of lthy body)) funcs');
val (corres_thms, (), lthy') =
AutoCorresUtil.define_funcs FunctionInfo.L1 filename fn_info'
l1_function_name (K l1_fn_type) get_l1_fn_assumption (K []) @{thm L1corres_recguard_0}
lthy (Symtab.map (K snd) l1_callees) ()
funcs';
val f_names = map (fn (name, _, _) => name) funcs;
val new_phases = map2 (fn f_name => fn corres => let
val old_phase = FunctionInfo.get_phase_info fn_info FunctionInfo.CP f_name;
val def = the (AutoCorresData.get_def (Proof_Context.theory_of lthy') filename "L1def" f_name);
val f_const = Utils.get_term lthy' (l1_function_name f_name);
in old_phase
|> FunctionInfo.phase_info_upd_phase FunctionInfo.L1
|> FunctionInfo.phase_info_upd_definition def
|> FunctionInfo.phase_info_upd_const f_const
|> FunctionInfo.phase_info_upd_mono_thm NONE (* done in translate *)
end) f_names corres_thms;
(* FIXME: return traces *)
in (Symtab.make (f_names ~~ (new_phases ~~ corres_thms)), lthy') end;
(*
* Top level translation from SIMPL to a monadic spec.
*
* We accept a filename (the same filename passed to the C parser; the
* parser stashes away important information using this filename as the
* key) and a local theory.
*
* We define a number of new functions (the converted monadic
* specifications of the SIMPL functions) and theorems (proving
* correspondence between our generated specs and the original SIMPL
* code).
*)
(* FIXME: use AutoCorresUtil.Future instead of default *)
fun translate filename prog_info fn_info check_termination do_opt trace_opt l1_function_name lthy =
let
val funcs_to_translate = Symtab.keys (FunctionInfo.get_all_functions fn_info);
(* All conversions can run in parallel. *)
val converted_funcs =
funcs_to_translate |> map (fn f =>
(f, Future.fork (fn _ =>
convert lthy prog_info fn_info check_termination do_opt trace_opt l1_function_name f)))
|> Symtab.make;
(* Definitions update lthy sequentially.
* We use the arbitrary (but deterministic) ordering defined by get_topo_sorted_functions.
* Each definition step produces a lthy that has a prefix of the update sequence,
* and can be used in an L2 translation that depends only on that prefix.
* Hence we return the intermediate lthys as futures. *)
fun add_def f_names accum = Future.fork (fn _ => let
(* Get output from previous definition.
* Technically we don't need all of defined_so_far, but we're guaranteed
* to have them at this point already. *)
val (lthy, _, defined_so_far) = Future.join accum;
(* Wait for conversions to finish *)
val f_convs = map (fn f => let
val conv = the' ("didn't convert function: " ^ quote f ^ "??") (Symtab.lookup converted_funcs f);
val (corres_thm, arg_frees) = Future.join conv;
in (f, corres_thm, arg_frees) end) f_names;
val (new_defs, lthy') = define lthy filename prog_info fn_info check_termination
(defined_so_far: (FunctionInfo.phase_info * thm) Symtab.table)
l1_function_name f_convs;
in (lthy', new_defs, Symtab.merge (K false) (defined_so_far, new_defs)) end);
val function_groups = FunctionInfo.get_topo_sorted_functions fn_info;
(* Chain of intermediate states: (lthy, new_defs, accumulator) *)
val (def_results, _) = Utils.accumulate add_def (Future.value (lthy, Symtab.empty, Symtab.empty)) function_groups;
(* Produce a mapping from each function group to its L1 phase_infos and the
* earliest intermediate lthy where it is defined. *)
val results =
def_results
|> map (Future.map (fn (lthy, f_defs, _ (* discard accum here *)) => let
(* Add monad_mono proofs. These are done in parallel as well
* (though in practice, they already run pretty quickly). *)
val mono_thms = if FunctionInfo.is_function_recursive fn_info (hd (Symtab.keys f_defs))
then l1_monad_mono lthy (Symtab.map (K fst) f_defs)
else Symtab.empty;
val f_defs' = f_defs |> Symtab.map (fn f =>
apfst (FunctionInfo.phase_info_upd_mono_thm (Symtab.lookup mono_thms f)));
in (lthy, f_defs') end));
in
(* would be ready to become a symtab, but a list also preserves order *)
function_groups ~~ results
end
end

View File

@ -149,6 +149,10 @@ unsigned opt_a(unsigned m, unsigned n) {
return opt_a(m - 1, opt_a(m, n - 1));
}
/* Test for measure_call */
unsigned opt_a2(unsigned n) {
return opt_a(n, n);
}
/*********************

View File

@ -22,6 +22,59 @@ install_C_file "type_strengthen.c"
For example, suppose that we do not want to lift loops to the option monad: *)
declare gets_theE_L2_while [ts_rule option del]
context type_strengthen begin
ML \<open>
let val fn_info = FunctionInfo.init_fn_info @{context} "type_strengthen.c"
val prog_info = ProgramInfo.get_prog_info @{context} "type_strengthen.c"
val (corres1, frees1) = SimplConv2.convert @{context} prog_info fn_info true true false (fn f => "l1_" ^ f) "opt_j";
val (corres2, frees2) = SimplConv2.convert @{context} prog_info fn_info true true false (fn f => "l1_" ^ f) "st_i";
(*val thm' = Thm.generalize ([], map fst frees) (Thm.maxidx_of thm + 1) thm*)
val lthy0 = @{context};
val (l1_infos1, lthy1) =
SimplConv2.define lthy0 "type_strengthen.c" prog_info fn_info true
Symtab.empty
(fn f => "l1_" ^ f)
[("opt_j", corres1, frees1)]
val (l1_infos2, lthy2) =
SimplConv2.define lthy1 "type_strengthen.c" prog_info fn_info true
l1_infos1
(fn f => "l1_" ^ f)
[("st_i", corres2, frees2)]
in (frees1, corres1, Symtab.dest l1_infos2) end
\<close>
ML \<open>
let val filename = "type_strengthen.c";
val fn_info = FunctionInfo.init_fn_info @{context} filename;
val prog_info = ProgramInfo.get_prog_info @{context} filename;
val l1_results =
SimplConv2.translate filename prog_info fn_info
true true false (fn f => "l1_" ^ f ^ "'") @{context};
(*
val l2_results =
LocalVarExtract2.translate filename prog_info fn_info l1_results
true false (fn f => "l2_" ^ f ^ "'");
*)
in l1_results |> map (snd #> Future.join) |> map (snd #> Symtab.dest) end
\<close>
ML \<open>
FunctionInfo.is_function_recursive (FunctionInfo.init_fn_info @{context} "type_strengthen.c") "opt_a"
\<close>
ML \<open>
FunctionInfo.get_topo_sorted_functions (FunctionInfo.init_fn_info @{context} "type_strengthen.c")
\<close>
end
declare [[ML_print_depth=99]]
autocorres [
ts_rules = nondet,
scope = st_i,
skip_heap_abs, skip_word_abs
] "type_strengthen.c"
(* We can also specify which monads are used for type strengthening.
Here, we exclude the read-only monad completely, and specify
rules for some individual functions. *)
@ -32,6 +85,7 @@ autocorres [
] "type_strengthen.c"
context type_strengthen begin
(* pure_f (and indirectly, pure_f2) are now lifted to the option monad. *)
thm pure_f'_def pure_f2'_def
thm pure_g'_def pure_h'_def

View File

@ -85,6 +85,13 @@ fun enumerate xs = let
fun nubBy _ [] = []
| nubBy f (x::xs) = x :: filter (fn y => f x <> f y) (nubBy f xs)
fun accumulate f acc xs = let
fun walk results acc [] = (results [], acc)
| walk results acc (x::xs) = let
val acc' = f x acc;
in walk (results o cons acc') acc' xs end;
in walk I acc xs end;
(* Define a constant "name" of type "term" into the local theory "lthy". *)
fun define_const name term lthy =
let