1structure Diff :> Diff =
2struct
3
4open HolKernel Parse boolLib hol88Lib jrhUtils limTheory;
5
6structure Parse = struct
7  open Parse
8  val (Type, Term) =
9      parse_from_grammars
10        (apsnd ParseExtras.grammar_loose_equality limTheory.lim_grammars)
11end
12open Parse
13
14val ERR = mk_HOL_ERR "Diff"
15
16val xreal    = Term`x:real`;
17val lreal    = Term`l:real`
18val diffl_tm = Term`$diffl`;
19val pow_tm   = Term`$pow`;
20
21(*---------------------------------------------------------------------------*)
22(* A conversion to differentiate expressions                                 *)
23(*---------------------------------------------------------------------------*)
24
25val basic_diffs = ref ([]:thm list);
26
27
28(*---------------------------------------------------------------------------*)
29(* DIFF_CONV "fn t => f[t]" =                                                *)
30(*   |- !l1..ln x. conditions[x] ==> ((fn t => f[t]) diffl f'[x])(x)         *)
31(* Where the li's are hypothetical derivatives for unknown sub-functions     *)
32(*---------------------------------------------------------------------------*)
33val iths = map TAUT_CONV
34             [(``(a ==> c) /\ (b ==> d) ==> ((a /\ b) ==> (c /\ d))``),
35              (``c /\ (b ==> d) ==> (b ==> (c /\ d))``),
36              (``(a ==> c) /\ d ==> (a ==> (c /\ d))``),
37              (``c /\ d ==> c /\ d``)];
38
39val [DIFF_INV', DIFF_DIV'] =
40   map (ONCE_REWRITE_RULE[TAUT_CONV (``a /\ b ==> c = a ==> b ==> c``)])
41          [DIFF_INV, REWRITE_RULE[CONJ_ASSOC] DIFF_DIV];
42
43val comths = [DIFF_ADD, DIFF_MUL, DIFF_SUB, DIFF_DIV', DIFF_NEG, DIFF_INV'];
44
45val CC = TAUT_CONV (``a ==> b ==> c = a /\ b ==> c``);
46
47fun DIFF_CONV tm =
48  let val xv = variant (frees tm) xreal
49      fun is_diffl tm =
50        (funpow 3 rator tm ~~ diffl_tm handle HOL_ERR _ => false)
51      fun make_assoc th =
52        let val tm1 = (snd o strip_imp o snd o strip_forall o concl) th
53            val tm2 = (rand o rator o rator) tm1 in
54            if is_abs tm2 then
55              (fst(strip_comb(body tm2)),th)
56            else
57              let val th1 = ETA_CONV (mk_abs(xreal,mk_comb(tm2, xreal)))
58                  val th2 = AP_TERM diffl_tm (SYM th1)
59                  val th3 = ONCE_REWRITE_RULE[th2] th
60              in
61                (fst(strip_comb tm2),th3)
62              end
63        end
64      val (cths, bths) = case map (map make_assoc) [comths, !basic_diffs]
65                         of [cths, bths] => (cths, bths)
66                          | _ => raise ERR "DIFF_CONV" ""
67      fun ICONJ th1 th2 =
68        let val (th1a, th2a) = case map (SPEC xv) [th1, th2]
69                               of [th1a, th2a] => (th1a, th2a)
70                                | _ => raise ERR "DIFF_CONV" ""
71        in
72          GEN xv (tryfind (C MATCH_MP (CONJ th1a th2a)) iths)
73        end
74      fun diff tm =
75        let val (v,bod) = dest_abs tm in
76            if not (free_in v bod) then
77              GEN xv (SPECL [bod, xv] DIFF_CONST)
78            else if bod ~~ v then
79              GEN xv (SPEC xv DIFF_X)
80            else
81              let val (opp,args) = strip_comb bod in
82              (let val th1 = tassoc opp cths
83                   val nargs = map (curry mk_abs v) args
84                   val dargs = map diff nargs
85                   val th2 = end_itlist ICONJ dargs
86                   val th3 = UNDISCH (SPEC xv th2)
87                             handle HOL_ERR _ => SPEC xv th2
88                   val th4 = MATCH_MP th1 th3
89                   val th5 = DISCH_ALL th4 handle _ => th4
90                   val th6 = Rewrite.GEN_REWRITE_RULE I Rewrite.empty_rewrites
91                              [CC] th5 handle HOL_ERR _ => th5
92                   val th7 = CONV_RULE(REDEPTH_CONV BETA_CONV) th6 in
93               GEN xv th7 end handle HOL_ERR _ =>
94               let val arg = Lib.trye hd args
95                   val narg = mk_abs(v,arg)
96                   val th1 = if opp ~~ pow_tm
97                             then SPEC (last args) DIFF_POW
98                             else tassoc opp bths
99                   val th2 = GEN xv (SPEC (mk_comb(narg,xv)) th1)
100                   val th3 = diff narg
101                   val th4 = SPEC xv (ICONJ th2 th3)
102                   val th5 = MATCH_MP DIFF_CHAIN (UNDISCH th4
103                             handle HOL_ERR _ => th4)
104                   val th6 = CONV_RULE(REDEPTH_CONV BETA_CONV) (DISCH_ALL th5)
105               in
106               GEN xv th6
107               end handle HOL_ERR _ =>
108                    let val tm1 = mk_comb(diffl_tm, tm)
109                        val var = variant (frees tm) lreal
110                        val tm2 = mk_comb(tm1,var)
111                        val tm3 = mk_comb(tm2,xv)
112                    in
113                      GEN xv (DISCH tm3 (ASSUME tm3))
114                    end)
115              end end
116      val tha = diff tm
117      val cjs = strip_conj (fst (dest_imp
118                  (snd (strip_forall (concl tha))))) handle HOL_ERR _ => []
119      val cj2 = filter is_diffl cjs
120      val fvs = map (rand o rator) cj2
121      val thb = itlist GEN fvs tha
122  in
123   CONV_RULE (ONCE_DEPTH_CONV(C ALPHA tm)) thb
124end;
125
126end;
127