1(* 2 Copyright (c) 2015 David C.J. Matthews 3 4 This library is free software; you can redistribute it and/or 5 modify it under the terms of the GNU Lesser General Public 6 License as published by the Free Software Foundation; either 7 version 2.1 of the License, or (at your option) any later version. 8 9 This library is distributed in the hope that it will be useful, 10 but WITHOUT ANY WARRANTY; without even the implied warranty of 11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 12 Lesser General Public License for more details. 13 14 You should have received a copy of the GNU Lesser General Public 15 License along with this library; if not, write to the Free Software 16 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 17*) 18 19(* 20Lambda-lifting. If every call point to a function can be identified we can 21lift the free variables as extra parameters. This avoids the need for a 22closure on the heap. It makes stack-closures largely redundant. The 23advantages of lambda-lifting over stack closures are that the containing 24function of a stack-closure cannot call a stack-closure with tail-recursion 25because the closure must remain on the stack until the function returns. 26Also we can lambda-lift a function even if it is used in a function that 27requires a full closure whereas we cannot use a stack closure for a 28function if the closure would be used in a full, heap closure. 29 30This pass is called after optimisation and after any functions that have 31empty closures have been code-generated to constants. 32*) 33functor CODETREE_LAMBDA_LIFT ( 34 35 structure BASECODETREE: BaseCodeTreeSig 36 structure CODETREE_FUNCTIONS: CodetreeFunctionsSig 37 structure BACKEND: CodegenTreeSig 38 structure DEBUG: DEBUGSIG 39 structure PRETTY : PRETTYSIG 40 41 sharing 42 BASECODETREE.Sharing 43 = CODETREE_FUNCTIONS.Sharing 44 = BACKEND.Sharing 45 = PRETTY.Sharing 46): CodegenTreeSig = 47struct 48 open BASECODETREE 49 open CODETREE_FUNCTIONS 50 exception InternalError = Misc.InternalError 51 52 (* First pass: identify the functions whose only use are calls. This annotates the tree 53 by setting the "use" or any bindings or recursive uses that require a closure to 54 [UseGeneral]. *) 55 fun checkBody(code: codetree, closureRef: int -> unit, recursiveRef: unit -> unit, localCount) = 56 let 57 (* An entry for each local binding. Set to true if we find a non-call reference. *) 58 val localsNeedClosures = BoolArray.array(localCount, false) 59 60 fun markExtract(LoadLocal n) = BoolArray.update(localsNeedClosures, n, true) 61 | markExtract LoadRecursive = recursiveRef() 62 | markExtract(LoadClosure n) = closureRef n 63 | markExtract(LoadArgument _) = () 64 65 fun checkCode(ext as Extract load) = (markExtract load; SOME ext) 66 (* These are loads which aren't calls. If they are functions they need closures. *) 67 68 | checkCode(Eval{function as Extract _, argList, resultType}) = 69 (* A call of a function. We don't need to mark the function as needing a closure. *) 70 SOME( 71 Eval{function=function, argList=map(fn (c, t) => (checkMapCode c, t)) argList, 72 resultType=resultType}) 73 74 | checkCode(Lambda lambda) = SOME(Lambda(checkLambda lambda)) 75 76 | checkCode(Newenv(decs, exp)) = 77 (* We want to add [UseGeneral] to bindings that require closures. To do that 78 we have to process the bindings in reverse order. *) 79 let 80 val processedExp = checkMapCode exp (* The expression first. *) 81 82 fun getFlag addr = 83 if BoolArray.sub(localsNeedClosures, addr) then [UseGeneral] else [] 84 85 fun processDecs [] = [] 86 87 | processDecs ((Declar { value, addr, ...}) :: tail) = 88 let 89 val pTail = processDecs tail (* Tail first *) 90 val pValue = checkMapCode value 91 in 92 Declar{value = pValue, addr=addr, use=getFlag addr} :: pTail 93 end 94 95 | processDecs (RecDecs l :: tail) = 96 let 97 val pTail = processDecs tail (* Tail first *) 98 (* Process the lambdas. Because they're mutually recursive this may set the 99 closure flag for others in the set. *) 100 val pLambdas = 101 map (fn {lambda, addr, ...} => 102 {addr=addr, use=[], lambda=checkLambda lambda}) l 103 (* Can now pick up the closure flags. *) 104 val pDecs = 105 map(fn {lambda, addr, ...} => 106 {lambda=lambda, addr=addr, use=getFlag addr}) pLambdas 107 in 108 RecDecs pDecs :: pTail 109 end 110 111 | processDecs (NullBinding c :: tail) = 112 let 113 val pTail = processDecs tail 114 in 115 NullBinding(checkMapCode c) :: pTail 116 end 117 118 | processDecs (Container{ addr, size, setter,... } :: tail) = 119 let 120 val pTail = processDecs tail 121 in 122 Container{addr=addr, use=[], size=size, setter=checkMapCode setter} :: pTail 123 end 124 in 125 SOME(Newenv(processDecs decs, processedExp)) 126 end 127 128 | checkCode _ = NONE 129 130 and checkLambda({body, closure, localCount, name, argTypes, resultType, ...}) = 131 (* Lambdas - check the function body and any recursive uses. *) 132 let 133 val recNeedsClosure = ref false 134 fun refToRecursive() = recNeedsClosure := true 135 fun refToClosure n = markExtract(List.nth(closure, n)) 136 val processedBody = checkBody(body, refToClosure, refToRecursive, localCount) 137 in 138 {body=processedBody, isInline=NonInline, closure=closure, localCount=localCount, name=name, 139 argTypes=argTypes, resultType=resultType, recUse=if !recNeedsClosure then [UseGeneral] else []} 140 end 141 142 and checkMapCode code = mapCodetree checkCode code 143 in 144 checkMapCode code 145 end 146 147 (* Second pass: Actually do the lambda-lifting. *) 148 datatype lift = 149 LiftLoad of loadForm (* Usually unlifted but also for recursive calls. *) 150 | LiftConst of codetree (* A lifted function. *) 151 152 fun processBody(code: codetree, getClosure: int -> lift * loadForm list, 153 getRecursive: unit -> loadForm list, localCount, debugArgs): codetree = 154 let 155 156 val processedLambdas: 157 (codetree * loadForm list) option array = Array.array(localCount, NONE) 158 159 fun findBinding(ext as LoadLocal n) = 160 ( 161 case Array.sub(processedLambdas, n) of 162 SOME (c, l) => (LiftConst c, l) 163 | NONE => (LiftLoad ext, []) 164 ) 165 | findBinding(LoadRecursive) = (LiftLoad LoadRecursive, getRecursive()) 166 (* The code for the recursive case is always LoadRecursive but depending 167 on whether it's been lifted or not there may be extra args. *) 168 | findBinding(LoadClosure n) = getClosure n 169 | findBinding(ext as LoadArgument _) = (LiftLoad ext, []) 170 171 fun processCode(Eval{function=Extract ext, argList, resultType}) = 172 let 173 (* If this has been lifted we have to add the extra arguments. 174 The function may also now be a constant. *) 175 val (newFunction, extraArgs) = 176 case findBinding ext of 177 (LiftConst c, l) => (c, l) 178 | (LiftLoad e, l) => (Extract e, l) 179 180 (* Process the original args. There may be functions in there. *) 181 val processedArgs = map(fn (c, t) => (processMapCode c, t)) argList 182 in 183 SOME(Eval{function=newFunction, 184 argList=processedArgs @ map(fn c => (Extract c, GeneralType)) extraArgs, 185 resultType=resultType}) 186 end 187 188 | processCode(Eval{function=Lambda(lambda as { recUse=[], ...}), argList, resultType}) = 189 (* We have a call to a lambda. This must be a recursive function otherwise it 190 would have been expanded inline. If the recursive references are just calls 191 we can lambda-lift it. *) 192 let 193 val (fnConstnt, extraArgs) = hd(liftLambdas([(lambda, NONE)])) 194 val processedArgs = map(fn (c, t) => (processMapCode c, t)) argList 195 in 196 SOME(Eval{function=fnConstnt, 197 argList=processedArgs @ map(fn c => (Extract c, GeneralType)) extraArgs, 198 resultType=resultType}) 199 end 200 201 | processCode(Extract ext) = 202 ( 203 (* A load of a binding outside a call. We need to process this to 204 rebuild the closure but if we get a lifted function it's an error. *) 205 case findBinding ext of 206 (LiftLoad e, []) => SOME(Extract e) 207 | _ => raise InternalError "Lifted function out of context" 208 ) 209 210 | processCode(Lambda lambda) = 211 (* Bare lambda or lambda in binding where we need a closure. 212 This can't be lambda-lifted but we still need to 213 process the body and rebuild the closure. *) 214 SOME(Lambda(processLambdaWithClosure lambda)) 215 216 | processCode(Newenv(decs, exp)) = 217 let 218 fun processDecs [] = [] 219 220 | processDecs ((Declar { value = Lambda (lambda as { recUse=[], ...}), addr, use=[]}) :: tail) = 221 let 222 (* We can lambda-lift. This results in a constant which is added to 223 the table. We don't need an entry for the binding. *) 224 val constntAndArgs = hd(liftLambdas[(lambda, SOME addr(*or NONE*))]) 225 in 226 Array.update(processedLambdas, addr, SOME constntAndArgs); 227 processDecs tail 228 end 229 230 | processDecs ((Declar { value, addr, ...}) :: tail) = 231 (* All other non-recursive bindings. *) 232 Declar{value = processMapCode value, addr=addr, use=[]} :: processDecs tail 233 234 | processDecs (RecDecs l :: tail) = 235 let 236 (* We only lambda-lift if all the functions are called. We could 237 actually lift all those that are called and leave the others 238 but it's probably not worth it. *) 239 fun checkLift({lambda={recUse=[], ...}, use=[], ...}, true) = true 240 | checkLift _ = false 241 in 242 if List.foldl checkLift true l 243 then 244 let 245 val results = 246 liftLambdas(map(fn{lambda, addr, ...} => (lambda, SOME addr)) l) 247 in 248 (* Add the code of the functions to the array. *) 249 ListPair.appEq( 250 fn (ca, {addr, ...}) => Array.update(processedLambdas, addr, SOME ca)) 251 (results, l); 252 (* And just deal with the rest of the bindings. *) 253 processDecs tail 254 end 255 else 256 let 257 val pLambdas = 258 map (fn {lambda, addr, ...} => 259 {addr=addr, use=[], lambda=processLambdaWithClosure lambda}) l 260 in 261 RecDecs pLambdas :: processDecs tail 262 end 263 end 264 265 | processDecs (NullBinding c :: tail) = 266 NullBinding(processMapCode c) :: processDecs tail 267 268 | processDecs (Container{ addr, size, setter,... } :: tail) = 269 Container{addr=addr, use=[], size=size, setter=processMapCode setter} :: processDecs tail 270 in 271 SOME(Newenv(processDecs decs, processMapCode exp)) 272 end 273 274 | processCode _ = NONE 275 276 and processLambdaWithClosure({body, closure, localCount, name, argTypes, resultType, ...}) = 277 (* Lambdas that are not to be lifted. They may still have functions inside that can 278 be lifted. They may also refer to functions that have been lifted. *) 279 let 280 (* We have to rebuild the closure. If any of the closure entries were lifted 281 functions they are now constants but their arguments have to be added to 282 the closure. *) 283 val newClosure = makeClosure() 284 285 fun closureRef n = 286 let 287 val (localFunction, extraArgs) = findBinding(List.nth(closure, n)) 288 (* If the function is a local we have to add it to the closure. 289 If it is a lifted function the function itself will be a 290 constant except in the case of a recursive call. We do 291 have to add the arguments to the closure. *) 292 val resFunction = 293 case localFunction of 294 LiftLoad ext => LiftLoad(addToClosure newClosure ext) 295 | c as LiftConst _ => c 296 val resArgs = map(fn ext => addToClosure newClosure ext) extraArgs 297 in 298 (resFunction, resArgs) 299 end 300 301 val processedBody = processBody(body, closureRef, fn () => [], localCount, debugArgs) 302 in 303 {body=processedBody, isInline=NonInline, closure=extractClosure newClosure, localCount=localCount, name=name, 304 argTypes=argTypes, resultType=resultType, recUse=[]} 305 end 306 307 and liftLambdas (bindings: (lambdaForm * int option) list) = 308 (* Lambda-lift one or more functions. The general, but least common, case is a 309 set of mutually recursive functions. More usually we have a single binding 310 of a function or a single anonymous lambda. 311 Lambda-lifting involves replacing the closure with arguments so it can only 312 be used when we can identify all the call sites of the function and add 313 the extra arguments. Because the transformed function has an empty closure 314 (but see below for the mutually-recursive case) it can be code-generated 315 immediately. The code then becomes a constant. 316 317 There are a few complications. Although the additional, "closure" 318 arguments are taken from the original function closure there may be 319 changes if some of the closure entries are actually lambda-lifted 320 functions. In that case the function may become a constant, and 321 so not need to be included in the arguments, but the additional 322 arguments for that function may need to be added to the closure. 323 The other complication is recursion, especially mutual recursion. 324 If we have references to mutually recursive functions we actually 325 leave those references in the closure. This means that we actually 326 code-generate mutually-recursive functions with non-empty closures 327 but that is allowed if the references are only to other functions 328 in the set. The code-generator sorts that out. *) 329 let 330 (* We need to construct a new common closure. This will be used by all 331 the functions. *) 332 val newClosure = makeClosure() 333 334 fun closureEntry clItem = 335 let 336 val (localFunction, extraArgs) = findBinding clItem 337 (* If the function is a local we have to add it to the closure. 338 If it is a lifted function the function itself will be a 339 constant except in the case of a recursive call. We do 340 have to add the arguments to the closure. *) 341 val resFunction = 342 case localFunction of 343 LiftLoad ext => LiftLoad(addToClosure newClosure ext) 344 | c as LiftConst _ => c 345 val resArgs = map(fn ext => addToClosure newClosure ext) extraArgs 346 in 347 (resFunction, resArgs) 348 end 349 350 local 351 (* Check for an address which is one of the recursive set. *) 352 val addressesUsed = List.mapPartial #2 bindings 353 in 354 fun isRecursive(LoadLocal n) = List.exists(fn p => p=n) addressesUsed 355 | isRecursive _ = false 356 end 357 358 local 359 fun closureItem ext = 360 (* If it's a local we have to check that it's not one of our 361 mutually recursive set. These items aren't going to be 362 passed as arguments. *) 363 if isRecursive ext then () else (closureEntry ext; ()) 364 in 365 val () = List.app(fn ({closure, ...}, _) => List.app closureItem closure) bindings 366 end 367 368 (* This composite closure is the set of additional arguments we need. *) 369 val transClosure = extractClosure newClosure 370 371 local 372 val extraArgs = List.map(fn _ => (GeneralType, [])) transClosure 373 val closureSize = List.length transClosure 374 375 (* Process the function bodies. *) 376 fun transformLambda({body, closure, localCount, name, argTypes, resultType, ...}, addr) = 377 let 378 val argSize = List.length argTypes 379 val recArgs = List.tabulate(closureSize, fn n => LoadArgument(n+argSize)) 380 381 (* References to other functions in the set are added to a 382 residual closure. *) 383 val residual = makeClosure() 384 385 fun closureRef clItem = 386 (* We have a reference to the (old) closure item. We need to change that 387 to return the appropriate argument. The exception is that if we 388 have a (recursive) reference to another function in the set we 389 instead use an entry from the residual closure. *) 390 let 391 val oldClosureItem = List.nth(closure, clItem) 392 in 393 if isRecursive oldClosureItem 394 then (LiftLoad(addToClosure residual oldClosureItem), recArgs) 395 else 396 let 397 val (localFunction, resArgs) = closureEntry oldClosureItem 398 399 fun mapToArg(LoadClosure n) = LoadArgument(n+argSize) 400 | mapToArg _ = raise InternalError "mapToArg" (* Not a closure item. *) 401 402 val resFunction = 403 case localFunction of 404 LiftLoad ext => LiftLoad(mapToArg ext) 405 | c as LiftConst _ => c 406 in 407 (resFunction, map mapToArg resArgs) 408 end 409 end 410 411 (* Recursive case - add the extra args. *) 412 and recursiveRef() = recArgs 413 414 val processedBody = processBody(body, closureRef, recursiveRef, localCount, debugArgs) 415 416 val lambda = 417 {body=processedBody, isInline=NonInline, closure=extractClosure residual, 418 localCount=localCount, name=name, 419 argTypes=argTypes @ extraArgs, resultType=resultType, recUse=[]} 420 in 421 { lambda=lambda, addr=getOpt(addr, 0), use=[] } 422 end 423 424 in 425 val bindingsForCode = List.map transformLambda bindings 426 end 427 428 local 429 (* We may have a single anonymous lambda. In that case we can give it 430 address zero. *) 431 val addresses = map (fn (_, addr) => getOpt(addr, 0)) bindings 432 (* Create "closures" for each entry. These will be set by the 433 code-generator to the code of each function and will become the 434 closures we return. Put them into the table. *) 435 val maxAddr = List.foldl(fn (addr, n) => Int.max(addr, n)) 0 addresses 436 (* To get the constant addresses we create bindings for the functions and 437 return a tuple with one entry for each binding. *) 438 val extracts = List.map(Extract o LoadLocal) addresses 439 val code = Newenv([RecDecs bindingsForCode], mkTuple extracts) 440 (* Code-generate, "run" the code and extract the results. *) 441 open Address 442 val closure = allocWordData(0w1, Word8.orb(F_mutable, F_words), toMachineWord 0w1) 443 (* Turn this into a lambda to code-generate. *) 444 val lambda:lambdaForm = 445 { 446 body = code, 447 isInline = NonInline, 448 name = "<top level>", 449 closure = [], 450 argTypes = [(GeneralType, [])], 451 resultType = GeneralType, 452 localCount = maxAddr+1, 453 recUse = [] 454 } 455 val props = BACKEND.codeGenerate(lambda, debugArgs, closure) 456 val code: unit -> machineWord = RunCall.unsafeCast closure 457 val codeConstnt = Constnt(code(), props) 458 459 fun getItem([], _) = [] 460 | getItem(_ :: l, n) = (mkInd(n, codeConstnt), transClosure) :: getItem(l, n+1) 461 in 462 (* Put in the results with the closures. *) 463 val results = getItem(bindings, 0) 464 end 465 in 466 results 467 end 468 469 and processMapCode code = mapCodetree processCode code 470 in 471 processMapCode code 472 end 473 474 fun codeGenerate(original: lambdaForm, debugArgs, closure) = 475 let 476 fun toplevel _ = raise InternalError "Top level reached" 477 val checked = checkBody(Lambda original, toplevel, toplevel, 0) 478 val processed = 479 case processBody(checked, toplevel, toplevel, 0, debugArgs) of 480 Lambda p => p 481 | _ => raise InternalError "CODETREE_LAMBDA_LIFT:codeGenerate" 482 in 483 BACKEND.codeGenerate(processed, debugArgs, closure) 484 end 485 486 structure Foreign = BACKEND.Foreign 487 488 structure Sharing = BASECODETREE.Sharing 489 490end; 491