1structure AC_Sort :> AC_Sort = 2struct 3 4open HolKernel Rewrite Conv 5 6fun gstrip dest t = let 7 fun recurse acc worklist = 8 case worklist of 9 [] => acc 10 | t::ts => let 11 in 12 case Lib.total dest t of 13 NONE => recurse (t::acc) ts 14 | SOME (t1, t2) => recurse acc (t2::t1::ts) 15 end 16in 17 recurse [] [t] 18end 19 20fun lmk mk ts = List.foldl mk (hd ts) (tl ts) 21 22fun balance (dest,mk) assoc t = let 23 val ts = gstrip dest t 24 fun recurse ts = let 25 val l = length ts div 2 26 val (p,s) = split_after l ts 27 in 28 if l = 0 then hd ts 29 else mk(recurse p, recurse s) 30 end 31in 32 if length ts < 4 then ALL_CONV t 33 else let 34 val btree_norm = QCONV (PURE_REWRITE_CONV [assoc]) (recurse ts) 35 val t_norm = QCONV (PURE_REWRITE_CONV [assoc]) t 36 in 37 TRANS t_norm (SYM btree_norm) 38 end 39end 40 41 42fun sort {cmp, combine, dest, mk, assoc, comm, preprocess} = let 43 val wassoc = REWR_CONV assoc 44 val wassoc' = REWR_CONV (GSYM assoc) 45 val wcomm = REWR_CONV comm 46 fun toList t = 47 case Lib.total dest t of 48 NONE => (t, NONE) 49 | SOME (t1, t2) => (t1, SOME t2) 50 fun merge t = let 51 val (t1, t2) = dest t 52 val (h1, rest1) = toList t1 53 val (h2, rest2) = toList t2 54 val p = (isSome rest1, isSome rest2) 55 fun lift_equal (true, true) = 56 RAND_CONV wcomm THENC wassoc' THENC RAND_CONV (wassoc THENC wcomm) THENC 57 wassoc THENC LAND_CONV combine THENC RAND_CONV merge 58 | lift_equal (false, true) = wassoc THENC LAND_CONV combine 59 | lift_equal (true, false) = wcomm THENC wassoc THENC LAND_CONV combine 60 | lift_equal (false, false) = combine 61 fun lift_left (true, _) = wassoc' THENC RAND_CONV merge 62 | lift_left (false, _) = ALL_CONV 63 fun lift_right (_, true) = wcomm THENC wassoc' THENC RAND_CONV merge 64 | lift_right (_, false) = wcomm 65 in 66 case cmp (h1, h2) of 67 EQUAL => lift_equal p 68 | LESS => lift_left p 69 | GREATER => lift_right p 70 end t 71 72 fun recurse t = 73 case Lib.total dest t of 74 NONE => TRY_CONV (QCHANGED_CONV preprocess THENC recurse) t 75 | SOME (t1, t2) => (BINOP_CONV recurse THENC merge) t 76in 77 balance (dest, mk) assoc THENC recurse 78end 79 80(* 81 82-- booleans over \/, with idempotency and cancellation 83 84val boolcombine = let 85 val porp = last (CONJUNCTS (SPEC_ALL OR_CLAUSES)) 86 val pornotp = EXCLUDED_MIDDLE |> SPEC_ALL |> EQT_INTRO 87 val notporp = pornotp |> CONV_RULE (LAND_CONV (REWR_CONV DISJ_COMM)) 88in 89 TRY_CONV (REWR_CONV porp ORELSEC REWR_CONV pornotp ORELSEC 90 REWR_CONV notporp) 91end 92 93fun boolcompare (t1, t2) = let 94 val (t1',_) = strip_neg t1 95 val (t2',_) = strip_neg t2 96in 97 Term.compare(t1',t2') 98end 99 100val boolpreprocess = REPEATC (REWR_CONV (hd (CONJUNCTS NOT_CLAUSES))) 101 102val booldisj_sort = sort {mk = mk_disj, dest = dest_disj, 103 cmp = boolcompare, comm = DISJ_COMM, 104 assoc = DISJ_ASSOC, 105 combine = boolcombine, 106 preprocess = boolpreprocess} 107 108val b1 = time booldisj_sort ``p \/ r \/ q`` 109val b2 = time booldisj_sort ``~p \/ r \/ q \/ a \/ p`` 110val b3 = time booldisj_sort ``~~~p \/ r \/ q \/ a \/ p \/ b`` 111val b4 = time booldisj_sort ``p \/ r \/ q \/ p`` 112val b5 = time booldisj_sort ``~a \/ p \/ r \/ q \/ ~~~~p`` 113 114-- integers with coefficient gathering 115 116fun intcombine t = let 117 open intSyntax 118 val (t1, t2) = intSyntax.dest_mult t 119in 120 if is_int_literal t1 then intLib.REDUCE_CONV t 121 else ALL_CONV t 122end 123 124fun intcompare(t1, t2) = let 125 open intSyntax 126in 127 case (is_int_literal t1, is_int_literal t2) of 128 (true, true) => EQUAL 129 | (true, false) => LESS 130 | (false, true) => GREATER 131 | (false, false) => Term.compare(t1, t2) 132end 133 134fun preprocess t = 135 if intSyntax.is_int_literal t then NO_CONV t 136 else REWR_CONV integerTheory.INT_NEG_MINUS1 t 137 138val intmul = 139 sort {mk = intSyntax.mk_mult, dest = intSyntax.dest_mult, 140 cmp = intcompare, comm = integerTheory.INT_MUL_COMM, 141 assoc = integerTheory.INT_MUL_ASSOC, combine = intcombine, 142 preprocess = preprocess} 143 144val test1 = time intmul ``2 * x * -7 * -y`` 145val test2 = time intmul ``2 * a * -x * a * 6 * b`` 146val test3 = time intmul ``x:int * y`` (* UNCHANGED *) 147val test4 = time intmul ``y:int * x`` 148val test5 = time intmul ``2 * x * -1`` 149val test6 = time intmul ``1 * x:int`` (* UNCHANGED *) 150val test7 = time intmul ``x * y * 0 : int`` 151 152*) 153end (* struct *) 154