lh-l4v/tools/autocorres/heap_lift.ML

896 lines
41 KiB
Standard ML

(*
* Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
*
* SPDX-License-Identifier: BSD-2-Clause
*)
(*
* Rewrite L2 specifications to use a higher-level ("lifted") heap representation.
*
* The main interface to this module is translate (and inner functions
* convert and define). See AutoCorresUtil for a conceptual overview.
*)
structure HeapLift =
struct
(* Convenience shortcuts. *)
val warning = Utils.ac_warning
val apply_tac = Utils.apply_tac
val the' = Utils.the'
(* Print the current goal state then fail hard. *)
exception ProofFailed of string
fun fail_tac ctxt s = (print_tac ctxt s THEN (fn _ => Seq.single (raise (ProofFailed s))))
type heap_info = HeapLiftBase.heap_info
(* Return the function for fetching an object of a particular type. *)
fun get_heap_getter (heap_info : heap_info) T =
case Typtab.lookup (#heap_getters heap_info) T of
SOME x => Const x
| NONE => Utils.invalid_typ "heap type for getter" T
(* Return the function for updating an object of a particular type. *)
fun get_heap_setter (heap_info : heap_info) T =
case Typtab.lookup (#heap_setters heap_info) T of
SOME x => Const x
| NONE => Utils.invalid_typ "heap type for setter" T
(* Return the function for determining if a given pointer is valid for a type. *)
fun get_heap_valid_getter (heap_info : heap_info) T =
case Typtab.lookup (#heap_valid_getters heap_info) T of
SOME x => Const x
| NONE => Utils.invalid_typ "heap type for valid getter" T
(* Return the function for updating if a given pointer is valid for a type. *)
fun get_heap_valid_setter (heap_info : heap_info) T =
case Typtab.lookup (#heap_valid_setters heap_info) T of
SOME x => Const x
| NONE => Utils.invalid_typ "heap type for valid setter" T
(* Return the heap type used by a function. *)
fun get_expected_fn_state_type heap_info is_function_lifted fn_name =
if is_function_lifted fn_name then
#globals_type heap_info
else
#old_globals_type heap_info
(* Get a state translation function for the given function. *)
fun get_expected_st heap_info is_function_lifted fn_name =
if is_function_lifted fn_name then
(#lift_fn_full heap_info)
else
@{mk_term "id :: ?'a => ?'a" ('a)} (#old_globals_type heap_info)
(* Get the expected type of a function from its name. *)
fun get_expected_hl_fn_type prog_info l2_infos (heap_info : HeapLiftBase.heap_info)
is_function_lifted fn_name =
let
val fn_def = the (Symtab.lookup l2_infos fn_name)
val fn_params_typ = map snd (#args fn_def)
(* Fill in the measure argument and return type *)
val globals_typ = get_expected_fn_state_type heap_info is_function_lifted fn_name
val fn_ret_typ = #return_type fn_def
val measure_typ = @{typ "nat"}
val fn_typ = (measure_typ :: fn_params_typ)
---> LocalVarExtract.mk_l2monadT globals_typ fn_ret_typ @{typ unit}
in
fn_typ
end
(* Get the expected theorem that will be generated about a function. *)
fun get_expected_hl_fn_thm prog_info l2_infos (heap_info : HeapLiftBase.heap_info)
is_function_lifted ctxt fn_name function_free fn_args _ measure_var =
let
(* Get L2 const *)
val l2_def = the (Symtab.lookup l2_infos fn_name)
val l2_term = betapplys (#const l2_def, measure_var :: fn_args)
(* Get expected HL const. *)
val hl_term = betapplys (function_free, measure_var :: fn_args)
in
@{mk_term "Trueprop (L2Tcorres ?st ?A ?C)" (st, A, C)}
(get_expected_st heap_info is_function_lifted fn_name, hl_term, l2_term)
end
(* Get arguments passed into the function. *)
fun get_expected_hl_fn_args prog_info l2_infos fn_name =
#args (the (Symtab.lookup l2_infos fn_name))
(*
* Guess whether a function can be lifted.
*
* For example, we probably can't lift functions that introspect the heap-type
* data "hrs_htd".
*)
fun can_lift_function lthy prog_info fn_info =
let
val t = #definition fn_info |> Thm.prop_of |> Utils.rhs_of
(* Determine if everything in term "t" appears valid for lifting. *)
val bad_consts = [@{const_name hrs_htd}, @{const_name hrs_htd_update}, @{const_name ptr_retyp}]
fun term_contains_const_name c t =
exists_Const (fn (const_name, _) => c = const_name) t
in
not (exists (fn c => term_contains_const_name c t) bad_consts)
end
(*
* Convert a cterm from the format "f a (b n) c" into "((f $ a) $ (b $ n)) $ c".
*
* Return a "thm" of the form "old = new".
*)
fun mk_first_order ctxt ct =
let
fun expand_conv ct =
Utils.named_cterm_instantiate ctxt
[("a", Thm.dest_fun ct),("b", Thm.dest_arg ct)] @{lemma "a b == (a $ b)" by simp}
in
Conv.bottom_conv (K (Conv.try_conv expand_conv)) ctxt ct
end
(* The opposite to "mk_first_order" *)
fun dest_first_order ctxt ct =
Conv.bottom_conv (K (Conv.try_conv (Conv.rewr_conv
@{lemma "($) == (%a b. a b)" by (rule meta_ext, rule ext, simp)}))) ctxt ct
(*
* Resolve "base_thm" with "subgoal_thm" in all assumptions it is possible to
* do so.
*
* Return a tuple: (<new thm>, <a change was made>).
*)
fun greedy_thm_instantiate base_thm subgoal_thm =
let
val asms = Thm.prop_of base_thm |> Logic.strip_assums_hyp
in
fold (fn (i, asm) => fn (thm, change_made) =>
if (Term.could_unify (asm, Thm.concl_of subgoal_thm)) then
(subgoal_thm RSN (i, thm), true) handle (THM _ ) => (thm, change_made)
else
(thm, change_made)) (tag_list 1 asms) (base_thm, false)
end
(* Return a list of thm's where "base_thm" has been successfully resolved with
* one of "subgoal_thms". *)
fun instantiate_against_thms base_thm subgoal_thms =
map (greedy_thm_instantiate base_thm) subgoal_thms
|> filter snd
|> map fst
(*
* Modify a list of thms to instantiate assumptions where ever possible.
*)
fun cross_instantiate base_thms subgoal_thm_lists =
let
fun iterate_base subgoal_thms base_thms =
map (fn thm => (instantiate_against_thms thm subgoal_thms) @ [thm]) base_thms
|> List.concat
in
fold iterate_base subgoal_thm_lists base_thms
end
(*
* EXPERIMENTAL: define wrappers and syntax for common heap operations.
* We use the notations "s[p]->r" for {p->r} and "s[p->r := q]" for {p->r = q}.
* For non-fields, "s[p]" and "s[p := q]".
* The wrappers are named like "get_type_field" and "update_type_field".
*
* Known issues:
* * Every pair of getter/setter and valid/setter lemmas should be generated.
* If you find yourself expanding one of the wrapper definitions, then
* something wasn't generated correctly.
*
* * On that note, lemmas relating structs and struct fields
* (foo vs foo.field) are not being generated.
* * TODO: this problem appears in Suzuki.thy
*
* * The syntax looks as terrible as c-parser's. Well, at least you won't need
* to subscript greek letters.
*
* * Isabelle doesn't like overloaded syntax. Issue VER-412
*)
exception NO_GETTER_SETTER (* Not visible externally *)
fun mixfix (sy, ps, p) = Mixfix (Input.string sy, ps, p, Position.no_range)
(* Define getter/setter and syntax for one struct field.
Returns the getter/setter and their definitions. *)
fun field_syntax (heap_info : HeapLiftBase.heap_info)
(struct_info : HeapLiftBase.struct_info)
(field_info: HeapLiftBase.field_info)
(new_getters, new_setters, lthy) =
let
fun unsuffix' suffix str = if String.isSuffix suffix str then unsuffix suffix str else str
val struct_pname = unsuffix' "_C" (#name struct_info)
val field_pname = unsuffix' "_C" (#name field_info)
val struct_typ = #struct_type struct_info
val state_var = ("s", #globals_type heap_info)
val ptr_var = ("ptr", Type (@{type_name "ptr"}, [struct_typ]))
val val_var = ("val", #field_type field_info)
val struct_getter = case Typtab.lookup (#heap_getters heap_info) struct_typ of
SOME getter => Const getter
| _ => raise NO_GETTER_SETTER
val struct_setter = case Typtab.lookup (#heap_setters heap_info) struct_typ of
SOME setter => Const setter
| _ => raise NO_GETTER_SETTER
(* We will modify lthy soon, so may not exit with NO_GETTER_SETTER after this point *)
(* Define field accessor function *)
val field_getter_term = @{mk_term "?field_get (?heap_get s ptr)" (heap_get, field_get)}
(struct_getter, #getter field_info)
val new_heap_get_name = "get_" ^ struct_pname ^ "_" ^ field_pname
val (new_heap_get, new_heap_get_thm, lthy) =
Utils.define_const_args new_heap_get_name false field_getter_term
[state_var, ptr_var] lthy
val field_getter = @{mk_term "?get s ptr" (get)} new_heap_get
val field_getter_typ = type_of (fold lambda (rev [Free state_var, Free ptr_var]) field_getter)
(* Define field update function *)
val field_setter_term = @{mk_term "?heap_update (%old. old(ptr := ?field_update (%_. val) (old ptr))) s"
(heap_update, field_update)} (struct_setter, #setter field_info)
val new_heap_update_name = "update_" ^ struct_pname ^ "_" ^ field_pname
val (new_heap_update, new_heap_update_thm, lthy) =
Utils.define_const_args new_heap_update_name false field_setter_term
[state_var, ptr_var, val_var] lthy
val field_setter = @{mk_term "?update s ptr new" (update)} new_heap_update
val field_setter_typ = type_of (fold lambda (rev [Free state_var, Free ptr_var, Free val_var]) field_setter)
val getter_mixfix = mixfix ("_[_]\<rightarrow>" ^ (Syntax_Ext.escape field_pname), [1000], 1000)
val setter_mixfix = mixfix ("_[_\<rightarrow>" ^ (Syntax_Ext.escape field_pname) ^ " := _]", [1000], 1000)
val lthy = Specification.notation true Syntax.mode_default [
(new_heap_get, getter_mixfix),
(new_heap_update, setter_mixfix)] lthy
(* The struct_pname returned here must match the type_pname returned in heap_syntax.
* new_heap_update_thm relies on this to determine what kind of thm to generate. *)
val new_getters = Symtab.update_new (new_heap_get_name,
(struct_pname, field_pname, new_heap_get, SOME new_heap_get_thm)) new_getters
val new_setters = Symtab.update_new (new_heap_update_name,
(struct_pname, field_pname, new_heap_update, SOME new_heap_update_thm)) new_setters
in
(new_getters, new_setters, lthy)
end
handle NO_GETTER_SETTER => (new_getters, new_setters, lthy)
(* Define syntax for one C type. This also creates new wrappers for heap updates. *)
fun heap_syntax (heap_info : HeapLiftBase.heap_info)
(heap_type : typ)
(new_getters, new_setters, lthy) =
let
val getter = case Typtab.lookup (#heap_getters heap_info) heap_type of
SOME x => x
| NONE => raise TYPE ("heap_lift/heap_syntax: no getter", [heap_type], [])
val setter = case Typtab.lookup (#heap_setters heap_info) heap_type of
SOME x => x
| NONE => raise TYPE ("heap_lift/heap_syntax: no setter", [heap_type], [])
fun replace_C (#"_" :: #"C" :: xs) = replace_C xs
| replace_C (x :: xs) = x :: replace_C xs
| replace_C [] = []
val type_pname = HeapLiftBase.name_from_type heap_type
|> String.explode |> replace_C |> String.implode
val state_var = ("s", #globals_type heap_info)
val heap_ptr_type = Type (@{type_name "ptr"}, [heap_type])
val ptr_var = ("ptr", heap_ptr_type)
val val_var = ("val", heap_type)
val setter_def = @{mk_term "?heap_update (%old. old(ptr := val)) s" heap_update} (Const setter)
val new_heap_update_name = "update_" ^ type_pname
val (new_heap_update, new_heap_update_thm, lthy) =
Utils.define_const_args new_heap_update_name false setter_def
[state_var, ptr_var, val_var] lthy
val getter_mixfix = mixfix ("_[_]", [1000], 1000)
val setter_mixfix = mixfix ("_[_ := _]", [1000], 1000)
val lthy = Specification.notation true Syntax.mode_default
[(Const getter, getter_mixfix), (new_heap_update, setter_mixfix)] lthy
val new_getters = Symtab.update_new (Long_Name.base_name (fst getter), (type_pname, "", Const getter, NONE)) new_getters
val new_setters = Symtab.update_new (new_heap_update_name, (type_pname, "", new_heap_update, SOME new_heap_update_thm)) new_setters
in
(new_getters, new_setters, lthy)
end
(* Make all heap syntax and collect the results. *)
fun make_heap_syntax heap_info lthy =
(Symtab.empty, Symtab.empty, lthy)
(* struct fields *)
|> Symtab.fold (fn (_, struct_info) =>
fold (field_syntax heap_info struct_info)
(#field_info struct_info)
) (#structs heap_info)
(* types *)
|> fold (heap_syntax heap_info) (Typtab.keys (#heap_getters heap_info))
(* Prove lemmas for the new getter/setter definitions. *)
fun new_heap_update_thm (getter_type_name, getter_field_name, getter, getter_def)
(setter_type_name, setter_field_name, setter, setter_def)
lthy =
(* TODO: also generate lemmas relating whole-struct updates to field updates *)
if getter_type_name = setter_type_name
andalso not ((getter_field_name = "") = (setter_field_name = "")) then NONE else
let val lhs = @{mk_term "?get (?set s p v)" (get, set)} (getter, setter)
val rhs = if getter_type_name = setter_type_name andalso
getter_field_name = setter_field_name
(* functional update *)
then @{mk_term "(?get s) (p := v)" (get)} getter
(* separation *)
else @{mk_term "?get s" (get)} getter
val prop = @{mk_term "Trueprop (?lhs = ?rhs)" (lhs, rhs)} (lhs, rhs)
val defs = the_list getter_def @ the_list setter_def
val thm = Goal.prove_future lthy ["s", "p", "v"] [] prop
(fn params => (simp_tac ((#context params) addsimps
@{thms ext fun_upd_apply} @ defs) 1))
in SOME thm end
fun new_heap_valid_thm valid_term (_, _, setter, NONE) lthy = NONE
| new_heap_valid_thm valid_term (_, _, setter, SOME setter_def) lthy =
let val prop = @{mk_term "Trueprop (?valid (?set s p v) q = ?valid s q)" (valid, set)}
(Const valid_term, setter)
val thm = Goal.prove_future lthy ["s", "p", "v", "q"] [] prop
(fn params => (simp_tac ((#context params) addsimps
[@{thm fun_upd_apply}, setter_def]) 1))
in SOME thm end
(* Take a definition and eta contract the RHS:
lhs = rhs s ==> (%s. lhs) = rhs
This allows us to rewrite a heap update even if the state is eta contracted away. *)
fun eta_rhs lthy thm = let
val Const (@{const_name "Pure.eq"}, typ) $ lhs $ (rhs $ Var (("s", s_n), s_typ)) = term_of_thm thm
val abs_term = @{mk_term "?a == ?b" (a, b)} (lambda (Var (("s", s_n), s_typ)) lhs, rhs)
val thm' = Goal.prove_future lthy [] [] abs_term
(fn params => simp_tac (put_simpset HOL_basic_ss (#context params) addsimps thm :: @{thms atomize_eq ext}) 1)
in thm' end
(* Extract the abstract term out of a L2Tcorres thm. *)
fun dest_L2Tcorres_term_abs @{term_pat "L2Tcorres _ ?t _"} = t
(* Generate lifted_globals lemmas and instantiate them into the heap lifting rules. *)
fun lifted_globals_lemmas prog_info heap_info lthy = let
(* Tactic to solve subgoals below. *)
local
(* Fetch simp rules generated by the C Parser about structures. *)
val struct_simpset = UMM_Proof_Theorems.get (Proof_Context.theory_of lthy)
fun lookup_the t k = case Symtab.lookup t k of SOME x => x | NONE => []
val struct_simps =
(lookup_the struct_simpset "typ_name_simps")
@ (lookup_the struct_simpset "typ_name_itself")
@ (lookup_the struct_simpset "fl_ti_simps")
@ (lookup_the struct_simpset "fl_simps")
@ (lookup_the struct_simpset "fg_cons_simps")
val base_ss = simpset_of @{theory_context HeapLift}
val record_ss = RecordUtils.get_record_simpset lthy
val merged_ss = merge_ss (base_ss, record_ss)
(* Generate a simpset containing everything we need. *)
val ss =
(Context_Position.set_visible false lthy)
|> put_simpset merged_ss
|> (fn ctxt => ctxt
addsimps [#lift_fn_thm heap_info]
@ @{thms typ_simple_heap_simps}
@ @{thms valid_globals_field_def}
@ @{thms the_fun_upd_lemmas}
@ struct_simps)
|> simpset_of
in
fun subgoal_solver_tac ctxt =
(fast_force_tac (put_simpset ss ctxt) 1)
ORELSE (CHANGED (Method.try_intros_tac ctxt [@{thm conjI}, @{thm ext}] []
THEN clarsimp_tac (put_simpset ss ctxt) 1))
end
(* Generate "valid_typ_heap" predicates for each heap type we have. *)
fun mk_valid_typ_heap_thm typ =
@{mk_term "Trueprop (valid_typ_heap ?st ?getter ?setter ?valid_getter ?valid_setter ?t_hrs ?t_hrs_update)"
(st, getter, setter, valid_getter, valid_setter, t_hrs, t_hrs_update)}
(#lift_fn_full heap_info,
get_heap_getter heap_info typ,
get_heap_setter heap_info typ,
get_heap_valid_getter heap_info typ,
get_heap_valid_setter heap_info typ,
#t_hrs_getter prog_info,
#t_hrs_setter prog_info)
|> (fn prop => Goal.prove_future lthy [] [] prop
(fn params =>
((resolve_tac lthy @{thms valid_typ_heapI} 1) THEN (
REPEAT (subgoal_solver_tac (#context params))))))
(* Make thms for all types. *)
(* FIXME: these are currently auto-parallelised using prove_future, but perhaps
* we should exercise finer control over the evaluation, as prove_futures
* persist long after the AutoCorres command actually returns. *)
val heap_types = (#heap_getters heap_info |> Typtab.dest |> map fst)
val valid_typ_heap_thms = map mk_valid_typ_heap_thm heap_types
(* Generate "valid_typ_heap" thms for signed words. *)
val valid_typ_heap_thms =
valid_typ_heap_thms
@ (map_product
(fn a => fn b => try (fn _ => a OF [b]) ())
@{thms signed_valid_typ_heaps}
valid_typ_heap_thms
|> map_filter I)
(* Generate "valid_struct_field" for each field of each struct. *)
fun mk_valid_struct_field_thm struct_name typ (field_info : HeapLiftBase.field_info) =
@{mk_term "Trueprop (valid_struct_field ?st [?fname] ?fgetter ?fsetter ?t_hrs ?t_hrs_update)"
(st, fname, fgetter, fsetter, t_hrs, t_hrs_update) }
(#lift_fn_full heap_info,
Utils.encode_isa_string (#name field_info),
#getter field_info,
#setter field_info,
#t_hrs_getter prog_info,
#t_hrs_setter prog_info)
|> (fn prop =>
(* HACK: valid_struct_field currently works only for packed types,
* so typecheck the prop first *)
case try (Syntax.check_term lthy) prop of
SOME _ =>
[Goal.prove_future lthy [] [] prop
(fn params =>
(resolve_tac lthy @{thms valid_struct_fieldI} 1) THEN
(* Need some extra thms from the records package for our struct type. *)
(EqSubst.eqsubst_tac lthy [0]
[hd (Proof_Context.get_thms lthy (struct_name ^ "_idupdates")) RS @{thm sym}] 1
THEN asm_full_simp_tac lthy 1) THEN
(FIRST (Proof_Context.get_thms lthy (struct_name ^ "_fold_congs")
|> map (fn t => resolve_tac lthy [t OF @{thms refl refl}] 1))
THEN asm_full_simp_tac lthy 1) THEN
(REPEAT (subgoal_solver_tac (#context params))))]
| NONE => [])
(* Generate "valid_struct_field_legacy" for each field of each struct. *)
fun mk_valid_struct_field_legacy_thm typ (field_info : HeapLiftBase.field_info) =
@{mk_term "Trueprop (valid_struct_field_legacy ?st [?fname] ?fgetter (%v. ?fsetter (%_. v)) ?getter ?setter ?valid_getter ?valid_setter ?t_hrs ?t_hrs_update)"
(st, fname, fgetter, fsetter, getter, setter, valid_getter, valid_setter, t_hrs, t_hrs_update) }
(#lift_fn_full heap_info,
Utils.encode_isa_string (#name field_info),
#getter field_info,
#setter field_info,
get_heap_getter heap_info typ,
get_heap_setter heap_info typ,
get_heap_valid_getter heap_info typ,
get_heap_valid_setter heap_info typ,
#t_hrs_getter prog_info,
#t_hrs_setter prog_info)
|> (fn prop => Goal.prove_future lthy [] [] prop
(fn params =>
(resolve_tac lthy @{thms valid_struct_field_legacyI} 1) THEN (
REPEAT (subgoal_solver_tac (#context params)))))
(* Make thms for all fields of structs in our heap. *)
fun valid_struct_abs_thms T =
case (Typtab.lookup (#struct_types heap_info) T) of
NONE => []
| SOME struct_info =>
map (fn field =>
mk_valid_struct_field_thm (#name struct_info) T field
@ [mk_valid_struct_field_legacy_thm T field])
(#field_info struct_info)
|> List.concat
val valid_field_thms =
map valid_struct_abs_thms heap_types |> List.concat
(* Generate conversions from globals embedded directly in the "globals" and
* "lifted_globals" record. *)
fun mk_valid_globals_field_thm name =
@{mk_term "Trueprop (valid_globals_field ?st ?old_get ?old_set ?new_get ?new_set)"
(st, old_get, old_set, new_get, new_set)}
(#lift_fn_full heap_info,
Symtab.lookup (#global_field_getters heap_info) name |> the |> fst,
Symtab.lookup (#global_field_setters heap_info) name |> the |> fst,
Symtab.lookup (#global_field_getters heap_info) name |> the |> snd,
Symtab.lookup (#global_field_setters heap_info) name |> the |> snd)
|> (fn prop => Goal.prove_future lthy [] [] prop (fn params => subgoal_solver_tac (#context params)))
val valid_global_field_thms = map (#1 #> mk_valid_globals_field_thm) (#global_fields heap_info)
(* At this point, the lemmas are ready to be instantiated into the generic
* heap_abs rules (which will be fetched from the most recent lthy). *)
in
[ valid_typ_heap_thms, valid_field_thms, valid_global_field_thms ]
end;
(*
* Prepare for the heap lifting phase.
* We need to:
* - define a lifted_globals type
* - prove generic heap lifting lemmas for the lifted_globals type
* - define heap syntax and rewrite rules (if heap_abs_syntax is set)
* - store these new results into the HeapInfo theory data
* Note that because we are adding definitions that are required by all
* conversions, we need to wait for all previous L2 conversions to finish,
* limiting parallelism somewhat. This requires us to modify l2_results by
* updating its intermediate lthys.
*
* These results are cached in the local theory, so we attempt to fetch an
* existing definition (in the case that we are resuming a previous run).
* In this scenario, we don't have to modify l2_results.
*)
fun prepare_heap_lift
(filename : string)
(prog_info : ProgramInfo.prog_info)
(l2_results : (local_theory * FunctionInfo.function_info Symtab.table) FSeq.fseq)
(* An initial lthy, used to check for an existing heap_lift_setup.
* Also used as fallback target in the unlikely case where l2_results = [] *)
(lthy0 : local_theory)
(* We define the lifted heap for all functions in the program, even if they are
* not included in this translation. This allows heap lifting to work with
* incremental translations. *)
(all_simpl_infos : FunctionInfo.function_info Symtab.table)
(* Settings *)
(make_lifted_globals_field_name : string -> string)
(gen_word_heaps : bool)
(heap_abs_syntax : bool)
: ((local_theory * FunctionInfo.function_info Symtab.table) FSeq.fseq
* HeapLiftBase.heap_lift_setup) =
let
(* Get target lthy for adding new definitions.
* This is the most recent l2_results lthy, except if there are no results,
* in which case the fallback lthy is used. *)
fun get_target_lthy l2_results fallback_lthy =
if FSeq.null l2_results then Option.getOpt (fallback_lthy, lthy0)
else fst (List.last (FSeq.list_of l2_results));
fun update_results lthy = FSeq.map (apfst (K lthy));
(* Set up heap_info and associated lemmas *)
val (l2_results, HL_setup, fallback_lthy) =
case Symtab.lookup (HeapInfo.get (Proof_Context.theory_of lthy0)) filename of
SOME HL_setup => (l2_results, HL_setup, NONE)
| NONE => let
val lthy = get_target_lthy l2_results NONE;
val (heap_info, lthy) = HeapLiftBase.setup prog_info all_simpl_infos
make_lifted_globals_field_name gen_word_heaps lthy;
val lifted_heap_lemmas = lifted_globals_lemmas prog_info heap_info lthy;
val HL_setup = { heap_info = heap_info,
lifted_heap_lemmas = lifted_heap_lemmas,
heap_syntax_rewrs = [] };
val lthy = Local_Theory.background_theory (
HeapInfo.map (fn tbl => Symtab.update (filename, HL_setup) tbl)) lthy;
in (update_results lthy l2_results, HL_setup, SOME lthy) end;
(* Do some extra lifting and create syntax (see field_syntax comment).
* We do this separately because heap_abs_syntax could be enabled halfway
* through an incremental translation. *)
val (l2_results, HL_setup, fallback_lthy) =
if not heap_abs_syntax orelse not (null (#heap_syntax_rewrs HL_setup))
then (l2_results, HL_setup, fallback_lthy)
else
let val lthy = get_target_lthy l2_results fallback_lthy;
val (heap_syntax_rewrs, lthy) =
Utils.exec_background_result (fn lthy => let
val optcat = List.mapPartial I
val heap_info = #heap_info HL_setup
(* Define the new heap operations and their syntax. *)
val (new_getters, new_setters, lthy) =
make_heap_syntax heap_info lthy
(* Make simplification thms and add them to the simpset. *)
val update_thms = map (fn get => map (fn set => new_heap_update_thm get set lthy)
(Symtab.dest new_setters |> map snd))
(Symtab.dest new_getters |> map snd)
|> List.concat
val valid_thms = map (fn valid => map (fn set => new_heap_valid_thm valid set lthy)
(Symtab.dest new_setters |> map snd))
(Typtab.dest (#heap_valid_getters heap_info) |> map snd)
|> List.concat
val thms = update_thms @ valid_thms |> optcat
val lthy = Utils.simp_add thms lthy
(* Name the thms. (FIXME: do this elsewhere?) *)
val (_, lthy) = Utils.define_lemmas "heap_abs_simps" thms lthy
(* Rewrite rules for converting the program. *)
val getter_thms = Symtab.dest new_getters |> map (#4 o snd) |> optcat
val setter_thms = Symtab.dest new_setters |> map (#4 o snd) |> optcat
val eta_setter_thms = map (eta_rhs lthy) setter_thms
val rewrite_thms = map (fn thm => @{thm symmetric} OF [thm])
(getter_thms @ eta_setter_thms)
in (rewrite_thms, lthy) end) lthy;
val HL_setup = { heap_info = #heap_info HL_setup,
lifted_heap_lemmas = #lifted_heap_lemmas HL_setup,
heap_syntax_rewrs = heap_syntax_rewrs };
val lthy = Local_Theory.background_theory (
HeapInfo.map (fn tbl => Symtab.update (filename, HL_setup) tbl)) lthy;
in (update_results lthy l2_results, HL_setup, SOME lthy) end;
in (l2_results, HL_setup) end;
(* Convert a program to use a lifted heap. *)
fun translate
(filename : string)
(prog_info : ProgramInfo.prog_info)
(l2_results : FunctionInfo.phase_results)
(existing_l2_infos : FunctionInfo.function_info Symtab.table)
(existing_hl_infos : FunctionInfo.function_info Symtab.table)
(HL_setup : HeapLiftBase.heap_lift_setup)
(no_heap_abs : Symset.key Symset.set)
(force_heap_abs : Symset.key Symset.set)
(heap_abs_syntax : bool)
(keep_going : bool)
(trace_funcs : string list)
(do_opt : bool)
(trace_opt : bool)
(add_trace: string -> string -> AutoCorresData.Trace -> unit)
(hl_function_name : string -> string)
: FunctionInfo.phase_results =
if FSeq.null l2_results then FSeq.empty () else
let
(* lthy for conversion rules. This needs to be (at latest) the earliest lthy
* result so that the rules can be used in all conversions *)
val lthy0 = fst (FSeq.hd l2_results);
val heap_info = #heap_info HL_setup;
(*
* Fetch rules from the theory, instantiating any rule with the
* lifted_globals lemmas for "valid_globals_field", "valid_typ_heap" etc.
* that we generated previously.
*)
val base_rules = Utils.get_rules lthy0 @{named_theorems heap_abs}
val rules =
cross_instantiate base_rules (#lifted_heap_lemmas HL_setup)
(* Remove rules that haven't been fully instantiated *)
|> filter_out (Thm.prop_of #> exists_subterm (fn x =>
case x of Const (@{const_name "valid_globals_field"}, _) => true
| Const (@{const_name "valid_struct_field"}, _) => true
| Const (@{const_name "valid_struct_field_legacy"}, _) => true
| Const (@{const_name "valid_typ_heap"}, _) => true
| _ => false));
(* We only use this blanket rule for non-lifted functions;
* liftable expressions can be handled by specific struct_rewrite rules *)
val nolift_rules = @{thms struct_rewrite_expr_id}
(* This does a linear search. We will only need it in is_function_lifted, though *)
fun lookup_l2_results f_name =
FSeq.find (fn (_, l2_infos) => Symtab.defined l2_infos f_name) l2_results
|> the' ("HL: missing L2 results for " ^ f_name);
fun is_function_lifted f_name =
case Symtab.lookup existing_hl_infos f_name of
SOME info => let
(* We heap-lifted this function earlier. Check its state type. *)
val body = #definition info |> Thm.prop_of |> Utils.rhs_of_eq;
val stT = LocalVarExtract.l2monad_state_type body;
in stT = #globals_type heap_info end
| NONE => let
val (lthy, l2_infos) = lookup_l2_results f_name;
val can_lift =
if can_lift_function lthy prog_info (the (Symtab.lookup l2_infos f_name))
then not (Symset.contains no_heap_abs f_name)
else if Symset.contains force_heap_abs f_name
then true
else (* Report functions that we're not lifting,
* but not if the user has overridden explicitly *)
(if Symset.contains no_heap_abs f_name then () else
writeln ("HL: disabling heap lift for: " ^ f_name ^
" (use force_heap_abs to enable)");
false);
in can_lift end;
(* Cache answers for which functions we are lifting. *)
val is_function_lifted = String_Memo.memo is_function_lifted;
(* Convert to new heap format. *)
fun convert lthy l2_infos f: AutoCorresUtil.convert_result =
let
val f_l2_info = the (Symtab.lookup l2_infos f);
(* Fix measure variable. *)
val ([measure_var_name], lthy) = Variable.variant_fixes ["rec_measure'"] lthy;
val measure_var = Free (measure_var_name, AutoCorresUtil.measureT);
(* Add callee assumptions. *)
val (lthy, export_thm, callee_terms) =
AutoCorresUtil.assume_called_functions_corres lthy
(#callees f_l2_info) (#rec_callees f_l2_info)
(get_expected_hl_fn_type prog_info l2_infos heap_info is_function_lifted)
(get_expected_hl_fn_thm prog_info l2_infos heap_info is_function_lifted)
(get_expected_hl_fn_args prog_info l2_infos)
hl_function_name
measure_var;
(* Fix argument variables. *)
val new_fn_args = get_expected_hl_fn_args prog_info l2_infos f;
val (arg_names, lthy) = Variable.variant_fixes (map fst new_fn_args) lthy;
val arg_frees = map Free (arg_names ~~ map snd new_fn_args);
(* Fetch the function definition. *)
val l2_body_def =
#definition f_l2_info
(* Instantiate the arguments. *)
|> Utils.inst_args lthy (map (Thm.cterm_of lthy) (measure_var :: arg_frees))
(* Get L2 body definition with function arguments. *)
val l2_term = betapplys (#const f_l2_info, measure_var :: arg_frees)
(* Get our state translation function. *)
val st = get_expected_st heap_info is_function_lifted f
(* Generate a schematic goal. *)
val goal = @{mk_term "Trueprop (L2Tcorres ?st ?A ?C)" (st, C)}
(st, l2_term)
|> Thm.cterm_of lthy
|> Goal.init
|> Utils.apply_tac "unfold RHS" (EqSubst.eqsubst_tac lthy [0] [l2_body_def] 1)
val callee_mono_thms =
callee_terms |> map fst
|> List.mapPartial (fn callee =>
if FunctionInfo.is_function_recursive (the (Symtab.lookup l2_infos callee))
then #mono_thm (the (Symtab.lookup l2_infos callee))
else NONE)
val rules = rules @ (map (snd #> #3) callee_terms) @ callee_mono_thms
val rules = if is_function_lifted f then rules else rules @ nolift_rules
val fo_rules = Utils.get_rules lthy @{named_theorems heap_abs_fo}
(* Apply a conversion to the concrete side of the given L2T term.
* By convention, the concrete side is the last argument (index ~1). *)
fun l2t_conc_body_conv conv =
Conv.params_conv (~1) (K (Conv.arg_conv (Utils.nth_arg_conv (~1) conv)))
(* Standard tactics. *)
val print_debug = f = ""
fun rtac_all r n = (APPEND_LIST (map (fn thm =>
resolve_tac lthy [thm] n THEN (fn x =>
(if print_debug then @{trace} thm else ();
Seq.succeed x))) r))
(* Convert the concrete side of the given L2T term to/from first-order form. *)
val l2t_to_fo_tac = CONVERSION (Drule.beta_eta_conversion then_conv l2t_conc_body_conv (mk_first_order lthy) lthy)
val l2t_from_fo_tac = CONVERSION (l2t_conc_body_conv (dest_first_order lthy then_conv Drule.beta_eta_conversion) lthy)
val fo_tac = ((l2t_to_fo_tac THEN' rtac_all fo_rules) THEN_ALL_NEW l2t_from_fo_tac) 1
(*
* Recursively solve subgoals.
*
* We allow backtracking in order to solve a particular subgoal, but once a
* subgoal is completed we don't ever try to solve it in a different way.
*
* This allows us to try different approaches to solving subgoals without
* leading to exponential explosion (of many different combinations of
* "good solutions") once we hit an unsolvable subgoal.
*)
val tactics =
if #is_simpl_wrapper f_l2_info
then (* Solver for trivial simpl wrappers. *)
[(@{thm L2Tcorres_id}, resolve_tac lthy [@{thm L2Tcorres_id}] 1)]
else map (fn rule => (rule, resolve_tac lthy [rule] 1)) rules
@ [(@{thm fun_app_def}, fo_tac)]
val replay_failure_start = 1
val replay_failures = Unsynchronized.ref replay_failure_start
val (thm, trace) =
case AutoCorresTrace.maybe_trace_solve_tac lthy (member (op =) trace_funcs f)
true false (K tactics) goal NONE replay_failures of
NONE => (* intentionally generate a TRACE_SOLVE_TAC_FAIL *)
(AutoCorresTrace.trace_solve_tac lthy false false (K tactics) goal NONE (Unsynchronized.ref 0);
(* never reached *) error "heap_lift fail tac: impossible")
| SOME (thm, [trace]) => (Goal.finish lthy thm, trace)
val _ = if !replay_failures < replay_failure_start then
warning ("HL: " ^ f ^ ": reverted to slow replay " ^
Int.toString(replay_failure_start - !replay_failures) ^ " time(s)") else ()
(* DEBUG: make sure that all uses of field_lvalue and c_guard are rewritten.
* Also make sure that we cleaned up internal constants. *)
fun contains_const name = exists_subterm (fn x => case x of Const (n, _) => n = name | _ => false)
fun const_gone term name =
if not (contains_const name term) then ()
else Utils.TERM_non_critical keep_going
("Heap lift: could not remove " ^ name ^ " in " ^ f ^ ".") [term]
fun const_old_heap term name =
if not (contains_const name term) then ()
else warning ("Heap lift: could not remove " ^ name ^ " in " ^ f ^
". Output program may be unprovable.")
val _ = if not (is_function_lifted f) then []
else (map (const_gone (term_of_thm thm))
[@{const_name "heap_lift__h_val"}];
map (const_old_heap (term_of_thm thm))
[@{const_name "field_lvalue"}, @{const_name "c_guard"}]
)
(* Apply peephole optimisations to the theorem. *)
val _ = writeln ("Simplifying (HL) " ^ f)
val (thm, opt_traces) = L2Opt.cleanup_thm_tagged lthy thm (if do_opt then 0 else 2) 2 trace_opt "HL"
(* If we created extra heap wrappers, apply them now.
* Our simp rules don't seem to be enough for L2Opt,
* so we cannot change the program before that. *)
val thm = if not heap_abs_syntax then thm else
Raw_Simplifier.rewrite_rule lthy (#heap_syntax_rewrs HL_setup) thm
val f_body = dest_L2Tcorres_term_abs (HOLogic.dest_Trueprop (Thm.concl_of thm));
(* Get actual recursive callees *)
val rec_callees = AutoCorresUtil.get_rec_callees callee_terms f_body;
(* Return the constants that we fixed. This will be used to process the returned body. *)
val callee_consts =
callee_terms |> map (fn (callee, (_, const, _)) => (callee, const)) |> Symtab.make;
in
{ body = f_body,
proof = Morphism.thm export_thm thm,
rec_callees = rec_callees,
callee_consts = callee_consts,
arg_frees = map dest_Free (measure_var :: arg_frees),
traces = (if member (op =) trace_funcs f
then [("HL", AutoCorresData.RuleTrace trace)] else []) @ opt_traces
}
end
(* Define a previously-converted function (or recursive function group).
* lthy must include all definitions from hl_callees. *)
fun define
(lthy: local_theory)
(l2_infos: FunctionInfo.function_info Symtab.table)
(hl_callees: FunctionInfo.function_info Symtab.table)
(funcs: AutoCorresUtil.convert_result Symtab.table)
: FunctionInfo.function_info Symtab.table * local_theory = let
val funcs' = Symtab.dest funcs |>
map (fn result as (name, {proof, arg_frees, ...}) =>
(name, (AutoCorresUtil.abstract_fn_body l2_infos result,
proof, arg_frees)));
val (new_thms, lthy') =
AutoCorresUtil.define_funcs
FunctionInfo.HL filename l2_infos hl_function_name
(get_expected_hl_fn_type prog_info l2_infos heap_info is_function_lifted)
(get_expected_hl_fn_thm prog_info l2_infos heap_info is_function_lifted)
(get_expected_hl_fn_args prog_info l2_infos)
@{thm L2Tcorres_recguard_0}
lthy (Symtab.map (K #corres_thm) hl_callees)
funcs';
val new_infos = Symtab.map (fn f_name => fn (const, def, corres_thm) => let
val old_info = the (Symtab.lookup l2_infos f_name);
in old_info
|> FunctionInfo.function_info_upd_phase FunctionInfo.HL
|> FunctionInfo.function_info_upd_const const
|> FunctionInfo.function_info_upd_definition def
|> FunctionInfo.function_info_upd_corres_thm corres_thm
|> FunctionInfo.function_info_upd_mono_thm NONE (* added later *)
end) new_thms;
in (new_infos, lthy') end;
(* Do conversions in parallel. *)
val converted_groups = AutoCorresUtil.par_convert convert existing_l2_infos l2_results add_trace;
(* Sequence of new function_infos and intermediate lthys *)
val def_results = FSeq.mk (fn _ =>
(* If there's nothing to translate, we won't have a lthy to use *)
if FSeq.null l2_results then NONE else
let (* Get initial lthy from end of L2 defs *)
val (l2_lthy, _) = FSeq.list_of l2_results |> List.last;
val results = AutoCorresUtil.define_funcs_sequence
l2_lthy define existing_l2_infos existing_hl_infos converted_groups;
in FSeq.uncons results end);
(* Produce a mapping from each function group to its L1 phase_infos and the
* earliest intermediate lthy where it is defined. *)
val results =
def_results
|> FSeq.map (fn (lthy, f_defs) => let
(* Add monad_mono proofs. These are done in parallel as well
* (though in practice, they already run pretty quickly). *)
val mono_thms = if FunctionInfo.is_function_recursive (snd (hd (Symtab.dest f_defs)))
then LocalVarExtract.l2_monad_mono lthy f_defs
else Symtab.empty;
val f_defs' = f_defs |> Symtab.map (fn f =>
FunctionInfo.function_info_upd_mono_thm (Symtab.lookup mono_thms f));
in (lthy, f_defs') end);
in results end
end