lh-l4v/tools/autocorres/autocorres_util.ML
Gerwin Klein 0fbe82511d isabelle2021-1: AutoCorres
After these changes AutoCorres type checks and compiles, and the proofs
work, but for most test cases we still get runtime exceptions.

Signed-off-by: Gerwin Klein <gerwin.klein@proofcraft.systems>
2022-03-29 08:38:25 +11:00

626 lines
26 KiB
Standard ML

(*
* Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
*
* SPDX-License-Identifier: BSD-2-Clause
*)
(*
* Common code for all translation phases: defining funcs, calculating dependencies,
* variable fixes, etc.
*)
(*
* Here is a brief explanation of how most AutoCorres phases work
* with each other.
*
* AutoCorres's phases are L1, L2, HL, WA and TS. (TS doesn't share these
* utils for historical reasons; fixing that is another story.)
* Basically, each of L1, L2, HL and WA:
* 1. takes a list of input function specs;
* 2. converts each function individually;
* 3. defines each new function (or recursive function group).
* This updates the local theory sequentially;
* 4. proves monad_mono theorems and places them into the
* output list of functions;
* 5. outputs a list of new function specs in the original format.
*
* === Concurrency ===
* To support concurrent processing better, we do not use lists.
* Instead, we use a future-chained sequence (FSeq, below) so that
* define and convert steps can be done in parallel (up to the
* dependencies between them, of course).
*
* (We do not use a plain list of futures because a define
* step may produce one or more function groups, so we can't
* know how many groups there will be in advance. See the
* recursive group splitting comment for define_funcs_sequence.)
*
* Additionally, AutoCorres is structured so that conversions
* do not require the most up-to-date local theory, so we also
* output a stream of intermediate local theories. This allows
* conversions of phase N+1 to be pipelined with define steps of
* phase N.
*
* FunctionInfo.phase_results is the uniform sequence type that
* most AutoCorres translation phases adhere to.
*
* === (2) Conversion ===
* Converting a function starts by assuming correspondence theorems
* for all the functions that it calls (including itself, if
* recursive). We invent free variables to stand for those functions;
* see assume_called_functions_corres.
*
* Because it's fiddly to have these assumptions everywhere,
* we use Assumption to hide them in the thm hyps during conversion.
* When done, we export the assumptions using Morphism.thm.
*
* After performing these conversions, we get a corres theorem
* with corres assumptions for called functions (along with other
* auxillary info). These are generally packaged into a
* convert_result record.
*
* The conversions are all independent, so we launch them in
* topological order; see par_convert. This is the most convenient
* because each conversion takes place between the previous and next
* define step, which already require topological order.
*
* === (3) Definition ===
* We take the sequence of conversion results and define each
* function (or recursive group) in the theory.
*
* Each function group and its convert_results are processed by
* define_funcs. Conventionally, AutoCorres phases provide a
* "define" wrapper that sets up the required inputs to define_funcs
* and constructs function_infos for the newly defined functions.
*
* There is also a high-level wrapper, define_funcs_sequence, that
* calls these "define" wrappers in the correct order.
* It also splits recursive groups after defining them; see its
* documentation for details.
*
* === (4) Corollaries ===
* Currently, each phase only proves one type of corollary,
* monad_mono theorems. These proofs are duplicated in the source
* of the individual phases (this should be fixed) and do not make
* use of the utils here.
*
* === Incremental mode support ===
* AutoCorres supports incremental translation, which means that
* we need to insert previously-translated function data at the
* appropriate places. The par_convert and define_funcs_sequence
* wrappers take "existing_foo" arguments and ensure that these
* are available to the per-phase convert and define steps.
*)
structure AutoCorresUtil =
struct
(*
* Maximum time to let an individual function translation phase to run for.
*
* Note that this is wall time, and not CPU time, so it is a very rough
* tool.
* FIXME: convert to proper option
*)
val max_run_time = Unsynchronized.ref NONE
(*
val max_run_time = Unsynchronized.ref (SOME (seconds 900.0))
*)
exception AutocorresTimeout of string list
fun time_limit f v =
case !max_run_time of
SOME t =>
(Timeout.apply t f ()
handle Timeout.TIMEOUT _ =>
raise AutocorresTimeout v)
| NONE =>
f ()
(* Should we use concurrency? *)
val concurrent = Unsynchronized.ref true;
(*
* Conditionally fork a group of tasks, depending on the value
* of "concurrent".
*)
fun maybe_fork f = if !concurrent then Future.fork f else Future.value (f ());
(*
* Get functions called by a particular function.
*
* We split the result into standard calls and recursive calls (i.e., calls
* which may recursively call back into us).
*)
fun get_callees fn_infos fn_name =
let
val fn_info = the (Symtab.lookup fn_infos fn_name)
in
(Symset.dest (#callees fn_info), Symset.dest (#rec_callees fn_info))
end
(* Measure variables are currently hardcoded as nats. *)
val measureT = @{typ nat};
(*
* Assume theorems for called functions.
*
* A new context is returned with the assumptions in it, along with a morphism
* used for exporting the theorems out, and a list of the functions assumed:
*
* (<function name>, (<is_mutually_recursive>, <function free>, <arg frees>, <function thm>))
*
* In this context, the theorems refer to functions by fixed free variables.
*
* get_fn_args may return user-friendly argument names that clash with other names.
* We will process these names to avoid conflicts.
*
* get_fn_assumption should produce the desired theorems to assume. Its arguments:
* context (with fixed vars), callee name, callee term, arg terms, is recursive, measure term
* (all terms are fixed free vars).
*
* get_const_name generates names for the free function placeholders.
* FIXME: probably unnecessary and/or broken.
*
* We return a morphism that exposes the assumptions and generalises over the
* assumed constants.
* FIXME: automatically generalising over the assumed free variables is
* probably broken. Instead, the caller should manually generalise and
* instantiate the frees to avoid clashes.
*)
fun assume_called_functions_corres ctxt callees rec_callees
get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var =
let
(* Assume the existence of a function, along with a theorem about its
* behaviour. *)
fun assume_func ctxt fn_name is_recursive_call =
let
val fn_args = get_fn_args fn_name
(* Fix a variable for the function. *)
val ([fixed_fn_name], ctxt') = Variable.variant_fixes [get_const_name fn_name] ctxt
val fn_free = Free (fixed_fn_name, get_fn_type fn_name)
(* Fix a variable for the measure and function arguments. *)
val (measure_var_name :: arg_names, ctxt'')
= Variable.variant_fixes ("rec_measure'" :: (map fst fn_args)) ctxt'
val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args))
val my_measure_var = Free (measure_var_name, @{typ nat})
(*
* A measure variable is needed to handle recursion: for recursive calls,
* we need to decrement the caller's input measure value (and our
* assumption will need to assume this to). This is so we can later prove
* termination of our function definition: the measure always reaches zero.
*
* Non-recursive calls can have a fresh value.
*)
val measure_var =
if is_recursive_call then
@{const "recguard_dec"} $ callers_measure_var
else
my_measure_var
(* Create our assumption. *)
val assumption =
get_fn_assumption ctxt'' fn_name fn_free fn_arg_terms
is_recursive_call measure_var
|> fold Logic.all (rev ((if is_recursive_call then [] else [my_measure_var]) @ fn_arg_terms))
|> Sign.no_vars ctxt'
|> Thm.cterm_of ctxt'
val ([thm], ctxt''') = Assumption.add_assumes [assumption] ctxt'
(* Generate a morphism for escaping this context. *)
val m = (Assumption.export_morphism ctxt''' ctxt')
$> (Variable.export_morphism ctxt' ctxt)
in
(fn_free, thm, ctxt''', m)
end
(* Add assumptions: recursive calls first, matching the order in define_functions *)
val (res, (ctxt', m)) = fold_map (
fn (fn_name, is_recursive_call) =>
fn (ctxt, m) =>
let
val (free, thm, ctxt', m') =
assume_func ctxt fn_name is_recursive_call
in
((fn_name, (is_recursive_call, free, thm)), (ctxt', m' $> m))
end)
(map (fn f => (f, false)) (Symset.dest callees) @
map (fn f => (f, true)) (Symset.dest rec_callees))
(ctxt, Morphism.identity)
in
(ctxt', m, res)
end
(* Determine which functions are called by a code fragment.
* Only function terms in callee_consts are used. *)
fun get_body_callees
(callee_consts: string Termtab.table)
(body: term)
: symset =
Term.fold_aterms (fn t => fn a =>
(Termtab.lookup callee_consts t
|> Option.map single
|> the_default []) @ a)
body []
|> Symset.make;
(* Determine which recursive calls are actually used by a code fragment.
* This is used to make adjustments to recursive function groups
* between conversion and definition steps.
*
* callee_terms is a list of (is_recursive, func const, thm)
* as provided by assume_called_functions_corres *)
fun get_rec_callees
(callee_terms: (string * (bool * term * thm)) list)
(body: term)
: symset = let
val callee_lookup =
callee_terms |> List.mapPartial (fn (callee, (is_rec, const, _)) =>
if is_rec then SOME (const, callee) else NONE)
|> Termtab.make;
in get_body_callees callee_lookup body end;
(*
* Given one or more function specs, define them and instantiate corres proofs.
*
* "callee_thms" contains corres theorems for already-defined functions.
*
* "fn_infos" is used to look up function callees. It is expected
* to consist of the previous translation output for the functions
* being defined, but may of course contain other entries.
*
* "functions" contains a list of (fn_name, (body, corres proof, arg_frees)).
* The body should be of the form generated by abstract_fn_body,
* with lambda abstractions for all callees and arguments.
*
* The given corres proof is expected to use the free variables in
* arg_frees for the function's arguments (including the measure variable,
* if there if there is one). It is also expected to use schematic
* variables for assumed callees.
* (FIXME: this interface should be simplified a bit.)
*
* We assume that all functions in this list are mutually recursive.
* (If not, you should call "define_funcs" multiple times, each
* time with a single function.)
*
* Returns the new function constants, definitions, final corres proofs,
* and local theory.
*)
fun define_funcs
(phase : FunctionInfo.phase)
(filename : string)
(fn_infos : FunctionInfo.function_info Symtab.table)
(get_const_name : string -> string)
(get_fn_type : string -> typ)
(get_fn_assumption : Proof.context -> string -> term -> term list -> bool -> term -> term)
(get_fn_args : string -> (string * typ) list)
(rec_base_case : thm)
(ctxt : Proof.context)
(callee_thms : thm Symtab.table)
(functions : (string * (term * thm * (string * typ) list)) list)
: (term * thm * thm) Symtab.table * Proof.context =
let
val fn_names = map fst functions
val fn_bodies = map (snd #> #1) functions
(* Generalise over the function arguments *)
val fn_thms = map (snd #> (fn (_, thm, frees) =>
(Thm.generalize (Names.empty, Names.make_set (map fst frees)) (Thm.maxidx_of thm + 1) thm)))
functions
val _ = writeln ("Defining (" ^ FunctionInfo.string_of_phase phase ^ ") " ^
(Utils.commas (map get_const_name fn_names)))
(*
* Determine if we are in a recursive case by checking to see if the
* first function in our list makes recursive calls to any other
* function. (This "other function" will be itself if it is simple
* recursion, but may be a different function if we are mutually
* recursive.)
*)
val is_recursive = FunctionInfo.is_function_recursive (the (Symtab.lookup fn_infos (hd fn_names)))
val _ = assert (length fn_names = 1 orelse is_recursive)
"define_funcs passed multiple functions, but they don't appear to be recursive."
(*
* Patch in functions into our function body in the following order:
*
* * Non-recursive calls;
* * Recursive calls
*)
fun fill_body fn_name body =
let
val fn_info = the (Symtab.lookup fn_infos fn_name)
val non_rec_calls = map (fn x => Utils.get_term ctxt (get_const_name x)) (Symset.dest (#callees fn_info))
val rec_calls = map (fn x => Free (get_const_name x, get_fn_type x)) (Symset.dest (#rec_callees fn_info))
in
body
|> (fn t => betapplys (t, non_rec_calls))
|> (fn t => betapplys (t, rec_calls))
end
(*
* Define our functions.
*
* Definitions should be of the form:
*
* %arg1 arg2 arg3. (arg1 + arg2 + arg3)
*
* Mutually recursive calls should be of the form "Free (fn_name, fn_type)".
*)
val defs = map (
fn (fn_name, fn_body) => let
val fn_args = get_fn_args fn_name
(* FIXME: this retraces assume_called_functions_corres *)
val (fn_free :: measure_free :: arg_frees, _) = Variable.variant_fixes
(get_const_name fn_name :: "rec_measure'" :: map fst fn_args) ctxt
in (get_const_name fn_name, (* be inflexible when it comes to fn_name *)
(measure_free, @{typ nat}) :: (arg_frees ~~ map snd fn_args), (* changing arg names is ok *)
fill_body fn_name fn_body) end)
(fn_names ~~ fn_bodies)
val (fn_def_thms, ctxt) = Utils.define_functions defs true is_recursive ctxt
(*
* Instantiate schematic function calls in our theorems with their
* concrete definitions.
*)
val combined_callees = map (get_callees fn_infos) (map fst functions)
val combined_normal_calls =
map fst combined_callees |> flat |> sort_distinct fast_string_ord
val combined_recursive_calls =
map snd combined_callees |> flat |> sort_distinct fast_string_ord
val callee_terms =
(combined_recursive_calls @ combined_normal_calls)
|> map (fn x => (get_const_name x, Utils.get_term ctxt (get_const_name x)))
|> Symtab.make
fun fill_proofs thm =
Utils.instantiate_thm_vars ctxt
(fn ((name, _), _) =>
Symtab.lookup callee_terms name
|> Option.map (Thm.cterm_of ctxt)) thm
val fn_thms = map fill_proofs fn_thms
(* Fix free variable for the measure. *)
val ([measure_var_name], ctxt') = Variable.variant_fixes ["m"] ctxt
val measure_var = Free (measure_var_name, @{typ nat})
(* Generate corres predicates for each function. *)
val preds = map (
fn fn_name =>
let
fun mk_forall v t = HOLogic.all_const (Term.fastype_of v) $ lambda v t
val fn_const = Utils.get_term ctxt' (get_const_name fn_name)
(* Fetch parameters to this function. *)
val free_params =
get_fn_args fn_name
|> Variable.variant_frees ctxt' [measure_var]
|> map Free
in
(* Generate the prop. *)
get_fn_assumption ctxt' fn_name fn_const
free_params is_recursive measure_var
|> fold Logic.all (rev free_params)
end) fn_names
(* We generate a goal which solves all the mutually recursive calls simultaneously. *)
val goal = map (Object_Logic.atomize_term ctxt') preds
|> Utils.mk_conj_list
|> HOLogic.mk_Trueprop
|> Thm.cterm_of ctxt'
(* Prove each of the predicates above, leaving any assumptions about called
* functions unsolved. *)
val pred_thms = map (
fn (pred, thm, body_def) =>
Thm.trivial (Thm.cterm_of ctxt' pred)
|> Utils.apply_tac "unfold body" (Hypsubst.stac ctxt' body_def 1)
|> Utils.apply_tac "apply rule" (resolve_tac ctxt' [thm] 1)
|> Goal.norm_result ctxt
|> singleton (Variable.export ctxt' ctxt)
)
(Utils.zip3 preds fn_thms fn_def_thms)
(* Create a set of "helper theorems", which should be sufficient to discharge
* all assumptions that our callees refine. *)
val helper_thms =
(map (Symtab.lookup callee_thms #> the) combined_normal_calls) @ pred_thms
|> map (Thm.forall_intr_vars)
|> map (Conv.fconv_rule (Object_Logic.atomize ctxt))
(* Generate a proof term of equivalence using the folded definitions. *)
val new_thm =
Goal.init goal
|> (fn thm =>
if is_recursive then (
Utils.apply_tac "start induction"
(resolve_tac ctxt'
[Utils.named_cterm_instantiate ctxt'
[("n", Thm.cterm_of ctxt' measure_var)] @{thm recguard_induct}]
1) thm
|> Utils.apply_tac "unfold bodies"
(EVERY (map (fn x => (EqSubst.eqsubst_tac ctxt' [1] [x] 1)) (rev fn_def_thms)))
|> Utils.apply_tac "solve induction base cases"
(SOLVES ((simp_tac (put_simpset HOL_ss ctxt' addsimps [rec_base_case]) 1)))
|> Utils.apply_tac "solve remaining goals"
(Utils.metis_insert_tac ctxt helper_thms 1)
) else (
Utils.apply_tac "solve remaining goals"
(Utils.metis_insert_tac ctxt helper_thms 1) thm
))
|> Goal.finish ctxt'
(*
* The proof above is of the form (L1corres a & L1corres b & ...).
* Split it up into several proofs.
*)
fun prove_partial_corres thm pred =
Thm.cterm_of ctxt' pred
|> Goal.init
|> Utils.apply_tac "solving using metis" (Utils.metis_tac ctxt [thm] 1)
|> Goal.finish ctxt'
(* Generate the final theorems. *)
val new_thms =
map (prove_partial_corres new_thm) preds
|> (Variable.export ctxt' ctxt)
|> map (Goal.norm_result ctxt)
val results =
fn_names ~~ (new_thms ~~ fn_def_thms)
|> map (fn (fn_name, (corres_thm, def_thm)) =>
(* FIXME: ugly way to get the function constant *)
(fn_name, (Utils.get_term ctxt (get_const_name fn_name), def_thm, corres_thm)))
|> Symtab.make;
in
(results, ctxt)
end
(* Support function for incremental translation.
* This updates #callees of the functions we are translating, to include
* background functions that have already been translated.
* (We don't need to handle #rec_callees because recursive groups
* cannot be translated piecemeal.) *)
fun add_background_callees
(background: FunctionInfo.function_info Symtab.table)
: FunctionInfo.function_info Symtab.table ->
FunctionInfo.function_info Symtab.table = let
val bg_consts =
Symtab.dest background
|> map (fn (f, bg_info) => (#raw_const bg_info, f))
|> Termtab.make;
in Symtab.map (K (fn fn_info => let
val bg_callees = get_body_callees bg_consts (Thm.prop_of (#definition fn_info));
in FunctionInfo.function_info_upd_callees (Symset.union (#callees fn_info) bg_callees) fn_info end))
end;
(* Utility for doing conversions in parallel.
* The conversion of each function f should depend only on the previous
* define phase for f (which necessarily also includes f's callees). *)
type convert_result = {
body: term, (* new body *)
proof: thm, (* corres thm *)
rec_callees: symset, (* minimal rec_callees after translation *)
callee_consts: term Symtab.table, (* assumed frees for other callees *)
arg_frees: (string * typ) list, (* fixed argument frees, including measure *)
traces: (string * AutoCorresData.Trace) list (* traces *)
}
fun par_convert
(* Worker: lthy -> function_infos for func and callees -> func name -> results *)
(convert: local_theory -> FunctionInfo.function_info Symtab.table ->
string -> convert_result)
(* Functions from prior incremental translation *)
(existing_infos: FunctionInfo.function_info Symtab.table)
(* Functions from previous phase *)
(prev_results: FunctionInfo.phase_results)
(* Add traces from the conversion result. *)
(add_trace: string -> string -> AutoCorresData.Trace -> unit)
(* Return converted functions in recursive groups.
* The groups are tagged with fn_infos from prev_results to identify them. *)
: (FunctionInfo.function_info Symtab.table * convert_result Symtab.table) FSeq.fseq =
(* Knowing that prev_results is in topological order,
* we accumulate its function_infos, which will be a superset
* of the callee infos that each conversion requires. *)
FSeq.fold_map (fn fn_infos_accum => fn (lthy, fn_infos) => let
val fn_infos = add_background_callees existing_infos fn_infos;
val fn_infos_accum = Symtab.merge (K false) (fn_infos_accum, fn_infos);
(* Convert fn_infos in parallel, but join the results right away.
* This is fine because we will define them together. *)
val conv_results =
Symtab.dest fn_infos
|> Par_List.map (fn (f, _) =>
(f, convert lthy fn_infos_accum f))
|> Symtab.make;
val _ = app (fn (f, result) => app (fn (typ, trace) =>
add_trace f typ trace) (#traces result))
(Symtab.dest conv_results);
in ((fn_infos, conv_results), fn_infos_accum) end)
existing_infos prev_results;
(* Given a function body containing arguments and assumed function calls,
* abstract the code over those parameters.
*
* The returned body will have free variables as placeholders for the function's
* measure parameter and other arguments, as well as for the functions it calls.
*
* We modify the body to be of the form:
*
* %fun1 fun2 rec1 rec2 measure arg1 arg2. f <...>
*
* That is, all non-recursive calls are abstracted out the front, followed by
* recursive calls, followed by the measure variable, followed by function
* arguments. This is the format expected by define_funcs.
*)
fun abstract_fn_body
(prev_fn_infos: FunctionInfo.function_info Symtab.table)
(fn_name, {body, callee_consts, arg_frees, ...} : convert_result) = let
val (callees, rec_callees) = get_callees prev_fn_infos fn_name;
val calls = map (the o Symtab.lookup callee_consts) callees;
val rec_calls = map (the o Symtab.lookup callee_consts) rec_callees;
val abs_body = body
|> fold lambda (rev (map Free arg_frees))
|> fold lambda (rev rec_calls)
|> fold lambda (rev calls);
in abs_body end;
(* Utility for defining functions.
*
* Definitions update the theory sequentially.
* Each definition step produces a lthy that contains the current function
* group, and can immediately be used in the next conversion phase for
* those functions. Hence we return the intermediate lthys as futures.
*
* The actual recursive function groups may be finer-grained than in
* converted_groups, as function calls can be removed by dead code
* elimination and other transformations. Hence we detect the actual
* function groups after defining them. *)
fun define_funcs_sequence
(lthy: local_theory)
(define_worker: local_theory ->
(* previous infos for functions *)
FunctionInfo.function_info Symtab.table ->
(* new infos for callees *)
FunctionInfo.function_info Symtab.table ->
(* data for functions *)
convert_result Symtab.table ->
(* new infos for functions *)
FunctionInfo.function_info Symtab.table * local_theory)
(* previous infos from prior translation (in incremental mode) *)
(existing_infos: FunctionInfo.function_info Symtab.table)
(* functions defined so far, initially populated from prior translation *)
(defined_so_far: FunctionInfo.function_info Symtab.table)
(converted_groups: (FunctionInfo.function_info Symtab.table *
convert_result Symtab.table) FSeq.fseq)
: FunctionInfo.phase_results = FSeq.mk (fn () =>
case FSeq.uncons converted_groups of
NONE => NONE
| SOME ((prev_infos, conv_group), remaining_groups) => SOME let
(* NB: we don't need to add_background_callees to prev_infos because
* par_convert already does that *)
(* Define the function group, then split into minimal recursive groups. *)
val (new_infos, lthy') =
define_worker lthy (Symtab.merge (K false) (prev_infos, existing_infos))
defined_so_far conv_group;
(* Minimise callees and split recursive group if needed. *)
val new_infoss = FunctionInfo.recalc_callees defined_so_far new_infos;
val defined_so_far = Symtab.merge (K false) (defined_so_far, new_infos);
(* We can't wrap the first result because we're already in FSeq.mk.
* Fortunately, one is guaranteed to exist *)
val new_infos1 :: new_infoss = new_infoss;
(* Output new group(s) in the result sequence. *)
val remaining_results =
define_funcs_sequence lthy' define_worker
existing_infos defined_so_far remaining_groups;
in
((lthy', new_infos1),
FSeq.append
(FSeq.of_list (map (fn defs => (lthy', defs)) new_infoss))
remaining_results)
end);
end