1structure Diff :> Diff =
2struct
3
4open HolKernel Parse boolLib hol88Lib jrhUtils limTheory;
5
6structure Parse = struct
7  open Parse
8  val (Type, Term) = parse_from_grammars limTheory.lim_grammars
9end
10open Parse
11
12val ERR = mk_HOL_ERR "Diff"
13
14val xreal    = Term`x:real`;
15val lreal    = Term`l:real`
16val diffl_tm = Term`$diffl`;
17val pow_tm   = Term`$pow`;
18
19(*---------------------------------------------------------------------------*)
20(* A conversion to differentiate expressions                                 *)
21(*---------------------------------------------------------------------------*)
22
23val basic_diffs = ref ([]:thm list);
24
25
26(*---------------------------------------------------------------------------*)
27(* DIFF_CONV "fn t => f[t]" =                                                *)
28(*   |- !l1..ln x. conditions[x] ==> ((fn t => f[t]) diffl f'[x])(x)         *)
29(* Where the li's are hypothetical derivatives for unknown sub-functions     *)
30(*---------------------------------------------------------------------------*)
31val iths = map TAUT_CONV
32             [(``(a ==> c) /\ (b ==> d) ==> ((a /\ b) ==> (c /\ d))``),
33              (``c /\ (b ==> d) ==> (b ==> (c /\ d))``),
34              (``(a ==> c) /\ d ==> (a ==> (c /\ d))``),
35              (``c /\ d ==> c /\ d``)];
36
37val [DIFF_INV', DIFF_DIV'] =
38   map (ONCE_REWRITE_RULE[TAUT_CONV (``a /\ b ==> c = a ==> b ==> c``)])
39          [DIFF_INV, REWRITE_RULE[CONJ_ASSOC] DIFF_DIV];
40
41val comths = [DIFF_ADD, DIFF_MUL, DIFF_SUB, DIFF_DIV', DIFF_NEG, DIFF_INV'];
42
43val CC = TAUT_CONV (``a ==> b ==> c = a /\ b ==> c``);
44
45fun DIFF_CONV tm =
46  let val xv = variant (frees tm) xreal
47      fun is_diffl tm =
48        (funpow 3 rator tm = diffl_tm handle HOL_ERR _ => false)
49      fun make_assoc th =
50        let val tm1 = (snd o strip_imp o snd o strip_forall o concl) th
51            val tm2 = (rand o rator o rator) tm1 in
52            if is_abs tm2 then
53              (fst(strip_comb(body tm2)),th)
54            else
55              let val th1 = ETA_CONV (mk_abs(xreal,mk_comb(tm2, xreal)))
56                  val th2 = AP_TERM diffl_tm (SYM th1)
57                  val th3 = ONCE_REWRITE_RULE[th2] th
58              in
59                (fst(strip_comb tm2),th3)
60              end
61        end
62      val (cths, bths) = case map (map make_assoc) [comths, !basic_diffs]
63                         of [cths, bths] => (cths, bths)
64                          | _ => raise ERR "DIFF_CONV" ""
65      fun ICONJ th1 th2 =
66        let val (th1a, th2a) = case map (SPEC xv) [th1, th2]
67                               of [th1a, th2a] => (th1a, th2a)
68                                | _ => raise ERR "DIFF_CONV" ""
69        in
70          GEN xv (tryfind (C MATCH_MP (CONJ th1a th2a)) iths)
71        end
72      fun diff tm =
73        let val (v,bod) = dest_abs tm in
74            if not (free_in v bod) then
75              GEN xv (SPECL [bod, xv] DIFF_CONST)
76            else if bod = v then
77              GEN xv (SPEC xv DIFF_X)
78            else
79              let val (opp,args) = strip_comb bod in
80              (let val th1 = snd(assoc opp cths)
81                   val nargs = map (curry mk_abs v) args
82                   val dargs = map diff nargs
83                   val th2 = end_itlist ICONJ dargs
84                   val th3 = UNDISCH (SPEC xv th2)
85                             handle HOL_ERR _ => SPEC xv th2
86                   val th4 = MATCH_MP th1 th3
87                   val th5 = DISCH_ALL th4 handle _ => th4
88                   val th6 = Rewrite.GEN_REWRITE_RULE I Rewrite.empty_rewrites
89                              [CC] th5 handle HOL_ERR _ => th5
90                   val th7 = CONV_RULE(REDEPTH_CONV BETA_CONV) th6 in
91               GEN xv th7 end handle HOL_ERR _ =>
92               let val arg = Lib.trye hd args
93                   val narg = mk_abs(v,arg)
94                   val th1 = if opp = pow_tm
95                             then SPEC (last args) DIFF_POW
96                             else snd(assoc opp bths)
97                   val th2 = GEN xv (SPEC (mk_comb(narg,xv)) th1)
98                   val th3 = diff narg
99                   val th4 = SPEC xv (ICONJ th2 th3)
100                   val th5 = MATCH_MP DIFF_CHAIN (UNDISCH th4
101                             handle HOL_ERR _ => th4)
102                   val th6 = CONV_RULE(REDEPTH_CONV BETA_CONV) (DISCH_ALL th5)
103               in
104               GEN xv th6
105               end handle HOL_ERR _ =>
106                    let val tm1 = mk_comb(diffl_tm, tm)
107                        val var = variant (frees tm) lreal
108                        val tm2 = mk_comb(tm1,var)
109                        val tm3 = mk_comb(tm2,xv)
110                    in
111                      GEN xv (DISCH tm3 (ASSUME tm3))
112                    end)
113              end end
114      val tha = diff tm
115      val cjs = strip_conj (fst (dest_imp
116                  (snd (strip_forall (concl tha))))) handle HOL_ERR _ => []
117      val cj2 = filter is_diffl cjs
118      val fvs = map (rand o rator) cj2
119      val thb = itlist GEN fvs tha
120  in
121   CONV_RULE (ONCE_DEPTH_CONV(C ALPHA tm)) thb
122end;
123
124end;
125