1(*
2    Copyright (c) 2012,13,15 David C.J. Matthews
3
4    This library is free software; you can redistribute it and/or
5    modify it under the terms of the GNU Lesser General Public
6    License version 2.1 as published by the Free Software Foundation.
7    
8    This library is distributed in the hope that it will be useful,
9    but WITHOUT ANY WARRANTY; without even the implied warranty of
10    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
11    Lesser General Public License for more details.
12    
13    You should have received a copy of the GNU Lesser General Public
14    License along with this library; if not, write to the Free Software
15    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
16*)
17
18functor CODETREE_OPTIMISER(
19    structure BASECODETREE: BaseCodeTreeSig
20
21    structure CODETREE_FUNCTIONS: CodetreeFunctionsSig
22
23    structure REMOVE_REDUNDANT:
24    sig
25        type codetree
26        type loadForm
27        type codeUse
28        val cleanProc : (codetree * codeUse list * (int -> loadForm) * int) -> codetree
29        structure Sharing: sig type codetree = codetree and loadForm = loadForm and codeUse = codeUse end
30    end
31
32    structure SIMPLIFIER:
33    sig
34        type codetree and codeBinding and envSpecial
35
36        val simplifier:
37            codetree * int -> (codetree * codeBinding list * envSpecial) * int * bool
38        val specialToGeneral:
39            codetree * codeBinding list * envSpecial -> codetree
40
41        structure Sharing:
42        sig
43            type codetree = codetree
44            and codeBinding = codeBinding
45            and envSpecial = envSpecial
46        end
47    end
48
49    structure DEBUG: DEBUGSIG
50    structure PRETTY : PRETTYSIG
51
52    structure BACKEND:
53    sig
54        type codetree
55        type machineWord = Address.machineWord
56        val codeGenerate:
57            codetree * int * Universal.universal list -> (unit -> machineWord) * Universal.universal list
58        structure Sharing : sig type codetree = codetree end
59    end
60
61    sharing 
62        BASECODETREE.Sharing 
63    =   CODETREE_FUNCTIONS.Sharing
64    =   REMOVE_REDUNDANT.Sharing
65    =   SIMPLIFIER.Sharing
66    =   PRETTY.Sharing
67    =   BACKEND.Sharing
68
69) :
70    sig
71        type codetree and envSpecial and codeBinding
72        val codetreeOptimiser: codetree  * Universal.universal list * int ->
73            { numLocals: int, general: codetree, bindings: codeBinding list, special: envSpecial }
74        structure Sharing: sig type codetree = codetree and envSpecial = envSpecial and codeBinding = codeBinding end
75    end
76=
77struct
78    open BASECODETREE
79    open Address
80    open CODETREE_FUNCTIONS
81    open StretchArray
82    
83    infix 9 sub
84    
85    exception InternalError = Misc.InternalError
86
87 
88    datatype inlineTest =
89        TooBig
90    |   NonRecursive
91    |   TailRecursive of bool vector
92    |   NonTailRecursive of bool vector
93
94    fun evaluateInlining(function, numArgs, maxInlineSize) =
95    let
96        (* This checks for the possibility of inlining a function.  It sees if it is
97           small enough according to some rough estimate of the cost and it also looks
98           for recursive uses of the function.
99           Typically if the function is small enough to inline there will be only
100           one recursive use but we consider the possibility of more than one.  If
101           the only uses are tail recursive we can replace the recursive calls by
102           a Loop with a BeginLoop outside it.  If there are non-tail recursive
103           calls we may be able to lift out arguments that are unchanged.  For
104           example for fun map f [] = [] | map f (a::b) = f a :: map f b 
105           it may be worth lifting out f and generating specific mapping
106           functions for each application. *)
107        val hasRecursiveCall = ref false (* Set to true if rec call *)
108        val allTail = ref true (* Set to false if non recursive *)
109        (* An element of this is set to false if the actual value if anything
110           other than the original argument.  At the end we are then
111           left with the arguments that are unchanged. *)
112        val argMod = Array.array(numArgs, true)
113
114        infix 6 --
115        (* Subtract y from x but return 0 rather than a negative number. *)
116        fun x -- y = if x >= y then x-y else 0
117
118        (* Check for the code size and also recursive references.  N,B. We assume in hasLoop
119           that tail recursion applies only with Cond, Newenv and Handler. *)
120        fun checkUse _ (_, 0, _) = 0 (* The function is too big to inline. *)
121 
122        |   checkUse isMain (Newenv(decs, exp), cl, isTail) =
123            let
124                fun checkBind (Declar{value, ...}, cl) = checkUse isMain(value, cl, false)
125                |   checkBind (RecDecs decs, cl) = List.foldl(fn ({lambda, ...}, n) => checkUse isMain (Lambda lambda, n, false)) cl decs
126                |   checkBind (NullBinding c, cl) = checkUse isMain (c, cl, false)
127                |   checkBind (Container{setter, ...}, cl) = checkUse isMain(setter, cl -- 1, false)
128            in
129                checkUse isMain (exp, List.foldl checkBind cl decs, isTail)
130            end
131
132        |   checkUse _      (Constnt(w, _), cl, _) = if isShort w then cl else cl -- 1
133
134            (* A recursive reference in any context other than a call prevents any inlining. *)
135        |   checkUse true   (Extract LoadRecursive, _, _) = 0
136        |   checkUse _      (Extract _, cl, _) = cl -- 1
137
138        |   checkUse isMain (Indirect{base, ...}, cl, _) = checkUse isMain (base, cl -- 1, false)
139
140        |   checkUse _      (Lambda {body, argTypes, closure, ...}, cl, _) =
141                (* For the moment, any recursive use in an inner function prevents inlining. *)
142                if List.exists (fn LoadRecursive => true | _ => false) closure
143                then 0
144                else checkUse false (body, cl -- (List.length argTypes + List.length closure), false)
145
146        |   checkUse true (Eval{function = Extract LoadRecursive, argList, ...}, cl, isTail) =
147            let
148                (* If the actual argument is anything but the original argument
149                   then the corresponding entry in the array is set to false. *)
150                fun testArg((exp, _), n) =
151                (
152                    if (case exp of Extract(LoadArgument a) => n = a | _ => false)
153                    then ()
154                    else Array.update(argMod, n, false);
155                    n+1
156                )
157            in
158                List.foldl testArg 0 argList;
159                hasRecursiveCall := true;
160                if isTail then () else allTail := false;
161                List.foldl(fn ((e, _), n) => checkUse true (e, n, false)) (cl--3) argList
162            end
163
164        |   checkUse isMain (Eval{function, argList, ...}, cl, _) =
165                checkUse isMain (function, List.foldl(fn ((e, _), n) => checkUse isMain (e, n, false)) (cl--2) argList, false)
166
167        |   checkUse _ (GetThreadId, cl, _) = cl -- 1
168        |   checkUse isMain (Unary{arg1, ...}, cl, _) = checkUse isMain (arg1, cl -- 1, false)
169        |   checkUse isMain (Binary{arg1, arg2, ...}, cl, _) = checkUseList isMain ([arg1, arg2], cl -- 1)
170        |   checkUse isMain (Arbitrary{arg1, arg2, ...}, cl, _) = checkUseList isMain ([arg1, arg2], cl -- 4)
171        |   checkUse isMain (AllocateWordMemory {numWords, flags, initial}, cl, _) =
172                checkUseList isMain ([numWords, flags, initial], cl -- 1)
173
174        |   checkUse isMain (Cond(i, t, e), cl, isTail) =
175                checkUse isMain (i, checkUse isMain (t, checkUse isMain (e, cl -- 2, isTail), isTail), false)
176        |   checkUse isMain (BeginLoop { loop, arguments, ...}, cl, _) =
177                checkUse isMain (loop, List.foldl (fn (({value, ...}, _), n) => checkUse isMain (value, n, false)) cl arguments, false)
178        |   checkUse isMain (Loop args, cl, _) = List.foldl(fn ((e, _), n) => checkUse isMain (e, n, false)) cl args
179        |   checkUse isMain (Raise c, cl, _) = checkUse isMain (c, cl -- 1, false)
180        |   checkUse isMain (Handle {exp, handler, ...}, cl, isTail) =
181                checkUse isMain (exp, checkUse isMain (handler, cl, isTail), false)
182        |   checkUse isMain (Tuple{ fields, ...}, cl, _) = checkUseList isMain (fields, cl)
183
184        |   checkUse isMain (SetContainer{container, tuple = Tuple { fields, ...}, ...}, cl, _) =
185                (* This can be optimised *)
186                checkUse isMain (container, checkUseList isMain (fields, cl), false)
187        |   checkUse isMain (SetContainer{container, tuple, filter}, cl, _) =
188                checkUse isMain (container, checkUse isMain (tuple, cl -- (BoolVector.length filter), false), false)
189
190        |   checkUse isMain (TagTest{test, ...}, cl, _) = checkUse isMain (test, cl -- 1, false)
191
192        |   checkUse isMain (LoadOperation{address, ...}, cl, _) = checkUseAddress isMain (address, cl -- 1)
193
194        |   checkUse isMain (StoreOperation{address, value, ...}, cl, _) =
195                checkUse isMain (value, checkUseAddress isMain (address, cl -- 1), false)
196
197        |   checkUse isMain (BlockOperation{sourceLeft, destRight, length, ...}, cl, _) =
198                checkUse isMain (length,
199                    checkUseAddress isMain (destRight, checkUseAddress isMain (sourceLeft, cl -- 1)), false)
200        
201        and checkUseList isMain (elems, cl) =
202            List.foldl(fn (e, n) => checkUse isMain (e, n, false)) cl elems
203
204        and checkUseAddress isMain ({base, index=NONE, ...}, cl) = checkUse isMain (base, cl, false)
205        |   checkUseAddress isMain ({base, index=SOME index, ...}, cl) = checkUseList isMain ([base, index], cl)
206        
207        val costLeft = checkUse true (function, maxInlineSize, true)
208    in
209        if costLeft = 0
210        then TooBig
211        else if not (! hasRecursiveCall) 
212        then NonRecursive
213        else if ! allTail then TailRecursive(Array.vector argMod)
214        else NonTailRecursive(Array.vector argMod)
215    end
216
217    (* Turn a list of fields to use into a filter for SetContainer. *)
218    fun fieldsToFilter useList =
219    let
220        val maxDest = List.foldl Int.max ~1 useList
221        val fields = BoolArray.array(maxDest+1, false)
222        val _ = List.app(fn n => BoolArray.update(fields, n, true)) useList
223    in
224        BoolArray.vector fields
225    end
226
227    and filterToFields filter =
228        BoolVector.foldri (fn (i, true, l) => i :: l | (_, _, l) => l) [] filter
229
230    and setInFilter filter = BoolVector.foldl (fn (true, n) => n+1 | (false, n) => n) 0 filter
231
232    (* Work-around for bug in bytevector equality. *)
233    and boolVectorEq(a, b) = filterToFields a = filterToFields b
234 
235    fun buildFullTuple(filter, select) =
236    let
237        fun extArg(t, u) =
238            if t = BoolVector.length filter then []
239            else if BoolVector.sub(filter, t)
240            then select u :: extArg(t+1, u+1)
241            else CodeZero :: extArg (t+1, u)
242    in
243        mkTuple(extArg(0, 0))
244    end
245
246    (* When transforming code we only process one level and do not descend into sub-functions. *)
247    local
248        fun deExtract(Extract l) = l | deExtract _ = raise Misc.InternalError "deExtract"
249        fun onlyFunction repEntry (Lambda{ body, isInline, name, closure, argTypes, resultType, localCount, recUse }) =
250            SOME(
251                Lambda {
252                    body = body, isInline = isInline, name = name,
253                    closure = map (deExtract o mapCodetree repEntry o Extract) closure,
254                    argTypes = argTypes, resultType = resultType, localCount = localCount,
255                    recUse = recUse
256                }
257            )
258        |   onlyFunction repEntry code = repEntry code
259    in
260        fun mapFunctionCode repEntry = mapCodetree (onlyFunction repEntry)
261    end
262
263    local
264        (* This transforms the body of a "small" recursive function replacing any reference
265           to the arguments by the appropriate entry and the recursive calls themselves
266           by either a Loop or a recursive call. *)
267        fun mapCodeForFunctionRewriting(code, argMap, modVec, transformCall) =
268        let
269            fun repEntry(Extract(LoadArgument n)) = SOME(Extract(Vector.sub(argMap, n)))
270            |   repEntry(Eval { function = Extract LoadRecursive, argList, resultType }) =
271                let
272                    (* Filter arguments to include only those that are changed and map any values we pass.
273                       They may include references to the parameters. *)
274                    fun mapArg((arg, argT)::rest, n) =
275                        if Vector.sub(modVec, n) then mapArg(rest, n+1)
276                        else (mapCode arg, argT) :: mapArg(rest, n+1)
277                    |   mapArg([], _) = []
278                in
279                    SOME(transformCall(mapArg(argList, 0), resultType))
280                end
281            |   repEntry _ = NONE
282        
283            and mapCode code = mapFunctionCode repEntry code
284        in
285            mapCode code
286        end
287    in
288        (* If we have a tail recursive function we can replace the tail calls
289           by a loop.  modVec indicates the arguments that have not changed. *)
290        fun replaceTailRecursiveWithLoop(body, argTypes, modVec, nextAddress) =
291        let
292            (* We need to create local bindings for arguments that will change.
293               Those that do not can be reused. *)
294            local
295                fun mapArgs((argT, use):: rest, n, decs, mapList) =
296                    if Vector.sub(modVec, n)
297                    then mapArgs (rest, n+1, decs, LoadArgument n :: mapList)
298                    else
299                    let
300                        val na = ! nextAddress before nextAddress := !nextAddress + 1
301                    in
302                        mapArgs (rest, n+1, ({addr = na, value = mkLoadArgument n, use=use}, argT) :: decs, LoadLocal na :: mapList)
303                    end
304                |   mapArgs([], _, decs, mapList) = (List.rev decs, List.rev mapList)
305                val (decs, mapList) = mapArgs(argTypes, 0, [], [])
306            in
307                val argMap = Vector.fromList mapList
308                val loopArgs = decs
309            end
310        
311        in
312            BeginLoop { arguments = loopArgs, loop = mapCodeForFunctionRewriting(body, argMap, modVec, fn (l, _) => Loop l) }
313        end
314
315        (* If we have a small recursive function where some arguments are passed
316           through unchanged we can transform it by extracting the
317           stable arguments and only passing the changing arguments.  The
318           advantage is that this allows the stable arguments to be inserted
319           inline which is important if they are functions. The canonical
320           example is List.map. *)
321        fun liftRecursiveFunction(body, argTypes, modVec, closureSize, name, resultType, localCount) =
322        let
323            local
324                fun getArgs((argType, use)::rest, nArg, clCount, argCount, stable, change, mapList) =
325                    let
326                        (* This is the argument from the outer function.  It is either added
327                           to the closure or passed to the inner function. *)
328                        val argN = LoadArgument nArg
329                    in
330                        if Vector.sub(modVec, nArg)
331                        then getArgs(rest, nArg+1, clCount+1, argCount,
332                                    argN :: stable, change, LoadClosure clCount :: mapList)
333                        else getArgs(rest, nArg+1, clCount, argCount+1,
334                                    stable, (Extract argN, argType, use) :: change, LoadArgument argCount :: mapList)
335                    end
336                |   getArgs([], _, _, _, stable, change, mapList) =
337                        (List.rev stable, List.rev change, List.rev mapList)
338            in
339                (* The stable args go into the closure.  The changeable args are passed in. *)
340                val (stableArgs, changeArgsAndTypes, mapList) =
341                    getArgs(argTypes, 0, closureSize, 0, [], [], [])
342                val argMap = Vector.fromList mapList
343            end
344
345            val subFunction =
346                Lambda {
347                    body = mapCodeForFunctionRewriting(body, argMap, modVec, 
348                            fn (l, t) => Eval {
349                                function = Extract LoadRecursive, argList = l, resultType = t
350                            }),
351                    isInline = NonInline, (* Don't inline this function. *)
352                    name = name ^ "()",
353                    closure = List.tabulate(closureSize, fn n => LoadClosure n) @ stableArgs,
354                    argTypes = List.map (fn (_, t, u) => (t, u)) changeArgsAndTypes,
355                    resultType = resultType, localCount = localCount, recUse = [UseGeneral]
356                }
357        in
358            Eval {
359                function = subFunction,
360                argList = map (fn (c, t, _) => (c, t)) changeArgsAndTypes,
361                resultType = resultType
362            }
363        end
364    end
365
366    (* If the function arguments are used in a way that could be optimised the
367       data structure represents it. *)
368    datatype functionArgPattern =
369        ArgPattTuple of { filter: BoolVector.vector, allConst: bool, fromFields: bool }
370        (* ArgPattCurry is a list, one per level of application, of a
371           list, one per argument of the pattern for that argument. *)
372    |   ArgPattCurry of functionArgPattern list list * functionArgPattern
373    |   ArgPattSimple
374
375
376    (* Returns ArgPattCurry even if it is just a single application. *)
377    local
378        (* Control how we check for side-effects. *)
379        datatype curryControl =
380            CurryNoCheck | CurryCheck | CurryReorderable
381
382        local
383            open Address
384
385            (* Return the width of a tuple.  Returns 1 for non-tuples including
386               datatypes where different variants could have different widths.
387               Also returns a flag indicating if the value came from a constant.
388               Constants are already tupled so there's no advantage in untupling
389               them unless there are other non-constant arguments as well. *)
390            fun findTuple(Tuple{fields, isVariant=false}) = (List.length fields, false)
391            |   findTuple(Constnt(w, _)) =
392                    if isShort w orelse flags (toAddress w) <> F_words then (1, false)
393                    else (Word.toInt(length (toAddress w)), true)
394            |   findTuple(Extract _) = (1, false) (* TODO: record this for variables *)
395            |   findTuple(Cond(_, t, e)) =
396                    let
397                        val (tl, tc) = findTuple t
398                        and (el, ec) = findTuple e
399                    in
400                        if tl = el then (tl, tc andalso ec) else (1, false)
401                    end
402            |   findTuple(Newenv(_, e)) = findTuple e
403            |   findTuple _ = (1, false)
404            
405        in
406            fun mapArg c =
407            let
408                val (n, f) = findTuple c
409            in
410                if n <= 1
411                then ArgPattSimple
412                else ArgPattTuple{filter=BoolVector.tabulate(n, fn _ => true),
413                                  allConst=f, fromFields=false}
414            end
415        end
416
417        fun useToPattern _ [] = ArgPattSimple
418        |   useToPattern checkCurry (hd::tl) =
419            let
420                (* Construct a possible pattern from the head. *)
421                val p1 =
422                    case hd of
423                        UseApply(resl, arguments) =>
424                            let
425                                (* If the result is also curried extend the list. *)
426                                val subCheck =
427                                    case checkCurry of CurryCheck => CurryReorderable | c => c
428                                val (resultPatts, resultResult) =
429                                    case useToPattern subCheck resl of
430                                        ArgPattCurry l => l
431                                    |   tupleOrSimple => ([], tupleOrSimple)
432                                
433                                val thisArg = map mapArg arguments
434                            in
435                                (* If we have an argument that is a curried function we
436                                   can safely apply it to the first argument even if that
437                                   has a side-effect but we can't uncurry further than that
438                                   because the behaviour could rely on a side-effect of the
439                                   first application. *)
440                                if checkCurry = CurryReorderable
441                                    andalso List.exists(not o reorderable) arguments
442                                then ArgPattSimple
443                                else ArgPattCurry(thisArg :: resultPatts, resultResult)
444                            end
445
446                    |   UseField (n, _) =>
447                            ArgPattTuple{filter=BoolVector.tabulate(n+1, fn m => m=n), allConst=false, fromFields=true}
448
449                    |   _ => ArgPattSimple
450
451                fun mergePattern(ArgPattCurry(l1, r1), ArgPattCurry(l2, r2)) =
452                    let
453                        (* Each argument list should be the same length.
454                           The length here is the number of arguments
455                           provided to this application. *)
456                        fun mergeArgLists(al1, al2) =
457                            ListPair.mapEq mergePattern (al1, al2)
458                        (* The currying lists could be different lengths
459                           because some applications could only partially
460                           apply it.  It is essential not to assume more
461                           currying than the minimum so we stop with the
462                           shorter. *)
463                        val prefix = ListPair.map mergeArgLists (l1, l2)
464                    in
465                        if null prefix then ArgPattSimple else ArgPattCurry(prefix, mergePattern(r1, r2))
466                    end
467                    
468                |   mergePattern(ArgPattTuple{filter=n1, allConst=c1, fromFields=f1}, ArgPattTuple{filter=n2, allConst=c2, fromFields=f2}) =
469                        (* If the tuples are different sizes we can't use a tuple.
470                           Unlike currying it would be safe to assume tupling where
471                           there isn't (unless the function is actually polymorphic). *)
472                        if boolVectorEq(n1, n2)
473                        then ArgPattTuple{filter=n1, allConst=c1 andalso c2, fromFields = f1 andalso f2}
474                        else if f1 andalso f2
475                        then
476                        let
477                            open BoolVector
478                            val l1 = length n1 and l2 = length n2
479                            fun safesub(n, v) = if n < length v then v sub n else false
480                            val union = tabulate(Int.max(l1, l2), fn n => safesub(n, n1) orelse safesub(n, n2))
481                        in
482                            ArgPattTuple{filter=union, allConst=c1 andalso c2, fromFields = f1 andalso f2}
483                        end
484                        else ArgPattSimple
485
486                |   mergePattern _ = ArgPattSimple
487            in
488                case tl of
489                    [] => p1
490                |   tl => mergePattern(p1, useToPattern checkCurry tl)
491            end
492
493        (* If the result is just a function where all the arguments are simple
494           it's not actually curried. *)
495        fun usageToPattern checkCurry use =
496            case useToPattern checkCurry use of
497            (*    a as ArgPattCurry [s] =>
498                    if List.all(fn ArgPattSimple => true | _ => false) s
499                    then ArgPattSimple
500                    else a
501            |*)   patt => patt
502    in
503        (* Decurrying involves reordering (f exp1) exp2 into code
504           where any effects of evaluating exp2 are done before the
505           application.  That's only safe if either (f exp1) or exp2 have
506           no side-effects and do not depend on references.
507           In the case of the function body we can check that the body does
508           not depend on any references (typically it's a lambda) but for
509           function arguments we have to check how it is applied. *)
510        val usageForFunctionBody = usageToPattern CurryNoCheck
511        and usageForFunctionArg  = usageToPattern CurryCheck
512
513        (* To decide whether we want to detuple the argument we look to see
514           if the function is ever applied to a tuple.  This is rather different
515           to currying where we only decurry if every application is to multiple
516           arguments.  This information is then merged with information about the
517           arguments within the function. *)
518        fun existTupling (use: codeUse list): functionArgPattern list =
519        let
520            val argListLists =
521                List.foldl (fn (UseApply(_, args), l) => map mapArg args :: l | (_, l) => l) [] use
522            fun orMerge [] = raise Empty
523            |   orMerge [hd] = hd
524            |   orMerge (hd1 :: hd2 :: tl) =
525                let
526                    fun merge(a as ArgPattTuple _, _) = a
527                    |   merge(_, b) = b
528                in
529                    orMerge(ListPair.mapEq merge (hd1, hd2) :: tl)
530                end
531        in
532            orMerge argListLists
533        end
534
535        (* If the result of a function contains a tuple but it is not detupled on
536           every path, see if it is detupled on at least one. *)
537        fun existDetupling(UseApply(resl, _) :: rest) =
538            List.exists(fn UseField _ => true | _ => false) resl orelse
539                existDetupling rest
540        |   existDetupling(_ :: rest) = existDetupling rest
541        |   existDetupling [] = false
542    end
543
544    (* Return a tuple if any of the branches returns a tuple.  The idea is
545       that if the body actually constructs a tuple on the heap on at least
546       one branch it is probably worth attempting to detuple the result. *)
547    fun bodyReturnsTuple (Tuple{fields, isVariant=false}) =
548        ArgPattTuple{
549            filter=BoolVector.tabulate(List.length fields, fn _ => true),
550            allConst=false, fromFields=false
551        }
552
553    |   bodyReturnsTuple(Cond(_, t, e)) =
554        (
555            case bodyReturnsTuple t of
556                a as ArgPattTuple _ => a
557            |   _ => bodyReturnsTuple e
558        )
559
560    |   bodyReturnsTuple(Newenv(_, exp)) = bodyReturnsTuple exp
561
562    |   bodyReturnsTuple _ = ArgPattSimple
563
564    (* If the usage indicates that the body of the function should be transformed
565       these do the transformation.  It is possible that each of these cases could
566       apply and it would be possible to merge them all.  For the moment keep them
567       separate.  If another of the cases applies this will be re-entered on a
568       subsequent pass. *)
569    fun detupleResult({ argTypes, name, resultType, closure, isInline, localCount, body, ...}: lambdaForm , filter, makeAddress) =
570        (* The function returns a tuple or at least the uses of the function take apart a tuple.
571           Transform it to take a container as an argument and put the result in there. *)
572        let
573            local
574                fun mapArg f n ((t, _) :: tl) = (Extract(f n), t) :: mapArg f (n+1) tl
575                |   mapArg _ _ [] = []
576            in
577                fun mapArgs f l = mapArg f 0 l
578            end
579            val mainAddress = makeAddress() and shimAddress = makeAddress()
580
581            (* The main function performs the previous computation but puts the result into
582               the container.  We need to replace any recursive references with calls to the
583               shim.*)
584            local
585                val recEntry = LoadClosure(List.length closure)
586
587                fun doMap(Extract LoadRecursive) = SOME(Extract recEntry)
588                |   doMap _ = NONE
589            in
590                val transBody = mapFunctionCode doMap body
591            end
592
593            local
594                val containerArg = Extract(LoadArgument(List.length argTypes))
595                val newBody =
596                    SetContainer{container = containerArg, tuple = transBody, filter=filter }
597                val mainLambda: lambdaForm =
598                    {
599                        body = newBody, name = name, resultType=GeneralType,
600                        argTypes=argTypes @ [(GeneralType, [])],
601                        closure=closure @ [LoadLocal shimAddress],
602                        localCount=localCount + 1, isInline=isInline,
603                        recUse = [UseGeneral]
604                    }
605            in
606                val mainFunction = (mainAddress, mainLambda)
607            end
608
609            (* The shim function creates a container, passes it to the main function and then
610               builds a tuple from the container. *)
611            val shimBody =
612                mkEnv(
613                    [Container{addr = 0, use = [], size = setInFilter filter,
614                        setter= Eval {
615                                function = Extract(LoadClosure 0),
616                                argList = mapArgs LoadArgument argTypes @ [(Extract(LoadLocal 0), GeneralType)],
617                                resultType = GeneralType
618                            }
619                        }
620                    ],
621                    buildFullTuple(filter, fn n => mkInd(n, mkLoadLocal 0))
622                    )
623            val shimLambda =
624                { body = shimBody, name = name, argTypes = argTypes, closure = [LoadLocal mainAddress],
625                  resultType = resultType, isInline = Inline, localCount = 1, recUse = [UseGeneral] }
626            val shimFunction = (shimAddress, shimLambda)
627         in
628            (shimLambda, [mainFunction, shimFunction])
629        end
630
631    fun transformFunctionArgs({ argTypes, name, resultType, closure, isInline, localCount, body, ...} , usage, makeAddress) =
632        (* Not curried - just a single argument. *)
633        let
634            (* We need to construct an inline "shim" function that
635               has the same calling pattern as the original.  This simply
636               calls the transformed main function.
637               We need to construct the arguments to call the transformed
638               main function.  That needs, for example, to unpack tuples
639               and repack argument functions.
640               We need to produce an argument map to transform the main
641               function.  This needs, for example, to pack the arguments
642               into tuples.  Then when the code is run through the simplifier
643               the tuples will be optimised away.  *)
644            val localCounter = ref localCount
645
646            fun mapPattern(ArgPattTuple{filter, allConst=false, ...} :: patts, n, m) =
647                let
648                    val fieldList = filterToFields filter
649                    val (decs, args, mapList) = mapPattern(patts, n+1, m + setInFilter filter)
650                    val newAddr = ! localCounter before localCounter := ! localCounter + 1
651                    val tuple = buildFullTuple(filter, fn u => mkLoadArgument(m+u))
652                    val thisDec = Declar { addr = newAddr, use = [], value = tuple }
653                    (* Arguments for the call *)
654                    val thisArg = List.map(fn p => mkInd(p, mkLoadArgument n)) fieldList
655                in
656                    (thisDec :: decs, thisArg @ args, LoadLocal newAddr :: mapList)
657                end
658
659            |   mapPattern(ArgPattCurry(currying as [_], ArgPattTuple{allConst=false, filter, ...}) :: patts, n, m) =
660                (* It's a function that returns a tuple.  The function must not be curried because
661                   otherwise it returns a function not a tuple. *)
662                let
663                    val (thisDec, thisArg, thisMap) =
664                        transformFunctionArgument(currying, [LoadArgument m], [LoadArgument n], SOME filter)
665                    val (decs, args, mapList) = mapPattern(patts, n+1, m+1)
666                in
667                    (thisDec :: decs, thisArg :: args, thisMap :: mapList)
668                end
669
670            |   mapPattern(ArgPattCurry(currying as firstArgSet :: _, _) :: patts, n, m) =
671                (* Transform it if it's curried or if there is a tuple in the first arg. *)
672                if (*List.length currying >= 2 orelse *) (* This transformation is unsafe. *)
673                   List.exists(fn ArgPattTuple{allConst=false, ...} => true | _ => false) firstArgSet
674                then
675                let
676                    val (thisDec, thisArg, thisMap) =
677                        transformFunctionArgument(currying, [LoadArgument m], [LoadArgument n], NONE)
678                    val (decs, args, mapList) = mapPattern(patts, n+1, m+1)
679                in
680                    (thisDec :: decs, thisArg :: args, thisMap :: mapList)
681                end
682                else
683                let
684                    val (decs, args, mapList) = mapPattern(patts, n+1, m+1)
685                in
686                    (decs, Extract(LoadArgument n) :: args, LoadArgument m :: mapList)
687                end
688
689            |   mapPattern(_ :: patts, n, m) =
690                let
691                    val (decs, args, mapList) = mapPattern(patts, n+1, m+1)
692                in
693                    (decs, Extract(LoadArgument n) :: args, LoadArgument m :: mapList)
694                end
695
696            |   mapPattern([], _, _) = ([], [], [])
697
698            and transformFunctionArgument(argumentArgs, loadPack, loadThisArg, filterOpt) =
699            let
700                (* Disable the transformation of curried arguments for the moment.
701                   This is unsafe.  See Test146.  The problem is that this transformation
702                   is only safe if the function is applied immediately to all the arguments.
703                   However the usage information is propagated so that if the result of
704                   the first application is bound to a variable and then that variable is
705                   applied it still appears as curried. *)
706                val argumentArgs = [hd argumentArgs]
707                (* We have a function that takes a series of curried argument.
708                   Change that so that the function takes a list of arguments. *)
709                val newAddr = ! localCounter before localCounter := ! localCounter + 1
710                (* In the main function we are expecting to call the argument in a curried
711                   fashion.  We need to construct a function that packages up the
712                   arguments and, when all of them have been provided, calls the actual
713                   argument. *)
714                local
715                    fun curryPack([], fnclosure) =
716                        let
717                            (* We're ready to call the function.  We now need to unpack any
718                               tupled arguments. *)
719                            fun mapArgs(c :: ctl, args) =
720                            let
721                                fun mapArg([], args) = mapArgs(ctl, args)
722                                |   mapArg(ArgPattTuple{filter, allConst=false, ...} :: patts, arg :: argctl) =
723                                    let
724                                        val fields = filterToFields filter
725                                    in
726                                        List.map(fn p => (mkInd(p, Extract arg), GeneralType)) fields @
727                                            mapArg(patts, argctl)
728                                    end
729                                |   mapArg(_ :: patts, arg :: argctl) =
730                                        (Extract arg, GeneralType) :: mapArg(patts, argctl)
731                                |   mapArg(_, []) = raise InternalError "mapArgs: mismatch"
732                            in
733                                mapArg(c, args)
734                            end
735                            |   mapArgs _ = []
736                            val argList = mapArgs(argumentArgs, tl fnclosure)
737                        in
738                            case filterOpt of
739                                NONE =>
740                                    Eval { function = Extract(hd fnclosure), resultType = GeneralType,
741                                            argList = argList }
742                            |   SOME filter =>
743                                    (* We need a container here for the result. *)
744                                    mkEnv(
745                                        [
746                                            Container{addr=0, size=setInFilter filter, use=[UseGeneral], setter=
747                                                Eval { function = Extract(hd fnclosure), resultType = GeneralType,
748                                                    argList = argList @ [(mkLoadLocal 0, GeneralType)] }
749                                            }
750                                        ],
751                                        buildFullTuple(filter, fn n => mkInd(n, mkLoadLocal 0))
752                                    )
753                        end
754                    |   curryPack(hd :: tl, fnclosure) =
755                        let
756                            val nArgs = List.length hd
757                            (* If this is the last then we need to include the container if required. *)
758                            val needContainer = case (tl, filterOpt) of ([], SOME _) => true | _ => false
759                        in
760                            Lambda { closure = fnclosure,
761                                isInline = Inline, name = name ^ "-P", resultType = GeneralType,
762                                argTypes = List.tabulate(nArgs, fn _ => (GeneralType, [UseGeneral])),
763                                localCount = if needContainer then 1 else 0, recUse = [],
764                                body = curryPack(tl,
765                                            (* The closure for the next level is the current closure
766                                               together with all the arguments at this level. *)
767                                            List.tabulate(List.length fnclosure, fn n => LoadClosure n) @
768                                            List.tabulate(nArgs, LoadArgument))
769                            }
770                        end
771                in
772                    val packFn = curryPack(argumentArgs, loadPack)
773                end
774                val thisDec = Declar { addr = newAddr, use = [], value = packFn }
775                fun argCount(ArgPattTuple{filter, allConst=false, ...}, m) = setInFilter filter + m
776                |   argCount(_, m) = m+1
777                local
778                    (* In the shim function, i.e. the inline function outside, we have
779                       a lambda that will be called when the main function wants to
780                       call its argument function.  This is provided with all the arguments
781                       and so it has to call the actual argument, which is expected to be
782                       curried, an argument at a time. *)
783                    fun curryApply(hd :: tl, n, c) =
784                        let
785                            fun makeArgs(_, []) = []
786                            |   makeArgs(q, ArgPattTuple{filter, allConst=false, ...} :: args) =
787                                    (buildFullTuple(filter, fn r => mkLoadArgument(r+q)), GeneralType) ::
788                                         makeArgs(q + setInFilter filter, args)
789                            |   makeArgs(q, _ :: args) =
790                                    (mkLoadArgument q, GeneralType) :: makeArgs(q+1, args)
791                            val args = makeArgs(n, hd)
792                        in
793                            curryApply(tl, n + List.foldl argCount 0 hd,
794                                Eval{function=c, resultType = GeneralType, argList=args})
795                        end
796                    |   curryApply([], _, c) = c
797                in
798                    val thisBody = curryApply (argumentArgs, 0, mkLoadClosure 0)
799                end
800                local
801                    (* We have one argument for each argument at each level of currying, or
802                       where we've expanded a tuple, one argument for each field.
803                       If the function is returning a tuple we have an extra argument for
804                       the container. *)
805                    val totalArgCount =
806                        List.foldl(fn (c, n) => n + List.foldl argCount 0 c) 0 argumentArgs +
807                        (case filterOpt of SOME _ => 1 | _ => 0)
808                    val functionBody =
809                        case filterOpt of
810                            NONE => thisBody
811                        |   SOME filter => mkSetContainer(mkLoadArgument(totalArgCount-1), thisBody, filter)
812                in
813                    val thisArg =
814                        Lambda {
815                            closure = loadThisArg, isInline = Inline, name = name ^ "-E",
816                            argTypes = List.tabulate(totalArgCount, fn _ => (GeneralType, [UseGeneral])),
817                            resultType = GeneralType, localCount = 0, recUse = [UseGeneral], body = functionBody
818                        }
819                end
820            in
821                (thisDec, thisArg, LoadLocal newAddr)
822            end
823
824            val (extraBindings, transArgCode, argMapList) = mapPattern(usage, 0, 0)
825
826            local
827                (* Transform the body by replacing the arguments with the new arguments. *)
828                val argMap = Vector.fromList argMapList
829                (* If we have a recursive reference we have to replace it with a reference
830                   to the shim. *)
831                val recEntry = LoadClosure(List.length closure)
832
833                fun doMap(Extract(LoadArgument n)) = SOME(Extract(Vector.sub(argMap, n)))
834                |   doMap(Extract LoadRecursive) = SOME(Extract recEntry)
835                |   doMap _ = NONE
836            in
837                val transBody = mapFunctionCode doMap body
838            end
839
840            local
841                (* The argument types for the main function have the tuples expanded,  Functions
842                   are not affected. *)
843                fun expand(ArgPattTuple{filter, allConst=false, ...}, _, r) = List.tabulate(setInFilter filter, fn _ => (GeneralType, [])) @ r
844                |   expand(_, a, r) = a :: r
845            in
846                val transArgTypes = ListPair.foldrEq expand [] (usage, argTypes)
847            end
848
849            (* Add the type information to the argument code. *)
850            val transArgs = ListPair.mapEq(fn (c, (t, _)) => (c, t)) (transArgCode, transArgTypes)
851
852            val mainAddress = makeAddress() and shimAddress = makeAddress()
853            val transLambda =
854                {
855                    body = mkEnv(extraBindings, transBody), name = name, argTypes = transArgTypes,
856                    closure = closure @ [LoadLocal shimAddress], resultType = resultType, isInline = isInline,
857                    localCount = ! localCounter, recUse = [UseGeneral]
858                }
859
860            (* Return the pair of functions. *)
861            val mainFunction = (mainAddress, transLambda)
862            val shimBody =
863                Eval { function = Extract(LoadClosure 0), argList = transArgs, resultType = resultType }
864            val shimLambda =
865                { body = shimBody, name = name, argTypes = argTypes, closure = [LoadLocal mainAddress],
866                  resultType = resultType, isInline = Inline, localCount = 0, recUse = [UseGeneral] }
867            val shimFunction = (shimAddress, shimLambda)
868            (* TODO:  We have two copies of the shim function here. *)
869        in
870            (shimLambda, [mainFunction, shimFunction])
871        end
872
873    fun decurryFunction(
874            { argTypes, name, resultType, closure, isInline, localCount,
875              body as Lambda { argTypes=subArgTypes, resultType=subResultType, ... } , ...}, makeAddress) =
876        (* Curried - just unwind one level this time.  This case is normally dealt with by
877           the front-end at least for fun bindings. *)
878        let
879            local
880                fun mapArg f n ((t, _) :: tl) = (Extract(f n), t) :: mapArg f (n+1) tl
881                |   mapArg _ _ [] = []
882            in
883                fun mapArgs f l = mapArg f 0 l
884            end
885
886            val mainAddress = makeAddress() and shimAddress = makeAddress()
887            (* The main function calls the original body as a function.  The body
888               is a lambda which will contain references to the outer arguments but
889               because we're just adding arguments these will be as before. *)
890            (* We have to transform any recursive references to point to the shim. *)
891            local
892                val recEntry = LoadClosure(List.length closure)
893
894                fun doMap(Extract LoadRecursive) = SOME(Extract recEntry)
895                |   doMap _ = NONE
896            in
897                val transBody = mapFunctionCode doMap body
898            end
899
900            val arg1Count = List.length argTypes
901            val mainLambda =
902                {
903                    body =
904                        Eval{ function = transBody, resultType = subResultType,
905                            argList = mapArgs (fn n => LoadArgument(n+arg1Count)) subArgTypes
906                        },
907                    name = name, resultType = subResultType,
908                    closure = closure @ [LoadLocal shimAddress], isInline = isInline, localCount = localCount,
909                    argTypes = argTypes @ subArgTypes, recUse = [UseGeneral]
910                }
911            val mainFunction = (mainAddress, mainLambda)
912
913            val shimInnerLambda =
914                Lambda {
915                    (* The inner shim closure contains the main function and the outer arguments. *)
916                    closure = LoadClosure 0 :: List.tabulate(arg1Count, LoadArgument),
917                    body = Eval {
918                                function = Extract(LoadClosure 0),
919                                resultType = resultType,
920                                (* Calls main function with both sets of args. *)
921                                argList = mapArgs (fn n => LoadClosure(n+1)) argTypes @
922                                          mapArgs LoadArgument subArgTypes
923                            },
924                    name = name ^ "-", resultType = subResultType, localCount = 0, isInline = Inline,
925                    argTypes = subArgTypes, recUse = [UseGeneral]
926                }
927
928            val shimOuterLambda =
929                { body = shimInnerLambda, name = name, argTypes = argTypes, closure = [LoadLocal mainAddress],
930                  resultType = resultType, isInline = Inline, localCount = 0, recUse = [UseGeneral] }
931            val shimFunction = (shimAddress, shimOuterLambda)
932        in
933            (shimOuterLambda: lambdaForm, [mainFunction, shimFunction])
934        end
935
936    |   decurryFunction _ = raise InternalError "decurryFunction"
937
938    (* Process a Lambda slightly differently in different contexts. *)
939    datatype lambdaContext = LCNormal | LCRecursive | LCImmediateCall
940
941    (* Transforming a lambda may result in producing auxiliary functions that are in
942       general mutually recursive. *)
943    fun mapLambdaResult([], lambda) = lambda
944    |   mapLambdaResult(bindings, lambda) =
945            mkEnv([RecDecs(map(fn(addr, lam) => {addr=addr, use=[], lambda=lam}) bindings)], lambda)
946
947    fun optimise (context, use) (Lambda lambda) =
948            SOME(mapLambdaResult(optLambda(context, use, lambda, LCNormal)))
949
950    |   optimise (context, use) (Newenv(envDecs, envExp)) =
951        let
952            fun mapExp mapUse = mapCodetree (optimise(context, mapUse))
953
954            fun mapbinding(Declar{value, addr, use}) = Declar{value=mapExp use value, addr=addr, use=use}
955            |   mapbinding(RecDecs l) =
956                let
957                    fun mapRecDec({addr, lambda, use}, rest) =
958                        case optLambda(context, use, lambda, LCRecursive) of
959                            (bindings, Lambda lambdaRes) =>
960                                (* Turn any bindings into extra mutually-recursive functions. *)
961                                {addr=addr, use = use, lambda = lambdaRes } ::
962                                    map (fn (addr, res) => {addr=addr, use=use, lambda=res }) bindings @ rest
963                        |   _ => raise InternalError "mapbinding: not lambda"
964                in
965                    RecDecs(foldl mapRecDec [] l)
966                end
967            |   mapbinding(NullBinding exp) = NullBinding(mapExp [UseGeneral] exp)
968            |   mapbinding(Container{addr, use, size, setter}) =
969                    Container{addr=addr, use=use, size=size, setter = mapExp [UseGeneral] setter}
970        in
971            SOME(Newenv(map mapbinding envDecs, mapExp use envExp))
972        end
973
974        (* Immediate call to a function.  We may be able to expand this inline unless it
975           is recursive. *)
976    |   optimise (context, use) (Eval {function = Lambda lambda, argList, resultType}) =
977        let
978            val args = map (fn (c, t) => (optGeneral context c, t)) argList
979            val argTuples = map #1 args
980            val (bindings, newLambda) = optLambda(context, [UseApply(use, argTuples)], lambda, LCImmediateCall)
981            val call = Eval { function=newLambda, argList=args, resultType = resultType }
982        in
983            SOME(mapLambdaResult(bindings, call))
984        end
985
986    |   optimise (context as { reprocess, ...}, use) (Eval {function = Cond(i, t, e), argList, resultType}) =
987        let
988            (* Transform "(if i then t else e) x" into "if i then t x else e x".  This
989               allows for other optimisations and inline expansion. *)
990            (* We duplicate the function arguments which could cause the size of the code
991               to blow-up if they involve complicated expressions. *)
992            fun pushFunction l =
993                 mapCodetree (optimise(context, use)) (Eval{function=l, argList=argList, resultType=resultType})
994        in
995            reprocess := true;
996            SOME(Cond(i, pushFunction t, pushFunction e))
997        end
998
999    |   optimise (context, use) (Eval {function, argList, resultType}) =
1000        (* If nothing else we need to ensure that "use" is correctly set on
1001           the function and arguments and we don't simply pass the original. *)
1002        let
1003            val args = map (fn (c, t) => (optGeneral context c, t)) argList
1004            val argTuples = map #1 args
1005        in
1006            SOME(
1007                Eval{
1008                    function= mapCodetree (optimise (context, [UseApply(use, argTuples)])) function,
1009                    argList=args, resultType = resultType
1010                })
1011        end
1012
1013    |   optimise (context, use) (Indirect{base, offset, isVariant = false}) =
1014        SOME(Indirect{base = mapCodetree (optimise(context, [UseField(offset, use)])) base,
1015                      offset = offset, isVariant = false})
1016
1017    |   optimise (context, use) (code as Cond _) =
1018        (* If the result of the if-then-else is always taken apart as fields
1019           then we are better off taking it apart further down and putting
1020           the fields into a container on the stack. *)
1021        if List.all(fn UseField _ => true | _ => false) use
1022        then SOME(optFields(code, context, use))
1023        else NONE
1024
1025    |   optimise (context, use) (code as BeginLoop _) =
1026        (* If the result of the loop is taken apart we should push
1027           this down as well. *)
1028        if List.all(fn UseField _ => true | _ => false) use
1029        then SOME(optFields(code, context, use))
1030        else NONE
1031
1032    |   optimise _ _ = NONE
1033    
1034    and optGeneral context exp = mapCodetree (optimise(context, [UseGeneral])) exp
1035
1036    and optLambda(
1037            { debugArgs, reprocess, makeAddr, ... },
1038            contextUse,
1039            { body, name, argTypes, resultType, closure, localCount, isInline, recUse, ...},
1040            lambdaContext) : (int * lambdaForm) list * codetree =
1041    (*
1042        Optimisations on lambdas.
1043        1.  A lambda that simply calls another function with all its own arguments
1044            can be replaced by a reference to the function provided the "function"
1045            is a side-effect-free expression.
1046        2.  Don't attempt to optimise inline functions that are exported.
1047        3.  Transform lambdas that take tuples as arguments or are curried or where
1048            an argument is a function with tupled or curried arguments into a pair
1049            of an inline function with the original argument set and a new "main"
1050            function with register/stack arguments.
1051    *)
1052    let
1053        (* The overall use of the function is the context plus the recursive use. *)
1054        val use = contextUse @ recUse
1055        (* Check if it's a call to another function with all the original arguments.
1056           This is really wanted when we are passing this lambda as an argument to
1057           another function and really only when we have produced a shim function
1058           that has been inline expanded.  Otherwise this will be a "small" function
1059           and will be inline expanded when it's used. *)
1060        val replaceBody =
1061            case (body, lambdaContext = LCRecursive) of
1062                (Eval { function, argList, resultType=callresult }, false) =>
1063                let
1064                    fun argSequence((Extract(LoadArgument a), _) :: rest, b) = a = b andalso argSequence(rest, b+1)
1065                    |   argSequence([], _) = true
1066                    |   argSequence _ = false
1067        
1068                    val argumentsMatch =
1069                        argSequence(argList, 0) andalso 
1070                            ListPair.allEq(fn((_, a), (b, _)) => a = b) (argList, argTypes) andalso
1071                            callresult = resultType
1072                in
1073                    if not argumentsMatch
1074                    then NONE
1075                    else
1076                    case function of
1077                        (* This could be any function which has neither side-effects nor
1078                           depends on a reference nor depends on another argument but if
1079                           it has local variables they would have to be renumbered into
1080                           the surrounding scope.  In practice we're really only interested
1081                           in simple cases that arise as a result of using a "shim" function
1082                           created in the code below. *)
1083                        c as Constnt _ => SOME c
1084                    |   Extract(LoadClosure addr) => SOME(Extract(List.nth(closure, addr)))
1085                    |   _ => NONE
1086                end
1087            |   _ => NONE
1088    in
1089        case replaceBody of
1090            SOME c => ([], c)
1091        |   NONE =>
1092            if isInline = Inline andalso List.exists (fn UseExport => true | _ => false) use
1093            then
1094            let
1095                (* If it's inline any application of this will be optimised after
1096                   inline expansion.  We still apply any opimisations to the body
1097                   at this stage because we will compile and code-generate a version
1098                   for use if we want a "general" value. *)
1099                val addressAllocator = ref localCount
1100                val optContext =
1101                {
1102                    makeAddr = fn () => (! addressAllocator) before addressAllocator := ! addressAllocator + 1,
1103                    reprocess = reprocess,
1104                    debugArgs = debugArgs
1105                }
1106                val optBody = mapCodetree (optimise(optContext, [UseGeneral])) body
1107                val lambdaRes =
1108                    {
1109                        body = optBody,
1110                        isInline = isInline, name = name, closure = closure,
1111                        argTypes = argTypes, resultType = resultType, recUse = recUse,
1112                        localCount = !addressAllocator (* After optimising body. *)
1113                    }
1114            in
1115                ([], Lambda lambdaRes) 
1116            end
1117            else
1118            let
1119                (* Allocate any new addresses after the existing ones. *)
1120                val addressAllocator = ref localCount
1121                val optContext =
1122                {
1123                    makeAddr = fn () => (! addressAllocator) before addressAllocator := ! addressAllocator + 1,
1124                    reprocess = reprocess,
1125                    debugArgs = debugArgs
1126                }
1127                val optBody = mapCodetree (optimise(optContext, [UseGeneral])) body
1128
1129                (* See if this should be expanded inline.  If we are calling the lambda
1130                   immediately we try to expand it unless maxInlineSize is zero.  We
1131                   may not be able to expand it if it is recursive. (It may have been
1132                   inside an inline function). *)
1133                val maxInlineSize = DEBUG.getParameter DEBUG.maxInlineSizeTag debugArgs
1134                val (inlineType, updatedBody, localCount) =
1135                    case evaluateInlining(optBody, List.length argTypes,
1136                            if maxInlineSize <> 0 andalso lambdaContext = LCImmediateCall
1137                            then 1000 else FixedInt.toInt maxInlineSize) of
1138                        NonRecursive  => (Inline, optBody, ! addressAllocator)
1139                    |   TailRecursive bv =>
1140                            (Inline,
1141                                replaceTailRecursiveWithLoop(optBody, argTypes, bv, addressAllocator), ! addressAllocator)
1142                    |   NonTailRecursive bv =>
1143                            if Vector.exists (fn n => n) bv
1144                            then (Inline, 
1145                                    liftRecursiveFunction(
1146                                        optBody, argTypes, bv, List.length closure, name, resultType, !addressAllocator), 0)
1147                            else (NonInline, optBody, ! addressAllocator) (* All arguments have been modified *)
1148                    |   TooBig => (NonInline, optBody, ! addressAllocator)
1149
1150                val lambda: lambdaForm =
1151                {
1152                    body = updatedBody, name = name, argTypes = argTypes, closure = closure,
1153                    resultType = resultType, isInline = inlineType, localCount = localCount,
1154                    recUse = recUse
1155                }
1156
1157                (* See if it should be transformed.  We only do this if the function is not going
1158                   to be inlined.  If it is then there's no point because the transformation is
1159                   going to be done as part of the inling process.  Even if it's marked for
1160                   inlining we may not actually call the function and instead pass it as an
1161                   argument or return it as result but in that case transformation doesn't
1162                   achieve anything because we are going to pass the untransformed "shim"
1163                   function anyway. *)
1164                val (newLambda, bindings) =
1165                    if isInline = NonInline
1166                    then
1167                    let
1168                        val functionPattern =
1169                            case usageForFunctionBody use of
1170                                ArgPattCurry(arg1 :: arg2 :: moreArgs, res) =>
1171                                    (* The function is always called with at least two curried arguments.
1172                                       We can decurry the function if the body is applicative - typically
1173                                       if it's a lambda - but not if applying the body would have a
1174                                       side-effect.  We only do it one level at this stage.  If it's
1175                                       curried more than that we'll come here again. *)
1176                                    (* In order to get the types we restrict this to the case of
1177                                       a body that is a lambda.  The result is a function and therefore
1178                                       ArgPattSimple unless we are using up all the args. *)
1179                                    if (*reorderable body*) case updatedBody of Lambda _ => true | _ => false
1180                                    then ArgPattCurry([arg1, arg2], if null moreArgs then res else ArgPattSimple)
1181                                    else ArgPattCurry([arg1], ArgPattSimple)
1182                            |   usage => usage
1183
1184                        val argPatterns = map (usageForFunctionArg o #2) argTypes
1185
1186                        (* fullArgPattern is a list, one per level of currying, of a list, one per argument of
1187                           the patterns.  resultPattern is used to detect whether the result is a tuple that
1188                           is taken apart. *)
1189                        val (fullArgPattern, resultPattern) =
1190                            case functionPattern of
1191                                ArgPattCurry(_ :: rest, resPattern) =>
1192                                let
1193                                    (* The function is always applied at least to the first set of arguments.
1194                                       (It's never just passed). Merge the applications of the function
1195                                       with the use of the arguments.  Return the usage within the
1196                                       function unless the function takes apart a tuple but no
1197                                       application passes in a tuple. *)
1198                                    fun merge(ArgPattTuple _, argUse as ArgPattTuple _) = argUse
1199                                    |   merge(_, ArgPattTuple _) = ArgPattSimple
1200                                    |   merge(_, argUse)  = argUse
1201
1202                                    val mergedArgs =
1203                                        (ListPair.mapEq merge (existTupling use, argPatterns)) :: rest
1204
1205                                    (* *)
1206                                    val mergedResult =
1207                                        case (bodyReturnsTuple updatedBody, resPattern) of
1208                                            (bodyTuple as ArgPattTuple _, ArgPattSimple) =>
1209                                                if existDetupling use
1210                                                then bodyTuple
1211                                                else ArgPattSimple
1212                                        |   _ => resPattern
1213                                            
1214                                in
1215                                    (mergedArgs, mergedResult)
1216                                end
1217                            |   _ => (* Not called: either exported or passed as a value. *)
1218                                (* This previously tried to see whether the body returned a tuple
1219                                   if the function was exported.  This caused an infinite loop
1220                                   (see Tests/Succeed/Test164.ML) and anyway doesn't seem to
1221                                   optimise the cases we want. *)
1222                                ([], ArgPattSimple)
1223                    in
1224                        case (fullArgPattern, resultPattern) of
1225                            (_ :: _ :: _, _) => (* Curried *)
1226                                ( reprocess := true; decurryFunction(lambda, makeAddr))
1227
1228                        |   (_, ArgPattTuple {filter, ...}) => (* Result is a tuple *)
1229                                ( reprocess := true; detupleResult(lambda, filter, makeAddr))
1230
1231                        |   (first :: _, _) =>
1232                            let
1233                                fun checkArg (ArgPattTuple{allConst=false, ...}) = true
1234                                        (* Function has at least one tupled arg. *)
1235                                |   checkArg (ArgPattCurry([_], ArgPattTuple{allConst=false, ...})) = true
1236                                        (* Function has an arg that is a function that returns a tuple.
1237                                           It must not be curried otherwise it returns a function not a tuple. *)
1238                                (* This transformation is unsafe.  See comment in transformFunctionArgument above. *)
1239                                (*|   checkArg (ArgPattCurry(_ :: _ :: _, _)) = true *)
1240                                        (* Function has an arg that is a curried function. *)
1241                                |   checkArg (ArgPattCurry(firstArgSet :: _, _)) =
1242                                        (* Function has an arg that is a function that
1243                                           takes a tuple in its first argument set. *)
1244                                        List.exists(fn ArgPattTuple{allConst=false, ...} => true | _ => false) firstArgSet
1245                                |   checkArg _ = false
1246                            in
1247                                (* It isn't curried - look at the arguments. *)
1248                                if List.exists checkArg first
1249                                then ( reprocess := true; transformFunctionArgs(lambda, first, makeAddr) )
1250                                else (lambda, [])
1251                            end
1252
1253                        |   _ => (lambda, [])
1254                    end
1255                    else (lambda, [])
1256            in
1257                (* If this is to be inlined but was not before we may need to reprocess.
1258                   We don't reprocess if this is only exported.  If it's only exported
1259                   we're not going to expand it within this code and we can end up with
1260                   repeated processing. *)
1261                if #isInline newLambda = Inline andalso isInline = NonInline andalso
1262                    (case use of [UseExport] => false | _ => true)
1263                then reprocess := true
1264                else ();
1265                (bindings, Lambda newLambda)
1266            end
1267    end
1268
1269    and optFields (code, context as { reprocess, makeAddr, ...}, use) =
1270    let
1271        (* We have an if-then-else or a loop whose result is only ever
1272           taken apart.  We push this down. *)
1273        (* Find the fields that are used.  Not all may be. *)
1274        local
1275            val maxField =
1276                List.foldl(fn (UseField(f, _), m) => Int.max(f, m) | (_, m) => m) 0 use
1277            val fieldUse = BoolArray.array(maxField+1, false)
1278            val _ =
1279                List.app(fn UseField(f, _) => BoolArray.update(fieldUse, f, true) | _ => ()) use
1280        in
1281            val maxField = maxField
1282            val useList = BoolArray.foldri (fn (i, true, l) => i :: l | (_, _, l) => l) [] fieldUse
1283        end
1284
1285        fun pushContainer(Cond(ifpt, thenpt, elsept), leafFn) =
1286                Cond(ifpt, pushContainer(thenpt, leafFn), pushContainer(elsept, leafFn))
1287
1288        |   pushContainer(Newenv(decs, exp), leafFn) =
1289                Newenv(decs, pushContainer(exp, leafFn))
1290
1291        |   pushContainer(BeginLoop{loop, arguments}, leafFn) =
1292                (* If we push it through a BeginLoop we MUST then push it through
1293                   anything that could contain the Loop i.e. Cond, Newenv, Handle. *)
1294                BeginLoop{loop = pushContainer(loop, leafFn), arguments=arguments}
1295
1296        |   pushContainer(l as Loop _, _) = l
1297                (* Within a BeginLoop only the non-Loop leaves return
1298                   values.  Loop entries go back to the BeginLoop so
1299                   these are unchanged. *)
1300
1301        |   pushContainer(Handle{exp, handler, exPacketAddr}, leafFn) =
1302                Handle{exp=pushContainer(exp, leafFn), handler=pushContainer(handler, leafFn), exPacketAddr=exPacketAddr}
1303
1304        |   pushContainer(tuple, leafFn) = leafFn tuple (* Anything else. *)
1305
1306        val () = reprocess := true
1307    in
1308        case useList of
1309            [offset] => (* We only want a single field.  Push down an Indirect. *)
1310            let
1311                (* However the context still requires a tuple.  We need to
1312                   reconstruct one with unused fields set to zero.  They will
1313                   be filtered out later by the simplifier pass. *)
1314                val field =
1315                    optGeneral context (pushContainer(code, fn t => mkInd(offset, t)))
1316                fun mkFields n = if n = offset then field else CodeZero
1317            in
1318                Tuple{ fields = List.tabulate(offset+1, mkFields), isVariant = false }
1319            end
1320
1321        |   _ =>
1322            let
1323                (* We require a container. *)
1324                val containerAddr = makeAddr()
1325                val width = List.length useList
1326                val loadContainer = Extract(LoadLocal containerAddr)
1327
1328                fun setContainer tuple = (* At the leaf set the container. *)
1329                    SetContainer{container = loadContainer, tuple = tuple, filter = fieldsToFilter useList }
1330
1331                val setCode = optGeneral context (pushContainer(code, setContainer))
1332                val makeContainer =
1333                    Container{addr=containerAddr, use=[], size=width, setter=setCode}
1334                (* The context requires a tuple of the original width.  We need
1335                   to add dummy fields where necessary. *)
1336                val container =
1337                    if width = maxField+1
1338                    then mkTupleFromContainer(containerAddr, width)
1339                    else
1340                    let
1341                        fun mkField(n, m, hd::tl) =
1342                            if n = hd 
1343                            then mkInd(m, loadContainer) :: mkField(n+1, m+1, tl)
1344                            else CodeZero :: mkField(n+1, m, hd::tl)
1345                        |   mkField _ = []
1346                    in
1347                        Tuple{fields = mkField(0, 0, useList), isVariant=false}
1348                    end
1349            in
1350                mkEnv([makeContainer], container)
1351            end
1352    end
1353
1354    (* TODO: convert "(if a then b else c) (args)" into if a then b(args) else c(args).  This would
1355       allow for possible inlining and also passing information about call patterns. *)
1356
1357    (* Once all the inlining is done we look for functions that can be compiled immediately.
1358       These are either functions with no free variables or functions where every use is a
1359       call, as opposed to being passed or returned as a closure.  Functions that have free
1360       variables but are called can be lambda-lifted where the free variables are turned into
1361       extra parameters.  The advantage compared with using a static-link or a closure on
1362       the stack is that they can be fully tail-recursive.  With a static-link or stack
1363       closure the free variables have to remain on the stack until the function returns. *)
1364    fun lambdaLiftAndConstantFunction(code, debugSwitches, numLocals) =
1365    let
1366        val needReprocess = ref false
1367        (* At the moment this just code-generates immediately any lambdas without
1368           free-variables.  The idea is to that we will get a constant which can
1369           then be inserted directly in references to the function.  In general
1370           this takes a list of mutually recursive functions which can be code-
1371           generated immediately if all the free variables are other functions
1372           in the list.  The simplifier has separated mutually recursive
1373           bindings into strongly connected components so we can consider
1374           the list as a single entity. *)
1375        fun processLambdas lambdaList =
1376        let
1377            (* First process the bodies of the functions. *)
1378            val needed = ! needReprocess
1379            val _ = needReprocess := false;
1380            val transLambdas =
1381                map (fn {lambda={body, isInline, name, closure, argTypes, resultType, localCount, recUse}, use, addr} =>
1382                        {lambda={body=mapChecks body, isInline=isInline, name=name, closure=closure,
1383                                  argTypes=argTypes, resultType=resultType, localCount=localCount, recUse=recUse},
1384                         use=use, addr=addr}) lambdaList
1385            val theseTransformed = ! needReprocess
1386            val _ = if needed then needReprocess := true else ()
1387
1388            fun hasFreeVariables{lambda={closure, ...}, ...} =
1389            let
1390                fun notInLambdas(LoadLocal lAddr) =
1391                    (* A local is allowed if it only refers to another lambda. *)
1392                        not (List.exists (fn {addr, ...} => addr = lAddr) lambdaList)
1393                |   notInLambdas _ = true (* Anything else is not allowed. *)
1394            in
1395                List.exists notInLambdas closure
1396            end
1397        in
1398            if theseTransformed orelse List.exists (fn {lambda={isInline=Inline, ...}, ...} => true | _ => false) lambdaList
1399               orelse List.exists hasFreeVariables lambdaList
1400            (* If we have transformed any of the bodies we need to reprocess so defer any
1401               code-generation.  Don't CG it if it is inline, or perhaps if it is inline and exported. 
1402               Don't CG it if it has free variables.  We still need to examine
1403               the bodies of the functions. *)
1404            then (transLambdas, [])
1405            else
1406            let
1407                (* Construct code to declare the functions and extract the values. *)
1408                val tupleFields = map (fn {addr, ...} => Extract(LoadLocal addr)) transLambdas
1409                val decsAndTuple = Newenv([RecDecs transLambdas], mkTuple tupleFields)
1410                val maxLocals = List.foldl(fn ({addr, ...}, n) => Int.max(addr, n)) 0 transLambdas
1411                val (code, props) = BACKEND.codeGenerate(decsAndTuple, maxLocals + 1, debugSwitches)
1412                val resultConstnt = Constnt(code(), props)
1413                fun getResults([], _) = []
1414                |   getResults({addr, use, ...} :: tail, n) =
1415                        Declar {value=mkInd(n, resultConstnt), addr=addr, use=use} :: getResults(tail, n+1)
1416                val () = needReprocess := true
1417            in
1418                ([], getResults(transLambdas, 0))
1419            end
1420        end
1421
1422        and runChecks (Lambda (lambda as { isInline=NonInline, closure=[], ... })) =
1423            (
1424                (* Bare lambda. *)
1425                case processLambdas[{lambda=lambda, use = [], addr = 0}] of
1426                    ([{lambda=unCGed, ...}], []) => SOME(Lambda unCGed)
1427                |   ([], [Declar{value, ...}]) => SOME value
1428                |   _ => raise InternalError "processLambdas"
1429            )
1430        
1431        |   runChecks (Newenv(bindings, exp)) =
1432            let 
1433                (* We have a block of bindings.  Are any of them functions that are only ever called? *)
1434                fun checkBindings(Declar{value=Lambda lambda, addr, use}, tail) =
1435                    (
1436                        (* Process this lambda and extract the result. *)
1437                        case processLambdas[{lambda=lambda, use = use, addr = addr}] of
1438                            ([{lambda=unCGed, use, addr}], []) =>
1439                                Declar{value=Lambda unCGed, use=use, addr=addr} :: tail
1440                        |   ([], cgedDec) => cgedDec @ tail
1441                        |   _ => raise InternalError "checkBindings"
1442                    )
1443
1444                |   checkBindings(Declar{value, addr, use}, tail) =
1445                        Declar{value=mapChecks value, addr=addr, use=use} :: tail
1446
1447                |   checkBindings(RecDecs l, tail) =
1448                    let
1449                        val (notConsts, asConsts) = processLambdas l
1450                    in
1451                        asConsts @
1452                            (if null notConsts then [] else [RecDecs notConsts]) @
1453                                tail
1454                    end
1455
1456                |   checkBindings(NullBinding exp, tail) = NullBinding(mapChecks exp) :: tail
1457
1458                |   checkBindings(Container{addr, use, size, setter}, tail) =
1459                        Container{addr=addr, use=use, size=size, setter=mapChecks setter} :: tail
1460
1461            in
1462                SOME(Newenv((List.foldr checkBindings [] bindings), mapChecks exp))
1463            end
1464
1465        |   runChecks _ = NONE
1466
1467        and mapChecks c = mapCodetree runChecks c
1468
1469    in
1470        (mapCodetree runChecks code, numLocals, !needReprocess)
1471    end
1472
1473    (* Main optimiser and simplifier loop. *)
1474    fun codetreeOptimiser(code, debugSwitches, numLocals) =
1475    let
1476        fun topLevel _ = raise InternalError "top level reached in optimiser"
1477
1478        fun processTree (code, nLocals, optAgain) =
1479        let
1480            (* First run the simplifier.  Among other things this does inline
1481               expansion and if it does any we at least need to run cleanProc
1482               on the code so it will have set simpAgain. *)
1483            val (simpCode, simpCount, simpAgain) = SIMPLIFIER.simplifier(code, nLocals)
1484        in
1485            if optAgain orelse simpAgain
1486            then
1487            let
1488                (* Identify usage information and remove redundant code. *)
1489                val printCodeTree      = DEBUG.getParameter DEBUG.codetreeTag debugSwitches
1490                and compilerOut        = PRETTY.getCompilerOutput debugSwitches
1491                val simpCode = SIMPLIFIER.specialToGeneral simpCode
1492                val () = if printCodeTree then compilerOut(PRETTY.PrettyString "Output of simplifier") else ()
1493                val () = if printCodeTree then compilerOut (BASECODETREE.pretty simpCode) else ()
1494                val preOptCode =
1495                    REMOVE_REDUNDANT.cleanProc(simpCode, [UseExport], topLevel, simpCount)
1496                (* Print the code with the use information before it goes into the optimiser. *)
1497                val () = if printCodeTree then compilerOut(PRETTY.PrettyString "Output of cleaner") else ()
1498                val () = if printCodeTree then compilerOut (BASECODETREE.pretty preOptCode) else ()
1499
1500                val reprocess = ref false (* May be set in the optimiser *)
1501                (* Allocate any new addresses after the existing ones. *)
1502                val addressAllocator = ref simpCount
1503                fun makeAddr() =
1504                    (! addressAllocator) before addressAllocator := ! addressAllocator + 1
1505                val optContext =
1506                {
1507                    makeAddr = makeAddr,
1508                    reprocess = reprocess,
1509                    debugArgs = debugSwitches
1510                }
1511                (* Optimise the code, rewriting it as necessary. *)
1512                val optCode = mapCodetree (optimise(optContext, [UseExport])) preOptCode
1513                
1514                val (llCode, llCount, llAgain) =
1515                    (* If we have optimised it or the simplifier has run something that it wants to
1516                       run again we must rerun these before we try to generate any code. *)
1517                    if ! reprocess (* Re-optimise *) orelse simpAgain (* The simplifier wants to run again on this. *)
1518                    then (optCode, ! addressAllocator, ! reprocess)
1519                    else (* We didn't detect any inlineable functions.  Check for lambda-lifting. *)
1520                        lambdaLiftAndConstantFunction(optCode, debugSwitches, ! addressAllocator)
1521
1522                (* Print the code after the optimiser. *)
1523                val () = if printCodeTree then compilerOut(PRETTY.PrettyString "Output of optimiser") else ()
1524                val () = if printCodeTree then compilerOut (BASECODETREE.pretty llCode) else ()
1525            in
1526                (* Rerun the simplifier at least. *)
1527                processTree(llCode, llCount, llAgain)
1528            end
1529            else (simpCode, simpCount) (* We're done *)
1530        end
1531
1532        val (postOptCode, postOptCount) = processTree(code, numLocals, true (* Once at least *))
1533        val (rGeneral, rDecs, rSpec) = postOptCode
1534    in
1535        { numLocals = postOptCount, general = rGeneral, bindings = rDecs, special = rSpec }
1536    end
1537
1538    structure Sharing = struct type codetree = codetree and envSpecial = envSpecial and codeBinding = codeBinding end
1539
1540end;
1541