Skip to content

Commit

Permalink
Add bmv_monad_of_typ function
Browse files Browse the repository at this point in the history
  • Loading branch information
jvanbruegge committed Jan 4, 2025
1 parent 0a350f8 commit 85f1c48
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 82 deletions.
210 changes: 132 additions & 78 deletions Tools/bmv_monad_def.ML
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ signature BMV_MONAD_DEF = sig
val bmv_monad_def: BNF_Def.inline_policy -> (Proof.context -> BNF_Def.fact_policy)
-> (binding -> binding) -> (Proof.context -> tactic) bmv_monad_model -> local_theory -> (bmv_monad * thm list) * local_theory

val compose_bmv_monad: (binding -> binding) -> bmv_monad -> bmv_monad list -> local_theory
val compose_bmv_monad: (binding -> binding) -> bmv_monad -> (bmv_monad, typ) MRBNF_Util.either list -> local_theory
-> (bmv_monad * thm list) * local_theory
end

Expand Down Expand Up @@ -160,7 +160,7 @@ fun morph_bmv_monad_param phi f ({ Map, Supps, axioms, Map_Sb, Supp_Sb, Map_Vrs
Supp_Sb = map f Supp_Sb,
Map_Vrs = map (map (Option.map f)) Map_Vrs
}: 'b bmv_monad_param;

datatype bmv_monad = BMV of {
ops: typ list,
bd: term,
Expand Down Expand Up @@ -551,7 +551,7 @@ fun define_bmv_monad_consts const_policy fact_policy qualify (model: 'a bmv_mona
val (lthy, old_lthy) = `Local_Theory.end_nested lthy;
val phi = Proof_Context.export_morphism old_lthy lthy;

val vars = #frees model @ #lives model @ #lives' model;
val vars = map TFree (rev (Term.add_tfreesT (nth (#ops model) (#leader model)) [])) @ #lives' model;
val subst = (map (Morphism.typ phi) vars ~~ vars);

val phi' = Morphism.term_morphism "bmv_monad_export" (Term.subst_atomic_types subst o Morphism.term phi)
Expand Down Expand Up @@ -716,50 +716,75 @@ fun register_bnf_as_pbmv_monad name lthy =
- does not appear in the codomain of any (=of any **other** SOp) Injection,
*)

fun compose_bmv_monad qualify (outer : bmv_monad) (inners : bmv_monad list) lthy =
fun compose_bmv_monad qualify (outer : bmv_monad) (inners : (bmv_monad, typ) either list) lthy =
let
val _ = if length (lives_of_bmv_monad outer) <> length inners then
error "Outer needs exactly as many lives as there are inners" else ()

val filter_bmvs = map_filter (fn Inl x => SOME x | _ => NONE);

val frees = fold (fn a =>
let val (n, s) = dest_TFree a
in Symtab.map_default (n, s) (curry (Sign.inter_sort (Proof_Context.theory_of lthy)) s) end
) (frees_of_bmv_monad outer @ maps frees_of_bmv_monad (filter_bmvs inners)) Symtab.empty;

fun mk_sign_morph bmv =
morph_bmv_monad (MRBNF_Util.subst_typ_morphism (map (fn a =>
let val (n, _) = dest_TFree a;
in (a, TFree (n, the (Symtab.lookup frees n))) end
) (frees_of_bmv_monad bmv))) bmv;
fun mk_T_morph T =
let val vars = Term.add_tfreesT T [];
in Term.typ_subst_atomic (map (fn x =>
(TFree x, the_default (TFree x) (Option.map (TFree o pair (fst x)) (Symtab.lookup frees (fst x))))
) vars) T end
val outer = mk_sign_morph outer;
val inners = map (map_sum mk_sign_morph mk_T_morph) inners;

val bmvs = Typtab.make_distinct (flat (map (fn bmv => (#ops bmv ~~
((#params bmv) ~~ (#Injs bmv) ~~ (#Sbs bmv) ~~ (#Vrs bmv) ~~ map SOME (#axioms bmv) ~~ replicate (length (#Sbs bmv)) (SOME bmv))
)) (map Rep_bmv inners)));
)) (map_filter (fn Inl bmv => SOME (Rep_bmv bmv) | Inr _ => NONE) inners)));

val outer_ops' = map (fn T => if Typtab.defined bmvs T then NONE else SOME T) (
map (Term.typ_subst_atomic (lives_of_bmv_monad outer ~~ map (fn bmv =>
nth (ops_of_bmv_monad bmv) (leader_of_bmv_monad bmv)
map (Term.typ_subst_atomic (lives_of_bmv_monad outer ~~ map (
fn Inl bmv => nth (ops_of_bmv_monad bmv) (leader_of_bmv_monad bmv)
| Inr T => T
) inners)) (ops_of_bmv_monad outer)
);

val ((Sbs, Injs), Vrs) = apfst split_list (split_list (@{map 5} (fn NONE => K (K (K (K ((NONE, NONE), NONE)))) | SOME T => (fn NONE => K (K (K ((NONE, NONE), NONE)))
| SOME param => fn Sb => fn Injs => fn Vrs =>
let
val ((Sbs, Ts), (Injss, Vrsss)) = apfst split_list (apsnd split_list (split_list (map (fn bmv =>
val ((Sbs, Ts), (Injss, Vrsss)) = apfst split_list (apsnd split_list (split_list (map (fn Inl bmv =>
let fun pick xs = nth xs (leader_of_bmv_monad bmv)
in (
(pick (Sbs_of_bmv_monad bmv), pick (ops_of_bmv_monad bmv)),
(pick (Injs_of_bmv_monad bmv), pick (Vrs_of_bmv_monad bmv))
(SOME (pick (Sbs_of_bmv_monad bmv)), pick (ops_of_bmv_monad bmv)),
(SOME (pick (Injs_of_bmv_monad bmv)), SOME (pick (Vrs_of_bmv_monad bmv)))
) end
| Inr T => ((NONE, T), (NONE, NONE))
) inners)));
val subst = (lives_of_bmv_monad outer @ lives'_of_bmv_monad outer) ~~ (Ts @ Ts);
val Injs' = distinct ((op=) o apply2 fastype_of) (Injs @ flat Injss);
val Injs' = distinct ((op=) o apply2 fastype_of) (Injs @ flat (map_filter I Injss));
val ((fs, x), _) = lthy
|> mk_Frees "f" (map fastype_of Injs')
||>> apfst hd o mk_Frees "x" [T];

val Vrs' = @{fold 4} (fn i => fn inner => @{fold 2} (fn Inj => fn Vrs => fn tab =>
case Typtab.lookup tab (fastype_of Inj) of
NONE => tab
| SOME inner_tab =>
let val inner_tab' = @{fold 2} (fn NONE => K I | SOME Vrs => fn free =>
Typtab.map_default (free, [(i, Vrs)]) (cons (i, Vrs))
) Vrs (frees_of_bmv_monad inner) inner_tab;
in Typtab.update (fastype_of Inj, inner_tab') tab end
)) (0 upto length inners) (outer :: inners) (Injs :: Injss) (Vrs :: Vrsss) (Typtab.make (map (rpair Typtab.empty o fastype_of) Injs'));
val Vrs' = @{fold 4} (fn i => fn inner => fn Injs => fn Vrss => fn tab => case inner of
Inr _ => tab
| Inl inner => @{fold 2} (fn Inj => fn Vrs => fn tab =>
case Typtab.lookup tab (fastype_of Inj) of
NONE => tab
| SOME inner_tab =>
let val inner_tab' = @{fold 2} (fn NONE => K I | SOME Vrs => fn free =>
Typtab.map_default (free, [(i, Vrs)]) (cons (i, Vrs))
) Vrs (frees_of_bmv_monad inner) inner_tab;
in Typtab.update (fastype_of Inj, inner_tab') tab end
) (the Injs) (the Vrss) tab
) (0 upto length inners) (Inl outer :: inners) (SOME Injs :: Injss) (SOME Vrs :: Vrsss) (Typtab.make (map (rpair Typtab.empty o fastype_of) Injs'));

val frees = distinct (op=) (maps snd (Typtab.dest (Typtab.map (K Typtab.keys) Vrs')));
val Supps = map (Term.subst_atomic_types subst) (#Supps param);

val Supps = map (Term.subst_atomic_types subst) (#Supps param)
val Vrs' = map (fn Inj => map (fn free => Option.mapPartial (fn xs =>
let
val Vrss = distinct (op=) (rev xs);
Expand All @@ -779,12 +804,15 @@ fun compose_bmv_monad qualify (outer : bmv_monad) (inners : bmv_monad list) lthy
val find_fs = map (fn Inj =>
the (List.find (fn f => fastype_of f = fastype_of Inj) fs)
);
fun mk_comp t = if true orelse length (binder_types (fastype_of Sb)) > 1 then
HOLogic.mk_comp (t, Term.list_comb (Sb, find_fs Injs))
else t
in ((
SOME (Term.subst_atomic_types subst (
fold_rev (Term.absfree o dest_Free) fs (HOLogic.mk_comp (
Term.list_comb (#Map param,
map2 (fn Sb => fn Injs => Term.list_comb (Sb, find_fs Injs)) Sbs Injss
), Term.list_comb (Sb, find_fs Injs)
fold_rev (Term.absfree o dest_Free) fs (mk_comp (
Term.list_comb (#Map param, @{map 3} (fn Inr T => K (K (HOLogic.id_const T))
| _ => fn Sb => fn Injs => Term.list_comb (the Sb, find_fs (the Injs))
) inners Sbs Injss)
))
)),
SOME Injs'),
Expand All @@ -806,7 +834,7 @@ fun compose_bmv_monad qualify (outer : bmv_monad) (inners : bmv_monad list) lthy
let val bmvs' = Typtab.delete T bmvs
in (SOME (add_ops T Injs bmvs'), bmvs') end
end
) Injs bmvs)))
) Injs bmvs)));

fun pick xs = nth xs (leader_of_bmv_monad outer)
val ops = add_ops (the (pick outer_ops')) (the (pick Injs)) bmvs;
Expand All @@ -831,10 +859,12 @@ fun compose_bmv_monad qualify (outer : bmv_monad) (inners : bmv_monad list) lthy

val ops' = subtract (fn (bmv, T) => hd (ops_of_bmv_monad bmv) = T) bmv_ops ops;

val inners' = filter_bmvs inners;

val idxs = map (fn T => find_index (curry (op=) T) ops) ops';
val Vrs = map (the o nth Vrs) idxs;
val Injs = map (the o nth Injs) idxs;
val frees = distinct (op=) (maps frees_of_bmv_monad (outer :: inners));
val frees = distinct (op=) (maps frees_of_bmv_monad (outer :: inners'));
val outer_Vrs = map (nth (Vrs_of_bmv_monad outer)) idxs;

val model = {
Expand All @@ -844,58 +874,61 @@ fun compose_bmv_monad qualify (outer : bmv_monad) (inners : bmv_monad list) lthy
bd_infinite_regular_card_order = fn ctxt => rtac ctxt (bd_infinite_regular_card_order_of_bmv_monad outer) 1,
var_class = var_class_of_bmv_monad outer,
frees = frees,
lives = distinct (op=) (maps lives_of_bmv_monad inners),
lives' = distinct (op=) (maps lives'_of_bmv_monad inners),
lives = distinct (op=) (maps lives_of_bmv_monad inners'),
lives' = distinct (op=) (maps lives'_of_bmv_monad inners'),
params = replicate (length ops') NONE,
leader = 0,
Injs = Injs,
Sbs = map (the o nth Sbs) idxs,
Vrs = Vrs,
tacs = @{map 5} (fn axioms => fn param => fn Injs => fn Vrs => fn outer_Vrs => {
Sb_Inj = fn ctxt => EVERY [
Local_Defs.unfold0_tac ctxt (#Sb_Inj axioms :: @{thms o_id}),
Local_Defs.unfold0_tac ctxt (
#Map_id (#axioms param)
:: maps (map #Sb_Inj o axioms_of_bmv_monad) inners
),
rtac ctxt refl 1
Sb_Inj = fn ctxt => EVERY1 [
rtac ctxt trans,
rtac ctxt @{thm arg_cong2[OF _ refl, of _ _ "(\<circ>)"]},
rtac ctxt ext,
rtac ctxt (trans OF [#Map_cong (#axioms param), #Map_id (#axioms param) RS fun_cong]),
REPEAT_DETERM o resolve_tac ctxt (refl :: maps (map (fn ax =>
#Sb_Inj ax RS fun_cong
) o axioms_of_bmv_monad) inners'),
rtac ctxt @{thm trans[OF id_o]},
rtac ctxt (#Sb_Inj axioms)
],
Sb_comp_Injs = map (fn thm => fn ctxt =>
print_tac ctxt "Sb_comp_Inj"
) (#Sb_comp_Injs axioms),
Sb_comp = fn ctxt => EVERY1 [
K (Local_Defs.unfold0_tac ctxt @{thms comp_assoc}),
rtac ctxt @{thm trans[OF comp_assoc]},
rtac ctxt trans,
rtac ctxt @{thm arg_cong2[OF refl, of _ _ "(\<circ>)"]},
SELECT_GOAL (Local_Defs.unfold0_tac ctxt @{thms comp_assoc[symmetric]}),
rtac ctxt @{thm trans[OF comp_assoc[symmetric]]},
rtac ctxt trans,
rtac ctxt @{thm arg_cong2[OF _ refl, of _ _ "(\<circ>)"]},
rtac ctxt (#Map_Sb param RS sym),
REPEAT_DETERM o assume_tac ctxt,
SELECT_GOAL (Local_Defs.unfold0_tac ctxt @{thms comp_assoc}),
rtac ctxt @{thm trans[OF comp_assoc]},
rtac ctxt @{thm arg_cong2[OF refl, of _ _ "(\<circ>)"]},
rtac ctxt (#Sb_comp axioms),
REPEAT_DETERM o assume_tac ctxt,
K (Local_Defs.unfold0_tac ctxt @{thms comp_assoc[symmetric]}),
rtac ctxt @{thm trans[OF comp_assoc[symmetric]]},
rtac ctxt @{thm arg_cong2[OF _ refl, of _ _ "(\<circ>)"]},
rtac ctxt trans,
rtac ctxt (#Map_comp (#axioms param)),
rtac ctxt ext,
rtac ctxt (#Map_cong (#axioms param)),
REPEAT_DETERM o EVERY' [
EqSubst.eqsubst_tac ctxt [0] (maps (map #Sb_comp o axioms_of_bmv_monad) inners),
EqSubst.eqsubst_tac ctxt [0] (maps (map #Sb_comp o axioms_of_bmv_monad) inners'),
REPEAT_DETERM o assume_tac ctxt,
rtac ctxt refl
]
],
Vrs_bds = map (map (Option.map (K (fn ctxt => EVERY1 [
REPEAT_DETERM o resolve_tac ctxt (
map (fn thm =>
maps (map_filter I) (#Vrs_bds axioms)
@ maps (maps (maps (map_filter I) o #Vrs_bds) o axioms_of_bmv_monad) inners'
@ #Supp_bd (#axioms param)
@ map (fn thm =>
thm OF [bd_infinite_regular_card_order_of_bmv_monad outer]
) @{thms infinite_regular_card_order_Un infinite_regular_card_order_UN}
@ maps (map_filter I) (#Vrs_bds axioms)
@ maps (maps (maps (map_filter I) o #Vrs_bds) o axioms_of_bmv_monad) inners
@ #Supp_bd (#axioms param)
)
])))) Vrs,
Vrs_Injs = map (map (Option.map (fn thm => fn ctxt =>
Expand All @@ -904,37 +937,46 @@ fun compose_bmv_monad qualify (outer : bmv_monad) (inners : bmv_monad list) lthy
Vrs_Sbs = map (map (Option.map (K (fn ctxt => EVERY1 [
K (Local_Defs.unfold0_tac ctxt @{thms UN_Un}),
REPEAT_DETERM o rtac ctxt @{thm arg_cong2[of _ _ _ _ "(\<union>)"]},
TRY o EVERY' [
EqSubst.eqsubst_tac ctxt [0] [#Map_Sb param],
REPEAT_DETERM1 o assume_tac ctxt,
SELECT_GOAL (Local_Defs.unfold0_tac ctxt @{thms comp_def}),
rtac ctxt trans,
resolve_tac ctxt (maps (map_filter I) (#Vrs_Sbs axioms)),
REPEAT_DETERM1 o assume_tac ctxt,
EqSubst.eqsubst_tac ctxt [0] (maps (map_filter I) (#Map_Vrs param)),
rtac ctxt refl
],
K (Local_Defs.unfold0_tac ctxt (@{thms comp_def} @ #Supp_Map (#axioms param) @ #Supp_Sb param)),
K (Local_Defs.unfold0_tac ctxt (@{thms image_comp[unfolded comp_def] UN_UN_flatten}
@ maps (maps (maps (map_filter I) o #Vrs_Sbs) o axioms_of_bmv_monad) inners
)),
REPEAT_DETERM o rtac ctxt refl
REPEAT_DETERM o FIRST' [
EVERY' [
rtac ctxt @{thm trans[OF arg_cong[OF comp_apply]]},
rtac ctxt trans,
resolve_tac ctxt (maps (map_filter I) (#Map_Vrs param)),
rtac ctxt trans,
resolve_tac ctxt (maps (map_filter I) (#Vrs_Sbs axioms)),
REPEAT_DETERM o assume_tac ctxt,
rtac ctxt refl
],
EVERY' [
rtac ctxt trans,
rtac ctxt @{thm arg_cong[of _ _ "\<lambda>x. \<Union>(_ ` x)"]},
rtac ctxt trans,
rtac ctxt @{thm trans[OF arg_cong[OF comp_apply]]},
resolve_tac ctxt (#Supp_Map (#axioms param)),
rtac ctxt @{thm arg_cong[of _ _ "\<lambda>x. _ ` x"]},
resolve_tac ctxt (#Supp_Sb param),
SELECT_GOAL (Local_Defs.unfold0_tac ctxt @{thms UN_simps(10)}),
rtac ctxt trans,
rtac ctxt @{thm UN_cong},
resolve_tac ctxt (maps (maps (maps (map_filter I) o #Vrs_Sbs) o axioms_of_bmv_monad) inners'),
REPEAT_DETERM o assume_tac ctxt,
SELECT_GOAL (Local_Defs.unfold0_tac ctxt @{thms UN_extend_simps(9)}),
rtac ctxt refl
]
]
])))) Vrs,
Sb_cong = fn ctxt => EVERY1 [
K (Local_Defs.unfold0_tac ctxt @{thms comp_def}),
rtac ctxt @{thm comp_apply_eq},
Subgoal.FOCUS_PREMS (fn {context=ctxt, prems, ...} => EVERY1 [
EqSubst.eqsubst_tac ctxt [0] [Drule.rotate_prems (~(length (
maps (map_filter I) outer_Vrs
))) (#Sb_cong axioms)],
resolve_tac ctxt prems,
etac ctxt @{thm contrapos_pp},
SELECT_GOAL (Local_Defs.unfold0_tac ctxt @{thms Un_iff de_Morgan_disj}),
REPEAT_DETERM o etac ctxt conjE,
assume_tac ctxt,
REPEAT_DETERM o resolve_tac ctxt prems,
rtac ctxt @{thm trans[rotated]},
rtac ctxt (
let val n = length (lives_of_bmv_monad outer);
in mk_arg_cong lthy (n + 1) (#Map param) OF (replicate n refl) end
),
K (prefer_tac 2),
rtac ctxt (#Map_cong (#axioms param)),
K (Local_Defs.unfold0_tac ctxt (#Supp_Sb param)),
EVERY' (map (fn inner => EVERY' [
EVERY' (map (fn Inr _ => rtac ctxt refl | Inl inner => EVERY' [
resolve_tac ctxt (map #Sb_cong (axioms_of_bmv_monad inner)),
REPEAT_DETERM o EVERY' [
REPEAT_DETERM o resolve_tac ctxt prems,
Expand All @@ -946,7 +988,17 @@ fun compose_bmv_monad qualify (outer : bmv_monad) (inners : bmv_monad list) lthy
assume_tac ctxt,
assume_tac ctxt
]
]) inners)
]) inners),
rtac ctxt (#Sb_cong axioms),
REPEAT_DETERM o EVERY' [
resolve_tac ctxt prems,
TRY o EVERY' [
etac ctxt @{thm contrapos_pp},
SELECT_GOAL (Local_Defs.unfold0_tac ctxt @{thms Un_iff de_Morgan_disj}),
REPEAT_DETERM o etac ctxt conjE,
assume_tac ctxt
]
]
]) ctxt
]
} : (Proof.context -> tactic) bmv_monad_axioms)
Expand All @@ -958,12 +1010,14 @@ fun compose_bmv_monad qualify (outer : bmv_monad) (inners : bmv_monad list) lthy
val (res, lthy) = bmv_monad_def BNF_Def.Smart_Inline (K BNF_Def.Dont_Note) qualify model lthy
in (res, lthy) end;

fun pbmv_monad_cmd (((((((b, ops), frees), Sbs), Injs), Vrs), param_opt), bd) lthy =
fun pbmv_monad_cmd ((((((b, ops), Sbs), Injs), Vrs), param_opt), bd) lthy =
let
val ops = map (Syntax.read_typ lthy) ops;
val bd = Syntax.read_term lthy bd;
val frees = map (Syntax.read_typ lthy) frees;
val Sbs = map (Syntax.read_term lthy) Sbs;
val frees = distinct (op=) (maps (
map (fst o dest_funT) o fst o split_last o binder_types o fastype_of
) Sbs);
val Injs = map (map (Syntax.read_term lthy)) Injs;
val Vrs = map (map (map (fn "_" => NONE | t => SOME (Syntax.read_term lthy t)))) Vrs;

Expand Down Expand Up @@ -1029,7 +1083,7 @@ fun pbmv_monad_cmd (((((((b, ops), frees), Sbs), Injs), Vrs), param_opt), bd) lt

fun print_pbmv_monads ctxt =
let
fun pretty_mrbnf (key, BMV {ops, frees, lives, bd, ...}) =
fun pretty_mrbnf (key, BMV {ops, frees, lives, bd, Sbs, ...}) =
Pretty.big_list
(Pretty.string_of (Pretty.block ([Pretty.str key, Pretty.str ":", Pretty.brk 1] @
map (Pretty.quote o Syntax.pretty_typ ctxt) ops)))
Expand All @@ -1039,7 +1093,8 @@ fun print_pbmv_monads ctxt =
[Pretty.block [Pretty.str "live:", Pretty.brk 1, Pretty.str (string_of_int (length lives)),
Pretty.brk 3, Pretty.list "[" "]" (List.map (Syntax.pretty_typ ctxt) lives)]]
else []) @
[Pretty.block [Pretty.str ("bd:"), Pretty.brk 1,
[ Pretty.block ([Pretty.str "Sb:", Pretty.brk 1] @ map (Pretty.quote o Syntax.pretty_term ctxt) Sbs),
Pretty.block [Pretty.str ("bd:"), Pretty.brk 1,
Pretty.quote (Syntax.pretty_term ctxt bd)]]);
in
Pretty.big_list "Registered parametrized bounded multi-variate monads:"
Expand All @@ -1054,8 +1109,7 @@ val _ =

val _ = Outer_Syntax.local_theory_to_proof @{command_keyword pbmv_monad}
"register a parametrized bounded multi-variate monad"
(parse_opt_binding_colon -- Scan.repeat1 (Scan.unless (Parse.reserved "frees") Parse.typ) --|
(Parse.reserved "frees" -- @{keyword ":"}) -- Scan.repeat1 (Scan.unless (Parse.reserved "Sbs") Parse.typ) --|
(parse_opt_binding_colon -- Scan.repeat1 (Scan.unless (Parse.reserved "Sbs") Parse.typ) --|
(Parse.reserved "Sbs" -- @{keyword ":"}) -- Scan.repeat1 (Scan.unless (Parse.reserved "Injs") Parse.term) --|
(Parse.reserved "Injs" -- @{keyword ":"}) -- Parse.list (Scan.repeat1 (Scan.unless (Parse.reserved "Vrs") Parse.term)) --|
(Parse.reserved "Vrs" -- @{keyword ":"}) -- Parse.and_list (Parse.list (
Expand Down
Loading

0 comments on commit 85f1c48

Please sign in to comment.