lh-l4v/tools/autocorres/function_info.ML

422 lines
16 KiB
Standard ML

(*
* Copyright 2014, NICTA
*
* This software may be distributed and modified according to the terms of
* the BSD 2-Clause license. Note that NO WARRANTY is provided.
* See "LICENSE_BSD2.txt" for details.
*
* @TAG(NICTA_BSD)
*)
(*
* Information about functions in the program we are translating,
* and the call-graph between them.
* To support incremental translation, we store the function information
* for every intermediate phase as well.
*)
signature FUNCTION_INFO2 =
sig
(*** Basic data types ***)
(* List of AutoCorres phases. *)
datatype phase = CP (* Initial definition we get from the C parser *)
| L1 (* SimplConv *)
| L2 (* LocalVarExtract *)
| HL (* HeapLift *) (* TODO: rename to HeapAbstract *)
| WA (* WordAbstract *)
| TS (* TypeStrengthen *);
val string_of_phase : phase -> string;
val phase_ord : phase * phase -> order;
structure Phasetab : TABLE; (* currently unused *)
(* Function info for a single phase. *)
type function_info = {
(* Name of the function. *)
name : string,
(* The translation phase for this definition. *)
phase : phase,
(* Constant for the function, which can be inserted as a call to the
* function. Unlike "raw_const", this includes any locale parameters
* required by the function. *)
const: term,
(* Raw constant for the function. Existence of this constant in another
* function's body indicates that that function calls this one. *)
raw_const: term,
(* Arguments of the function, in order, excluding measure variables. *)
args : (string * typ) list,
(* Return type of the function ("unit" is used for void). *)
return_type : typ,
(* Function calls. Mutually recursive calls go in rec_callees. *)
callees : symset,
rec_callees : symset,
(* Definition of the function. *)
definition : thm,
(* corres theorem for the function. (TrueI when phase = CP.) *)
corres_thm : thm,
(* monad_mono theorem for the function, if it is recursive. *)
mono_thm : thm option,
(* Is this function actually being translated, or are we just
* wrapping the SIMPL code? *)
is_simpl_wrapper : bool,
(* Is this function generated by AutoCorres as a placeholder for
* a function we didn't have the source code to? *)
invented_body : bool
};
(* Standard result sequence that is passed between stages *)
type phase_results = (local_theory * function_info Symtab.table) FSeq.fseq;
val function_info_upd_name : string -> function_info -> function_info;
val function_info_upd_phase : phase -> function_info -> function_info;
(* Also updates raw_const. *)
val function_info_upd_const : term -> function_info -> function_info;
val function_info_upd_args : (string * typ) list -> function_info -> function_info;
val function_info_upd_return_type : typ -> function_info -> function_info;
val function_info_upd_callees : symset -> function_info -> function_info;
val function_info_upd_rec_callees : symset -> function_info -> function_info;
val function_info_upd_definition : thm -> function_info -> function_info;
val function_info_upd_mono_thm : thm option -> function_info -> function_info;
val function_info_upd_corres_thm : thm -> function_info -> function_info;
val function_info_upd_invented_body : bool -> function_info -> function_info;
val function_info_upd_is_simpl_wrapper : bool -> function_info -> function_info;
(* Convenience getters. *)
val is_function_recursive : function_info -> bool;
val all_callees : function_info -> symset;
(* Generate initial function_info from the C Parser's output. *)
val init_function_info : Proof.context -> string -> function_info Symtab.table;
type call_graph_info = {
(* Topologically sorted function calls, in dependency order.
* Each sub-list represents one function or recursive function group. *)
topo_sorted_functions : symset list,
(* Table mapping raw_consts to functions. *)
const_to_function : string Termtab.table,
(* Table mapping each recursive function to its recursive function group.
* Non-recursive functions do not appear in the table. *)
recursive_group_of : symset Symtab.table
};
(* Calculate call-graph information.
* Also updates the callees and rec_callees entries of its inputs,
* which are assumed to have outdated callee info.
*
* Ideally, we'd also have a pre_function_info type that doesn't have
* outdated callees, but dealing with ML records is annoying. *)
val calc_call_graph : function_info Symtab.table -> call_graph_info * function_info Symtab.table;
(* Update callees and rec_callees for the given recursive function group,
* relative to a set of background functions (that must not contain the
* given group).
*
* Returns a sequence of function groups in a valid topological order.
* We return multiple groups because function calls can be removed by
* dead code elimination and other transformations, which could cause
* the original group to split.*)
val recalc_callees :
function_info Symtab.table -> (* background *)
function_info Symtab.table -> (* group *)
function_info Symtab.table list;
end;
structure FunctionInfo : FUNCTION_INFO2 =
struct
datatype phase = CP | L1 | L2 | HL | WA | TS;
fun string_of_phase CP = "CP"
| string_of_phase L1 = "L1"
| string_of_phase L2 = "L2"
| string_of_phase HL = "HL"
| string_of_phase WA = "WA"
| string_of_phase TS = "TS";
fun encode_phase CP = 0
| encode_phase L1 = 1
| encode_phase L2 = 2
| encode_phase HL = 3
| encode_phase WA = 4
| encode_phase TS = 5;
val phase_ord = int_ord o apply2 encode_phase;
structure Phasetab = Table(
type key = phase
val ord = phase_ord);
type function_info = {
name : string,
phase : phase,
const : term,
raw_const : term,
args : (string * typ) list,
return_type : typ,
callees : symset,
rec_callees : symset,
definition : thm,
mono_thm : thm option,
corres_thm : thm,
invented_body : bool,
is_simpl_wrapper : bool
};
type phase_results = (local_theory * function_info Symtab.table) FSeq.fseq;
(* We use FunctionalRecordUpdate internally to define the setters *)
open FunctionalRecordUpdate;
local
fun from name phase const raw_const args return_type callees rec_callees
definition mono_thm corres_thm is_simpl_wrapper invented_body =
{ name = name,
phase = phase,
const = const,
raw_const = raw_const,
args = args,
return_type = return_type,
callees = callees,
rec_callees = rec_callees,
definition = definition,
mono_thm = mono_thm,
corres_thm = corres_thm,
is_simpl_wrapper = is_simpl_wrapper,
invented_body = invented_body };
fun from' invented_body is_simpl_wrapper corres_thm mono_thm definition
rec_callees callees return_type args raw_const const phase name =
{ name = name,
phase = phase,
const = const,
raw_const = raw_const,
args = args,
return_type = return_type,
callees = callees,
rec_callees = rec_callees,
definition = definition,
mono_thm = mono_thm,
corres_thm = corres_thm,
is_simpl_wrapper = is_simpl_wrapper,
invented_body = invented_body };
fun to f { name,
phase,
const,
raw_const,
args,
return_type,
callees,
rec_callees,
definition,
mono_thm,
corres_thm,
is_simpl_wrapper,
invented_body } =
f name phase const raw_const args return_type callees rec_callees
definition mono_thm corres_thm is_simpl_wrapper invented_body;
fun update x = makeUpdate13 (from, from', to) x;
in
fun function_info_upd_name name pinfo = update pinfo (U#name name) $$;
fun function_info_upd_phase phase pinfo = update pinfo (U#phase phase) $$;
fun function_info_upd_const_ const pinfo = update pinfo (U#const const) $$;
fun function_info_upd_raw_const_ raw_const pinfo = update pinfo (U#raw_const raw_const) $$;
fun function_info_upd_args args pinfo = update pinfo (U#args args) $$;
fun function_info_upd_return_type return_type pinfo = update pinfo (U#return_type return_type) $$;
fun function_info_upd_callees callees pinfo = update pinfo (U#callees callees) $$;
fun function_info_upd_rec_callees rec_callees pinfo = update pinfo (U#rec_callees rec_callees) $$;
fun function_info_upd_definition definition pinfo = update pinfo (U#definition definition) $$;
fun function_info_upd_mono_thm mono_thm pinfo = update pinfo (U#mono_thm mono_thm) $$;
fun function_info_upd_corres_thm corres_thm pinfo = update pinfo (U#corres_thm corres_thm) $$;
fun function_info_upd_is_simpl_wrapper is_simpl_wrapper pinfo = update pinfo (U#is_simpl_wrapper is_simpl_wrapper) $$;
fun function_info_upd_invented_body invented_body pinfo = update pinfo (U#invented_body invented_body) $$;
fun function_info_upd_const t = function_info_upd_const_ t o function_info_upd_raw_const_ (head_of t);
end;
fun is_function_recursive { rec_callees, ... } = not (Symset.is_empty rec_callees);
fun all_callees { rec_callees, callees, ... } = Symset.union rec_callees callees;
type call_graph_info = {
topo_sorted_functions : symset list,
const_to_function : string Termtab.table,
recursive_group_of : symset Symtab.table
};
fun calc_call_graph fn_infos = let
val const_to_function =
Symtab.dest fn_infos
|> map (fn (name, info) => (#raw_const info, name))
|> Termtab.make;
(* Get a function's direct callees, based on the list of constants that appear
* in its definition. *)
fun get_direct_callees fn_info = let
val body =
#definition fn_info
|> Thm.concl_of
|> Utils.rhs_of_eq;
in
(* Ignore function bodies if we are using SIMPL wrappers. *)
if #is_simpl_wrapper fn_info then [] else
Term.fold_aterms (fn t => fn a =>
(Termtab.lookup const_to_function t
|> Option.map single
|> the_default []) @ a) body []
|> distinct (op =)
end;
(* Call graph of all functions. *)
val fn_callees_lists = fn_infos |> Symtab.map (K get_direct_callees);
(* Add each function to its own callees to get a complete inverse *)
val fn_callers_lists = flip_symtab (Symtab.map cons fn_callees_lists);
val topo_sorted_functions =
Topo_Sort.topo_sort {
cmp = String.compare,
graph = Symtab.lookup fn_callees_lists #> the,
converse = Symtab.lookup fn_callers_lists #> the
} (Symtab.keys fn_callees_lists |> sort String.compare)
|> map Symset.make;
val fn_callees = Symtab.map (K Symset.make) fn_callees_lists;
fun is_recursive_singleton f =
Symset.contains (Utils.the' ("is_recursive_singleton: " ^ f)
(Symtab.lookup fn_callees f)) f;
val recursive_group_of =
topo_sorted_functions
|> maps (fn f_group =>
(* Exclude non-recursive functions *)
if Symset.card f_group = 1 andalso not (is_recursive_singleton (hd (Symset.dest f_group)))
then []
else Symset.dest f_group ~~ replicate (Symset.card f_group) f_group)
|> Symtab.make;
(* Now update callee info. *)
fun maybe_symset NONE = Symset.empty
| maybe_symset (SOME x) = x;
val fn_infos' =
fn_infos |> Symtab.map (fn f => let
val (rec_callees, callees) =
Symset.dest (Utils.the' ("not in fn_callees: " ^ f) (Symtab.lookup fn_callees f))
|> List.partition (Symset.contains (maybe_symset (Symtab.lookup recursive_group_of f)));
in function_info_upd_callees (Symset.make callees) o
function_info_upd_rec_callees (Symset.make rec_callees) end);
in ({ topo_sorted_functions = topo_sorted_functions,
const_to_function = const_to_function,
recursive_group_of = recursive_group_of
}, fn_infos')
end;
fun recalc_callees base_infos fn_infos = let
val base_consts =
Symtab.dest base_infos
|> map (fn (f, info) => (#raw_const info, f))
|> Termtab.make;
(* restrict_fn_infos has the correct call graph,
* but omits functions outside fn_infos *)
val (call_graph, restrict_fn_infos) = calc_call_graph fn_infos;
fun update_info f info = let
val restrict_info = the (Symtab.lookup restrict_fn_infos f);
(* Update calls into base_info *)
val f_body = Thm.prop_of (#definition info) |> Utils.rhs_of_eq;
val base_callees' =
if #is_simpl_wrapper info then Symset.empty else
Term.fold_aterms (fn t => fn a =>
(Termtab.lookup base_consts t
|> Option.map single
|> the_default []) @ a)
f_body []
|> Symset.make;
(* base_infos should not include fn_infos, otherwise this weird call would exist *)
val () = assert (not (Symset.contains base_callees' f))
"FunctionInfo.recalc_base_callees";
(* rec_callees has been recalculated *)
val rec_callees' = #rec_callees restrict_info;
(* Some rec_callees may have become callees due to breaking recursive loops *)
val callees' = Symset.union base_callees' (#callees restrict_info);
in info
|> function_info_upd_rec_callees rec_callees'
|> function_info_upd_callees callees'
end;
in #topo_sorted_functions call_graph
|> map (fn group =>
Symset.dest group
|> map (fn f => (f, update_info f (the (Symtab.lookup fn_infos f))))
|> Symtab.make)
end;
fun init_function_info ctxt filename = let
val thy = Proof_Context.theory_of ctxt;
val prog_info = ProgramInfo.get_prog_info ctxt filename;
val csenv = #csenv prog_info;
(* Get information about a single function. *)
fun gen_fn_info name (return_ctype, _, carg_list) = let
(* Convert C Parser return type into a HOL return type. *)
val return_type =
if return_ctype = Absyn.Void then
@{typ unit}
else
CalculateState.ctype_to_typ (thy, return_ctype);
(* Convert arguments into a list of (name, HOL type) pairs. *)
val arg_list = map (fn v =>
(ProgramAnalysis.get_mname v |> MString.dest,
CalculateState.ctype_to_typ (thy, ProgramAnalysis.get_vi_type v))
) carg_list;
(*
* Get constant, type signature and definition of the function.
*
* The definition may not exist if the function is declared "extern", but
* never defined. In this case, we replace the body of the function with
* what amounts to a "fail" command. Any C body is a valid refinement of
* this, allowing our abstraction to succeed.
*)
val const = Utils.get_term ctxt (name ^ "_'proc");
val myvars_typ = #state_type prog_info;
val (definition, invented) =
(Proof_Context.get_thm ctxt (name ^ "_body_def"), false)
handle ERROR _ =>
(Thm.instantiate ([((("'a", 0), ["HOL.type"]), Thm.ctyp_of ctxt myvars_typ)], [])
@{thm undefined_function_body_def}, true);
in {
name = name,
phase = CP,
args = arg_list,
return_type = return_type,
const = const,
raw_const = const,
callees = Symset.empty, (* filled in later *)
rec_callees = Symset.empty,
definition = definition,
mono_thm = NONE,
corres_thm = @{thm TrueI},
is_simpl_wrapper = false,
invented_body = invented
}
end
val raw_infos = ProgramAnalysis.get_fninfo csenv
|> Symtab.dest
|> map (uncurry gen_fn_info);
(* We discard the call graph info here.
* After calling init_function_info, we often want to change some of the entries,
* which usually requires recalculating it anyway. *)
val (_, fn_infos) =
calc_call_graph (Symtab.make (map (fn info => (#name info, info)) raw_infos));
in
fn_infos
end;
end; (* structure FunctionInfo *)