diff --git a/src/Convert/NestPI.hs b/src/Convert/NestPI.hs index 5525dd0..8b38d3f 100644 --- a/src/Convert/NestPI.hs +++ b/src/Convert/NestPI.hs @@ -78,13 +78,10 @@ addItems pis existingPIs (item : items) = addItems pis existingPIs (head itemsToAdd : item : items) where thisPI = execWriter $ collectPIsM item - runner f = execWriter $ collectNestedModuleItemsM f item - usedPIs = Set.unions $ map runner - [ collectStmtsM collectSubroutinesM - , collectTypesM $ collectNestedTypesM collectTypenamesM - , collectExprsM $ collectNestedExprsM collectExprIdentsM - , collectLHSsM $ collectNestedLHSsM collectLHSIdentsM - ] + usedPIs = execWriter $ + traverseNestedModuleItemsM (traverseIdentsM writeIdent) item + writeIdent :: Identifier -> Writer Idents Identifier + writeIdent x = tell (Set.singleton x) >> return x neededPIs = Set.difference (Set.difference usedPIs existingPIs) thisPI itemsToAdd = map MIPackageItem $ Map.elems $ Map.restrictKeys pis neededPIs @@ -98,28 +95,66 @@ collectPIsM (MIPackageItem item) = ident -> tell $ Set.singleton ident collectPIsM _ = return () --- writes down the names of subroutine invocations -collectSubroutinesM :: Stmt -> Writer Idents () -collectSubroutinesM (Subroutine (Ident f) _) = tell $ Set.singleton f -collectSubroutinesM _ = return () +-- visits all identifiers in a module item +traverseIdentsM :: Monad m => MapperM m Identifier -> MapperM m ModuleItem +traverseIdentsM identMapper = traverseNodesM + (traverseExprIdentsM identMapper) + (traverseDeclIdentsM identMapper) + (traverseTypeIdentsM identMapper) + (traverseLHSIdentsM identMapper) + (traverseStmtIdentsM identMapper) --- writes down the names of function calls and identifiers -collectExprIdentsM :: Expr -> Writer Idents () -collectExprIdentsM (Call (Ident x) _) = tell $ Set.singleton x -collectExprIdentsM (Ident x) = tell $ Set.singleton x -collectExprIdentsM _ = return () +-- visits all identifiers in an expression +traverseExprIdentsM :: Monad m => MapperM m Identifier -> MapperM m Expr +traverseExprIdentsM identMapper = fullMapper + where + fullMapper = exprMapper >=> traverseSinglyNestedExprsM fullMapper + exprMapper (Call (Ident x) args) = + identMapper x >>= \x' -> return $ Call (Ident x') args + exprMapper (Ident x) = identMapper x >>= return . Ident + exprMapper other = return other --- writes down the names of identifiers -collectLHSIdentsM :: LHS -> Writer Idents () -collectLHSIdentsM (LHSIdent x) = tell $ Set.singleton x -collectLHSIdentsM _ = return () +-- visits all identifiers in a type +traverseTypeIdentsM :: Monad m => MapperM m Identifier -> MapperM m Type +traverseTypeIdentsM identMapper = fullMapper + where + fullMapper = typeMapper + >=> traverseTypeExprsM (traverseExprIdentsM identMapper) + >=> traverseSinglyNestedTypesM fullMapper + typeMapper (Alias x t) = aliasHelper (Alias ) x t + typeMapper (PSAlias p x t) = aliasHelper (PSAlias p ) x t + typeMapper (CSAlias c p x t) = aliasHelper (CSAlias c p) x t + typeMapper other = return other + aliasHelper constructor x t = + identMapper x >>= \x' -> return $ constructor x' t --- writes down aliased typenames -collectTypenamesM :: Type -> Writer Idents () -collectTypenamesM (Alias x _) = tell $ Set.singleton x -collectTypenamesM (PSAlias _ x _) = tell $ Set.singleton x -collectTypenamesM (CSAlias _ _ x _) = tell $ Set.singleton x -collectTypenamesM _ = return () +-- visits all identifiers in an LHS +traverseLHSIdentsM :: Monad m => MapperM m Identifier -> MapperM m LHS +traverseLHSIdentsM identMapper = fullMapper + where + fullMapper = lhsMapper + >=> traverseLHSExprsM (traverseExprIdentsM identMapper) + >=> traverseSinglyNestedLHSsM fullMapper + lhsMapper (LHSIdent x) = identMapper x >>= return . LHSIdent + lhsMapper other = return other + +-- visits all identifiers in a statement +traverseStmtIdentsM :: Monad m => MapperM m Identifier -> MapperM m Stmt +traverseStmtIdentsM identMapper = fullMapper + where + fullMapper = stmtMapper + >=> traverseStmtExprsM (traverseExprIdentsM identMapper) + >=> traverseStmtLHSsM (traverseLHSIdentsM identMapper) + >=> traverseSinglyNestedStmtsM fullMapper + stmtMapper (Subroutine (Ident x) args) = + identMapper x >>= \x' -> return $ Subroutine (Ident x') args + stmtMapper other = return other + +-- visits all identifiers in a declaration +traverseDeclIdentsM :: Monad m => MapperM m Identifier -> MapperM m Decl +traverseDeclIdentsM identMapper = + traverseDeclExprsM (traverseExprIdentsM identMapper) >=> + traverseDeclTypesM (traverseTypeIdentsM identMapper) -- returns the "name" of a package item, if it has one piName :: PackageItem -> Identifier diff --git a/src/Convert/Traverse.hs b/src/Convert/Traverse.hs index 92b858b..eae33d0 100644 --- a/src/Convert/Traverse.hs +++ b/src/Convert/Traverse.hs @@ -26,6 +26,7 @@ module Convert.Traverse , traverseExprsM , traverseExprs , collectExprsM +, traverseNodesM , traverseStmtExprsM , traverseStmtExprs , collectStmtExprsM @@ -84,6 +85,9 @@ module Convert.Traverse , traverseSinglyNestedExprsM , traverseSinglyNestedExprs , collectSinglyNestedExprsM +, traverseLHSExprsM +, traverseLHSExprs +, collectLHSExprsM , traverseNestedLHSsM , traverseNestedLHSs , collectNestedLHSsM @@ -503,6 +507,11 @@ traverseLHSExprsM exprMapper = return $ LHSStream o e' ls lhsMapper other = return other +traverseLHSExprs :: Mapper Expr -> Mapper LHS +traverseLHSExprs = unmonad traverseLHSExprsM +collectLHSExprsM :: Monad m => CollectorM m Expr -> CollectorM m LHS +collectLHSExprsM = collectify traverseLHSExprsM + mapBothM :: Monad m => MapperM m t -> MapperM m (t, t) mapBothM mapper (a, b) = do a' <- mapper a @@ -510,14 +519,31 @@ mapBothM mapper (a, b) = do return (a', b') traverseExprsM :: Monad m => MapperM m Expr -> MapperM m ModuleItem -traverseExprsM exprMapper = moduleItemMapper +traverseExprsM exprMapper = + traverseNodesM exprMapper declMapper typeMapper lhsMapper stmtMapper where - declMapper = traverseDeclExprsM exprMapper typeMapper = traverseNestedTypesM (traverseTypeExprsM exprMapper) lhsMapper = traverseNestedLHSsM (traverseLHSExprsM exprMapper) stmtMapper = traverseNestedStmtsM (traverseStmtExprsM exprMapper) +traverseExprs :: Mapper Expr -> Mapper ModuleItem +traverseExprs = unmonad traverseExprsM +collectExprsM :: Monad m => CollectorM m Expr -> CollectorM m ModuleItem +collectExprsM = collectify traverseExprsM + +traverseNodesM + :: Monad m + => MapperM m Expr + -> MapperM m Decl + -> MapperM m Type + -> MapperM m LHS + -> MapperM m Stmt + -> MapperM m ModuleItem +traverseNodesM exprMapper declMapper typeMapper lhsMapper stmtMapper = + moduleItemMapper + where + portBindingMapper (p, e) = exprMapper e >>= \e' -> return (p, e') @@ -600,11 +626,6 @@ traverseExprsM exprMapper = moduleItemMapper e' <- exprMapper e return (dir, ident, e') -traverseExprs :: Mapper Expr -> Mapper ModuleItem -traverseExprs = unmonad traverseExprsM -collectExprsM :: Monad m => CollectorM m Expr -> CollectorM m ModuleItem -collectExprsM = collectify traverseExprsM - traverseStmtExprsM :: Monad m => MapperM m Expr -> MapperM m Stmt traverseStmtExprsM exprMapper = flatStmtMapper where