diff --git a/Hiding_Type_Variables.thy b/Hiding_Type_Variables.thy index 341f636..9eefc1b 100644 --- a/Hiding_Type_Variables.thy +++ b/Hiding_Type_Variables.thy @@ -54,6 +54,7 @@ signature HIDE_TVAR = sig type hide_varT = { name: string, tvars: typ list, + typ_syn_tab : (string * typ list) Symtab.table, print_mode: print_mode, parse_mode: parse_mode } @@ -64,7 +65,7 @@ signature HIDE_TVAR = sig val update_mode : string -> print_mode option -> parse_mode option -> theory -> theory val lookup : theory -> string -> hide_varT option - val hide_tvar_tr' : string -> Proof.context -> typ -> term list -> term + val hide_tvar_tr' : string -> Proof.context -> term list -> term val hide_tvar_ast_tr : Proof.context -> Ast.ast list -> Ast.ast end @@ -74,11 +75,13 @@ structure Hide_Tvar : HIDE_TVAR = struct type hide_varT = { name: string, tvars: typ list, + typ_syn_tab : (string * typ list) Symtab.table, print_mode: print_mode, parse_mode: parse_mode } type hide_tvar_tab = (hide_varT) Symtab.table - fun merge_assert_tab (tab,tab') = Symtab.merge (op =) (tab,tab') + fun hide_tvar_eq (a, a') = (#name a) = (#name a') + fun merge_assert_tab (tab,tab') = Symtab.merge hide_tvar_eq (tab,tab') structure Data = Generic_Data ( @@ -117,10 +120,11 @@ structure Hide_Tvar : HIDE_TVAR = struct SOME m => m | NONE => #parse_mode old_entry val entry = { - name = name, - tvars = #tvars old_entry, - print_mode = print_m, - parse_mode = parse_m + name = name, + tvars = #tvars old_entry, + typ_syn_tab = #typ_syn_tab old_entry, + print_mode = print_m, + parse_mode = parse_m } in Symtab.update (name,entry) tab @@ -136,23 +140,119 @@ structure Hide_Tvar : HIDE_TVAR = struct Symtab.lookup tab name end - fun hide_tvar_tr' tname ctx typ terms = + fun obtain_normalized_vname lookup_table vname = + case List.find (fn e => fst e = vname) lookup_table of + SOME (_,idx) => (lookup_table, Int.toString idx) + | NONE => let + fun max_idx [] = 0 + | max_idx ((_,idx)::lt) = Int.max(idx,max_idx lt) + + val idx = (max_idx lookup_table ) + 1 + in + ((vname,idx)::lookup_table, Int.toString idx) end + + fun normalize_typvar_type lt (Type (a, Ts)) = + let + fun switch (a,b) = (b,a) + val (Ts', lt') = fold_map (fn t => fn lt => switch (normalize_typvar_type lt t)) Ts lt + in + (lt', Type (a, Ts')) + end + | normalize_typvar_type lt (TFree (vname, S)) = + let + val (lt, vname) = obtain_normalized_vname lt (vname) + in + (lt, TFree( vname, S)) + end + | normalize_typvar_type lt (TVar (xi, S)) = + let + val (lt, vname) = obtain_normalized_vname lt (Term.string_of_vname xi) + in + (lt, TFree( vname, S)) + end + + fun normalize_typvar_type' t = snd ( normalize_typvar_type [] t) + + fun mk_p s = s (* "("^s^")" *) + + fun key_of_type (Type(a, TS)) = mk_p (a^String.concat(map key_of_type TS)) + | key_of_type (TFree (vname, _)) = mk_p vname + | key_of_type (TVar (xi, _ )) = error("TVar not supported in key_of_type: "^ + (Term.string_of_vname xi)) + val key_of_type' = key_of_type o normalize_typvar_type' + + + fun normalize_typvar_term lt (Const (a, t)) = (lt, Const(a, t)) + | normalize_typvar_term lt (Free (a, t)) = let + val (lt, vname) = obtain_normalized_vname lt a + in + (lt, Free(vname,t)) + end + | normalize_typvar_term lt (Var (xi, t)) = + let + val (lt, vname) = obtain_normalized_vname lt (Term.string_of_vname xi) + in + (lt, Free(vname,t)) + end + | normalize_typvar_term lt (Bound (i)) = (lt, Bound(i)) + | normalize_typvar_term lt (Abs(s,ty,tr)) = + let + val (lt,tr) = normalize_typvar_term lt tr + in + (lt, Abs(s,ty,tr)) + end + | normalize_typvar_term lt (t1$t2) = + let + val (lt,t1) = normalize_typvar_term lt t1 + val (lt,t2) = normalize_typvar_term lt t2 + in + (lt, t1$t2) + end + + + fun normalize_typvar_term' t = snd(normalize_typvar_term [] t) + + fun key_of_term (Const(s,_)) = if String.isPrefix "\<^type>" s + then Lexicon.unmark_type s + else "" + | key_of_term (Free(s,_)) = s + | key_of_term (Var(_,_)) = error("Var() not supported in key_of_term") + | key_of_term (Bound(_)) = error("Bound() not supported in key_of_term") + | key_of_term (Abs(_,_,_)) = error("Abs() not supported in key_of_term") + | key_of_term (t1$t2) = (key_of_term t1)^(key_of_term t2) + + val key_of_term' = key_of_term o normalize_typvar_term' + + + fun hide_tvar_tr' tname ctx terms = let + val mtyp = Syntax.parse_typ ctx tname (* no type checking *) - val fq_name = case mtyp of - Type(s,_) => s + + val (fq_name,_) = case mtyp of + Type(s,ts) => (s,ts) | _ => error("Complex type not (yet) supported.") + val local_name_of = hd o rev o String.fields (fn c => c = #".") - val local_tname = local_name_of tname - val hide_type = Syntax.const("(_) "^(local_tname)) - val reg_type = Term.list_comb(Const(local_tname,typ),terms) + + fun hide_type tname = Syntax.const("(_) "^tname) + + val reg_type_as_term = Term.list_comb(Const(Lexicon.mark_type tname,dummyT),terms) + val key = key_of_term' reg_type_as_term + in case lookup (Proof_Context.theory_of ctx) fq_name of - NONE => reg_type - | SOME e => case (#print_mode e) of - always => hide_type - | default_only => hide_type (* TODO *) - | noprint => reg_type + NONE => raise Match + | SOME e => let + val tname = case Symtab.lookup (#typ_syn_tab e) key of + NONE => local_name_of tname + | SOME (s,_) => local_name_of s + in + case (#print_mode e) of + always => hide_type tname + | default_only => hide_type tname (* TODO *) + | noprint => raise Match + end end fun hide_tvar_ast_tr ctx (a::_)= @@ -184,6 +284,13 @@ structure Hide_Tvar : HIDE_TVAR = struct val typ = Syntax.parse_typ ctx typ_str val (name,tvars) = case typ of Type(name,tvars) => (name,tvars) | _ => error("Unsupported type structure.") + + val base_typ = Syntax.read_typ ctx typ_str + val (base_name,base_tvars) = case base_typ of Type(name,tvars) => (name,tvars) + | _ => error("Unsupported type structure.") + + val base_key = key_of_type' base_typ + val print_m = case print_mode of SOME m => m | NONE => always @@ -193,12 +300,45 @@ structure Hide_Tvar : HIDE_TVAR = struct val entry = { name = name, tvars = tvars, + typ_syn_tab = Symtab.empty:((string * typ list) Symtab.table), print_mode = print_m, parse_mode = parse_m } - fun reg tab = Symtab.update_new(name, entry) tab - val thy = Sign.typed_print_translation - [("\<^type>"^name, hide_tvar_tr' name)] thy + + val base_entry = if name = base_name + then + { + name = "", + tvars = [], + typ_syn_tab = Symtab.empty:((string * typ list) Symtab.table), + print_mode = noprint, + parse_mode = noparse + } + else case lookup thy base_name of + SOME e => e + | NONE => error ("No entry found for "^base_name^ + " (via "^name^")") + + val base_entry = { + name = #name base_entry, + tvars = #tvars base_entry, + typ_syn_tab = Symtab.update (base_key, (name, base_tvars)) + (#typ_syn_tab (base_entry)), + print_mode = #print_mode base_entry, + parse_mode = #parse_mode base_entry + } + + fun reg tab = let + val tab = Symtab.update_new(name, entry) tab + val tab = if name = base_name + then tab + else Symtab.update(base_name, base_entry) tab + in + tab + end + + val thy = Sign.print_translation + [(Lexicon.mark_type name, hide_tvar_tr' name)] thy in Context.theory_of ( (Data.map reg) (Context.Theory thy)) @@ -241,7 +381,7 @@ ML\ section\Examples\ subsection\Print Translation\ -datatype ('alpha, 'beta) foobar = foo 'alpha | bar 'beta +datatype ('a, 'b) foobar = foo 'a | bar 'b type_synonym ('a, 'b, 'c, 'd) baz = "('a+'b, 'a \ 'b) foobar" definition f::"('a, 'b) foobar \ ('a, 'b) foobar \ ('a, 'b) foobar" @@ -261,7 +401,7 @@ register_default_tvars "('alpha, 'beta, 'gamma, 'delta) baz" (always,active) update_default_tvars_mode "_ foobar" (noprint,noparse) assert[string_of_thm_equal, - thm_def="f_def", + thm_def="f_def", str="f (a::('a, 'b) foobar) (b::('a, 'b) foobar) = a"] assert[string_of_thm_equal, thm_def="g_def", @@ -272,7 +412,7 @@ update_default_tvars_mode "_ foobar" (always,noparse) assert[string_of_thm_equal, thm_def="f_def", str="f (a::(_) foobar) (b::(_) foobar) = a"] assert[string_of_thm_equal, - thm_def="g_def", str="g (a::(_) foobar) (b::(_) foobar) = a"] + thm_def="g_def", str="g (a::(_) baz) (b::(_) baz) = a"] subsection\Parse Translation\ update_default_tvars_mode "_ foobar" (noprint,active)