WIP: autocorres: split up function_info data structures

With this we move away from a global mutable fn_info; instead we will
use a table of persistent (lazy) entries for each phase.
Function call metadata is also now either stored locally or recomputed
on-demand for each stage (with a few TODOs).
This commit is contained in:
Japheth Lim 2016-06-14 18:10:28 +10:00
parent 2caf6520e5
commit 84cb9deaf8
6 changed files with 649 additions and 348 deletions

View File

@ -146,6 +146,7 @@ ML_file "type_strengthen.ML"
ML_file "autocorres.ML"
declare [[ML_print_depth=42]]
ML_file "function_info2.ML"
ML_file "autocorres_util2.ML"
ML_file "simpl_conv2.ML"
ML_file "local_var_extract2.ML"

View File

@ -190,21 +190,11 @@ fun map_all ctxt fn_info convert =
* We split the result into standard calls and recursive calls (i.e., calls
* which may recursively call back into us).
*)
fun get_callees fn_info fn_name =
fun get_callees fn_infos fn_name =
let
(* Get a list of functions we call. *)
val all_callees = FunctionInfo.get_function_callees fn_info fn_name
(* Fetch calls that may recursively call back to us. *)
val recursive_calls = FunctionInfo.get_recursive_group fn_info fn_name
(* Remove "recursive_calls" from the standard callee set. *)
val callees =
Symset.make all_callees
|> Symset.subtract (Symset.make recursive_calls)
|> Symset.dest
val fn_info = the (Symtab.lookup fn_infos fn_name)
in
(callees, recursive_calls)
(Symset.dest (#callees fn_info), Symset.dest (#rec_callees fn_info))
end
(* Is the given term a Trueprop? *)
@ -237,7 +227,7 @@ fun is_Trueprop (Const (@{const_name "Trueprop"}, _) $ _) = true
* (this exists for backwards compat; all new code should explicitly use the
* returned free variable set)
*)
fun assume_called_functions_corres ctxt callees
fun assume_called_functions_corres ctxt callees rec_callees
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var =
let
(* Assume the existence of a function, along with a theorem about its
@ -286,7 +276,7 @@ let
(fn_free, thm, ctxt''', m)
end
(* Apply each assumption. *)
(* Add assumptions: recursive calls first, matching the order in define_functions *)
val (res, (ctxt', m)) = fold_map (
fn (fn_name, is_recursive_call) =>
fn (ctxt, m) =>
@ -296,111 +286,25 @@ let
in
((fn_name, (is_recursive_call, free, thm)), (ctxt', m' $> m))
end)
callees (ctxt, Morphism.identity)
(map (fn f => (f, false)) (Symset.dest callees) @
map (fn f => (f, true)) (Symset.dest rec_callees))
(ctxt, Morphism.identity)
in
(ctxt', m, res)
end
(*
* Convert a single function.
*
* Given a single concrete function, abstract that function and
* return a theorem that shows the correspondence.
*
* A theorem is returned which has assumptions that called functions
* correspond, giving a goal that this given function corresponds.
*)
fun gen_corres_for_function
(phase : FunctionInfo.phase)
(fn_info : FunctionInfo.fn_info)
(get_fn_type : string -> typ)
(get_fn_assumption : Proof.context -> string -> term -> term list -> bool -> term -> term)
(get_fn_args : string -> (string * typ) list)
(get_const_name : string -> string)
(convert : Proof.context -> string -> ((bool * term * thm) Symtab.table) ->
term -> term list -> (term * thm * (string * AutoCorresData.Trace) list))
(ctxt : Proof.context)
(fn_name : string) =
let
val _ = writeln ("Converting (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name)
val start_time = Timer.startRealTimer ()
(* Get a list of functions we call. *)
val (normal_calls, recursive_calls)
= get_callees fn_info fn_name
val callees =
(map (fn x => (x, false)) normal_calls)
@ (map (fn x => (x, true)) recursive_calls)
(* Make sure the desired function name is available. *)
val fn_target_name = get_const_name fn_name
val ([fn_free], ctxt') = Variable.variant_fixes [fn_target_name] ctxt
val _ = if fn_free = fn_target_name then () else
warning ("Variable clobbered: " ^ fn_target_name ^ " -> " ^ fn_free ^ ". Translating " ^
fn_name ^ " may fail.")
val fn_var_morph = Variable.export_morphism ctxt' ctxt
(* Fix a measure variable that will be used to track recursion progress. *)
val ([measure_var_name], ctxt'') = Variable.variant_fixes ["rec_measure'"] ctxt'
val measure_var = Free (measure_var_name, @{typ nat})
val measure_var_morph = Variable.export_morphism ctxt'' ctxt'
(* Fix variables for function arguments. *)
val fn_args = get_fn_args fn_name
val (arg_names, ctxt''')
= Variable.variant_fixes (map (fn (a, _) => a ^ "'arg") fn_args) ctxt''
val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args))
val fn_args_morph = Variable.export_morphism ctxt''' ctxt''
val _ = @{trace} ("Vars", fn_free, measure_var_name, arg_names)
(* Enter a context where we assume our callees exist. *)
val (ctxt'''', m, callee_info_and_proofs)
= assume_called_functions_corres ctxt''' callees
get_fn_type get_fn_assumption get_fn_args get_const_name
measure_var
(*
* Do the conversion. We receive a new monadic version of the SIMPL
* term and a tactic for proving correspondence.
*)
val callee_tab = Symtab.make callee_info_and_proofs
val (body, thm, trace) = convert ctxt'''' fn_name callee_tab measure_var fn_arg_terms
(*
* The returned body will have free variables as placeholders for the function's
* input parameters, for the functions it calls, and for its measure variable.
*
* We modify the body to be of the form:
*
* %fun1 fun2 rec1 rec2 measure arg1 arg2. f <...>
*
* That is, all non-recursive calls are abstracted out the front, followed by
* recursive calls, followed by the measure variable, followed by function
* arguments.
*)
val body =
fold lambda (rev fn_arg_terms) body
|> lambda measure_var
|> fold lambda (rev (map (fn x => Symtab.lookup callee_tab x |> the |> #2) recursive_calls))
|> fold lambda (rev (map (fn x => Symtab.lookup callee_tab x |> the |> #2) normal_calls))
(* Export the theorem out of our context. *)
val exported_thm = Morphism.thm (m $> fn_args_morph $> measure_var_morph $> fn_var_morph) thm
(* TODO: allow this message to be configured *)
val _ = @{trace} ("Converted (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name ^ " in " ^
Time.toString (Timer.checkRealTimer start_time) ^ " s")
in
(exported_thm, body, trace)
end
(*
* Given a SIMPL function, define a constant and a proof for it.
* Given one or more function specs, define them and instantiate corres proofs.
*
* "callee_thms" contains a table mapping function names to complete
* corres proofs for those functions.
* corres proofs for those functions. At this point, the functions are
* still free variables.
*
* "functions" contains a list of (fn_name, (proof, body, proof traces)).
* "fn_infos" is used to look up function callees. It is expected
* to consist of the previous translation output for the functions
* being defined, but may of course contain other entries.
*
* "functions" contains a list of (fn_name, (proof, body)).
* The body should be of the form generated by gen_corres_for_function,
* with lambda abstractions for all callees and arguments.
*
@ -408,12 +312,13 @@ end
* (If not, you should call "define_funcs" multiple times, each
* time with a single function.)
*
* The proof traces are stored into the theory. (This should probably be moved.)
* Returns the new function constants, definitions, corres proofs,
* and local theory.
*)
fun define_funcs
(phase : FunctionInfo.phase)
(phase : FunctionInfo2.phase)
(filename : string)
(fn_info : FunctionInfo.fn_info)
(fn_infos : FunctionInfo2.function_info Symtab.table)
(get_const_name : string -> string)
(get_fn_type : string -> typ)
(get_fn_assumption : Proof.context -> string -> term -> term list -> bool -> term -> term)
@ -421,19 +326,20 @@ fun define_funcs
(rec_base_case : thm)
(ctxt : Proof.context)
(callee_thms : thm Symtab.table)
accum (* translate allows an accumulator, but we don't use it here *)
(functions : (string * (thm * term * (string * AutoCorresData.Trace) list)) list)
=
(accum : 'a) (* translate allows an accumulator, but we don't use it here *)
(functions : (string * (thm * term)) list)
: (term * thm * thm) Symtab.table * 'a * Proof.context =
let
val fn_names = map fst functions
val fn_thms = map (snd #> #1) functions
val fn_bodies = map (snd #> #2) functions
val fn_traces = map (fn (fn_name, (_, _, traces)) => map (fn (module, trace) => (module, fn_name, trace)) traces) functions |> List.concat
val _ = writeln ("Defining (" ^ FunctionInfo.string_of_phase phase ^ ") " ^
val _ = writeln ("Defining (" ^ FunctionInfo2.string_of_phase phase ^ ") " ^
(Utils.commas (map get_const_name fn_names)))
(*
val _ = @{trace} ("function(s)", functions)
val _ = @{trace} ("callee(s)", Symtab.dest callee_thms)
*)
(*
* Determine if we are in a recursive case by checking to see if the
@ -442,7 +348,7 @@ fun define_funcs
* recursion, but may be a different function if we are mutually
* recursive.)
*)
val is_recursive = FunctionInfo.is_function_recursive fn_info (hd fn_names)
val is_recursive = FunctionInfo2.is_function_recursive (the (Symtab.lookup fn_infos (hd fn_names)))
val _ = assert (length fn_names = 1 orelse is_recursive)
"define_funcs passed multiple functions, but they don't appear to be recursive."
@ -454,10 +360,9 @@ fun define_funcs
*)
fun fill_body fn_name body =
let
val (normal_calls, recursive_calls)
= get_callees fn_info fn_name
val non_rec_calls = map (fn x => Utils.get_term ctxt (get_const_name x)) normal_calls
val rec_calls = map (fn x => Free (get_const_name x, get_fn_type x)) recursive_calls
val fn_info = the (Symtab.lookup fn_infos fn_name)
val non_rec_calls = map (fn x => Utils.get_term ctxt (get_const_name x)) (Symset.dest (#callees fn_info))
val rec_calls = map (fn x => Free (get_const_name x, get_fn_type x)) (Symset.dest (#rec_callees fn_info))
in
body
|> (fn t => betapplys (t, non_rec_calls))
@ -485,17 +390,19 @@ fun define_funcs
(fn_names ~~ fn_bodies)
val (fn_def_thms, ctxt) = Utils.define_functions defs true is_recursive ctxt
(*
(* Record the constant in our theory data. *)
val ctxt = fold (fn (fn_name, def) =>
Local_Theory.background_theory (
AutoCorresData.add_def filename (FunctionInfo.string_of_phase phase ^ "def") fn_name def))
(Utils.zip fn_names fn_def_thms) ctxt
*)
(*
* Instantiate schematic function calls in our theorems with their
* concrete definitions.
*)
val combined_callees = map (get_callees fn_info) (map fst functions)
val combined_callees = map (get_callees fn_infos) (map fst functions)
val combined_normal_calls =
map fst combined_callees |> flat |> sort_distinct fast_string_ord
val combined_recursive_calls =
@ -543,12 +450,12 @@ fun define_funcs
(* Prove each of the predicates above, leaving any assumptions about called
* functions unsolved. *)
val pred_thms = map (
fn (pred, thm, body_def) => (@{trace} thm;
fn (pred, thm, body_def) =>
Thm.trivial (Thm.cterm_of ctxt' pred)
|> Utils.apply_tac "unfold body" (Hypsubst.stac ctxt' body_def 1)
|> Utils.apply_tac "apply rule" (resolve_tac ctxt' [thm] 1)
|> Goal.norm_result ctxt
|> singleton (Variable.export ctxt' ctxt))
|> singleton (Variable.export ctxt' ctxt)
)
(Utils.zip3 preds fn_thms fn_def_thms)
@ -585,7 +492,7 @@ fun define_funcs
* The proof above is of the form (L1corres a & L1corres b & ...).
* Split it up into several proofs.
*)
fun prove_partial_l1_corres thm pred =
fun prove_partial_corres thm pred =
Thm.cterm_of ctxt' pred
|> Goal.init
|> Utils.apply_tac "solving using metis" (Utils.metis_tac ctxt [thm] 1)
@ -593,10 +500,11 @@ fun define_funcs
(* Generate the final theorems. *)
val new_thms =
map (prove_partial_l1_corres new_thm) preds
map (prove_partial_corres new_thm) preds
|> (Variable.export ctxt' ctxt)
|> map (Goal.norm_result ctxt)
(*
(* Record the theorems in our theory data. *)
val ctxt = fold (fn (fn_name, thm) =>
Local_Theory.background_theory
@ -607,18 +515,22 @@ fun define_funcs
val ctxt = fold (fn (fn_name, thm) =>
Utils.define_lemma (fn_name ^ "_" ^ FunctionInfo.string_of_phase phase ^ "corres") thm #> snd)
(fn_names ~~ new_thms) ctxt
*)
(* Add the traces to the context. *)
val ctxt = Local_Theory.background_theory
(fold (fn (phase, fn_name, trace) =>
AutoCorresData.add_trace filename phase fn_name trace) fn_traces) ctxt
val results =
fn_names ~~ (new_thms ~~ fn_def_thms)
|> map (fn (fn_name, (corres_thm, def_thm)) =>
(* FIXME: ugly way to get the function constant *)
(fn_name, (Utils.get_term ctxt (get_const_name fn_name), def_thm, corres_thm)))
|> Symtab.make;
in
(new_thms, accum, ctxt)
(results, accum, ctxt)
end
(*
* Do a translation phase, converting every function from one form to another.
*)
(*
fun do_translation_phase
(phase : FunctionInfo.phase)
(filename : string)
@ -660,5 +572,6 @@ let
in
(ctxt', new_fn_info)
end
*)
end

View File

@ -0,0 +1,373 @@
(*
* Copyright 2014, NICTA
*
* This software may be distributed and modified according to the terms of
* the BSD 2-Clause license. Note that NO WARRANTY is provided.
* See "LICENSE_BSD2.txt" for details.
*
* @TAG(NICTA_BSD)
*)
(*
* Information about functions in the program we are translating,
* and the call-graph between them.
* To support incremental translation, we store the function information
* for every intermediate phase as well.
*)
signature FUNCTION_INFO2 =
sig
(*** Basic data types ***)
(* List of AutoCorres phases. *)
datatype phase = CP (* Initial definition we get from the C parser *)
| L1 (* SimplConv *)
| L2 (* LocalVarExtract *)
| HL (* HeapLift *) (* TODO: rename to HeapAbstract *)
| WA (* WordAbstract *)
| TS (* TypeStrengthen *);
val string_of_phase : phase -> string;
val phase_ord : phase * phase -> order;
structure Phasetab : TABLE;
(* Function info for a single phase. *)
type function_info = {
(* Name of the function. *)
name : string,
(* The translation phase for this definition. *)
phase : phase,
(* Constant for the function, which can be inserted as a call to the
* function. Unlike "raw_const", this includes any locale parameters
* required by the function. *)
const: term,
(* Raw constant for the function. Existence of this constant in another
* function's body indicates that that function calls this one. *)
raw_const: term,
(* Arguments of the function, in order, excluding measure variables. *)
args : (string * typ) list,
(* Return type of the function ("unit" is used for void). *)
return_type : typ,
(* Function calls. Mutually recursive calls go in rec_callees. *)
callees : symset,
rec_callees : symset,
(* Definition of the function. *)
definition : thm,
(* corres theorem for the function. (TrueI when phase = CP.) *)
corres_thm : thm,
(* monad_mono theorem for the function, if it is recursive. *)
mono_thm : thm option,
(* Is this function actually being translated, or are we just
* wrapping the SIMPL code? *)
is_simpl_wrapper : bool,
(* Is this function generated by AutoCorres as a placeholder for
* a function we didn't have the source code to? *)
invented_body : bool
};
val function_info_upd_name : string -> function_info -> function_info;
val function_info_upd_phase : phase -> function_info -> function_info;
(* Also updates raw_const. *)
val function_info_upd_const : term -> function_info -> function_info;
val function_info_upd_args : (string * typ) list -> function_info -> function_info;
val function_info_upd_return_type : typ -> function_info -> function_info;
val function_info_upd_callees : symset -> function_info -> function_info;
val function_info_upd_rec_callees : symset -> function_info -> function_info;
val function_info_upd_definition : thm -> function_info -> function_info;
val function_info_upd_mono_thm : thm option -> function_info -> function_info;
val function_info_upd_corres_thm : thm -> function_info -> function_info;
val function_info_upd_invented_body : bool -> function_info -> function_info;
val function_info_upd_is_simpl_wrapper : bool -> function_info -> function_info;
(* Convenience getters. *)
val is_function_recursive : function_info -> bool;
val all_callees : function_info -> symset;
(* Generate initial function_info from the C Parser's output. *)
val init_function_info : Proof.context -> string -> function_info Symtab.table;
type call_graph_info = {
(* Topologically sorted function calls, in dependency order.
* Each sub-list represents one function or recursive function group. *)
topo_sorted_functions : symset list,
(* Table mapping raw_consts to functions. *)
const_to_function : string Termtab.table,
(* Table mapping each recursive function to its recursive function group.
* Non-recursive functions do not appear in the table. *)
recursive_group_of : symset Symtab.table
};
(* Calculate call-graph information.
* Also updates the callees and rec_callees entries of its inputs,
* which are assumed to have outdated callee info.
*
* Ideally, we'd also have a pre_function_info type that doesn't have
* outdated callees, but dealing with ML records is annoying. *)
val calc_call_graph : function_info Symtab.table -> call_graph_info * function_info Symtab.table;
end;
structure FunctionInfo2 : FUNCTION_INFO2 =
struct
datatype phase = CP | L1 | L2 | HL | WA | TS;
fun string_of_phase CP = "CP"
| string_of_phase L1 = "L1"
| string_of_phase L2 = "L2"
| string_of_phase HL = "HL"
| string_of_phase WA = "WA"
| string_of_phase TS = "TS";
fun encode_phase CP = 0
| encode_phase L1 = 1
| encode_phase L2 = 2
| encode_phase HL = 3
| encode_phase WA = 4
| encode_phase TS = 5;
val phase_ord = int_ord o apply2 encode_phase;
structure Phasetab = Table(
type key = phase
val ord = phase_ord);
type function_info = {
name : string,
phase : phase,
const : term,
raw_const : term,
args : (string * typ) list,
return_type : typ,
callees : symset,
rec_callees : symset,
definition : thm,
mono_thm : thm option,
corres_thm : thm,
invented_body : bool,
is_simpl_wrapper : bool
};
(* We use FunctionalRecordUpdate internally to define the setters *)
open FunctionalRecordUpdate;
local
fun from name phase const raw_const args return_type callees rec_callees
definition mono_thm corres_thm is_simpl_wrapper invented_body =
{ name = name,
phase = phase,
const = const,
raw_const = raw_const,
args = args,
return_type = return_type,
callees = callees,
rec_callees = rec_callees,
definition = definition,
mono_thm = mono_thm,
corres_thm = corres_thm,
is_simpl_wrapper = is_simpl_wrapper,
invented_body = invented_body };
fun from' invented_body is_simpl_wrapper corres_thm mono_thm definition
rec_callees callees return_type args raw_const const phase name =
{ name = name,
phase = phase,
const = const,
raw_const = raw_const,
args = args,
return_type = return_type,
callees = callees,
rec_callees = rec_callees,
definition = definition,
mono_thm = mono_thm,
corres_thm = corres_thm,
is_simpl_wrapper = is_simpl_wrapper,
invented_body = invented_body };
fun to f { name,
phase,
const,
raw_const,
args,
return_type,
callees,
rec_callees,
definition,
mono_thm,
corres_thm,
is_simpl_wrapper,
invented_body } =
f name phase const raw_const args return_type callees rec_callees
definition mono_thm corres_thm is_simpl_wrapper invented_body;
fun update x = makeUpdate13 (from, from', to) x;
in
fun function_info_upd_name name pinfo = update pinfo (U#name name) $$;
fun function_info_upd_phase phase pinfo = update pinfo (U#phase phase) $$;
fun function_info_upd_const_ const pinfo = update pinfo (U#const const) $$;
fun function_info_upd_raw_const_ raw_const pinfo = update pinfo (U#raw_const raw_const) $$;
fun function_info_upd_args args pinfo = update pinfo (U#args args) $$;
fun function_info_upd_return_type return_type pinfo = update pinfo (U#return_type return_type) $$;
fun function_info_upd_callees callees pinfo = update pinfo (U#callees callees) $$;
fun function_info_upd_rec_callees rec_callees pinfo = update pinfo (U#rec_callees rec_callees) $$;
fun function_info_upd_definition definition pinfo = update pinfo (U#definition definition) $$;
fun function_info_upd_mono_thm mono_thm pinfo = update pinfo (U#mono_thm mono_thm) $$;
fun function_info_upd_corres_thm corres_thm pinfo = update pinfo (U#corres_thm corres_thm) $$;
fun function_info_upd_is_simpl_wrapper is_simpl_wrapper pinfo = update pinfo (U#is_simpl_wrapper is_simpl_wrapper) $$;
fun function_info_upd_invented_body invented_body pinfo = update pinfo (U#invented_body invented_body) $$;
fun function_info_upd_const t = function_info_upd_const_ t o function_info_upd_raw_const_ (head_of t);
end;
fun is_function_recursive { rec_callees, ... } = not (Symset.is_empty rec_callees);
fun all_callees { rec_callees, callees, ... } = Symset.union rec_callees callees;
type call_graph_info = {
topo_sorted_functions : symset list,
const_to_function : string Termtab.table,
recursive_group_of : symset Symtab.table
};
fun calc_call_graph fn_infos = let
val const_to_function =
Symtab.dest fn_infos
|> map (fn (name, info) => (#raw_const info, name))
|> Termtab.make;
(* Get a function's direct callees, based on the list of constants that appear
* in its definition. *)
fun get_direct_callees fn_info = let
val body =
#definition fn_info
|> Thm.concl_of
|> Utils.rhs_of_eq;
in
(* Ignore function bodies if we are using SIMPL wrappers. *)
if #is_simpl_wrapper fn_info then [] else
Term.fold_aterms (fn t => fn a =>
(Termtab.lookup const_to_function t
|> Option.map single
|> the_default []) @ a) body []
|> distinct (op =)
end;
(* Call graph of all functions. *)
val fn_callees_lists = fn_infos |> Symtab.map (K get_direct_callees);
(* Add each function to its own callees to get a complete inverse *)
val fn_callers_lists = flip_symtab (Symtab.map cons fn_callees_lists);
val topo_sorted_functions =
Topo_Sort.topo_sort {
cmp = String.compare,
graph = Symtab.lookup fn_callees_lists #> the,
converse = Symtab.lookup fn_callers_lists #> the
} (Symtab.keys fn_callees_lists |> sort String.compare)
|> map Symset.make;
val fn_callees = Symtab.map (K Symset.make) fn_callees_lists;
fun is_recursive_singleton f =
Symset.contains (Utils.the' ("is_recursive_singleton: " ^ f)
(Symtab.lookup fn_callees f)) f;
val recursive_group_of =
topo_sorted_functions
|> maps (fn f_group =>
(* Exclude non-recursive functions *)
if Symset.card f_group = 1 andalso not (is_recursive_singleton (hd (Symset.dest f_group)))
then []
else Symset.dest f_group ~~ replicate (Symset.card f_group) f_group)
|> Symtab.make;
(* Now update callee info. *)
fun maybe_symset NONE = Symset.empty
| maybe_symset (SOME x) = x;
val fn_infos' =
fn_infos |> Symtab.map (fn f => let
val (rec_callees, callees) =
Symset.dest (Utils.the' ("not in fn_callees: " ^ f) (Symtab.lookup fn_callees f))
|> List.partition (Symset.contains (maybe_symset (Symtab.lookup recursive_group_of f)));
in function_info_upd_callees (Symset.make callees) o
function_info_upd_rec_callees (Symset.make rec_callees) end);
in ({ topo_sorted_functions = topo_sorted_functions,
const_to_function = const_to_function,
recursive_group_of = recursive_group_of
}, fn_infos')
end;
fun init_function_info ctxt filename = let
val thy = Proof_Context.theory_of ctxt;
val prog_info = ProgramInfo.get_prog_info ctxt filename;
val csenv = #csenv prog_info;
(* Get information about a single function. *)
fun gen_fn_info name (return_ctype, _, carg_list) = let
(* Convert C Parser return type into a HOL return type. *)
val return_type =
if return_ctype = Absyn.Void then
@{typ unit}
else
CalculateState.ctype_to_typ (thy, return_ctype);
(* Convert arguments into a list of (name, HOL type) pairs. *)
val arg_list = map (fn v =>
(ProgramAnalysis.get_mname v |> MString.dest,
CalculateState.ctype_to_typ (thy, ProgramAnalysis.get_vi_type v))
) carg_list;
(*
* Get constant, type signature and definition of the function.
*
* The definition may not exist if the function is declared "extern", but
* never defined. In this case, we replace the body of the function with
* what amounts to a "fail" command. Any C body is a valid refinement of
* this, allowing our abstraction to succeed.
*)
val const = Utils.get_term ctxt (name ^ "_'proc");
val myvars_typ = #state_type prog_info;
val (definition, invented) =
(Proof_Context.get_thm ctxt (name ^ "_body_def"), false)
handle ERROR _ =>
(Thm.instantiate ([((("'a", 0), ["HOL.type"]), Thm.ctyp_of ctxt myvars_typ)], [])
@{thm undefined_function_body_def}, true);
in {
name = name,
phase = CP,
args = arg_list,
return_type = return_type,
const = const,
raw_const = const,
callees = Symset.empty, (* filled in later *)
rec_callees = Symset.empty,
definition = definition,
mono_thm = NONE,
corres_thm = @{thm TrueI},
is_simpl_wrapper = false,
invented_body = invented
}
end
val raw_infos = ProgramAnalysis.get_fninfo csenv
|> Symtab.dest
|> map (uncurry gen_fn_info);
(* We discard the call graph info here.
* After calling init_function_info, we often want to change some of the entries,
* which usually requires recalculating it anyway. *)
val (_, fn_infos) =
calc_call_graph (Symtab.make (map (fn info => (#name info, info)) raw_infos));
in
fn_infos
end;
end; (* structure FunctionInfo *)
(* Save function information into the theory. *)
structure AutoCorresFunctionInfo = Theory_Data(
type T = FunctionInfo.fn_info Symtab.table;
val empty = Symtab.empty;
val extend = I;
fun merge (l, r) =
Symtab.merge (fn _ => true) (l, r);
)

View File

@ -75,10 +75,10 @@ fun convert_local_vars name_map term [] = ([], term)
convert_local_vars name_map term vars
(* Get the set of variables a function accepts and returns. *)
fun get_fn_input_output_vars prog_info fn_info fn_name =
fun get_fn_input_output_vars prog_info l1_infos fn_name =
let
val fn_def = FunctionInfo.get_phase_info fn_info FunctionInfo.L1 fn_name
val inputs = #args fn_def |> Varset.make
val fn_info = the (Symtab.lookup l1_infos fn_name);
val inputs = #args fn_info |> Varset.make;
(* Get the return type of a function. *)
val return_ctype =
@ -90,15 +90,15 @@ let
empty_set
else
make_set [(NameGeneration.return_var_name return_ctype |> MString.dest,
#return_type fn_def)]
#return_type fn_info)]
in
(inputs, outputs)
end
(* Get the return variable of a particular function. *)
fun get_ret_var prog_info fn_info fn_name =
fun get_ret_var prog_info l1_infos fn_name =
let
val (_, outputs) = get_fn_input_output_vars prog_info fn_info fn_name
val (_, outputs) = get_fn_input_output_vars prog_info l1_infos fn_name
in
hd ((Varset.dest outputs) @ [("void", @{typ unit})])
end
@ -427,7 +427,7 @@ fun l1call_function_const t = case strip_comb t |> apsnd rev of
* In particular, we break down the structure of the program and parse the
* usage of local variables in all expressions and modifies clauses.
*)
fun parse_l1 ctxt prog_info fn_info name_map term =
fun parse_l1 ctxt prog_info l1_infos l1_call_info name_map term =
case term of
(Const (@{const_name "L1_skip"}, _)) =>
Modify (term,
@ -445,12 +445,12 @@ fun parse_l1 ctxt prog_info fn_info name_map term =
end
| (Const (@{const_name "L1_seq"}, _) $ lhs $ rhs) =>
Seq (term, parse_l1 ctxt prog_info fn_info name_map lhs,
parse_l1 ctxt prog_info fn_info name_map rhs)
Seq (term, parse_l1 ctxt prog_info l1_infos l1_call_info name_map lhs,
parse_l1 ctxt prog_info l1_infos l1_call_info name_map rhs)
| (Const (@{const_name "L1_catch"}, _) $ lhs $ rhs) =>
Catch (term, parse_l1 ctxt prog_info fn_info name_map lhs,
parse_l1 ctxt prog_info fn_info name_map rhs)
Catch (term, parse_l1 ctxt prog_info l1_infos l1_call_info name_map lhs,
parse_l1 ctxt prog_info l1_infos l1_call_info name_map rhs)
| (Const (@{const_name "L1_guard"}, _) $ c) =>
let
@ -468,8 +468,8 @@ fun parse_l1 ctxt prog_info fn_info name_map term =
val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map cond
in
Condition (term, (parsed_expr, make_set read_vars, is_globals_reader),
parse_l1 ctxt prog_info fn_info name_map lhs,
parse_l1 ctxt prog_info fn_info name_map rhs)
parse_l1 ctxt prog_info l1_infos l1_call_info name_map lhs,
parse_l1 ctxt prog_info l1_infos l1_call_info name_map rhs)
end
| (Const (@{const_name "L1_call"}, L1_call_type)
@ -488,13 +488,13 @@ fun parse_l1 ctxt prog_info fn_info name_map term =
(* Get the name of the variable the return value of the function will
* be placed into. *)
val dest_fn = FunctionInfo.get_function_from_const fn_info
(l1call_function_const dest_fn_term)
val dest_fn = Termtab.lookup (#const_to_function l1_call_info) (l1call_function_const dest_fn_term)
|> Utils.the' ("Unknown function " ^ quote (@{make_string} dest_fn_term))
|> Symtab.lookup l1_infos |> the
val dest_fn_name = #name dest_fn
(* Parse the return arguments. *)
val ret_var = get_ret_var prog_info fn_info (#name dest_fn)
val ret_var = get_ret_var prog_info l1_infos (#name dest_fn)
val parsed_clause =
parse_modify ctxt prog_info name_map (betapply (ret_extract, Free ("_dummy_state", #state_type prog_info)))
|> map (fn (target_var, read_vars, globals_read, expr) =>
@ -517,7 +517,7 @@ fun parse_l1 ctxt prog_info fn_info name_map term =
val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map cond;
in
While (term, (parsed_expr, make_set read_vars, is_globals_reader),
parse_l1 ctxt prog_info fn_info name_map body)
parse_l1 ctxt prog_info l1_infos l1_call_info name_map body)
end
| (Const (@{const_name "L1_init"}, _) $ setter) =>
@ -538,7 +538,7 @@ fun parse_l1 ctxt prog_info fn_info name_map term =
Fail term
| (Const (@{const_name "L1_recguard"}, _) $ var $ body) =>
RecGuard (term, parse_l1 ctxt prog_info fn_info name_map body)
RecGuard (term, parse_l1 ctxt prog_info l1_infos l1_call_info name_map body)
| other => Utils.invalid_term "a L1 term" other
@ -802,7 +802,8 @@ fun inject_return_vals ctxt prog_info name_map needed_returns allow_excess throw
fun do_conv
(ctxt : Proof.context)
prog_info
fn_info
(l1_infos : FunctionInfo2.function_info Symtab.table)
(l1_call_info : FunctionInfo2.call_graph_info)
name_map
(fn_vars : varset)
(callee_proofs : (bool * term * thm) Symtab.table)
@ -929,9 +930,11 @@ in
(* Convert LHS and RHS. *)
val ret_vars = rhs_live INTER lhs_modified
val (lhs_reads, lhs_rets, new_lhs, lhs_thm)
= do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs ret_vars true throw_vars lhs
= do_conv ctxt prog_info l1_infos l1_call_info name_map
fn_vars callee_proofs ret_vars true throw_vars lhs
val (rhs_reads, rhs_rets, new_rhs, rhs_thm)
= do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs needed_vars allow_excess throw_vars rhs
= do_conv ctxt prog_info l1_infos l1_call_info name_map
fn_vars callee_proofs needed_vars allow_excess throw_vars rhs
val block_reads = lhs_reads UNION (rhs_reads MINUS lhs_modified)
(* Reconstruct body to support our input tuple. *)
@ -964,9 +967,11 @@ in
(* Convert LHS and RHS. *)
val lhs_throws = rhs_live INTER lhs_modified
val (lhs_reads, lhs_rets, new_lhs, lhs_thm)
= do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs (needed_vars) false lhs_throws lhs
= do_conv ctxt prog_info l1_infos l1_call_info name_map
fn_vars callee_proofs (needed_vars) false lhs_throws lhs
val (rhs_reads, _, new_rhs, rhs_thm)
= do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs (needed_vars) false throw_vars rhs
= do_conv ctxt prog_info l1_infos l1_call_info name_map
fn_vars callee_proofs (needed_vars) false throw_vars rhs
val block_reads = lhs_reads UNION (rhs_reads MINUS lhs_throws)
(* Reconstruct body to support our input tuple. *)
@ -994,7 +999,8 @@ in
let
(* Convert body. *)
val (body_reads, vars_returned, new_body, body_thm) =
do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs needed_vars false throw_vars body
do_conv ctxt prog_info l1_infos l1_call_info name_map
fn_vars callee_proofs needed_vars false throw_vars body
(* Get recguard variable. *)
val (_ $ var $ _) = l1_term
@ -1014,9 +1020,11 @@ in
(* Convert LHS and RHS. *)
val requested_vars = needed_vars INTER modified_vars
val (lhs_reads, _, new_lhs, lhs_thm)
= do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs requested_vars false throw_vars lhs
= do_conv ctxt prog_info l1_infos l1_call_info name_map
fn_vars callee_proofs requested_vars false throw_vars lhs
val (rhs_reads, _, new_rhs, rhs_thm)
= do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs requested_vars false throw_vars rhs
= do_conv ctxt prog_info l1_infos l1_call_info name_map
fn_vars callee_proofs requested_vars false throw_vars rhs
val block_reads = lhs_reads UNION rhs_reads UNION read_vars
(* Generate the final term. *)
@ -1033,7 +1041,8 @@ in
(* Convert body. *)
val loop_iterators = (needed_vars UNION live_vars) INTER modified_vars
val (body_reads, _, new_body, body_thm) =
do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs loop_iterators false throw_vars body
do_conv ctxt prog_info l1_infos l1_call_info name_map
fn_vars callee_proofs loop_iterators false throw_vars body
val (body_term, _, body_modifies) = get_node_data body
(* Reconstruct body to support our input tuple. *)
@ -1082,8 +1091,8 @@ in
| _ => raise TERM ("local_var_extract: strange function call", [dest_fn])
(* Get destination function. *)
val dest_fn = FunctionInfo.get_function_from_const fn_info
(l1call_function_const dest_fn)
val dest_fn = Termtab.lookup (#const_to_function l1_call_info) (l1call_function_const dest_fn)
|> Option.mapPartial (Symtab.lookup l1_infos)
(* Lookup the callee proof, if it exists. *)
val callee_proof = Option.mapPartial
@ -1102,7 +1111,7 @@ in
| SOME (is_recursive, callee_free, callee_thm) =>
(let
(* Get information about the function. *)
val dest_fn = FunctionInfo.get_1_phase_info (the dest_fn) FunctionInfo.L1
val dest_fn = the (dest_fn)
val args = #args dest_fn
(* Parse argument setup. *)
@ -1205,12 +1214,12 @@ end
val measureT = @{typ nat};
(* Get the expected type of a function from its name. *)
fun get_expected_l2_fn_type prog_info fn_info fn_name =
fun get_expected_l2_fn_type prog_info l1_infos fn_name =
let
val fn_def = FunctionInfo.get_phase_info fn_info FunctionInfo.L1 fn_name
val fn_params_typ = measureT :: map snd (#args fn_def)
val fn_info = the (Symtab.lookup l1_infos fn_name)
val fn_params_typ = measureT :: map snd (#args fn_info)
in
fn_params_typ ---> mk_l2monadT (#globals_type prog_info) (#return_type fn_def) @{typ unit}
fn_params_typ ---> mk_l2monadT (#globals_type prog_info) (#return_type fn_info) @{typ unit}
end
(* Avoid clashes with fixed free variables such as C-parser's "symbol_table".
@ -1218,29 +1227,29 @@ end
fun to_free_var_name lthy x = if Variable.is_fixed lthy x then to_free_var_name lthy (x ^ "'") else x
(* Get arguments passed into the function. *)
fun get_expected_l2_fn_args lthy prog_info fn_info fn_name =
fun get_expected_l2_fn_args lthy prog_info l1_infos fn_name =
let
val fn_def = FunctionInfo.get_phase_info fn_info FunctionInfo.L1 fn_name
val fn_def = the (Symtab.lookup l1_infos fn_name)
in
map (apfst (to_free_var_name lthy o ProgramInfo.demangle_name)) (#args fn_def)
end
fun get_expected_l2_fn_thm prog_info fn_info ctxt fn_name fn_free fn_args _ measure_var =
fun get_expected_l2_fn_thm prog_info l1_infos ctxt fn_name fn_free fn_args _ measure_var =
let
(* Fetch input/output params for monad type. *)
val (input_params, output_params) = get_fn_input_output_vars prog_info fn_info fn_name
val (input_params, output_params) = get_fn_input_output_vars prog_info l1_infos fn_name
(* Get mapping from internal variable names that we use to the names passed
* in "fn_args". *)
val fn_def = FunctionInfo.get_phase_info fn_info FunctionInfo.L1 fn_name
val args = map fst (#args fn_def)
val fn_info = the (Symtab.lookup l1_infos fn_name)
val args = map fst (#args fn_info)
val m = Symtab.make (args ~~ fn_args)
fun name_map (n, _) = Symtab.lookup m n |> the
in
mk_corresXF_prop (Proof_Context.theory_of ctxt) prog_info name_map
output_params empty_set input_params
(betapplys (fn_free, measure_var :: fn_args))
(betapply (#const fn_def, measure_var))
(betapply (#const fn_info, measure_var))
end
(* Extract the abstract body of a L2corres theorem. *)
@ -1249,17 +1258,17 @@ fun get_body_of_thm ctxt thm =
|> HOLogic.dest_Trueprop
|> dest_L2corres_term_abs
fun get_l2corres_thm ctxt prog_info fn_info do_opt trace_opt fn_name
fun get_l2corres_thm ctxt prog_info l1_infos l1_call_info do_opt trace_opt fn_name
callee_terms fn_args l1_term init_unfold = let
(* Get information about the return variable. *)
val fn_def = FunctionInfo.get_phase_info fn_info FunctionInfo.L1 fn_name
val fn_info = the (Symtab.lookup l1_infos fn_name)
(* Get return variables. *)
val (fn_input_vars, fn_ret_vars) = get_fn_input_output_vars prog_info fn_info fn_name
val (fn_input_vars, fn_ret_vars) = get_fn_input_output_vars prog_info l1_infos fn_name
(* Get mapping from internal variable names to external arguments. *)
val m = Symtab.make (map fst (#args fn_def) ~~ fn_args)
val m = Symtab.make (map fst (#args fn_info) ~~ fn_args)
fun name_map_ext (n, T) = Symtab.lookup m n |> the
fun name_map_internal (n, T) = Free ("lvar'" ^ n, T)
@ -1285,7 +1294,7 @@ fun get_l2corres_thm ctxt prog_info fn_info do_opt trace_opt fn_name
|> Utils.unsafe_unvarify
(* Do basic parsing. *)
val parsed_term = parse_l1 ctxt prog_info fn_info name_map_internal source_term
val parsed_term = parse_l1 ctxt prog_info l1_infos l1_call_info name_map_internal source_term
(* Get a list of all variables either read from or written to. *)
val all_vars = Prog.fold_prog
@ -1318,7 +1327,7 @@ fun get_l2corres_thm ctxt prog_info fn_info do_opt trace_opt fn_name
(* Ensure that the only live variables at the beginning of the function are
* those that are function inputs. *)
val fn_inputs = get_node_data liveness_data
val fn_params = #args fn_def
val fn_params = #args fn_info
val excess_inputs = fn_inputs MINUS (make_set fn_params)
val _ =
if excess_inputs <> empty_set then
@ -1329,8 +1338,9 @@ fun get_l2corres_thm ctxt prog_info fn_info do_opt trace_opt fn_name
()
(* Do the conversion. *)
val (vars_read, _, term, thm) = do_conv ctxt prog_info fn_info name_map_internal fn_input_vars
callee_terms fn_ret_vars false empty_set input_term
val (vars_read, _, term, thm) =
do_conv ctxt prog_info l1_infos l1_call_info name_map_internal fn_input_vars
callee_terms fn_ret_vars false empty_set input_term
(* Replace our internal terms with external terms. *)
val replacements = (map name_map_internal fn_params) ~~ (map name_map_ext fn_params)
@ -1396,14 +1406,14 @@ end
* Note that this is also used for subsequent L2-based phases.
*)
fun l2_monad_mono lthy (fn_infos: FunctionInfo.phase_info Symtab.table) =
fun l2_monad_mono lthy (l2_infos: FunctionInfo2.function_info Symtab.table) =
let
(*
* For the induction, we need to have the form
* "\<And> m. (ALL a b... f m a b...) /\ (ALL a b... g m a b...) /\ ..."
* and this gets annoying pretty quickly. But it is probably unavoidable.
*)
val (fn_names, fn_defs) = split_list (Symtab.dest fn_infos);
val (fn_names, fn_defs) = split_list (Symtab.dest l2_infos);
val measure = Free ("rec_measure'", measureT)
fun make_mono_step_stmt current_def =
let
@ -1469,13 +1479,13 @@ in
end
(* For functions that are not translated, just generate a trivial wrapper. *)
fun mk_l2corres_call_simpl_thm prog_info fn_info ctxt fn_name fn_args = let
val fn_def = FunctionInfo.get_phase_info fn_info FunctionInfo.L1 fn_name
fun mk_l2corres_call_simpl_thm prog_info l1_infos ctxt fn_name fn_args = let
val fn_def = the (Symtab.lookup l1_infos fn_name)
val const = #const fn_def
val args = #args fn_def
(* Get return variables. *)
val (fn_input_vars, fn_ret_vars) = get_fn_input_output_vars prog_info fn_info fn_name
val (fn_input_vars, fn_ret_vars) = get_fn_input_output_vars prog_info l1_infos fn_name
(* Get mapping from internal variable names to external arguments. *)
val m = Symtab.make (map fst args ~~ fn_args)
fun name_map_ext (n, T) = Symtab.lookup m n |> the
@ -1502,57 +1512,50 @@ fun mk_l2corres_call_simpl_thm prog_info fn_info ctxt fn_name fn_args = let
fun convert
(lthy: local_theory) (* must contain at least L1 callee defs, but no other requirements *)
(prog_info: ProgramInfo.prog_info)
(fn_info: FunctionInfo.fn_info) (* legacy for aux data, won't actually contain L1 data *)
(* L1 data for f and callees, may include extra funcs *)
(l1_info: (FunctionInfo.phase_info * thm) Symtab.table)
(l1_infos: FunctionInfo2.function_info Symtab.table)
(do_opt: bool)
(trace_opt: bool)
(l2_function_name: string -> string)
(f_name: string)
: thm * (string * typ) list = let
val callee_names = FunctionInfo.get_function_callees fn_info f_name;
val _ = filter (fn f => not (isSome (Symtab.lookup l1_info f))) callee_names
(* FIXME: refactor? *)
val (l1_call_info, l1_infos) = FunctionInfo2.calc_call_graph l1_infos;
val f_info = Utils.the' ("L2 conversion missing info for " ^ f_name)
(Symtab.lookup l1_infos f_name);
val callee_names = FunctionInfo2.all_callees f_info;
val _ = filter (fn f => not (isSome (Symtab.lookup l1_infos f))) (Symset.dest callee_names)
|> (fn bad => if null bad then () else
error ("L2 conversion missing callees for " ^ f_name ^ ": " ^ commas bad));
(* TODO sanity check:
- the required L1 defs exist *)
(* Place the L1 data where the existing code expects it. *)
val fn_info' = fn_info |> FunctionInfo.add_phases (fn f =>
K (Symtab.lookup l1_info f |> Option.map fst));
(* Fix measure variable. *)
val ([measure_var_name], lthy') = Variable.variant_fixes ["rec_measure'"] lthy;
val measure_var = Free (measure_var_name, measureT);
(* Fix argument variables. *)
val (f_l1_info, f_l1corres) = the (Symtab.lookup l1_info f_name);
val f_args = #args f_l1_info;
val f_args = #args f_info;
val (arg_names, lthy'') = Variable.variant_fixes (map fst f_args) lthy';
val arg_frees = arg_names ~~ map snd f_args;
(* Add callee assumptions. Note that our define code has to use the same assumption order. *)
val group = Symset.make (FunctionInfo.get_recursive_group fn_info f_name);
val callee_terms =
map (fn callee => (callee, Symset.contains group callee)) callee_names;
val (lthy''', export_thm, callee_terms) =
AutoCorresUtil.assume_called_functions_corres lthy''
callee_terms
(get_expected_l2_fn_type prog_info fn_info')
(get_expected_l2_fn_thm prog_info fn_info')
(get_expected_l2_fn_args lthy prog_info fn_info')
AutoCorresUtil2.assume_called_functions_corres lthy''
(#callees f_info) (#rec_callees f_info)
(get_expected_l2_fn_type prog_info l1_infos)
(get_expected_l2_fn_thm prog_info l1_infos)
(get_expected_l2_fn_args lthy prog_info l1_infos)
l2_function_name
measure_var;
val f_l1_def = Utils.named_cterm_instantiate lthy'''
[("rec_measure'", Thm.cterm_of lthy''' measure_var)]
(#definition f_l1_info)
(#definition f_info)
val (thm, opt_traces) =
if #is_simpl_wrapper (FunctionInfo.get_function_info fn_info' f_name)
then (mk_l2corres_call_simpl_thm prog_info fn_info' lthy''' f_name (map Free arg_frees), [])
else get_l2corres_thm lthy''' prog_info fn_info' do_opt trace_opt f_name
if #is_simpl_wrapper f_info
then (mk_l2corres_call_simpl_thm prog_info l1_infos lthy''' f_name (map Free arg_frees), [])
else get_l2corres_thm lthy''' prog_info l1_infos l1_call_info do_opt trace_opt f_name
(Symtab.make callee_terms) (map Free f_args)
(betapply (#const f_l1_info, measure_var))
(betapply (#const f_info, measure_var))
f_l1_def;
in (Morphism.thm export_thm thm,
(* Provide the fixed vars so the user can generalize/instantiate them *)
@ -1566,11 +1569,11 @@ fun define
(lthy: local_theory)
(filename: string)
(prog_info: ProgramInfo.prog_info)
(fn_info: FunctionInfo.fn_info) (* required to have L1 *)
(l2_callees: (FunctionInfo.phase_info * thm) Symtab.table) (* L2 callees & corres thms *)
(l1_infos: FunctionInfo2.function_info Symtab.table)
(l2_callees: FunctionInfo2.function_info Symtab.table)
(l2_function_name: string -> string)
(funcs: (string * thm * (string * typ) list) list) (* name, corres, arg frees *)
: (FunctionInfo.phase_info * thm) Symtab.table * local_theory = let
: FunctionInfo2.function_info Symtab.table * local_theory = let
(* FIXME: dedup with convert *)
(* FIXME: pass this from assume_called_functions_corres, etc. *)
@ -1582,7 +1585,7 @@ fun define
fun prepare_fn_body (fn_name, corres_thm, arg_frees) = let
val _ = @{trace} ("prepare_fn_body", fn_name, corres_thm);
val @{term_pat "Trueprop (L2corres _ _ _ _ ?body _)"} = Thm.concl_of corres_thm;
val (callees, recursive_callees) = AutoCorresUtil.get_callees fn_info fn_name;
val (callees, recursive_callees) = AutoCorresUtil2.get_callees l1_infos fn_name;
val calls = map (fn c => Var (guess_callee_var corres_thm c)) callees;
val recursive_calls = map (fn c => Var (guess_callee_var corres_thm c)) recursive_callees;
@ -1604,68 +1607,75 @@ fun define
|> fold lambda (rev calls);
in abs_body end;
val fn_info' = FunctionInfo.add_phases (fn f => K (Option.map fst (Symtab.lookup l2_callees f))) fn_info;
val funcs' = funcs |>
map (fn (name, thm, frees) =>
(name, (* FIXME: define_funcs needs this currently *)
(Thm.generalize ([], map fst frees) (Thm.maxidx_of thm + 1) thm,
prepare_fn_body (name, thm, frees), [])));
val _ = @{trace} ("L2.define", map (fn (name, (thm, body, _)) => (name, thm, Thm.cterm_of lthy body)) funcs');
val (corres_thms, (), lthy') =
AutoCorresUtil.define_funcs FunctionInfo.L2 filename fn_info'
prepare_fn_body (name, thm, frees))));
val _ = @{trace} ("L2.define", map (fn (name, (thm, body)) => (name, thm, Thm.cterm_of lthy body)) funcs');
val (new_thms, (), lthy') =
AutoCorresUtil2.define_funcs FunctionInfo2.L2 filename l1_infos
l2_function_name
(get_expected_l2_fn_type prog_info fn_info')
(get_expected_l2_fn_thm prog_info fn_info')
(get_expected_l2_fn_args lthy prog_info fn_info')
(get_expected_l2_fn_type prog_info l1_infos)
(get_expected_l2_fn_thm prog_info l1_infos)
(get_expected_l2_fn_args lthy prog_info l1_infos)
@{thm L2corres_recguard_0}
lthy (Symtab.map (K snd) l2_callees) ()
lthy (Symtab.map (K #corres_thm) l2_callees) ()
funcs';
val f_names = map (fn (name, _, _) => name) funcs;
val new_phases = map2 (fn f_name => fn corres => let
val old_phase = FunctionInfo.get_phase_info fn_info' FunctionInfo.L1 f_name;
val def = the (AutoCorresData.get_def (Proof_Context.theory_of lthy') filename "L2def" f_name);
val f_const = Utils.get_term lthy' (l2_function_name f_name);
in old_phase
|> FunctionInfo.phase_info_upd_phase FunctionInfo.L2
|> FunctionInfo.phase_info_upd_definition def
|> FunctionInfo.phase_info_upd_const f_const
|> FunctionInfo.phase_info_upd_mono_thm NONE (* TODO *)
end) f_names corres_thms;
val new_infos = Symtab.map (fn f_name => fn (const, def, corres_thm) => let
val old_info = the (Symtab.lookup l1_infos f_name);
in old_info
|> FunctionInfo2.function_info_upd_phase FunctionInfo2.L2
|> FunctionInfo2.function_info_upd_const const
|> FunctionInfo2.function_info_upd_definition def
|> FunctionInfo2.function_info_upd_corres_thm corres_thm
|> FunctionInfo2.function_info_upd_mono_thm NONE (* added later *)
end) new_thms;
(* FIXME: return traces *)
in (Symtab.make (f_names ~~ (new_phases ~~ corres_thms)), lthy') end;
in (new_infos, lthy') end;
fun symtab_merge tabs = maps Symtab.dest tabs |> Symtab.make;
fun symtab_merge allow_dups tabs =
maps Symtab.dest tabs
|> (if allow_dups then sort_distinct (fast_string_ord o apply2 fst) else I)
|> Symtab.make;
(*
* Translate all functions from L1 to L2 format.
*)
fun translate filename prog_info fn_info
fun translate filename prog_info
(* lazy results from L1 *)
(l1_results: (string * (local_theory * (FunctionInfo.phase_info * thm) Symtab.table) future) list)
(l1_results: (symset * (local_theory * FunctionInfo2.function_info Symtab.table) future) list)
do_opt trace_opt l2_function_name =
(* if there's nothing to translate, we won't have a lthy *)
if null l1_results then [] else
let
val funcs_to_translate = Symtab.keys (FunctionInfo.get_all_functions fn_info);
(* TODO: we should recalculate this from l1_results to take dead-code elim
* into account, but we'd need to do this lazily! *)
val function_groups = map fst l1_results;
(* Results for individual functions *)
val l1_results' = maps (fn (f_names, r) =>
Symset.dest f_names ~~ replicate (Symset.card f_names) r) l1_results;
val get_l1_result = let
val table = Symtab.make l1_results;
val table = Symtab.make l1_results';
in fn f => the' ("missing L1 lazy result for function: " ^ f) (Symtab.lookup table f) end;
(* All conversions can run in parallel.
* Each conversion depends only on the corresponding L1 define phase
* (which necessarily also includes L1 callees). *)
val converted_funcs =
funcs_to_translate |> map (fn f =>
maps Symset.dest function_groups
|> map (fn f =>
(f, Future.fork (fn _ => let
(* L1 info for f and its callees *)
val (r_lthy, f_l1_info) = Future.join (get_l1_result f);
val callee_l1_infos = FunctionInfo.get_function_callees fn_info f
val callee_l1_infos = Symset.dest (FunctionInfo2.all_callees (the (Symtab.lookup f_l1_info f)))
|> map (fn callee => snd (Future.join (get_l1_result callee)));
val l1_infos = symtab_merge (f_l1_info :: callee_l1_infos);
in convert r_lthy prog_info fn_info l1_infos do_opt trace_opt l2_function_name f end)))
val l1_infos = symtab_merge true (f_l1_info :: callee_l1_infos);
in convert r_lthy prog_info l1_infos do_opt trace_opt l2_function_name f end)))
|> Symtab.make;
(* Definitions update lthy sequentially.
@ -1681,20 +1691,19 @@ let
val conv = the' ("didn't convert function: " ^ quote f ^ "??") (Symtab.lookup converted_funcs f);
val (corres_thm, arg_frees) = Future.join conv;
in (f, corres_thm, arg_frees) end) f_names;
(* Add L1 phase results. *)
(* Get L1 phase results (should be a no-op at this point) *)
val (_, l1_infos) = Future.join (get_l1_result (hd f_names));
val fn_info' = FunctionInfo.add_phases (fn f => K (Symtab.lookup l1_infos f |> Option.map fst)) fn_info;
val (new_defs, lthy') = define lthy filename prog_info fn_info'
(defined_so_far: (FunctionInfo.phase_info * thm) Symtab.table)
val (new_defs, lthy') = define lthy filename prog_info l1_infos
(defined_so_far: FunctionInfo2.function_info Symtab.table)
l2_function_name f_convs;
in (lthy', new_defs, Symtab.merge (K false) (defined_so_far, new_defs)) end);
val function_groups = FunctionInfo.get_topo_sorted_functions fn_info;
(* Get initial lthy from end of L1 defs *)
val l1_def_context = Future.map (fn (lthy, _) => (lthy, Symtab.empty, Symtab.empty))
(snd (hd (rev l1_results)));
(* Chain of intermediate states: (lthy, new_defs, accumulator) *)
val (def_results, _) = Utils.accumulate add_def l1_def_context function_groups;
val (def_results, _) = Utils.accumulate add_def l1_def_context
(map Symset.dest function_groups);
(* Produce a mapping from each function group to its L1 phase_infos and the
* earliest intermediate lthy where it is defined. *)
@ -1703,14 +1712,13 @@ let
|> map (Future.map (fn (lthy, f_defs, _ (* discard accum here *)) => let
(* Add monad_mono proofs. These are done in parallel as well
* (though in practice, they already run pretty quickly). *)
val mono_thms = if FunctionInfo.is_function_recursive fn_info (hd (Symtab.keys f_defs))
then l2_monad_mono lthy (Symtab.map (K fst) f_defs)
val mono_thms = if FunctionInfo2.is_function_recursive (snd (hd (Symtab.dest f_defs)))
then l2_monad_mono lthy f_defs
else Symtab.empty;
val f_defs' = f_defs |> Symtab.map (fn f =>
apfst (FunctionInfo.phase_info_upd_mono_thm (Symtab.lookup mono_thms f)));
FunctionInfo2.function_info_upd_mono_thm (Symtab.lookup mono_thms f));
in (lthy, f_defs') end));
in
(* would be ready to become a symtab, but a list also preserves order *)
function_groups ~~ results
end

View File

@ -48,9 +48,9 @@ fun get_L1corres_monad @{term_pat "L1corres _ _ ?l1_monad _"} = l1_monad
*
* "Call foo_'proc"
*)
fun mk_SIMPL_call_term ctxt prog_info fn_info target_fn =
fun mk_SIMPL_call_term ctxt prog_info target_fn =
@{mk_term "Call ?proc :: (?'s, int, strictc_errortype) com" (proc, 's)}
(FunctionInfo.get_phase_info fn_info FunctionInfo.CP target_fn |> #const, #state_type prog_info)
(#const target_fn, #state_type prog_info)
(*
* Construct a correspondence lemma between a given monadic term and a SIMPL fragment.
@ -70,9 +70,9 @@ fun mk_L1corres_prop prog_info check_termination monad_term simpl_term =
* L1corres ct \<Gamma> <term> (Call foo_'proc)
*
*)
fun mk_L1corres_call_prop ctxt prog_info fn_info check_termination target_fn_name term =
fun mk_L1corres_call_prop ctxt prog_info check_termination target_fn term =
mk_L1corres_prop prog_info check_termination term
(mk_SIMPL_call_term ctxt prog_info fn_info target_fn_name)
(mk_SIMPL_call_term ctxt prog_info target_fn)
|> HOLogic.mk_Trueprop
(*
@ -83,7 +83,8 @@ fun mk_L1corres_call_prop ctxt prog_info fn_info check_termination target_fn_nam
*)
fun simpl_conv'
(prog_info : ProgramInfo.prog_info)
(fn_info : FunctionInfo.fn_info)
(simpl_defs : FunctionInfo2.function_info Symtab.table)
(simpl_calls : FunctionInfo2.call_graph_info)
(ctxt : Proof.context)
(callee_terms : (bool * term * thm) Symtab.table)
(measure_var : term)
@ -91,7 +92,7 @@ fun simpl_conv'
let
fun prove_term subterms base_thm result_term =
let
val subterms' = map (simpl_conv' prog_info fn_info ctxt
val subterms' = map (simpl_conv' prog_info simpl_defs simpl_calls ctxt
callee_terms measure_var) subterms;
val converted_terms = map fst subterms';
val subproofs = map snd subterms';
@ -173,7 +174,7 @@ fun simpl_conv'
| (Const (@{const_name call}, _) $ a $ (fn_const as Const (b, _)) $ c $ (Abs (_, _, Abs (_, _, (Const (@{const_name Basic}, _) $ d))))) =>
let
val state_type = #state_type prog_info
val target_fn_name = FunctionInfo.get_function_from_const fn_info fn_const |> Option.map #name
val target_fn_name = Termtab.lookup (#const_to_function simpl_calls) fn_const
in
case Option.mapPartial (Symtab.lookup callee_terms) target_fn_name of
NONE =>
@ -194,7 +195,9 @@ fun simpl_conv'
* and we can just give an arbitrary value.
*)
val target_fn_name = (the target_fn_name)
val target_rec = FunctionInfo.is_function_recursive fn_info target_fn_name
val target_fn = Utils.the' ("missing SIMPL def for " ^ target_fn_name)
(Symtab.lookup simpl_defs target_fn_name)
val target_rec = FunctionInfo2.is_function_recursive target_fn
val term' =
if is_rec then
term $ (@{term "recguard_dec"} $ measure_var)
@ -296,10 +299,10 @@ end
* about raw definitions instead of more abstract constructs generated
* by the C parser.
*)
fun get_simpl_body ctxt fn_info fn_name =
fun get_simpl_body ctxt simpl_defs fn_name =
let
(* Find the definition of the given function. *)
val simpl_thm = #definition (FunctionInfo.get_phase_info fn_info FunctionInfo.CP fn_name)
val simpl_thm = #definition (the (Symtab.lookup simpl_defs fn_name))
handle ERROR _ => raise FunctionNotFound fn_name;
(* Unfold terms in the body which we don't want to deal with. *)
@ -323,10 +326,11 @@ in
(unfolded_simpl_term, unfolded_simpl_thm, impl_thm)
end
fun get_l1corres_thm prog_info fn_info check_termination ctxt do_opt trace_opt fn_name
callee_terms measure_var = let
fun get_l1corres_thm prog_info simpl_defs simpl_calls check_termination ctxt do_opt trace_opt
fn_name callee_terms measure_var = let
val fn_def = the (Symtab.lookup simpl_defs fn_name);
val thy = Proof_Context.theory_of ctxt
val (simpl_term, simpl_thm, impl_thm) = get_simpl_body ctxt fn_info fn_name
val (simpl_term, simpl_thm, impl_thm) = get_simpl_body ctxt simpl_defs fn_name
(* Fetch stats on pre-converted term. *)
val _ = Statistics.gather ctxt "CParser" fn_name simpl_term
@ -335,7 +339,7 @@ fun get_l1corres_thm prog_info fn_info check_termination ctxt do_opt trace_opt f
* Do the conversion. We receive a new monadic version of the SIMPL
* term and a tactic for proving correspondence.
*)
val (monad, tactic) = simpl_conv' prog_info fn_info ctxt
val (monad, tactic) = simpl_conv' prog_info simpl_defs simpl_calls ctxt
callee_terms measure_var simpl_term
(*
@ -343,7 +347,7 @@ fun get_l1corres_thm prog_info fn_info check_termination ctxt do_opt trace_opt f
* failure when the measure reaches zero. This lets us automatically
* prove termination of the recursive function.
*)
val is_recursive = FunctionInfo.is_function_recursive fn_info fn_name
val is_recursive = FunctionInfo2.is_function_recursive fn_def
val (monad, tactic) =
if is_recursive then
(Utils.mk_term thy @{term "L1_recguard"} [measure_var, monad],
@ -356,7 +360,7 @@ fun get_l1corres_thm prog_info fn_info check_termination ctxt do_opt trace_opt f
* SIMPL body (with folded constants) and the output monad term.
*)
in
mk_L1corres_call_prop ctxt prog_info fn_info check_termination fn_name monad
mk_L1corres_call_prop ctxt prog_info check_termination (the (Symtab.lookup simpl_defs fn_name)) monad
|> Thm.cterm_of ctxt
|> Goal.init
|> (case impl_thm of
@ -380,7 +384,7 @@ fun split_conj thm =
handle THM _ => [thm]
(* Prove monad_mono for recursive functions. *)
fun l1_monad_mono lthy (l1_defs : FunctionInfo.phase_info Symtab.table) =
fun l1_monad_mono lthy (l1_defs : FunctionInfo2.function_info Symtab.table) =
let
val l1_defs' = Symtab.dest l1_defs;
fun mk_stmt [func] = @{mk_term "monad_mono ?f" f} func
@ -412,10 +416,9 @@ end
(* For functions that are not translated, just generate a trivial wrapper. *)
fun mk_l1corres_call_simpl_thm fn_info check_termination ctxt fn_name = let
val info = FunctionInfo.get_phase_info fn_info FunctionInfo.CP fn_name
val const = #const info
val impl_thm = Proof_Context.get_thm ctxt (fn_name ^ "_impl")
fun mk_l1corres_call_simpl_thm check_termination ctxt simpl_def = let
val const = #const simpl_def
val impl_thm = Proof_Context.get_thm ctxt (#name simpl_def ^ "_impl")
val gamma = safe_mk_meta_eq impl_thm |> Thm.concl_of |> Logic.dest_equals
|> fst |> (fn (f $ _) => f | t => raise TERM ("gamma", [t]))
val thm = Utils.named_cterm_instantiate ctxt
@ -435,14 +438,18 @@ fun mk_l1corres_call_simpl_thm fn_info check_termination ctxt fn_name = let
fun convert
(lthy: local_theory)
(prog_info: ProgramInfo.prog_info)
(fn_info: FunctionInfo.fn_info) (* needs CP phase_info for each callee *)
(simpl_defs: FunctionInfo2.function_info Symtab.table)
(check_termination: bool)
(do_opt: bool)
(trace_opt: bool)
(l1_function_name: string -> string)
(f_name: string)
: thm * (string * typ) list = let
val callee_names = FunctionInfo.get_function_callees fn_info f_name;
(* FIXME: refactor? *)
val (simpl_calls, simpl_defs) = FunctionInfo2.calc_call_graph simpl_defs;
val f_info = Utils.the' ("SimplConv: missing SIMPL def for " ^ f_name) (Symtab.lookup simpl_defs f_name);
val callee_names = FunctionInfo2.all_callees f_info;
(* TODO sanity check:
- all SIMPL defs exist *)
@ -453,19 +460,19 @@ fun convert
(* L1corres for f's callees. *)
fun get_l1_fn_assumption ctxt fn_name free _ _ measure_var =
mk_L1corres_call_prop ctxt prog_info fn_info check_termination fn_name (betapply (free, measure_var));
mk_L1corres_call_prop ctxt prog_info check_termination
(Utils.the' ("SimplConv: missing callee def for " ^ fn_name)
(Symtab.lookup simpl_defs fn_name)) (betapply (free, measure_var));
(* Fix measure variable. *)
val ([measure_var_name], lthy') = Variable.variant_fixes ["rec_measure'"] lthy;
val measure_var = Free (measure_var_name, measureT);
(* Add callee assumptions. Note that our define code has to use the same assumption order. *)
val group = Symset.make (FunctionInfo.get_recursive_group fn_info f_name);
val callee_terms =
map (fn callee => (callee, Symset.contains group callee)) callee_names;
val group = Option.getOpt (Symtab.lookup (#recursive_group_of simpl_calls) f_name, Symset.empty);
val (lthy'', export_thm, callee_terms) =
AutoCorresUtil.assume_called_functions_corres lthy'
callee_terms
AutoCorresUtil2.assume_called_functions_corres lthy'
(#callees f_info) (#rec_callees f_info)
(K l1_fn_type)
get_l1_fn_assumption
(K [])
@ -473,9 +480,9 @@ fun convert
measure_var;
val (thm, opt_traces) =
if #is_simpl_wrapper (FunctionInfo.get_function_info fn_info f_name)
then (mk_l1corres_call_simpl_thm fn_info check_termination lthy'' f_name, [])
else get_l1corres_thm prog_info fn_info check_termination lthy'' do_opt trace_opt f_name
if #is_simpl_wrapper f_info
then (mk_l1corres_call_simpl_thm check_termination lthy'' f_info, [])
else get_l1corres_thm prog_info simpl_defs simpl_calls check_termination lthy'' do_opt trace_opt f_name
(Symtab.make callee_terms) measure_var;
in (Morphism.thm export_thm thm,
(* Provide the fixed vars so the user can generalize/instantiate them *)
@ -484,17 +491,18 @@ fun convert
(* Define a previously-converted function (or recursive function group).
* lthy must include all definitions from l1_callees *)
* lthy must include all definitions from l1_callees.
* simpl_defs must include current function set and its immediate callees. *)
fun define
(lthy: local_theory)
(filename: string)
(prog_info: ProgramInfo.prog_info)
(fn_info: FunctionInfo.fn_info) (* doesn't need to have L1 info *)
(simpl_defs: FunctionInfo2.function_info Symtab.table)
(check_termination: bool)
(l1_callees: (FunctionInfo.phase_info * thm) Symtab.table) (* L1 callees & corres thms *)
(l1_callees: FunctionInfo2.function_info Symtab.table)
(l1_function_name: string -> string)
(funcs: (string * thm * (string * typ) list) list) (* name, corres, arg frees *)
: (FunctionInfo.phase_info * thm) Symtab.table * local_theory = let
: FunctionInfo2.function_info Symtab.table * local_theory = let
(* FIXME: dedup with convert *)
(* All L1 functions have the same signature: measure \<Rightarrow> L1_monad *)
@ -503,7 +511,8 @@ fun define
(* L1corres for f's callees. *)
fun get_l1_fn_assumption ctxt fn_name free _ _ measure_var =
mk_L1corres_call_prop ctxt prog_info fn_info check_termination fn_name (betapply (free, measure_var));
mk_L1corres_call_prop ctxt prog_info check_termination
(the (Symtab.lookup simpl_defs fn_name)) (betapply (free, measure_var));
(* FIXME: pass this from assume_called_functions_corres, etc. *)
fun guess_callee_var thm callee = let
@ -513,9 +522,11 @@ fun define
fun prepare_fn_body (fn_name, corres_thm, measure_free) = let
val @{term_pat "Trueprop (L1corres _ _ ?body _)"} = Thm.concl_of corres_thm;
val (callees, recursive_callees) = AutoCorresUtil.get_callees fn_info fn_name;
val calls = map (fn c => Var (guess_callee_var corres_thm c)) callees;
val recursive_calls = map (fn c => Var (guess_callee_var corres_thm c)) recursive_callees;
val fn_info = the (Symtab.lookup simpl_defs fn_name);
val calls = map (fn c => Var (guess_callee_var corres_thm c))
(Symset.dest (#callees fn_info));
val recursive_calls = map (fn c => Var (guess_callee_var corres_thm c))
(Symset.dest (#rec_callees fn_info));
(*
* The returned body will have free variables as placeholders for the function's
* measure parameter and other arguments, and schematic variables for the functions it calls.
@ -534,31 +545,29 @@ fun define
|> fold lambda (rev calls);
in abs_body end;
val fn_info' = FunctionInfo.add_phases (fn f => K (Option.map fst (Symtab.lookup l1_callees f))) fn_info;
val funcs' = funcs |>
map (fn (name, thm, frees) =>
(name, (* FIXME: define_funcs needs this currently *)
(Thm.generalize ([], map fst frees) (Thm.maxidx_of thm + 1) thm,
prepare_fn_body (name, thm, frees), [])));
val _ = @{trace} ("L1.define", map (fn (name, (thm, body, _)) => (name, thm, Thm.cterm_of lthy body)) funcs');
val (corres_thms, (), lthy') =
AutoCorresUtil.define_funcs FunctionInfo.L1 filename fn_info'
prepare_fn_body (name, thm, frees))));
val _ = @{trace} ("L1.define", map (fn (name, (thm, body)) => (name, thm, Thm.cterm_of lthy body)) funcs');
val (new_thms, (), lthy') =
AutoCorresUtil2.define_funcs FunctionInfo2.L1 filename simpl_defs
l1_function_name (K l1_fn_type) get_l1_fn_assumption (K []) @{thm L1corres_recguard_0}
lthy (Symtab.map (K snd) l1_callees) ()
lthy (Symtab.map (K #corres_thm) l1_callees) ()
funcs';
val f_names = map (fn (name, _, _) => name) funcs;
val new_phases = map2 (fn f_name => fn corres => let
val old_phase = FunctionInfo.get_phase_info fn_info FunctionInfo.CP f_name;
val def = the (AutoCorresData.get_def (Proof_Context.theory_of lthy') filename "L1def" f_name);
val f_const = Utils.get_term lthy' (l1_function_name f_name);
in old_phase
|> FunctionInfo.phase_info_upd_phase FunctionInfo.L1
|> FunctionInfo.phase_info_upd_definition def
|> FunctionInfo.phase_info_upd_const f_const
|> FunctionInfo.phase_info_upd_mono_thm NONE (* done in translate *)
end) f_names corres_thms;
val new_defs = Symtab.map (fn f_name => fn (const, def, corres_thm) => let
val f_info = the (Symtab.lookup simpl_defs f_name);
in f_info
|> FunctionInfo2.function_info_upd_phase FunctionInfo2.L1
|> FunctionInfo2.function_info_upd_definition def
|> FunctionInfo2.function_info_upd_corres_thm corres_thm
|> FunctionInfo2.function_info_upd_const const
|> FunctionInfo2.function_info_upd_mono_thm NONE (* done in translate *)
end) new_thms;
(* FIXME: return traces *)
in (Symtab.make (f_names ~~ (new_phases ~~ corres_thms)), lthy') end;
in (new_defs, lthy') end;
(*
@ -574,15 +583,15 @@ fun define
* code).
*)
(* FIXME: use AutoCorresUtil.Future instead of default *)
fun translate filename prog_info fn_info check_termination do_opt trace_opt l1_function_name lthy =
fun translate filename prog_info simpl_infos check_termination do_opt trace_opt l1_function_name lthy =
let
val funcs_to_translate = Symtab.keys (FunctionInfo.get_all_functions fn_info);
val funcs_to_translate = Symtab.keys simpl_infos;
(* All conversions can run in parallel. *)
val converted_funcs =
funcs_to_translate |> map (fn f =>
(f, Future.fork (fn _ =>
convert lthy prog_info fn_info check_termination do_opt trace_opt l1_function_name f)))
convert lthy prog_info simpl_infos check_termination do_opt trace_opt l1_function_name f)))
|> Symtab.make;
(* Definitions update lthy sequentially.
@ -600,14 +609,15 @@ let
val conv = the' ("didn't convert function: " ^ quote f ^ "??") (Symtab.lookup converted_funcs f);
val (corres_thm, arg_frees) = Future.join conv;
in (f, corres_thm, arg_frees) end) f_names;
val (new_defs, lthy') = define lthy filename prog_info fn_info check_termination
(defined_so_far: (FunctionInfo.phase_info * thm) Symtab.table)
l1_function_name f_convs;
val (new_defs, lthy') = define lthy filename prog_info simpl_infos check_termination
defined_so_far l1_function_name f_convs;
in (lthy', new_defs, Symtab.merge (K false) (defined_so_far, new_defs)) end);
val function_groups = FunctionInfo.get_topo_sorted_functions fn_info;
val (simpl_calls, _) = FunctionInfo2.calc_call_graph simpl_infos;
val function_groups = #topo_sorted_functions simpl_calls;
(* Chain of intermediate states: (lthy, new_defs, accumulator) *)
val (def_results, _) = Utils.accumulate add_def (Future.value (lthy, Symtab.empty, Symtab.empty)) function_groups;
val (def_results, _) = Utils.accumulate add_def (Future.value (lthy, Symtab.empty, Symtab.empty))
(map Symset.dest function_groups);
(* Produce a mapping from each function group to its L1 phase_infos and the
* earliest intermediate lthy where it is defined. *)
@ -616,14 +626,13 @@ let
|> map (Future.map (fn (lthy, f_defs, _ (* discard accum here *)) => let
(* Add monad_mono proofs. These are done in parallel as well
* (though in practice, they already run pretty quickly). *)
val mono_thms = if FunctionInfo.is_function_recursive fn_info (hd (Symtab.keys f_defs))
then l1_monad_mono lthy (Symtab.map (K fst) f_defs)
val mono_thms = if FunctionInfo2.is_function_recursive (snd (hd (Symtab.dest f_defs)))
then l1_monad_mono lthy f_defs
else Symtab.empty;
val f_defs' = f_defs |> Symtab.map (fn f =>
apfst (FunctionInfo.phase_info_upd_mono_thm (Symtab.lookup mono_thms f)));
FunctionInfo2.function_info_upd_mono_thm (Symtab.lookup mono_thms f));
in (lthy, f_defs') end));
in
(* would be ready to become a symtab, but a list also preserves order *)
function_groups ~~ results
end

View File

@ -24,47 +24,44 @@ declare gets_theE_L2_while [ts_rule option del]
context type_strengthen begin
ML \<open>
let val fn_info = FunctionInfo.init_fn_info @{context} "type_strengthen.c"
FunctionInfo2.init_function_info @{context} "type_strengthen.c"
|> Symtab.dest
|> map (fn (f, info) => (f, Symset.dest (#callees info), Symset.dest (#rec_callees info)))
\<close>
ML \<open>
let val simpl_infos = FunctionInfo2.init_function_info @{context} "type_strengthen.c"
val prog_info = ProgramInfo.get_prog_info @{context} "type_strengthen.c"
val (corres1, frees1) = SimplConv2.convert @{context} prog_info fn_info true true false (fn f => "l1_" ^ f) "opt_j";
val (corres2, frees2) = SimplConv2.convert @{context} prog_info fn_info true true false (fn f => "l1_" ^ f) "st_i";
val (corres1, frees1) = SimplConv2.convert @{context} prog_info simpl_infos true true false (fn f => "l1_" ^ f) "opt_j";
val (corres2, frees2) = SimplConv2.convert @{context} prog_info simpl_infos true true false (fn f => "l1_" ^ f) "st_i";
(*val thm' = Thm.generalize ([], map fst frees) (Thm.maxidx_of thm + 1) thm*)
val lthy0 = @{context};
val (l1_infos1, lthy1) =
SimplConv2.define lthy0 "type_strengthen.c" prog_info fn_info true
SimplConv2.define lthy0 "type_strengthen.c" prog_info simpl_infos true
Symtab.empty
(fn f => "l1_" ^ f)
[("opt_j", corres1, frees1)]
[("opt_j", corres1, frees1)];
val _ = @{trace} ("l1_infos1", Symtab.dest l1_infos1);
val (l1_infos2, lthy2) =
SimplConv2.define lthy1 "type_strengthen.c" prog_info fn_info true
SimplConv2.define lthy1 "type_strengthen.c" prog_info simpl_infos true
l1_infos1
(fn f => "l1_" ^ f)
[("st_i", corres2, frees2)]
[("st_i", corres2, frees2)];
in (frees1, corres1, Symtab.dest l1_infos2) end
\<close>
ML \<open>
let val filename = "type_strengthen.c";
val fn_info = FunctionInfo.init_fn_info @{context} filename;
val simpl_info = FunctionInfo2.init_function_info @{context} filename;
val prog_info = ProgramInfo.get_prog_info @{context} filename;
val l1_results =
SimplConv2.translate filename prog_info fn_info
SimplConv2.translate filename prog_info simpl_info
true true false (fn f => "l1_" ^ f ^ "'") @{context};
(*
val l2_results =
LocalVarExtract2.translate filename prog_info fn_info l1_results
LocalVarExtract2.translate filename prog_info l1_results
true false (fn f => "l2_" ^ f ^ "'");
*)
in l1_results |> map (snd #> Future.join) |> map (snd #> Symtab.dest) end
in l2_results |> map (snd #> Future.join) |> map (snd #> Symtab.dest) end
\<close>
ML \<open>
FunctionInfo.is_function_recursive (FunctionInfo.init_fn_info @{context} "type_strengthen.c") "opt_a"
\<close>
ML \<open>
FunctionInfo.get_topo_sorted_functions (FunctionInfo.init_fn_info @{context} "type_strengthen.c")
\<close>
end
declare [[ML_print_depth=99]]