236 lines
8.2 KiB
Standard ML
236 lines
8.2 KiB
Standard ML
(*
|
|
* Copyright 2014, NICTA
|
|
*
|
|
* This software may be distributed and modified according to the terms of
|
|
* the BSD 2-Clause license. Note that NO WARRANTY is provided.
|
|
* See "LICENSE_BSD2.txt" for details.
|
|
*
|
|
* @TAG(NICTA_BSD)
|
|
*)
|
|
|
|
(*
|
|
* Code to manage converting between L2_monad and other monad types.
|
|
*
|
|
* TypeStrengthen provides a higher level interface for converting entire programs.
|
|
*)
|
|
|
|
structure Monad_Convert = struct
|
|
|
|
(* Utilities. *)
|
|
fun intersperse _ [] = []
|
|
| intersperse _ [x] = [x]
|
|
| intersperse a (x::xs) = x :: a :: intersperse a xs
|
|
|
|
fun theE NONE exc = raise exc
|
|
| theE (SOME x) _ = x
|
|
|
|
fun oneE [] exc = raise exc
|
|
| oneE (x::_) _ = x
|
|
|
|
|
|
|
|
(* From Find_Theorems *)
|
|
fun apply_dummies tm =
|
|
let
|
|
val (xs, _) = Term.strip_abs tm;
|
|
val tm' = Term.betapplys (tm, map (Term.dummy_pattern o #2) xs);
|
|
in #1 (Term.replace_dummy_patterns tm' 1) end;
|
|
|
|
fun parse_pattern ctxt nm =
|
|
let
|
|
val consts = Proof_Context.consts_of ctxt;
|
|
val nm' =
|
|
(case Syntax.parse_term ctxt nm of
|
|
Const (c, _) => c
|
|
| _ => Consts.intern consts nm);
|
|
in
|
|
(case try (Consts.the_abbreviation consts) nm' of
|
|
SOME (_, rhs) => apply_dummies (Proof_Context.expand_abbrevs ctxt rhs)
|
|
| NONE => Proof_Context.read_term_pattern ctxt nm)
|
|
end;
|
|
|
|
(* Breadth-first term search *)
|
|
fun term_search_bf cont pred prune = let
|
|
fun fresh_var vars v = if member (op =) vars v then fresh_var vars (v ^ "'") else v
|
|
fun search ((vars, term), queue) =
|
|
if pred term then cont (vars, term) (fn () => walk queue) else
|
|
if prune term then walk queue else
|
|
case term of
|
|
t as Abs (v, typ, _) =>
|
|
let val v' = fresh_var vars v in
|
|
walk (Queue.enqueue
|
|
((v'::vars), betapply (t, Free (v', typ))) queue)
|
|
end
|
|
| f $ x => walk (Queue.enqueue (vars, x) (Queue.enqueue (vars, f) queue))
|
|
| _ => walk queue
|
|
and walk queue = if Queue.is_empty queue then () else search (Queue.dequeue queue)
|
|
in
|
|
(fn term => search (([], term), Queue.empty))
|
|
end
|
|
|
|
fun term_search_bf_first pred prune term = let
|
|
val r = Unsynchronized.ref NONE
|
|
val _ = term_search_bf (fn result => K (r := SOME result)) pred prune term
|
|
in !r end
|
|
|
|
(* From Pure/Tools/find_theorems.ML, because Florian made it private *)
|
|
fun matches_subterm thy (pat, obj) =
|
|
let
|
|
fun msub bounds obj = Pattern.matches thy (pat, obj) orelse
|
|
(case obj of
|
|
Abs (_, T, t) => msub (bounds + 1) (snd (Term.dest_abs (Name.bound bounds, T, t)))
|
|
| t $ u => msub bounds t orelse msub bounds u
|
|
| _ => false)
|
|
in msub 0 obj end;
|
|
|
|
fun grep_term ctxt pattern =
|
|
let
|
|
val thy = Proof_Context.theory_of ctxt
|
|
in
|
|
term_search_bf_first
|
|
(fn term => Pattern.matches thy (pattern, term))
|
|
(fn term => not (matches_subterm thy (pattern, term)))
|
|
end
|
|
|
|
(* Check whether the term is in L2_monad notation. *)
|
|
val term_is_L2 = Monad_Types.check_lifting_head
|
|
[@{term "L2_unknown"}, @{term "L2_seq"}, @{term "L2_modify"},
|
|
@{term "L2_gets"}, @{term "L2_condition"}, @{term "L2_catch"}, @{term "L2_while"},
|
|
@{term "L2_throw"}, @{term "L2_spec"}, @{term "L2_guard"}, @{term "L2_fail"},
|
|
@{term "L2_recguard"}, @{term "L2_call"}]
|
|
|
|
(*
|
|
* Perform monad conversion on a term, taking into account any extra
|
|
* simplifying facts. Only a successful conversion is returned.
|
|
*
|
|
* For this conversion to be useful on recursive programs, it needs
|
|
* to be given a fact representing the inductive assumption.
|
|
*)
|
|
fun monad_rewrite (lthy : Proof.context) (mt : Monad_Types.monad_type)
|
|
(more_facts : thm list) (forward : bool)
|
|
(term : term) : thm option =
|
|
let
|
|
val lthy = Utils.set_hidden_ctxt lthy
|
|
val rules = Monad_Types.monad_type_rules mt
|
|
val rules' = if forward then #lift_rules rules else #unlift_rules rules
|
|
val cterm = Thm.cterm_of lthy term
|
|
(* Just apply the simplifier and hope that it works. *)
|
|
val thm = Simplifier.rewrite (
|
|
put_simpset HOL_ss lthy addsimps rules' addsimps more_facts) cterm
|
|
val rhs = Utils.rhs_of (term_of_thm thm)
|
|
val good_rewrite = if forward then #valid_term mt else term_is_L2
|
|
in
|
|
if good_rewrite lthy rhs
|
|
then SOME thm else NONE
|
|
end
|
|
|
|
(*
|
|
* Apply polish to a theorem of the form:
|
|
*
|
|
* <LHS> == <lift> $ <some term to polish>
|
|
*
|
|
* Return the new theorem.
|
|
*)
|
|
local
|
|
val case_prod_eta_contract_thm =
|
|
@{lemma "(%x. (case_prod s) x) == (case_prod s)" by simp}
|
|
in
|
|
fun polish ctxt (mt : Monad_Types.monad_type) do_opt thm =
|
|
let
|
|
(* Apply any polishing rules. *)
|
|
val ctxt = Utils.set_hidden_ctxt ctxt
|
|
val simps = if do_opt then Utils.get_rules ctxt @{named_theorems polish} else []
|
|
|
|
(* Simplify using polish rules. *)
|
|
val rules = Monad_Types.monad_type_rules mt
|
|
val simp_conv = Simplifier.rewrite (put_simpset HOL_ss ctxt addsimps (#polish_rules rules @ simps))
|
|
|
|
(* eta-contract "case_prod clauses, so that they render as:
|
|
* "%(a, b). P a b" instead of "case x of (a, b) => P a b". *)
|
|
val case_prod_conv =
|
|
Conv.bottom_conv (
|
|
K (Conv.try_conv (Conv.rewr_conv case_prod_eta_contract_thm))) ctxt
|
|
|
|
val thm_p =
|
|
Conv.fconv_rule (Conv.arg_conv (Utils.rhs_conv (
|
|
simp_conv then_conv case_prod_conv))) thm
|
|
in
|
|
thm_p
|
|
end
|
|
end
|
|
|
|
(*
|
|
* Wrap a tactic that doesn't handle invalid subgoal numbers to return
|
|
* "Seq.empty" when appropriate.
|
|
*)
|
|
fun handle_invalid_subgoals (tac : int -> tactic) n =
|
|
fn thm =>
|
|
if Logic.count_prems (term_of_thm thm) < n then
|
|
no_tac thm
|
|
else
|
|
tac n thm
|
|
|
|
(*
|
|
* monad_convert tactic.
|
|
*)
|
|
fun monad_convert_tac (ctxt : Proof.context) (monad_name : string) (do_opt : bool)
|
|
(pattern_str : string) (n : int) : tactic =
|
|
fn state => let
|
|
val all_rules = Monad_Types.TSRules.get (Context.Proof ctxt)
|
|
|
|
(* Figure out which monad to lift into. *)
|
|
val target_rule = theE (Symtab.lookup all_rules monad_name)
|
|
(ERROR ("monad_convert: could not find monad type " ^ quote monad_name))
|
|
|
|
(* Search the subgoal for the supplied pattern. *)
|
|
val pattern = parse_pattern ctxt pattern_str
|
|
val subgoal = Logic.get_goal (term_of_thm state) n
|
|
val (m_vars, m_term) = theE (grep_term ctxt pattern subgoal)
|
|
(TERM ("monad_convert: failed to match pattern", [pattern]))
|
|
|
|
(* Find a lifting rule whose output matches m_term.
|
|
* This saves us from having to try every unlift rule. *)
|
|
val orig_lift_rule = oneE (filter (fn mt => #valid_term mt ctxt m_term)
|
|
(all_rules |> Symtab.dest |> map snd))
|
|
(TERM ("monad_convert: could not determine monad type", [m_term]))
|
|
|
|
(* Unlift back to L2_monad. *)
|
|
val unlift_thm = theE (monad_rewrite ctxt orig_lift_rule [] false m_term)
|
|
(TERM ("monad_convert: could not unlift term (rule: " ^
|
|
#name orig_lift_rule ^ ")", [m_term]))
|
|
val unlift_term = Utils.rhs_of (term_of_thm unlift_thm)
|
|
|
|
(* Lift to target monad. *)
|
|
val relift_thm = theE (monad_rewrite ctxt target_rule [] true unlift_term)
|
|
(TERM ("monad_convert: could not lift to " ^ #name target_rule,
|
|
[m_term, unlift_term]))
|
|
|
|
(* Polish result. *)
|
|
val relift_thm' = polish ctxt target_rule do_opt relift_thm
|
|
|
|
val translate_thm = Thm.transitive unlift_thm relift_thm'
|
|
|
|
(* Make variables schematic *)
|
|
val translate_thm' = Goal.prove ctxt (sort_distinct string_ord m_vars) []
|
|
(term_of_thm translate_thm)
|
|
(K (resolve_tac ctxt [translate_thm] 1))
|
|
|
|
val result = EqSubst.eqsubst_tac ctxt [0] [translate_thm'] n state
|
|
in
|
|
case Seq.pull result of
|
|
NONE => raise TERM ("monad_convert: failed to apply conversion",
|
|
[term_of_thm translate_thm', subgoal])
|
|
| SOME (x, xs) => Seq.cons x xs
|
|
end
|
|
|
|
val _ = Context.>> (Context.map_theory
|
|
(Method.setup (Binding.name "monad_convert")
|
|
(* Based on subgoal_tac parser *)
|
|
(Args.goal_spec -- Scan.lift (Parse.name -- Args.embedded_inner_syntax) >>
|
|
(fn (quant, (monad_name, term_str)) => fn ctxt =>
|
|
SIMPLE_METHOD'' quant (handle_invalid_subgoals (
|
|
monad_convert_tac ctxt monad_name true term_str))))
|
|
"autocorres monad conversion"))
|
|
|
|
end
|