(* * Copyright 2023, Proofcraft Pty Ltd * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) * * SPDX-License-Identifier: BSD-2-Clause *) (* Monadic functions over lists: sequence, mapM, filter, etc Definitions, equations, Hoare logic and no_fail/empty_fail setup. *) theory Monad_Lists imports Monads.Nondet_In_Monad Monads.Nondet_Det Monads.Nondet_Empty_Fail Monads.Nondet_No_Fail begin lemma mapME_Cons: "mapME m (x # xs) = (doE y \ m x; ys \ (mapME m xs); returnOk (y # ys) odE)" by (simp add: mapME_def sequenceE_def Let_def) lemma mapME_Nil : "mapME f [] = returnOk []" unfolding mapME_def by (simp add: sequenceE_def) lemmas mapME_simps = mapME_Nil mapME_Cons lemma zipWithM_x_inv': assumes x: "\x y. m x y \P\" shows "zipWithM_x m xs ys \P\" proof (induct xs arbitrary: ys) case Nil show ?case by (simp add: zipWithM_x_def sequence_x_def zipWith_def) next case (Cons x xs) have zipWithM_x_Cons: "\m x xs y ys. zipWithM_x m (x # xs) (y # ys) = do m x y; zipWithM_x m xs ys od" by (simp add: zipWithM_x_def sequence_x_def zipWith_def) have zipWithM_x_Nil: "\m xs. zipWithM_x m xs [] = return ()" by (simp add: zipWithM_x_def sequence_x_def zipWith_def) show ?case by (cases ys; wpsimp simp: zipWithM_x_Nil zipWithM_x_Cons wp: Cons x) qed (* For compatibility with existing proofs. *) lemma zipWithM_x_inv: assumes x: "\x y. m x y \P\" shows "length xs = length ys \ zipWithM_x m xs ys \P\" by (rule zipWithM_x_inv', rule x) lemma sequence_x_Cons: "\x xs. sequence_x (x # xs) = (x >>= (\_. sequence_x xs))" by (simp add: sequence_x_def) lemma mapM_Cons: "mapM m (x # xs) = (do y \ m x; ys \ (mapM m xs); return (y # ys) od)" by (simp add: mapM_def sequence_def Let_def) lemma mapM_Nil: "mapM m [] = return []" by (simp add: mapM_def sequence_def) lemmas mapM_simps = mapM_Nil mapM_Cons lemma zipWithM_x_mapM: "zipWithM_x f as bs = (mapM (case_prod f) (zip as bs) >>= (\_. return ()))" apply (simp add: zipWithM_x_def zipWith_def) apply (induct ("zip as bs")) apply (simp add: sequence_x_def mapM_def sequence_def) apply (simp add: sequence_x_Cons mapM_Cons bind_assoc) done lemma mapM_x_mapM: "mapM_x m l = (mapM m l >>= (\x. return ()))" apply (simp add: mapM_x_def sequence_x_def mapM_def sequence_def) apply (induct l, simp_all add: Let_def bind_assoc) done lemma mapM_x_Nil: "mapM_x f [] = return ()" unfolding mapM_x_def sequence_x_def by simp lemma sequence_xappend1: "sequence_x (xs @ [x]) = (sequence_x xs >>= (\_. x))" by (induct xs) (simp add: sequence_x_def, simp add: sequence_x_Cons bind_assoc) lemma mapM_append_single: "mapM_x f (xs @ [y]) = (mapM_x f xs >>= (\_. f y))" unfolding mapM_x_def by (simp add: sequence_xappend1) lemma mapM_x_Cons: "mapM_x m (x # xs) = (do m x; mapM_x m xs od)" by (simp add: mapM_x_def sequence_x_def) lemma zipWithM_x_mapM_x: "zipWithM_x f as bs = mapM_x (\(x, y). f x y) (zip as bs)" apply (subst zipWithM_x_mapM) apply (subst mapM_x_mapM) apply (rule refl) done lemma zipWithM_x_append1: fixes f :: "'b \ 'c \ ('a, unit) nondet_monad" assumes ls: "length xs = length ys" shows "(zipWithM_x f (xs @ [x]) (ys @ [y])) = (zipWithM_x f xs ys >>= (\_. f x y))" unfolding zipWithM_x_def zipWith_def by (subst zip_append [OF ls], simp, rule sequence_xappend1) lemma zipWithM_x_Cons: assumes ls: "length xs = length ys" shows "(zipWithM_x f (x # xs) (y # ys)) = (f x y >>= (\_. zipWithM_x f xs ys))" unfolding zipWithM_x_def zipWith_def by (simp, rule sequence_x_Cons) lemma mapME_x_map_simp: "mapME_x m (map f xs) = mapME_x (m o f) xs" by (simp add: mapME_x_def sequenceE_x_def) lemma mapM_return: "mapM (\x. return (f x)) xs = return (map f xs)" apply (induct xs) apply (simp add: mapM_def sequence_def) apply (simp add: mapM_Cons) done lemma liftM_return [simp]: "liftM f (return x) = return (f x)" by (simp add: liftM_def) lemma mapM_x_return : "mapM_x (\_. return v) xs = return v" by (induct xs) (auto simp: mapM_x_Nil mapM_x_Cons) lemma bind_comm_mapM_comm: assumes bind_comm: "\n z. do x \ a; y \ b z; (n x y :: ('a, 's) nondet_monad) od = do y \ b z; x \ a; n x y od" shows "\n'. do x \ a; ys \ mapM b zs; (n' x ys :: ('a, 's) nondet_monad) od = do ys \ mapM b zs; x \ a; n' x ys od" proof (induct zs) case Nil thus ?case by (simp add: mapM_def sequence_def) next case (Cons z zs') thus ?case by (clarsimp simp: mapM_Cons bind_assoc bind_comm intro!: bind_cong [OF refl]) qed lemma liftE_handle : "(liftE f g) = liftE f" by (simp add: handleE_def handleE'_def liftE_def) lemma mapM_empty: "mapM f [] = return []" unfolding mapM_def by (simp add: sequence_def) lemma mapM_append: "mapM f (xs @ ys) = (do x \ mapM f xs; y \ mapM f ys; return (x @ y) od)" proof (induct xs) case Nil thus ?case by (simp add: mapM_empty) next case (Cons x xs) show ?case by (simp add: mapM_Cons bind_assoc Cons.hyps) qed lemma mapM_x_append: (* FIXME: remove extra return, fix proofs *) "mapM_x f (xs @ ys) = (do x \ mapM_x f xs; y \ mapM_x f ys; return () od)" by (simp add: mapM_x_mapM mapM_append bind_assoc) (* FIXME: duplicate, but mapM_x_append has an extra useless return *) lemma mapM_x_append2: "mapM_x f (xs @ ys) = do mapM_x f xs; mapM_x f ys od" apply (simp add: mapM_x_def sequence_x_def) apply (induct xs) apply simp apply (simp add: bind_assoc) done lemma mapM_singleton: "mapM f [x] = do r \ f x; return [r] od" by (simp add: mapM_def sequence_def) lemma mapM_x_singleton: "mapM_x f [x] = f x" by (simp add: mapM_x_mapM mapM_singleton) lemma mapME_x_sequenceE: "mapME_x f xs \ doE _ \ sequenceE (map f xs); returnOk () odE" apply (induct xs, simp_all add: mapME_x_def sequenceE_def sequenceE_x_def) apply (simp add: Let_def bindE_assoc) done lemma sequenceE_Cons: "sequenceE (x # xs) = (doE v \ x; vs \ sequenceE xs; returnOk (v # vs) odE)" by (simp add: sequenceE_def Let_def) lemma zipWithM_Nil [simp]: "zipWithM f xs [] = return []" by (simp add: zipWithM_def zipWith_def sequence_def) lemma zipWithM_One: "zipWithM f (x#xs) [a] = (do z \ f x a; return [z] od)" by (simp add: zipWithM_def zipWith_def sequence_def) lemma zipWithM_x_Nil[simp]: "zipWithM_x f xs [] = return ()" by (simp add: zipWithM_x_def zipWith_def sequence_x_def) lemma zipWithM_x_One: "zipWithM_x f (x#xs) [a] = f x a" by (simp add: zipWithM_x_def zipWith_def sequence_x_def) lemma mapM_last_Cons: "\ xs = [] \ g v = y; xs \ [] \ do x \ f (last xs); return (g x) od = do x \ f (last xs); return y od \ \ do ys \ mapM f xs; return (g (last (v # ys))) od = do mapM_x f xs; return y od" apply (cases "xs = []") apply (simp add: mapM_x_Nil mapM_Nil) apply (simp add: mapM_x_mapM) apply (subst append_butlast_last_id[symmetric], assumption, subst mapM_append)+ apply (simp add: bind_assoc mapM_Cons mapM_Nil) done lemma map_length_cong: "\ length xs = length ys; \x y. (x, y) \ set (zip xs ys) \ f x = g y \ \ map f xs = map g ys" apply atomize apply (erule rev_mp, erule list_induct2) apply auto done lemma mapM_length_cong: "\ length xs = length ys; \x y. (x, y) \ set (zip xs ys) \ f x = g y \ \ mapM f xs = mapM g ys" by (simp add: mapM_def map_length_cong) (* FIXME: duplicate *) lemma zipWithM_mapM: "zipWithM f xs ys = mapM (case_prod f) (zip xs ys)" by (simp add: zipWithM_def zipWith_def mapM_def) lemma zip_take_triv2: "length as \ n \ zip as (take n bs) = zip as bs" apply (induct as arbitrary: n bs; simp) apply (case_tac n; simp) apply (case_tac bs; simp) done lemma zipWithM_If_cut: "zipWithM (\a b. if a < n then f a b else g a b) [0 ..< m] xs = do ys \ zipWithM f [0 ..< min n m] xs; zs \ zipWithM g [n ..< m] (drop n xs); return (ys @ zs) od" apply (cases "n < m") apply (cut_tac i=0 and j=n and k="m - n" in upt_add_eq_append) apply simp apply (simp add: zipWithM_mapM) apply (simp add: zip_append1 mapM_append zip_take_triv2 split_def) apply (intro bind_cong bind_apply_cong refl mapM_length_cong fun_cong[OF mapM_length_cong]) apply (clarsimp simp: set_zip) apply (clarsimp simp: set_zip) apply (simp add: zipWithM_mapM mapM_Nil) apply (intro mapM_length_cong refl) apply (clarsimp simp: set_zip) done lemma mapM_liftM_const: "mapM (\x. liftM (\y. f x) (g x)) xs = liftM (\ys. map f xs) (mapM g xs)" apply (induct xs) apply (simp add: mapM_Nil) apply (simp add: mapM_Cons) apply (simp add: liftM_def bind_assoc) done lemma mapM_discarded: "mapM f xs >>= (\ys. g) = mapM_x f xs >>= (\_. g)" by (simp add: mapM_x_mapM) lemma mapM_x_map: "mapM_x f (map g xs) = mapM_x (\x. f (g x)) xs" by (simp add: mapM_x_def o_def) lemma filterM_append: "filterM f (xs @ ys) = do xs' \ filterM f xs; ys' \ filterM f ys; return (xs' @ ys') od" apply (induct xs) apply simp apply (simp add: bind_assoc) apply (rule ext bind_apply_cong [OF refl])+ apply simp done lemma filterM_mapM: "filterM f xs = do ys \ mapM (\x. do v \ f x; return (x, v) od) xs; return (map fst (filter snd ys)) od" apply (induct xs) apply (simp add: mapM_def sequence_def) apply (simp add: mapM_Cons bind_assoc) apply (rule bind_cong [OF refl] bind_apply_cong[OF refl])+ apply simp done lemma mapM_gets: assumes P: "\x. m x = gets (f x)" shows "mapM m xs = gets (\s. map (\x. f x s) xs)" proof (induct xs) case Nil show ?case by (simp add: mapM_def sequence_def gets_def get_def bind_def) next case (Cons y ys) thus ?case by (simp add: mapM_Cons P simpler_gets_def return_def bind_def) qed lemma mapM_map_simp: "mapM m (map f xs) = mapM (m \ f) xs" apply (induct xs) apply (simp add: mapM_def sequence_def) apply (simp add: mapM_Cons) done lemma filterM_voodoo: "\ys. P ys (do zs \ filterM m xs; return (ys @ zs) od) \ P [] (filterM m xs)" by (drule spec[where x=Nil], simp) lemma mapME_x_Cons: "mapME_x f (x # xs) = (doE f x; mapME_x f xs odE)" by (simp add: mapME_x_def sequenceE_x_def) lemma liftME_map_mapME: "liftME (map f) (mapME m xs) = mapME (liftME f o m) xs" apply (rule sym) apply (induct xs) apply (simp add: liftME_def mapME_Nil) apply (simp add: mapME_Cons liftME_def bindE_assoc) done lemma mapM_x_split_append: "mapM_x f xs = do _ \ mapM_x f (take n xs); mapM_x f (drop n xs) od" using mapM_x_append[where f=f and xs="take n xs" and ys="drop n xs"] by simp lemma mapME_wp: assumes x: "\x. x \ S \ \P\ f x \\_. P\, \\_. E\" shows "set xs \ S \ \P\ mapME f xs \\_. P\, \\_. E\" apply (induct xs) apply (simp add: mapME_def sequenceE_def) apply wp apply simp apply (simp add: mapME_Cons) apply (wp x|simp)+ done lemmas mapME_wp' = mapME_wp [OF _ subset_refl] lemma mapM_x_inv_wp3: fixes m :: "'b \ ('a, unit) nondet_monad" assumes hr: "\a as bs. xs = as @ [a] @ bs \ \\s. I s \ V as s\ m a \\r s. I s \ V (as @ [a]) s\" shows "\\s. I s \ V [] s\ mapM_x m xs \\rv s. I s \ V xs s\" using hr proof (induct xs rule: rev_induct) case Nil thus ?case by (simp add: mapM_x_Nil) next case (snoc x xs) show ?case apply (simp add: mapM_append_single) apply (wp snoc.prems) apply simp apply (rule snoc.hyps [OF snoc.prems]) apply simp apply assumption done qed lemma mapME_x_inv_wp: assumes x: "\x. \P\ f x \\rv. P\,\E\" shows "\P\ mapME_x f xs \\rv. P\,\E\" apply (induct xs) apply (simp add: mapME_x_def sequenceE_x_def) apply wp apply (simp add: mapME_x_def sequenceE_x_def) apply (fold mapME_x_def sequenceE_x_def) apply wp apply (rule x) apply assumption done lemma mapM_upd: assumes "\x rv s s'. (rv,s') \ fst (f x s) \ x \ set xs \ (rv, g s') \ fst (f x (g s))" shows "(rv,s') \ fst (mapM f xs s) \ (rv, g s') \ fst (mapM f xs (g s))" using assms proof (induct xs arbitrary: rv s s') case Nil thus ?case by (simp add: mapM_Nil return_def) next case (Cons z zs) from Cons.prems show ?case apply (clarsimp simp: mapM_Cons in_monad) apply (drule Cons.prems, simp) apply (rule exI, erule conjI) apply (erule Cons.hyps) apply (erule Cons.prems) apply simp done qed lemma no_fail_mapM_wp: assumes "\x. x \ set xs \ no_fail (P x) (f x)" assumes "\x y. \ x \ set xs; y \ set xs \ \ \P x\ f y \\_. P x\" shows "no_fail (\s. \x \ set xs. P x s) (mapM f xs)" using assms proof (induct xs) case Nil thus ?case by (simp add: mapM_empty) next case (Cons z zs) show ?case apply (clarsimp simp: mapM_Cons) apply (wp Cons.prems Cons.hyps hoare_vcg_const_Ball_lift|simp)+ done qed lemma no_fail_mapM: "\x. no_fail \ (f x) \ no_fail \ (mapM f xs)" apply (induct xs) apply (simp add: mapM_def sequence_def) apply (simp add: mapM_Cons) apply (wp|fastforce)+ done lemma filterM_preserved: "\ \x. x \ set xs \ \P\ m x \\rv. P\ \ \ \P\ filterM m xs \\rv. P\" apply (induct xs) apply (wp | simp | erule meta_mp | drule meta_spec)+ done lemma filterM_distinct1: "\\ and K (P \ distinct xs)\ filterM m xs \\rv s. (P \ distinct rv) \ set rv \ set xs\" apply (rule hoare_gen_asm, erule rev_mp) apply (rule rev_induct [where xs=xs]) apply (clarsimp | wp)+ apply (simp add: filterM_append) apply (erule hoare_seq_ext[rotated]) apply (rule hoare_seq_ext[rotated], rule hoare_vcg_prop) apply (wp, clarsimp) apply blast done lemma filterM_subset: "\\\ filterM m xs \\rv s. set rv \ set xs\" by (rule hoare_chain, rule filterM_distinct1[where P=False], simp_all) lemma filterM_all: "\ \x y. \ x \ set xs; y \ set xs \ \ \P y\ m x \\rv. P y\ \ \ \\s. \x \ set xs. P x s\ filterM m xs \\rv s. \x \ set rv. P x s\" apply (rule_tac Q="\rv s. set rv \ set xs \ (\x \ set xs. P x s)" in hoare_strengthen_post) apply (wp filterM_subset hoare_vcg_const_Ball_lift filterM_preserved) apply simp+ apply blast done lemma filterM_distinct: "\K (distinct xs)\ filterM m xs \\rv s. distinct rv\" by (rule hoare_chain, rule filterM_distinct1[where P=True], simp_all) lemma mapM_wp: assumes x: "\x. x \ S \ \P\ f x \\rv. P\" shows "set xs \ S \ \P\ mapM f xs \\rv. P\" apply (induct xs) apply (simp add: mapM_def sequence_def) apply (simp add: mapM_Cons) apply wp apply (rule x, clarsimp) apply simp done lemma mapM_wp': assumes x: "\x. x \ set xs \ \P\ f x \\rv. P\" shows "\P\ mapM f xs \\rv. P\" apply (rule mapM_wp) apply (erule x) apply simp done lemma mapM_set: assumes "\x. x \ set xs \ \P\ f x \\_. P\" assumes "\x. x \ set xs \ \P\ f x \\_. Q x\" assumes "\x y. \ x \ set xs; y \ set xs \ \ \P and Q y\ f x \\_. Q y\" shows "\P\ mapM f xs \\_ s. \x \ set xs. Q x s\" using assms proof (induct xs) case Nil show ?case by (simp add: mapM_def sequence_def) wp next case (Cons y ys) have PQ_inv: "\x. x \ set ys \ \P and Q y\ f x \\_. P and Q y\" by (wpsimp wp: Cons) show ?case apply (simp add: mapM_Cons) apply wp apply (rule hoare_vcg_conj_lift) apply (rule hoare_strengthen_post) apply (rule mapM_wp') apply (erule PQ_inv) apply simp apply (wp Cons|simp)+ done qed lemma mapM_set_inv: assumes "\x. x \ set xs \ \P\ f x \\_. P\" assumes "\x. x \ set xs \ \P\ f x \\_. Q x\" assumes "\x y. \ x \ set xs; y \ set xs \ \ \P and Q y\ f x \\_. Q y\" shows "\P\ mapM f xs \\_ s. P s \ (\x \ set xs. Q x s)\" apply (rule hoare_weaken_pre, rule hoare_vcg_conj_lift) apply (rule mapM_wp', erule assms) apply (rule mapM_set; rule assms; assumption) apply simp done lemma mapM_x_wp: assumes x: "\x. x \ S \ \P\ f x \\rv. P\" shows "set xs \ S \ \P\ mapM_x f xs \\rv. P\" by (subst mapM_x_mapM) (wp mapM_wp x) lemma no_fail_mapM': assumes rl: "\x. no_fail (\_. P x) (f x)" shows "no_fail (\_. \x \ set xs. P x) (mapM f xs)" proof (induct xs) case Nil thus ?case by (simp add: mapM_def sequence_def) next case (Cons x xs) have nf: "no_fail (\_. P x) (f x)" by (rule rl) have ih: "no_fail (\_. \x \ set xs. P x) (mapM f xs)" by (rule Cons) show ?case apply (simp add: mapM_Cons) apply (rule no_fail_pre) apply (wp nf ih) apply simp done qed lemma det_mapM: assumes x: "\x. x \ S \ det (f x)" shows "set xs \ S \ det (mapM f xs)" apply (induct xs) apply (simp add: mapM_def sequence_def) apply (simp add: mapM_Cons x) done lemma det_zipWithM_x: assumes x: "\x y. (x, y) \ set (zip xs ys) \ det (f x y)" shows "det (zipWithM_x f xs ys)" apply (simp add: zipWithM_x_mapM) apply (rule bind_detI) apply (rule det_mapM [where S="set (zip xs ys)"]) apply (clarsimp simp add: x) apply simp apply simp done lemma empty_fail_sequence_x : assumes "\m. m \ set ms \ empty_fail m" shows "empty_fail (sequence_x ms)" using assms by (induct ms) (auto simp: sequence_x_def) lemma mapME_set: assumes est: "\x. \R\ f x \P\, -" and invp: "\x y. \R and P x\ f y \\_. P x\, -" and invr: "\x. \R\ f x \\_. R\, -" shows "\R\ mapME f xs \\rv s. \x \ set rv. P x s\, -" proof (rule hoare_post_imp_R [where Q' = "\rv s. R s \ (\x \ set rv. P x s)"], induct xs) case Nil thus ?case by (simp add: mapME_Nil | wp returnOKE_R_wp)+ next case (Cons y ys) have minvp: "\x. \R and P x\ mapME f ys \\_. P x\, -" apply (rule hoare_pre) apply (rule_tac Q' = "\_ s. R s \ P x s" in hoare_post_imp_R) apply (wp mapME_wp' invr invp)+ apply simp apply simp apply simp done show ?case apply (simp add: mapME_Cons) apply (wp) apply (rule_tac Q' = "\xs s. (R s \ (\x \ set xs. P x s)) \ P x s" in hoare_post_imp_R) apply (wp Cons.hyps minvp) apply simp apply (fold validE_R_def) apply simp apply (wp invr est) apply simp done qed clarsimp lemma empty_fail_mapM_x [simp]: "(\x. empty_fail (a x)) \ empty_fail (mapM_x a xs)" apply (induct_tac xs) apply (clarsimp simp: mapM_x_Nil) apply (clarsimp simp: mapM_x_Cons) done lemma mapM_upd_inv: assumes f: "\x rv. (rv,s) \ fst (f x s) \ x \ set xs \ (rv, g s) \ fst (f x (g s))" assumes inv: "\x. \(=) s\ f x \\_. (=) s\" shows "(rv,s) \ fst (mapM f xs s) \ (rv, g s) \ fst (mapM f xs (g s))" using f inv proof (induct xs arbitrary: rv s) case Nil thus ?case by (simp add: mapM_Nil return_def) next case (Cons z zs) from Cons.prems show ?case apply (clarsimp simp: mapM_Cons in_monad) apply (frule use_valid, assumption, rule refl) apply clarsimp apply (drule Cons.prems, simp) apply (rule exI, erule conjI) apply (drule Cons.hyps) apply simp apply assumption apply simp done qed lemma case_option_find_give_me_a_map: "case_option a return (find f xs) = liftM projl (mapME (\x. if (f x) then throwError x else returnOk ()) xs >>=E (\x. assert (\x \ set xs. \ f x) >>= (\_. liftM (Inl :: 'a \ 'a + unit) a)))" apply (induct xs) apply simp apply (simp add: liftM_def mapME_Nil) apply (simp add: mapME_Cons split: if_split) apply (clarsimp simp add: throwError_def bindE_def bind_assoc liftM_def) apply (rule bind_cong [OF refl]) apply (simp add: lift_def throwError_def returnOk_def split: sum.split) done end