Modify recursive function call sites. #62

Jul 21, 2021

Modify recursive function call sites.

These wrappers are generated by the rmrec pass, and have the form

"\a1 a2 ... an -> f (D a1 a2 ... an)"

where f is a recursive function and D is a data constructor.

--TODO: Modify recursive function call sites.

{- |
Module : Fhw.Pass.Streamify.Streamify
Description : Lift programs into the world of Streams.

Currently we're making numerous assumptions about the form of the
        input module. List them!
      1. Only one call site for each recursive function. This call site 
          defines a top-level vdef e.g. v = f a1 a2 ... an.
      2. If the recursive function 'f' was modified by rmrec pass, it has only
          one call site within a wrapper function 'g' of the form
          g = \a1 a2 ... an -> f (D a1 a2 ... an).
      3. If a recursive function f was modified by rmrec pass, then it's argument
          is of type "Action_f".
      4. Up to two recursive functions may exist, where the result of a call
          to one of them is passed as an argument to the other e.g.
          append (build 1 100 []) [].

This pass has four high-level steps:
  1. Modify or generate the Action data type for each recursive function. 
      This type encapsulates both the function's arguments (Call, Start)
      and its return value (Ret, Done). We also introduce the NOP
      variant here.
  2. Modify the definition of each recursive function 'f':
      1) Change type signature to "Action_f -> Action_f".
      2) Modify return type of each case expression to "Action_f".
      3) Introduce new top-level case matching on Action variants
          if f was ignored by rmrec pass.
      4) Wrap any non-recursive result in a "Done" variant. 
      5) Replace any recursive call "f (D a1 a2 ... an)" with "D a1 a2 ... an"
          if "f" was modified by rmrec pass.
      6) Add new case alternatives for Start, Done, and NOP variants if
          f was modified by rmrec pass.
  3. Modify call site of recursive function (or wrapper, if it exists).
     We assume it is of the form "v = f e1 e2 ...". (f may be a wrapper)
      1) If one argument e_i is the result of some other recursive
          function call "e_i = g a1 a2 ...", replace with  
          "v = map f (NOP_f :> zipWith merge e_i v)". Define merge function
          that passes a Start off to f once g is Done.
      2) Otherwise, replace with "v = map f (Start e1 e2 ... :> v)".
      3) Change type of 'v' to Stream Action_f.
  4. If main = repeat @ t e and e is a top-level vdef of type Stream Action_*,
      modify type of "main" to Stream Action_f and
      redefine main = e.

module Fhw.Pass.Streamify.Streamify (streamify) where

--import Debug.Trace

import Fhw.Core.Core
import Fhw.Pass.RemoveRecursion.Utils
import Fhw.Haskgen.DeadCode

import Data.List
import Data.Maybe

-- | Information pertaining to a recursive function
data FuncInfo = FuncInfo { func :: Vdef           -- the function
                         , wrapper :: Maybe Vdef  -- a possible wrapper
                         , acTy :: ActionTy }     -- Action type associated with
                            deriving (Show,Eq)    --  the function

-- | Components of an Action type definition needed
--    for transformations
data ActionTy = ActionTy { aTdef :: Tdef          -- the type definition
                         , aTy :: Ty              -- type constructor
                         , startDcon :: Exp       -- data constructors
                         , doneDcon :: Exp
                         , callDcon :: Exp
                         , nopDcon :: Exp } deriving (Show,Eq)

streamify :: Module -> Module
streamify (Module mname tdefs vdefs) = removeCode (Module mname tdefs' vdefs') ["main"]
    (recDefs,nonRecDefs) = partition isRecursive vdefs
    recsAndWraps = map (getWrapper nonRecDefs) recDefs
    (funcInfo,oldTdefs) = modActionType recsAndWraps tdefs
    newInfo = map (modRecFuncs . buildInfo) funcInfo
    newNonRecs = modMain $ foldl (modCallSites newInfo) nonRecDefs newInfo
    newRecs = map func newInfo
    vdefs' = newRecs ++ newNonRecs
    tdefs' = map (\(_,_,td) -> td) funcInfo ++ oldTdefs
    --TODO: Modify recursive function call sites.

    isRecursive (Vdef (_,n) _ ex) = exprRecursiveness n ex /= Nonrecursive

-- | Pair each recursive function with its wrapper, if it exists.
--   These wrappers are generated by the rmrec pass, and have the form
--      "\a1 a2 ... an -> f (D a1 a2 ... an)" 
--   where f is a recursive function and D is a data constructor.
--   TODO: These wrappers may occur naturally; have rmrec pass send along
--          the functions it modified instead of searching for them based on
--          structure.
getWrapper :: [Vdef] -> Vdef -> (Vdef,Maybe Vdef)
getWrapper vdefs recDef = (recDef, wrap)
    wrap = find (isWrapper (vdefName recDef)) vdefs

    isWrapper name (Vdef (_,n) _ (Lam _ (App (Var (_,var) _) _))) =
      n /= name && var == name
    isWrapper _ _ = False

-- | Given the recursive functions in our program, modify (or generate)
--   its associated Action type and return all modified and unmodified type
--   definitions.
modActionType :: [(Vdef,Maybe Vdef)] -> [Tdef] -> ([(Vdef,Maybe Vdef,Tdef)],[Tdef])
modActionType recsAndWraps tdefs = (newInfo,oldTdefs)
    --New action types associated with each function and its possible wrapper.
    newInfo = map (modAction . getAction) recsAndWraps
    tyNames = map (\(_,_,ty) -> tdefName ty) newInfo

    oldTdefs = filter (flip notElem tyNames . tdefName) tdefs

    --Functions have Action types if they were modified by rmrec and every 
    --function so modified has a wrapper associated with it.
    getAction (vd@(Vdef n ty _), Just vdef) = 
      let t = snd $ collectArgTypes ty
          actionTy = find ((==) (tyName $ head t) . tdefName) tdefs
      in if length t /= 1 || isNothing actionTy
          then error $ "Unexpected type for function " ++ show n
          else (vd,Just vdef,actionTy)
    getAction (vdef,Nothing) = (vdef,Nothing,Nothing)

    modAction (vd,Just vdef,Just (Data tname tb cdefs)) = --wrapper vdef and action type
      (vd,Just vdef,Data tname tb $ start : nop : done : cdefs)
        name = vdefName vdef

        --TODO: Name reliant
        start = mkConstr name "Start_" $ getNonConts "Call"
        nop   = mkConstr name "NOP_"   [] 
        done  = mkConstr name "Done_"  $ getNonConts "Ret"

        getNonConts cname = maybe [] (removeConts . constrTypes) $
                    find (\(Constr (_,n) _ _) -> cname `isPrefixOf` n) cdefs

        --TODO: Name reliant
        removeConts [] = []
        removeConts (ty:rest) 
          | "CT" `isPrefixOf` tyName ty = removeConts rest
          | otherwise = ty : removeConts rest

    modAction (vdef@(Vdef (_,name) ty _),Nothing,Nothing) =  --no action type
      (vdef,Nothing,Data (Nothing,"Action_"++name) [] [start,nop,done,call] )
        (retTy,argTys) = collectArgTypes ty

        start = mkConstr name "Start_" argTys
        nop   = mkConstr name "NOP_" []
        done  = mkConstr name "Done_" [retTy]
        call  = mkConstr name "Call_" argTys

    modAction _ = error "Unexpected form in modAction"

    mkConstr name n = Constr (Nothing,n++name) []

-- | Build FuncInfo data type 
buildInfo :: (Vdef,Maybe Vdef,Tdef) -> FuncInfo
buildInfo (vd,mVd,td) = 
  FuncInfo { func = vd, wrapper = mVd, 
             acTy = ActionTy { aTdef = td 
                             , aTy = actTy
                             , startDcon = getDcon "Start"
                             , doneDcon = getDcon "Done"
                             , callDcon = getDcon "Call"
                             , nopDcon = getDcon "NOP" } }
    actTy = Tcon (Nothing,tdefName td)
    getDcon name = Dcon dName dTy
        (dName,dTy) = maybe (error $ "Cdef missing dcon " ++ name) getDefInfo cdef
        getDefInfo (Constr cname _ tys) = (cname,tArrows $ tys ++ [actTy])
        cdef = find (\(Constr (_,n) _ _) -> name `isPrefixOf` n) $ tdefConstrs td

-- | Given a recursive function, a possible wrapper, and its associated
--   Action data type, transform recursive function into a non-recursive
--   transition table of type Action -> Action.
--   The modifications are numbered as follows:
--   1. Change type signature to Action -> Action.
--   2. Introduce new top-level lambda and case expression matching on 
--      Action variants if function was ignored by rmrec pass.
--   3. Modify return types of all case expressions to Action.
--   4. Add Start, Done, and NOP case alternatives.
--   5. Replace final expressions:
--      a. For modified functions 'f', if expr is 'f e' replace with 'e'.
--      b. For ignored functions 'f', if expr is 'f a1 a2 ...' replace 
--          with 'Call a1 a2 ..."
--      c. Otherwise, expr is 'e' so replace with with 'Done e'.
modRecFuncs :: FuncInfo -> FuncInfo
modRecFuncs info@(FuncInfo { func = Vdef name _ (Lam bs (Case e vbs _ alts))
                      , wrapper = wrap
                      , acTy = ActionTy 
                        { aTy = actionTy 
                        , startDcon = Dcon startName startTy
                        , callDcon = callEx@(Dcon callName _)
                        , nopDcon = nopEx@(Dcon nopName _)
                        , doneDcon = doneEx@(Dcon doneName doneTy)}}) = 
  info {func = newVdef}
    --Make choice based on whether this function was ignored by rmrec or not
    wrapChoice c1 c2 = if isNothing wrap then c1 else c2

    newVdef = Vdef name (tArrows $ replicate 2 actionTy) $ --(1)
              wrapChoice newLam $ Lam bs $ 
              mapReturn modFunc $
              Case e vbs actionTy $ 
              nopAcon : doneAcon : startAcon startBinds : moddedAlts

    startBinds = let (_,argTys) = collectArgTypes startTy
                     maker t binds = ("arg_" ++ show (length binds),t) : binds
                     -- TODO: Name uniqueness
                 in foldr maker [] argTys

    -- (2)
    newLam = Lam [Vb ("actionArg",actionTy)] $
             mkVarCase "actionArg" actionTy actionTy
             [Acon callName [] origBinds $ 
                mapReturn modFunc $ Case e vbs actionTy moddedAlts
             ,startAcon origBinds
        origBinds = map removeConstr $ filter isVb bs
            isVb (Vb _) = True
            isVb _ = False 
            removeConstr (Vb v) = v
            removeConstr _ = error "No type binds should exist"

    moddedAlts = map (mapAlt changeRetTy) alts
        changeRetTy (Case scrut vbinds _ ex) = Case scrut vbinds actionTy ex
        changeRetTy ex = ex

    nopAcon = Acon nopName [] [] nopEx
    doneAcon = Acon doneName [] [("_",doneArgTy)] nopEx
        doneArgTy = let (_,argTy) = collectArgTypes doneTy
                    in if null argTy
                        then error "Done variant issue in modRecFuncs"
                        else head argTy
    startAcon binds = Acon startName [] binds $
                      mkFuncApp callEx [] $
                      map (\(v,t) -> Var (Nothing,v) t) binds ++ 
                      if length binds == length argTys - 1 --expecting one more arg
                        then wrapChoice [] [initialCont] 
                        else []
        (_,argTys) = collectArgTypes $ exprType callEx 
        --TODO: Assumes that if we have a wrapper, then it is passing the
        --initial arguments in a data constructor whose final field is a 
        --continuation. I don't know if this is always the case.
        initialCont = case wrap of
          Just (Vdef _ _ (Lam _ (App _ ex))) -> 
            let (_,args,_) = collectArgs ex
            in if null args then wraperr else last args
          _ -> wraperr 
        wraperr = error $ "Wrapper does not have expected form" ++ show wrap

    modFunc ex 
      | isAction call = ex
      | getVarName call == snd name = wrapChoice caller (head expArgs)
      | otherwise = App doneEx ex
        (call,expArgs,_) = collectArgs ex
        caller = foldl App callEx expArgs
        isAction (Dcon _ t) = actionTy == fst (collectArgTypes t)
        isAction _ = False

modRecFuncs _ = error "Unexpected form of recursive function"
-- | Modify the call site of each recursive function, 
--   or its wrapper if one exists.
modCallSites :: [FuncInfo] -> [Vdef] -> FuncInfo -> [Vdef]
modCallSites allInfo vdefs info = newCallSite callSite ++ others
    actionTy = aTy $ acTy info

    --Find call site
    (callSite,others) = partition callsRec vdefs
    callsRec (Vdef _ _ ex) = 
      let (call,_,_) = collectArgs ex
          checkName  = (==) (getVarName call) . vdefName
      in checkName $ fromMaybe (func info) (wrapper info)

    newCallSite [Vdef name ty ex] = 
      if null streamArgs then [mapFunc] else zipFunc (head streamArgs)
        Vdef funcName funcTy _ = func info
        (_,expArgs,_) = collectArgs ex
        startEx = startDcon $ acTy info
        nopEx = nopDcon $ acTy info

        streamArgs = let names = map getVarName expArgs
                         topDefs = filter (flip elem names . vdefName) vdefs
                     in mapMaybe getStreamArgs topDefs

        --Get the FuncInfo entry corresponding to this streamArg, if
        --it exists.
        getStreamArgs vd@(Vdef _ _ e) = 
          let (call,exArgs,_) = collectArgs e
              callName = let x = getVarName call
                         --TODO: Hack to handle cases where we already 
                         --modified the other function's call sitee call
                         in if x == "map" 
                              then getVarName $ head exArgs
                              else x
              funcAndWrap i = func i : mapMaybe wrapper [i]
              entry = find (elem callName . map vdefName . funcAndWrap) allInfo
          in fmap (\entryEx -> (vd,entryEx)) entry

        baseFunc exs = Vdef name (streamTy actionTy) $
                        mkFuncApp varMap [actionTy,actionTy]
                        [Var funcName funcTy
                        ,mkFuncApp streamDcon [actionTy] exs]

        mapFunc = baseFunc [foldl App startEx expArgs,Var name ty]

        zipFunc (Vdef n t _,otherInfo) = 
          [mergeFunc,baseFunc [nopEx,mkFuncApp varZW tyList
                                    [mergeVar,Var n t,Var name ty]]]
            tyList = [otherTy,actionTy,actionTy]
            otherTy = aTy $ acTy otherInfo
            (Dcon dconName dconTy) = doneDcon $ acTy otherInfo
            doneArgTy = let (_,tys) = collectArgTypes dconTy
                        in if null tys
                            then error "Done should take an argument in zipFunc"
                            else head tys

            mergeVar = Var (Nothing,mergeName) mergeTy
            mergeTy = tArrows tyList
            mergeName = intercalate "_" ["merge",snd n,snd name]
            mergeFunc = Vdef (Nothing,mergeName) mergeTy $
                        Lam [Vb ("buildArg",otherTy),Vb ("arg",actionTy)] $
                        mkVarCase "buildArg" otherTy actionTy
                        [Acon dconName [] [("result",doneArgTy)] $
                          foldl App startEx startArgs
                        ,Adefault $ Var (Nothing,"arg") actionTy]

            --Index of stream argument
            argIndex = elemIndex (snd n) (map getVarName expArgs)
            startArgs = let spot = fromMaybe (error "No stream arg found in zipFunc") argIndex
                            (preRest,postRes) = splitAt spot expArgs
                            resVar = Var (Nothing,"result") doneArgTy
                        in preRest ++ [resVar] ++ tail postRes

    newCallSite _ = error "Unexpected call site in modCallSites"

modMain :: [Vdef] -> [Vdef]
modMain vdefs = newMain mainF : others
    (mainF,others) = partition isMain vdefs
    isMain (Vdef (_,"main") _ _) = True
    isMain _ = False

    newMain [vd@(Vdef name _ ex)] = if length expArgs == 1 && isJust topDef
                                      then let (Vdef _ ty _) = fromJust topDef
                                           in Vdef name ty $
                                              foldl1 App expArgs
                                      else vd
        (_,expArgs,_) = collectArgs ex
        topDef = let n = getVarName $ head expArgs
                 in find ((==) n . vdefName) vdefs
    newMain _ = error "Issue with newMain"

tyName :: Ty -> String
tyName (Tcon (_,name)) = name
tyName _ = ""

getVarName :: Exp -> Var
getVarName (Var (_,name) _) = name
getVarName _ = ""

mkVarCase :: String -> Ty -> Ty -> [Alt] -> Exp
mkVarCase name ty = Case (Var (Nothing,name) ty) ("dummy",ty)

tvA :: Ty
tvA = Tvar "a"
tvB :: Ty
tvB = Tvar "b"
tvC :: Ty
tvC = Tvar "c"

streamDcon :: Exp
streamDcon = Dcon (Just fhwPrimMname, ":>") ty
    ty = tArrows [tvA, streamTy tvA, streamTy tvA]

varMap :: Exp
varMap = Var (Just fhwPrimMname, "map") mapType
    mapType = tForalls [("a",Klifted),("b",Klifted)] $
              tArrows [tArrow tvA tvB
                      ,streamTy tvA
                      ,streamTy tvB]

varZW :: Exp
varZW  = Var (Just fhwPrimMname, "zipWith") zwTy
    zwTy = tForalls [("a",Klifted),("b",Klifted),("c",Klifted)] $
           tArrows [tArrows [tvA,tvB,tvC]
                   ,streamTy tvA
                   ,streamTy tvB
                   ,streamTy tvC]


