diff --git a/tools/autocorres/AutoCorres.thy b/tools/autocorres/AutoCorres.thy index 5579a2515..af904e20d 100644 --- a/tools/autocorres/AutoCorres.thy +++ b/tools/autocorres/AutoCorres.thy @@ -145,6 +145,11 @@ ML_file "monad_convert.ML" ML_file "type_strengthen.ML" ML_file "autocorres.ML" +declare [[ML_print_depth=42]] +ML_file "autocorres_util2.ML" +ML_file "simpl_conv2.ML" +ML_file "local_var_extract2.ML" + (* Setup "autocorres" keyword. *) ML {* Outer_Syntax.command @{command_keyword "autocorres"} @@ -153,4 +158,82 @@ ML {* (Toplevel.theory o (fn (opt, filename) => AutoCorres.do_autocorres opt filename))) *} +ML \ + +fun assume_called_functions_corres ctxt fn_info callees + get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var = +let + (* Assume the existence of a function, along with a theorem about its + * behaviour. *) + fun assume_func ctxt fn_name is_recursive_call = + let + val fn_args = get_fn_args fn_name + + (* Fix a variable for the function. *) + val ([fixed_fn_name], ctxt') = Variable.variant_fixes [get_const_name fn_name] ctxt + val fn_free = Free (fixed_fn_name, get_fn_type fn_name) + + (* Fix a variable for the measure and function arguments. *) + val (measure_var_name :: arg_names, ctxt'') + = Variable.variant_fixes ("rec_measure'" :: (map fst fn_args)) ctxt' + val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args)) + val my_measure_var = Free (measure_var_name, @{typ nat}) + + (* + * A measure variable is needed to handle recursion: for recursive calls, + * we need to decrement the caller's input measure value (and our + * assumption will need to assume this to). This is so we can later prove + * termination of our function definition: the measure always reaches zero. + * + * Non-recursive calls can have a fresh value. + *) + val measure_var = + if is_recursive_call then + @{const "recguard_dec"} $ callers_measure_var + else + my_measure_var + + (* Create our assumption. *) + val assumption = + get_fn_assumption ctxt'' fn_name fn_free fn_arg_terms + is_recursive_call measure_var + |> fold Logic.all (rev ((if is_recursive_call then [] else [my_measure_var]) @ fn_arg_terms)) + |> Sign.no_vars ctxt' + |> Thm.cterm_of ctxt' + val ([thm], ctxt''') = Assumption.add_assumes [assumption] ctxt' + + (* Generate a morphism for escaping this context. *) + val m = (Assumption.export_morphism ctxt''' ctxt') + $> (Variable.export_morphism ctxt' ctxt) + in + (fn_free, thm, ctxt''', m) + end + + (* Apply each assumption. *) + val (res, (ctxt', m)) = fold_map ( + fn (fn_name, is_recursive_call) => + fn (ctxt, m) => + let + val (free, thm, ctxt', m') = + assume_func ctxt fn_name is_recursive_call + in + ((fn_name, (is_recursive_call, free, thm)), (ctxt', m' $> m)) + end) + callees (ctxt, Morphism.identity) +in + (ctxt', m, res) +end; + +(* +fun assume_called_functions_corres ctxt fn_info callees + get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var +*) +assume_called_functions_corres @{context} () [("a", false), ("r", true)] + (K @{typ "nat \ nat \ string \ string"}) + (fn ctxt => fn name => fn term => fn args => fn is_rec => fn meas => + HOLogic.mk_Trueprop (@{term "my_corres :: (string \ string) \ bool"} $ betapplys (term, meas :: args))) + (fn f => if f = "a" then [("arg_a", @{typ nat})] else [("arg_r", @{typ nat})]) + I @{term "rec_measure :: nat"} +\ + end diff --git a/tools/autocorres/autocorres_util.ML b/tools/autocorres/autocorres_util.ML index 2de6fa703..722f84ebd 100644 --- a/tools/autocorres/autocorres_util.ML +++ b/tools/autocorres/autocorres_util.ML @@ -57,10 +57,36 @@ sig (* result *) -> (Proof.context * FunctionInfo.fn_info) + val define_funcs: + FunctionInfo.phase -> + string -> + FunctionInfo.fn_info -> + (string -> string) -> + (string -> typ) -> + (Proof.context -> string -> term -> term list -> bool -> term -> term) -> + (string -> (string * typ) list) -> + thm -> + Proof.context -> + thm Symtab.table -> + 'a -> + (string * (thm * term * (string * AutoCorresData.Trace) list)) list -> + thm list * 'a * Proof.context + val map_all : Proof.context -> FunctionInfo.fn_info -> (string -> FunctionInfo.function_info -> 'a) -> 'a list val concurrent : bool Unsynchronized.ref val has_simpl_body_def : local_theory -> string -> bool val max_run_time : Time.time option Unsynchronized.ref + + val assume_called_functions_corres : + Proof.context -> + (string * bool) list -> + (string -> typ) -> + (Proof.context -> string -> term -> term list -> bool -> term -> term) -> + (string -> (string * typ) list) -> + (string -> string) -> + term -> + Proof.context * morphism * (string * (bool * term * thm)) list + val get_callees : FunctionInfo.fn_info -> string -> string list * string list (* (nonrecs, recs) *) end; structure AutoCorresUtil : AUTOCORRES_UTIL = @@ -262,14 +288,32 @@ fun is_Trueprop (Const (@{const_name "Trueprop"}, _) $ _) = true | is_Trueprop _ = false (* - * Assume the existence of the given list of functions. + * Assume theorems for called functions. * * A new context is returned with the assumptions in it, along with a morphism * used for exporting the theorems out, and a list of the functions assumed: * - * (, (, , )) + * (, (, , , )) + * + * In this context, the theorems refer to functions by fixed free variables. + * + * get_fn_args may return user-friendly argument names that clash with other names. + * We will process these names to avoid conflicts. + * + * get_fn_assumption should produce the desired theorems to assume. Its arguments: + * context (with fixed vars), callee name, callee term, arg terms, is recursive, measure term + * (all terms are fixed free vars). + * + * get_const_name generates names for the free function placeholders. + * FIXME: probably unnecessary and/or broken. + * + * We return two morphisms: + * - the first one makes the assumptions visible again + * - the second one automatically generalizes the assumed constants + * (this exists for backwards compat; all new code should explicitly use the + * returned free variable set) *) -fun assume_called_functions_corres ctxt fn_info callees +fun assume_called_functions_corres ctxt callees get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var = let (* Assume the existence of a function, along with a theorem about its @@ -354,7 +398,7 @@ fun gen_corres_for_function (ctxt : Proof.context) (fn_name : string) = let - val _ = writeln ("Converting (" ^ FunctionInfo.string_of_phase phase ^") " ^ fn_name) + val _ = writeln ("Converting (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name) val start_time = Timer.startRealTimer () (* Get a list of functions we call. *) @@ -364,21 +408,30 @@ let (map (fn x => (x, false)) normal_calls) @ (map (fn x => (x, true)) recursive_calls) + (* Make sure the desired function name is available. *) + val fn_target_name = get_const_name fn_name + val ([fn_free], ctxt') = Variable.variant_fixes [fn_target_name] ctxt + val _ = if fn_free = fn_target_name then () else + warning ("Variable clobbered: " ^ fn_target_name ^ " -> " ^ fn_free ^ ". Translating " ^ + fn_name ^ " may fail.") + val fn_var_morph = Variable.export_morphism ctxt' ctxt + (* Fix a measure variable that will be used to track recursion progress. *) - val ([measure_var_name], ctxt') = Variable.variant_fixes ["rec_measure'"] ctxt + val ([measure_var_name], ctxt'') = Variable.variant_fixes ["rec_measure'"] ctxt' val measure_var = Free (measure_var_name, @{typ nat}) - val measure_var_morph = Variable.export_morphism ctxt' ctxt + val measure_var_morph = Variable.export_morphism ctxt'' ctxt' (* Fix variables for function arguments. *) val fn_args = get_fn_args fn_name - val (arg_names, ctxt'') - = Variable.variant_fixes (map (fn (a, _) => a ^ "'arg") fn_args) ctxt' + val (arg_names, ctxt''') + = Variable.variant_fixes (map (fn (a, _) => a ^ "'arg") fn_args) ctxt'' val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args)) - val fn_args_morph = Variable.export_morphism ctxt'' ctxt' + val fn_args_morph = Variable.export_morphism ctxt''' ctxt'' + val _ = @{trace} ("Vars", fn_free, measure_var_name, arg_names) (* Enter a context where we assume our callees exist. *) - val (ctxt''', m, callee_info_and_proofs) - = assume_called_functions_corres ctxt'' fn_info callees + val (ctxt'''', m, callee_info_and_proofs) + = assume_called_functions_corres ctxt''' callees get_fn_type get_fn_assumption get_fn_args get_const_name measure_var @@ -387,7 +440,7 @@ let * term and a tactic for proving correspondence. *) val callee_tab = Symtab.make callee_info_and_proofs - val (body, thm, trace) = convert ctxt''' fn_name callee_tab measure_var fn_arg_terms + val (body, thm, trace) = convert ctxt'''' fn_name callee_tab measure_var fn_arg_terms (* * The returned body will have free variables as placeholders for the function's @@ -408,7 +461,7 @@ let |> fold lambda (rev (map (fn x => Symtab.lookup callee_tab x |> the |> #2) normal_calls)) (* Export the theorem out of our context. *) - val exported_thm = Morphism.thm (m $> fn_args_morph $> measure_var_morph) thm + val exported_thm = Morphism.thm (m $> fn_args_morph $> measure_var_morph $> fn_var_morph) thm (* TODO: allow this message to be configured *) val _ = @{trace} ("Converted (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name ^ " in " ^ @@ -423,15 +476,15 @@ end * "callee_thms" contains a table mapping function names to complete * corres proofs for those functions. * - * "functions" contains a list of (fn_name, (proof, callees)). We - * assume that all functions in this list are mutually recursive. (If - * not, you should call "define_funcs" multiple times, each + * "functions" contains a list of (fn_name, (proof, body, proof traces)). + * The body should be of the form generated by gen_corres_for_function, + * with lambda abstractions for all callees and arguments. + * + * We assume that all functions in this list are mutually recursive. + * (If not, you should call "define_funcs" multiple times, each * time with a single function.) * - * This code is quite complex in order to support mutual recursion, - * where function definitions and proofs must simultaneously take place - * for several functions: if we were only supporting non-recursive - * functions, life would be easier. + * The proof traces are stored into the theory. (This should probably be moved.) *) fun define_funcs (phase : FunctionInfo.phase) @@ -444,7 +497,7 @@ fun define_funcs (rec_base_case : thm) (ctxt : Proof.context) (callee_thms : thm Symtab.table) - _ + accum (* translate allows an accumulator, but we don't use it here *) (functions : (string * (thm * term * (string * AutoCorresData.Trace) list)) list) = let @@ -455,6 +508,8 @@ fun define_funcs val _ = writeln ("Defining (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ (Utils.commas (map get_const_name fn_names))) + val _ = @{trace} ("define_funcs function(s)", functions) + val _ = @{trace} ("define_funcs callee(s)", Symtab.dest callee_thms) (* * Determine if we are in a recursive case by checking to see if the @@ -495,10 +550,14 @@ fun define_funcs * Mutually recursive calls should be of the form "Free (fn_name, fn_type)". *) val defs = map ( - fn (fn_name, fn_body) => - (get_const_name fn_name, - ("rec_measure'", @{typ nat}) :: get_fn_args fn_name, - fill_body fn_name fn_body)) + fn (fn_name, fn_body) => let + val fn_args = get_fn_args fn_name + (* FIXME: this retraces assume_called_functions_corres *) + val (fn_free :: measure_free :: arg_frees, _) = Variable.variant_fixes + (get_const_name fn_name :: "rec_measure'" :: map fst fn_args) ctxt + in (get_const_name fn_name, (* be inflexible when it comes to fn_name *) + (measure_free, @{typ nat}) :: (arg_frees ~~ map snd fn_args), (* changing arg names is ok *) + fill_body fn_name fn_body) end) (fn_names ~~ fn_bodies) val (fn_def_thms, ctxt) = Utils.define_functions defs true is_recursive ctxt @@ -560,12 +619,12 @@ fun define_funcs (* Prove each of the predicates above, leaving any assumptions about called * functions unsolved. *) val pred_thms = map ( - fn (pred, thm, body_def) => + fn (pred, thm, body_def) => (@{trace} ("define_funcs applying rule", thm); Thm.trivial (Thm.cterm_of ctxt' pred) |> Utils.apply_tac "unfold body" (Hypsubst.stac ctxt' body_def 1) |> Utils.apply_tac "apply rule" (resolve_tac ctxt' [thm] 1) |> Goal.norm_result ctxt - |> singleton (Variable.export ctxt' ctxt) + |> singleton (Variable.export ctxt' ctxt)) ) (Utils.zip3 preds fn_thms fn_def_thms) @@ -630,7 +689,7 @@ fun define_funcs (fold (fn (phase, fn_name, trace) => AutoCorresData.add_trace filename phase fn_name trace) fn_traces) ctxt in - (new_thms, (), ctxt) + (new_thms, accum, ctxt) end (* diff --git a/tools/autocorres/autocorres_util2.ML b/tools/autocorres/autocorres_util2.ML new file mode 100644 index 000000000..cfb100b5d --- /dev/null +++ b/tools/autocorres/autocorres_util2.ML @@ -0,0 +1,664 @@ +(* + * Copyright 2014, NICTA + * + * This software may be distributed and modified according to the terms of + * the BSD 2-Clause license. Note that NO WARRANTY is provided. + * See "LICENSE_BSD2.txt" for details. + * + * @TAG(NICTA_BSD) + *) + +(* + * Common code for all translation phases: defining funcs, calculating dependencies, variable fixes, etc. + *) + +structure AutoCorresUtil2 = +struct + +(* + * Maximum time to let an individual function translation phase to run for. + * + * Note that this is wall time, and not CPU time, so it is a very rough + * tool. + * FIXME: convert to proper option + *) +val max_run_time = Unsynchronized.ref NONE +(* +val max_run_time = Unsynchronized.ref (SOME (seconds 900.0)) +*) + +exception AutocorresTimeout of string list + +fun time_limit f v = + case !max_run_time of + SOME t => + (TimeLimit.timeLimit t f () + handle TimeLimit.TimeOut => + raise AutocorresTimeout v) + | NONE => + f () + +(* Should we use concurrency? *) +val concurrent = Unsynchronized.ref true; + +(* + * Conditionally fork a group of tasks, depending on the value + * of "concurrent". + *) +datatype 'a maybe_fork = Future of 'a future | Boring of 'a + +(* Fork a group of tasks. *) +fun maybe_fork ctxt vals = + if ((!concurrent) andalso not (Config.get ctxt ML_Options.exception_trace)) then + map Future (Future.forks { + name = "", group = NONE, deps = [], pri = ~1, interrupts = true} + vals) + else + map (fn x => Boring (x ())) vals + +(* Ensure a forked task has completed. *) +fun maybe_join v = + case v of + Boring x => x + | Future x => Future.join x + +(* Functional map. *) +fun par_map ctxt f a = + if ((!concurrent) andalso not (Config.get ctxt ML_Options.exception_trace)) then + Par_List.map f a + else + map f a + +(* Does a SIMPL body exist for the given function name? *) +fun has_simpl_body_def lthy name = + try (fn name => Proof_Context.get_thm lthy (name ^ "_body_def")) name + |> is_some + + +(* + * A translation step transforms the program from one form to another; + * such as from SIMPL to a monadic type, or from one type of monad to + * another. + * + * "filename" is the name of the file we are translating: this is required as a + * key to fetch data stashed away by the C parser. + * + * "lthy" is the local theory. + * + * "convert" performs any proof work required for this translation step. All + * conversions are performed in parallel, so must be able to be completed + * without results from previous steps. + * + * "define" actually sets up any definitions required by the translation; + * definition steps occur serially, but may be in parallel with conversion + * steps whose results are not yet required. + * + * "prove_mono" should prove the monad_mono property for recursive functions. + * (* It is run in parallel over all recursive groups. *) + * + * Because functions handed to us by the C parser may be mutually recursive and + * such mutually recursive functions must typically be defined simultaneously, + * "define" is handed a list of functions which must all be defined in one + * step. + *) +fun translate lthy phase fn_info initial_callees convert define gen_new_info prove_mono v = + let + (* Get list of functions we need to translate. + * This is a bit complicated because we need to skip over functions that + * have already been translated, and hence we need to recalculate the + * function call graph. *) + val functions_to_translate = + Symtab.dest (FunctionInfo.get_all_functions fn_info) + |> map_filter (fn (name, info) => + case FunctionInfo.Phasetab.lookup (#phases info) phase of + NONE => SOME name + | SOME _ => NONE) + |> Symset.make + val fn_info_restricted = FunctionInfo.map_fn_info (fn info => + if Symset.contains functions_to_translate (#name info) then SOME info else NONE) fn_info + val function_groups = FunctionInfo.get_topo_sorted_functions fn_info_restricted + val all_functions = List.concat function_groups + + (* + * Convert every function. + * + * We perform the conversions using futures, which are run in parallel. + * This allows us to perform conversions while we start defining functions, + * hopefully speeding everything up on multicore systems. + *) + val converted_body_thms = + map (fn name => fn _ => + time_limit (fn _ => convert lthy name) [name]) all_functions + |> maybe_fork lthy + val converted_bodies = Symtab.make (all_functions ~~ converted_body_thms) + + (* In sorted order, define constants and proofs for the functions. *) + fun translate fn_names (callee_thms, new_phase_infos, v, lthy) = + let + val defs = map (fn fn_name => + Symtab.lookup converted_bodies fn_name |> the |> maybe_join) fn_names + val (proofs, v, lthy) + = time_limit (fn _ => + define lthy callee_thms v (fn_names ~~ defs)) fn_names + val new_callee_thms = fold Symtab.update_new + (fn_names ~~ proofs) callee_thms + val new_phase_infos = fold (fn n => + Symtab.update_new (n, gen_new_info lthy v (FunctionInfo.get_function_info fn_info n))) + fn_names new_phase_infos + in + (new_callee_thms, new_phase_infos, v, lthy) + end + + val (proof_table, new_phase_infos, v, lthy) + = fold translate function_groups (initial_callees, Symtab.empty, v, lthy) + + val mono_thms = + function_groups + |> map (fn funcs => if not (FunctionInfo.is_function_recursive fn_info (hd funcs)) + then K Symtab.empty else (fn _ => time_limit (fn _ => + (List.mapPartial (fn f => + case Symtab.lookup new_phase_infos f of + SOME phase_info => + SOME (FunctionInfo.function_info_add_phase phase_info + (FunctionInfo.get_function_info fn_info f)) + | _ => NONE) funcs + |> prove_mono lthy)) funcs)) + |> maybe_fork lthy |> map maybe_join + |> maps Symtab.dest |> Symtab.make + + val new_phase_infos = new_phase_infos |> + Symtab.map (fn func => FunctionInfo.phase_info_upd_mono_thm (Symtab.lookup mono_thms func)) + in + (lthy, FunctionInfo.add_phases (fn name => K (Symtab.lookup new_phase_infos name)) fn_info, v) + end + +(* + * A translation step that maps over every function in the program. + * + * "convert" performs any proof work required for this translation step. All + * conversions are performed in parallel, so must be able to be completed + * without results from previous steps. + * + * We return a list of all results. + *) +fun map_all ctxt fn_info convert = + par_map ctxt (uncurry convert) (FunctionInfo.get_all_functions fn_info |> Symtab.dest) + +(* + * Get functions called by a particular function. + * + * We split the result into standard calls and recursive calls (i.e., calls + * which may recursively call back into us). + *) +fun get_callees fn_info fn_name = +let + (* Get a list of functions we call. *) + val all_callees = FunctionInfo.get_function_callees fn_info fn_name + + (* Fetch calls that may recursively call back to us. *) + val recursive_calls = FunctionInfo.get_recursive_group fn_info fn_name + + (* Remove "recursive_calls" from the standard callee set. *) + val callees = + Symset.make all_callees + |> Symset.subtract (Symset.make recursive_calls) + |> Symset.dest +in + (callees, recursive_calls) +end + +(* Is the given term a Trueprop? *) +fun is_Trueprop (Const (@{const_name "Trueprop"}, _) $ _) = true + | is_Trueprop _ = false + +(* + * Assume theorems for called functions. + * + * A new context is returned with the assumptions in it, along with a morphism + * used for exporting the theorems out, and a list of the functions assumed: + * + * (, (, , , )) + * + * In this context, the theorems refer to functions by fixed free variables. + * + * get_fn_args may return user-friendly argument names that clash with other names. + * We will process these names to avoid conflicts. + * + * get_fn_assumption should produce the desired theorems to assume. Its arguments: + * context (with fixed vars), callee name, callee term, arg terms, is recursive, measure term + * (all terms are fixed free vars). + * + * get_const_name generates names for the free function placeholders. + * FIXME: probably unnecessary and/or broken. + * + * We return two morphisms: + * - the first one makes the assumptions visible again + * - the second one automatically generalizes the assumed constants + * (this exists for backwards compat; all new code should explicitly use the + * returned free variable set) + *) +fun assume_called_functions_corres ctxt callees + get_fn_type get_fn_assumption get_fn_args get_const_name callers_measure_var = +let + (* Assume the existence of a function, along with a theorem about its + * behaviour. *) + fun assume_func ctxt fn_name is_recursive_call = + let + val fn_args = get_fn_args fn_name + + (* Fix a variable for the function. *) + val ([fixed_fn_name], ctxt') = Variable.variant_fixes [get_const_name fn_name] ctxt + val fn_free = Free (fixed_fn_name, get_fn_type fn_name) + + (* Fix a variable for the measure and function arguments. *) + val (measure_var_name :: arg_names, ctxt'') + = Variable.variant_fixes ("rec_measure'" :: (map fst fn_args)) ctxt' + val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args)) + val my_measure_var = Free (measure_var_name, @{typ nat}) + + (* + * A measure variable is needed to handle recursion: for recursive calls, + * we need to decrement the caller's input measure value (and our + * assumption will need to assume this to). This is so we can later prove + * termination of our function definition: the measure always reaches zero. + * + * Non-recursive calls can have a fresh value. + *) + val measure_var = + if is_recursive_call then + @{const "recguard_dec"} $ callers_measure_var + else + my_measure_var + + (* Create our assumption. *) + val assumption = + get_fn_assumption ctxt'' fn_name fn_free fn_arg_terms + is_recursive_call measure_var + |> fold Logic.all (rev ((if is_recursive_call then [] else [my_measure_var]) @ fn_arg_terms)) + |> Sign.no_vars ctxt' + |> Thm.cterm_of ctxt' + val ([thm], ctxt''') = Assumption.add_assumes [assumption] ctxt' + + (* Generate a morphism for escaping this context. *) + val m = (Assumption.export_morphism ctxt''' ctxt') + $> (Variable.export_morphism ctxt' ctxt) + in + (fn_free, thm, ctxt''', m) + end + + (* Apply each assumption. *) + val (res, (ctxt', m)) = fold_map ( + fn (fn_name, is_recursive_call) => + fn (ctxt, m) => + let + val (free, thm, ctxt', m') = + assume_func ctxt fn_name is_recursive_call + in + ((fn_name, (is_recursive_call, free, thm)), (ctxt', m' $> m)) + end) + callees (ctxt, Morphism.identity) +in + (ctxt', m, res) +end + +(* + * Convert a single function. + * + * Given a single concrete function, abstract that function and + * return a theorem that shows the correspondence. + * + * A theorem is returned which has assumptions that called functions + * correspond, giving a goal that this given function corresponds. + *) +fun gen_corres_for_function + (phase : FunctionInfo.phase) + (fn_info : FunctionInfo.fn_info) + (get_fn_type : string -> typ) + (get_fn_assumption : Proof.context -> string -> term -> term list -> bool -> term -> term) + (get_fn_args : string -> (string * typ) list) + (get_const_name : string -> string) + (convert : Proof.context -> string -> ((bool * term * thm) Symtab.table) -> + term -> term list -> (term * thm * (string * AutoCorresData.Trace) list)) + (ctxt : Proof.context) + (fn_name : string) = +let + val _ = writeln ("Converting (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name) + val start_time = Timer.startRealTimer () + + (* Get a list of functions we call. *) + val (normal_calls, recursive_calls) + = get_callees fn_info fn_name + val callees = + (map (fn x => (x, false)) normal_calls) + @ (map (fn x => (x, true)) recursive_calls) + + (* Make sure the desired function name is available. *) + val fn_target_name = get_const_name fn_name + val ([fn_free], ctxt') = Variable.variant_fixes [fn_target_name] ctxt + val _ = if fn_free = fn_target_name then () else + warning ("Variable clobbered: " ^ fn_target_name ^ " -> " ^ fn_free ^ ". Translating " ^ + fn_name ^ " may fail.") + val fn_var_morph = Variable.export_morphism ctxt' ctxt + + (* Fix a measure variable that will be used to track recursion progress. *) + val ([measure_var_name], ctxt'') = Variable.variant_fixes ["rec_measure'"] ctxt' + val measure_var = Free (measure_var_name, @{typ nat}) + val measure_var_morph = Variable.export_morphism ctxt'' ctxt' + + (* Fix variables for function arguments. *) + val fn_args = get_fn_args fn_name + val (arg_names, ctxt''') + = Variable.variant_fixes (map (fn (a, _) => a ^ "'arg") fn_args) ctxt'' + val fn_arg_terms = map (fn (n, T) => Free (n, T)) (arg_names ~~ (map snd fn_args)) + val fn_args_morph = Variable.export_morphism ctxt''' ctxt'' + + val _ = @{trace} ("Vars", fn_free, measure_var_name, arg_names) + (* Enter a context where we assume our callees exist. *) + val (ctxt'''', m, callee_info_and_proofs) + = assume_called_functions_corres ctxt''' callees + get_fn_type get_fn_assumption get_fn_args get_const_name + measure_var + + (* + * Do the conversion. We receive a new monadic version of the SIMPL + * term and a tactic for proving correspondence. + *) + val callee_tab = Symtab.make callee_info_and_proofs + val (body, thm, trace) = convert ctxt'''' fn_name callee_tab measure_var fn_arg_terms + + (* + * The returned body will have free variables as placeholders for the function's + * input parameters, for the functions it calls, and for its measure variable. + * + * We modify the body to be of the form: + * + * %fun1 fun2 rec1 rec2 measure arg1 arg2. f <...> + * + * That is, all non-recursive calls are abstracted out the front, followed by + * recursive calls, followed by the measure variable, followed by function + * arguments. + *) + val body = + fold lambda (rev fn_arg_terms) body + |> lambda measure_var + |> fold lambda (rev (map (fn x => Symtab.lookup callee_tab x |> the |> #2) recursive_calls)) + |> fold lambda (rev (map (fn x => Symtab.lookup callee_tab x |> the |> #2) normal_calls)) + + (* Export the theorem out of our context. *) + val exported_thm = Morphism.thm (m $> fn_args_morph $> measure_var_morph $> fn_var_morph) thm + + (* TODO: allow this message to be configured *) + val _ = @{trace} ("Converted (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ fn_name ^ " in " ^ + Time.toString (Timer.checkRealTimer start_time) ^ " s") +in + (exported_thm, body, trace) +end + +(* + * Given a SIMPL function, define a constant and a proof for it. + * + * "callee_thms" contains a table mapping function names to complete + * corres proofs for those functions. + * + * "functions" contains a list of (fn_name, (proof, body, proof traces)). + * The body should be of the form generated by gen_corres_for_function, + * with lambda abstractions for all callees and arguments. + * + * We assume that all functions in this list are mutually recursive. + * (If not, you should call "define_funcs" multiple times, each + * time with a single function.) + * + * The proof traces are stored into the theory. (This should probably be moved.) + *) +fun define_funcs + (phase : FunctionInfo.phase) + (filename : string) + (fn_info : FunctionInfo.fn_info) + (get_const_name : string -> string) + (get_fn_type : string -> typ) + (get_fn_assumption : Proof.context -> string -> term -> term list -> bool -> term -> term) + (get_fn_args : string -> (string * typ) list) + (rec_base_case : thm) + (ctxt : Proof.context) + (callee_thms : thm Symtab.table) + accum (* translate allows an accumulator, but we don't use it here *) + (functions : (string * (thm * term * (string * AutoCorresData.Trace) list)) list) + = + let + val fn_names = map fst functions + val fn_thms = map (snd #> #1) functions + val fn_bodies = map (snd #> #2) functions + val fn_traces = map (fn (fn_name, (_, _, traces)) => map (fn (module, trace) => (module, fn_name, trace)) traces) functions |> List.concat + + val _ = writeln ("Defining (" ^ FunctionInfo.string_of_phase phase ^ ") " ^ + (Utils.commas (map get_const_name fn_names))) + val _ = @{trace} ("function(s)", functions) + val _ = @{trace} ("callee(s)", Symtab.dest callee_thms) + + (* + * Determine if we are in a recursive case by checking to see if the + * first function in our list makes recursive calls to any other + * function. (This "other function" will be itself if it is simple + * recursion, but may be a different function if we are mutually + * recursive.) + *) + val is_recursive = FunctionInfo.is_function_recursive fn_info (hd fn_names) + val _ = assert (length fn_names = 1 orelse is_recursive) + "define_funcs passed multiple functions, but they don't appear to be recursive." + + (* + * Patch in functions into our function body in the following order: + * + * * Non-recursive calls; + * * Recursive calls + *) + fun fill_body fn_name body = + let + val (normal_calls, recursive_calls) + = get_callees fn_info fn_name + val non_rec_calls = map (fn x => Utils.get_term ctxt (get_const_name x)) normal_calls + val rec_calls = map (fn x => Free (get_const_name x, get_fn_type x)) recursive_calls + in + body + |> (fn t => betapplys (t, non_rec_calls)) + |> (fn t => betapplys (t, rec_calls)) + end + + (* + * Define our functions. + * + * Definitions should be of the form: + * + * %arg1 arg2 arg3. (arg1 + arg2 + arg3) + * + * Mutually recursive calls should be of the form "Free (fn_name, fn_type)". + *) + val defs = map ( + fn (fn_name, fn_body) => let + val fn_args = get_fn_args fn_name + (* FIXME: this retraces assume_called_functions_corres *) + val (fn_free :: measure_free :: arg_frees, _) = Variable.variant_fixes + (get_const_name fn_name :: "rec_measure'" :: map fst fn_args) ctxt + in (get_const_name fn_name, (* be inflexible when it comes to fn_name *) + (measure_free, @{typ nat}) :: (arg_frees ~~ map snd fn_args), (* changing arg names is ok *) + fill_body fn_name fn_body) end) + (fn_names ~~ fn_bodies) + val (fn_def_thms, ctxt) = Utils.define_functions defs true is_recursive ctxt + + (* Record the constant in our theory data. *) + val ctxt = fold (fn (fn_name, def) => + Local_Theory.background_theory ( + AutoCorresData.add_def filename (FunctionInfo.string_of_phase phase ^ "def") fn_name def)) + (Utils.zip fn_names fn_def_thms) ctxt + + (* + * Instantiate schematic function calls in our theorems with their + * concrete definitions. + *) + val combined_callees = map (get_callees fn_info) (map fst functions) + val combined_normal_calls = + map fst combined_callees |> flat |> sort_distinct fast_string_ord + val combined_recursive_calls = + map snd combined_callees |> flat |> sort_distinct fast_string_ord + val callee_terms = + (combined_recursive_calls @ combined_normal_calls) + |> map (fn x => (get_const_name x, Utils.get_term ctxt (get_const_name x))) + |> Symtab.make + fun fill_proofs thm = + Utils.instantiate_thm_vars ctxt + (fn ((name, _), _) => + Symtab.lookup callee_terms name + |> Option.map (Thm.cterm_of ctxt)) thm + val fn_thms = map fill_proofs fn_thms + + (* Fix free variable for the measure. *) + val ([measure_var_name], ctxt') = Variable.variant_fixes ["m"] ctxt + val measure_var = Free (measure_var_name, @{typ nat}) + + (* Generate corres predicates for each function. *) + val preds = map ( + fn fn_name => + let + fun mk_forall v t = HOLogic.all_const (Term.fastype_of v) $ lambda v t + val fn_const = Utils.get_term ctxt' (get_const_name fn_name) + + (* Fetch parameters to this function. *) + val free_params = + get_fn_args fn_name + |> Variable.variant_frees ctxt' [measure_var] + |> map Free + in + (* Generate the prop. *) + get_fn_assumption ctxt' fn_name fn_const + free_params is_recursive measure_var + |> fold Logic.all (rev free_params) + end) fn_names + + (* We generate a goal which solves all the mutually recursive calls simultaneously. *) + val goal = map (Object_Logic.atomize_term ctxt') preds + |> Utils.mk_conj_list + |> HOLogic.mk_Trueprop + |> Thm.cterm_of ctxt' + + (* Prove each of the predicates above, leaving any assumptions about called + * functions unsolved. *) + val pred_thms = map ( + fn (pred, thm, body_def) => (@{trace} thm; + Thm.trivial (Thm.cterm_of ctxt' pred) + |> Utils.apply_tac "unfold body" (Hypsubst.stac ctxt' body_def 1) + |> Utils.apply_tac "apply rule" (resolve_tac ctxt' [thm] 1) + |> Goal.norm_result ctxt + |> singleton (Variable.export ctxt' ctxt)) + ) + (Utils.zip3 preds fn_thms fn_def_thms) + + (* Create a set of "helper theorems", which should be sufficient to discharge + * all assumptions that our callees refine. *) + val helper_thms = + (map (Symtab.lookup callee_thms #> the) combined_normal_calls) @ pred_thms + |> map (Drule.forall_intr_vars) + |> map (Conv.fconv_rule (Object_Logic.atomize ctxt)) + + (* Generate a proof term of equivalence using the folded definitions. *) + val new_thm = + Goal.init goal + |> (fn thm => + if is_recursive then ( + Utils.apply_tac "start induction" + (resolve_tac ctxt' + [Utils.named_cterm_instantiate ctxt' + [("n", Thm.cterm_of ctxt' measure_var)] @{thm recguard_induct}] + 1) thm + |> Utils.apply_tac "unfold bodies" + (EVERY (map (fn x => (EqSubst.eqsubst_tac ctxt' [1] [x] 1)) (rev fn_def_thms))) + |> Utils.apply_tac "solve induction base cases" + (SOLVES ((simp_tac (put_simpset HOL_ss ctxt' addsimps [rec_base_case]) 1))) + |> Utils.apply_tac "solve remaing goals" + (Utils.metis_insert_tac ctxt helper_thms 1) + ) else ( + Utils.apply_tac "solve remaing goals" + (Utils.metis_insert_tac ctxt helper_thms 1) thm + )) + |> Goal.finish ctxt' + + (* + * The proof above is of the form (L1corres a & L1corres b & ...). + * Split it up into several proofs. + *) + fun prove_partial_l1_corres thm pred = + Thm.cterm_of ctxt' pred + |> Goal.init + |> Utils.apply_tac "solving using metis" (Utils.metis_tac ctxt [thm] 1) + |> Goal.finish ctxt' + + (* Generate the final theorems. *) + val new_thms = + map (prove_partial_l1_corres new_thm) preds + |> (Variable.export ctxt' ctxt) + |> map (Goal.norm_result ctxt) + + (* Record the theorems in our theory data. *) + val ctxt = fold (fn (fn_name, thm) => + Local_Theory.background_theory + (AutoCorresData.add_thm filename (FunctionInfo.string_of_phase phase ^ "corres") fn_name thm)) + (fn_names ~~ new_thms) ctxt + + (* Add the theorems to the context. *) + val ctxt = fold (fn (fn_name, thm) => + Utils.define_lemma (fn_name ^ "_" ^ FunctionInfo.string_of_phase phase ^ "corres") thm #> snd) + (fn_names ~~ new_thms) ctxt + + (* Add the traces to the context. *) + val ctxt = Local_Theory.background_theory + (fold (fn (phase, fn_name, trace) => + AutoCorresData.add_trace filename phase fn_name trace) fn_traces) ctxt + in + (new_thms, accum, ctxt) + end + +(* + * Do a translation phase, converting every function from one form to another. + *) +fun do_translation_phase + (phase : FunctionInfo.phase) + (filename : string) + (prog_info : ProgramInfo.prog_info) + (fn_info : FunctionInfo.fn_info) + (get_fn_type : string -> typ) + (get_fn_assumption : local_theory -> string -> term -> term list -> bool -> term -> term) + (get_fn_args : string -> (string * typ) list) + (get_const_name : string -> string) + (convert : local_theory -> string -> ((bool * term * thm) Symtab.table) -> + term -> term list -> (term * thm * (string * AutoCorresData.Trace) list)) + (gen_new_info : local_theory -> FunctionInfo.function_info -> FunctionInfo.phase_info) + (prove_mono : local_theory -> FunctionInfo.function_info list -> thm Symtab.table) + (rec_base_case : thm) + (ctxt : Proof.context) = +let + val do_gen_corres = + gen_corres_for_function phase fn_info get_fn_type get_fn_assumption + get_fn_args get_const_name convert; + val do_define_funcs = + define_funcs phase filename fn_info get_const_name get_fn_type + get_fn_assumption get_fn_args rec_base_case + (* Lookup functions that have already been translated (i.e. phase exists) *) + val initial_callees = Symtab.dest (FunctionInfo.get_all_functions fn_info) + |> List.mapPartial (fn (fn_name, info) => + FunctionInfo.Phasetab.lookup (#phases info) phase + |> Option.mapPartial (fn phase_info => + AutoCorresData.get_thm (Proof_Context.theory_of ctxt) filename + (FunctionInfo.string_of_phase phase ^ "corres") fn_name) + |> Option.map (fn thm => (fn_name, thm))) + |> Symtab.make + + (* Do the translation. *) + val (ctxt', new_fn_info, _) = + translate ctxt phase fn_info initial_callees do_gen_corres do_define_funcs + (fn lthy => K (gen_new_info lthy)) prove_mono () + + (* Map function information. *) +in + (ctxt', new_fn_info) +end + +end diff --git a/tools/autocorres/function_info.ML b/tools/autocorres/function_info.ML index 16d7c7f46..086d7cd5d 100644 --- a/tools/autocorres/function_info.ML +++ b/tools/autocorres/function_info.ML @@ -25,21 +25,29 @@ sig (* Function info for a single phase. *) type phase_info = { + phase : phase, args : (string * typ) list, return_type : typ, const : term, raw_const : term, +(* + callees : string list, + rec_callees : string list, +*) definition : thm, - mono_thm : thm option, - phase : phase + mono_thm : thm option }; + val phase_info_upd_phase : phase -> phase_info -> phase_info; val phase_info_upd_args : (string * typ) list -> phase_info -> phase_info; val phase_info_upd_return_type : typ -> phase_info -> phase_info; (* also updates raw_const *) val phase_info_upd_const : term -> phase_info -> phase_info; +(* + val phase_info_upd_callees : string list -> phase_info -> phase_info; + val phase_info_upd_rec_callees : string list -> phase_info -> phase_info; +*) val phase_info_upd_definition : thm -> phase_info -> phase_info; val phase_info_upd_mono_thm : thm option -> phase_info -> phase_info; - val phase_info_upd_phase : phase -> phase_info -> phase_info; (* Function info for a single function. *) type function_info = { @@ -101,6 +109,9 @@ structure Phasetab = Table( val ord = phase_ord); type phase_info = { + (* The translation phase for this definition. *) + phase : phase, + (* Arguments of the function, in order, excluding measure variables. *) args : (string * typ) list, @@ -120,10 +131,7 @@ type phase_info = { definition : thm, (* monad_mono theorem for the function, if it is recursive. *) - mono_thm : thm option, - - (* The translation phase for this definition. *) - phase : phase + mono_thm : thm option }; type function_info = { diff --git a/tools/autocorres/local_var_extract2.ML b/tools/autocorres/local_var_extract2.ML new file mode 100644 index 000000000..944f1349d --- /dev/null +++ b/tools/autocorres/local_var_extract2.ML @@ -0,0 +1,1717 @@ +(* + * Copyright 2014, NICTA + * + * This software may be distributed and modified according to the terms of + * the BSD 2-Clause license. Note that NO WARRANTY is provided. + * See "LICENSE_BSD2.txt" for details. + * + * @TAG(NICTA_BSD) + *) + +(* + * Extract local variables out of converted L1 fragments. + *) +structure LocalVarExtract2 = +struct + +open Prog + +(* Convenience abbreviations for set manipulation. *) +infix 1 INTER MINUS UNION +val empty_set = Varset.empty +val make_set = Varset.make +val union_sets = Varset.union_sets +val dest_set = Varset.dest +fun (a INTER b) = Varset.inter a b +fun (a MINUS b) = Varset.subtract b a +fun (a UNION b) = Varset.union a b + +(* Convenience shortcuts. *) +val warning = Utils.ac_warning +val apply_tac = Utils.apply_tac +val the' = Utils.the' + +(* Simpset we use for automated tactics. *) +fun setup_l2_ss ctxt = put_simpset AUTOCORRES_SIMPSET ctxt + addsimps [@{thm ucast_id}, @{thm pred_conj_def}] + +(* Convert a set of variable names into an Isabelle list of strings. *) +fun var_set_to_isa_list s = + dest_set s + |> map fst + |> map ProgramInfo.demangle_name + |> map Utils.encode_isa_string + |> Utils.encode_isa_list @{typ string} + +(* + * Remove references to local variables in "term", replacing them with free + * variables. + * + * We return a list of variables that were successfully extracted, along with + * the modified term itself. + * + * For instance: + * + * convert_local_vars @{term "a_' s + b + c"} + * [("x", @{term "a_' s"}), ("y", @{term "b_' s"})] + * + * would return ("x", @{term "x + b + c"}). + *) +fun convert_local_vars name_map term [] = ([], term) + | convert_local_vars name_map term ((var_name, var_term) :: vars) = + if Utils.contains_subterm var_term term then + let + val free_var = name_map (var_name, fastype_of var_term) + + (* Pull out "term" from "var_term". *) + val abstracted = betapply (Utils.abs_over var_name var_term term, free_var) + + (* Pull out the other variables. *) + val (other_vars, other_term) = convert_local_vars name_map abstracted vars + in + (other_vars @ [(var_name, fastype_of var_term)], other_term) + end + else + convert_local_vars name_map term vars + +(* Get the set of variables a function accepts and returns. *) +fun get_fn_input_output_vars prog_info fn_info fn_name = +let + val fn_def = FunctionInfo.get_phase_info fn_info FunctionInfo.L1 fn_name + val inputs = #args fn_def |> Varset.make + + (* Get the return type of a function. *) + val return_ctype = + ProgramAnalysis.get_rettype fn_name (#csenv prog_info) + |> Utils.the' ("Function missing from C-parser's csenv: " ^ quote fn_name) + + val outputs = + if return_ctype = Absyn.Void then + empty_set + else + make_set [(NameGeneration.return_var_name return_ctype |> MString.dest, + #return_type fn_def)] +in + (inputs, outputs) +end + +(* Get the return variable of a particular function. *) +fun get_ret_var prog_info fn_info fn_name = +let + val (_, outputs) = get_fn_input_output_vars prog_info fn_info fn_name +in + hd ((Varset.dest outputs) @ [("void", @{typ unit})]) +end + +(* + * Determine the state, return and exception type of a monad. + * + * Monads have the form: + * + * 'a => 'b => ... => 's => ('x, 'y, 's) L2_monad. + * + * We return: + * + * (['a, 'b, ...], ('x, 'y, 's)) + *) +fun dest_l2monad_T t = +let + val (Type ("Product_Type.prod", [Type ("Set.set", [Type ("Product_Type.prod", [ + Type ("Sum_Type.sum", [ex, ret]) ,state])]), _])) + = body_type t + val args = binder_types t + val inputs = List.take (args, length args - 1) +in + (inputs, (state, ret, ex)) +end +fun l2monad_type monad = + dest_l2monad_T (fastype_of monad) |> snd +fun l2monad_state_type monad = #1 (l2monad_type monad) +fun l2monad_ret_type monad = #2 (l2monad_type monad) +fun l2monad_ex_type monad = #3 (l2monad_type monad) + +(* Get the abstract/concrete term from a "L2corres" predicate. *) +fun dest_L2corres_term_abs @{term_pat "L2corres _ _ _ _ ?t _"} = t +fun dest_L2corres_term_conc @{term_pat "L2corres _ _ _ _ _ ?t"} = t + +(* Make an L2 monad. *) +fun mk_l2monadT stateT retT exT = + Utils.gen_typ @{typ "('a, 'b, 'c) L2_monad"} [stateT, retT, exT] + +(* + * "Spec" expressions are of the form: + * + * {(s, t). f s t} + * + * where "s" and "t" are input/output states. We want to parse the expression, + * and convert it to an L2 expression dealing only with globals in "s" and "t". + * + * If the original SIMPL spec attempts to read or write to local variables, we + * just fail. + *) +fun parse_spec ctxt prog_info term = + let + (* + * If simplification was turned off in L1, the spec may still contain + * unions and intersections, i.e. be of the form + * {(s, t). f s t} \ {(s, t). g s t} ... + * We blithely rewrite them here. + *) + val term = Raw_Simplifier.rewrite_term (Proof_Context.theory_of ctxt) + (map safe_mk_meta_eq @{thms Collect_prod_inter Collect_prod_union}) [] term + + (* Apply a dummy old and new state variable to the term. *) + val dummy_s = Free ("_dummy_state1", #state_type prog_info) + val dummy_t = Free ("_dummy_state2", #state_type prog_info) + val dummy_tuple = HOLogic.mk_tuple [dummy_s, dummy_t] + val t = Envir.beta_eta_contract ( + (Const (@{const_name "Set.member"}, + fastype_of dummy_tuple --> fastype_of term --> @{typ bool}) + $ dummy_tuple $ term)) + + (* Pull apart the "split" at the beginning of the term, then apply + * to our dummy variables *) + val t = Raw_Simplifier.rewrite_term (Proof_Context.theory_of ctxt) + (map mk_meta_eq @{thms split_def fst_conv snd_conv mem_Collect_eq}) [] t + + (* + * Pull out any references to any other variables into a lambda + * function. + * + * We pull out the globals variable first, because we want it to end + * up inner-most compared to all the other lambdas we generate. + *) + val globals_getter = #globals_getter prog_info + val t = Utils.abs_over "t" (globals_getter $ dummy_t) t + |> Utils.abs_over "s" (globals_getter $ dummy_s) + |> HOLogic.mk_case_prod + val t_collect = @{mk_term "Collect :: (?'s \ bool) \ ?'s set" ('s)} + (domain_type (fastype_of t)) + in + (* Determine if there are any references left to the dummy state + * variable. If so, give up on the translation. *) + if Utils.contains_subterm dummy_s t + orelse Utils.contains_subterm dummy_t t then + (warning ("Can't parse spec term: " + ^ (Utils.term_to_string ctxt term)); NONE) + else + SOME (t_collect $ t) + end + + +(* + * Parse an L1 expression containing references to the global state. + * + * We assume that the input term is in the "abstracted" form "%s. f s" where + * "s" is the global state variable. + * + * Our return value is a list of variables abstracted, whether the global + * variable was used, and the abstracted term itself. + * + * The function will fail (and return NONE) if the input expression performs + * arbitrary transformations on the state. For example: + * + * "%s. a_' s" => ([a], False, SOME @{term "%a s. a"}) + * "%s. globals s" => ([], True, SOME @{term "%s. s"}) + * "%s. a_' s + b_' s" => ([a, b], False, SOME @{term "%a b s. a + b"}) + * "%s. False" => ([], False, SOME @{term "%s. False"}) + * "%s. bot s" => ([], False, NONE) + *) +fun parse_expr ctxt prog_info name_map term = + let + val dummy_state = Free ("_dummy_state", #state_type prog_info) + + (* Apply a dummy state variable to the term. This makes our later analysis + * easier. *) + val term = Envir.beta_eta_contract (term $ dummy_state) + + (* + * Pull out any references to any other variables into a lambda + * function. + * + * We pull out the globals variable first, because we want it to end + * up inner-most compared to all the other lambdas we generate. + *) + val globals_getter = #globals_getter prog_info $ dummy_state + val globals_used = Utils.contains_subterm globals_getter term + val t = Utils.abs_over "s" globals_getter term + + (* Pull out local variables. *) + val all_getters = #var_getters prog_info |> Symtab.dest |> map (fn (a,b) => (a, b $ dummy_state)) + val (v1, t) = convert_local_vars name_map t all_getters + + (* + * Determine if there are any references left to the dummy state + * variable. + * + * If so, we are stuck: we aren't pulling out a part of the state + * record, but instead performing an arbitrary transformation on it. + * The most likely reason for this is the C parser's dummy function + * "lvar_init", which attempts to set an uninitialised local + * variable to an invalid state. Other possibilities include "bot", + * the always-false guard. + *) + val t = if Utils.contains_subterm dummy_state t then + (warning ("Can't parse expression: " + ^ (Utils.term_to_string ctxt term)); NONE) + else + SOME t; + in + (v1, globals_used, t) + end + +(* + * Parse an "L1_modify" expression. + *) +local +fun parse_modify' ctxt prog_info name_map term = + let + val dummy_state = Free ("_dummy_state", #state_type prog_info) + + (* + * We expect modify clauses in two forms: both "%x. (foo x) x" and just + * "foo". We apply a state variable to the function and beta/eta contract + * to normalise our output for the next steps. + *) + val modify_clause = Envir.beta_eta_contract (term $ dummy_state) + + (* + * Extract "xxx" from "foo_'_update xxx". + * + * If the user has written custom "modifies" clauses (presumably + * using "AUXUPD" directives), this may fail. + *) + val (setter, modify_val, s) = case modify_clause of + (Const var $ value $ s) => (Const var, value, s) + | other => Utils.invalid_term "Const (x,y) $ z" other; + val (var_name, var_type) = ProgramInfo.guess_var_name_type_from_setter_term setter + + (* + * At this stage we have assume we have an update function "f" of + * type "'a => 'a" which expects the old value of the variable + * being updated, and returns a new value. + * + * We now want to convert this into a value of type "'a", returning + * the new value. We do this by applying "(field_' s)" to the + * function f, followed by normalisation. + *) + val getter = + case (Symtab.lookup (#var_getters prog_info) var_name) of + SOME v => v + | NONE => Utils.invalid_input "valid local variable getter" var_name + val modify_val = betapply (modify_val, getter $ dummy_state) + |> Envir.beta_eta_contract + + (* + * We are now in the form of "foo dummy_state". Pull out + * our dummy state variable, and parse the expression. + *) + fun remove_dummy_state_var t = Utils.abs_over "s" dummy_state t + val (vars, globals_used, modify_val) = parse_expr ctxt prog_info name_map (remove_dummy_state_var modify_val) + in + ((var_name, var_type), vars, globals_used, modify_val, remove_dummy_state_var s) + end +in +fun parse_modify ctxt prog_info name_map term = + let + val dummy_state = Free ("_dummy_state", #state_type prog_info) + in + if Envir.beta_eta_contract (term $ dummy_state) = dummy_state then + [] + else + let + val (updated_var, read_vars, reads_globals, term, residual) = parse_modify' ctxt prog_info name_map term + in + (updated_var, read_vars, reads_globals, term) :: parse_modify ctxt prog_info name_map residual + end + end +end + +(* + * Construct precondition from variable set. + * + * These preconditions are of the form: + * + * "(%s. n_' s = n) and (%s. i_' s = i) and ..." + *) +fun mk_precond prog_info name_map vars = +let + val myvarsT = #state_type prog_info + val dummy_state = Free ("_dummy_state", myvarsT) + + (* Fetch a variable getter, such as "n_'" from a variable's name. *) + fun var_getter var = + case Symtab.lookup (#var_getters prog_info) var of + SOME x => (x $ dummy_state) + | NONE => Utils.invalid_input "valid local variable name" var +in + Utils.chain_preds myvarsT + (map (fn (var_name, var_type) => Utils.abs_over "s" dummy_state + (HOLogic.mk_eq (var_getter var_name, name_map (var_name, var_type)))) + (dest_set vars)) +end + +(* + * Construct extraction functions, of the form: + * + * "%s. (a_' s, b_' s, c_' s)" + *) +fun mk_xf (prog_info : ProgramInfo.prog_info) vars = +let + val dummy_state = Free ("_dummy_state", #state_type prog_info) + fun var_getter var = + ((Symtab.lookup (#var_getters prog_info) var |> the) $ dummy_state) + handle Option => (Utils.invalid_input "valid local variable name" var) +in + Utils.abs_over "s" dummy_state + (HOLogic.mk_tuple (dest_set vars |> map fst |> map var_getter)) +end + +(* + * Construct a correspondence lemma between a given L2 and L1 terms. + *) +fun mk_corresXF_prop thy prog_info name_map return_vars except_vars precond_vars l2_term l1_term = + let + (* Construct precondition and extraction functions. *) + val precond = mk_precond prog_info name_map precond_vars + val return_xf = mk_xf prog_info return_vars + val except_xf = mk_xf prog_info except_vars + in + Utils.mk_term thy @{term L2corres} [#globals_getter prog_info, return_xf, + except_xf, precond, l2_term, l1_term] + |> HOLogic.mk_Trueprop + end + +(* + * Prove correspondence between L1 and L2. + * + * ctxt: Local theory context + * + * return_vars: Variables that are returned by the abstract spec's monad. + * + * except_vars: Variables that are thrown by the abstract spec's monad. + * + * precond_vars: Variables that must match between abstract and concrete. + * + * l2_term / l1_term: Abstract and concrete specs. + *) +fun mk_corresXF_thm ctxt prog_info name_map return_vars except_vars precond_vars l2_term l1_term tac = +let + val free_vars = precond_vars |> dest_set |> map name_map + val free_names = map (dest_Free #> fst) free_vars +in + mk_corresXF_prop (Proof_Context.theory_of ctxt) prog_info name_map + return_vars except_vars precond_vars l2_term l1_term + |> (fn x => Goal.prove ctxt free_names [] x (fn _ => tac)) +end + +fun mk_corresXF_thm' ctxt prog_info name_map return_vars except_vars precond_vars l2_term l1_term thm = + mk_corresXF_thm ctxt prog_info name_map return_vars except_vars precond_vars l2_term l1_term ( + (rewrite_goal_tac ctxt [mk_meta_eq @{thm split_def}] 1) + THEN + (resolve_tac ctxt [rewrite_rule ctxt [mk_meta_eq @{thm split_def}] thm] 1) + THEN + (REPEAT (CHANGED (asm_full_simp_tac (setup_l2_ss ctxt) 1))) + ) + +fun l1call_function_const t = case strip_comb t |> apsnd rev of + (Const c, (Const c' :: _)) => if String.isSuffix "_'proc" (fst c') + then Const c' else Const c + | (Const c, _) => Const c + | (Abs (_, _, t), []) => l1call_function_const t + | _ => raise TERM ("l1call_function_const", [t]) + +(* + * Parse an L1 term. + * + * In particular, we break down the structure of the program and parse the + * usage of local variables in all expressions and modifies clauses. + *) +fun parse_l1 ctxt prog_info fn_info name_map term = + case term of + (Const (@{const_name "L1_skip"}, _)) => + Modify (term, + (SOME (Abs ("s", #globals_type prog_info, @{term "()"})), empty_set, false), NONE) + + | (Const (@{const_name "L1_modify"}, _) $ m) => + let + val parsed_clause = parse_modify ctxt prog_info name_map m + val (updated_var, read_vars, is_globals_reader, parsed_expr) = + case parsed_clause of + [x] => x + | _ => Utils.invalid_term "Modifies clause too complex." m + in + Modify (term, (parsed_expr, make_set read_vars, is_globals_reader), SOME updated_var) + end + + | (Const (@{const_name "L1_seq"}, _) $ lhs $ rhs) => + Seq (term, parse_l1 ctxt prog_info fn_info name_map lhs, + parse_l1 ctxt prog_info fn_info name_map rhs) + + | (Const (@{const_name "L1_catch"}, _) $ lhs $ rhs) => + Catch (term, parse_l1 ctxt prog_info fn_info name_map lhs, + parse_l1 ctxt prog_info fn_info name_map rhs) + + | (Const (@{const_name "L1_guard"}, _) $ c) => + let + val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map c + in + Guard (term, (parsed_expr, make_set read_vars, is_globals_reader)) + end + + | (Const (@{const_name "L1_throw"}, _)) => + Throw term + + | (Const (@{const_name "L1_condition"}, _) $ cond $ lhs $ rhs) => + let + (* Parse the conditional. *) + val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map cond + in + Condition (term, (parsed_expr, make_set read_vars, is_globals_reader), + parse_l1 ctxt prog_info fn_info name_map lhs, + parse_l1 ctxt prog_info fn_info name_map rhs) + end + + | (Const (@{const_name "L1_call"}, L1_call_type) + $ arg_setup $ dest_fn_term $ globals_extract $ ret_extract) => + let + (* Parse arg setup. We treat this not as a modify, but as several + * expressions, as the modified variables are only in the scope of + * this L1_call command. *) + val arg_setup_exprs = parse_modify ctxt prog_info name_map arg_setup + |> map (fn (_, read_vars, is_globals_reader, term) => + (term, make_set read_vars, is_globals_reader)) + + val dest_fn_term = case dest_fn_term of + Const (@{const_name "measure_call"}, _) $ f => f + | _ => dest_fn_term + + (* Get the name of the variable the return value of the function will + * be placed into. *) + val dest_fn = FunctionInfo.get_function_from_const fn_info + (l1call_function_const dest_fn_term) + |> Utils.the' ("Unknown function " ^ quote (@{make_string} dest_fn_term)) + val dest_fn_name = #name dest_fn + + (* Parse the return arguments. *) + val ret_var = get_ret_var prog_info fn_info (#name dest_fn) + val parsed_clause = + parse_modify ctxt prog_info name_map (betapply (ret_extract, Free ("_dummy_state", #state_type prog_info))) + |> map (fn (target_var, read_vars, globals_read, expr) => + (target_var, (make_set read_vars) MINUS (make_set [ret_var]), + globals_read, Option.map (Utils.abs_over "ret" (name_map ret_var)) expr)) + + val (ret_expr, updated_var) = + case parsed_clause of + [(target_var, read_vars, globals_read, expr)] => + ((expr, read_vars, globals_read), SOME target_var) + | [] => ((NONE, empty_set, false), NONE) + | x => Utils.invalid_input "single return param" (PolyML.makestring x) + in + Call (term, arg_setup_exprs, ret_expr, updated_var, ()) + end + + | (Const (@{const_name "L1_while"}, _) $ cond $ body) => + let + (* Parse conditional. *) + val (read_vars, is_globals_reader, parsed_expr) = parse_expr ctxt prog_info name_map cond; + in + While (term, (parsed_expr, make_set read_vars, is_globals_reader), + parse_l1 ctxt prog_info fn_info name_map body) + end + + | (Const (@{const_name "L1_init"}, _) $ setter) => + let + val updated_var = ProgramInfo.guess_var_name_type_from_setter_term setter + in + Init (term, SOME updated_var) + end + + | (Const (@{const_name "L1_spec"}, _) $ c) => + (case parse_spec ctxt prog_info c of + SOME x => + Spec (term, (SOME x, empty_set, true)) + | NONE => + Spec (term, (NONE, empty_set, true))) + + | (Const (@{const_name "L1_fail"}, _)) => + Fail term + + | (Const (@{const_name "L1_recguard"}, _) $ var $ body) => + RecGuard (term, parse_l1 ctxt prog_info fn_info name_map body) + + | other => Utils.invalid_term "a L1 term" other + +(* + * Generate a proof showing that a particular variables "var" is not modified + * over the given input L1 term. + *) +fun mk_preservation_proof ctxt prog_info name_map var term = +let + val thy = Proof_Context.theory_of ctxt + + (* Apply a tactic then simplify all remaining subgoals. *) + fun s tac = + tac THEN (TRY (REPEAT (CHANGED (asm_full_simp_tac (setup_l2_ss ctxt) 1)))) + + (* Apply a rule then simplify all remaining subgoals. *) + fun r thm = s (resolve_tac ctxt [thm] 1) + + (* Generate the predicate. *) + val var_set = make_set [var] + val precond = mk_precond prog_info name_map var_set + val postcond_ret = absdummy @{typ unit} (mk_precond prog_info name_map var_set) + val postcond_ex = absdummy @{typ unit} (mk_precond prog_info name_map var_set) + val goal = + Utils.mk_term thy @{term validE} [precond, term, postcond_ret, postcond_ex] + |> HOLogic.mk_Trueprop + + (* Construct a tactic that solves the problem. *) + val tac = + (case term of + (Const (@{const_name "L1_skip"}, _)) => + r @{thm L1_skip_lp} + | (Const (@{const_name "L1_init"}, _) $ _) => + r @{thm L1_init_lp} + | (Const (@{const_name "L1_modify"}, _) $ _) => + r @{thm L1_modify_lp} + | (Const (@{const_name "L1_call"}, _) $ _ $ _ $ _ $ _) => + r @{thm L1_call_lp} + | (Const (@{const_name "L1_guard"}, _) $ _) => + r @{thm L1_guard_lp} + | (Const (@{const_name "L1_throw"}, _)) => + r @{thm L1_throw_lp} + | (Const (@{const_name "L1_spec"}, _) $ _) => + r @{thm hoareE_TrueI} + | (Const (@{const_name "L1_fail"}, _)) => + r @{thm L1_fail_lp} + | (Const (@{const_name "L1_while"}, _) $ _ $ body) => + let + val body' = mk_preservation_proof ctxt prog_info name_map var body + in + s (resolve_tac ctxt @{thms L1_while_lp} 1 THEN resolve_tac ctxt [body'] 1) + end + | (Const (@{const_name "L1_condition"}, _) $ _ $ lhs $ rhs) => + let + val lhs' = mk_preservation_proof ctxt prog_info name_map var lhs + val rhs' = mk_preservation_proof ctxt prog_info name_map var rhs + in + s (resolve_tac ctxt @{thms L1_condition_lp} 1 THEN resolve_tac ctxt [lhs'] 1 THEN resolve_tac ctxt [rhs'] 1) + end + | (Const (@{const_name "L1_seq"}, _) $ lhs $ rhs) => + let + val lhs' = mk_preservation_proof ctxt prog_info name_map var lhs + val rhs' = mk_preservation_proof ctxt prog_info name_map var rhs + in + s (resolve_tac ctxt @{thms L1_seq_lp} 1 THEN resolve_tac ctxt [lhs'] 1 THEN resolve_tac ctxt [rhs'] 1) + end + | (Const (@{const_name "L1_catch"}, _) $ lhs $ rhs) => + let + val lhs' = mk_preservation_proof ctxt prog_info name_map var lhs + val rhs' = mk_preservation_proof ctxt prog_info name_map var rhs + in + s (resolve_tac ctxt @{thms L1_catch_lp} 1 THEN resolve_tac ctxt [lhs'] 1 THEN resolve_tac ctxt [rhs'] 1) + end + | (Const (@{const_name "L1_recguard"}, _) $ _ $ body) => + let + val body' = mk_preservation_proof ctxt prog_info name_map var body + in + s (resolve_tac ctxt @{thms L1_recguard_lp} 1 THEN resolve_tac ctxt [body'] 1) + end + | other => Utils.invalid_term "a L1 term" other) +in + (* Generate proof. *) + Thm.cterm_of ctxt goal + |> Goal.init + |> Utils.apply_tac ("proving variable preservation for var '" ^ (fst var) ^ "'") tac + |> Goal.finish ctxt +end + +(* Generate a preservation proof for multiple variables. *) +fun mk_multivar_preservation_proof ctxt prog_info name_map term var_set = +let + val proofs = map (fn x => + mk_preservation_proof ctxt prog_info name_map x term) + (dest_set var_set) + val result = fold (fn x => fn y => @{thm combine_validE} OF [x,y]) + proofs @{thm hoareE_TrueI} +in + result +end +handle Option => error ("Preservation proof failed for " ^ quote (@{make_string} var_set)) + +(* + * Generate a well-typed L2 monad expression. + * + * "const" is the name of the monadic function (e.g., @{const_name "L2_gets"}) + * + * "ret"/"throw" are the variables being returned or thrown by this monadic + * expression. This is used only for determining the type of the output + * monad. + * + * "params" are the expressions to be beta applied to the monad. + *) +fun mk_l2monad (prog_info : ProgramInfo.prog_info) const ret throw params = +let + val retT = HOLogic.mk_tupleT (dest_set ret |> map snd) + val exT = HOLogic.mk_tupleT (dest_set throw |> map snd) + val monadT = mk_l2monadT (#globals_type prog_info) retT exT +in + betapplys ((Const (const, (map fastype_of params) ---> monadT)), params) +end + +(* Abstract over a tuple using the given name map. *) +fun abs_over_tuple_vars (name_map : (string * typ) -> term) (vars : varset) = + Utils.abs_over_tuple (map (fn (a, b) => (a, name_map (a, b))) (dest_set vars)) + +(* + * Take an L2corres theorem of the form: + * + * L2corres st ret_xf ex_xf P (foo a b c) X + * + * and convert it into the form: + * + * L2corres st ret_xf ex_xf P ((%(a, b, c). foo a b c) (a, b, c)) X + * + * This is used to ease unification in proofs where the abstract monad is + * expected to be of the form "A x", where "x" is the return value of another + * monad. + *) +fun abs_over_thm ctxt (name_map : (string * typ) -> term) (thm : thm) (vars : varset) = +let + fun convert_var_to_free x = + case x of + Var ((a, _), t) => Free (a, t) + | x => x + fun convert_free_to_var x = + case x of + Free (a, t) => Var ((a, 0), t) + | x => x + val (head $ st $ ret_xf $ ex_xf $ precond $ l2_term $ l1_term) = + map_aterms convert_var_to_free (Thm.concl_of thm) |> HOLogic.dest_Trueprop + val new_l2_term = (abs_over_tuple_vars name_map vars l2_term + $ Free ("r'", HOLogic.mk_tupleT (dest_set vars |> map snd))) + val new_concl = + head $ st $ ret_xf $ ex_xf $ precond $ new_l2_term $ l1_term + |> map_aterms convert_free_to_var + |> HOLogic.mk_Trueprop + val new_thm = list_implies (cprems_of thm, Thm.cterm_of ctxt new_concl) +in + Goal.init new_thm + |> asm_full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps [mk_meta_eq @{thm split_def}]) 1 |> Seq.hd + |> resolve_tac ctxt [rewrite_rule ctxt [mk_meta_eq @{thm split_def}] thm] 1 |> Seq.hd + |> REPEAT (assume_tac ctxt 1) |> Seq.hd + |> Goal.finish ctxt +end + +(* + * Given a L2 monad that returns the variables "vars_returned", convert it into + * an L2 monad that returns "needed_returns". + * + * This is frequently needed when a particular monad is only capable of returning + * a particular variable (or set of variables), but needs to return a different set + * of these variables. For example, both branches in an "condition" block need + * to return the same set of variables. + * + * The injection is done by (if necessary) appending an additional "L2_seq" to + * the input monad, returning the desired set of variables. + * + * "allow_excess" is the output monad is allowed to return a superset of + * "needed_returns". By allowing such excess variables to be returned, the + * generated output can be neater than if we were more strict. + *) +fun inject_return_vals ctxt prog_info name_map needed_returns allow_excess throw_vars fn_vars + term (vars_read, vars_returned, output_monad, thm) = + if needed_returns = vars_returned then + (* We already have precisely what is needed --- no more to do. *) + (vars_read, vars_returned, output_monad, thm) + else if (allow_excess andalso Varset.subset (needed_returns, vars_returned)) then + (* We already provide a superset of what is needed, and this is allowed. *) + (vars_read, vars_returned, output_monad, thm) + else + let + val (l1_term, _, _) = get_node_data term + + (* Generate the return statement. *) + val injected_return = + mk_l2monad prog_info @{const_name L2_gets} needed_returns throw_vars + [absdummy (#globals_type prog_info) (HOLogic.mk_tuple (dest_set needed_returns |> map name_map)), + var_set_to_isa_list needed_returns] + |> abs_over_tuple_vars name_map vars_returned + + (* Append the return statement to the input term. *) + val generated_term = mk_l2monad prog_info @{const_name L2_seq} + needed_returns throw_vars [output_monad, injected_return] + val preserved_vals = needed_returns MINUS vars_returned + + (* Generate a proof of correctness. *) + val generated_thm = + let + val preserve_proof = mk_multivar_preservation_proof ctxt prog_info name_map l1_term preserved_vals + in + mk_corresXF_thm' ctxt prog_info name_map needed_returns throw_vars (vars_read UNION preserved_vals) + generated_term l1_term + (@{thm L2corres_inject_return} OF [thm, @{thm validE_weaken} OF [preserve_proof]]) + end + in + (vars_read UNION preserved_vals, needed_returns, generated_term, generated_thm) + end + +(* + * Convert an L1 function into an L2 function. + * + * We assume that our input term has come out of the L1 conversion functions. + * + * We have inputs of the following: + * + * ctxt: Isabelle context + * + * needed_vars: + * + * Variables that are read in later executions. + * + * These are passed into the conversion so that we know what variables + * we need to track for later execution, and what variables we can + * just discard on the spot. + * + * If we didn't know what we actually needed to track, then the + * converted code would be significantly bloated due to returning + * variables that aren't actually used. + * + * allow_excess: + * + * Are we allowed to return _more_ variables than otherwise needed + * according to needed_vars? By setting this to true, more efficient + * code can be generated. + * + * throw_vars: + * + * Variables that must be thrown in the event we decide to emit an + * "L2_throw" call. These are calculated as we enter a try/catch block + * to ensure that all sites are consistent in the values they throw. + * + * term: The L1 term to convert. + * + * The return value of this function is a tuple: + * + * (, , , ) + * + * The "vars returned" is the variables that are returned through the "bind" + * combinator. + *) +fun do_conv + (ctxt : Proof.context) + prog_info + fn_info + name_map + (fn_vars : varset) + (callee_proofs : (bool * term * thm) Symtab.table) + (needed_vars : varset) + (allow_excess : bool) + (throw_vars : varset) + (term : (term * varset * varset, term option * varset * bool, (string * typ) option, unit) prog) + : (varset * varset * term * thm) = +let + val l1_term = get_node_data term |> #1 + val live_vars = get_node_data term |> #2 + val modified_vars = get_node_data term |> #3 + val inject = + inject_return_vals ctxt prog_info name_map needed_vars allow_excess throw_vars fn_vars term + fun mkthm read_vars ret_vars generated_term thm = + mk_corresXF_thm' ctxt prog_info name_map ret_vars throw_vars read_vars generated_term l1_term thm + val mk_monad = mk_l2monad prog_info +in + case term of + Init (_, SOME output_var) => + let + val out_vars = make_set [output_var] + val generated_term = mk_monad @{const_name L2_unknown} out_vars throw_vars + [Utils.ml_str_list_to_isa [fst output_var]] + val thm = mkthm empty_set out_vars generated_term @{thm L2corres_spec_unknown} + in + inject (empty_set, out_vars, generated_term, thm) + end + + (* L1_skip. *) + | Modify (_, (SOME expr, _, _), NONE) => + let + val generated_term = mk_monad @{const_name L2_gets} + empty_set throw_vars [expr, var_set_to_isa_list empty_set] + val thm = mkthm empty_set empty_set generated_term @{thm L2corres_gets_skip} + in + inject (empty_set, empty_set, generated_term, thm) + end + + (* L1_modify with unparsable expression. *) + | Modify (_, (NONE, _, _), SOME output_var) => + let + val out_vars = make_set [output_var] + val generated_term = mk_monad @{const_name L2_unknown} out_vars throw_vars [] + val thm = mkthm empty_set out_vars generated_term @{thm L2corres_modify_unknown} + in + inject (empty_set, out_vars, generated_term, thm) + end + + (* L1_modify that only modifies globals. *) + | Modify (_, (SOME expr, read_vars, _), SOME ("globals'", _)) => + let + val generated_term = mk_monad @{const_name L2_modify} empty_set throw_vars [expr] + val thm = mkthm read_vars empty_set generated_term @{thm L2corres_modify_global} + in + inject (read_vars, empty_set, generated_term, thm) + end + + (* L1_modify that only modifies a local and also reads globals. *) + | Modify (_, (SOME expr, read_vars, _), SOME output_var) => + let + val generated_term = mk_monad @{const_name L2_gets} + (make_set [output_var]) throw_vars [expr, var_set_to_isa_list (make_set [output_var])] + val thm = mkthm read_vars (make_set [output_var]) generated_term @{thm L2corres_modify_gets} + in + inject (read_vars, make_set [output_var], generated_term, thm) + end + + | Throw _ => + let + val generated_term = mk_monad @{const_name L2_throw} needed_vars throw_vars + [HOLogic.mk_tuple (dest_set throw_vars |> map name_map), + var_set_to_isa_list throw_vars] + val thm = mkthm throw_vars needed_vars generated_term @{thm L2corres_throw} + in + (throw_vars, needed_vars, generated_term, thm) + end + + | Spec (_, (SOME expr, read_vars, _)) => + let + val generated_term = mk_monad @{const_name "L2_spec"} needed_vars throw_vars [expr] + val thm = mkthm read_vars needed_vars generated_term @{thm L2corres_spec} + in + inject (read_vars, needed_vars, generated_term, thm) + end + + | Spec (_, (NONE, _, _)) => + let + val generated_term = mk_monad @{const_name "L2_fail"} needed_vars throw_vars [] + val thm = mkthm empty_set needed_vars generated_term @{thm L2corres_fail} + in + inject (empty_set, needed_vars, generated_term, thm) + end + + | Guard (_, (SOME expr, read_vars, _)) => + let + val generated_term = mk_monad @{const_name "L2_guard"} empty_set throw_vars [expr] + val thm = mkthm read_vars empty_set generated_term @{thm L2corres_guard} + in + inject (read_vars, empty_set, generated_term, thm) + end + + | Guard (_, (NONE, _, _)) => + let + val generated_term = mk_monad @{const_name "L2_fail"} needed_vars throw_vars [] + val thm = mkthm empty_set needed_vars generated_term @{thm L2corres_fail} + in + (empty_set, needed_vars, generated_term, thm) + end + + | Fail _ => + let + val generated_term = mk_monad @{const_name "L2_fail"} needed_vars throw_vars [] + val thm = mkthm empty_set needed_vars generated_term @{thm L2corres_fail} + in + (empty_set, needed_vars, generated_term, thm) + end + + | Seq (_, lhs, rhs) => + let + val (_, rhs_live, rhs_modified) = get_node_data rhs + val (lhs_term, _, lhs_modified) = get_node_data lhs + + (* Convert LHS and RHS. *) + val ret_vars = rhs_live INTER lhs_modified + val (lhs_reads, lhs_rets, new_lhs, lhs_thm) + = do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs ret_vars true throw_vars lhs + val (rhs_reads, rhs_rets, new_rhs, rhs_thm) + = do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs needed_vars allow_excess throw_vars rhs + val block_reads = lhs_reads UNION (rhs_reads MINUS lhs_modified) + + (* Reconstruct body to support our input tuple. *) + val rhs_thm = abs_over_thm ctxt name_map rhs_thm lhs_rets + val new_rhs = abs_over_tuple_vars name_map lhs_rets new_rhs + + (* Generate the final term. *) + val generated_term = mk_monad @{const_name L2_seq} rhs_rets throw_vars [new_lhs, new_rhs] + + (* Generate a proof. *) + val thm = + let + (* Show that certain variables are preserved by the LHS. *) + val needed_preserves = (rhs_reads MINUS lhs_modified) + val preserve_proof = mk_multivar_preservation_proof ctxt prog_info name_map lhs_term needed_preserves + in + mkthm block_reads rhs_rets generated_term + (@{thm L2corres_seq} OF [lhs_thm, rhs_thm, + @{thm validE_weaken} OF [preserve_proof]]) + end + in + inject (block_reads, rhs_rets, generated_term, thm) + end + + | Catch (_, lhs, rhs) => + let + val (lhs_term, _, lhs_modified) = get_node_data lhs + val (_, rhs_live, _) = get_node_data rhs + + (* Convert LHS and RHS. *) + val lhs_throws = rhs_live INTER lhs_modified + val (lhs_reads, lhs_rets, new_lhs, lhs_thm) + = do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs (needed_vars) false lhs_throws lhs + val (rhs_reads, _, new_rhs, rhs_thm) + = do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs (needed_vars) false throw_vars rhs + val block_reads = lhs_reads UNION (rhs_reads MINUS lhs_throws) + + (* Reconstruct body to support our input tuple. *) + val rhs_thm = abs_over_thm ctxt name_map rhs_thm lhs_throws + val new_rhs = abs_over_tuple_vars name_map lhs_throws new_rhs + + (* Generate the final term. *) + val generated_term = mk_monad @{const_name L2_catch} needed_vars throw_vars [new_lhs, new_rhs] + + (* Generate a proof. *) + val thm = + let + (* Show that certain variables are preserved by the LHS. *) + val needed_preserves = (rhs_reads MINUS lhs_modified) + val preserve_proof = mk_multivar_preservation_proof ctxt prog_info name_map lhs_term needed_preserves + in + mkthm block_reads needed_vars generated_term + (@{thm L2corres_catch} OF [lhs_thm, rhs_thm, @{thm validE_weaken} OF [preserve_proof]]) + end + in + inject (block_reads, needed_vars, generated_term, thm) + end + + | RecGuard (_, body) => + let + (* Convert body. *) + val (body_reads, vars_returned, new_body, body_thm) = + do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs needed_vars false throw_vars body + + (* Get recguard variable. *) + val (_ $ var $ _) = l1_term + + (* Generate the final term. *) + val generated_term = + mk_monad @{const_name "L2_recguard"} vars_returned throw_vars [ + var, new_body] + val thm = mkthm body_reads vars_returned generated_term + (@{thm L2corres_recguard} OF [body_thm]) + in + inject (body_reads, vars_returned, generated_term, thm) + end + + | Condition (_, (SOME expr, read_vars, _), lhs, rhs) => + let + (* Convert LHS and RHS. *) + val requested_vars = needed_vars INTER modified_vars + val (lhs_reads, _, new_lhs, lhs_thm) + = do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs requested_vars false throw_vars lhs + val (rhs_reads, _, new_rhs, rhs_thm) + = do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs requested_vars false throw_vars rhs + val block_reads = lhs_reads UNION rhs_reads UNION read_vars + + (* Generate the final term. *) + val generated_term = mk_monad @{const_name "L2_condition"} + requested_vars throw_vars [expr, new_lhs, new_rhs] + val thm = mkthm block_reads requested_vars generated_term + (@{thm L2corres_cond} OF [lhs_thm, rhs_thm]) + in + inject (block_reads, requested_vars, generated_term, thm) + end + + | While (_, (SOME expr, read_vars, _), body) => + let + (* Convert body. *) + val loop_iterators = (needed_vars UNION live_vars) INTER modified_vars + val (body_reads, _, new_body, body_thm) = + do_conv ctxt prog_info fn_info name_map fn_vars callee_proofs loop_iterators false throw_vars body + val (body_term, _, body_modifies) = get_node_data body + + (* Reconstruct body to support our input tuple. *) + val new_body = abs_over_tuple_vars name_map loop_iterators new_body + val body_thm = abs_over_thm ctxt name_map body_thm loop_iterators + + (* Generate the final term. *) + val generated_term = + mk_monad @{const_name "L2_while"} loop_iterators throw_vars [ + abs_over_tuple_vars name_map loop_iterators expr, + new_body, + HOLogic.mk_tuple (dest_set loop_iterators |> map name_map), + var_set_to_isa_list loop_iterators] + + (* Generate a proof. *) + val thm = + let + (* Show that certain variables are preserved by the LHS. *) + val needed_preserves = ((body_reads UNION read_vars) MINUS body_modifies) + val preserve_proof = mk_multivar_preservation_proof ctxt prog_info name_map body_term needed_preserves + + (* Instantiate while loop rule to avoid ambiguous unification. *) + val tracked_vars = (body_reads UNION read_vars UNION loop_iterators) + val invariant_precond = abs_over_tuple_vars name_map loop_iterators + (mk_precond prog_info name_map tracked_vars) + val base_thm = Utils.named_cterm_instantiate ctxt [ + ("P", Thm.cterm_of ctxt invariant_precond), + ("A", Thm.cterm_of ctxt new_body) + ] @{thm L2corres_while} + in + mkthm (body_reads UNION read_vars UNION loop_iterators) loop_iterators generated_term + (base_thm OF [body_thm, @{thm validE_weaken} OF [preserve_proof]]) + end + in + inject (body_reads UNION read_vars UNION loop_iterators, loop_iterators, generated_term, thm) + end + + | Call (_, expr_list, (ret_expr, ret_read_vars, _), ret_var, measure_term) => + let + val (_ $ arg_setup $ dest_fn $ globals_extract $ ret_extract) = l1_term + + val (measure_term, dest_fn) = case dest_fn of + (c as Const (@{const_name "measure_call"}, _)) $ f => (c, f) + | f $ (c as Const (@{const_name "undefined"}, _)) => (c, f) + | f $ (c as Const (@{const_name "recguard_dec"}, _) $ _) => (c, f) + | _ => raise TERM ("local_var_extract: strange function call", [dest_fn]) + + (* Get destination function. *) + val dest_fn = FunctionInfo.get_function_from_const fn_info + (l1call_function_const dest_fn) + + (* Lookup the callee proof, if it exists. *) + val callee_proof = Option.mapPartial + (Symtab.lookup callee_proofs) (Option.map #name dest_fn) + in + (* Determine if we have a proof for the callee. *) + case callee_proof of + NONE => + (let + val generated_term = mk_monad @{const_name "L2_fail"} needed_vars throw_vars [] + val thm = mkthm empty_set needed_vars generated_term @{thm L2corres_fail} + in + (empty_set, needed_vars, generated_term, thm) + end) + + | SOME (is_recursive, callee_free, callee_thm) => + (let + (* Get information about the function. *) + val dest_fn = FunctionInfo.get_1_phase_info (the dest_fn) FunctionInfo.L1 + val args = #args dest_fn + + (* Parse argument setup. *) + val arg_setup_vals = parse_modify ctxt prog_info name_map arg_setup |> List.rev + + (* Ensure that we can parse everything. *) + val arg_setup_vals = + map (fn (a, b, c, parsed_expr) => + case parsed_expr of + NONE => + raise Utils.InvalidInput ("Could not parse function parameter '" ^ (fst a) ^ "'") + | SOME x => + (a, b, c, x) + ) arg_setup_vals + + (* Sanity check: ensure that we have the correct number of arguments. *) + val _ = if length arg_setup_vals <> length args then + raise TERM ("Argument list length does not match function definition.", [arg_setup]) + else + () + + (* Rename input parameter names. *) + val arg_setup_vals = map (fn ((a,t),b,c,d) => ((a ^ "'param", t), b, c, d)) arg_setup_vals + + (* Generate the call. *) + (* The measure is the first arg, so we need to skip it when applying the others. *) + val args = map (fn (a,_,_,_) => name_map a) arg_setup_vals + val call_args = let + val var = Free ("rec_measure'", @{typ "nat"}) + in + lambda var (betapplys (callee_free, var :: args)) + end + + val call_measure = case measure_term of + Const (@{const_name "measure_call"}, _) => @{mk_term "measure_call ?f" f} call_args + | _ => betapply (call_args, measure_term) + + val (call, ret_vars) = + case (ret_var, ret_expr) of + (SOME ("globals'", _), SOME e) => + (mk_monad @{const_name L2_modifycall} empty_set throw_vars + [call_measure, e], empty_set) + | (SOME x, SOME e) => + (mk_monad @{const_name L2_returncall} (make_set [x]) throw_vars + [call_measure, e], make_set [x]) + | (NONE, _) => + (mk_monad @{const_name L2_voidcall} empty_set throw_vars + [call_measure], empty_set) + + (* + * We have a list of arguments; some may be expressions that refer to + * global variables, while others will be purely local variables. We + * just emit them all as "L2_gets" calls, and will clean them up + * later. + *) + val extractors = foldr ( + fn ((updated_var, read_vars, is_globals_reader, expr), rest) => + let + val ret_type = (make_set [("x'", fastype_of expr |> body_type)]) + val rest_type = (make_set [("x'", l2monad_ret_type rest)]) + val getter = mk_monad @{const_name L2_folded_gets} ret_type throw_vars + [expr, Utils.ml_str_list_to_isa [fst updated_var]] + in + mk_monad @{const_name "L2_seq"} rest_type throw_vars [ + getter, + Utils.abs_over (fst updated_var) (name_map updated_var) rest] + end + ) + call + arg_setup_vals + val read_vars = union_sets (map #2 expr_list) UNION ret_read_vars + + (* Generate a proof. *) + val my_debug_tac = if false then print_tac ctxt else fn _ => all_tac + val L2_call_thms = @{thms L2corres_returncall L2corres_voidcall L2corres_modifycall} + val L2_reccall_thms = @{thms L2corres_recursive_returncall L2corres_recursive_voidcall L2corres_recursive_modifycall} + val thm = + mk_corresXF_thm ctxt prog_info name_map ret_vars throw_vars read_vars extractors l1_term ( + (my_debug_tac "unfold folded_gets" + THEN (REPEAT (resolve_tac ctxt @{thms L2corres_folded_gets} 1))) + THEN (my_debug_tac "apply callee proof" + THEN FIRST (map (fn thm => resolve_tac ctxt [thm] 1 THEN + resolve_tac ctxt (List.mapPartial I [#mono_thm dest_fn]) 1 THEN + resolve_tac ctxt [callee_thm] 1) + L2_call_thms @ + map (fn thm => resolve_tac ctxt [thm] 1 THEN + resolve_tac ctxt [callee_thm] 1) + L2_reccall_thms)) + THEN (my_debug_tac "final simp" + THEN (REPEAT (CHANGED (asm_full_simp_tac (setup_l2_ss ctxt) 1)))) + ) + in + inject (read_vars, ret_vars, extractors, thm) + end) + end + | _ => Utils.invalid_input "a parsed L1 term" + (l1_term |> head_of |> PolyML.makestring) +end + +val measureT = @{typ nat}; + +(* Get the expected type of a function from its name. *) +fun get_expected_l2_fn_type prog_info fn_info fn_name = +let + val fn_def = FunctionInfo.get_phase_info fn_info FunctionInfo.L1 fn_name + val fn_params_typ = measureT :: map snd (#args fn_def) +in + fn_params_typ ---> mk_l2monadT (#globals_type prog_info) (#return_type fn_def) @{typ unit} +end + +(* Avoid clashes with fixed free variables such as C-parser's "symbol_table". + * FIXME: use variant_fixes instead, also for rec_measure *) +fun to_free_var_name lthy x = if Variable.is_fixed lthy x then to_free_var_name lthy (x ^ "'") else x + +(* Get arguments passed into the function. *) +fun get_expected_l2_fn_args lthy prog_info fn_info fn_name = +let + val fn_def = FunctionInfo.get_phase_info fn_info FunctionInfo.L1 fn_name +in + map (apfst (to_free_var_name lthy o ProgramInfo.demangle_name)) (#args fn_def) +end + +fun get_expected_l2_fn_thm prog_info fn_info ctxt fn_name fn_free fn_args _ measure_var = +let + (* Fetch input/output params for monad type. *) + val (input_params, output_params) = get_fn_input_output_vars prog_info fn_info fn_name + + (* Get mapping from internal variable names that we use to the names passed + * in "fn_args". *) + val fn_def = FunctionInfo.get_phase_info fn_info FunctionInfo.L1 fn_name + val args = map fst (#args fn_def) + val m = Symtab.make (args ~~ fn_args) + fun name_map (n, _) = Symtab.lookup m n |> the +in + mk_corresXF_prop (Proof_Context.theory_of ctxt) prog_info name_map + output_params empty_set input_params + (betapplys (fn_free, measure_var :: fn_args)) + (betapply (#const fn_def, measure_var)) +end + +(* Extract the abstract body of a L2corres theorem. *) +fun get_body_of_thm ctxt thm = + Thm.concl_of (Variable.gen_all ctxt thm) + |> HOLogic.dest_Trueprop + |> dest_L2corres_term_abs + +fun get_l2corres_thm ctxt prog_info fn_info do_opt trace_opt fn_name + callee_terms fn_args l1_term init_unfold = let + + (* Get information about the return variable. *) + val fn_def = FunctionInfo.get_phase_info fn_info FunctionInfo.L1 fn_name + + (* Get return variables. *) + val (fn_input_vars, fn_ret_vars) = get_fn_input_output_vars prog_info fn_info fn_name + + (* Get mapping from internal variable names to external arguments. *) + val m = Symtab.make (map fst (#args fn_def) ~~ fn_args) + fun name_map_ext (n, T) = Symtab.lookup m n |> the + fun name_map_internal (n, T) = Free ("lvar'" ^ n, T) + + (* + * Many constructs from SIMPL (and also L1) are in set form, but we really + * need them to be in functional form to be able to effectively parse them. + * In particular we can parse: + * + * (%s. n_' s) + * + * but not: + * + * {s. n_' s} + * + * We do some basic conversions here to convert common sets into lambda + * functions. + *) + val init_rule = Thm.cterm_of ctxt l1_term + |> Conv.rewr_conv (safe_mk_meta_eq init_unfold) + + (* Extract the term we will be working with. *) + val source_term = Thm.concl_of init_rule |> Logic.dest_equals |> snd + |> Utils.unsafe_unvarify + + (* Do basic parsing. *) + val parsed_term = parse_l1 ctxt prog_info fn_info name_map_internal source_term + + (* Get a list of all variables either read from or written to. *) + val all_vars = Prog.fold_prog + (K I) + (fn (_, vars, _) => fn old_vars => vars UNION old_vars) + (fn mod_var => fn old_vars => + case mod_var of SOME x => (Varset.insert x old_vars) | NONE => old_vars) + (K I) + parsed_term empty_set + + (* Perform liveness analysis of the function. *) + val liveness_data = calc_live_vars parsed_term fn_ret_vars + + (* + * Get information about modified variables. + * + * "NONE" represents "modifies potentially all variables"; we modify + * the results to fit this. + *) + val modification_data = + get_modified_vars parsed_term + |> map_prog (fn x => Option.getOpt (x, all_vars)) I I I + + (* Combine collected data. *) + fun zip_node_data a b c = + zip_progs a (zip_progs b c) + |> map_prog (fn (a, (b, c)) => (a, b, c)) fst fst fst + val input_term = zip_node_data parsed_term liveness_data modification_data + + (* Ensure that the only live variables at the beginning of the function are + * those that are function inputs. *) + val fn_inputs = get_node_data liveness_data + val fn_params = #args fn_def + val excess_inputs = fn_inputs MINUS (make_set fn_params) + val _ = + if excess_inputs <> empty_set then + warning + ("Input function '" ^ fn_name ^ "' has unresolved variables: " + ^ PolyML.makestring (dest_set excess_inputs)) + else + () + + (* Do the conversion. *) + val (vars_read, _, term, thm) = do_conv ctxt prog_info fn_info name_map_internal fn_input_vars + callee_terms fn_ret_vars false empty_set input_term + + (* Replace our internal terms with external terms. *) + val replacements = (map name_map_internal fn_params) ~~ (map name_map_ext fn_params) + val term = Raw_Simplifier.rewrite_term (Proof_Context.theory_of ctxt) + [] [Termtab.lookup (Termtab.make replacements)] term + + (* + * Generate a theorem with a folded RHS, with the LHS unfolded. + * + * The idea here is that we must generate a theorem of the form we + * committed to in "get_expected_l2_fn_thm", but with schematic variables. + *) + val new_thm = + mk_corresXF_prop (Proof_Context.theory_of ctxt) prog_info name_map_ext + fn_ret_vars empty_set fn_input_vars + term l1_term + |> Thm.cterm_of ctxt + |> Goal.init + |> apply_tac "unfold RHS" (EqSubst.eqsubst_tac ctxt [0] [init_rule] 1) + |> apply_tac "generalise guard" (resolve_tac ctxt @{thms L2corres_guard_imp} 1) + |> apply_tac "solve main goal" (resolve_tac ctxt [thm] 1) + |> apply_tac "solve guard_imp" (REPEAT (FIRST [ + resolve_tac ctxt @{thms HOL.refl} 1, + resolve_tac ctxt @{thms pred_andI} 1, + resolve_tac ctxt @{thms conjI} 1, + CHANGED (asm_full_simp_tac (setup_l2_ss ctxt) 1)])) + |> Goal.finish ctxt + + (* Remove intermediate scaffolding. *) + val new_thm = Conv.fconv_rule ( + Utils.remove_meta_conv (fn ctxt => + Utils.nth_arg_conv 5 ( + Raw_Simplifier.rewrite ctxt false @{thms L2_remove_scaffolding_1} + then_conv + Raw_Simplifier.rewrite ctxt false @{thms L2_remove_scaffolding_2})) ctxt) new_thm + + (* Gather statistics. *) + val _ = Statistics.gather ctxt "L2" fn_name + (Thm.prop_of new_thm |> HOLogic.dest_Trueprop |> (fn t => Utils.term_nth_arg t 4)) + + (* Cleanup. *) + val _ = writeln ("Simplifying (L2) " ^ fn_name) + val new_thm = Simplifier.simplify (put_simpset HOL_basic_ss ctxt addsimps + (* this rule is expensive *) + (if do_opt then @{thms L2_unknown_bind} else [])) + new_thm + + val _ = writeln ("Simplifying (L2opt) " ^ fn_name) + (* HACK: we need to avoid these simps until heap_lift *) + val cleanup_del = @{thms ptr_coerce.simps ptr_add_0_id} + val (new_thm, traces) = L2Opt.cleanup_thm_tagged (ctxt delsimps cleanup_del) new_thm + (if do_opt then 0 else 2) 5 trace_opt "L2" + + (* Gather post-simplification statistics. *) + val _ = Statistics.gather ctxt "L2simp" fn_name + (Thm.prop_of new_thm |> HOLogic.dest_Trueprop |> (fn t => Utils.term_nth_arg t 4)) +in + (new_thm, traces) +end + +(* + * Prove monad_mono property for recursive functions. + * Note that this is also used for subsequent L2-based phases. + *) + +fun l2_monad_mono lthy (fn_infos: FunctionInfo.phase_info Symtab.table) = +let + (* + * For the induction, we need to have the form + * "\ m. (ALL a b... f m a b...) /\ (ALL a b... g m a b...) /\ ..." + * and this gets annoying pretty quickly. But it is probably unavoidable. + *) + val (fn_names, fn_defs) = split_list (Symtab.dest fn_infos); + val measure = Free ("rec_measure'", measureT) + fun make_mono_step_stmt current_def = + let + (* def should be of the form "func ?locale_args... ?measure ?args... = ..." *) + val (_, locale_args) = strip_comb (#const current_def) + val (_, all_args) = Utils.lhs_of_eq (term_of_thm (#definition current_def)) |> strip_comb + val _ :: args = drop (length locale_args) all_args + val args = args |> map (fn Var ((name, _), typ) => Free (name, typ)) + in + fold (fn arg => fn t => @{mk_term "All ?P" P} (lambda arg t)) args + (@{mk_term "monad_mono_step ?f ?m" (f, m)} + (lambda measure (betapplys (#const current_def, measure :: args)), measure)) + end + val mk_conj_list = foldr1 (fn (a, b) => @{term "conj"} $ a $ b) + + val func_expand = map (fn fn_def => EqSubst.eqsubst_tac lthy [0] + [Utils.abs_def lthy (#definition fn_def)]) fn_defs + val tac = + resolve_tac lthy @{thms nat.induct} 1 + THEN EVERY (map (fn expand => + TRY (resolve_tac lthy @{thms conjI} 1) + THEN REPEAT (resolve_tac lthy @{thms allI} 1) + THEN expand 1 + THEN resolve_tac lthy @{thms monad_mono_step_L2_recguard_0} 1) func_expand) + THEN REPEAT (eresolve_tac lthy @{thms conjE} 1) + THEN EVERY (map (fn expand => + TRY (resolve_tac lthy @{thms conjI} 1) + THEN REPEAT (resolve_tac lthy @{thms allI} 1) + THEN expand 1 + THEN REPEAT (FIRST ( + map (fn t => resolve_tac lthy [t] 1) @{thms L2_monad_mono_step_rules} + (* We use simp to solve assumptions (assume_tac doesn't work + * because the assumptions are ALL-quantified) and to + * split tuple cases. *) + @ [CHANGED (asm_full_simp_tac (clear_simpset lthy + addsimps @{thms split_conv split_tupled_all}) 1)]))) + func_expand) + + val mono_thm = map make_mono_step_stmt fn_defs + |> mk_conj_list + |> (fn t => Logic.all measure (@{term "Trueprop"} $ t)) + |> (fn t => Goal.prove lthy [] [] t (K tac)) + + (* We have finished the induction, now we extract the individual results. *) + fun make_mono_stmt L2_def = + let + val (_, locale_args) = strip_comb (#const L2_def) + val (_, all_args) = Utils.lhs_of_eq (term_of_thm (#definition L2_def)) |> strip_comb + val _ :: args = drop (length locale_args) all_args + val args = args |> map (fn Var ((name, _), typ) => Free (name, typ)) + in + fold Logic.all args + (@{mk_term "Trueprop (monad_mono ?f)" f} + (lambda measure (betapplys (#const L2_def, measure :: args)))) + end + val final_thms = fn_defs + |> map (fn fn_def => Goal.prove lthy [] [] (make_mono_stmt fn_def) + (K (asm_full_simp_tac (lthy addsimps [@{thm monad_mono_alt_def}, mono_thm]) 1))) +in + final_thms + |> (fn thms => fn_names ~~ thms) + |> Symtab.make +end + +(* For functions that are not translated, just generate a trivial wrapper. *) +fun mk_l2corres_call_simpl_thm prog_info fn_info ctxt fn_name fn_args = let + val fn_def = FunctionInfo.get_phase_info fn_info FunctionInfo.L1 fn_name + val const = #const fn_def + val args = #args fn_def + + (* Get return variables. *) + val (fn_input_vars, fn_ret_vars) = get_fn_input_output_vars prog_info fn_info fn_name + (* Get mapping from internal variable names to external arguments. *) + val m = Symtab.make (map fst args ~~ fn_args) + fun name_map_ext (n, T) = Symtab.lookup m n |> the + + val arg_xf = mk_precond prog_info name_map_ext fn_input_vars + val ret_xf = mk_xf prog_info fn_ret_vars + val ex_xf = Abs ("s", #state_type prog_info, HOLogic.unit) + + val thm = Utils.named_cterm_instantiate ctxt + (map (apsnd (Thm.cterm_of ctxt)) + [("l1_f", betapply (const, Free ("rec_measure'", @{typ "nat"}))), + ("ex_xf", ex_xf), ("gs", #globals_getter prog_info), + ("ret_xf", ret_xf), ("arg_xf", arg_xf)]) + @{thm L2corres_L2_call_simpl} + OF [#definition fn_def] + in thm end + +(* + * Convert a single function. Returns a thm that looks like + * \ L2corres ?callee1 l1_callee1; ... \ \ + * L2corres (conversion result...) l1_f + * i.e. with assumptions for called functions, which are parameterised as Vars. + *) +fun convert + (lthy: local_theory) (* must contain at least L1 callee defs, but no other requirements *) + (prog_info: ProgramInfo.prog_info) + (fn_info: FunctionInfo.fn_info) (* legacy for aux data, won't actually contain L1 data *) + (* L1 data for f and callees, may include extra funcs *) + (l1_info: (FunctionInfo.phase_info * thm) Symtab.table) + (do_opt: bool) + (trace_opt: bool) + (l2_function_name: string -> string) + (f_name: string) + : thm * (string * typ) list = let + val callee_names = FunctionInfo.get_function_callees fn_info f_name; + val _ = filter (fn f => not (isSome (Symtab.lookup l1_info f))) callee_names + |> (fn bad => if null bad then () else + error ("L2 conversion missing callees for " ^ f_name ^ ": " ^ commas bad)); + (* TODO sanity check: + - the required L1 defs exist *) + + (* Place the L1 data where the existing code expects it. *) + val fn_info' = fn_info |> FunctionInfo.add_phases (fn f => + K (Symtab.lookup l1_info f |> Option.map fst)); + + (* Fix measure variable. *) + val ([measure_var_name], lthy') = Variable.variant_fixes ["rec_measure'"] lthy; + val measure_var = Free (measure_var_name, measureT); + + (* Fix argument variables. *) + val (f_l1_info, f_l1corres) = the (Symtab.lookup l1_info f_name); + val f_args = #args f_l1_info; + val (arg_names, lthy'') = Variable.variant_fixes (map fst f_args) lthy'; + val arg_frees = arg_names ~~ map snd f_args; + + (* Add callee assumptions. Note that our define code has to use the same assumption order. *) + val group = Symset.make (FunctionInfo.get_recursive_group fn_info f_name); + val callee_terms = + map (fn callee => (callee, Symset.contains group callee)) callee_names; + val (lthy''', export_thm, callee_terms) = + AutoCorresUtil.assume_called_functions_corres lthy'' + callee_terms + (get_expected_l2_fn_type prog_info fn_info') + (get_expected_l2_fn_thm prog_info fn_info') + (get_expected_l2_fn_args lthy prog_info fn_info') + l2_function_name + measure_var; + + val f_l1_def = Utils.named_cterm_instantiate lthy''' + [("rec_measure'", Thm.cterm_of lthy''' measure_var)] + (#definition f_l1_info) + val (thm, opt_traces) = + if #is_simpl_wrapper (FunctionInfo.get_function_info fn_info' f_name) + then (mk_l2corres_call_simpl_thm prog_info fn_info' lthy''' f_name (map Free arg_frees), []) + else get_l2corres_thm lthy''' prog_info fn_info' do_opt trace_opt f_name + (Symtab.make callee_terms) (map Free f_args) + (betapply (#const f_l1_info, measure_var)) + f_l1_def; + in (Morphism.thm export_thm thm, + (* Provide the fixed vars so the user can generalize/instantiate them *) + dest_Free measure_var :: arg_frees) + end + + +(* Define a previously-converted function (or recursive function group). + * lthy must include all definitions from l2_callees. *) +fun define + (lthy: local_theory) + (filename: string) + (prog_info: ProgramInfo.prog_info) + (fn_info: FunctionInfo.fn_info) (* required to have L1 *) + (l2_callees: (FunctionInfo.phase_info * thm) Symtab.table) (* L2 callees & corres thms *) + (l2_function_name: string -> string) + (funcs: (string * thm * (string * typ) list) list) (* name, corres, arg frees *) + : (FunctionInfo.phase_info * thm) Symtab.table * local_theory = let + (* FIXME: dedup with convert *) + + (* FIXME: pass this from assume_called_functions_corres, etc. *) + fun guess_callee_var thm callee = let + val base_name = l2_function_name callee; + val mentioned_vars = Term.add_vars (Thm.prop_of thm) []; + in hd (filter (fn ((v, _), _) => v = base_name) mentioned_vars) end; + + fun prepare_fn_body (fn_name, corres_thm, arg_frees) = let + val _ = @{trace} ("prepare_fn_body", fn_name, corres_thm); + val @{term_pat "Trueprop (L2corres _ _ _ _ ?body _)"} = Thm.concl_of corres_thm; + val (callees, recursive_callees) = AutoCorresUtil.get_callees fn_info fn_name; + val calls = map (fn c => Var (guess_callee_var corres_thm c)) callees; + val recursive_calls = map (fn c => Var (guess_callee_var corres_thm c)) recursive_callees; + + (* + * The returned body will have free variables as placeholders for the function's + * measure parameter and other arguments, and schematic variables for the functions it calls. + * + * We modify the body to be of the form: + * + * %fun1 fun2 rec1 rec2 measure arg1 arg2. f <...> + * + * That is, all non-recursive calls are abstracted out the front, followed by + * recursive calls, followed by the measure variable, followed by function + * arguments. This is the format expected by define_funcs. + *) + val abs_body = body + |> fold lambda (rev (map Free arg_frees)) + |> fold lambda (rev recursive_calls) + |> fold lambda (rev calls); + in abs_body end; + + val fn_info' = FunctionInfo.add_phases (fn f => K (Option.map fst (Symtab.lookup l2_callees f))) fn_info; + val funcs' = funcs |> + map (fn (name, thm, frees) => + (name, (* FIXME: define_funcs needs this currently *) + (Thm.generalize ([], map fst frees) (Thm.maxidx_of thm + 1) thm, + prepare_fn_body (name, thm, frees), []))); + val _ = @{trace} ("L2.define", map (fn (name, (thm, body, _)) => (name, thm, Thm.cterm_of lthy body)) funcs'); + val (corres_thms, (), lthy') = + AutoCorresUtil.define_funcs FunctionInfo.L2 filename fn_info' + l2_function_name + (get_expected_l2_fn_type prog_info fn_info') + (get_expected_l2_fn_thm prog_info fn_info') + (get_expected_l2_fn_args lthy prog_info fn_info') + @{thm L2corres_recguard_0} + lthy (Symtab.map (K snd) l2_callees) () + funcs'; + val f_names = map (fn (name, _, _) => name) funcs; + val new_phases = map2 (fn f_name => fn corres => let + val old_phase = FunctionInfo.get_phase_info fn_info' FunctionInfo.L1 f_name; + val def = the (AutoCorresData.get_def (Proof_Context.theory_of lthy') filename "L2def" f_name); + val f_const = Utils.get_term lthy' (l2_function_name f_name); + in old_phase + |> FunctionInfo.phase_info_upd_phase FunctionInfo.L2 + |> FunctionInfo.phase_info_upd_definition def + |> FunctionInfo.phase_info_upd_const f_const + |> FunctionInfo.phase_info_upd_mono_thm NONE (* TODO *) + end) f_names corres_thms; + (* FIXME: return traces *) + in (Symtab.make (f_names ~~ (new_phases ~~ corres_thms)), lthy') end; + + +fun symtab_merge tabs = maps Symtab.dest tabs |> Symtab.make; + + +(* + * Translate all functions from L1 to L2 format. + *) +fun translate filename prog_info fn_info + (* lazy results from L1 *) + (l1_results: (string * (local_theory * (FunctionInfo.phase_info * thm) Symtab.table) future) list) + do_opt trace_opt l2_function_name = +(* if there's nothing to translate, we won't have a lthy *) +if null l1_results then [] else +let + val funcs_to_translate = Symtab.keys (FunctionInfo.get_all_functions fn_info); + + val get_l1_result = let + val table = Symtab.make l1_results; + in fn f => the' ("missing L1 lazy result for function: " ^ f) (Symtab.lookup table f) end; + + (* All conversions can run in parallel. + * Each conversion depends only on the corresponding L1 define phase + * (which necessarily also includes L1 callees). *) + val converted_funcs = + funcs_to_translate |> map (fn f => + (f, Future.fork (fn _ => let + (* L1 info for f and its callees *) + val (r_lthy, f_l1_info) = Future.join (get_l1_result f); + val callee_l1_infos = FunctionInfo.get_function_callees fn_info f + |> map (fn callee => snd (Future.join (get_l1_result callee))); + val l1_infos = symtab_merge (f_l1_info :: callee_l1_infos); + in convert r_lthy prog_info fn_info l1_infos do_opt trace_opt l2_function_name f end))) + |> Symtab.make; + + (* Definitions update lthy sequentially. + * We use the arbitrary (but deterministic) ordering defined by get_topo_sorted_functions. + * Each definition step produces a lthy that has a prefix of the update sequence, + * and can be used subsequently to convert a function that depends only on that prefix. + * Hence we produce the intermediate lthys lazily for maximum parallelism. *) + fun add_def f_names accum = Future.fork (fn _ => let + (* Wait for previous definition to finish *) + val (lthy, _, defined_so_far) = Future.join accum; + (* Wait for conversions to finish *) + val f_convs = map (fn f => let + val conv = the' ("didn't convert function: " ^ quote f ^ "??") (Symtab.lookup converted_funcs f); + val (corres_thm, arg_frees) = Future.join conv; + in (f, corres_thm, arg_frees) end) f_names; + (* Add L1 phase results. *) + val (_, l1_infos) = Future.join (get_l1_result (hd f_names)); + val fn_info' = FunctionInfo.add_phases (fn f => K (Symtab.lookup l1_infos f |> Option.map fst)) fn_info; + val (new_defs, lthy') = define lthy filename prog_info fn_info' + (defined_so_far: (FunctionInfo.phase_info * thm) Symtab.table) + l2_function_name f_convs; + in (lthy', new_defs, Symtab.merge (K false) (defined_so_far, new_defs)) end); + + val function_groups = FunctionInfo.get_topo_sorted_functions fn_info; + (* Get initial lthy from end of L1 defs *) + val l1_def_context = Future.map (fn (lthy, _) => (lthy, Symtab.empty, Symtab.empty)) + (snd (hd (rev l1_results))); + (* Chain of intermediate states: (lthy, new_defs, accumulator) *) + val (def_results, _) = Utils.accumulate add_def l1_def_context function_groups; + + (* Produce a mapping from each function group to its L1 phase_infos and the + * earliest intermediate lthy where it is defined. *) + val results = + def_results + |> map (Future.map (fn (lthy, f_defs, _ (* discard accum here *)) => let + (* Add monad_mono proofs. These are done in parallel as well + * (though in practice, they already run pretty quickly). *) + val mono_thms = if FunctionInfo.is_function_recursive fn_info (hd (Symtab.keys f_defs)) + then l2_monad_mono lthy (Symtab.map (K fst) f_defs) + else Symtab.empty; + val f_defs' = f_defs |> Symtab.map (fn f => + apfst (FunctionInfo.phase_info_upd_mono_thm (Symtab.lookup mono_thms f))); + in (lthy, f_defs') end)); +in + (* would be ready to become a symtab, but a list also preserves order *) + function_groups ~~ results +end + +end diff --git a/tools/autocorres/simpl_conv2.ML b/tools/autocorres/simpl_conv2.ML new file mode 100644 index 000000000..acf06dd5b --- /dev/null +++ b/tools/autocorres/simpl_conv2.ML @@ -0,0 +1,630 @@ +(* + * Copyright 2014, NICTA + * + * This software may be distributed and modified according to the terms of + * the BSD 2-Clause license. Note that NO WARRANTY is provided. + * See "LICENSE_BSD2.txt" for details. + * + * @TAG(NICTA_BSD) + *) + +(* + * Automatically convert SIMPL code fragments into a monadic form, with proofs + * of correspondence between the two. + *) +structure SimplConv2 = +struct + +(* Convenience shortcuts. *) +val warning = Utils.ac_warning +val apply_tac = Utils.apply_tac +val the' = Utils.the' + +exception FunctionNotFound of string + +val simpl_conv_ss = AUTOCORRES_SIMPSET + +(* + * Given a function constant name such as "Blah.foo_'proc", guess the underlying + * function name "foo". + *) +fun guess_function_name const_name = + const_name |> unsuffix "_'proc" |> Long_Name.base_name + +(* Generate a L1 monad type. *) +fun mk_l1monadT stateT = + Utils.gen_typ @{typ "'a L1_monad"} [stateT] + +(* + * Extract the L1 monadic term out of a L1corres constant. + *) +fun get_L1corres_monad @{term_pat "L1corres _ _ ?l1_monad _"} = l1_monad + | get_L1corres_monad t = raise TERM ("get_L1corres_monad", [t]) + +(* + * Generate a SIMPL term that calls the given function. + * + * For instance, we might return: + * + * "Call foo_'proc" + *) +fun mk_SIMPL_call_term ctxt prog_info fn_info target_fn = + @{mk_term "Call ?proc :: (?'s, int, strictc_errortype) com" (proc, 's)} + (FunctionInfo.get_phase_info fn_info FunctionInfo.CP target_fn |> #const, #state_type prog_info) + +(* + * Construct a correspondence lemma between a given monadic term and a SIMPL fragment. + * + * The term is of the form: + * + * L1corres check_termination \ monad simpl + *) +fun mk_L1corres_prop prog_info check_termination monad_term simpl_term = + @{mk_term "L1corres ?ct ?gamma ?monad ?simpl" (ct, gamma, monad, simpl)} + (Utils.mk_bool check_termination, #gamma prog_info, monad_term, simpl_term) + +(* + * Construct a prop claiming that the given term is equivalent to + * a call to the given SIMPL function: + * + * L1corres ct \ (Call foo_'proc) + * + *) +fun mk_L1corres_call_prop ctxt prog_info fn_info check_termination target_fn_name term = + mk_L1corres_prop prog_info check_termination term + (mk_SIMPL_call_term ctxt prog_info fn_info target_fn_name) + |> HOLogic.mk_Trueprop + +(* + * Convert a SIMPL fragment into a monadic term. + * + * We return the monadic version of the input fragment and a tactic + * to prove correspondence. + *) +fun simpl_conv' + (prog_info : ProgramInfo.prog_info) + (fn_info : FunctionInfo.fn_info) + (ctxt : Proof.context) + (callee_terms : (bool * term * thm) Symtab.table) + (measure_var : term) + (simpl_term : term) = + let + fun prove_term subterms base_thm result_term = + let + val subterms' = map (simpl_conv' prog_info fn_info ctxt + callee_terms measure_var) subterms; + val converted_terms = map fst subterms'; + val subproofs = map snd subterms'; + val new_term = (result_term converted_terms); + in + (new_term, (resolve_tac ctxt [base_thm] 1) THEN (EVERY subproofs)) + end + + (* Construct a "L1 monad" term with the given arguments applied to it. *) + fun mk_l1 (Const (a, _)) args = + Term.betapplys (Const (a, map fastype_of args + ---> mk_l1monadT (#state_type prog_info)), args) + + (* Convert a set construct into a predicate construct. *) + fun set_to_pred t = + (Const (@{const_name L1_set_to_pred}, + fastype_of t --> (HOLogic.dest_setT (fastype_of t) --> @{typ bool})) $ t) + in + (case simpl_term of + (* + * Various easy cases of SIMPL to monadic conversion. + *) + + (Const (@{const_name Skip}, _)) => + prove_term [] @{thm L1corres_skip} + (fn _ => mk_l1 @{term "L1_skip"} []) + + | (Const (@{const_name Seq}, _) $ left $ right) => + prove_term [left, right] @{thm L1corres_seq} + (fn [l, r] => mk_l1 @{term "L1_seq"} [l, r]) + + | (Const (@{const_name Basic}, _) $ m) => + prove_term [] @{thm L1corres_modify} + (fn _ => mk_l1 @{term "L1_modify"} [m]) + + | (Const (@{const_name Cond}, _) $ c $ left $ right) => + prove_term [left, right] @{thm L1corres_condition} + (fn [l, r] => mk_l1 @{term "L1_condition"} [set_to_pred c, l, r]) + + | (Const (@{const_name Catch}, _) $ left $ right) => + prove_term [left, right] @{thm L1corres_catch} + (fn [l, r] => mk_l1 @{term "L1_catch"} [l, r]) + + | (Const (@{const_name While}, _) $ c $ body) => + prove_term [body] @{thm L1corres_while} + (fn [body] => mk_l1 @{term "L1_while"} [set_to_pred c, body]) + + | (Const (@{const_name Throw}, _)) => + prove_term [] @{thm L1corres_throw} + (fn _ => mk_l1 @{term "L1_throw"} []) + + | (Const (@{const_name Guard}, _) $ _ $ c $ body) => + prove_term [body] @{thm L1corres_guard} + (fn [body] => mk_l1 @{term "L1_seq"} [mk_l1 @{term "L1_guard"} [set_to_pred c], body]) + + | @{term_pat "lvar_nondet_init _ ?upd"} => + prove_term [] @{thm L1corres_init} + (fn _ => mk_l1 @{term "L1_init"} [upd]) + + | (Const (@{const_name Spec}, _) $ s) => + prove_term [] @{thm L1corres_spec} + (fn _ => mk_l1 @{term "L1_spec"} [s]) + + | (Const (@{const_name guarded_spec_body}, _) $ _ $ s) => + prove_term [] @{thm L1corres_guarded_spec} + (fn _ => mk_l1 @{term "L1_spec"} [s]) + + (* + * "call": This is primarily what is output by the C parser. We + * accept input terms of the form: + * + * "call (%_ s. Basic ( s))". + * + * In particular, the last argument needs to be of precisely the + * form above. SIMPL, in theory, supports complex expressions in + * the last argument. In practice, the C parser only outputs + * the form above, and supporting more would be a pain. + *) + | (Const (@{const_name call}, _) $ a $ (fn_const as Const (b, _)) $ c $ (Abs (_, _, Abs (_, _, (Const (@{const_name Basic}, _) $ d))))) => + let + val state_type = #state_type prog_info + val target_fn_name = FunctionInfo.get_function_from_const fn_info fn_const |> Option.map #name + in + case Option.mapPartial (Symtab.lookup callee_terms) target_fn_name of + NONE => + (* If no proof of our callee could be found, we emit a call to + * "fail". This may happen for functions without bodies. *) + let + val _ = warning ("Function '" ^ guess_function_name b ^ "' contains no body. " + ^ "Replacing the function call with a \"fail\" command.") + in + prove_term [] @{thm L1corres_fail} (fn _ => mk_l1 @{term "L1_fail"} []) + end + | SOME (is_rec, term, thm) => + let + (* + * If this is an internal recursive call, decrement the measure. + * Or if this is calling a recursive function, use measure_call. + * If the callee isn't recursive, it doesn't use the measure var + * and we can just give an arbitrary value. + *) + val target_fn_name = (the target_fn_name) + val target_rec = FunctionInfo.is_function_recursive fn_info target_fn_name + val term' = + if is_rec then + term $ (@{term "recguard_dec"} $ measure_var) + else if target_rec then + @{mk_term "measure_call ?f" f} term + else + term $ @{term "undefined :: nat"} + in + (* Generate the term. *) + (mk_l1 @{term "L1_call"} + [a, term', c, absdummy state_type d], + resolve_tac ctxt [if is_rec orelse not target_rec then + @{thm L1corres_reccall} else @{thm L1corres_call}] 1 + THEN resolve_tac ctxt [thm] 1) + end + end + + (* TODO : Don't currently support DynCom *) + | other => Utils.invalid_term "a SIMPL term" other) + end + +(* Perform post-processing on a theorem. *) +fun cleanup_thm ctxt do_opt trace_opt prog_info fn_name thm = +let + (* Measure the term. *) + fun gather_stats phase thm = + Statistics.gather ctxt phase fn_name + (Thm.concl_of thm |> HOLogic.dest_Trueprop |> get_L1corres_monad) + val _ = gather_stats "L1" thm + + (* For each function, we want to prepend a statement that sets its return + * value undefined. It is actually always defined, but our analysis isn't + * sophisticated enough to realise. *) + fun prepend_undef thm fn_name = + let + val ret_var_name = + Symtab.lookup (ProgramAnalysis.get_fninfo (#csenv prog_info)) fn_name + |> the + |> (fn (ctype, _, _) => NameGeneration.return_var_name ctype |> MString.dest) + val ret_var_setter = Symtab.lookup (#var_setters prog_info) ret_var_name + val ret_var_getter = Symtab.lookup (#var_getters prog_info) ret_var_name + fun try_unify (x::xs) = + ((x ()) handle THM _ => try_unify xs) + in + case ret_var_setter of + SOME _ => + (* Prepend the L1_init code. *) + Utils.named_cterm_instantiate ctxt + [("X", Thm.cterm_of ctxt (the ret_var_setter)), + ("X'", Thm.cterm_of ctxt (the ret_var_getter))] + (try_unify [ + (fn _ => @{thm L1corres_prepend_unknown_var_recguard} OF [thm]), + (fn _ => @{thm L1corres_prepend_unknown_var} OF [thm]), + (fn _ => @{thm L1corres_prepend_unknown_var'} OF [thm])]) + + (* Discharge the given proof obligation. *) + |> simp_tac (put_simpset simpl_conv_ss ctxt) 1 |> Seq.hd + | NONE => thm + end + val thm = prepend_undef thm fn_name + + (* Conversion combinator to apply a conversion only to the L1 subterm of a + * L1corres term. *) + fun l1conv conv = (Conv.arg_conv (Utils.nth_arg_conv 3 conv)) + + (* Conversion to simplify guards. *) + fun guard_conv' c = + case (Thm.term_of c) of + (Const (@{const_name "L1_guard"}, _) $ _) => + Simplifier.asm_full_rewrite (put_simpset simpl_conv_ss ctxt) c + | _ => + Conv.all_conv c + val guard_conv = Conv.top_conv (K guard_conv') ctxt + + (* Apply all the conversions on the generated term. *) + val (thm, guard_opt_trace) = AutoCorresTrace.fconv_rule_maybe_traced ctxt (l1conv guard_conv) thm trace_opt + val (thm, peephole_opt_trace) = + AutoCorresTrace.fconv_rule_maybe_traced ctxt + (l1conv (Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps + (if do_opt then Utils.get_rules ctxt @{named_theorems L1opt} else [])))) + thm trace_opt + val _ = gather_stats "L1peep" thm + + (* Rewrite exceptions. *) + val (thm, exn_opt_trace) = AutoCorresTrace.fconv_rule_maybe_traced ctxt + (l1conv (ExceptionRewrite.except_rewrite_conv ctxt do_opt)) thm trace_opt + val _ = gather_stats "L1except" thm +in + (thm, + [("L1 guard opt", guard_opt_trace), ("L1 peephole opt", peephole_opt_trace), ("L1 exception opt", exn_opt_trace)] + |> List.mapPartial (fn (n, tr) => case tr of NONE => NONE | SOME x => SOME (n, AutoCorresData.SimpTrace x)) + ) +end + +(* + * Get theorems about a SIMPL body in a format convenient to reason about. + * + * In particular, we unfold parts of SIMPL where we would prefer to reason + * about raw definitions instead of more abstract constructs generated + * by the C parser. + *) +fun get_simpl_body ctxt fn_info fn_name = +let + (* Find the definition of the given function. *) + val simpl_thm = #definition (FunctionInfo.get_phase_info fn_info FunctionInfo.CP fn_name) + handle ERROR _ => raise FunctionNotFound fn_name; + + (* Unfold terms in the body which we don't want to deal with. *) + val unfolded_simpl_thm = + Conv.fconv_rule (Utils.rhs_conv + (Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps + (Utils.get_rules ctxt @{named_theorems L1unfold})))) + simpl_thm + val unfolded_simpl_term = Thm.concl_of unfolded_simpl_thm |> Utils.rhs_of; + + (* + * Get the implementation definition for this function. These rules are of + * the form "Gamma foo_'proc = Some foo_body". + *) + val impl_thm = + Proof_Context.get_thm ctxt (fn_name ^ "_impl") + |> Local_Defs.unfold ctxt [unfolded_simpl_thm] + |> SOME + handle (ERROR _) => NONE +in + (unfolded_simpl_term, unfolded_simpl_thm, impl_thm) +end + +fun get_l1corres_thm prog_info fn_info check_termination ctxt do_opt trace_opt fn_name + callee_terms measure_var = let + val thy = Proof_Context.theory_of ctxt + val (simpl_term, simpl_thm, impl_thm) = get_simpl_body ctxt fn_info fn_name + + (* Fetch stats on pre-converted term. *) + val _ = Statistics.gather ctxt "CParser" fn_name simpl_term + + (* + * Do the conversion. We receive a new monadic version of the SIMPL + * term and a tactic for proving correspondence. + *) + val (monad, tactic) = simpl_conv' prog_info fn_info ctxt + callee_terms measure_var simpl_term + + (* + * Wrap the monad in a "L1_recguard" statement, which triggers + * failure when the measure reaches zero. This lets us automatically + * prove termination of the recursive function. + *) + val is_recursive = FunctionInfo.is_function_recursive fn_info fn_name + val (monad, tactic) = + if is_recursive then + (Utils.mk_term thy @{term "L1_recguard"} [measure_var, monad], + (resolve_tac ctxt @{thms L1corres_recguard} 1 THEN tactic)) + else + (monad, tactic) + + (* + * Return a new theorem of correspondence between the original + * SIMPL body (with folded constants) and the output monad term. + *) +in + mk_L1corres_call_prop ctxt prog_info fn_info check_termination fn_name monad + |> Thm.cterm_of ctxt + |> Goal.init + |> (case impl_thm of + NONE => apply_tac "unfold SIMPL body" (resolve_tac ctxt @{thms L1corres_undefined_call} 1) + | SOME def => apply_tac "unfold SIMPL body" (resolve_tac ctxt @{thms L1corres_Call} 1 THEN + resolve_tac ctxt [def] 1) + #> apply_tac "solve L1corres" tactic) + |> Goal.finish ctxt + (* Apply simplifications to the L1 term. *) + |> cleanup_thm ctxt do_opt trace_opt prog_info fn_name +end + +fun get_body_of_l1corres_thm thm = + (* Extract the monad from the thm. *) + Thm.concl_of thm + |> HOLogic.dest_Trueprop + |> get_L1corres_monad + +fun split_conj thm = + (thm RS @{thm conjunct1}) :: split_conj (thm RS @{thm conjunct2}) + handle THM _ => [thm] + +(* Prove monad_mono for recursive functions. *) +fun l1_monad_mono lthy (l1_defs : FunctionInfo.phase_info Symtab.table) = +let + val l1_defs' = Symtab.dest l1_defs; + fun mk_stmt [func] = @{mk_term "monad_mono ?f" f} func + | mk_stmt (func :: funcs) = @{mk_term "monad_mono ?f \ ?g" (f, g)} (func, mk_stmt funcs); + val mono_thm = @{term "Trueprop"} $ mk_stmt (map (#const o snd) l1_defs'); + val func_expand = map (fn (_, l1_def) => + EqSubst.eqsubst_tac lthy [0] [Utils.abs_def lthy (#definition l1_def)]) l1_defs'; + val tac = + REPEAT (EqSubst.eqsubst_tac lthy [0] + [@{thm monad_mono_alt_def}, @{thm all_conj_distrib} RS @{thm sym}] 1) + THEN resolve_tac lthy @{thms allI} 1 THEN resolve_tac lthy @{thms nat.induct} 1 + THEN EVERY (map (fn expand => + TRY (resolve_tac lthy @{thms conjI} 1) + THEN expand 1 + THEN resolve_tac lthy @{thms monad_mono_step_L1_recguard_0} 1) func_expand) + THEN REPEAT (eresolve_tac lthy @{thms conjE} 1) + THEN EVERY (map (fn expand => + TRY (resolve_tac lthy @{thms conjI} 1) + THEN expand 1 + THEN REPEAT (FIRST [assume_tac lthy 1, + resolve_tac lthy @{thms L1_monad_mono_step_rules} 1])) + func_expand); +in + Goal.prove lthy [] [] mono_thm (K tac) + |> split_conj + |> (fn thms => map fst l1_defs' ~~ thms) + |> Symtab.make +end + + +(* For functions that are not translated, just generate a trivial wrapper. *) +fun mk_l1corres_call_simpl_thm fn_info check_termination ctxt fn_name = let + val info = FunctionInfo.get_phase_info fn_info FunctionInfo.CP fn_name + val const = #const info + val impl_thm = Proof_Context.get_thm ctxt (fn_name ^ "_impl") + val gamma = safe_mk_meta_eq impl_thm |> Thm.concl_of |> Logic.dest_equals + |> fst |> (fn (f $ _) => f | t => raise TERM ("gamma", [t])) + val thm = Utils.named_cterm_instantiate ctxt + [("ct", Thm.cterm_of ctxt (Utils.mk_bool check_termination)), + ("proc", Thm.cterm_of ctxt const), + ("Gamma", Thm.cterm_of ctxt gamma)] + @{thm L1corres_call_simpl} + in thm end + + +(* + * Convert a single function. Returns a thm that looks like + * \ L1corres ?callee1 (Call callee1_'proc); ... \ \ + * L1corres (conversion result...) (Call f_'proc) + * i.e. with assumptions for called functions, which are parameterised as Vars. + *) +fun convert + (lthy: local_theory) + (prog_info: ProgramInfo.prog_info) + (fn_info: FunctionInfo.fn_info) (* needs CP phase_info for each callee *) + (check_termination: bool) + (do_opt: bool) + (trace_opt: bool) + (l1_function_name: string -> string) + (f_name: string) + : thm * (string * typ) list = let + val callee_names = FunctionInfo.get_function_callees fn_info f_name; + (* TODO sanity check: + - all SIMPL defs exist *) + + val measureT = @{typ nat}; + + (* All L1 functions have the same signature: measure \ L1_monad *) + val l1_fn_type = measureT --> mk_l1monadT (#state_type prog_info); + + (* L1corres for f's callees. *) + fun get_l1_fn_assumption ctxt fn_name free _ _ measure_var = + mk_L1corres_call_prop ctxt prog_info fn_info check_termination fn_name (betapply (free, measure_var)); + + (* Fix measure variable. *) + val ([measure_var_name], lthy') = Variable.variant_fixes ["rec_measure'"] lthy; + val measure_var = Free (measure_var_name, measureT); + + (* Add callee assumptions. Note that our define code has to use the same assumption order. *) + val group = Symset.make (FunctionInfo.get_recursive_group fn_info f_name); + val callee_terms = + map (fn callee => (callee, Symset.contains group callee)) callee_names; + val (lthy'', export_thm, callee_terms) = + AutoCorresUtil.assume_called_functions_corres lthy' + callee_terms + (K l1_fn_type) + get_l1_fn_assumption + (K []) + l1_function_name + measure_var; + + val (thm, opt_traces) = + if #is_simpl_wrapper (FunctionInfo.get_function_info fn_info f_name) + then (mk_l1corres_call_simpl_thm fn_info check_termination lthy'' f_name, []) + else get_l1corres_thm prog_info fn_info check_termination lthy'' do_opt trace_opt f_name + (Symtab.make callee_terms) measure_var; + in (Morphism.thm export_thm thm, + (* Provide the fixed vars so the user can generalize/instantiate them *) + [dest_Free measure_var]) + end + + +(* Define a previously-converted function (or recursive function group). + * lthy must include all definitions from l1_callees *) +fun define + (lthy: local_theory) + (filename: string) + (prog_info: ProgramInfo.prog_info) + (fn_info: FunctionInfo.fn_info) (* doesn't need to have L1 info *) + (check_termination: bool) + (l1_callees: (FunctionInfo.phase_info * thm) Symtab.table) (* L1 callees & corres thms *) + (l1_function_name: string -> string) + (funcs: (string * thm * (string * typ) list) list) (* name, corres, arg frees *) + : (FunctionInfo.phase_info * thm) Symtab.table * local_theory = let + (* FIXME: dedup with convert *) + + (* All L1 functions have the same signature: measure \ L1_monad *) + val measureT = @{typ nat}; + val l1_fn_type = measureT --> mk_l1monadT (#state_type prog_info); + + (* L1corres for f's callees. *) + fun get_l1_fn_assumption ctxt fn_name free _ _ measure_var = + mk_L1corres_call_prop ctxt prog_info fn_info check_termination fn_name (betapply (free, measure_var)); + + (* FIXME: pass this from assume_called_functions_corres, etc. *) + fun guess_callee_var thm callee = let + val base_name = l1_function_name callee; + val mentioned_vars = Term.add_vars (Thm.prop_of thm) []; + in hd (filter (fn ((v, _), _) => v = base_name) mentioned_vars) end; + + fun prepare_fn_body (fn_name, corres_thm, measure_free) = let + val @{term_pat "Trueprop (L1corres _ _ ?body _)"} = Thm.concl_of corres_thm; + val (callees, recursive_callees) = AutoCorresUtil.get_callees fn_info fn_name; + val calls = map (fn c => Var (guess_callee_var corres_thm c)) callees; + val recursive_calls = map (fn c => Var (guess_callee_var corres_thm c)) recursive_callees; + (* + * The returned body will have free variables as placeholders for the function's + * measure parameter and other arguments, and schematic variables for the functions it calls. + * + * We modify the body to be of the form: + * + * %fun1 fun2 rec1 rec2 measure arg1 arg2. f <...> + * + * That is, all non-recursive calls are abstracted out the front, followed by + * recursive calls, followed by the measure variable, followed by function + * arguments (none for L1). This is the format expected by define_funcs. + *) + val abs_body = body + |> fold lambda (rev (map Free measure_free)) + |> fold lambda (rev recursive_calls) + |> fold lambda (rev calls); + in abs_body end; + + val fn_info' = FunctionInfo.add_phases (fn f => K (Option.map fst (Symtab.lookup l1_callees f))) fn_info; + val funcs' = funcs |> + map (fn (name, thm, frees) => + (name, (* FIXME: define_funcs needs this currently *) + (Thm.generalize ([], map fst frees) (Thm.maxidx_of thm + 1) thm, + prepare_fn_body (name, thm, frees), []))); + val _ = @{trace} ("L1.define", map (fn (name, (thm, body, _)) => (name, thm, Thm.cterm_of lthy body)) funcs'); + val (corres_thms, (), lthy') = + AutoCorresUtil.define_funcs FunctionInfo.L1 filename fn_info' + l1_function_name (K l1_fn_type) get_l1_fn_assumption (K []) @{thm L1corres_recguard_0} + lthy (Symtab.map (K snd) l1_callees) () + funcs'; + val f_names = map (fn (name, _, _) => name) funcs; + val new_phases = map2 (fn f_name => fn corres => let + val old_phase = FunctionInfo.get_phase_info fn_info FunctionInfo.CP f_name; + val def = the (AutoCorresData.get_def (Proof_Context.theory_of lthy') filename "L1def" f_name); + val f_const = Utils.get_term lthy' (l1_function_name f_name); + in old_phase + |> FunctionInfo.phase_info_upd_phase FunctionInfo.L1 + |> FunctionInfo.phase_info_upd_definition def + |> FunctionInfo.phase_info_upd_const f_const + |> FunctionInfo.phase_info_upd_mono_thm NONE (* done in translate *) + end) f_names corres_thms; + (* FIXME: return traces *) + in (Symtab.make (f_names ~~ (new_phases ~~ corres_thms)), lthy') end; + + +(* + * Top level translation from SIMPL to a monadic spec. + * + * We accept a filename (the same filename passed to the C parser; the + * parser stashes away important information using this filename as the + * key) and a local theory. + * + * We define a number of new functions (the converted monadic + * specifications of the SIMPL functions) and theorems (proving + * correspondence between our generated specs and the original SIMPL + * code). + *) +(* FIXME: use AutoCorresUtil.Future instead of default *) +fun translate filename prog_info fn_info check_termination do_opt trace_opt l1_function_name lthy = +let + val funcs_to_translate = Symtab.keys (FunctionInfo.get_all_functions fn_info); + + (* All conversions can run in parallel. *) + val converted_funcs = + funcs_to_translate |> map (fn f => + (f, Future.fork (fn _ => + convert lthy prog_info fn_info check_termination do_opt trace_opt l1_function_name f))) + |> Symtab.make; + + (* Definitions update lthy sequentially. + * We use the arbitrary (but deterministic) ordering defined by get_topo_sorted_functions. + * Each definition step produces a lthy that has a prefix of the update sequence, + * and can be used in an L2 translation that depends only on that prefix. + * Hence we return the intermediate lthys as futures. *) + fun add_def f_names accum = Future.fork (fn _ => let + (* Get output from previous definition. + * Technically we don't need all of defined_so_far, but we're guaranteed + * to have them at this point already. *) + val (lthy, _, defined_so_far) = Future.join accum; + (* Wait for conversions to finish *) + val f_convs = map (fn f => let + val conv = the' ("didn't convert function: " ^ quote f ^ "??") (Symtab.lookup converted_funcs f); + val (corres_thm, arg_frees) = Future.join conv; + in (f, corres_thm, arg_frees) end) f_names; + val (new_defs, lthy') = define lthy filename prog_info fn_info check_termination + (defined_so_far: (FunctionInfo.phase_info * thm) Symtab.table) + l1_function_name f_convs; + in (lthy', new_defs, Symtab.merge (K false) (defined_so_far, new_defs)) end); + + val function_groups = FunctionInfo.get_topo_sorted_functions fn_info; + (* Chain of intermediate states: (lthy, new_defs, accumulator) *) + val (def_results, _) = Utils.accumulate add_def (Future.value (lthy, Symtab.empty, Symtab.empty)) function_groups; + + (* Produce a mapping from each function group to its L1 phase_infos and the + * earliest intermediate lthy where it is defined. *) + val results = + def_results + |> map (Future.map (fn (lthy, f_defs, _ (* discard accum here *)) => let + (* Add monad_mono proofs. These are done in parallel as well + * (though in practice, they already run pretty quickly). *) + val mono_thms = if FunctionInfo.is_function_recursive fn_info (hd (Symtab.keys f_defs)) + then l1_monad_mono lthy (Symtab.map (K fst) f_defs) + else Symtab.empty; + val f_defs' = f_defs |> Symtab.map (fn f => + apfst (FunctionInfo.phase_info_upd_mono_thm (Symtab.lookup mono_thms f))); + in (lthy, f_defs') end)); +in + (* would be ready to become a symtab, but a list also preserves order *) + function_groups ~~ results +end + +end \ No newline at end of file diff --git a/tools/autocorres/tests/examples/type_strengthen.c b/tools/autocorres/tests/examples/type_strengthen.c index d6a5057bf..8652ba12b 100644 --- a/tools/autocorres/tests/examples/type_strengthen.c +++ b/tools/autocorres/tests/examples/type_strengthen.c @@ -149,6 +149,10 @@ unsigned opt_a(unsigned m, unsigned n) { return opt_a(m - 1, opt_a(m, n - 1)); } +/* Test for measure_call */ +unsigned opt_a2(unsigned n) { + return opt_a(n, n); +} /********************* diff --git a/tools/autocorres/tests/examples/type_strengthen_tricks.thy b/tools/autocorres/tests/examples/type_strengthen_tricks.thy index 3b7ba0812..7b9b59f89 100644 --- a/tools/autocorres/tests/examples/type_strengthen_tricks.thy +++ b/tools/autocorres/tests/examples/type_strengthen_tricks.thy @@ -22,6 +22,59 @@ install_C_file "type_strengthen.c" For example, suppose that we do not want to lift loops to the option monad: *) declare gets_theE_L2_while [ts_rule option del] +context type_strengthen begin +ML \ +let val fn_info = FunctionInfo.init_fn_info @{context} "type_strengthen.c" + val prog_info = ProgramInfo.get_prog_info @{context} "type_strengthen.c" + val (corres1, frees1) = SimplConv2.convert @{context} prog_info fn_info true true false (fn f => "l1_" ^ f) "opt_j"; + val (corres2, frees2) = SimplConv2.convert @{context} prog_info fn_info true true false (fn f => "l1_" ^ f) "st_i"; + (*val thm' = Thm.generalize ([], map fst frees) (Thm.maxidx_of thm + 1) thm*) + val lthy0 = @{context}; + val (l1_infos1, lthy1) = + SimplConv2.define lthy0 "type_strengthen.c" prog_info fn_info true + Symtab.empty + (fn f => "l1_" ^ f) + [("opt_j", corres1, frees1)] + val (l1_infos2, lthy2) = + SimplConv2.define lthy1 "type_strengthen.c" prog_info fn_info true + l1_infos1 + (fn f => "l1_" ^ f) + [("st_i", corres2, frees2)] + in (frees1, corres1, Symtab.dest l1_infos2) end +\ + +ML \ +let val filename = "type_strengthen.c"; + val fn_info = FunctionInfo.init_fn_info @{context} filename; + val prog_info = ProgramInfo.get_prog_info @{context} filename; + val l1_results = + SimplConv2.translate filename prog_info fn_info + true true false (fn f => "l1_" ^ f ^ "'") @{context}; +(* + val l2_results = + LocalVarExtract2.translate filename prog_info fn_info l1_results + true false (fn f => "l2_" ^ f ^ "'"); +*) +in l1_results |> map (snd #> Future.join) |> map (snd #> Symtab.dest) end +\ + +ML \ +FunctionInfo.is_function_recursive (FunctionInfo.init_fn_info @{context} "type_strengthen.c") "opt_a" +\ + +ML \ +FunctionInfo.get_topo_sorted_functions (FunctionInfo.init_fn_info @{context} "type_strengthen.c") +\ +end + +declare [[ML_print_depth=99]] +autocorres [ + ts_rules = nondet, + scope = st_i, + skip_heap_abs, skip_word_abs + ] "type_strengthen.c" + + (* We can also specify which monads are used for type strengthening. Here, we exclude the read-only monad completely, and specify rules for some individual functions. *) @@ -32,6 +85,7 @@ autocorres [ ] "type_strengthen.c" context type_strengthen begin + (* pure_f (and indirectly, pure_f2) are now lifted to the option monad. *) thm pure_f'_def pure_f2'_def thm pure_g'_def pure_h'_def diff --git a/tools/autocorres/utils.ML b/tools/autocorres/utils.ML index 8b6a9e67d..877184920 100644 --- a/tools/autocorres/utils.ML +++ b/tools/autocorres/utils.ML @@ -85,6 +85,13 @@ fun enumerate xs = let fun nubBy _ [] = [] | nubBy f (x::xs) = x :: filter (fn y => f x <> f y) (nubBy f xs) +fun accumulate f acc xs = let + fun walk results acc [] = (results [], acc) + | walk results acc (x::xs) = let + val acc' = f x acc; + in walk (results o cons acc') acc' xs end; + in walk I acc xs end; + (* Define a constant "name" of type "term" into the local theory "lthy". *) fun define_const name term lthy = let