1structure AC_Sort :> AC_Sort =
2struct
3
4open HolKernel Rewrite Conv
5
6fun gstrip dest t = let
7  fun recurse acc worklist =
8      case worklist of
9        [] => acc
10      | t::ts => let
11        in
12          case Lib.total dest t of
13            NONE => recurse (t::acc) ts
14          | SOME (t1, t2) => recurse acc (t2::t1::ts)
15        end
16in
17  recurse [] [t]
18end
19
20fun lmk mk ts = List.foldl mk (hd ts) (tl ts)
21
22fun balance (dest,mk) assoc t = let
23  val ts = gstrip dest t
24  fun recurse ts = let
25    val l = length ts div 2
26    val (p,s) = split_after l ts
27  in
28    if l = 0 then hd ts
29    else mk(recurse p, recurse s)
30  end
31in
32  if length ts < 4 then ALL_CONV t
33  else let
34      val btree_norm = QCONV (PURE_REWRITE_CONV [assoc]) (recurse ts)
35      val t_norm = QCONV (PURE_REWRITE_CONV [assoc]) t
36    in
37      TRANS t_norm (SYM btree_norm)
38    end
39end
40
41
42fun sort {cmp, combine, dest, mk, assoc, comm, preprocess} = let
43  val wassoc = REWR_CONV assoc
44  val wassoc' = REWR_CONV (GSYM assoc)
45  val wcomm = REWR_CONV comm
46  fun toList t =
47      case Lib.total dest t of
48        NONE => (t, NONE)
49      | SOME (t1, t2) => (t1, SOME t2)
50  fun merge t = let
51    val (t1, t2) = dest t
52    val (h1, rest1) = toList t1
53    val (h2, rest2) = toList t2
54    val p = (isSome rest1, isSome rest2)
55    fun lift_equal (true, true) =
56        RAND_CONV wcomm THENC wassoc' THENC RAND_CONV (wassoc THENC wcomm) THENC
57        wassoc THENC LAND_CONV combine THENC RAND_CONV merge
58      | lift_equal (false, true) = wassoc THENC LAND_CONV combine
59      | lift_equal (true, false) = wcomm THENC wassoc THENC LAND_CONV combine
60      | lift_equal (false, false) = combine
61    fun lift_left (true, _) = wassoc' THENC RAND_CONV merge
62      | lift_left (false, _) = ALL_CONV
63    fun lift_right (_, true) = wcomm THENC wassoc' THENC RAND_CONV merge
64      | lift_right (_, false) = wcomm
65  in
66    case cmp (h1, h2) of
67      EQUAL => lift_equal p
68    | LESS => lift_left p
69    | GREATER => lift_right p
70  end t
71
72  fun recurse t =
73      case Lib.total dest t of
74        NONE => TRY_CONV (QCHANGED_CONV preprocess THENC recurse) t
75      | SOME (t1, t2) => (BINOP_CONV recurse THENC merge) t
76in
77  balance (dest, mk) assoc THENC recurse
78end
79
80(*
81
82-- booleans over \/, with idempotency and cancellation
83
84val boolcombine = let
85  val porp = last (CONJUNCTS (SPEC_ALL OR_CLAUSES))
86  val pornotp = EXCLUDED_MIDDLE |> SPEC_ALL |> EQT_INTRO
87  val notporp = pornotp |> CONV_RULE (LAND_CONV (REWR_CONV DISJ_COMM))
88in
89  TRY_CONV (REWR_CONV porp ORELSEC REWR_CONV pornotp ORELSEC
90            REWR_CONV notporp)
91end
92
93fun boolcompare (t1, t2) = let
94  val (t1',_) = strip_neg t1
95  val (t2',_) = strip_neg t2
96in
97  Term.compare(t1',t2')
98end
99
100val boolpreprocess = REPEATC (REWR_CONV (hd (CONJUNCTS NOT_CLAUSES)))
101
102val booldisj_sort = sort {mk = mk_disj, dest = dest_disj,
103                          cmp = boolcompare, comm = DISJ_COMM,
104                          assoc = DISJ_ASSOC,
105                          combine = boolcombine,
106                          preprocess = boolpreprocess}
107
108val b1 = time booldisj_sort ``p \/ r \/ q``
109val b2 = time booldisj_sort ``~p \/ r \/ q \/ a \/ p``
110val b3 = time booldisj_sort ``~~~p \/ r \/ q \/ a \/ p \/ b``
111val b4 = time booldisj_sort ``p \/ r \/ q \/ p``
112val b5 = time booldisj_sort ``~a \/ p \/ r \/ q \/ ~~~~p``
113
114-- integers with coefficient gathering
115
116fun intcombine t = let
117  open intSyntax
118  val (t1, t2) = intSyntax.dest_mult t
119in
120  if is_int_literal t1 then intLib.REDUCE_CONV t
121  else ALL_CONV t
122end
123
124fun intcompare(t1, t2) = let
125  open intSyntax
126in
127  case (is_int_literal t1, is_int_literal t2) of
128    (true, true) => EQUAL
129  | (true, false) => LESS
130  | (false, true) => GREATER
131  | (false, false) => Term.compare(t1, t2)
132end
133
134fun preprocess t =
135    if intSyntax.is_int_literal t then NO_CONV t
136    else REWR_CONV integerTheory.INT_NEG_MINUS1 t
137
138val intmul =
139    sort {mk = intSyntax.mk_mult, dest = intSyntax.dest_mult,
140          cmp = intcompare, comm = integerTheory.INT_MUL_COMM,
141          assoc = integerTheory.INT_MUL_ASSOC, combine = intcombine,
142          preprocess = preprocess}
143
144val test1 = time intmul ``2 * x * -7 * -y``
145val test2 = time intmul ``2 * a * -x * a * 6 * b``
146val test3 = time intmul ``x:int * y``   (* UNCHANGED *)
147val test4 = time intmul ``y:int * x``
148val test5 = time intmul ``2 * x * -1``
149val test6 = time intmul ``1 * x:int``  (* UNCHANGED *)
150val test7 = time intmul ``x * y * 0 : int``
151
152*)
153end (* struct *)
154