lib/monads: style cleanup in MonadEq+MonadEq_Lemmas

Style and proof contraction.

Signed-off-by: Gerwin Klein <gerwin.klein@proofcraft.systems>
This commit is contained in:
Gerwin Klein 2023-01-30 09:52:50 +11:00
parent 4f44b1ce7e
commit b0da6b3ee9
No known key found for this signature in database
GPG Key ID: 20A847CE6AB7F5F3
2 changed files with 35 additions and 71 deletions

View File

@ -54,24 +54,22 @@ method_setup monad_eq = \<open>
Method.sections Clasimp.clasimp_modifiers >> (K (SIMPLE_METHOD o monad_eq_tac))\<close>
"prove equality on monads"
lemma monad_eq_simp_state [monad_eq]:
lemma monad_eq_simp_state[monad_eq]:
"((A :: ('s, 'a) nondet_monad) s = B s') =
((\<forall>r t. (r, t) \<in> fst (A s) \<longrightarrow> (r, t) \<in> fst (B s'))
\<and> (\<forall>r t. (r, t) \<in> fst (B s') \<longrightarrow> (r, t) \<in> fst (A s))
\<and> (snd (A s) = snd (B s')))"
apply (auto intro!: set_eqI prod_eqI)
done
by (auto intro!: set_eqI prod_eqI)
lemma monad_eq_simp [monad_eq]:
lemma monad_eq_simp[monad_eq]:
"((A :: ('s, 'a) nondet_monad) = B) =
((\<forall>r t s. (r, t) \<in> fst (A s) \<longrightarrow> (r, t) \<in> fst (B s))
\<and> (\<forall>r t s. (r, t) \<in> fst (B s) \<longrightarrow> (r, t) \<in> fst (A s))
\<and> (\<forall>x. snd (A x) = snd (B x)))"
apply (auto intro!: set_eqI prod_eqI)
done
by (auto intro!: set_eqI prod_eqI)
declare in_monad [monad_eq]
declare in_bindE [monad_eq]
declare in_monad[monad_eq]
declare in_bindE[monad_eq]
(* Test *)
lemma "returnOk 3 = liftE (return 3)"

View File

@ -59,11 +59,8 @@ lemma fst_return:
by (simp add: return_def)
lemma in_bind_split[monad_eq]:
"(rv \<in> fst ((f >>= g) s)) =
(\<exists>rv'. rv' \<in> fst (f s) \<and> rv \<in> fst (g (fst rv') (snd rv')))"
apply (cases rv)
apply (fastforce simp add: in_bind)
done
"(rv \<in> fst ((f >>= g) s)) = (\<exists>rv'. rv' \<in> fst (f s) \<and> rv \<in> fst (g (fst rv') (snd rv')))"
by (cases rv) (fastforce simp: in_bind)
lemma Inr_in_liftE_simp[monad_eq]:
"((Inr rv, x) \<in> fst (liftE fn s)) = ((rv, x) \<in> fst (fn s))"
@ -71,8 +68,7 @@ lemma Inr_in_liftE_simp[monad_eq]:
lemma gets_the_member:
"(x, s') \<in> fst (gets_the f s) = (f s = Some x \<and> s' = s)"
by (case_tac "f s", simp_all add: gets_the_def
simpler_gets_def bind_def in_assert_opt)
by (cases "f s"; simp add: gets_the_def simpler_gets_def bind_def in_assert_opt)
lemma fst_throwError_returnOk:
"fst (throwError e s) = {(Inl e, s)}"
@ -100,10 +96,7 @@ declare in_assert_opt[monad_eq]
lemma not_snd_bindD':
"\<lbrakk>\<not> snd ((a >>= b) s); \<not> snd (a s) \<Longrightarrow> (rv, s') \<in> fst (a s)\<rbrakk> \<Longrightarrow> \<not> snd (a s) \<and> \<not> snd (b rv s')"
apply (frule not_snd_bindI1)
apply (erule not_snd_bindD)
apply simp
done
by (metis not_snd_bindI1 not_snd_bindI2)
lemma snd_bind[monad_eq]:
"snd ((a >>= b) s) = (snd (a s) \<or> (\<exists>r s'. (r, s') \<in> fst (a s) \<and> snd (b r s')))"
@ -113,21 +106,18 @@ lemma snd_bind[monad_eq]:
lemma in_lift[monad_eq]:
"(rv, s') \<in> fst (lift M v s) =
(case v of Inl x \<Rightarrow> rv = Inl x \<and> s' = s
| Inr x \<Rightarrow> (rv, s') \<in> fst (M x s))"
apply (clarsimp simp: lift_def throwError_def return_def split: sum.splits)
done
(case v of Inl x \<Rightarrow> rv = Inl x \<and> s' = s
| Inr x \<Rightarrow> (rv, s') \<in> fst (M x s))"
by (clarsimp simp: lift_def throwError_def return_def split: sum.splits)
lemma snd_lift[monad_eq]:
"snd (lift M a b) = (\<exists>x. a = Inr x \<and> snd (M x b))"
apply (clarsimp simp: lift_def throwError_def return_def split: sum.splits)
done
by (clarsimp simp: lift_def throwError_def return_def split: sum.splits)
lemma snd_bindE[monad_eq]:
"snd ((a >>=E b) s) = (snd (a s) \<or> (\<exists>r s'. (r, s') \<in> fst (a s) \<and> (\<exists>a. r = Inr a \<and> snd (b a s'))))"
apply (clarsimp simp: bindE_def)
apply monad_eq
done
unfolding bindE_def
by monad_eq
lemma snd_get[monad_eq]:
"snd (get s) = False"
@ -141,41 +131,25 @@ lemma in_handleE'[monad_eq]:
"((rv, s') \<in> fst ((f <handle2> g) s)) =
((\<exists>ex. rv = Inr ex \<and> (Inr ex, s') \<in> fst (f s)) \<or>
(\<exists>rv' s''. (rv, s') \<in> fst (g rv' s'') \<and> (Inl rv', s'') \<in> fst (f s)))"
apply (clarsimp simp: handleE'_def)
apply (rule iffI)
apply (subst (asm) in_bind_split)
apply (clarsimp simp: return_def split: sum.splits)
apply (case_tac a)
apply (erule allE, erule (1) impE)
apply clarsimp
apply (erule allE, erule (1) impE)
apply clarsimp
apply (subst in_bind_split)
apply (clarsimp simp: return_def split: sum.splits)
apply blast
done
unfolding handleE'_def return_def
by (simp add: in_bind_split) (fastforce split: sum.splits)
lemma in_handleE[monad_eq]:
"(a, b) \<in> fst ((A <handle> B) s) =
((\<exists>x. a = Inr x \<and> (Inr x, b) \<in> fst (A s)) \<or>
(\<exists>r t. (Inl r, t) \<in> fst (A s) \<and> (a, b) \<in> fst (B r t)))"
apply (unfold handleE_def)
apply (monad_eq split: sum.splits)
apply blast
done
unfolding handleE_def
by (monad_eq split: sum.splits) blast
lemma snd_handleE'[monad_eq]:
"snd ((A <handle2> B) s) = (snd (A s) \<or> (\<exists>r s'. (r, s')\<in>fst (A s) \<and> (\<exists>a. r = Inl a \<and> snd (B a s'))))"
apply (clarsimp simp: handleE'_def)
apply (monad_eq simp: Bex_def split: sum.splits)
apply (metis sum.sel(1) sum.distinct(1) sumE)
done
unfolding handleE'_def
by (monad_eq simp: Bex_def split: sum.splits) fastforce
lemma snd_handleE[monad_eq]:
"snd ((A <handle> B) s) = (snd (A s) \<or> (\<exists>r s'. (r, s')\<in>fst (A s) \<and> (\<exists>a. r = Inl a \<and> snd (B a s'))))"
apply (unfold handleE_def)
apply (rule snd_handleE')
done
unfolding handleE_def
by (rule snd_handleE')
declare in_liftE[monad_eq]
@ -192,11 +166,11 @@ lemma snd_when[monad_eq]:
by (clarsimp simp: when_def return_def)
lemma in_condition[monad_eq]:
"((a, b) \<in> fst (condition C L R s)) = ((C s \<longrightarrow> (a, b) \<in> fst (L s)) \<and> (\<not> C s \<longrightarrow> (a, b) \<in> fst (R s)))"
"(a, b) \<in> fst (condition C L R s) = ((C s \<longrightarrow> (a, b) \<in> fst (L s)) \<and> (\<not>C s \<longrightarrow> (a, b) \<in> fst (R s)))"
by (rule condition_split)
lemma snd_condition[monad_eq]:
"(snd (condition C L R s)) = ((C s \<longrightarrow> snd (L s)) \<and> (\<not> C s \<longrightarrow> snd (R s)))"
"snd (condition C L R s) = ((C s \<longrightarrow> snd (L s)) \<and> (\<not>C s \<longrightarrow> snd (R s)))"
by (rule condition_split)
declare snd_fail [simp]
@ -205,35 +179,27 @@ declare snd_returnOk [simp, monad_eq]
lemma in_catch[monad_eq]:
"(r, t) \<in> fst ((M <catch> E) s)
= ((Inr r, t) \<in> fst (M s)
\<or> (\<exists>r' s'. ((Inl r', s') \<in> fst (M s)) \<and> (r, t) \<in> fst (E r' s')))"
apply (rule iffI)
apply (clarsimp simp: catch_def in_bind in_return split: sum.splits)
= ((Inr r, t) \<in> fst (M s) \<or> (\<exists>r' s'. ((Inl r', s') \<in> fst (M s)) \<and> (r, t) \<in> fst (E r' s')))"
apply (rule iffI; clarsimp simp: catch_def in_bind in_return split: sum.splits)
apply (metis sumE)
apply (clarsimp simp: catch_def in_bind in_return split: sum.splits)
apply (metis sum.sel(1) sum.distinct(1) sum.inject(2))
apply fastforce
done
lemma snd_catch[monad_eq]:
"snd ((M <catch> E) s)
= (snd (M s)
\<or> (\<exists>r' s'. ((Inl r', s') \<in> fst (M s)) \<and> snd (E r' s')))"
apply (rule iffI)
apply (clarsimp simp: catch_def snd_bind snd_return split: sum.splits)
apply (clarsimp simp: catch_def snd_bind snd_return split: sum.splits)
apply force
done
= (snd (M s) \<or> (\<exists>r' s'. ((Inl r', s') \<in> fst (M s)) \<and> snd (E r' s')))"
by (force simp: catch_def snd_bind snd_return split: sum.splits)
declare in_get[monad_eq]
lemma returnOk_cong: "\<lbrakk> \<And>s. B a s = B' a s \<rbrakk> \<Longrightarrow> ((returnOk a) >>=E B) = ((returnOk a) >>=E B')"
lemma returnOk_cong:
"\<lbrakk> \<And>s. B a s = B' a s \<rbrakk> \<Longrightarrow> ((returnOk a) >>=E B) = ((returnOk a) >>=E B')"
by monad_eq
lemma in_state_assert [monad_eq, simp]:
"(rv, s') \<in> fst (state_assert P s) = (rv = () \<and> s' = s \<and> P s)"
apply (monad_eq simp: state_assert_def)
apply metis
done
by (monad_eq simp: state_assert_def)
metis
lemma snd_state_assert[monad_eq]:
"snd (state_assert P s) = (\<not> P s)"