1{-
2    SockeyeChecker.hs: AST checker for Sockeye
3
4    Part of Sockeye
5
6    Copyright (c) 2017, ETH Zurich.
7
8    All rights reserved.
9
10    This file is distributed under the terms in the attached LICENSE file.
11    If you do not find this file, copies can be found by writing to:
12    ETH Zurich D-INFK, CAB F.78, Universitaetstrasse 6, CH-8092 Zurich,
13    Attn: Systems Group.
14-}
15
16{-# LANGUAGE MultiParamTypeClasses #-}
17{-# LANGUAGE FlexibleInstances #-}
18{-# LANGUAGE FlexibleContexts #-}
19
20module SockeyeTypeChecker
21( typeCheckSockeye ) where
22
23import Control.Monad
24
25import Data.Map(Map)
26import qualified Data.Map as Map
27import Data.Set (Set)
28import qualified Data.Set as Set
29import Data.Either
30
31import SockeyeChecks
32
33import qualified SockeyeASTParser as ParseAST
34import qualified SockeyeASTTypeChecker as CheckAST
35
36data TypeCheckFail
37    = DuplicateModule String
38    | DuplicateParameter String
39    | DuplicateVariable String
40    | NoSuchModule String
41    | NoSuchParameter String
42    | NoSuchVariable String
43    | ParamTypeMismatch String CheckAST.ModuleParamType CheckAST.ModuleParamType
44    | WrongNumberOfArgs String Int Int
45    | ArgTypeMismatch String String CheckAST.ModuleParamType CheckAST.ModuleParamType
46
47instance Show TypeCheckFail where
48    show (DuplicateModule name)    = concat ["Multiple definitions for module '", name, "'"]
49    show (DuplicateParameter name) = concat ["Multiple parameters named '", name, "'"]
50    show (DuplicateVariable name)  = concat ["Multiple definitions for variable '", name, "'"]
51    show (NoSuchModule name)       = concat ["No definition for module '", name, "'"]
52    show (NoSuchParameter name)    = concat ["Parameter '", name, "' not in scope"]
53    show (NoSuchVariable name)     = concat ["Variable '", name, "' not in scope"]
54    show (WrongNumberOfArgs name takes given) = concat ["Module '", name, "' takes ", show takes, " argument(s), given ", show given]
55    show (ParamTypeMismatch name expected actual) =
56        concat ["Expected type '", show expected, "' but '", name, "' has type '", show actual, "'"]
57    show (ArgTypeMismatch modName name expected actual) =
58        concat ["Type mismatch for argument '", name, "' for module '", modName, "': Expected '", show expected, "', given '", show actual, "'"]
59
60data ModuleSymbol = ModuleSymbol
61    { paramNames :: [String]
62    , paramTypes :: Map String CheckAST.ModuleParamType
63    }
64type SymbolTable = Map String ModuleSymbol
65
66data Context = Context
67    { symTable   :: SymbolTable
68    , curModule  :: !String
69    , instModule :: !String
70    , vars       :: Set String
71    }
72
73typeCheckSockeye :: ParseAST.SockeyeSpec -> Either (FailedChecks TypeCheckFail) CheckAST.SockeyeSpec
74typeCheckSockeye ast = do
75    symbolTable <- runChecks $ buildSymbolTable ast
76    let context = Context
77            { symTable   = symbolTable
78            , curModule  = ""
79            , instModule = ""
80            , vars       = Set.empty
81            }
82    runChecks $ check context ast
83
84--
85-- Build Symbol table
86--
87class SymbolSource a where
88    buildSymbolTable :: a -> Checks TypeCheckFail SymbolTable
89
90instance SymbolSource ParseAST.SockeyeSpec where
91    buildSymbolTable ast = do
92        let mods = ParseAST.modules ast
93        symbolTables <- mapM buildSymbolTable mods
94        let names = concat $ map Map.keys symbolTables
95        checkDuplicates "@all" DuplicateModule names
96        return $ Map.unions symbolTables
97
98instance SymbolSource ParseAST.Module where
99    buildSymbolTable ast = do
100        let modName = ParseAST.name ast
101            params = ParseAST.parameters ast
102            names = map ParseAST.paramName params
103            types = map ParseAST.paramType params
104        checkDuplicates modName DuplicateParameter names
105        let typeMap = Map.fromList $ zip names types
106            modSymbol = ModuleSymbol
107                { paramNames = names
108                , paramTypes = typeMap
109                }
110        return $ Map.singleton modName modSymbol
111
112--
113-- Check module bodies
114--
115class Checkable a b where
116    check :: Context -> a -> Checks TypeCheckFail b
117
118instance Checkable ParseAST.SockeyeSpec CheckAST.SockeyeSpec where
119    check context ast = do
120        let mods = ParseAST.modules ast
121            rootNetSpecs = ParseAST.net ast
122            names = map ParseAST.name mods
123            rootName = "@root"
124            rootSymbol = ModuleSymbol
125                { paramNames = []
126                , paramTypes = Map.empty
127                }
128            rootModContext = context
129                { symTable = Map.insert rootName rootSymbol $ symTable context
130                , curModule = rootName
131                }
132        checkedRootNetSpecs <- check rootModContext rootNetSpecs
133        checkedModules <- check context mods
134        let root = CheckAST.ModuleInst
135                { CheckAST.namespace  = Nothing
136                , CheckAST.moduleName = rootName
137                , CheckAST.arguments  = Map.empty
138                , CheckAST.inPortMap  = []
139                , CheckAST.outPortMap = []
140                }
141            rootModule = CheckAST.Module
142                { CheckAST.paramNames   = []
143                , CheckAST.paramTypeMap = Map.empty
144                , CheckAST.ports        = []
145                , CheckAST.nodeDecls    = lefts  checkedRootNetSpecs
146                , CheckAST.moduleInsts  = rights checkedRootNetSpecs
147                }
148            moduleMap = Map.fromList $ zip (rootName:names) (rootModule:checkedModules)
149        return CheckAST.SockeyeSpec
150            { CheckAST.root    = root
151            , CheckAST.modules = moduleMap
152            }
153
154instance Checkable ParseAST.Module CheckAST.Module where
155    check context ast = do
156        let
157            name = ParseAST.name ast
158            body = ParseAST.moduleBody ast
159            ports = ParseAST.ports body
160            netSpecs = ParseAST.moduleNet body
161            symbol = (symTable context) Map.! name
162        let bodyContext = context
163                { curModule = name }
164        checkedPorts <- check bodyContext ports
165        checkedNetSpecs <- check bodyContext netSpecs
166        let
167            checkedNodeDecls = lefts checkedNetSpecs
168            checkedModuleInsts = rights checkedNetSpecs
169        return CheckAST.Module
170            { CheckAST.paramNames   = paramNames symbol
171            , CheckAST.paramTypeMap = paramTypes symbol
172            , CheckAST.ports        = checkedPorts
173            , CheckAST.nodeDecls    = checkedNodeDecls
174            , CheckAST.moduleInsts  = checkedModuleInsts
175            }
176
177instance Checkable ParseAST.Port CheckAST.Port where
178    check context (ParseAST.InputPort portId portWidth) = do
179        checkedId <- check context portId
180        return $ CheckAST.InputPort checkedId portWidth
181    check context (ParseAST.OutputPort portId portWidth) = do
182        checkedId <- check context portId
183        return $ CheckAST.OutputPort checkedId portWidth
184    check context (ParseAST.MultiPort for) = do
185        checkedFor <- check context for
186        return $ CheckAST.MultiPort checkedFor
187
188instance Checkable ParseAST.NetSpec (Either CheckAST.NodeDecl CheckAST.ModuleInst) where
189    check context (ParseAST.NodeDeclSpec decl) = do
190        checkedDecl <- check context decl
191        return $ Left checkedDecl
192    check context (ParseAST.ModuleInstSpec inst) = do
193        checkedInst <- check context inst
194        return $ Right checkedInst
195
196instance Checkable ParseAST.ModuleInst CheckAST.ModuleInst where
197    check context (ParseAST.MultiModuleInst for) = do
198        checkedFor <- check context for
199        return $ CheckAST.MultiModuleInst checkedFor
200    check context ast = do
201        let
202            namespace = ParseAST.namespace ast
203            name = ParseAST.moduleName ast
204            arguments = ParseAST.arguments ast
205            portMaps = ParseAST.portMappings ast
206        checkedArgs <- if Map.member name (symTable context)
207            then check (context { instModule = name }) arguments
208            else do
209                failCheck (curModule context) $ NoSuchModule name
210                return Map.empty
211        checkedNamespace <- check context namespace
212        inPortMap  <- check context $ filter isInMap  portMaps
213        outPortMap <- check context $ filter isOutMap portMaps
214        return CheckAST.ModuleInst
215            { CheckAST.namespace  = Just checkedNamespace
216            , CheckAST.moduleName = name
217            , CheckAST.arguments  = checkedArgs
218            , CheckAST.inPortMap  = inPortMap
219            , CheckAST.outPortMap = outPortMap
220            }
221        where
222            isInMap  (ParseAST.InputPortMap  {}) = True
223            isInMap  (ParseAST.OutputPortMap {}) = False
224            isInMap  (ParseAST.MultiPortMap for) = isInMap $ ParseAST.body for
225            isOutMap (ParseAST.InputPortMap  {}) = False
226            isOutMap (ParseAST.OutputPortMap {}) = True
227            isOutMap (ParseAST.MultiPortMap for) = isOutMap $ ParseAST.body for
228
229instance Checkable [ParseAST.ModuleArg] (Map String CheckAST.ModuleArg) where
230    check context ast = do
231        let symbol = (symTable context) Map.! instName
232            names = paramNames symbol
233            expTypes = map (paramTypes symbol Map.!) names
234        checkArgCount names ast
235        checkedArgs <- zipWithM checkArgType (zip names expTypes) ast
236        return $ Map.fromList $ zip names checkedArgs
237        where
238            checkArgCount params args = do
239                let
240                    paramc = length params
241                    argc = length args
242                if argc == paramc
243                    then return ()
244                    else failCheck curName $ WrongNumberOfArgs instName paramc argc
245            checkArgType (name, expType) arg = do
246                case arg of
247                    ParseAST.NumericalArg value -> do
248                        return $ CheckAST.NumericalArg value
249                    ParseAST.ParamArg paramName -> do
250                        checkParamType context paramName expType
251                        return $ CheckAST.ParamArg paramName
252                where
253                    mismatch = failCheck curName . ArgTypeMismatch instName name expType
254            curName = curModule context
255            instName = instModule context
256
257instance Checkable ParseAST.PortMap CheckAST.PortMap where
258    check context (ParseAST.MultiPortMap for) = do
259        checkedFor <- check context for
260        return $ CheckAST.MultiPortMap checkedFor
261    check context portMap = do
262        let
263            mappedId = ParseAST.mappedId portMap
264            mappedPort = ParseAST.mappedPort portMap
265        checkedId <- check context mappedId
266        checkedPort <- check context mappedPort
267        return $ CheckAST.PortMap
268            { CheckAST.mappedId   = checkedId
269            , CheckAST.mappedPort = checkedPort
270            }
271
272instance Checkable ParseAST.NodeDecl CheckAST.NodeDecl where
273    check context (ParseAST.MultiNodeDecl for) = do
274        checkedFor <- check context for
275        return $ CheckAST.MultiNodeDecl checkedFor
276    check context ast = do
277        let
278            nodeId = ParseAST.nodeId ast
279            nodeSpec = ParseAST.nodeSpec ast
280        checkedId <- check context nodeId
281        checkedSpec <- check context nodeSpec
282        return CheckAST.NodeDecl
283            { CheckAST.nodeId   = checkedId
284            , CheckAST.nodeSpec = checkedSpec
285            }
286
287instance Checkable ParseAST.Identifier CheckAST.Identifier where
288    check _ (ParseAST.SimpleIdent name) = return $ CheckAST.SimpleIdent name
289    check context ast = do
290        let
291            prefix = ParseAST.prefix ast
292            varName = ParseAST.varName ast
293            suffix = ParseAST.suffix ast
294        checkVarInScope context varName
295        checkedSuffix <- case suffix of
296            Nothing    -> return Nothing
297            Just ident -> do
298                checkedIdent <- check context ident
299                return $ Just checkedIdent
300        return CheckAST.TemplateIdent
301            { CheckAST.prefix  = prefix
302            , CheckAST.varName = varName
303            , CheckAST.suffix  = checkedSuffix
304            }
305
306instance Checkable ParseAST.NodeSpec CheckAST.NodeSpec where
307    check context ast = do
308        let
309            nodeType = ParseAST.nodeType ast
310            accept = ParseAST.accept ast
311            translate = ParseAST.translate ast
312            overlay = ParseAST.overlay ast
313            reserved = ParseAST.reserved ast
314        checkedAccept <- check context accept
315        checkedTranslate <- check context translate
316        checkedReserved <- check context reserved
317        checkedOverlay <- case overlay of
318            Nothing    -> return Nothing
319            Just ident -> do
320                checkedIdent <- check context ident
321                return $ Just checkedIdent
322        return CheckAST.NodeSpec
323            { CheckAST.nodeType  = nodeType
324            , CheckAST.accept    = checkedAccept
325            , CheckAST.translate = checkedTranslate
326            , CheckAST.reserved  = checkedReserved
327            , CheckAST.overlay   = checkedOverlay
328            }
329
330instance Checkable ParseAST.BlockSpec CheckAST.BlockSpec where
331    check context (ParseAST.SingletonBlock address props) = do
332        checkedAddress <- check context address
333        return CheckAST.SingletonBlock
334            { CheckAST.base = checkedAddress
335            , CheckAST.props = props }
336    check context (ParseAST.RangeBlock base limit props) = do
337        checkedBase <- check context base
338        checkedLimit <- check context limit
339        return CheckAST.RangeBlock
340            { CheckAST.base  = checkedBase
341            , CheckAST.limit = checkedLimit
342            , CheckAST.props = props
343            }
344    check context (ParseAST.LengthBlock base bits props) = do
345        checkedBase <- check context base
346        return CheckAST.LengthBlock
347            { CheckAST.base = checkedBase
348            , CheckAST.bits = bits
349            , CheckAST.props = props
350            }
351
352instance Checkable ParseAST.MapSpec CheckAST.MapSpec where
353    check context ast = do
354        let
355            block = ParseAST.block ast
356            destNode = ParseAST.destNode ast
357            destBase = ParseAST.destBase ast
358            destProps = ParseAST.destProps ast
359        checkedBlock <- check context block
360        checkedDestNode <- check context destNode
361        checkedDestBase <- case destBase of
362            Nothing      -> return Nothing
363            Just address -> do
364                checkedAddress <- check context address
365                return $ Just checkedAddress
366        return CheckAST.MapSpec
367            { CheckAST.block    = checkedBlock
368            , CheckAST.destNode = checkedDestNode
369            , CheckAST.destBase = checkedDestBase
370            , CheckAST.destProps = destProps
371            }
372
373instance Checkable ParseAST.OverlaySpec CheckAST.OverlaySpec where
374    check context (ParseAST.OverlaySpec over width) = do
375        checkedOver <- check context over
376        return $ CheckAST.OverlaySpec checkedOver width
377
378instance Checkable ParseAST.Address CheckAST.Address where
379    check _ (ParseAST.LiteralAddress value) = do
380        return $ CheckAST.LiteralAddress value
381    check context (ParseAST.ParamAddress name) = do
382        checkParamType context name CheckAST.AddressParam
383        return $ CheckAST.ParamAddress name
384
385instance Checkable a b => Checkable (ParseAST.For a) (CheckAST.For b) where
386    check context ast = do
387        let
388            varRanges = ParseAST.varRanges ast
389            varNames = map ParseAST.var varRanges
390            body = ParseAST.body ast
391            currentVars = vars context
392        checkDuplicates (curModule context) DuplicateVariable (varNames ++ Set.elems currentVars)
393        ranges <- check context varRanges
394        let
395            bodyVars = currentVars `Set.union` (Set.fromList varNames)
396            bodyContext = context
397                { vars = bodyVars }
398        checkedBody <- check bodyContext body
399        let
400            checkedVarRanges = Map.fromList $ zip varNames ranges
401        return CheckAST.For
402                { CheckAST.varRanges = checkedVarRanges
403                , CheckAST.body      = checkedBody
404                }
405
406instance Checkable ParseAST.ForVarRange CheckAST.ForRange where
407    check context ast = do
408        let
409            start = ParseAST.start ast
410            end = ParseAST.end ast
411        checkedStart <- check context start
412        checkedEnd<- check context end
413        return CheckAST.ForRange
414            { CheckAST.start = checkedStart
415            , CheckAST.end   = checkedEnd
416            }
417
418instance Checkable ParseAST.ForLimit CheckAST.ForLimit where
419    check _ (ParseAST.LiteralLimit value) = do
420        return $ CheckAST.LiteralLimit value
421    check context (ParseAST.ParamLimit name) = do
422        checkParamType context name CheckAST.NaturalParam
423        return $ CheckAST.ParamLimit name
424
425instance (Traversable t, Checkable a b) => Checkable (t a) (t b) where
426    check context as = mapM (check context) as
427
428--
429-- Helpers
430--
431checkVarInScope :: Context -> String -> Checks TypeCheckFail ()
432checkVarInScope context name = do
433    if name `Set.member` (vars context)
434        then return ()
435        else failCheck (curModule context) $ NoSuchVariable name
436
437
438checkParamType :: Context -> String -> CheckAST.ModuleParamType -> Checks TypeCheckFail ()
439checkParamType context name expected = do
440    let symbol = (symTable context) Map.! (curModule context)
441    case Map.lookup name $ paramTypes symbol of
442        Nothing -> failCheck (curModule context) $ NoSuchParameter name
443        Just actual -> do
444            if actual == expected
445                then return ()
446                else failCheck (curModule context) $ ParamTypeMismatch name expected actual
447