aarch64 ainvs: use machine_word for table indices

This replaces 'a word for indices with machine_word. Since we can't use
a specific word length for a generic table index (because different
tables can have different index types), we don't win much by using 'a
word, but we do lose something: we must instantiate 'a when we use the
term, which means we need to decide at that point which type of table
we are talking about. This forces early case distinctions in proofs.

Using machine_word allows us to delay committing to a particular table
type and instead write a generic condition on the width of the index.

We are using machine_word instead of nat or a different specific word
length, because the index into the table is a slice of either an
obj_ref (in ptes_of) or a vref (when we do page table walks), both of
which are compatible with machine_word.

Signed-off-by: Gerwin Klein <gerwin.klein@proofcraft.systems>
This commit is contained in:
Gerwin Klein 2022-05-05 08:54:55 +10:00 committed by Gerwin Klein
parent 8e92e1f702
commit 6c229d7b0d
2 changed files with 93 additions and 149 deletions

View File

@ -294,9 +294,7 @@ primrec valid_pte :: "vm_level \<Rightarrow> pte \<Rightarrow> 'z::state_ext sta
(\<lambda>s. normal_pt_at (ptrFromPAddr base) s \<and> 0 < level)"
definition pt_range :: "pt \<Rightarrow> pte set" where
"pt_range pt \<equiv> case pt of VSRootPT vs \<Rightarrow> range vs | NormalPT pt \<Rightarrow> range pt"
lemmas pt_range_simps[simp] = pt_range_def[split_simps pt.split]
"pt_range pt \<equiv> range (pt_apply pt)"
fun valid_vspace_obj :: "vm_level \<Rightarrow> arch_kernel_obj \<Rightarrow> 'z::state_ext state \<Rightarrow> bool" where
"valid_vspace_obj _ (ASIDPool pool) =
@ -1087,20 +1085,15 @@ lemma pte_at_def2:
"pte_at vsp p = (pt_at vsp (table_base vsp p) and K (is_aligned p pte_bits))"
by (auto simp: pte_at_def level_pte_of_pt)
lemmas pt_apply_def_simps[simp] = pt_apply_def[split_simps pt.split]
lemma level_ptes_of_pts:
"(level_pte_of vsp p (pts_of s) = Some pte) =
(\<exists>pt. pts_of s (table_base vsp p) = Some pt \<and> pt_pte pt p = pte \<and>
(\<exists>pt. pts_of s (table_base vsp p) = Some pt \<and> pt_apply pt (table_index vsp p) = pte \<and>
is_aligned p pte_bits \<and> vsp = is_VSRootPT pt)"
by (clarsimp simp: level_pte_of_def obj_at_def obind_def in_omonad opt_map_def
split: option.splits)
lemmas pt_pte_simps [simp] = pt_pte_def[split_simps pt.split]
definition pt_apply :: "pt \<Rightarrow> machine_word \<Rightarrow> pte" where
"pt_apply pt idx \<equiv> case pt of NormalPT npt \<Rightarrow> npt (ucast idx) | VSRootPT vs \<Rightarrow> vs (ucast idx)"
lemmas pt_apply_def_simps[simp] = pt_apply_def[split_simps pt.split]
lemma ptes_of_Some:
"(ptes_of s vsp p = Some pte) =
(is_aligned p pte_bits \<and>
@ -1168,6 +1161,29 @@ lemma valid_vcpu_default[simp]:
"valid_vcpu default_vcpu s"
by (simp add: valid_vcpu_def default_vcpu_def)
lemma pt_range_Normal[simp]:
"pt_range (NormalPT npt) = range npt"
unfolding pt_range_def
by (force intro: arg_cong[where f=npt] ucast_down_ucast_id[symmetric] simp: is_down)
lemma pt_range_VSRoot[simp]:
"pt_range (VSRootPT vs) = range vs"
unfolding pt_range_def
by (force intro: arg_cong[where f=vs] ucast_down_ucast_id[symmetric] simp: is_down bit_simps)
lemma pt_apply_pt_range[simp, intro!]:
"pt_apply pt idx \<in> pt_range pt"
by (auto simp: pt_range_def)
lemma pt_apply_mask:
"pt_apply pt (idx && mask (ptTranslationBits (is_VSRootPT pt))) = pt_apply pt idx"
by (cases pt; clarsimp simp: bit_simps ucast_mask_drop)
lemma pt_rangeD:
"pte \<in> pt_range pt \<Longrightarrow>
\<exists>idx. idx \<le> mask (ptTranslationBits (is_VSRootPT pt)) \<and> pt_apply pt idx = pte"
by (fastforce simp: pt_range_def intro!: word_and_le1 pt_apply_mask)
lemma wellformed_arch_default[simp]:
"arch_valid_obj (default_arch_object ao_type dev us) s"
unfolding arch_valid_obj_def default_arch_object_def
@ -1826,11 +1842,22 @@ lemma is_aligned_table_base_pte_bits[simp]:
unfolding pte_bits_def
by (simp add: bit_simps is_aligned_neg_mask)
lemma pt_slot_offset_offset:
lemma pt_index_bounded[simp, intro!]:
"pt_index level vref \<le> mask (ptTranslationBits (level = max_pt_level))"
by (simp add: pt_index_def word_and_le1)
lemma table_index_plus:
"\<lbrakk> is_aligned pt_ptr (pt_bits vsp); i \<le> mask (ptTranslationBits vsp) \<rbrakk> \<Longrightarrow>
table_index vsp (pt_ptr + (i << pte_bits)) = i"
unfolding is_aligned_mask bit_simps
apply (cases "vsp \<and> config_ARM_PA_SIZE_BITS_40"; simp only:)
by (subst word_plus_and_or_coroll; word_bitwise; simp add: word_size)+
lemma pt_slot_offset_offset[simp]:
"is_aligned pt (pt_bits (level = max_pt_level)) \<Longrightarrow>
pt_slot_offset level pt vref && mask (pt_bits (level = max_pt_level)) >> pte_bits = pt_index level vref"
by (simp add: pt_slot_offset_def pt_index_def bit_simps mask_add_aligned and_mask_shiftr_comm
word_size and_mask2 mask_twice)
table_index (level = max_pt_level) (pt_slot_offset level pt vref) = pt_index level vref"
unfolding pt_slot_offset_def
by (simp add: table_index_plus)
lemmas pt_slot_offset_minus_eq =
pt_slot_offset_vref_for_level_eq[where level="level - 1" for level, simplified]
@ -1848,52 +1875,6 @@ lemma table_base_plus:
apply (cases "vsp \<and> config_ARM_PA_SIZE_BITS_40"; simp only:)
by (subst word_plus_and_or_coroll; word_bitwise; simp add: word_size)+
lemma table_base_plus_ucast:
"is_aligned pt_ptr (pt_bits vsp) \<Longrightarrow>
table_base vsp (pt_ptr + (ucast (i::pt_index) << pte_bits)) = pt_ptr"
by (fastforce intro!: table_base_plus ucast_leq_mask simp: bit_simps)
lemma table_index_plus:
"\<lbrakk> is_aligned pt_ptr (pt_bits vsp); i \<le> mask (ptTranslationBits vsp) \<rbrakk> \<Longrightarrow>
table_index vsp (pt_ptr + (i << pte_bits)) = ucast i"
unfolding is_aligned_mask bit_simps
apply (cases "vsp \<and> config_ARM_PA_SIZE_BITS_40"; simp only:)
by (subst word_plus_and_or_coroll; word_bitwise; simp add: word_size)+
lemma table_index_plus_ucast:
"\<lbrakk> is_aligned pt_ptr (pt_bits vsp); LENGTH('a::len) = ptTranslationBits vsp \<rbrakk> \<Longrightarrow>
table_index vsp (pt_ptr + (ucast (i::'a word) << pte_bits)) = i"
apply (drule table_index_plus[where i="ucast i" and 'a="'a"])
apply (rule ucast_leq_mask, simp add: bit_simps)
apply (simp add: is_down_def target_size_def source_size_def word_size ucast_down_ucast_id bit_simps)
done
lemma table_index_plus_ucast_pt[simp]:
"is_aligned pt_ptr (pt_bits False) \<Longrightarrow>
table_index False (pt_ptr + (ucast (i::pt_index) << pte_bits)) = i"
by (rule table_index_plus_ucast; simp add: bit_simps)
lemma table_index_plus_ucast_vs[simp]:
"is_aligned pt_ptr (pt_bits True) \<Longrightarrow>
table_index True (pt_ptr + (ucast (i::vs_index) << pte_bits)) = i"
by (rule table_index_plus_ucast; simp add: bit_simps)
lemma table_index_offset_pt_bits_left_vs:
"\<lbrakk> is_aligned pt_ref (pt_bits True); lvl = max_pt_level \<rbrakk> \<Longrightarrow>
(table_index True (pt_slot_offset lvl pt_ref vref)::vs_index) =
ucast (vref >> pt_bits_left lvl)"
by (auto simp: table_index_plus_ucast pt_slot_offset_def pt_index_def vs_index_ptTranslationBits
ucast_ucast_mask[where 'a=vs_index_len, simplified, symmetric])
lemma table_index_offset_pt_bits_left_pt:
"\<lbrakk> is_aligned pt_ref (pt_bits False); lvl \<noteq> max_pt_level \<rbrakk> \<Longrightarrow>
(table_index False (pt_slot_offset lvl pt_ref vref)::pt_index) =
ucast (vref >> pt_bits_left lvl)"
by (simp add: table_index_plus_ucast pt_slot_offset_def pt_index_def
ptTranslationBits_def ucast_ucast_mask[where 'a=pt_index_len, simplified, symmetric])
lemma vs_lookup_slot_level:
"vs_lookup_slot bot_level asid vref s = Some (level, p) \<Longrightarrow>
vs_lookup_slot level asid vref s = Some (level, p)"
@ -2178,12 +2159,11 @@ lemma is_aligned_pt_slot_offset_pte:
unfolding pt_slot_offset_def
by (simp add: is_aligned_add bit_simps is_aligned_weaken is_aligned_shift)
lemma pt_slot_offset_pt_range:
"\<lbrakk> ptes_of s (level = max_pt_level) (pt_slot_offset level pt vref) = Some pte;
pts_of s pt = Some ptable; is_aligned pt (pt_bits (level = max_pt_level)) \<rbrakk>
\<Longrightarrow> pte \<in> pt_range ptable"
by (clarsimp simp: ptes_of_Some pt_apply_def split: pt.split)
by (clarsimp simp: ptes_of_Some)
lemma valid_vspace_objs_strongD:
"\<lbrakk> valid_vspace_objs s;
@ -2432,41 +2412,25 @@ lemma vs_lookup_table_ap_step:
apply (fastforce simp: vs_lookup_table_def vspace_for_pool_def entry_for_pool_def in_omonad)
done
locale_abbrev vref_for_index :: "'a::len word \<Rightarrow> vm_level \<Rightarrow> vspace_ref" where
"vref_for_index idx level \<equiv> ucast idx << pt_bits_left level"
locale_abbrev vref_for_index :: "machine_word \<Rightarrow> vm_level \<Rightarrow> vspace_ref" where
"vref_for_index idx level \<equiv> idx << pt_bits_left level"
locale_abbrev vref_for_level_idx :: "vspace_ref \<Rightarrow> 'a::len word \<Rightarrow> vm_level \<Rightarrow> vspace_ref" where
locale_abbrev vref_for_level_idx :: "vspace_ref \<Rightarrow> machine_word \<Rightarrow> vm_level \<Rightarrow> vspace_ref" where
"vref_for_level_idx vref idx level \<equiv> vref_for_level vref (level+1) || vref_for_index idx level"
lemma pt_index_vref_for_level[simp]:
"\<lbrakk> level \<le> max_pt_level; idx \<le> mask (ptTranslationBits (level = max_pt_level)) \<rbrakk> \<Longrightarrow>
pt_index level (vref_for_level vref (level + 1) || vref_for_index idx level) = idx"
using pt_bits_left_bound[of "level"]
apply (simp add: pt_index_def vref_for_level_def pt_bits_left_bound_def)
apply word_eqI
by (auto simp: bit_simps pt_bits_left_def size_max_pt_level split: if_split_asm)
lemma table_index_pt_slot_offset:
"\<lbrakk> is_aligned p (pt_bits (level = max_pt_level)); level \<le> max_pt_level;
LENGTH('a) = ptTranslationBits (level = max_pt_level) \<rbrakk> \<Longrightarrow>
idx \<le> mask (ptTranslationBits (level = max_pt_level)) \<rbrakk> \<Longrightarrow>
table_index (level = max_pt_level) (pt_slot_offset level p (vref_for_level_idx vref idx level)) = idx"
for idx::"'a::len word"
using pt_bits_left_bound[of "level"]
using pt_bits_left_bound[of "level+1"]
apply (simp add: pt_slot_offset_def pt_index_def vref_for_level_def pt_bits_left_bound_def)
apply (subst word_plus_and_or_coroll)
apply word_eqI
apply (clarsimp simp: bit_simps)
apply word_eqI
apply (clarsimp simp: bit_simps pt_bits_left_def split: if_split_asm)
done
lemma table_index_pt_slot_offset_pt:
"\<lbrakk> is_aligned p (pt_bits False); level \<le> max_pt_level; level \<noteq> max_pt_level \<rbrakk> \<Longrightarrow>
ptable_index (pt_slot_offset level p (vref_for_level_idx vref idx level)) = idx"
for idx :: pt_index
using table_index_pt_slot_offset[where level=level and 'a=pt_index_len]
by (simp add: bit_simps)
lemma table_index_pt_slot_offset_vs:
"\<lbrakk> is_aligned p (pt_bits True); level= max_pt_level \<rbrakk> \<Longrightarrow>
vsroot_index (pt_slot_offset level p (vref_for_level_idx vref idx level)) = idx"
for idx :: vs_index
using table_index_pt_slot_offset[where level=max_pt_level and 'a=vs_index_len]
by (simp add: bit_simps)
by simp
lemma vs_lookup_vref_for_level_eq1:
"vref_for_level vref' (bot_level+1) = vref_for_level vref (bot_level+1) \<Longrightarrow>
@ -2479,10 +2443,9 @@ lemma vs_lookup_vref_for_level_eq1:
done
lemma vref_for_level_idx[simp]:
"\<lbrakk> level \<le> max_pt_level; LENGTH('a) = ptTranslationBits (level = max_pt_level) \<rbrakk> \<Longrightarrow>
"\<lbrakk> level \<le> max_pt_level; idx \<le> mask (ptTranslationBits (level = max_pt_level)) \<rbrakk> \<Longrightarrow>
vref_for_level (vref_for_level_idx vref idx level) (level + 1) =
vref_for_level vref (level + 1)"
for idx :: "'a::len word"
apply (simp add: vref_for_level_def pt_bits_left_def)
apply (rule conjI, clarsimp)
apply (rule conjI; clarsimp; word_eqI_solve simp: bit_simps level_defs dest: bit_imp_possible_bit)
@ -2505,17 +2468,16 @@ lemma vref_for_level_user_regionD:
lemma vref_for_level_idx_canonical_user:
"\<lbrakk> vref \<le> canonical_user; level \<le> max_pt_level;
LENGTH('a) = ptTranslationBits (level = max_pt_level);
level = max_pt_level \<longrightarrow> ucast idx \<notin> invalid_mapping_slots \<rbrakk> \<Longrightarrow>
idx \<le> mask (ptTranslationBits (level = max_pt_level));
level = max_pt_level \<longrightarrow> ucast idx \<notin> invalid_mapping_slots \<rbrakk> \<Longrightarrow>
vref_for_level_idx vref idx level \<le> canonical_user"
for idx :: "'a::len word"
apply (simp add: canonical_user_def le_mask_high_bits ipa_size_def word_size split: if_split_asm)
apply (clarsimp simp: canonical_user_def le_mask_high_bits ipa_size_def word_size split: if_split_asm)
apply (cases "level = max_pt_level")
apply (clarsimp simp: bit_simps pt_bits_left_def size_max_pt_level)
apply (drule bit_imp_possible_bit)
apply (simp add: bit_simps pt_bits_left_def size_max_pt_level asid_pool_level_eq word_size)
apply (frule bit_imp_possible_bit)
apply simp
apply (clarsimp simp: bit_simps pt_bits_left_def size_max_pt_level asid_pool_level_eq)
apply (drule bit_imp_possible_bit)
apply (simp add: bit_simps pt_bits_left_def size_max_pt_level asid_pool_level_eq)
apply (frule bit_imp_possible_bit)
apply simp
apply (drule xt1(11), simp)
apply (subst (asm) vm_level_size_less[symmetric])
@ -2525,10 +2487,10 @@ lemma vref_for_level_idx_canonical_user:
apply (frule bit_imp_possible_bit)
apply simp
apply (prop_tac "idx \<le> mask valid_vs_slot_bits")
apply (simp add: invalid_mapping_slots_def bit_simps le_mask_high_bits word_size)
apply (fastforce simp: invalid_mapping_slots_def bit_simps le_mask_high_bits word_size)
apply (simp add: bit_simps le_mask_high_bits word_size)
apply (clarsimp simp: bit_simps pt_bits_left_def size_max_pt_level asid_pool_level_eq word_size)
apply (drule bit_imp_possible_bit)
apply (simp add: bit_simps pt_bits_left_def size_max_pt_level asid_pool_level_eq word_size)
apply (frule bit_imp_possible_bit)
apply (drule xt1(11), simp)
apply (subst (asm) vm_level_size_less[symmetric])
apply (simp add: size_max_pt_level)
@ -2541,28 +2503,9 @@ lemma vs_lookup_table_pt_step:
valid_pt_range pt; pspace_distinct s \<rbrakk> \<Longrightarrow>
\<exists>vref'. vs_lookup_target level asid vref' s = Some (level, p') \<and>
vref' \<in> user_region"
apply (cases pt; clarsimp)
apply (rename_tac vs idx)
apply (rule_tac x="vref_for_level vref (level+1) ||
(ucast (idx::vs_index) << pt_bits_left level)" in exI)
apply (simp add: vs_lookup_target_def vs_lookup_slot_def in_omonad)
apply (rule conjI)
apply (rule_tac x="pt_slot_offset level p (vref_for_level_idx vref idx level)" in exI)
apply (rule conjI, clarsimp)
apply (rule_tac x=level in exI)
apply (rule_tac x=p in exI)
apply clarsimp
apply (subst vs_lookup_vref_for_level_eq1)
prefer 2
apply assumption
apply (simp add: bit_simps)
apply (fastforce simp: ptes_of_Some in_omonad is_aligned_pt_slot_offset_pte
table_index_pt_slot_offset_vs)
apply (prop_tac "idx \<notin> invalid_mapping_slots", clarsimp simp: valid_pt_range_def)
apply (simp add: user_region_def vref_for_level_idx_canonical_user bit_simps)
apply (rename_tac ptn idx)
apply (rule_tac x="vref_for_level vref (level+1) ||
(ucast (idx::pt_index) << pt_bits_left level)" in exI)
apply (drule pt_rangeD, erule exE)
apply (rename_tac idx)
apply (rule_tac x="vref_for_level_idx vref idx level" in exI)
apply (simp add: vs_lookup_target_def vs_lookup_slot_def in_omonad)
apply (rule conjI)
apply (rule_tac x="pt_slot_offset level p (vref_for_level_idx vref idx level)" in exI)
@ -2575,10 +2518,12 @@ lemma vs_lookup_table_pt_step:
apply (subst vs_lookup_vref_for_level_eq1)
prefer 2
apply assumption
apply (simp add: bit_simps)
apply (fastforce simp: ptes_of_Some in_omonad is_aligned_pt_slot_offset_pte
table_index_pt_slot_offset_pt)
apply (simp add: user_region_def vref_for_level_idx_canonical_user bit_simps)
apply simp
apply (clarsimp simp: ptes_of_Some in_omonad is_aligned_pt_slot_offset_pte
table_index_pt_slot_offset)
apply (prop_tac "is_VSRootPT pt \<longrightarrow> ucast idx \<notin> invalid_mapping_slots")
apply (cases pt; clarsimp simp: valid_pt_range_def)
apply (simp add: user_region_def vref_for_level_idx_canonical_user)
done
lemma pte_rights_PagePTE[simp]:

View File

@ -71,20 +71,16 @@ text \<open>The base address of the table a page table entry at p is in (assumin
locale_abbrev table_base :: "bool \<Rightarrow> obj_ref \<Rightarrow> obj_ref" where
"table_base is_vspace p \<equiv> p && ~~mask (pt_bits is_vspace)"
text \<open>The index within the page table that a page table entry at p addresses\<close>
locale_abbrev table_index :: "bool \<Rightarrow> obj_ref \<Rightarrow> 'a::len word" where
"table_index is_vspace p \<equiv> ucast (p && mask (pt_bits is_vspace) >> pte_bits)"
text \<open>The index within the page table that a page table entry at p addresses. We return a
@{typ machine_word}, which is the slice of the provided address that represents the index in the
table of the specified table type.\<close>
locale_abbrev table_index :: "bool \<Rightarrow> obj_ref \<Rightarrow> machine_word" where
"table_index is_vspace p \<equiv> p && mask (pt_bits is_vspace) >> pte_bits"
locale_abbrev vsroot_index :: "obj_ref \<Rightarrow> vs_index" where
"vsroot_index \<equiv> table_index True"
locale_abbrev ptable_index :: "obj_ref \<Rightarrow> pt_index" where
"ptable_index \<equiv> table_index False"
definition pt_pte :: "pt \<Rightarrow> obj_ref \<Rightarrow> pte" where
"pt_pte pt p \<equiv> case pt of
VSRootPT vs \<Rightarrow> vs (vsroot_index p)
| NormalPT pt \<Rightarrow> pt (ptable_index p)"
text \<open>Use an index computed by @{const table_index} and apply it to a page table. Bits higher than
the table index width will be ignored.\<close>
definition pt_apply :: "pt \<Rightarrow> machine_word \<Rightarrow> pte" where
"pt_apply pt idx \<equiv> case pt of NormalPT npt \<Rightarrow> npt (ucast idx) | VSRootPT vs \<Rightarrow> vs (ucast idx)"
text \<open>Extract a PTE from the page table of a specific level\<close>
definition level_pte_of :: "bool \<Rightarrow> obj_ref \<Rightarrow> (obj_ref \<rightharpoonup> pt) \<rightharpoonup> pte" where
@ -92,7 +88,7 @@ definition level_pte_of :: "bool \<Rightarrow> obj_ref \<Rightarrow> (obj_ref \<
oassert (is_aligned p pte_bits);
pt \<leftarrow> oapply (table_base is_vspace p);
oassert (is_vspace = is_VSRootPT pt);
oreturn $ pt_pte pt p
oreturn $ pt_apply pt (table_index is_vspace p)
}"
locale_abbrev ptes_of :: "'z::state_ext state \<Rightarrow> bool \<Rightarrow> obj_ref \<rightharpoonup> pte" where
@ -103,17 +99,20 @@ text \<open>The following function takes a pointer to a PTE in kernel memory and
locale_abbrev get_pte :: "bool \<Rightarrow> obj_ref \<Rightarrow> (pte,'z::state_ext) s_monad" where
"get_pte is_vspace \<equiv> gets_map (swp ptes_of is_vspace)"
definition pt_upd :: "pt \<Rightarrow> obj_ref \<Rightarrow> pte \<Rightarrow> pt" where
"pt_upd pt p pte \<equiv> case pt of
VSRootPT vs \<Rightarrow> VSRootPT (vs(vsroot_index p := pte))
| NormalPT pt \<Rightarrow> NormalPT (pt(ptable_index p := pte))"
text \<open>The update function that corresponds to @{const pt_apply}. Also expects an index computed
with @{const table_index} for the correct page table type.\<close>
definition pt_upd :: "pt \<Rightarrow> machine_word \<Rightarrow> pte \<Rightarrow> pt" where
"pt_upd pt idx pte \<equiv> case pt of
VSRootPT vs \<Rightarrow> VSRootPT (vs(ucast idx := pte))
| NormalPT pt \<Rightarrow> NormalPT (pt(ucast idx := pte))"
definition store_pte :: "bool \<Rightarrow> obj_ref \<Rightarrow> pte \<Rightarrow> (unit,'z::state_ext) s_monad" where
"store_pte is_vspace p pte \<equiv> do
assert (is_aligned p pte_bits);
base \<leftarrow> return $ table_base is_vspace p;
pt \<leftarrow> get_pt base;
set_pt base (pt_upd pt p pte)
set_pt base (pt_upd pt (table_index is_vspace p) pte)
od"