diff --git a/src/Convert/Logic.hs b/src/Convert/Logic.hs index 6db8875..85415dd 100644 --- a/src/Convert/Logic.hs +++ b/src/Convert/Logic.hs @@ -25,6 +25,7 @@ module Convert.Logic (convert) where +import Control.Monad.State import Control.Monad.Writer import qualified Data.Map.Strict as Map import qualified Data.Set as Set @@ -33,7 +34,7 @@ import Convert.Traverse import Language.SystemVerilog.AST type Idents = Set.Set Identifier -type Ports = Map.Map (Identifier, Identifier) Direction +type Ports = Map.Map Identifier [(Identifier, Direction)] convert :: [AST] -> [AST] convert = @@ -43,19 +44,23 @@ convert = where collectPortsM :: Description -> Writer Ports () collectPortsM (orig @ (Part _ _ _ _ name portNames _)) = - collectModuleItemsM collectPortDirsM orig + tell $ Map.singleton name ports where - collectPortDirsM :: ModuleItem -> Writer Ports () - collectPortDirsM (MIPackageItem (Decl (Variable dir _ ident _ _))) = - if dir == Local then - return () - else if elem ident portNames then - tell $ Map.singleton (name, ident) dir - else - error $ "encountered decl with a dir that isn't a port: " - ++ show (dir, ident) - collectPortDirsM _ = return () + ports = zip portNames (map lookupDir portNames) + dirs = execWriter $ collectModuleItemsM collectDeclDirsM orig + lookupDir :: Identifier -> Direction + lookupDir portName = + case lookup portName dirs of + Just dir -> dir + Nothing -> error $ "Could not find dir for port " ++ + portName ++ " in module " ++ name collectPortsM _ = return () + collectDeclDirsM :: ModuleItem -> Writer [(Identifier, Direction)] () + collectDeclDirsM (MIPackageItem (Decl (Variable dir _ ident _ _))) = + if dir == Local + then return () + else tell [(ident, dir)] + collectDeclDirsM _ = return () convertDescription :: Ports -> Description -> Description convertDescription ports orig = @@ -93,15 +98,16 @@ convertDescription ports orig = where comment = MIPackageItem $ Decl $ CommentDecl "rewrote reg-to-output bindings" - (bindings', newItemsList) = unzip $ map fixBinding bindings + (bindings', newItemsList) = + unzip $ map (uncurry fixBinding) $ zip bindings [0..] newItems = concat newItemsList - fixBinding :: PortBinding -> (PortBinding, [ModuleItem]) - fixBinding (portName, Just expr) = + fixBinding :: PortBinding -> Int -> (PortBinding, [ModuleItem]) + fixBinding (portName, Just expr) portIdx = if portDir /= Just Output || Set.disjoint usedIdents origIdents then ((portName, Just expr), []) else ((portName, Just tmpExpr), items) where - portDir = Map.lookup (moduleName, portName) ports + portDir = lookupPortDir portName portIdx usedIdents = execWriter $ collectNestedExprsM exprIdents expr tmp = "sv2v_tmp_" ++ instanceName ++ "_" ++ portName @@ -117,7 +123,18 @@ convertDescription ports orig = error $ "bad non-lhs, non-net expr " ++ show expr ++ " connected to output port " ++ portName ++ " of " ++ instanceName - fixBinding other = (other, []) + fixBinding other _ = (other, []) + lookupPortDir :: Identifier -> Int -> Maybe Direction + lookupPortDir "" portIdx = + case Map.lookup moduleName ports of + Nothing -> Nothing + Just l -> if portIdx >= length l + then Nothing + else Just $ snd $ l !! portIdx + lookupPortDir portName _ = + case Map.lookup moduleName ports of + Nothing -> Nothing + Just l -> lookup portName l fixModuleItem other = other -- rewrite variable declarations to have the correct type @@ -141,26 +158,43 @@ convertDescription ports orig = convertDecl other = other regIdents :: ModuleItem -> Writer Idents () -regIdents (AlwaysC _ stmt) = do - collectNestedStmtsM collectReadMemsM stmt - collectNestedStmtsM (collectStmtLHSsM (collectNestedLHSsM lhsIdents)) $ - traverseNestedStmts removeTimings stmt - where - removeTimings :: Stmt -> Stmt - removeTimings (Timing _ s) = s - removeTimings other = other - collectReadMemsM :: Stmt -> Writer Idents () - collectReadMemsM (Subroutine (Ident f) (Args (_ : Just (Ident x) : _) [])) = - if f == "$readmemh" || f == "$readmemb" - then tell $ Set.singleton x - else return () - collectReadMemsM _ = return () -regIdents (Initial stmt) = - regIdents $ AlwaysC Always stmt -regIdents (Final stmt) = - regIdents $ AlwaysC Always stmt +regIdents (item @ AlwaysC{}) = regIdents' item +regIdents (item @ Initial{}) = regIdents' item +regIdents (item @ Final{}) = regIdents' item regIdents _ = return () +regIdents' :: ModuleItem -> Writer Idents () +regIdents' item = do + let write = traverseScopesM traverseDeclM return traverseStmtM item + leftovers <- execStateT write Set.empty + if Set.null leftovers + then return () + else error $ "regIdents' got leftovers: " ++ show leftovers + +traverseDeclM :: Monad m => Decl -> StateT Idents m Decl +traverseDeclM (decl @ (Variable _ _ x _ _)) = + modify (Set.insert x) >> return decl +traverseDeclM decl = return decl + +traverseStmtM :: Stmt -> StateT Idents (Writer Idents) Stmt +traverseStmtM (Timing _ stmt) = traverseStmtM stmt +traverseStmtM (Subroutine (Ident f) args) = do + case args of + Args [_, Just (Ident x), _] [] -> + -- assuming that no one will readmem into a local variable + if f == "$readmemh" || f == "$readmemb" + then lift $ tell $ Set.singleton x + else return () + _ -> return () + return $ Subroutine (Ident f) args +traverseStmtM stmt = do + -- only write down idents which aren't shadowed + let regs = execWriter $ collectStmtLHSsM (collectNestedLHSsM lhsIdents) stmt + locals <- get + let globals = Set.difference regs locals + lift $ tell globals + return stmt + lhsIdents :: LHS -> Writer Idents () lhsIdents (LHSIdent x) = tell $ Set.singleton x lhsIdents _ = return () -- the collector recurses for us diff --git a/test/basic/output_bound_reg.sv b/test/basic/output_bound_reg.sv new file mode 100644 index 0000000..0d79bd9 --- /dev/null +++ b/test/basic/output_bound_reg.sv @@ -0,0 +1,37 @@ +module Flip(x, y); + input x; + output y; + assign y = ~x; +endmodule + +module Test1(o); + output [1:0] o; + logic x = 0; + for (genvar i = 0; i < 1; ++i) begin + Flip flip(x, o[i]); + end + initial begin + integer i = 0; + end +endmodule + +module Test2(o); + output o; + logic x = 0; + Flip flip(x, o); + initial begin + integer o = 0; + x = 0; + end +endmodule + +module Test3(o); + output o; + logic [1:0] x; + Flip flip(x[0], x[1]); + assign o = x[0]; + initial x[0] = 0; + initial begin + integer x = 0; + end +endmodule diff --git a/test/basic/output_bound_reg.v b/test/basic/output_bound_reg.v new file mode 100644 index 0000000..c6a9e7d --- /dev/null +++ b/test/basic/output_bound_reg.v @@ -0,0 +1,43 @@ +module Flip(x, y); + input x; + output y; + assign y = ~x; +endmodule + +module Test1(o); + output [1:0] o; + wire x = 0; + generate + genvar i; + for (i = 0; i < 1; i = i + 1) begin + Flip flip(x, o[i]); + end + endgenerate + initial begin : blah + integer i; + i = 0; + end +endmodule + +module Test2(o); + output o; + wire x = 0; + Flip flip(x, o); + initial begin : blah + integer o; + o = 0; + end +endmodule + +module Test3(o); + output o; + reg x_0; + wire x_1; + Flip flip(x_0, x_1); + assign o = x_0; + initial x_0 = 0; + initial begin : blah + integer x; + x = 0; + end +endmodule diff --git a/test/basic/output_bound_reg_tb.v b/test/basic/output_bound_reg_tb.v new file mode 100644 index 0000000..541c825 --- /dev/null +++ b/test/basic/output_bound_reg_tb.v @@ -0,0 +1,8 @@ +module top; + wire [1:0] o1; + Test1 bar(o1); + wire o2; + Test2 test2(o2); + wire o3; + Test3 test3(o3); +endmodule