1(*
2    Copyright (c) 2015 David C.J. Matthews
3
4    This library is free software; you can redistribute it and/or
5    modify it under the terms of the GNU Lesser General Public
6    License as published by the Free Software Foundation; either
7    version 2.1 of the License, or (at your option) any later version.
8    
9    This library is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12    Lesser General Public License for more details.
13    
14    You should have received a copy of the GNU Lesser General Public
15    License along with this library; if not, write to the Free Software
16    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17*)
18
19(*
20Lambda-lifting.  If every call point to a function can be identified we can
21lift the free variables as extra parameters.  This avoids the need for a
22closure on the heap.  It makes stack-closures largely redundant.  The
23advantages of lambda-lifting over stack closures are that the containing
24function of a stack-closure cannot call a stack-closure with tail-recursion
25because the closure must remain on the stack until the function returns.
26Also we can lambda-lift a function even if it is used in a function that
27requires a full closure whereas we cannot use a stack closure for a
28function if the closure would be used in a full, heap closure.
29
30This pass is called after optimisation and after any functions that have
31empty closures have been code-generated to constants.
32*)
33functor CODETREE_LAMBDA_LIFT (
34
35    structure BASECODETREE: BaseCodeTreeSig
36    structure CODETREE_FUNCTIONS: CodetreeFunctionsSig
37    structure BACKEND: CodegenTreeSig
38    structure DEBUG: DEBUGSIG
39    structure PRETTY : PRETTYSIG
40
41    sharing
42        BASECODETREE.Sharing
43    =   CODETREE_FUNCTIONS.Sharing
44    =   BACKEND.Sharing
45    =   PRETTY.Sharing
46): CodegenTreeSig =
47struct
48    open BASECODETREE
49    open CODETREE_FUNCTIONS
50    exception InternalError = Misc.InternalError
51    
52    (* First pass: identify the functions whose only use are calls.  This annotates the tree
53       by setting the "use" or any bindings or recursive uses that require a closure to
54       [UseGeneral]. *)
55    fun checkBody(code: codetree, closureRef: int -> unit, recursiveRef: unit -> unit, localCount) =
56    let
57        (* An entry for each local binding.  Set to true if we find a non-call reference. *)
58        val localsNeedClosures = BoolArray.array(localCount, false)
59
60        fun markExtract(LoadLocal n) = BoolArray.update(localsNeedClosures, n, true)
61        |   markExtract LoadRecursive = recursiveRef()
62        |   markExtract(LoadClosure n) = closureRef n
63        |   markExtract(LoadArgument _) = ()
64
65        fun checkCode(ext as Extract load) = (markExtract load; SOME ext)
66            (* These are loads which aren't calls.  If they are functions they need closures. *)
67
68        |   checkCode(Eval{function as Extract _, argList, resultType}) =
69            (* A call of a function.  We don't need to mark the function as needing a closure. *)
70                SOME(
71                    Eval{function=function, argList=map(fn (c, t) => (checkMapCode c, t)) argList,
72                         resultType=resultType})
73
74        |   checkCode(Lambda lambda) = SOME(Lambda(checkLambda lambda))
75
76        |   checkCode(Newenv(decs, exp)) =
77            (* We want to add [UseGeneral] to bindings that require closures.  To do that
78               we have to process the bindings in reverse order. *)
79            let
80                val processedExp = checkMapCode exp (* The expression first. *)
81                
82                fun getFlag addr =
83                    if BoolArray.sub(localsNeedClosures, addr) then [UseGeneral] else []
84
85                fun processDecs [] = []
86
87                |   processDecs ((Declar { value, addr, ...}) :: tail) =
88                    let
89                        val pTail = processDecs tail (* Tail first *)
90                        val pValue = checkMapCode value
91                    in
92                        Declar{value = pValue, addr=addr, use=getFlag addr} :: pTail
93                    end
94
95                |   processDecs (RecDecs l :: tail) =
96                    let
97                        val pTail = processDecs tail (* Tail first *)
98                        (* Process the lambdas.  Because they're mutually recursive this may set the
99                           closure flag for others in the set. *)
100                        val pLambdas =
101                            map (fn {lambda, addr, ...} =>
102                                    {addr=addr, use=[], lambda=checkLambda lambda}) l
103                        (* Can now pick up the closure flags. *)
104                        val pDecs =
105                            map(fn {lambda, addr, ...} =>
106                                    {lambda=lambda, addr=addr, use=getFlag addr}) pLambdas
107                    in
108                        RecDecs pDecs :: pTail
109                    end
110
111                |   processDecs (NullBinding c :: tail) =
112                    let
113                        val pTail = processDecs tail
114                    in
115                        NullBinding(checkMapCode c) :: pTail
116                    end
117
118                |   processDecs (Container{ addr, size, setter,... } :: tail) =
119                    let
120                        val pTail = processDecs tail
121                    in
122                        Container{addr=addr, use=[], size=size, setter=checkMapCode setter} :: pTail
123                    end
124            in
125                SOME(Newenv(processDecs decs, processedExp))
126            end
127
128        |   checkCode _ = NONE
129
130        and checkLambda({body, closure, localCount, name, argTypes, resultType, ...}) =
131        (* Lambdas - check the function body and any recursive uses. *)
132        let
133            val recNeedsClosure = ref false
134            fun refToRecursive() = recNeedsClosure := true
135            fun refToClosure n = markExtract(List.nth(closure, n))
136            val processedBody = checkBody(body, refToClosure, refToRecursive, localCount)
137        in
138            {body=processedBody, isInline=NonInline, closure=closure, localCount=localCount, name=name,
139             argTypes=argTypes, resultType=resultType, recUse=if !recNeedsClosure then [UseGeneral] else []}
140        end
141
142        and checkMapCode code = mapCodetree checkCode code
143    in
144        checkMapCode code
145    end
146
147    (* Second pass: Actually do the lambda-lifting. *)
148    datatype lift =
149        LiftLoad of loadForm (* Usually unlifted but also for recursive calls. *)
150    |   LiftConst of codetree (* A lifted function. *)
151
152    fun processBody(code: codetree, getClosure: int -> lift * loadForm list,
153                    getRecursive: unit -> loadForm list, localCount, debugArgs): codetree =
154    let
155 
156        val processedLambdas:
157            (codetree * loadForm list) option array = Array.array(localCount, NONE)
158
159        fun findBinding(ext as LoadLocal n) =
160            (
161                case Array.sub(processedLambdas, n) of 
162                    SOME (c, l) => (LiftConst c, l)
163                |   NONE => (LiftLoad ext, [])
164            )
165        |   findBinding(LoadRecursive) = (LiftLoad LoadRecursive, getRecursive())
166            (* The code for the recursive case is always LoadRecursive but depending
167               on whether it's been lifted or not there may be extra args. *)
168        |   findBinding(LoadClosure n) = getClosure n
169        |   findBinding(ext as LoadArgument _) = (LiftLoad ext, [])
170
171        fun processCode(Eval{function=Extract ext, argList, resultType}) =
172            let
173                (* If this has been lifted we have to add the extra arguments.
174                   The function may also now be a constant. *)
175                val (newFunction, extraArgs) =
176                    case findBinding ext of
177                        (LiftConst c, l) => (c, l)
178                    |   (LiftLoad e, l) => (Extract e, l)
179
180                (* Process the original args.  There may be functions in there. *)
181                val processedArgs = map(fn (c, t) => (processMapCode c, t)) argList
182            in
183                SOME(Eval{function=newFunction,
184                     argList=processedArgs @ map(fn c => (Extract c, GeneralType)) extraArgs,
185                     resultType=resultType})
186            end
187
188        |   processCode(Eval{function=Lambda(lambda as { recUse=[], ...}), argList, resultType}) =
189            (* We have a call to a lambda.  This must be a recursive function otherwise it
190               would have been expanded inline.  If the recursive references are just calls
191               we can lambda-lift it. *)
192            let
193                val (fnConstnt, extraArgs) = hd(liftLambdas([(lambda, NONE)]))
194                val processedArgs = map(fn (c, t) => (processMapCode c, t)) argList
195            in
196                SOME(Eval{function=fnConstnt,
197                     argList=processedArgs @ map(fn c => (Extract c, GeneralType)) extraArgs,
198                     resultType=resultType})
199            end
200
201        |   processCode(Extract ext) =
202            (
203                (* A load of a binding outside a call.  We need to process this to
204                   rebuild the closure but if we get a lifted function it's an error. *)
205                case findBinding ext of
206                    (LiftLoad e, []) => SOME(Extract e)
207                |   _ => raise InternalError "Lifted function out of context"
208            )
209
210        |   processCode(Lambda lambda) =
211                (* Bare lambda or lambda in binding where we need a closure.
212                   This can't be lambda-lifted but we still need to
213                   process the body and rebuild the closure. *)
214                SOME(Lambda(processLambdaWithClosure lambda))
215
216        |   processCode(Newenv(decs, exp)) =
217            let
218                fun processDecs [] = []
219
220                |   processDecs ((Declar { value = Lambda (lambda as { recUse=[], ...}), addr, use=[]}) :: tail) =
221                    let
222                        (* We can lambda-lift.  This results in a constant which is added to
223                           the table.  We don't need an entry for the binding. *)
224                        val constntAndArgs = hd(liftLambdas[(lambda, SOME addr(*or NONE*))])
225                    in
226                        Array.update(processedLambdas, addr, SOME constntAndArgs);
227                        processDecs tail
228                    end
229
230                |   processDecs ((Declar { value, addr, ...}) :: tail) =
231                        (* All other non-recursive bindings. *)
232                        Declar{value = processMapCode value, addr=addr, use=[]} :: processDecs tail
233
234                |   processDecs (RecDecs l :: tail) =
235                    let
236                        (* We only lambda-lift if all the functions are called.  We could
237                           actually lift all those that are called and leave the others
238                           but it's probably not worth it. *)
239                        fun checkLift({lambda={recUse=[], ...}, use=[], ...}, true) = true
240                        |   checkLift _ = false
241                    in
242                        if List.foldl checkLift true l
243                        then
244                        let
245                            val results =
246                                liftLambdas(map(fn{lambda, addr, ...} => (lambda, SOME addr)) l)
247                        in
248                            (* Add the code of the functions to the array. *)
249                            ListPair.appEq(
250                                fn (ca, {addr, ...}) => Array.update(processedLambdas, addr, SOME ca))
251                                (results, l);
252                            (* And just deal with the rest of the bindings. *)
253                            processDecs tail
254                        end
255                        else
256                        let
257                            val pLambdas =
258                                map (fn {lambda, addr, ...} =>
259                                        {addr=addr, use=[], lambda=processLambdaWithClosure lambda}) l
260                        in
261                            RecDecs pLambdas :: processDecs tail
262                        end
263                    end
264
265                |   processDecs (NullBinding c :: tail) =
266                        NullBinding(processMapCode c) :: processDecs tail
267
268                |   processDecs (Container{ addr, size, setter,... } :: tail) =
269                        Container{addr=addr, use=[], size=size, setter=processMapCode setter} :: processDecs tail
270            in
271                SOME(Newenv(processDecs decs, processMapCode exp))
272            end
273
274        |   processCode _ = NONE
275
276        and processLambdaWithClosure({body, closure, localCount, name, argTypes, resultType, ...}) =
277        (* Lambdas that are not to be lifted.  They may still have functions inside that can
278           be lifted.  They may also refer to functions that have been lifted. *)
279        let
280            (* We have to rebuild the closure.  If any of the closure entries were lifted
281               functions they are now constants but their arguments have to be added to
282               the closure. *)
283            val newClosure = makeClosure()
284
285            fun closureRef n =
286            let
287                val (localFunction, extraArgs) = findBinding(List.nth(closure, n))
288                (* If the function is a local we have to add it to the closure.
289                   If it is a lifted function the function itself will be a
290                   constant except in the case of a recursive call.  We do
291                   have to add the arguments to the closure. *)
292                val resFunction =
293                    case localFunction of
294                        LiftLoad ext => LiftLoad(addToClosure newClosure ext)
295                    |   c as LiftConst _ => c
296                val resArgs = map(fn ext => addToClosure newClosure ext) extraArgs
297            in
298                (resFunction, resArgs)
299            end
300
301            val processedBody = processBody(body, closureRef, fn () => [], localCount, debugArgs)
302        in
303            {body=processedBody, isInline=NonInline, closure=extractClosure newClosure, localCount=localCount, name=name,
304             argTypes=argTypes, resultType=resultType, recUse=[]}
305        end
306
307        and liftLambdas (bindings: (lambdaForm * int option) list) =
308        (* Lambda-lift one or more functions.  The general, but least common, case is a
309           set of mutually recursive functions.  More usually we have a single binding
310           of a function or a single anonymous lambda.
311           Lambda-lifting involves replacing the closure with arguments so it can only
312           be used when we can identify all the call sites of the function and add
313           the extra arguments. Because the transformed function has an empty closure
314           (but see below for the mutually-recursive case) it can be code-generated
315           immediately.  The code then becomes a constant.
316
317           There are a few complications.  Although the additional, "closure"
318           arguments are taken from the original function closure there may be
319           changes if some of the closure entries are actually lambda-lifted
320           functions.  In that case the function may become a constant, and
321           so not need to be included in the arguments, but the additional
322           arguments for that function may need to be added to the closure.
323           The other complication is recursion, especially mutual recursion.
324           If we have references to mutually recursive functions we actually
325           leave those references in the closure.  This means that we actually
326           code-generate mutually-recursive functions with non-empty closures
327           but that is allowed if the references are only to other functions
328           in the set.  The code-generator sorts that out. *)
329        let
330            (* We need to construct a new common closure.  This will be used by all
331               the functions. *)
332            val newClosure = makeClosure()
333
334            fun closureEntry clItem =
335            let
336                val (localFunction, extraArgs) = findBinding clItem
337                (* If the function is a local we have to add it to the closure.
338                   If it is a lifted function the function itself will be a
339                   constant except in the case of a recursive call.  We do
340                   have to add the arguments to the closure. *)
341                val resFunction =
342                    case localFunction of
343                        LiftLoad ext => LiftLoad(addToClosure newClosure ext)
344                    |   c as LiftConst _ => c
345                val resArgs = map(fn ext => addToClosure newClosure ext) extraArgs
346            in
347                (resFunction, resArgs)
348            end
349
350            local
351                (* Check for an address which is one of the recursive set. *)
352                val addressesUsed = List.mapPartial #2 bindings
353            in
354                fun isRecursive(LoadLocal n) = List.exists(fn p => p=n) addressesUsed
355                |   isRecursive _ = false
356            end
357
358            local
359                fun closureItem ext =
360                    (* If it's a local we have to check that it's not one of our
361                       mutually recursive set. These items aren't going to be
362                       passed as arguments. *)
363                    if isRecursive ext then () else (closureEntry ext; ())
364            in
365                val () = List.app(fn ({closure, ...}, _) => List.app closureItem closure) bindings
366            end
367
368            (* This composite closure is the set of additional arguments we need. *)
369            val transClosure = extractClosure newClosure
370
371            local
372                val extraArgs = List.map(fn _ => (GeneralType, [])) transClosure
373                val closureSize = List.length transClosure
374
375                (* Process the function bodies. *)
376                fun transformLambda({body, closure, localCount, name, argTypes, resultType, ...}, addr) =
377                let
378                    val argSize = List.length argTypes
379                    val recArgs = List.tabulate(closureSize, fn n => LoadArgument(n+argSize))
380
381                    (* References to other functions in the set are added to a
382                       residual closure. *)
383                    val residual = makeClosure()
384
385                    fun closureRef clItem =
386                    (* We have a reference to the (old) closure item.  We need to change that
387                       to return the appropriate argument.  The exception is that if we
388                       have a (recursive) reference to another function in the set we
389                       instead use an entry from the residual closure. *)
390                    let
391                        val oldClosureItem = List.nth(closure, clItem)
392                    in
393                        if isRecursive oldClosureItem
394                        then (LiftLoad(addToClosure residual oldClosureItem), recArgs)
395                        else
396                        let
397                            val (localFunction, resArgs) = closureEntry oldClosureItem
398
399                            fun mapToArg(LoadClosure n) = LoadArgument(n+argSize)
400                            |   mapToArg _ = raise InternalError "mapToArg" (* Not a closure item. *)
401
402                            val resFunction =
403                                case localFunction of
404                                    LiftLoad ext => LiftLoad(mapToArg ext)
405                                |   c as LiftConst _ => c
406                        in
407                            (resFunction, map mapToArg resArgs)
408                        end
409                    end
410
411                    (* Recursive case - add the extra args. *)
412                    and recursiveRef() = recArgs
413
414                    val processedBody = processBody(body, closureRef, recursiveRef, localCount, debugArgs)
415
416                    val lambda = 
417                        {body=processedBody, isInline=NonInline, closure=extractClosure residual,
418                         localCount=localCount, name=name,
419                         argTypes=argTypes @ extraArgs, resultType=resultType, recUse=[]}
420                in
421                    { lambda=lambda, addr=getOpt(addr, 0), use=[] }
422                end
423            
424            in
425                val bindingsForCode = List.map transformLambda bindings
426            end
427
428            local
429                (* We may have a single anonymous lambda.  In that case we can give it
430                   address zero. *)
431                val addresses = map (fn (_, addr) => getOpt(addr, 0)) bindings
432                (* Create "closures" for each entry.  These will be set by the
433                   code-generator to the code of each function and will become the
434                   closures we return.  Put them into the table. *)
435                val maxAddr = List.foldl(fn (addr, n) => Int.max(addr, n)) 0 addresses
436                (* To get the constant addresses we create bindings for the functions and
437                   return a tuple with one entry for each binding. *)
438                val extracts = List.map(Extract o LoadLocal) addresses
439                val code = Newenv([RecDecs bindingsForCode], mkTuple extracts)
440                (* Code-generate, "run" the code and extract the results. *)
441                open Address
442                val closure = allocWordData(0w1, Word8.orb(F_mutable, F_words), toMachineWord 0w1)
443                (* Turn this into a lambda to code-generate. *)
444                val lambda:lambdaForm =
445                {
446                    body = code,
447                    isInline = NonInline,
448                    name = "<top level>",
449                    closure = [],
450                    argTypes = [(GeneralType, [])],
451                    resultType = GeneralType,
452                    localCount = maxAddr+1,
453                    recUse = []
454                }
455                val props = BACKEND.codeGenerate(lambda, debugArgs, closure)
456                val code: unit -> machineWord = RunCall.unsafeCast closure
457                val codeConstnt = Constnt(code(), props)
458
459                fun getItem([], _) = []
460                |   getItem(_ :: l, n) = (mkInd(n, codeConstnt), transClosure) :: getItem(l, n+1)
461            in
462                (* Put in the results with the closures. *)
463                val results = getItem(bindings, 0)
464            end
465        in
466            results
467        end
468            
469        and processMapCode code = mapCodetree processCode code
470    in
471        processMapCode code
472    end
473
474    fun codeGenerate(original: lambdaForm, debugArgs, closure) =
475    let
476        fun toplevel _ = raise InternalError "Top level reached"
477        val checked = checkBody(Lambda original, toplevel, toplevel, 0)
478        val processed =
479            case processBody(checked, toplevel, toplevel, 0, debugArgs) of
480                Lambda p => p
481            |   _ => raise InternalError "CODETREE_LAMBDA_LIFT:codeGenerate"
482    in
483        BACKEND.codeGenerate(processed, debugArgs, closure)
484    end
485
486    structure Foreign = BACKEND.Foreign
487
488    structure Sharing = BASECODETREE.Sharing
489
490end;
491