1(*  Title:      Pure/sorts.ML
2    Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
3
4The order-sorted algebra of type classes.
5
6Classes denote (possibly empty) collections of types that are
7partially ordered by class inclusion. They are represented
8symbolically by strings.
9
10Sorts are intersections of finitely many classes. They are represented
11by lists of classes.  Normal forms of sorts are sorted lists of
12minimal classes (wrt. current class inclusion).
13*)
14
15signature SORTS =
16sig
17  val make: sort list -> sort Ord_List.T
18  val subset: sort Ord_List.T * sort Ord_List.T -> bool
19  val union: sort Ord_List.T -> sort Ord_List.T -> sort Ord_List.T
20  val subtract: sort Ord_List.T -> sort Ord_List.T -> sort Ord_List.T
21  val remove_sort: sort -> sort Ord_List.T -> sort Ord_List.T
22  val insert_sort: sort -> sort Ord_List.T -> sort Ord_List.T
23  val insert_typ: typ -> sort Ord_List.T -> sort Ord_List.T
24  val insert_typs: typ list -> sort Ord_List.T -> sort Ord_List.T
25  val insert_term: term -> sort Ord_List.T -> sort Ord_List.T
26  val insert_terms: term list -> sort Ord_List.T -> sort Ord_List.T
27  type algebra
28  val classes_of: algebra -> serial Graph.T
29  val arities_of: algebra -> (class * sort list) list Symtab.table
30  val all_classes: algebra -> class list
31  val super_classes: algebra -> class -> class list
32  val class_less: algebra -> class * class -> bool
33  val class_le: algebra -> class * class -> bool
34  val sort_eq: algebra -> sort * sort -> bool
35  val sort_le: algebra -> sort * sort -> bool
36  val sorts_le: algebra -> sort list * sort list -> bool
37  val inter_sort: algebra -> sort * sort -> sort
38  val minimize_sort: algebra -> sort -> sort
39  val complete_sort: algebra -> sort -> sort
40  val add_class: Context.generic -> class * class list -> algebra -> algebra
41  val add_classrel: Context.generic -> class * class -> algebra -> algebra
42  val add_arities: Context.generic -> string * (class * sort list) list -> algebra -> algebra
43  val empty_algebra: algebra
44  val merge_algebra: Context.generic -> algebra * algebra -> algebra
45  val dest_algebra: algebra list -> algebra ->
46    {classrel: (class * class list) list,
47     arities: (string * sort list * class) list}
48  val subalgebra: Context.generic -> (class -> bool) -> (class * string -> sort list option)
49    -> algebra -> (sort -> sort) * algebra
50  type class_error
51  val class_error: Context.generic -> class_error -> string
52  exception CLASS_ERROR of class_error
53  val has_instance: algebra -> string -> sort -> bool
54  val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
55  val meet_sort: algebra -> typ * sort
56    -> sort Vartab.table -> sort Vartab.table   (*exception CLASS_ERROR*)
57  val meet_sort_typ: algebra -> typ * sort -> typ -> typ   (*exception CLASS_ERROR*)
58  val of_sort: algebra -> typ * sort -> bool
59  val of_sort_derivation: algebra ->
60    {class_relation: typ -> bool -> 'a * class -> class -> 'a,
61     type_constructor: string * typ list -> ('a * class) list list -> class -> 'a,
62     type_variable: typ -> ('a * class) list} ->
63    typ * sort -> 'a list   (*exception CLASS_ERROR*)
64  val classrel_derivation: algebra ->
65    ('a * class -> class -> 'a) -> 'a * class -> class -> 'a  (*exception CLASS_ERROR*)
66  val witness_sorts: algebra -> string list -> (typ * sort) list -> sort list -> (typ * sort) list
67end;
68
69structure Sorts: SORTS =
70struct
71
72
73(** ordered lists of sorts **)
74
75val make = Ord_List.make Term_Ord.sort_ord;
76val subset = Ord_List.subset Term_Ord.sort_ord;
77val union = Ord_List.union Term_Ord.sort_ord;
78val subtract = Ord_List.subtract Term_Ord.sort_ord;
79
80val remove_sort = Ord_List.remove Term_Ord.sort_ord;
81val insert_sort = Ord_List.insert Term_Ord.sort_ord;
82
83fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
84  | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
85  | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
86and insert_typs [] Ss = Ss
87  | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
88
89fun insert_term (Const (_, T)) Ss = insert_typ T Ss
90  | insert_term (Free (_, T)) Ss = insert_typ T Ss
91  | insert_term (Var (_, T)) Ss = insert_typ T Ss
92  | insert_term (Bound _) Ss = Ss
93  | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
94  | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
95
96fun insert_terms [] Ss = Ss
97  | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
98
99
100
101(** order-sorted algebra **)
102
103(*
104  classes: graph representing class declarations together with proper
105    subclass relation, which needs to be transitive and acyclic.
106
107  arities: table of association lists of all type arities; (t, ars)
108    means that type constructor t has the arities ars; an element
109    (c, Ss) of ars represents the arity t::(Ss)c.  "Coregularity" of
110    the arities structure requires that for any two declarations
111    t::(Ss1)c1 and t::(Ss2)c2 such that c1 <= c2 holds Ss1 <= Ss2.
112*)
113
114datatype algebra = Algebra of
115 {classes: serial Graph.T,
116  arities: (class * sort list) list Symtab.table};
117
118fun classes_of (Algebra {classes, ...}) = classes;
119fun arities_of (Algebra {arities, ...}) = arities;
120
121fun make_algebra (classes, arities) =
122  Algebra {classes = classes, arities = arities};
123
124fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
125fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
126
127
128(* classes *)
129
130fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
131
132val super_classes = Graph.immediate_succs o classes_of;
133
134
135(* class relations *)
136
137val class_less : algebra -> class * class -> bool = Graph.is_edge o classes_of;
138fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
139
140
141(* sort relations *)
142
143fun sort_le algebra (S1: sort, S2: sort) =
144  S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
145
146fun sorts_le algebra (Ss1, Ss2) =
147  ListPair.all (sort_le algebra) (Ss1, Ss2);
148
149fun sort_eq algebra (S1, S2) =
150  sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
151
152
153(* intersection *)
154
155fun inter_class algebra c S =
156  let
157    fun intr [] = [c]
158      | intr (S' as c' :: c's) =
159          if class_le algebra (c', c) then S'
160          else if class_le algebra (c, c') then intr c's
161          else c' :: intr c's
162  in intr S end;
163
164fun inter_sort algebra (S1, S2) =
165  sort_strings (fold (inter_class algebra) S1 S2);
166
167
168(* normal forms *)
169
170fun minimize_sort _ [] = []
171  | minimize_sort _ (S as [_]) = S
172  | minimize_sort algebra S =
173      filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
174      |> sort_distinct string_ord;
175
176fun complete_sort algebra =
177  Graph.all_succs (classes_of algebra) o minimize_sort algebra;
178
179
180
181(** build algebras **)
182
183(* classes *)
184
185fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
186
187fun err_cyclic_classes context css =
188  error (cat_lines (map (fn cs =>
189    "Cycle in class relation: " ^ Syntax.string_of_classrel (Syntax.init_pretty context) cs) css));
190
191fun add_class context (c, cs) = map_classes (fn classes =>
192  let
193    val classes' = classes |> Graph.new_node (c, serial ())
194      handle Graph.DUP dup => err_dup_class dup;
195    val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
196      handle Graph.CYCLES css => err_cyclic_classes context css;
197  in classes'' end);
198
199
200(* arities *)
201
202local
203
204fun for_classes _ NONE = ""
205  | for_classes ctxt (SOME (c1, c2)) = " for classes " ^ Syntax.string_of_classrel ctxt [c1, c2];
206
207fun err_conflict context t cc (c, Ss) (c', Ss') =
208  let val ctxt = Syntax.init_pretty context in
209    error ("Conflict of type arities" ^ for_classes ctxt cc ^ ":\n  " ^
210      Syntax.string_of_arity ctxt (t, Ss, [c]) ^ " and\n  " ^
211      Syntax.string_of_arity ctxt (t, Ss', [c']))
212  end;
213
214fun coregular context algebra t (c, Ss) ars =
215  let
216    fun conflict (c', Ss') =
217      if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
218        SOME ((c, c'), (c', Ss'))
219      else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
220        SOME ((c', c), (c', Ss'))
221      else NONE;
222  in
223    (case get_first conflict ars of
224      SOME ((c1, c2), (c', Ss')) => err_conflict context t (SOME (c1, c2)) (c, Ss) (c', Ss')
225    | NONE => (c, Ss) :: ars)
226  end;
227
228fun complete algebra (c, Ss) = map (rpair Ss) (c :: super_classes algebra c);
229
230fun insert context algebra t (c, Ss) ars =
231  (case AList.lookup (op =) ars c of
232    NONE => coregular context algebra t (c, Ss) ars
233  | SOME Ss' =>
234      if sorts_le algebra (Ss, Ss') then ars
235      else if sorts_le algebra (Ss', Ss)
236      then coregular context algebra t (c, Ss) (remove (op =) (c, Ss') ars)
237      else err_conflict context t NONE (c, Ss) (c, Ss'));
238
239in
240
241fun insert_ars context algebra t = fold_rev (insert context algebra t);
242
243fun insert_complete_ars context algebra (t, ars) arities =
244  let val ars' =
245    Symtab.lookup_list arities t
246    |> fold_rev (insert_ars context algebra t) (map (complete algebra) ars);
247  in Symtab.update (t, ars') arities end;
248
249fun add_arities context arg algebra =
250  algebra |> map_arities (insert_complete_ars context algebra arg);
251
252fun add_arities_table context algebra =
253  Symtab.fold (fn (t, ars) => insert_complete_ars context algebra (t, ars));
254
255end;
256
257
258(* classrel *)
259
260fun rebuild_arities context algebra = algebra |> map_arities (fn arities =>
261  Symtab.empty
262  |> add_arities_table context algebra arities);
263
264fun add_classrel context rel = rebuild_arities context o map_classes (fn classes =>
265  classes |> Graph.add_edge_trans_acyclic rel
266    handle Graph.CYCLES css => err_cyclic_classes context css);
267
268
269(* empty and merge *)
270
271val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
272
273fun merge_algebra context
274   (Algebra {classes = classes1, arities = arities1},
275    Algebra {classes = classes2, arities = arities2}) =
276  let
277    val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
278      handle Graph.DUP c => err_dup_class c
279        | Graph.CYCLES css => err_cyclic_classes context css;
280    val algebra0 = make_algebra (classes', Symtab.empty);
281    val arities' =
282      (case (pointer_eq (classes1, classes2), pointer_eq (arities1, arities2)) of
283        (true, true) => arities1
284      | (true, false) =>  (*no completion*)
285          (arities1, arities2) |> Symtab.join (fn t => fn (ars1, ars2) =>
286            if pointer_eq (ars1, ars2) then raise Symtab.SAME
287            else insert_ars context algebra0 t ars2 ars1)
288      | (false, true) =>  (*unary completion*)
289          Symtab.empty
290          |> add_arities_table context algebra0 arities1
291      | (false, false) => (*binary completion*)
292          Symtab.empty
293          |> add_arities_table context algebra0 arities1
294          |> add_arities_table context algebra0 arities2);
295  in make_algebra (classes', arities') end;
296
297
298(* destruct *)
299
300fun dest_algebra parents (Algebra {classes, arities}) =
301  let
302    fun new_classrel rel = not (exists (fn algebra => class_less algebra rel) parents);
303    val classrel =
304      (classes, []) |-> Graph.fold (fn (c, (_, (_, ds))) =>
305        (case filter (fn d => new_classrel (c, d)) (Graph.Keys.dest ds) of
306          [] => I
307        | ds' => cons (c, sort_strings ds')))
308      |> sort_by #1;
309
310    fun is_arity t ar algebra = member (op =) (Symtab.lookup_list (arities_of algebra) t) ar;
311    fun add_arity t (c, Ss) = not (exists (is_arity t (c, Ss)) parents) ? cons (t, Ss, c);
312    val arities =
313      (arities, []) |-> Symtab.fold (fn (t, ars) => fold_rev (add_arity t) (sort_by #1 ars))
314      |> sort_by #1;
315  in {classrel = classrel, arities = arities} end;
316
317
318(* algebra projections *)  (* FIXME potentially violates abstract type integrity *)
319
320fun subalgebra context P sargs (algebra as Algebra {classes, arities}) =
321  let
322    val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
323    fun restrict_arity t (c, Ss) =
324      if P c then
325        (case sargs (c, t) of
326          SOME sorts =>
327            SOME (c, Ss |> map2 (curry (inter_sort algebra)) sorts |> map restrict_sort)
328        | NONE => NONE)
329      else NONE;
330    val classes' = classes |> Graph.restrict P;
331    val arities' = arities |> Symtab.map (map_filter o restrict_arity);
332  in (restrict_sort, rebuild_arities context (make_algebra (classes', arities'))) end;
333
334
335
336(** sorts of types **)
337
338(* errors -- performance tuning via delayed message composition *)
339
340datatype class_error =
341  No_Classrel of class * class |
342  No_Arity of string * class |
343  No_Subsort of sort * sort;
344
345fun class_error context =
346  let val ctxt = Syntax.init_pretty context in
347    fn No_Classrel (c1, c2) => "No class relation " ^ Syntax.string_of_classrel ctxt [c1, c2]
348     | No_Arity (a, c) => "No type arity " ^ Syntax.string_of_arity ctxt (a, [], [c])
349     | No_Subsort (S1, S2) =>
350        "Cannot derive subsort relation " ^
351          Syntax.string_of_sort ctxt S1 ^ " < " ^ Syntax.string_of_sort ctxt S2
352  end;
353
354exception CLASS_ERROR of class_error;
355
356
357(* instances *)
358
359fun has_instance algebra a =
360  forall (AList.defined (op =) (Symtab.lookup_list (arities_of algebra) a));
361
362fun mg_domain algebra a S =
363  let
364    val ars = Symtab.lookup_list (arities_of algebra) a;
365    fun dom c =
366      (case AList.lookup (op =) ars c of
367        NONE => raise CLASS_ERROR (No_Arity (a, c))
368      | SOME Ss => Ss);
369    fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
370  in
371    (case S of
372      [] => raise Fail "Unknown domain of empty intersection"
373    | c :: cs => fold dom_inter cs (dom c))
374  end;
375
376
377(* meet_sort *)
378
379fun meet_sort algebra =
380  let
381    fun inters S S' = inter_sort algebra (S, S');
382    fun meet _ [] = I
383      | meet (TFree (_, S)) S' =
384          if sort_le algebra (S, S') then I
385          else raise CLASS_ERROR (No_Subsort (S, S'))
386      | meet (TVar (v, S)) S' =
387          if sort_le algebra (S, S') then I
388          else Vartab.map_default (v, S) (inters S')
389      | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
390  in uncurry meet end;
391
392fun meet_sort_typ algebra (T, S) =
393  let val tab = meet_sort algebra (T, S) Vartab.empty;
394  in Term.map_type_tvar (fn (v, _) => TVar (v, (the o Vartab.lookup tab) v)) end;
395
396
397(* of_sort *)
398
399fun of_sort algebra =
400  let
401    fun ofS (_, []) = true
402      | ofS (TFree (_, S), S') = sort_le algebra (S, S')
403      | ofS (TVar (_, S), S') = sort_le algebra (S, S')
404      | ofS (Type (a, Ts), S) =
405          let val Ss = mg_domain algebra a S in
406            ListPair.all ofS (Ts, Ss)
407          end handle CLASS_ERROR _ => false;
408  in ofS end;
409
410
411(* animating derivations *)
412
413fun of_sort_derivation algebra {class_relation, type_constructor, type_variable} =
414  let
415    val arities = arities_of algebra;
416
417    fun weaken T D1 S2 =
418      let val S1 = map snd D1 in
419        if S1 = S2 then map fst D1
420        else
421          S2 |> map (fn c2 =>
422            (case D1 |> filter (fn (_, c1) => class_le algebra (c1, c2)) of
423              [d1] => class_relation T true d1 c2
424            | (d1 :: _ :: _) => class_relation T false d1 c2
425            | [] => raise CLASS_ERROR (No_Subsort (S1, S2))))
426      end;
427
428    fun derive (_, []) = []
429      | derive (Type (a, Us), S) =
430          let
431            val Ss = mg_domain algebra a S;
432            val dom = map2 (fn U => fn S => derive (U, S) ~~ S) Us Ss;
433          in
434            S |> map (fn c =>
435              let
436                val Ss' = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
437                val dom' = map (fn ((U, d), S') => weaken U d S' ~~ S') ((Us ~~ dom) ~~ Ss');
438              in type_constructor (a, Us) dom' c end)
439          end
440      | derive (T, S) = weaken T (type_variable T) S;
441  in derive end;
442
443fun classrel_derivation algebra class_relation =
444  let
445    fun path (x, c1 :: c2 :: cs) = path (class_relation (x, c1) c2, c2 :: cs)
446      | path (x, _) = x;
447  in
448    fn (x, c1) => fn c2 =>
449      (case Graph.irreducible_paths (classes_of algebra) (c1, c2) of
450        [] => raise CLASS_ERROR (No_Classrel (c1, c2))
451      | cs :: _ => path (x, cs))
452  end;
453
454
455(* witness_sorts *)
456
457fun witness_sorts algebra ground_types hyps sorts =
458  let
459    fun le S1 S2 = sort_le algebra (S1, S2);
460    fun get S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
461    fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
462
463    fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
464      | witn_sort path S (solved, failed) =
465          if exists (le S) failed then (NONE, (solved, failed))
466          else
467            (case get_first (get S) solved of
468              SOME w => (SOME w, (solved, failed))
469            | NONE =>
470                (case get_first (get S) hyps of
471                  SOME w => (SOME w, (w :: solved, failed))
472                | NONE => witn_types path ground_types S (solved, failed)))
473
474    and witn_sorts path x = fold_map (witn_sort path) x
475
476    and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
477      | witn_types path (t :: ts) S solved_failed =
478          (case mg_dom t S of
479            SOME SS =>
480              (*do not descend into stronger args (achieving termination)*)
481              if exists (fn D => le D S orelse exists (le D) path) SS then
482                witn_types path ts S solved_failed
483              else
484                let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
485                  if forall is_some ws then
486                    let val w = (Type (t, map (#1 o the) ws), S)
487                    in (SOME w, (w :: solved', failed')) end
488                  else witn_types path ts S (solved', failed')
489                end
490          | NONE => witn_types path ts S solved_failed);
491
492  in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
493
494end;
495