(* * Copyright 2014, NICTA * * This software may be distributed and modified according to the terms of * the BSD 2-Clause license. Note that NO WARRANTY is provided. * See "LICENSE_BSD2.txt" for details. * * @TAG(NICTA_BSD) *) (* Author: Gerwin Klein The state and error monads in Isabelle, *) header "Monads" theory StateMonad imports Lib begin type_synonym ('s,'a) state_monad = "'s \ 'a \ 's" definition runState :: "('s,'a) state_monad \ 's \ 'a \ 's" where "runState \ id" definition "return a \ \s. (a,s)" definition bind :: "('s, 'a) state_monad \ ('a \ ('s, 'b) state_monad) \ ('s, 'b) state_monad" (infixl ">>=" 60) where "bind f g \ (\s. let (v,s') = f s in (g v) s')" definition "bind' f g \ bind f (\_. g)" declare bind'_def [iff] definition "get \ \s. (s,s)" definition "put s \ \_. ((),s)" definition "gets f \ get >>= (\s. return $ f s)" definition "modify f \ get >>= (\s. put $ f s)" definition "when p s \ if p then s else return ()" definition "unless p s \ when (\p) s" text {* The monad laws: *} lemma return_bind [simp]: "(return x >>= f) = f x" by (simp add: return_def bind_def runState_def) lemma bind_return [simp]: "(m >>= return) = m" apply (unfold bind_def return_def runState_def) apply (simp add: Let_def split_def) done lemma bind_assoc: fixes m :: "('s,'a) state_monad" fixes f :: "'a \ ('s,'b) state_monad" fixes g :: "'b \ ('s,'c) state_monad" shows "(m >>= f) >>= g = m >>= (\x. f x >>= g)" apply (unfold bind_def) apply (clarsimp simp add: Let_def split_def) done text {* An errorT state\_monad (returnOk=return, bindE=bind): *} definition "returnOk \ return o Inr" definition "throwError \ return o Inl" definition "Ok \ Inr" definition lift :: "('a \ ('s, 'e + 'b) state_monad) \ 'e+'a \ ('s, 'e + 'b) state_monad" where "lift f v \ case v of Inl e \ throwError e | Inr v' \ f v'" definition lift2 :: "('c \ ('a, 'b + 'e + 'd) state_monad) \ 'b+'e+'c \ ('a, 'b+'e+'d) state_monad" where "lift2 f v \ case v of Inl e \ throwError e | Inr v'' \ (case v'' of Inl e' \ return $ Inr $ Inl e' | Inr v' \ f v')" (* This is used if you are just trying to throwError by itself (throwError is nice on the else branch of an if or something, but *) definition raise :: "'a \ 's \ ('a + unit) \ 's" where "raise \ return \ Inl" definition bindE :: "('s, 'e + 'a) state_monad \ ('a \ ('s, 'e + 'b) state_monad) \ ('s, 'e + 'b) state_monad" (infixl ">>=E" 60) where "bindE f g \ bind f (lift g)" definition "bindE' f g \ bindE f (\_. g)" definition liftE :: "('s,'a) state_monad \ ('s, 'e+'a) state_monad" where "liftE f \ \s. let (v,s') = f s in (Inr v, s')" definition "whenE P f \ if P then f else returnOk ()" definition "unlessE P f \ if P then returnOk () else f" definition "throw_opt ex x \ case x of None \ throwError ex | Some v \ returnOk v" definition "bindEE f g \ bind f (lift2 g)" definition "bindEE' f g \ bindEE f (\_. g)" definition "modifyE \ (liftE \ modify)" definition "getsE x \ liftE $ gets x" syntax bindEE :: "'a \ 'b \ 'c" (infixl ">>=EE" 60) declare bindE'_def [iff] bindEE_def [iff] bindEE'_def [iff] lemma returnOk_bindE [simp]: "(returnOk x >>=E f) = f x" apply (unfold bindE_def return_def returnOk_def) apply (clarsimp simp: lift_def) done lemma lift_return [simp]: "lift (return \ Inr) = return" apply (rule ext) apply (simp add: lift_def throwError_def split: sum.splits) done lemma bindE_returnOk [simp]: "(m >>=E returnOk) = m" by (simp add: bindE_def returnOk_def) lemma bindE_assoc: shows "(m >>=E f) >>=E g = m >>=E (\x. f x >>=E g)" apply (clarsimp simp add: Let_def bindE_def bind_def lift_def split_def split: sum.splits) apply (rule ext) apply (auto simp: Let_def runState_def throwError_def return_def lift_def split: sum.splits) done lemma throwError_bindE [simp]: "throwError E >>=E f = throwError E" by (simp add: bindE_def bind_def throwError_def lift_def return_def) subsection "Syntax for state monad" nonterminal dobinds and dobind and nobind syntax "_dobind" :: "[pttrn, 'a] => dobind" ("(_ <-/ _)" 10) "" :: "dobind => dobinds" ("_") "_nobind" :: "'a => dobind" ("_") "_dobinds" :: "[dobind, dobinds] => dobinds" ("(_);//(_)") "_do" :: "[dobinds, 'a] => 'a" ("(do (_);// (_)//od)" 100) syntax (xsymbols) "_dobind" :: "[pttrn, 'a] => dobind" ("(_ \/ _)" 10) translations "_do (_dobinds b bs) e" == "_do b (_do bs e)" "_do (_nobind b) e" == "CONST bind' b e" "do x <- a; e od" == "a >>= (\x. e)" lemma "do x \ return 1; return 2; return x od = return 1" by simp subsection "Syntax for errorT state monad" syntax "_doE" :: "[dobinds, 'a] => 'a" ("(doE (_);// (_)//odE)" 100) translations "_doE (_dobinds b bs) e" == "_doE b (_doE bs e)" "_doE (_nobind b) e" == "CONST bindE' b e" "doE x <- a; e odE" == "a >>=E (\x. e)" subsection "Syntax for errorT errorT state monad" syntax "_doEE" :: "[dobinds, 'a] => 'a" ("(doEE (_);// (_)//odEE)" 100) translations "_doEE (_dobinds b bs) e" == "_doEE b (_doEE bs e)" "_doEE (_nobind b) e" == "CONST bindEE' b e" "doEE x <- a; e odEE" == "a >>=EE (\x. e)" primrec inc_forloop :: "nat \ 'g::{plus,one} \ ('g \ ('a, 'b + unit) state_monad) \ ('a, 'b + unit) state_monad" where "inc_forloop 0 current body = returnOk ()" | "inc_forloop (Suc left) current body = doE body current ; inc_forloop left (current+1) body odE" primrec do_times :: "nat \ ('a, 'b + unit) state_monad \ ('a, 'b + unit) state_monad \ ('a, 'b + unit) state_monad" where "do_times 0 body increment = returnOk ()" | "do_times (Suc left) body increment = doE body ; increment ; do_times left body increment odE" definition function_update :: "'a \ ('b \ 'b) \ ('a \ 'b) \ ('a \ 'b)" where "function_update index modifier f \ \x. if x = index then modifier (f x) else (f x)" lemma "doE x \ returnOk 1; returnOk 2; returnOk x odE = returnOk 1" by simp term "doEE x \ returnOk $ Ok 1; returnOk $ Ok 2; returnOk $ Ok x odEE" definition "skip \ returnOk $ Ok ()" definition "liftM f m \ do x \ m; return (f x) od" definition "liftME f m \ doE x \ m; returnOk (f x) odE" definition "sequence_x xs \ foldr (\x y. x >>= (\_. y)) xs (return ())" definition "zipWithM_x f xs ys \ sequence_x (zipWith f xs ys)" definition "mapM_x f xs \ sequence_x (map f xs)" definition "sequence xs \ let mcons = (\p q. p >>= (\x. q >>= (\y. return (x#y)))) in foldr mcons xs (return [])" definition "mapM f xs \ sequence (map f xs)" definition "sequenceE_x xs \ foldr (\x y. doE uu <- x; y odE) xs (returnOk ())" definition "mapME_x f xs \ sequenceE_x (map f xs)" definition "sequenceEE_x xs \ foldr bindEE' xs (skip)" definition "mapMEE_x f xs \ sequenceEE_x (map f xs)" definition catch :: "('s, 'a + 'b) state_monad \ ('a \ ('s, 'b) state_monad) \ ('s, 'b) state_monad" where "catch f handler \ do x \ f; case x of Inr b \ return b | Inl e \ handler e od" definition handleE :: "('s, 'x + 'a) state_monad \ ('x \ ('s, 'x + 'a) state_monad) \ ('s, 'x + 'a) state_monad" (infix "" 11) where "f handler \ do v \ f; case v of Inl e \ handler e | Inr v' \ return v od" definition handle_elseE :: "('s, 'x + 'a) state_monad \ ('x \ ('s, 'x + 'a) state_monad) \ ('a \ ('s, 'x + 'a) state_monad) \ ('s, 'x + 'a) state_monad" ("_ _ _" 10) where "f handler continue \ do v \ f; case v of Inl e \ handler e | Inr v \ continue v od" definition isSkip :: "('s, 'a) state_monad \ bool" where "isSkip m \ \s. \r. m s = (r,s)" (* "while" for monads not needed in current version. Formalisation available in revision 1397 *) lemma isSkip_bindI: "\ isSkip f; \x. isSkip (g x) \ \ isSkip (f >>= g)" apply (clarsimp simp add: isSkip_def bind_def Let_def) apply (erule_tac x=s in allE) apply clarsimp done lemma isSkip_return [simp,intro!]: "isSkip (return x)" by (simp add: isSkip_def return_def) lemma isSkip_gets [simp,intro!]: "isSkip (gets x)" by (simp add: isSkip_def gets_def get_def bind_def return_def) lemma isSkip_liftE [iff]: "isSkip (liftE f) = isSkip f" apply (simp add: isSkip_def liftE_def Let_def split_def) apply rule apply clarsimp apply (case_tac "f s") apply (erule_tac x = s in allE) apply simp apply clarsimp apply (case_tac "f s") apply (erule_tac x = s in allE) apply simp done lemma isSkip_liftI [simp, intro!]: "\ \y. x = Inr y \ isSkip (f y) \ \ isSkip (lift f x)" by (simp add: lift_def throwError_def return_def isSkip_def split: sum.splits) lemma isSkip_Error [iff]: "isSkip (throwError x)" by (simp add: throwError_def) lemma isSkip_returnOk [iff]: "isSkip (returnOk x)" by (simp add: returnOk_def) lemma isSkip_throw_opt [iff]: "isSkip (throw_opt e x)" by (simp add: throw_opt_def split: option.splits) lemma nested_bind [simp]: "do x <- do y <- f; return (g y) od; h x od = do y <- f; h (g y) od" apply (clarsimp simp add: bind_def) apply (rule ext) apply (clarsimp simp add: Let_def split_def runState_def return_def) done lemma skip_bind: "isSkip s \ do _ \ s; g od = g" apply (clarsimp simp add: bind_def) apply (rule ext) apply (clarsimp simp add: isSkip_def Let_def) apply (erule_tac x=sa in allE) apply clarsimp done lemma bind_eqI: "\ f = f'; \x. g x = g' x \ \ f >>= g = f' >>= g'" by (simp add: bind_def) lemma bind_cong [fundef_cong]: "\ f = f'; \v s s'. f' s = (v, s') \ g v s' = g' v s' \ \ f >>= g = f' >>= g'" by (simp add: bind_def Let_def split_def) lemma bind'_cong [fundef_cong]: "\ f = f'; \v s s'. f' s = (v, s') \ g s' = g' s' \ \ bind' f g = bind' f' g'" apply (simp) apply (rule bind_cong, simp_all) done lemma bindE_cong[fundef_cong]: "\ M = M' ; \v s s'. M' s = (Inr v, s') \ N v s' = N' v s' \ \ bindE M N = bindE M' N'" apply (simp add: bindE_def) apply (rule bind_cong) apply (rule refl) apply (unfold lift_def) apply (case_tac v, simp_all) done lemma bindE'_cong[fundef_cong]: "\ M = M' ; \v s s'. M' s = (Inr v, s') \ N s' = N' s' \ \ bindE' M N = bindE' M' N'" apply (simp) apply (rule bindE_cong, simp_all) done definition valid :: "('s \ bool) \ ('s,'a) state_monad \ ('a \ 's \ bool) \ bool" ("\_\ _ \_\") where "\P\ f \Q\ \ \s. P s \ split Q (f s)" definition validE :: "('s \ bool) \ ('s, 'a + 'b) state_monad \ ('b \ 's \ bool) \ ('a \ 's \ bool) \ bool" ("\_\ _ \_\, \_\") where "\P\ f \Q\,\R\ \ \s. P s \ split (\r s. case r of Inr b \ Q b s | Inl a \ R a s) (f s)" lemma validE_def2: "\P\ f \Q\,\R\ \ \P\ f \ \r s. case r of Inr b \ Q b s | Inl a \ R a s \" by (unfold valid_def validE_def) (* FIXME: modernize *) syntax top :: "'a \ bool" ("\") bottom :: "'a \ bool" ("\") translations "\" == "\_. CONST True" "\" == "\_. CONST False" definition bipred_conj :: "('a \ 'b \ bool) \ ('a \ 'b \ bool) \ ('a \ 'b \ bool)" (infixl "And" 96) where "bipred_conj P Q \ \x y. P x y \ Q x y" definition bipred_disj :: "('a \ 'b \ bool) \ ('a \ 'b \ bool) \ ('a \ 'b \ bool)" (infixl "Or" 91) where "bipred_disj P Q \ \x y. P x y \ Q x y" definition bipred_neg :: "('a \ 'b \ bool) \ ('a \ 'b \ bool)" ("Not _") where "bipred_neg P \ \x y. \ P x y" syntax toptop :: "'a \ 'b \ bool" ("\\") botbot :: "'a \ 'b \ bool" ("\\") translations "\\" == "\_ _. CONST True" "\\" == "\_ _. CONST False" definition pred_lift_exact :: "('a \ bool) \ ('b \ bool) \ ('a \ 'b \ bool)" ("\_,_\") where "pred_lift_exact P Q \ \x y. P x \ Q y" lemma pred_lift_taut[simp]: "\\,\\ = \\" apply(simp add:pred_lift_exact_def) done lemma pred_lift_cont_l[simp]: "\\,x\ = \\" apply(simp add:pred_lift_exact_def) done lemma pred_lift_cont_r[simp]: "\x,\\ = \\" apply(simp add:pred_lift_exact_def) done lemma pred_liftI[intro!]: "\ P x; Q y \ \ \P,Q\ x y" apply(simp add:pred_lift_exact_def) done lemma pred_exact_split: "\P,Q\ = (\P,\\ And \\,Q\)" apply(simp add:pred_lift_exact_def bipred_conj_def) done lemma pred_andE[elim!]: "\ (A and B) x; \ A x; B x \ \ R \ \ R" apply(simp add:pred_conj_def) done lemma pred_andI[intro!]: "\ A x; B x \ \ (A and B) x" apply(simp add:pred_conj_def) done lemma bipred_conj_app[simp]: "(P And Q) x = (P x and Q x)" apply(simp add:pred_conj_def bipred_conj_def) done lemma bipred_disj_app[simp]: "(P Or Q) x = (P x or Q x)" apply(simp add:pred_disj_def bipred_disj_def) done lemma pred_conj_app[simp]: "(P and Q) x = (P x \ Q x)" apply(simp add:pred_conj_def) done lemma pred_disj_app[simp]: "(P or Q) x = (P x \ Q x)" apply(simp add:pred_disj_def) done lemma pred_notnotD[simp]: "(not not P) = P" apply(simp add:pred_neg_def) done lemma bipred_notnotD[simp]: "(Not Not P) = P" apply(simp add:bipred_neg_def) done lemma pred_lift_add[simp]: "\P,Q\ x = ((\s. P x) and Q)" apply(simp add:pred_lift_exact_def pred_conj_def) done lemma pred_and_true[simp]: "(P and \) = P" apply(simp add:pred_conj_def) done lemma pred_and_true_var[simp]: "(\ and P) = P" apply(simp add:pred_conj_def) done lemma pred_and_false[simp]: "(P and \) = \" apply(simp add:pred_conj_def) done lemma pred_and_false_var[simp]: "(\ and P) = \" apply(simp add:pred_conj_def) done lemma seq': "\ \A\ f \B\; \x. P x \ \C\ g x \D\; \x s. B x s \ P x \ C s \ \ \A\ do x \ f; g x od \D\" apply (clarsimp simp: valid_def runState_def bind_def Let_def split_def) apply (case_tac "f s") apply fastforce done lemma seq: assumes f_valid: "\A\ f \B\" assumes g_valid: "\x. P x \ \C\ g x \D\" assumes bind: "\x s. B x s \ P x \ C s" shows "\A\ do x \ f; g x od \D\" apply (insert f_valid g_valid bind) apply (blast intro: seq') done lemma seq_invar_nobind: assumes f_valid: "\A\ f \\\,A\\" assumes g_valid: "\x. \A\ g x \\\,A\\" shows "\A\ do x \ f; g x od \\\,A\\" apply(rule_tac B="\\,A\" and C="A" and P="\" in seq) apply(insert f_valid g_valid) apply(simp_all add:pred_lift_exact_def) done lemma seq_invar_bind: assumes f_valid: "\A\ f \\B,A\\" assumes g_valid: "\x. P x \ \A\ g x \\\,A\\" assumes bind: "\x. B x \ P x" shows "\A\ do x \ f; g x od \\\,A\\" apply(rule_tac B="\B,A\" and C="A" and P="P" in seq) apply(insert f_valid g_valid bind) apply(simp_all add: pred_lift_exact_def) done lemma seq_noimp: assumes f_valid: "\A\ f \\C,B\\" assumes g_valid: "\x. C x \ \B\ g x \D\" shows "\A\ do x \ f; g x od \D\" apply(rule_tac B="\C,B\" and C="B" and P="C" in seq) apply(insert f_valid g_valid, simp_all add:pred_lift_exact_def) done lemma seq_ext': "\ \A\ f \B\; \x. \B x\ g x \C\ \ \ \A\ do x \ f; g x od \C\" apply (clarsimp simp: valid_def runState_def bind_def Let_def split_def) done lemma seq_ext: assumes f_valid: "\A\ f \B\" assumes g_valid: "\x. \B x\ g x \C\" shows "\A\ do x \ f; g x od \C\" apply(insert f_valid g_valid) apply(blast intro: seq_ext') done lemma seqE': "\ \A\ f \B\,\E\; \x. \B x\ g x \C\,\E\ \ \ \A\ doE x \ f; g x odE \C\,\E\" apply(simp add:bindE_def lift_def bind_def Let_def split_def) apply(clarsimp simp:validE_def) apply(case_tac "fst (f s)", simp_all) apply(case_tac a, simp_all) apply(fastforce simp:throwError_def return_def) apply(clarsimp simp:throwError_def return_def) apply(case_tac a, simp_all) apply(fastforce)+ done lemma seqE: assumes f_valid: "\A\ f \B\,\E\" assumes g_valid: "\x. \B x\ g x \C\,\E\" shows "\A\ doE x \ f; g x odE \C\,\E\" apply(insert f_valid g_valid) apply(blast intro: seqE') done lemma get_sp: "\P\ get \\a s. s = a \ P s\" apply(simp add:get_def valid_def) done lemma put_sp: "\\\ put a \\_ s. s = a\" apply(simp add:put_def valid_def) done lemma return_sp: "\P\ return a \\b s. b = a \ P s\" apply(simp add:return_def valid_def) done lemma hoare_post_conj [intro!]: "\ \ P \ a \ Q \; \ P \ a \ R \ \ \ \ P \ a \ Q And R \" apply(simp add:valid_def split_def bipred_conj_def) done lemma hoare_pre_disj [intro!]: "\ \ P \ a \ R \; \ Q \ a \ R \ \ \ \ P or Q \ a \ R \" apply(simp add:valid_def pred_disj_def) done lemma hoare_post_taut [iff]: "\ P \ a \ \\ \" apply(simp add:valid_def) done lemma hoare_pre_cont [iff]: "\ \ \ a \ P \" apply(simp add:valid_def) done lemma hoare_return [intro!]: "\x. P x \ \ Q \ return x \ \P,Q\ \" apply(simp add:valid_def return_def pred_lift_exact_def) done lemma hoare_return_drop [iff]: "\ Q \ return x \ \\,Q\ \" apply(simp add:valid_def return_def pred_lift_exact_def) done lemma hoare_return_drop_var [iff]: "\ Q \ return x \ \r. Q \" apply(simp add:valid_def return_def pred_lift_exact_def) done lemma hoare_return_only [intro!]: "\x. P x \ \ Q \ return x \ \P,\\ \" apply(simp add:valid_def return_def pred_lift_exact_def) done lemma hoare_get [iff]: "\ P \ get \ \P,P\ \" apply(simp add:valid_def get_def pred_lift_exact_def) done lemma hoare_gets [intro!]: "\ \s. P s \ Q (f s) s \ \ \ P \ gets f \ Q \" apply(simp add:valid_def gets_def get_def bind_def return_def) done lemma hoare_modify [iff]: "\ P o f \ modify f \ \\,P\ \" apply(simp add:valid_def modify_def pred_lift_exact_def put_def bind_def get_def) done lemma hoare_modifyE [intro!]: "\ \s. P s \ Q (f s) \ \ \ P \ modify f \ \\,Q\ \" apply(simp add:valid_def modify_def pred_lift_exact_def put_def bind_def get_def) done lemma hoare_modifyE_var [intro!]: "\ \s. P s \ Q (f s) \ \ \ P \ modify f \ \r s. Q s \" apply(simp add:valid_def modify_def pred_lift_exact_def put_def bind_def get_def) done lemma hoare_put [intro!]: "P x \ \ Q \ put x \ \\,P\\" apply(simp add:valid_def put_def pred_lift_exact_def) done lemma hoare_if [intro!]: "\ P \ \ Q \ a \ R \; \ P \ \ Q \ b \ R \ \ \ \ Q \ if P then a else b \ R \" apply(simp add:valid_def) done lemma hoare_when [intro!]: "\ \ P \ \ \ Q \ a \ \\,R\ \; \s. \ \ P; Q s \ \ R s \ \ \ Q \ when P a \ \\,R\ \" apply(simp add:valid_def when_def split_def return_def pred_lift_exact_def) done lemma hoare_unless [intro!]: "\ \s. \ P; Q s \ \ R s; \ \ P \ \ \ Q \ a \ \\,R\ \ \ \ \ Q \ unless P a \ \\,R\ \" apply(simp add:valid_def unless_def split_def when_def return_def pred_lift_exact_def) done lemma hoare_pre_subst: "\ A = B; \A\ a \C\ \ \ \B\ a \C\" apply(clarsimp simp:valid_def split_def) done lemma hoare_post_subst: "\ B = C; \A\ a \B\ \ \ \A\ a \C\" apply(clarsimp simp:valid_def split_def) done lemma hoare_pre_tautI: "\ \A and P\ a \B\; \A and not P\ a \B\ \ \ \A\ a \B\" apply(clarsimp simp:valid_def split_def pred_conj_def pred_neg_def, blast) done lemma hoare_return_var[intro!]: "\ \x. P x \ Q x \ \ (\x. P x \ \R\ return x \\Q,R\\)" apply(clarsimp simp:valid_def split_def return_def pred_lift_exact_def) done lemma hoare_return_drop_imp[intro!]: "\ \s. P s \ Q s \ \ \P\ return x \\\,Q\\" apply(simp add:valid_def return_def) done lemma hoare_case_option_inference: "\ \y. x = Some y \ P x; x = None \ P x \ \ P x" apply(case_tac "x", simp_all) done lemma hoare_pre_imp: "\ \Q\ a \R\; \s. P s \ Q s \ \ \P\ a \R\" apply(simp add:valid_def) done lemma hoare_post_imp: "\ \P\ a \Q\; \r s. Q r s \ R r s \ \ \P\ a \R\" apply(simp add:valid_def split_def) done lemma hoare_post_impE: "\ \P\ a \Q\,\E\; \r s. Q r s \ R r s; \e s. E e s \ F e s \ \ \P\ a \R\,\F\" apply(clarsimp simp:validE_def) apply(case_tac aa, simp_all) apply(fastforce)+ done lemma "isSkip f \ \ P \ f \ \\,P\ \" apply (clarsimp simp: valid_def split_def isSkip_def) apply (case_tac "f s") apply (erule_tac x=s in allE) apply auto done end