(* * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) * * SPDX-License-Identifier: BSD-2-Clause *) (* * Rewrite L2 specifications to use "nat" and "int" data-types instead of * "word" data types. The former tend to be easier to reason about. * * The main interface to this module is translate (and inner functions * convert and define). See AutoCorresUtil for a conceptual overview. *) structure WordAbstract = struct (* Maximum depth that we will go before assuming that we are diverging. *) val WORD_ABS_MAX_DEPTH = 200 (* Convenience shortcuts. *) val warning = Utils.ac_warning val apply_tac = Utils.apply_tac val the' = Utils.the' type WARules = { signed : bool, lentype : typ, ctype : typ, atype : typ, abs_fn : term, inv_fn : term, rules : thm list } fun WARules_update rules_upd {signed, lentype, ctype, atype, abs_fn, inv_fn, rules} = {signed = signed, lentype = lentype, ctype = ctype, atype = atype, abs_fn = abs_fn, inv_fn = inv_fn, rules = rules_upd rules} val word_sizes = [@{typ 64}, @{typ 32}, @{typ 16}, @{typ 8}] fun mk_word_abs_rule T = { signed = false, lentype = T, ctype = fastype_of (@{mk_term "x :: (?'W::len) word" ('W)} T), atype = @{typ nat}, abs_fn = @{mk_term "unat :: (?'W::len) word => nat" ('W)} T, inv_fn = @{mk_term "of_nat :: nat => (?'W::len) word" ('W)} T, rules = @{thms word_abs_word} } val word_abs : WARules list = map mk_word_abs_rule word_sizes fun mk_sword_abs_rule T = { signed = true, lentype = T, ctype = fastype_of (@{mk_term "x :: (?'W::len) signed word" ('W)} T), atype = @{typ int}, abs_fn = @{mk_term "sint :: (?'W::len) signed word => int" ('W)} T, inv_fn = @{mk_term "of_int :: int => (?'W::len) signed word" ('W)} T, rules = @{thms word_abs_sword} } val sword_abs : WARules list = map mk_sword_abs_rule word_sizes (* Make rules for signed word reads from the lifted heap. * * The lifted heap stores all words as unsigned, but we need to avoid * generating unsigned arith guards when accessing signed words. * These rules are placed early in the ruleset and kick in before the * unsigned abstract_val rules get to run. *) fun mk_sword_heap_get_rule ctxt heap_info (rules: WARules) = let val uwordT = Type (@{type_name word}, [#lentype rules]) in case try (HeapLift.get_heap_getter heap_info) uwordT of NONE => NONE | SOME getter => SOME (@{thm abstract_val_heap_sword_template} |> Drule.infer_instantiate ctxt [(("heap_get", 0), Thm.cterm_of ctxt getter)]) end (* Get abstract version of a HOL type. *) fun get_abs_type (rules : WARules list) T = Option.getOpt (List.find (fn r => #ctype r = T) rules |> Option.map (fn r => #atype r), T) (* Get abstraction function for a HOL type. *) fun get_abs_fn (rules : WARules list) T = Option.getOpt (List.find (fn r => #ctype r = T) rules |> Option.map (fn r => #abs_fn r), @{mk_term "id :: ?'a => ?'a" ('a)} T) fun get_abs_inv_fn (rules : WARules list) t = Option.getOpt (List.find (fn r => #ctype r = fastype_of t) rules |> Option.map (fn r => #inv_fn r $ t), t) (* * From a list of abstract arguments to a function, derive a list of * concrete arguments and types and a precondition that links the two. *) fun get_wa_conc_args rules l2_infos fn_name fn_args lthy = let (* Construct arguments for the concrete body. We use the abstract names * with a prime ('), but with the concrete types. *) val args0 = the (Symtab.lookup l2_infos fn_name) |> #args; val conc_types = map snd args0; val (conc_names, lthy) = Variable.variant_fixes (map (fn Free (n, T) => n ^ "'") fn_args) lthy; val conc_args = map Free (conc_names ~~ conc_types) val arg_pairs = (conc_args ~~ fn_args) (* Create preconditions that link the new types to the old types. *) val precond = map (fn (old, new) => @{mk_term "abs_var ?n ?f ?o" (o, f, n)} (old, get_abs_fn rules (fastype_of old), new)) arg_pairs |> Utils.mk_conj_list in (conc_types, conc_args, precond, arg_pairs, lthy) end (* Get the expected type of a function from its name. *) fun get_expected_fn_type rules l2_infos fn_name = let val fn_info = the (Symtab.lookup l2_infos fn_name) val fn_params_typ = map ((get_abs_type rules) o snd) (#args fn_info) val fn_ret_typ = get_abs_type rules (#return_type fn_info) val globals_typ = LocalVarExtract.dest_l2monad_T (fastype_of (#const fn_info)) |> snd |> #1 val measure_typ = @{typ "nat"} in (measure_typ :: fn_params_typ) ---> LocalVarExtract.mk_l2monadT globals_typ fn_ret_typ @{typ unit} end (* Get the expected theorem that will be generated about a function. *) fun get_expected_fn_thm rules l2_infos ctxt fn_name function_free fn_args _ measure_var = let val old_def = the (Symtab.lookup l2_infos fn_name) val (old_arg_types, old_args, precond, arg_pairs, ctxt) = get_wa_conc_args rules l2_infos fn_name fn_args ctxt val old_term = betapplys (#const old_def, measure_var :: old_args) val new_term = betapplys (function_free, measure_var :: fn_args) in @{mk_term "Trueprop (corresTA (%x. ?P) ?rt id ?A ?C)" (rt, A, C, P)} (get_abs_fn rules (#return_type old_def), new_term, old_term, precond) |> fold (fn t => fn v => Logic.all t v) (rev (map fst arg_pairs)) end (* Get arguments passed into the function. *) fun get_expected_fn_args rules l2_infos fn_name = map (apsnd (get_abs_type rules)) (#args (the (Symtab.lookup l2_infos fn_name))) (* * Convert a theorem of the form: * * corresTA (%_. abs_var True a f a' \ abs_var True b f b' \ ...) ... * * into * * [| abstract_val A a f a'; abstract_val B b (f b') |] ==> corresTA (%_. A \ B \ ...) ... * * the latter of which better suits our resolution approach of proof * construction. *) fun extract_preconds_of_call thm = let fun r thm = r (thm RS @{thm corresTA_extract_preconds_of_call_step}) handle THM _ => (thm RS @{thm corresTA_extract_preconds_of_call_final} handle THM _ => thm RS @{thm corresTA_extract_preconds_of_call_final'}); in r (thm RS @{thm corresTA_extract_preconds_of_call_init}) end (* Convert a program by abstracting words. *) fun translate (filename: string) (prog_info: ProgramInfo.prog_info) (* Needed for mk_sword_heap_get_rule *) (heap_info: HeapLiftBase.heap_info option) (* Note that we refer to the previous phase as "l2"; it may be * from the L2 or HL phase. *) (l2_results: FunctionInfo.phase_results) (existing_l2_infos: FunctionInfo.function_info Symtab.table) (existing_wa_infos: FunctionInfo.function_info Symtab.table) (unsigned_abs: Symset.key Symset.set) (no_signed_abs: Symset.key Symset.set) (trace_funcs: string list) (do_opt: bool) (trace_opt: bool) (add_trace: string -> string -> AutoCorresData.Trace -> unit) (wa_function_name: string -> string) : FunctionInfo.phase_results = if FSeq.null l2_results then FSeq.empty () else let (* * Select the rules to translate each function. *) fun rules_for fn_name = (if Symset.contains unsigned_abs fn_name then word_abs else []) @ (if Symset.contains no_signed_abs fn_name then [] else sword_abs) (* Convert each function. *) fun convert lthy l2_infos f: AutoCorresUtil.convert_result = let val old_fn_info = the (Symtab.lookup l2_infos f); (* Add heap lift workaround for each word length that is in the heap. *) fun add_sword_heap_get_rules rules = if not (#signed rules) then [] else case heap_info of NONE => [] | SOME hi => the_list (mk_sword_heap_get_rule lthy hi rules) val wa_rules = rules_for f val sword_heap_rules = map add_sword_heap_get_rules wa_rules (* Fix measure variable. *) val ([measure_var_name], lthy) = Variable.variant_fixes ["rec_measure'"] lthy; val measure_var = Free (measure_var_name, AutoCorresUtil.measureT); (* Add callee assumptions. *) val (lthy, export_thm, callee_terms) = AutoCorresUtil.assume_called_functions_corres lthy (#callees old_fn_info) (#rec_callees old_fn_info) (fn f => get_expected_fn_type (rules_for f) l2_infos f) (fn lthy => fn f => get_expected_fn_thm (rules_for f) l2_infos lthy f) (fn f => get_expected_fn_args (rules_for f) l2_infos f) wa_function_name measure_var; (* Fix argument variables. *) val new_fn_args = get_expected_fn_args wa_rules l2_infos f; val (arg_names, lthy) = Variable.variant_fixes (map fst new_fn_args) lthy; val arg_frees = map Free (arg_names ~~ map snd new_fn_args); (* Construct free variables to represent our concrete arguments. *) val (conc_types, conc_args, precond, arg_pairs, lthy) = get_wa_conc_args wa_rules l2_infos f arg_frees lthy (* Fetch the function definition, and instantiate its arguments. *) val old_body_def = #definition old_fn_info (* Instantiate the arguments. *) |> Utils.inst_args lthy (map (Thm.cterm_of lthy) (measure_var :: conc_args)) (* Get old body definition with function arguments. *) val old_term = betapplys (#const old_fn_info, measure_var :: conc_args) (* Get a schematic variable accepting new arguments. *) val new_var = betapplys ( Var (("A", 0), get_expected_fn_type wa_rules l2_infos f), measure_var :: arg_frees) (* Fetch monotonicity theorems of callees. *) val callee_mono_thms = Symset.dest (FunctionInfo.all_callees old_fn_info) |> List.mapPartial (fn callee => if FunctionInfo.is_function_recursive (the (Symtab.lookup l2_infos callee)) then the (Symtab.lookup l2_infos callee) |> #mono_thm else NONE) (* * Generate a schematic goal. * * We only want ?A to depend on abstracted variables and ?C to depend on * concrete variables. We force this by applying bound variables to each * of the schematics, giving us something like: * * !!a a' b b'. corresTA ... (?A a b) (?C a' b') * * The abstract side will hence be prevented from capturing (i.e., using) * concrete variables, and vice-versa. *) val goal = @{mk_term "Trueprop (corresTA (%x. ?precond) ?ra id ?A ?C)" (ra, A, C, precond)} (get_abs_fn wa_rules (#return_type old_fn_info), new_var, old_term, precond) |> fold (fn t => fn v => Logic.all t v) (rev (arg_frees @ map fst arg_pairs)) |> Thm.cterm_of lthy |> Goal.init |> Utils.apply_tac "move precond to assumption" (resolve_tac lthy @{thms corresTA_precond_to_asm} 1) |> Utils.apply_tac "split precond" (REPEAT (CHANGED (eresolve_tac lthy @{thms conjE} 1))) |> Utils.apply_tac "create schematic precond" (resolve_tac lthy @{thms corresTA_precond_to_guard} 1) |> Utils.apply_tac "unfold RHS" (CHANGED (Utils.unfold_once_tac lthy (Utils.abs_def lthy old_body_def) 1)) (* * Fetch rules from the theory. * FIXME: move this out of convert? *) val rules = Utils.get_rules lthy @{named_theorems word_abs} @ List.concat (sword_heap_rules @ map #rules wa_rules) @ @{thms word_abs_default} val fo_rules = [@{thm abstract_val_fun_app}] val rules = rules @ (map (snd #> #3 #> extract_preconds_of_call) callee_terms) @ callee_mono_thms (* Standard tactics. *) fun my_rtac ctxt thm n = Utils.trace_if_success ctxt thm ( DETERM (EVERY' (resolve_tac ctxt [thm] :: replicate (Rule_Cases.get_consumes thm) (assume_tac ctxt)) n)) (* Apply a conversion to the abstract/concrete side of the given "abstract_val" term. *) fun wa_conc_body_conv conv = Conv.params_conv (~1) (K (Conv.concl_conv (~1) ((Conv.arg_conv (Utils.nth_arg_conv 4 conv))))) (* Tactics and conversions for converting goals into first-order format. *) fun to_fo_tac ctxt = CONVERSION (Drule.beta_eta_conversion then_conv wa_conc_body_conv (HeapLift.mk_first_order ctxt) ctxt) fun from_fo_tac ctxt = CONVERSION (wa_conc_body_conv (HeapLift.dest_first_order ctxt then_conv Drule.beta_eta_conversion) ctxt) fun make_fo_tac tac ctxt = ((to_fo_tac ctxt THEN' tac) THEN_ALL_NEW from_fo_tac ctxt) (* * Recursively solve subgoals. * * We allow backtracking in order to solve a particular subgoal, but once a * subgoal is completed we don't ever try to solve it in a different way. * * This allows us to try different approaches to solving subgoals, * hopefully reducing exponential explosion (of many different combinations * of "good solutions") once we hit an unsolvable subgoal. *) (* * Eliminate a lambda term in the concrete state, but only if the * lambda is "real". * * That is, we don't attempt to eta-expand in order to apply the theorem * "abstract_val_lambda", because that may lead to an infinite loop with * "abstract_val_fun_app". *) fun lambda_tac n thm = case Logic.concl_of_goal (Thm.prop_of thm) n of (Const (@{const_name "Trueprop"}, _) $ (Const (@{const_name "abstract_val"}, _) $ _ $ _ $ _ $ ( Abs (_, _, _)))) => resolve_tac lthy @{thms abstract_val_lambda} n thm | _ => no_tac thm (* All tactics we try, in the order we should try them. *) val step_tacs = [(@{thm imp_refl}, assume_tac lthy 1)] @ (map (fn thm => (thm, my_rtac lthy thm 1)) rules) @ (map (fn thm => (thm, make_fo_tac (my_rtac lthy thm) lthy 1)) fo_rules) @ [(@{thm abstract_val_lambda}, lambda_tac 1)] @ [(@{thm reflexive}, fn thm => (if Config.get lthy ML_Options.exception_trace then warning ("Could not solve subgoal: " ^ (Goal_Display.string_of_goal lthy thm)) else (); no_tac thm))] (* Solve the goal. *) val replay_failure_start = 1 val replay_failures = Unsynchronized.ref replay_failure_start val (thm, trace) = case AutoCorresTrace.maybe_trace_solve_tac lthy (member (op =) trace_funcs f) true false (K step_tacs) goal (SOME WORD_ABS_MAX_DEPTH) replay_failures of NONE => (* intentionally generate a TRACE_SOLVE_TAC_FAIL *) (AutoCorresTrace.trace_solve_tac lthy false false (K step_tacs) goal NONE (Unsynchronized.ref 0); (* never reached *) error "word_abstract fail tac: impossible") | SOME (thm, [trace]) => (Goal.finish lthy thm, trace) val _ = if !replay_failures < replay_failure_start then @{trace} (f ^ " WA: reverted to slow replay " ^ Int.toString(replay_failure_start - !replay_failures) ^ " time(s)") else () (* Clean out any final function application ($) constants or "id" constants * generated by some rules. *) fun corresTA_abs_conv conv = Utils.remove_meta_conv (fn ctxt => Utils.nth_arg_conv 4 (conv ctxt)) lthy val thm = Conv.fconv_rule ( corresTA_abs_conv (fn ctxt => (HeapLift.dest_first_order ctxt) then_conv (Simplifier.rewrite ( put_simpset HOL_basic_ss ctxt addsimps [@{thm id_def}])) then_conv Drule.beta_eta_conversion ) ) thm (* Ensure no schematics remain in the goal. *) val _ = Sign.no_vars lthy (Thm.prop_of thm) (* * Instantiate abstract function's meta-forall variables with their actual values. * * That is, we go from: * * !!a b c a' b' c'. corresTA (P a b c) ... * * to * * !!a' b' c'. corresTA (P a b c) ... *) val thm = Drule.forall_elim_list (map (Thm.cterm_of lthy) arg_frees) thm (* Apply peephole optimisations to the theorem. *) val _ = writeln ("Simpifying (WA) " ^ f) val (thm, opt_traces) = L2Opt.cleanup_thm_tagged lthy thm (if do_opt then 0 else 2) 4 trace_opt "WA" (* We end up with an unwanted L2_guard outside the L2_recguard. * L2Opt should simplify the condition to (%_. True) even if (not do_opt), * so we match the guard and get rid of it here. *) val thm = Simplifier.rewrite_rule lthy @{thms corresTA_simp_trivial_guard} thm (* Convert all-quantified vars in the concrete body to schematics. *) val thm = Variable.gen_all lthy thm (* Extract the abstract term out of a corresTA thm. *) fun dest_corresTA_term_abs @{term_pat "corresTA _ _ _ ?t _"} = t val f_body = Thm.concl_of thm |> HOLogic.dest_Trueprop |> dest_corresTA_term_abs; (* Get actual recursive callees *) val rec_callees = AutoCorresUtil.get_rec_callees callee_terms f_body; (* Return the constants that we fixed. This will be used to process the returned body. *) val callee_consts = callee_terms |> map (fn (callee, (_, const, _)) => (callee, const)) |> Symtab.make; in { body = f_body, proof = Morphism.thm export_thm thm, rec_callees = rec_callees, callee_consts = callee_consts, arg_frees = map dest_Free (measure_var :: arg_frees), traces = (if member (op =) trace_funcs f then [("WA", AutoCorresData.RuleTrace trace)] else []) @ opt_traces } end (* Define a previously-converted function (or recursive function group). * lthy must include all definitions from wa_callees. *) fun define (lthy: local_theory) (l2_infos: FunctionInfo.function_info Symtab.table) (wa_callees: FunctionInfo.function_info Symtab.table) (funcs: AutoCorresUtil.convert_result Symtab.table) : FunctionInfo.function_info Symtab.table * local_theory = let val funcs' = Symtab.dest funcs |> map (fn result as (name, {proof, arg_frees, ...}) => (name, (AutoCorresUtil.abstract_fn_body l2_infos result, proof, arg_frees))); val (new_thms, lthy') = AutoCorresUtil.define_funcs FunctionInfo.WA filename l2_infos wa_function_name (fn f => get_expected_fn_type (rules_for f) l2_infos f) (fn lthy => fn f => get_expected_fn_thm (rules_for f) l2_infos lthy f) (fn f => get_expected_fn_args (rules_for f) l2_infos f) @{thm corresTA_recguard_0} lthy (Symtab.map (K #corres_thm) wa_callees) funcs'; val new_infos = Symtab.map (fn f_name => fn (const, def, corres_thm) => let val old_info = the (Symtab.lookup l2_infos f_name); in old_info |> FunctionInfo.function_info_upd_phase FunctionInfo.WA |> FunctionInfo.function_info_upd_const const |> FunctionInfo.function_info_upd_definition def |> FunctionInfo.function_info_upd_corres_thm corres_thm |> FunctionInfo.function_info_upd_return_type (get_abs_type (rules_for f_name) (#return_type old_info)) |> FunctionInfo.function_info_upd_args (map (apsnd (get_abs_type (rules_for f_name))) (#args old_info)) |> FunctionInfo.function_info_upd_mono_thm NONE (* added later *) end) new_thms; in (new_infos, lthy') end; (* Do conversions in parallel. *) val converted_groups = AutoCorresUtil.par_convert convert existing_l2_infos l2_results add_trace; (* Sequence of intermediate states: (lthy, new_defs) *) val def_results = FSeq.mk (fn _ => (* If there's nothing to translate, we won't have a lthy to use *) if FSeq.null l2_results then NONE else let (* Get initial lthy from end of L2 defs *) val (l2_lthy, _) = FSeq.list_of l2_results |> List.last; val results = AutoCorresUtil.define_funcs_sequence l2_lthy define existing_l2_infos existing_wa_infos converted_groups; in FSeq.uncons results end); (* Produce a mapping from each function group to its L1 phase_infos and the * earliest intermediate lthy where it is defined. *) val results = def_results |> FSeq.map (fn (lthy, f_defs) => let (* Add monad_mono proofs. These are done in parallel as well * (though in practice, they already run pretty quickly). *) val mono_thms = if FunctionInfo.is_function_recursive (snd (hd (Symtab.dest f_defs))) then LocalVarExtract.l2_monad_mono lthy f_defs else Symtab.empty; val f_defs' = f_defs |> Symtab.map (fn f => FunctionInfo.function_info_upd_mono_thm (Symtab.lookup mono_thms f)); in (lthy, f_defs') end); in results end; end