(* * Copyright 2014, NICTA * * This software may be distributed and modified according to the terms of * the BSD 2-Clause license. Note that NO WARRANTY is provided. * See "LICENSE_BSD2.txt" for details. * * @TAG(NICTA_BSD) *) theory ExpandAll imports "~~/src/HOL/Main" begin lemma expand_forall: "\ \x. f (g x) = h x; (\x. f x) = P; surj g \ \ (\x. h x) = P" apply (simp add: surj_def) apply metis done lemma expand_exists: "\ \x. f (g x) = h x; (\x. f x) = P; surj g \ \ (\x. h x) = P" apply (simp add: surj_def) apply metis done lemma expand_one_split: "\ \a. (\b. P a b) = Q a \ \ (\v. (\(a, b). P a b) v) = (\a. Q a)" "\ \a. (\b. P a b) = Q a \ \ (\v. (\(a, b). P a b) v) = (\a. Q a)" by simp+ ML {* (* given patterns, e.g. fst, snd, replace \x. P (fst x) (snd x) with \x y. P x y more useful for things that aren't actually tuples, e.g. replace \xs. P (xs ! 0) (xs ! 2) with \x y. P x y pats here is a function for finding such pats ditto \x. P (fst x) (snd x) *) fun lambda_tuple [x] body = lambda x body | lambda_tuple (x::xs) body = HOLogic.mk_split (lambda x (lambda_tuple xs body)) | lambda_tuple [] body = raise TERM ("lambda_tuple: empty", [body]) fun expand_forall_pats ctxt pats tac t = let val (T, bdy, thm) = case t of (Const (@{const_name All}, T) $ bdy) => (T, bdy, @{thm expand_forall}) | (Const (@{const_name Ex}, T) $ bdy) => (T, bdy, @{thm expand_exists}) | _ => raise TERM ("expand_forall_pats: not All or Ex", [t]) val thy = Proof_Context.theory_of ctxt val x = Variable.variant_frees ctxt [bdy] [("x", domain_type (domain_type T))] |> the_single |> Free val bdy_x = betapply (bdy, x) val pat_xs = pats x bdy_x |> sort_distinct Term_Ord.fast_term_ord val ys = map (fn pat_x => ("y", fastype_of pat_x)) pat_xs |> Variable.variant_frees ctxt [bdy_x] |> map Free val f = Pattern.rewrite_term thy (pat_xs ~~ ys) [] bdy_x |> tap (fn f => exists_subterm (curry (op =) x) f andalso raise TERM ("expand_forall_pats: not all converted", [x] @ pat_xs @ ys)) |> lambda_tuple ys val g = lambda x (HOLogic.mk_tuple pat_xs) val numsplits = length pat_xs - 1 in cterm_instantiate [(@{cpat "?f\?'b \ bool"}, cterm_of thy f), (@{cpat "?g\?'a \ ?'b"}, cterm_of thy g), (@{cpat "?h\?'a \ bool"}, cterm_of thy bdy)] thm |> EVERY ( replicate numsplits (rtac @{thm trans[OF split_conv]} 1) @ [rtac @{thm refl} 1] @ replicate numsplits (resolve_tac @{thms expand_one_split} 1) @ [rtac @{thm refl} 1, tac pat_xs]) |> Seq.hd end fun mk_expand_forall_simproc s T pats tac thy = let val opT = (T --> HOLogic.boolT) --> HOLogic.boolT val P = Free ("P", T --> HOLogic.boolT) in Simplifier.simproc_i thy s [Const (@{const_name All}, opT) $ P, Const (@{const_name Ex}, opT) $ P] (fn _ => fn ss => try (let val ctxt = Simplifier.the_context ss in expand_forall_pats ctxt pats (tac ctxt) #> mk_meta_eq end)) end *} ML {* fun get_nths xs (t as (Const (@{const_name nth}, _) $ ys $ _)) = if xs aconv ys then [t] else [] | get_nths xs (t as (Const (@{const_name hd}, _) $ ys)) = if xs aconv ys then [t] else [] | get_nths xs (f $ x) = get_nths xs f @ get_nths xs x | get_nths xs (Abs (_, _, t)) = get_nths xs t | get_nths _ _ = [] *} lemma surj_via_mapI: "surj (\g. f (map g [0 ..< n])) \ surj (\xs. f xs)" by (auto simp add: surj_def) lemma surj_tup_apply_eq: "surj (\f. (f x, g f)) = (\v. surj (\f. g (f (x := v))))" apply (simp add: surj_def) apply (rule arg_cong[where f=All, OF ext])+ apply safe apply (metis fun_upd_triv) apply (rule_tac x="?f (x := ?y)" in exI, fastforce) done lemma surj_apply: "surj (\f. f x)" by (auto intro: surjI) lemma hd_map: "xs \ [] \ hd (map f xs) = f (hd xs)" by (clarsimp simp: neq_Nil_conv) ML {* fun inst_surj_via_mapI ctxt nths = let fun get_nth (f $ (@{term Suc} $ n)) = get_nth (f $ n) + 1 | get_nth (Const (@{const_name nth}, _) $ _ $ n) = HOLogic.dest_number n |> snd | get_nth (Const (@{const_name hd}, _) $ _) = 0 | get_nth t = raise TERM ("get_nth", [t]) val max_n = map get_nth nths |> foldr1 (uncurry Integer.max) val n = HOLogic.mk_number @{typ nat} (max_n + 1) |> cterm_of (ProofContext.theory_of ctxt) val t = cterm_instantiate [(@{cpat "?n\nat"}, n)] @{thm surj_via_mapI} val ss = @{simpset} addsimps @{thms surj_tup_apply_eq surj_apply hd_map} in rtac t 1 THEN simp_tac ss 1 end *} ML {* val t = expand_forall_pats @{context} get_nths (inst_surj_via_mapI @{context}) @{term "\xs. xs ! 1 + xs ! 3 + xs ! 42 + hd xs < (12 :: nat)"} *} ML {* val expand_forall_nths_simproc = mk_expand_forall_simproc "expand_forall_nths" @{typ "'a list"} get_nths inst_surj_via_mapI @{theory} *} lemma test: "(\xs. xs ! Suc 0 = 1 \ xs ! 2 = 3 \ P (xs ! Suc 0 + xs ! 2)) = P (1 + 3) \ (\xs. xs ! 3 = xs ! 4)" apply (tactic {* simp_tac (HOL_basic_ss addsimprocs [expand_forall_nths_simproc]) 1 *}) apply simp done end