lh-l4v/tools/autocorres/heap_lift.ML
Gerwin Klein afb3c7291c isabelle2021-1 autocorres: context in convs
Conv.params_conv changes the context, and the inner conversion that
it runs needs to work on that inner context, otherwise information
is lost about which of the Free variables are former Bound.

Isabelle2021-1 has more thorough checking and fails when the wrong
context is provided.

Signed-off-by: Gerwin Klein <gerwin.klein@proofcraft.systems>
2022-03-29 08:38:25 +11:00

896 lines
41 KiB
Standard ML

(*
* Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
*
* SPDX-License-Identifier: BSD-2-Clause
*)
(*
* Rewrite L2 specifications to use a higher-level ("lifted") heap representation.
*
* The main interface to this module is translate (and inner functions
* convert and define). See AutoCorresUtil for a conceptual overview.
*)
structure HeapLift =
struct
(* Convenience shortcuts. *)
val warning = Utils.ac_warning
val apply_tac = Utils.apply_tac
val the' = Utils.the'
(* Print the current goal state then fail hard. *)
exception ProofFailed of string
fun fail_tac ctxt s = (print_tac ctxt s THEN (fn _ => Seq.single (raise (ProofFailed s))))
type heap_info = HeapLiftBase.heap_info
(* Return the function for fetching an object of a particular type. *)
fun get_heap_getter (heap_info : heap_info) T =
case Typtab.lookup (#heap_getters heap_info) T of
SOME x => Const x
| NONE => Utils.invalid_typ "heap type for getter" T
(* Return the function for updating an object of a particular type. *)
fun get_heap_setter (heap_info : heap_info) T =
case Typtab.lookup (#heap_setters heap_info) T of
SOME x => Const x
| NONE => Utils.invalid_typ "heap type for setter" T
(* Return the function for determining if a given pointer is valid for a type. *)
fun get_heap_valid_getter (heap_info : heap_info) T =
case Typtab.lookup (#heap_valid_getters heap_info) T of
SOME x => Const x
| NONE => Utils.invalid_typ "heap type for valid getter" T
(* Return the function for updating if a given pointer is valid for a type. *)
fun get_heap_valid_setter (heap_info : heap_info) T =
case Typtab.lookup (#heap_valid_setters heap_info) T of
SOME x => Const x
| NONE => Utils.invalid_typ "heap type for valid setter" T
(* Return the heap type used by a function. *)
fun get_expected_fn_state_type heap_info is_function_lifted fn_name =
if is_function_lifted fn_name then
#globals_type heap_info
else
#old_globals_type heap_info
(* Get a state translation function for the given function. *)
fun get_expected_st heap_info is_function_lifted fn_name =
if is_function_lifted fn_name then
(#lift_fn_full heap_info)
else
@{mk_term "id :: ?'a => ?'a" ('a)} (#old_globals_type heap_info)
(* Get the expected type of a function from its name. *)
fun get_expected_hl_fn_type prog_info l2_infos (heap_info : HeapLiftBase.heap_info)
is_function_lifted fn_name =
let
val fn_def = the (Symtab.lookup l2_infos fn_name)
val fn_params_typ = map snd (#args fn_def)
(* Fill in the measure argument and return type *)
val globals_typ = get_expected_fn_state_type heap_info is_function_lifted fn_name
val fn_ret_typ = #return_type fn_def
val measure_typ = @{typ "nat"}
val fn_typ = (measure_typ :: fn_params_typ)
---> LocalVarExtract.mk_l2monadT globals_typ fn_ret_typ @{typ unit}
in
fn_typ
end
(* Get the expected theorem that will be generated about a function. *)
fun get_expected_hl_fn_thm prog_info l2_infos (heap_info : HeapLiftBase.heap_info)
is_function_lifted ctxt fn_name function_free fn_args _ measure_var =
let
(* Get L2 const *)
val l2_def = the (Symtab.lookup l2_infos fn_name)
val l2_term = betapplys (#const l2_def, measure_var :: fn_args)
(* Get expected HL const. *)
val hl_term = betapplys (function_free, measure_var :: fn_args)
in
@{mk_term "Trueprop (L2Tcorres ?st ?A ?C)" (st, A, C)}
(get_expected_st heap_info is_function_lifted fn_name, hl_term, l2_term)
end
(* Get arguments passed into the function. *)
fun get_expected_hl_fn_args prog_info l2_infos fn_name =
#args (the (Symtab.lookup l2_infos fn_name))
(*
* Guess whether a function can be lifted.
*
* For example, we probably can't lift functions that introspect the heap-type
* data "hrs_htd".
*)
fun can_lift_function lthy prog_info fn_info =
let
val t = #definition fn_info |> Thm.prop_of |> Utils.rhs_of
(* Determine if everything in term "t" appears valid for lifting. *)
val bad_consts = [@{const_name hrs_htd}, @{const_name hrs_htd_update}, @{const_name ptr_retyp}]
fun term_contains_const_name c t =
exists_Const (fn (const_name, _) => c = const_name) t
in
not (exists (fn c => term_contains_const_name c t) bad_consts)
end
(*
* Convert a cterm from the format "f a (b n) c" into "((f $ a) $ (b $ n)) $ c".
*
* Return a "thm" of the form "old = new".
*)
fun mk_first_order ctxt ct =
let
fun expand_conv ct =
Utils.named_cterm_instantiate ctxt
[("a", Thm.dest_fun ct),("b", Thm.dest_arg ct)] @{lemma "a b == (a $ b)" by simp}
in
Conv.bottom_conv (K (Conv.try_conv expand_conv)) ctxt ct
end
(* The opposite to "mk_first_order" *)
fun dest_first_order ctxt ct =
Conv.bottom_conv (K (Conv.try_conv (Conv.rewr_conv
@{lemma "($) == (%a b. a b)" by (rule meta_ext, rule ext, simp)}))) ctxt ct
(*
* Resolve "base_thm" with "subgoal_thm" in all assumptions it is possible to
* do so.
*
* Return a tuple: (<new thm>, <a change was made>).
*)
fun greedy_thm_instantiate base_thm subgoal_thm =
let
val asms = Thm.prop_of base_thm |> Logic.strip_assums_hyp
in
fold (fn (i, asm) => fn (thm, change_made) =>
if (Term.could_unify (asm, Thm.concl_of subgoal_thm)) then
(subgoal_thm RSN (i, thm), true) handle (THM _ ) => (thm, change_made)
else
(thm, change_made)) (tag_list 1 asms) (base_thm, false)
end
(* Return a list of thm's where "base_thm" has been successfully resolved with
* one of "subgoal_thms". *)
fun instantiate_against_thms base_thm subgoal_thms =
map (greedy_thm_instantiate base_thm) subgoal_thms
|> filter snd
|> map fst
(*
* Modify a list of thms to instantiate assumptions where ever possible.
*)
fun cross_instantiate base_thms subgoal_thm_lists =
let
fun iterate_base subgoal_thms base_thms =
map (fn thm => (instantiate_against_thms thm subgoal_thms) @ [thm]) base_thms
|> List.concat
in
fold iterate_base subgoal_thm_lists base_thms
end
(*
* EXPERIMENTAL: define wrappers and syntax for common heap operations.
* We use the notations "s[p]->r" for {p->r} and "s[p->r := q]" for {p->r = q}.
* For non-fields, "s[p]" and "s[p := q]".
* The wrappers are named like "get_type_field" and "update_type_field".
*
* Known issues:
* * Every pair of getter/setter and valid/setter lemmas should be generated.
* If you find yourself expanding one of the wrapper definitions, then
* something wasn't generated correctly.
*
* * On that note, lemmas relating structs and struct fields
* (foo vs foo.field) are not being generated.
* * TODO: this problem appears in Suzuki.thy
*
* * The syntax looks as terrible as c-parser's. Well, at least you won't need
* to subscript greek letters.
*
* * Isabelle doesn't like overloaded syntax. Issue VER-412
*)
exception NO_GETTER_SETTER (* Not visible externally *)
fun mixfix (sy, ps, p) = Mixfix (Input.string sy, ps, p, Position.no_range)
(* Define getter/setter and syntax for one struct field.
Returns the getter/setter and their definitions. *)
fun field_syntax (heap_info : HeapLiftBase.heap_info)
(struct_info : HeapLiftBase.struct_info)
(field_info: HeapLiftBase.field_info)
(new_getters, new_setters, lthy) =
let
fun unsuffix' suffix str = if String.isSuffix suffix str then unsuffix suffix str else str
val struct_pname = unsuffix' "_C" (#name struct_info)
val field_pname = unsuffix' "_C" (#name field_info)
val struct_typ = #struct_type struct_info
val state_var = ("s", #globals_type heap_info)
val ptr_var = ("ptr", Type (@{type_name "ptr"}, [struct_typ]))
val val_var = ("val", #field_type field_info)
val struct_getter = case Typtab.lookup (#heap_getters heap_info) struct_typ of
SOME getter => Const getter
| _ => raise NO_GETTER_SETTER
val struct_setter = case Typtab.lookup (#heap_setters heap_info) struct_typ of
SOME setter => Const setter
| _ => raise NO_GETTER_SETTER
(* We will modify lthy soon, so may not exit with NO_GETTER_SETTER after this point *)
(* Define field accessor function *)
val field_getter_term = @{mk_term "?field_get (?heap_get s ptr)" (heap_get, field_get)}
(struct_getter, #getter field_info)
val new_heap_get_name = "get_" ^ struct_pname ^ "_" ^ field_pname
val (new_heap_get, new_heap_get_thm, lthy) =
Utils.define_const_args new_heap_get_name false field_getter_term
[state_var, ptr_var] lthy
val field_getter = @{mk_term "?get s ptr" (get)} new_heap_get
val field_getter_typ = type_of (fold lambda (rev [Free state_var, Free ptr_var]) field_getter)
(* Define field update function *)
val field_setter_term = @{mk_term "?heap_update (%old. old(ptr := ?field_update (%_. val) (old ptr))) s"
(heap_update, field_update)} (struct_setter, #setter field_info)
val new_heap_update_name = "update_" ^ struct_pname ^ "_" ^ field_pname
val (new_heap_update, new_heap_update_thm, lthy) =
Utils.define_const_args new_heap_update_name false field_setter_term
[state_var, ptr_var, val_var] lthy
val field_setter = @{mk_term "?update s ptr new" (update)} new_heap_update
val field_setter_typ = type_of (fold lambda (rev [Free state_var, Free ptr_var, Free val_var]) field_setter)
val getter_mixfix = mixfix ("_[_]\<rightarrow>" ^ (Syntax_Ext.escape field_pname), [1000], 1000)
val setter_mixfix = mixfix ("_[_\<rightarrow>" ^ (Syntax_Ext.escape field_pname) ^ " := _]", [1000], 1000)
val lthy = Local_Theory.notation true Syntax.mode_default [
(new_heap_get, getter_mixfix),
(new_heap_update, setter_mixfix)] lthy
(* The struct_pname returned here must match the type_pname returned in heap_syntax.
* new_heap_update_thm relies on this to determine what kind of thm to generate. *)
val new_getters = Symtab.update_new (new_heap_get_name,
(struct_pname, field_pname, new_heap_get, SOME new_heap_get_thm)) new_getters
val new_setters = Symtab.update_new (new_heap_update_name,
(struct_pname, field_pname, new_heap_update, SOME new_heap_update_thm)) new_setters
in
(new_getters, new_setters, lthy)
end
handle NO_GETTER_SETTER => (new_getters, new_setters, lthy)
(* Define syntax for one C type. This also creates new wrappers for heap updates. *)
fun heap_syntax (heap_info : HeapLiftBase.heap_info)
(heap_type : typ)
(new_getters, new_setters, lthy) =
let
val getter = case Typtab.lookup (#heap_getters heap_info) heap_type of
SOME x => x
| NONE => raise TYPE ("heap_lift/heap_syntax: no getter", [heap_type], [])
val setter = case Typtab.lookup (#heap_setters heap_info) heap_type of
SOME x => x
| NONE => raise TYPE ("heap_lift/heap_syntax: no setter", [heap_type], [])
fun replace_C (#"_" :: #"C" :: xs) = replace_C xs
| replace_C (x :: xs) = x :: replace_C xs
| replace_C [] = []
val type_pname = HeapLiftBase.name_from_type heap_type
|> String.explode |> replace_C |> String.implode
val state_var = ("s", #globals_type heap_info)
val heap_ptr_type = Type (@{type_name "ptr"}, [heap_type])
val ptr_var = ("ptr", heap_ptr_type)
val val_var = ("val", heap_type)
val setter_def = @{mk_term "?heap_update (%old. old(ptr := val)) s" heap_update} (Const setter)
val new_heap_update_name = "update_" ^ type_pname
val (new_heap_update, new_heap_update_thm, lthy) =
Utils.define_const_args new_heap_update_name false setter_def
[state_var, ptr_var, val_var] lthy
val getter_mixfix = mixfix ("_[_]", [1000], 1000)
val setter_mixfix = mixfix ("_[_ := _]", [1000], 1000)
val lthy = Local_Theory.notation true Syntax.mode_default
[(Const getter, getter_mixfix), (new_heap_update, setter_mixfix)] lthy
val new_getters = Symtab.update_new (Long_Name.base_name (fst getter), (type_pname, "", Const getter, NONE)) new_getters
val new_setters = Symtab.update_new (new_heap_update_name, (type_pname, "", new_heap_update, SOME new_heap_update_thm)) new_setters
in
(new_getters, new_setters, lthy)
end
(* Make all heap syntax and collect the results. *)
fun make_heap_syntax heap_info lthy =
(Symtab.empty, Symtab.empty, lthy)
(* struct fields *)
|> Symtab.fold (fn (_, struct_info) =>
fold (field_syntax heap_info struct_info)
(#field_info struct_info)
) (#structs heap_info)
(* types *)
|> fold (heap_syntax heap_info) (Typtab.keys (#heap_getters heap_info))
(* Prove lemmas for the new getter/setter definitions. *)
fun new_heap_update_thm (getter_type_name, getter_field_name, getter, getter_def)
(setter_type_name, setter_field_name, setter, setter_def)
lthy =
(* TODO: also generate lemmas relating whole-struct updates to field updates *)
if getter_type_name = setter_type_name
andalso not ((getter_field_name = "") = (setter_field_name = "")) then NONE else
let val lhs = @{mk_term "?get (?set s p v)" (get, set)} (getter, setter)
val rhs = if getter_type_name = setter_type_name andalso
getter_field_name = setter_field_name
(* functional update *)
then @{mk_term "(?get s) (p := v)" (get)} getter
(* separation *)
else @{mk_term "?get s" (get)} getter
val prop = @{mk_term "Trueprop (?lhs = ?rhs)" (lhs, rhs)} (lhs, rhs)
val defs = the_list getter_def @ the_list setter_def
val thm = Goal.prove_future lthy ["s", "p", "v"] [] prop
(fn params => (simp_tac ((#context params) addsimps
@{thms ext fun_upd_apply} @ defs) 1))
in SOME thm end
fun new_heap_valid_thm valid_term (_, _, setter, NONE) lthy = NONE
| new_heap_valid_thm valid_term (_, _, setter, SOME setter_def) lthy =
let val prop = @{mk_term "Trueprop (?valid (?set s p v) q = ?valid s q)" (valid, set)}
(Const valid_term, setter)
val thm = Goal.prove_future lthy ["s", "p", "v", "q"] [] prop
(fn params => (simp_tac ((#context params) addsimps
[@{thm fun_upd_apply}, setter_def]) 1))
in SOME thm end
(* Take a definition and eta contract the RHS:
lhs = rhs s ==> (%s. lhs) = rhs
This allows us to rewrite a heap update even if the state is eta contracted away. *)
fun eta_rhs lthy thm = let
val Const (@{const_name "Pure.eq"}, typ) $ lhs $ (rhs $ Var (("s", s_n), s_typ)) = term_of_thm thm
val abs_term = @{mk_term "?a == ?b" (a, b)} (lambda (Var (("s", s_n), s_typ)) lhs, rhs)
val thm' = Goal.prove_future lthy [] [] abs_term
(fn params => simp_tac (put_simpset HOL_basic_ss (#context params) addsimps thm :: @{thms atomize_eq ext}) 1)
in thm' end
(* Extract the abstract term out of a L2Tcorres thm. *)
fun dest_L2Tcorres_term_abs @{term_pat "L2Tcorres _ ?t _"} = t
(* Generate lifted_globals lemmas and instantiate them into the heap lifting rules. *)
fun lifted_globals_lemmas prog_info heap_info lthy = let
(* Tactic to solve subgoals below. *)
local
(* Fetch simp rules generated by the C Parser about structures. *)
val struct_simpset = UMM_Proof_Theorems.get (Proof_Context.theory_of lthy)
fun lookup_the t k = case Symtab.lookup t k of SOME x => x | NONE => []
val struct_simps =
(lookup_the struct_simpset "typ_name_simps")
@ (lookup_the struct_simpset "typ_name_itself")
@ (lookup_the struct_simpset "fl_ti_simps")
@ (lookup_the struct_simpset "fl_simps")
@ (lookup_the struct_simpset "fg_cons_simps")
val base_ss = simpset_of @{theory_context HeapLift}
val record_ss = RecordUtils.get_record_simpset lthy
val merged_ss = merge_ss (base_ss, record_ss)
(* Generate a simpset containing everything we need. *)
val ss =
(Context_Position.set_visible false lthy)
|> put_simpset merged_ss
|> (fn ctxt => ctxt
addsimps [#lift_fn_thm heap_info]
@ @{thms typ_simple_heap_simps}
@ @{thms valid_globals_field_def}
@ @{thms the_fun_upd_lemmas}
@ struct_simps)
|> simpset_of
in
fun subgoal_solver_tac ctxt =
(fast_force_tac (put_simpset ss ctxt) 1)
ORELSE (CHANGED (Method.try_intros_tac ctxt [@{thm conjI}, @{thm ext}] []
THEN clarsimp_tac (put_simpset ss ctxt) 1))
end
(* Generate "valid_typ_heap" predicates for each heap type we have. *)
fun mk_valid_typ_heap_thm typ =
@{mk_term "Trueprop (valid_typ_heap ?st ?getter ?setter ?valid_getter ?valid_setter ?t_hrs ?t_hrs_update)"
(st, getter, setter, valid_getter, valid_setter, t_hrs, t_hrs_update)}
(#lift_fn_full heap_info,
get_heap_getter heap_info typ,
get_heap_setter heap_info typ,
get_heap_valid_getter heap_info typ,
get_heap_valid_setter heap_info typ,
#t_hrs_getter prog_info,
#t_hrs_setter prog_info)
|> (fn prop => Goal.prove_future lthy [] [] prop
(fn params =>
((resolve_tac lthy @{thms valid_typ_heapI} 1) THEN (
REPEAT (subgoal_solver_tac (#context params))))))
(* Make thms for all types. *)
(* FIXME: these are currently auto-parallelised using prove_future, but perhaps
* we should exercise finer control over the evaluation, as prove_futures
* persist long after the AutoCorres command actually returns. *)
val heap_types = (#heap_getters heap_info |> Typtab.dest |> map fst)
val valid_typ_heap_thms = map mk_valid_typ_heap_thm heap_types
(* Generate "valid_typ_heap" thms for signed words. *)
val valid_typ_heap_thms =
valid_typ_heap_thms
@ (map_product
(fn a => fn b => try (fn _ => a OF [b]) ())
@{thms signed_valid_typ_heaps}
valid_typ_heap_thms
|> map_filter I)
(* Generate "valid_struct_field" for each field of each struct. *)
fun mk_valid_struct_field_thm struct_name typ (field_info : HeapLiftBase.field_info) =
@{mk_term "Trueprop (valid_struct_field ?st [?fname] ?fgetter ?fsetter ?t_hrs ?t_hrs_update)"
(st, fname, fgetter, fsetter, t_hrs, t_hrs_update) }
(#lift_fn_full heap_info,
Utils.encode_isa_string (#name field_info),
#getter field_info,
#setter field_info,
#t_hrs_getter prog_info,
#t_hrs_setter prog_info)
|> (fn prop =>
(* HACK: valid_struct_field currently works only for packed types,
* so typecheck the prop first *)
case try (Syntax.check_term lthy) prop of
SOME _ =>
[Goal.prove_future lthy [] [] prop
(fn params =>
(resolve_tac lthy @{thms valid_struct_fieldI} 1) THEN
(* Need some extra thms from the records package for our struct type. *)
(EqSubst.eqsubst_tac lthy [0]
[hd (Proof_Context.get_thms lthy (struct_name ^ "_idupdates")) RS @{thm sym}] 1
THEN asm_full_simp_tac lthy 1) THEN
(FIRST (Proof_Context.get_thms lthy (struct_name ^ "_fold_congs")
|> map (fn t => resolve_tac lthy [t OF @{thms refl refl}] 1))
THEN asm_full_simp_tac lthy 1) THEN
(REPEAT (subgoal_solver_tac (#context params))))]
| NONE => [])
(* Generate "valid_struct_field_legacy" for each field of each struct. *)
fun mk_valid_struct_field_legacy_thm typ (field_info : HeapLiftBase.field_info) =
@{mk_term "Trueprop (valid_struct_field_legacy ?st [?fname] ?fgetter (%v. ?fsetter (%_. v)) ?getter ?setter ?valid_getter ?valid_setter ?t_hrs ?t_hrs_update)"
(st, fname, fgetter, fsetter, getter, setter, valid_getter, valid_setter, t_hrs, t_hrs_update) }
(#lift_fn_full heap_info,
Utils.encode_isa_string (#name field_info),
#getter field_info,
#setter field_info,
get_heap_getter heap_info typ,
get_heap_setter heap_info typ,
get_heap_valid_getter heap_info typ,
get_heap_valid_setter heap_info typ,
#t_hrs_getter prog_info,
#t_hrs_setter prog_info)
|> (fn prop => Goal.prove_future lthy [] [] prop
(fn params =>
(resolve_tac lthy @{thms valid_struct_field_legacyI} 1) THEN (
REPEAT (subgoal_solver_tac (#context params)))))
(* Make thms for all fields of structs in our heap. *)
fun valid_struct_abs_thms T =
case (Typtab.lookup (#struct_types heap_info) T) of
NONE => []
| SOME struct_info =>
map (fn field =>
mk_valid_struct_field_thm (#name struct_info) T field
@ [mk_valid_struct_field_legacy_thm T field])
(#field_info struct_info)
|> List.concat
val valid_field_thms =
map valid_struct_abs_thms heap_types |> List.concat
(* Generate conversions from globals embedded directly in the "globals" and
* "lifted_globals" record. *)
fun mk_valid_globals_field_thm name =
@{mk_term "Trueprop (valid_globals_field ?st ?old_get ?old_set ?new_get ?new_set)"
(st, old_get, old_set, new_get, new_set)}
(#lift_fn_full heap_info,
Symtab.lookup (#global_field_getters heap_info) name |> the |> fst,
Symtab.lookup (#global_field_setters heap_info) name |> the |> fst,
Symtab.lookup (#global_field_getters heap_info) name |> the |> snd,
Symtab.lookup (#global_field_setters heap_info) name |> the |> snd)
|> (fn prop => Goal.prove_future lthy [] [] prop (fn params => subgoal_solver_tac (#context params)))
val valid_global_field_thms = map (#1 #> mk_valid_globals_field_thm) (#global_fields heap_info)
(* At this point, the lemmas are ready to be instantiated into the generic
* heap_abs rules (which will be fetched from the most recent lthy). *)
in
[ valid_typ_heap_thms, valid_field_thms, valid_global_field_thms ]
end;
(*
* Prepare for the heap lifting phase.
* We need to:
* - define a lifted_globals type
* - prove generic heap lifting lemmas for the lifted_globals type
* - define heap syntax and rewrite rules (if heap_abs_syntax is set)
* - store these new results into the HeapInfo theory data
* Note that because we are adding definitions that are required by all
* conversions, we need to wait for all previous L2 conversions to finish,
* limiting parallelism somewhat. This requires us to modify l2_results by
* updating its intermediate lthys.
*
* These results are cached in the local theory, so we attempt to fetch an
* existing definition (in the case that we are resuming a previous run).
* In this scenario, we don't have to modify l2_results.
*)
fun prepare_heap_lift
(filename : string)
(prog_info : ProgramInfo.prog_info)
(l2_results : (local_theory * FunctionInfo.function_info Symtab.table) FSeq.fseq)
(* An initial lthy, used to check for an existing heap_lift_setup.
* Also used as fallback target in the unlikely case where l2_results = [] *)
(lthy0 : local_theory)
(* We define the lifted heap for all functions in the program, even if they are
* not included in this translation. This allows heap lifting to work with
* incremental translations. *)
(all_simpl_infos : FunctionInfo.function_info Symtab.table)
(* Settings *)
(make_lifted_globals_field_name : string -> string)
(gen_word_heaps : bool)
(heap_abs_syntax : bool)
: ((local_theory * FunctionInfo.function_info Symtab.table) FSeq.fseq
* HeapLiftBase.heap_lift_setup) =
let
(* Get target lthy for adding new definitions.
* This is the most recent l2_results lthy, except if there are no results,
* in which case the fallback lthy is used. *)
fun get_target_lthy l2_results fallback_lthy =
if FSeq.null l2_results then Option.getOpt (fallback_lthy, lthy0)
else fst (List.last (FSeq.list_of l2_results));
fun update_results lthy = FSeq.map (apfst (K lthy));
(* Set up heap_info and associated lemmas *)
val (l2_results, HL_setup, fallback_lthy) =
case Symtab.lookup (HeapInfo.get (Proof_Context.theory_of lthy0)) filename of
SOME HL_setup => (l2_results, HL_setup, NONE)
| NONE => let
val lthy = get_target_lthy l2_results NONE;
val (heap_info, lthy) = HeapLiftBase.setup prog_info all_simpl_infos
make_lifted_globals_field_name gen_word_heaps lthy;
val lifted_heap_lemmas = lifted_globals_lemmas prog_info heap_info lthy;
val HL_setup = { heap_info = heap_info,
lifted_heap_lemmas = lifted_heap_lemmas,
heap_syntax_rewrs = [] };
val lthy = Local_Theory.background_theory (
HeapInfo.map (fn tbl => Symtab.update (filename, HL_setup) tbl)) lthy;
in (update_results lthy l2_results, HL_setup, SOME lthy) end;
(* Do some extra lifting and create syntax (see field_syntax comment).
* We do this separately because heap_abs_syntax could be enabled halfway
* through an incremental translation. *)
val (l2_results, HL_setup, fallback_lthy) =
if not heap_abs_syntax orelse not (null (#heap_syntax_rewrs HL_setup))
then (l2_results, HL_setup, fallback_lthy)
else
let val lthy = get_target_lthy l2_results fallback_lthy;
val (heap_syntax_rewrs, lthy) =
Utils.exec_background_result (fn lthy => let
val optcat = List.mapPartial I
val heap_info = #heap_info HL_setup
(* Define the new heap operations and their syntax. *)
val (new_getters, new_setters, lthy) =
make_heap_syntax heap_info lthy
(* Make simplification thms and add them to the simpset. *)
val update_thms = map (fn get => map (fn set => new_heap_update_thm get set lthy)
(Symtab.dest new_setters |> map snd))
(Symtab.dest new_getters |> map snd)
|> List.concat
val valid_thms = map (fn valid => map (fn set => new_heap_valid_thm valid set lthy)
(Symtab.dest new_setters |> map snd))
(Typtab.dest (#heap_valid_getters heap_info) |> map snd)
|> List.concat
val thms = update_thms @ valid_thms |> optcat
val lthy = Utils.simp_add thms lthy
(* Name the thms. (FIXME: do this elsewhere?) *)
val (_, lthy) = Utils.define_lemmas "heap_abs_simps" thms lthy
(* Rewrite rules for converting the program. *)
val getter_thms = Symtab.dest new_getters |> map (#4 o snd) |> optcat
val setter_thms = Symtab.dest new_setters |> map (#4 o snd) |> optcat
val eta_setter_thms = map (eta_rhs lthy) setter_thms
val rewrite_thms = map (fn thm => @{thm symmetric} OF [thm])
(getter_thms @ eta_setter_thms)
in (rewrite_thms, lthy) end) lthy;
val HL_setup = { heap_info = #heap_info HL_setup,
lifted_heap_lemmas = #lifted_heap_lemmas HL_setup,
heap_syntax_rewrs = heap_syntax_rewrs };
val lthy = Local_Theory.background_theory (
HeapInfo.map (fn tbl => Symtab.update (filename, HL_setup) tbl)) lthy;
in (update_results lthy l2_results, HL_setup, SOME lthy) end;
in (l2_results, HL_setup) end;
(* Convert a program to use a lifted heap. *)
fun translate
(filename : string)
(prog_info : ProgramInfo.prog_info)
(l2_results : FunctionInfo.phase_results)
(existing_l2_infos : FunctionInfo.function_info Symtab.table)
(existing_hl_infos : FunctionInfo.function_info Symtab.table)
(HL_setup : HeapLiftBase.heap_lift_setup)
(no_heap_abs : Symset.key Symset.set)
(force_heap_abs : Symset.key Symset.set)
(heap_abs_syntax : bool)
(keep_going : bool)
(trace_funcs : string list)
(do_opt : bool)
(trace_opt : bool)
(add_trace: string -> string -> AutoCorresData.Trace -> unit)
(hl_function_name : string -> string)
: FunctionInfo.phase_results =
if FSeq.null l2_results then FSeq.empty () else
let
(* lthy for conversion rules. This needs to be (at latest) the earliest lthy
* result so that the rules can be used in all conversions *)
val lthy0 = fst (FSeq.hd l2_results);
val heap_info = #heap_info HL_setup;
(*
* Fetch rules from the theory, instantiating any rule with the
* lifted_globals lemmas for "valid_globals_field", "valid_typ_heap" etc.
* that we generated previously.
*)
val base_rules = Utils.get_rules lthy0 @{named_theorems heap_abs}
val rules =
cross_instantiate base_rules (#lifted_heap_lemmas HL_setup)
(* Remove rules that haven't been fully instantiated *)
|> filter_out (Thm.prop_of #> exists_subterm (fn x =>
case x of Const (@{const_name "valid_globals_field"}, _) => true
| Const (@{const_name "valid_struct_field"}, _) => true
| Const (@{const_name "valid_struct_field_legacy"}, _) => true
| Const (@{const_name "valid_typ_heap"}, _) => true
| _ => false));
(* We only use this blanket rule for non-lifted functions;
* liftable expressions can be handled by specific struct_rewrite rules *)
val nolift_rules = @{thms struct_rewrite_expr_id}
(* This does a linear search. We will only need it in is_function_lifted, though *)
fun lookup_l2_results f_name =
FSeq.find (fn (_, l2_infos) => Symtab.defined l2_infos f_name) l2_results
|> the' ("HL: missing L2 results for " ^ f_name);
fun is_function_lifted f_name =
case Symtab.lookup existing_hl_infos f_name of
SOME info => let
(* We heap-lifted this function earlier. Check its state type. *)
val body = #definition info |> Thm.prop_of |> Utils.rhs_of_eq;
val stT = LocalVarExtract.l2monad_state_type body;
in stT = #globals_type heap_info end
| NONE => let
val (lthy, l2_infos) = lookup_l2_results f_name;
val can_lift =
if can_lift_function lthy prog_info (the (Symtab.lookup l2_infos f_name))
then not (Symset.contains no_heap_abs f_name)
else if Symset.contains force_heap_abs f_name
then true
else (* Report functions that we're not lifting,
* but not if the user has overridden explicitly *)
(if Symset.contains no_heap_abs f_name then () else
writeln ("HL: disabling heap lift for: " ^ f_name ^
" (use force_heap_abs to enable)");
false);
in can_lift end;
(* Cache answers for which functions we are lifting. *)
val is_function_lifted = String_Memo.memo is_function_lifted;
(* Convert to new heap format. *)
fun convert lthy l2_infos f: AutoCorresUtil.convert_result =
let
val f_l2_info = the (Symtab.lookup l2_infos f);
(* Fix measure variable. *)
val ([measure_var_name], lthy) = Variable.variant_fixes ["rec_measure'"] lthy;
val measure_var = Free (measure_var_name, AutoCorresUtil.measureT);
(* Add callee assumptions. *)
val (lthy, export_thm, callee_terms) =
AutoCorresUtil.assume_called_functions_corres lthy
(#callees f_l2_info) (#rec_callees f_l2_info)
(get_expected_hl_fn_type prog_info l2_infos heap_info is_function_lifted)
(get_expected_hl_fn_thm prog_info l2_infos heap_info is_function_lifted)
(get_expected_hl_fn_args prog_info l2_infos)
hl_function_name
measure_var;
(* Fix argument variables. *)
val new_fn_args = get_expected_hl_fn_args prog_info l2_infos f;
val (arg_names, lthy) = Variable.variant_fixes (map fst new_fn_args) lthy;
val arg_frees = map Free (arg_names ~~ map snd new_fn_args);
(* Fetch the function definition. *)
val l2_body_def =
#definition f_l2_info
(* Instantiate the arguments. *)
|> Utils.inst_args lthy (map (Thm.cterm_of lthy) (measure_var :: arg_frees))
(* Get L2 body definition with function arguments. *)
val l2_term = betapplys (#const f_l2_info, measure_var :: arg_frees)
(* Get our state translation function. *)
val st = get_expected_st heap_info is_function_lifted f
(* Generate a schematic goal. *)
val goal = @{mk_term "Trueprop (L2Tcorres ?st ?A ?C)" (st, C)}
(st, l2_term)
|> Thm.cterm_of lthy
|> Goal.init
|> Utils.apply_tac "unfold RHS" (EqSubst.eqsubst_tac lthy [0] [l2_body_def] 1)
val callee_mono_thms =
callee_terms |> map fst
|> List.mapPartial (fn callee =>
if FunctionInfo.is_function_recursive (the (Symtab.lookup l2_infos callee))
then #mono_thm (the (Symtab.lookup l2_infos callee))
else NONE)
val rules = rules @ (map (snd #> #3) callee_terms) @ callee_mono_thms
val rules = if is_function_lifted f then rules else rules @ nolift_rules
val fo_rules = Utils.get_rules lthy @{named_theorems heap_abs_fo}
(* Apply a conversion to the concrete side of the given L2T term.
* By convention, the concrete side is the last argument (index ~1). *)
fun l2t_conc_body_conv conv =
Conv.params_conv (~1) (fn ctxt => (Conv.arg_conv (Utils.nth_arg_conv (~1) (conv ctxt))))
(* Standard tactics. *)
val print_debug = f = ""
fun rtac_all r n = (APPEND_LIST (map (fn thm =>
resolve_tac lthy [thm] n THEN (fn x =>
(if print_debug then @{trace} thm else ();
Seq.succeed x))) r))
(* Convert the concrete side of the given L2T term to/from first-order form. *)
val l2t_to_fo_tac = CONVERSION (Drule.beta_eta_conversion then_conv l2t_conc_body_conv mk_first_order lthy)
val l2t_from_fo_tac = CONVERSION (l2t_conc_body_conv (fn ctxt => dest_first_order ctxt then_conv Drule.beta_eta_conversion) lthy)
val fo_tac = ((l2t_to_fo_tac THEN' rtac_all fo_rules) THEN_ALL_NEW l2t_from_fo_tac) 1
(*
* Recursively solve subgoals.
*
* We allow backtracking in order to solve a particular subgoal, but once a
* subgoal is completed we don't ever try to solve it in a different way.
*
* This allows us to try different approaches to solving subgoals without
* leading to exponential explosion (of many different combinations of
* "good solutions") once we hit an unsolvable subgoal.
*)
val tactics =
if #is_simpl_wrapper f_l2_info
then (* Solver for trivial simpl wrappers. *)
[(@{thm L2Tcorres_id}, resolve_tac lthy [@{thm L2Tcorres_id}] 1)]
else map (fn rule => (rule, resolve_tac lthy [rule] 1)) rules
@ [(@{thm fun_app_def}, fo_tac)]
val replay_failure_start = 1
val replay_failures = Unsynchronized.ref replay_failure_start
val (thm, trace) =
case AutoCorresTrace.maybe_trace_solve_tac lthy (member (op =) trace_funcs f)
true false (K tactics) goal NONE replay_failures of
NONE => (* intentionally generate a TRACE_SOLVE_TAC_FAIL *)
(AutoCorresTrace.trace_solve_tac lthy false false (K tactics) goal NONE (Unsynchronized.ref 0);
(* never reached *) error "heap_lift fail tac: impossible")
| SOME (thm, [trace]) => (Goal.finish lthy thm, trace)
val _ = if !replay_failures < replay_failure_start then
warning ("HL: " ^ f ^ ": reverted to slow replay " ^
Int.toString(replay_failure_start - !replay_failures) ^ " time(s)") else ()
(* DEBUG: make sure that all uses of field_lvalue and c_guard are rewritten.
* Also make sure that we cleaned up internal constants. *)
fun contains_const name = exists_subterm (fn x => case x of Const (n, _) => n = name | _ => false)
fun const_gone term name =
if not (contains_const name term) then ()
else Utils.TERM_non_critical keep_going
("Heap lift: could not remove " ^ name ^ " in " ^ f ^ ".") [term]
fun const_old_heap term name =
if not (contains_const name term) then ()
else warning ("Heap lift: could not remove " ^ name ^ " in " ^ f ^
". Output program may be unprovable.")
val _ = if not (is_function_lifted f) then []
else (map (const_gone (term_of_thm thm))
[@{const_name "heap_lift__h_val"}];
map (const_old_heap (term_of_thm thm))
[@{const_name "field_lvalue"}, @{const_name "c_guard"}]
)
(* Apply peephole optimisations to the theorem. *)
val _ = writeln ("Simplifying (HL) " ^ f)
val (thm, opt_traces) = L2Opt.cleanup_thm_tagged lthy thm (if do_opt then 0 else 2) 2 trace_opt "HL"
(* If we created extra heap wrappers, apply them now.
* Our simp rules don't seem to be enough for L2Opt,
* so we cannot change the program before that. *)
val thm = if not heap_abs_syntax then thm else
Raw_Simplifier.rewrite_rule lthy (#heap_syntax_rewrs HL_setup) thm
val f_body = dest_L2Tcorres_term_abs (HOLogic.dest_Trueprop (Thm.concl_of thm));
(* Get actual recursive callees *)
val rec_callees = AutoCorresUtil.get_rec_callees callee_terms f_body;
(* Return the constants that we fixed. This will be used to process the returned body. *)
val callee_consts =
callee_terms |> map (fn (callee, (_, const, _)) => (callee, const)) |> Symtab.make;
in
{ body = f_body,
proof = Morphism.thm export_thm thm,
rec_callees = rec_callees,
callee_consts = callee_consts,
arg_frees = map dest_Free (measure_var :: arg_frees),
traces = (if member (op =) trace_funcs f
then [("HL", AutoCorresData.RuleTrace trace)] else []) @ opt_traces
}
end
(* Define a previously-converted function (or recursive function group).
* lthy must include all definitions from hl_callees. *)
fun define
(lthy: local_theory)
(l2_infos: FunctionInfo.function_info Symtab.table)
(hl_callees: FunctionInfo.function_info Symtab.table)
(funcs: AutoCorresUtil.convert_result Symtab.table)
: FunctionInfo.function_info Symtab.table * local_theory = let
val funcs' = Symtab.dest funcs |>
map (fn result as (name, {proof, arg_frees, ...}) =>
(name, (AutoCorresUtil.abstract_fn_body l2_infos result,
proof, arg_frees)));
val (new_thms, lthy') =
AutoCorresUtil.define_funcs
FunctionInfo.HL filename l2_infos hl_function_name
(get_expected_hl_fn_type prog_info l2_infos heap_info is_function_lifted)
(get_expected_hl_fn_thm prog_info l2_infos heap_info is_function_lifted)
(get_expected_hl_fn_args prog_info l2_infos)
@{thm L2Tcorres_recguard_0}
lthy (Symtab.map (K #corres_thm) hl_callees)
funcs';
val new_infos = Symtab.map (fn f_name => fn (const, def, corres_thm) => let
val old_info = the (Symtab.lookup l2_infos f_name);
in old_info
|> FunctionInfo.function_info_upd_phase FunctionInfo.HL
|> FunctionInfo.function_info_upd_const const
|> FunctionInfo.function_info_upd_definition def
|> FunctionInfo.function_info_upd_corres_thm corres_thm
|> FunctionInfo.function_info_upd_mono_thm NONE (* added later *)
end) new_thms;
in (new_infos, lthy') end;
(* Do conversions in parallel. *)
val converted_groups = AutoCorresUtil.par_convert convert existing_l2_infos l2_results add_trace;
(* Sequence of new function_infos and intermediate lthys *)
val def_results = FSeq.mk (fn _ =>
(* If there's nothing to translate, we won't have a lthy to use *)
if FSeq.null l2_results then NONE else
let (* Get initial lthy from end of L2 defs *)
val (l2_lthy, _) = FSeq.list_of l2_results |> List.last;
val results = AutoCorresUtil.define_funcs_sequence
l2_lthy define existing_l2_infos existing_hl_infos converted_groups;
in FSeq.uncons results end);
(* Produce a mapping from each function group to its L1 phase_infos and the
* earliest intermediate lthy where it is defined. *)
val results =
def_results
|> FSeq.map (fn (lthy, f_defs) => let
(* Add monad_mono proofs. These are done in parallel as well
* (though in practice, they already run pretty quickly). *)
val mono_thms = if FunctionInfo.is_function_recursive (snd (hd (Symtab.dest f_defs)))
then LocalVarExtract.l2_monad_mono lthy f_defs
else Symtab.empty;
val f_defs' = f_defs |> Symtab.map (fn f =>
FunctionInfo.function_info_upd_mono_thm (Symtab.lookup mono_thms f));
in (lthy, f_defs') end);
in results end
end