diff --git a/src/ksc/AD.hs b/src/ksc/AD.hs index c9c2a032c..b5f86203f 100644 --- a/src/ksc/AD.hs +++ b/src/ksc/AD.hs @@ -35,9 +35,10 @@ gradV adp (Simple x) = Grad x adp gradV _ v = error ("gradV: bad variable: " ++ render (ppr v)) gradTFun :: HasCallStack => ADPlan -> TFun Typed -> Type -> TFun Typed -gradTFun adp (TFun res_ty f) arg_tys - = TFun (mkGradType adp arg_tys res_ty) - (gradF adp f) +gradTFun adp (TFun { tf_ret = res_ty, tf_fun = f, tf_targs = targs }) arg_tys + = TFun { tf_ret = mkGradType adp arg_tys res_ty + , tf_fun = gradF adp f + , tf_targs = targs } -- SLPJ: is this right? mkGradTVar :: HasCallStack => ADPlan -> Type -> Var -> Type -> TVar mkGradTVar adp s var ty @@ -58,10 +59,11 @@ gradDef adp = gradDefInner adp . noTupPatifyDef gradDefInner :: HasCallStack => ADPlan -> TDef -> Maybe TDef gradDefInner adp - (Def { def_fun = Fun JustFun f, def_pat = VarPat params + (Def { def_fun = Fun JustFun f, def_qvars = qvs, def_pat = VarPat params , def_rhs = UserRhs rhs, def_res_ty = res_ty }) = Just $ Def { def_fun = Fun (GradFun adp) f + , def_qvars = qvs , def_pat = VarPat params , def_res_ty = mkGradType adp s_ty res_ty , def_rhs = UserRhs (mkLets lets (gradE adp s rhs')) } @@ -306,10 +308,11 @@ applyD dir def@(Def { def_pat = TupPat {} }) -- Dt$f :: S1 S2 -> (T, (S1,S2) -o T) -- fwd$f :: (S1, S2) (dS1, dS2) -> dT -- fwdt$f :: (S1, S2) (dS1, dS2) -> (T, dT) -applyD Fwd (Def { def_fun = Fun (GradFun adp) f, def_res_ty = res_ty +applyD Fwd (Def { def_fun = Fun (GradFun adp) f, def_qvars = qvs, def_res_ty = res_ty , def_pat = VarPat x, def_rhs = UserRhs rhs }) = Def { def_fun = Fun (DrvFun (AD adp Fwd)) f - , def_pat = VarPat x_dx + , def_qvars = qvs + , def_pat = VarPat x_dx , def_rhs = UserRhs $ extract2args $ perhapsFstToo $ lmApply lm $ Var dx , def_res_ty = t } where @@ -336,9 +339,10 @@ applyD Fwd (Def { def_fun = Fun (GradFun adp) f, def_res_ty = res_ty -- D$f :: S1 S2 -> ((S1,S2) -o T) -- rev$f :: (S1, S2) dT -> (dS1,dS2) -applyD Rev (Def { def_fun = Fun (GradFun adp) f, def_res_ty = res_ty +applyD Rev (Def { def_fun = Fun (GradFun adp) f, def_qvars = qvs, def_res_ty = res_ty , def_pat = VarPat x, def_rhs = UserRhs rhs }) = Def { def_fun = Fun (DrvFun (AD adp Rev)) f + , def_qvars = qvs , def_pat = VarPat x_dr , def_rhs = UserRhs $ extract2args $ lmApplyR (Var dr) lm , def_res_ty = tangentType (typeof x) } diff --git a/src/ksc/Annotate.hs b/src/ksc/Annotate.hs index e60d20451..149f02437 100644 --- a/src/ksc/Annotate.hs +++ b/src/ksc/Annotate.hs @@ -12,7 +12,7 @@ module Annotate ( import Lang import LangUtils import KMonad -import Ksc.Traversal ( traverseOf, over ) +import Ksc.Traversal ( traverseOf ) import Prim import qualified Data.Map as Map import GHC.Stack @@ -55,11 +55,12 @@ annotDecls gbl_env decls -- We don't have a type-checked body yet, but that -- doesn't matter because we aren't doing inlining mk_rec_def :: DefX Parsed -> TcM TDef - mk_rec_def (Def { def_fun = fun, def_pat = pat, def_res_ty = res_ty }) + mk_rec_def (Def { def_fun = fun, def_qvars = qvs + , def_pat = pat, def_res_ty = res_ty }) = addCtxt (text "In the definition of" <+> ppr_fun) $ - tcWithPat pat $ \pat' -> - do { fun' <- tcUserFunArgTy @Parsed fun (typeof pat') - ; return (Def { def_fun = fun', def_pat = pat' + do { fun' <- tcUserFunArgTy @Parsed fun qvs pat + ; return (Def { def_fun = fun', def_qvars = qvs + , def_pat = pat , def_res_ty = res_ty , def_rhs = StubRhs }) } where ppr_fun = pprUserFun @Parsed fun @@ -95,23 +96,26 @@ tcDeclX (GDefDecl gdef) = do { gdef' <- tcGDef gdef tcDef :: forall p. (Pretty (BaseUserFun p), InPhase p) => DefX p -> TcM TDef tcDef (Def { def_fun = fun - , def_pat = pat + , def_qvars = qvs + , def_pat = pat , def_res_ty = res_ty , def_rhs = rhs }) = addCtxt (text "In the definition of" <+> ppr_fun) $ - do { tcWithPat pat $ \pat' -> + tcWithPat pat $ do { rhs' <- tcRhs fun rhs res_ty - ; fun' <- tcUserFunArgTy @p fun (typeof pat') - ; return (Def { def_fun = fun', def_pat = pat' + ; fun' <- tcUserFunArgTy @p fun qvs pat + ; return (Def { def_fun = fun' + , def_qvars = qvs + , def_pat = pat , def_rhs = rhs', def_res_ty = res_ty }) - }} + } where ppr_fun = pprUserFun @p fun -- Adds the variables in the pattern to the symbol table and -- runs the continuation on the pattern. -tcWithPat :: Pat -> (Pat -> TcM a) -> TcM a +tcWithPat :: Pat -> TcM a -> TcM a tcWithPat pat continueWithArg - = do { extendLclSTM (patVars pat) $ continueWithArg pat } + = extendLclSTM (patVars pat) continueWithArg -- Checks that the type of pat matches expected_ty and if so returns -- the type-annotated Pat. @@ -193,21 +197,30 @@ tcGDef g@(GDef d f) } tcUserFunArgTy :: forall p. (Pretty (BaseUserFun p), InPhase p) - => UserFun p -> Type + => UserFun p -> [TyVar] -> Pat -> TcM (UserFun Typed) -tcUserFunArgTy fun arg_ty = case baseFunArgTy_maybe fun arg_ty of - Right (Just baseTy) -> case addBaseTypeToUserFun @p fun baseTy of - Right r -> pure r - Left appliedTy -> - tcFail (text "The base type did not match the applied type" - $$ text "The argument type was" <+> ppr arg_ty - $$ text "from which the base type was determined to be" <+> ppr baseTy - $$ text "but the applied type was" <+> ppr appliedTy) - Right Nothing -> traverseOf (baseFunFun . baseUserFunType) f fun - where f = \case - Nothing -> tcFail (text "No type was supplied and I couldn't deduce it from the argument type") - Just appliedTy -> pure appliedTy - Left err -> tcFail err +tcUserFunArgTy fun qvs pat + | not (null qvs) + = case addArgTypeDescriptor @p fun Poly of + Right r -> pure r + Left _ -> tcFail (text "The function is polymorphic but its Fun is monomorphic") + + | otherwise + = case baseFunArgTy_maybe fun arg_ty of + Right (Just baseTy) -> case addArgTypeDescriptor @p fun (Mono baseTy) of + Right r -> pure r + Left appliedTy -> + tcFail (text "The base type did not match the applied type" + $$ text "The argument type was" <+> ppr arg_ty + $$ text "from which the base type was determined to be" <+> ppr baseTy + $$ text "but the applied type was" <+> ppr appliedTy) + Right Nothing -> traverseOf (baseFunFun . baseUserFunType) f fun + where f = \case + Nothing -> tcFail (text "No type was supplied and I couldn't deduce it from the argument type") + Just appliedTy -> pure appliedTy + Left err -> tcFail err + where + arg_ty = typeof pat tcExpr :: forall p. InPhase p => ExprX p -> TcM TypedExpr -- Naming conventions in this function: @@ -223,15 +236,19 @@ tcExpr (Var vx) tcExpr (Konst k) = return (TE (Konst k) (typeofKonst k)) -tcExpr (Call fx es) +tcExpr (Call fx arg) = do { let (fun, mb_ty) = getFun @p fx - ; pairs <- addCtxt (text "In the call of:" <+> ppr fun) $ - tcExpr es - ; (funTyped, res_ty) <- addCtxt (text "In the call of:" <+> ppr fun) $ - lookupGblTc fun pairs - ; res_ty <- checkTypes_maybe mb_ty res_ty $ + ; arg_typed <- addCtxt (text "In the call of:" <+> ppr fun) $ + tcExpr arg + ; fun_typed <- addCtxt (text "In the call of:" <+> ppr fun) $ + lookupFunTc fun arg_typed + ; let res_ty = tFunResTy fun_typed + ; _res_ty <- checkTypes_maybe mb_ty res_ty $ text "Function call type mismatch for" <+> ppr fun - ; let call' = Call (TFun res_ty funTyped) (exprOf pairs) + -- SLPJ: _res_ty comes from original TFun when checking; + -- Here we simply replace it with the one gotten from the argument + -- They are, after all, equal! + ; let call' = Call fun_typed (exprOf arg_typed) ; return (TE call' res_ty) } tcExpr (Let vx rhs body) @@ -294,57 +311,6 @@ tcVar var mb_ty -- The typecheck monad ----------------------------------------------- -userCallDef_maybe :: HasCallStack - => UserFun Typed -> GblSymTab -> Either SDoc TDef -userCallDef_maybe fn env - = case lookupGblST fn env of - Just def -> Right def - Nothing -> Left (text "Not in scope: userCall:" - <+> ppr_fn $$ message) - where message = if null similarEnv - then empty - else text "Did you mean one of these:" - $$ vcat similarEnv - similarEnv = (map ppr - . filter similar - . Map.keys) env - similar envfn = - editDistance (render ppr_fn) (render (ppr envfn)) - <= configEditDistanceThreshold - editDistance = E.levenshteinDistance E.defaultEditCosts - - ppr_fn = pprUserFun @Typed fn - -userCallResultTy_maybe :: HasCallStack - => UserFun Typed -> GblSymTab -> Type -> Either SDoc Type -userCallResultTy_maybe fn env args - = do { def <- userCallDef_maybe fn env - ; userCallResultTy_help def args } - -userCallResultTy_help :: HasCallStack - => TDef -> Type -> Either SDoc Type -userCallResultTy_help (Def { def_fun = fn - , def_res_ty = ret_ty - , def_pat = pat }) - args - = check_args (patType pat) (typeof args) - where - check_args :: Type -> Type -> Either SDoc Type - check_args bndr_ty arg_ty - | bndr_ty `compatibleType` arg_ty - = Right ret_ty - | otherwise - = Left (hang (text "Type mis-match in argument" - <+> text "of call to" <+> ppr_fn) - 2 (vcat [ text "Expected:" <+> ppr bndr_ty - , text "Actual: " <+> ppr arg_ty ])) - where ppr_fn = pprUserFun @Typed fn - - ------------------------------------------------ --- The typecheck monad ------------------------------------------------ - data TcEnv = TCE { tce_ctxt :: [SDoc] -- Context, innermost first , tce_st :: SymTab } @@ -478,37 +444,131 @@ lookupLclTc v , text "Envt:" <+> lclDoc st ]) Just ty -> return ty } -lookupGblTc :: Fun Parsed -> TypedExpr -> TcM (Fun Typed, Type) -lookupGblTc fun args - = do { st <- getSymTabTc - ; (funTyped, callResultTy_maybe) <- case perhapsUserFun fun of - Right userFun -> do - { userFun' <- tcUserFunArgTy @Parsed userFun ty - ; pure (userFunToFun userFun', - userCallResultTy_maybe userFun' (gblST st) ty) } - Left fun' -> pure (over baseFunFun PrimFun fun', primCallResultTy_maybe fun' ty) - - ; res_ty <- case callResultTy_maybe of - Left err -> tcFail $ hang err 2 (mk_extra funTyped st) - Right res_ty -> return res_ty - ; pure (funTyped, res_ty) - } +lookupFunTc :: Fun Parsed -> TypedExpr -> TcM (TFun Typed) +lookupFunTc fun@(Fun ds (PrimFun p)) arg + = do { case primCallResultTy_maybe (Fun ds p) (typeof arg) of { + Left err -> tcFail $ hang err 2 (mk_extra fun arg) ; + Right res_ty -> + + return (TFun { tf_fun = Fun ds (PrimFun p) + , tf_targs = [] + , tf_ret = res_ty }) } } where - ty = typeof args - mk_extra fun st + mk_extra fun arg = vcat [ text "In a call of:" <+> ppr fun - , text " Arg types:" <+> ppr ty - , text " Args:" <+> ppr (exprOf args) - , text "ST lookup:" <+> case maybeUserFun fun of - Nothing -> text "" - Just userFun -> ppr (lookupGblST userFun (gblST st)) - -- This is very verbose, and obscures error messages, but can be useful for typos. - -- Perhaps think about printing it only for failed lookup of userfun - -- , text "ST keys:" <+> gblDoc st - ] - --- gblDoc :: SymTab -> SDoc --- gblDoc st = vcat (map (text . show) (Map.keys (gblST st))) + , text " Arg types:" <+> ppr (typeof arg) + , text " Args:" <+> ppr (exprOf arg) ] + + + +lookupFunTc fun@(Fun ds (BaseUserFun (BaseUserFunId fun_name _))) arg + = do { let arg_ty = typeof arg + + ; base_ty <- findBaseFunArgTy fun arg_ty + ; def <- lookupFunDef (Fun ds fun_name) base_ty + + -- Check that arg_ty is acceptable, and return the instantiation + ; checkArgTy def arg_ty } + + +checkArgTy :: TDef -> Type -> TcM (TFun Typed) +-- Check to see if an argument of type arg_ty is acceptable +-- the supplied function. If so, return the instantiation +-- of the function's type parameters, and the (instantiated) type +-- of its result. If not, fail. +checkArgTy (Def { def_fun = Fun der user_fun, def_qvars = qvars + , def_pat = pat, def_res_ty = res_ty }) + arg_ty + | Just subst <- matchTy (typeof pat) arg_ty + = return (TFun { tf_fun = Fun der (BaseUserFun user_fun) + , tf_targs = map (tySubstTyVar subst) qvars + , tf_ret = tySubstTy subst res_ty }) + | otherwise + = tcFail (hang (text "Type mis-match in argument" + <+> text "of call to" <+> ppr user_fun) + 2 (vcat [ text "Expected:" <+> ppr pat_ty + , text "Actual: " <+> ppr arg_ty ])) + where + pat_ty = typeof pat + + +findBaseFunArgTy :: Fun Parsed -> Type -> TcM Type +-- Turns the argument type of the (perhaps derived) function +-- into the argument type of the /base/ function +-- The input (Fun Parsed) may, optionally, have a base-type supplied +-- in the source code; if not, we try to work it out from argument type +findBaseFunArgTy fun arg_ty + = case baseFunArgTy_maybe fun arg_ty of + Left err -> tcFail err + Right Nothing -> tcFail (text "No type was supplied and I couldn't deduce it from the argument type") + Right (Just base_ty) -> return base_ty + +lookupFunDef :: DerivedFun BaseUserFunName + -> Type + -> TcM TDef +lookupFunDef fn base_arg_ty + = do { st <- getSymTabTc + ; case lookupBindings fn (gblST st) of { + Nothing -> fail "no-bind" ; + Just bindings -> + + case bindings of { + PolyBind def -> return def ; + MonoBinds binds -> + + case Map.lookup base_arg_ty binds of + Just def -> return def + Nothing -> fail "mono-bind" }}} + where + fail s = tcFail $ + vcat [ text "Not in scope: userCall:" <+> ppr (fmap text fn) + , text s ] + + +userCallDef_maybe :: HasCallStack + => UserFun Typed -> GblSymTab -> Either SDoc TDef +userCallDef_maybe fn env + = case lookupDef fn env of + Just def -> Right def + Nothing -> Left (text "Not in scope: userCall:" + <+> ppr_fn $$ message) + where + message = if null similarEnv + then empty + else text "Did you mean one of these:" + $$ vcat similarEnv + similarEnv = (map ppr + . filter similar + . Map.keys) env + similar envfn = editDistance (render ppr_fn) (render (ppr envfn)) + <= configEditDistanceThreshold + editDistance = E.levenshteinDistance E.defaultEditCosts + + ppr_fn = pprUserFun @Typed fn + lclDoc :: SymTab -> SDoc lclDoc st = vcat (map (text . show) (Map.keys (lclST st))) + + +{- +userCallResultTy_help :: HasCallStack + => TDef -> Type -> Either SDoc Type +userCallResultTy_help (Def { def_fun = fn + , def_res_ty = ret_ty + , def_pat = pat }) + args + = check_args (patType pat) (typeof args) + where + check_args :: Type -> Type -> Either SDoc Type + check_args bndr_ty arg_ty + | bndr_ty `compatibleType` arg_ty + = Right ret_ty + | otherwise + = Left (hang (text "Type mis-match in argument" + <+> text "of call to" <+> ppr_fn) + 2 (vcat [ text "Expected:" <+> ppr bndr_ty + , text "Actual: " <+> ppr arg_ty ])) + where ppr_fn = pprUserFun @Typed fn +-} + diff --git a/src/ksc/Cgen.hs b/src/ksc/Cgen.hs index e081831fa..fac44df7f 100644 --- a/src/ksc/Cgen.hs +++ b/src/ksc/Cgen.hs @@ -185,6 +185,7 @@ allocatorUsageOfType = \case TypeLam {} -> UsesAllocator TypeLM {} -> UsesAllocator TypeUnknown -> error "Shouldn't see TypeUnknown at this stage of codegen" + TypeVar {} -> error "Shouldn't see TypeVar at this stage of codegen" allocatorUsageOfCType :: CType -> AllocatorUsage allocatorUsageOfCType = \case @@ -400,7 +401,7 @@ cgenExprWithoutResettingAlloc env = \case -- Special case for copydown. Mark the allocator before evaluating the -- expression, then copydown the result to the marked position. - Call (TFun _ (Fun JustFun (PrimFun P_copydown))) e -> do + Call (TFun { tf_fun = Fun JustFun (PrimFun P_copydown) }) e -> do CG cdecl cexpr ctype _callocusage <- cgenExprR env e ret <- freshCVar bumpmark <- freshCVar @@ -438,7 +439,7 @@ cgenExprWithoutResettingAlloc env = \case cftype (funAllocatorUsage tf cftype <> callocusage) - Call tf@(TFun _ fun) vs -> do + Call tf@(TFun { tf_fun = fun }) vs -> do cgvs <- cgenExprR env vs let cgargtype = typeof vs let cdecls = getDecl cgvs @@ -602,12 +603,16 @@ mangleType = \case TypeLam a b -> "l<" ++ mangleType a ++ mangleType b ++ ">" TypeLM _ _ -> error "Can't mangle TypeLM" TypeUnknown -> error "Can't mangle TypeUnknown" + TypeVar {} -> error "Can't mangle TypeVar" cgenBaseFun :: BaseFun Typed -> String cgenBaseFun = \case - (BaseUserFun (BaseUserFunId fun (TypeTuple []))) -> mangleFun fun - (BaseUserFun (BaseUserFunId fun (TypeTuple tys))) -> mangleFun (fun ++ "@" ++ concatMap mangleType tys) - (BaseUserFun (BaseUserFunId fun ty)) -> mangleFun (fun ++ "@" ++ mangleType ty) + (BaseUserFun (BaseUserFunId fun Poly)) -> pprPanic "cgenBaseFun:Poly" (ppr fun) + -- ToDo + (BaseUserFun (BaseUserFunId fun (Mono (TypeTuple [])))) -> mangleFun fun + (BaseUserFun (BaseUserFunId fun (Mono (TypeTuple tys)))) -> mangleFun (fun ++ "@" ++ concatMap mangleType tys) + (BaseUserFun (BaseUserFunId fun (Mono ty))) -> mangleFun (fun ++ "@" ++ mangleType ty) + (PrimFun (P_SelFun i _)) -> "ks::get<" ++ show (i - 1) ++ ">" (PrimFun fun) -> render (ppr fun) @@ -627,17 +632,17 @@ cgenUserFun f = case f of cgenAnyFun :: HasCallStack => TFun Typed -> CType -> String cgenAnyFun tf cftype = case tf of - TFun _ (Fun JustFun (PrimFun P_lmApply)) -> "lmApply" - TFun retty (Fun JustFun (PrimFun P_build)) -> + TFun { tf_fun = Fun JustFun (PrimFun P_lmApply) } -> "lmApply" + TFun { tf_fun = Fun JustFun (PrimFun P_build), tf_ret = retty } -> case retty of TypeTensor _ t -> "build<" ++ cgenType (mkCType t) ++ ">" _ -> error ("Unexpected return type for build: " ++ show retty) - TFun retty (Fun JustFun (PrimFun primname)) + TFun { tf_ret = retty, tf_fun = Fun JustFun (PrimFun primname) } | primname `elem` [P_sumbuild, P_buildFromSparse, P_buildFromSparseTupled] -> render (ppr primname) ++ "<" ++ cgenType (mkCType retty) ++ ">" -- This is one of the LM subtypes, e.g. HCat<...> Name is just HCat<...>::mk - TFun (TypeLM _ _) (Fun JustFun (PrimFun _)) -> cgenType cftype ++ "::mk" - TFun _ f -> cgenUserFun f + TFun { tf_ret = TypeLM {}, tf_fun = Fun JustFun (PrimFun {}) } -> cgenType cftype ++ "::mk" + TFun { tf_fun = f } -> cgenUserFun f {- Note [Allocator usage of function calls] @@ -666,8 +671,8 @@ are two cases: -} funUsesAllocator :: HasCallStack => TFun p -> Bool -funUsesAllocator (TFun _ (Fun JustFun (PrimFun (P_SelFun _ _)))) = False -funUsesAllocator (TFun _ (Fun JustFun (PrimFun fname))) = +funUsesAllocator (TFun { tf_fun = Fun JustFun (PrimFun (P_SelFun _ _)) }) = False +funUsesAllocator (TFun { tf_fun = Fun JustFun (PrimFun fname) }) = not $ fname `elem` [P_index, P_size, P_eq, P_ne, P_trace, P_print, P_ts_dot] funUsesAllocator _ = True @@ -707,13 +712,14 @@ cgenTypeLang = \case TypeTuple ts -> "tuple<" ++ intercalate "," (map cgenTypeLang ts) ++ ">" TypeTensor d t -> "tensor<" ++ show d ++ ", " ++ cgenTypeLang t ++ ">" TypeBool -> "bool" + TypeVar {} -> "void*" -- SPJ: guessing here TypeUnknown -> "void" TypeLam from to -> "std::function<" ++ cgenTypeLang to ++ "(" ++ cgenTypeLang from ++ ")>" TypeLM s t -> error $ "LM<" ++ cgenTypeLang s ++ "," ++ cgenTypeLang t ++ ">" ctypeofFun :: HasCallStack => CST -> TFun Typed -> [CType] -> CType -ctypeofFun env (TFun ty f) ctys +ctypeofFun env (TFun { tf_ret = ty, tf_fun = f }) ctys | Just f' <- maybeUserFun f , Just ret_ty <- cstMaybeLookupFun f' env -- trace ("Found fun " ++ show f) $ @@ -815,7 +821,7 @@ cppGen defs = isMainFunction :: TDef -> Bool isMainFunction Def{ def_fun = Fun JustFun f, def_res_ty = TypeInteger } - | BaseUserFunId "main" (TypeTuple []) <- f = True + | BaseUserFunId "main" (Mono (TypeTuple [])) <- f = True isMainFunction _ = False ksoGen :: [TDef] -> String diff --git a/src/ksc/Ksc/CatLang.hs b/src/ksc/Ksc/CatLang.hs index 321989a4c..0b2e255f9 100644 --- a/src/ksc/Ksc/CatLang.hs +++ b/src/ksc/Ksc/CatLang.hs @@ -14,7 +14,7 @@ data CLExpr = CLId | CLPrune [Int] Int CLExpr -- ? | CLKonst Konst - | CLCall Type (Fun Typed) -- The Type is the result type + | CLCall (TFun Typed) | CLComp CLExpr CLExpr -- Composition | CLTuple [CLExpr] -- Tuple | CLIf CLExpr CLExpr CLExpr -- If @@ -24,6 +24,7 @@ data CLExpr -- ^ Fold (Lam $t body) $acc $vector data CLDef = CLDef { cldef_fun :: BaseUserFun Typed + , cldef_qvars :: [TyVar] , cldef_arg :: Pat -- Arg type S , cldef_rhs :: CLExpr , cldef_res_ty :: Type } -- Result type T @@ -66,7 +67,7 @@ pprCLExpr p c@(CLComp {}) pprCLExpr _ CLId = text "Id" pprCLExpr _ (CLKonst k) = ppr k -pprCLExpr _ (CLCall _ f) = ppr f +pprCLExpr _ (CLCall f) = ppr f pprCLExpr p (CLPrune ts n c) = parensIf p precOne $ sep [ text "Prune" <> char '[' <> cat (punctuate comma (map ppr ts)) @@ -107,12 +108,14 @@ toCLDefs defs = mapMaybe toCLDef_maybe defs toCLDef_maybe :: TDef -> Maybe CLDef toCLDef_maybe (Def { def_fun = fun + , def_qvars = qvs , def_pat = pat , def_res_ty = res_ty , def_rhs = rhs }) | Fun JustFun f <- fun , UserRhs e <- rhs = Just CLDef { cldef_fun = f + , cldef_qvars = qvs , cldef_arg = pat , cldef_rhs = toCLExpr (patVars pat) e , cldef_res_ty = res_ty } @@ -177,8 +180,8 @@ to_cl_call pruned env f e Pruned -> CLFold t (toCLExpr (t:env) body) (toCLExpr env acc) (toCLExpr env v) - | TFun ty fun_id <- f - = CLCall ty fun_id `mkCLComp` to_cl_expr pruned env e + | otherwise + = CLCall f `mkCLComp` to_cl_expr pruned env e where call = Call f e @@ -223,10 +226,12 @@ fromCLDefs cldefs = map fromCLDef cldefs fromCLDef :: CLDef -> TDef fromCLDef (CLDef { cldef_fun = f + , cldef_qvars = qvs , cldef_arg = pat , cldef_rhs = rhs , cldef_res_ty = res_ty }) = Def { def_fun = Fun JustFun f + , def_qvars = qvs , def_pat = pat , def_res_ty = res_ty , def_rhs = UserRhs rhs' } @@ -247,12 +252,12 @@ fromCLExpr is arg (CLIf b t e) = If (fromCLExpr is arg b) (fromCLExpr is arg t) (fromCLExpr is arg e) -fromCLExpr _ arg (CLCall ty f) - = Call (TFun ty f) (mkTuple arg) +fromCLExpr _ arg (CLCall f) + = Call f (mkTuple arg) fromCLExpr is arg (CLComp e1 e2) - | CLCall ty f <- e1 -- Shortcut to avoid an unnecessary let - = Call (TFun ty f) (fromCLExpr is arg e2) + | CLCall f <- e1 -- Shortcut to avoid an unnecessary let + = Call f (fromCLExpr is arg e2) | otherwise = mkTempLet is "ax" (fromCLExpr is arg e2) $ \ is v2 -> fromCLExpr is [v2] e1 diff --git a/src/ksc/Ksc/Futhark.hs b/src/ksc/Ksc/Futhark.hs index 658904edb..f33bdf4b1 100644 --- a/src/ksc/Ksc/Futhark.hs +++ b/src/ksc/Ksc/Futhark.hs @@ -10,6 +10,7 @@ import Prelude hiding ( (<>) ) import qualified Cgen import qualified Lang as L +import Lang( TFun(..), DefX(..) ) import Lang (Pretty(..), text, render, empty, parensIf, (<>), (<+>), ($$), parens, brackets, punctuate, sep, integer, double, comma, PrimFun(..)) @@ -363,32 +364,34 @@ callPrimFun f _ args = -- case-by-case basis. toCall :: L.InPhase p => L.TFun p -> L.TExpr -> Exp -toCall (L.TFun _ (L.Fun L.JustFun (L.PrimFun (L.P_SelFun f _)))) e = +toCall (L.TFun { tf_fun = L.Fun L.JustFun (L.PrimFun (L.P_SelFun f _)) }) e = Project (toFutharkExp e) $ show f -toCall (L.TFun ret (L.Fun L.JustFun (L.PrimFun f))) args = +toCall (L.TFun { tf_ret = ret, tf_fun = L.Fun L.JustFun (L.PrimFun f) }) args = callPrimFun f ret args -toCall f@(L.TFun _ (L.Fun L.JustFun L.BaseUserFun{})) args = +toCall f@(L.TFun { tf_fun = L.Fun L.JustFun L.BaseUserFun{} }) args = Call (Var (toTypedName f (L.typeof args))) [toFutharkExp args] -toCall f@(L.TFun _ (L.Fun L.GradFun{} _)) args = +toCall f@(L.TFun { tf_fun = L.Fun L.GradFun{} _ }) args = Call (Var (toTypedName f (L.typeof args))) [toFutharkExp args] -toCall f@(L.TFun _ (L.Fun L.DrvFun{} _)) args = +toCall f@(L.TFun { tf_fun = L.Fun L.DrvFun{} _ }) args = Call (Var (toTypedName f (L.typeof args))) [toFutharkExp args] -toCall f@(L.TFun _ (L.Fun L.ShapeFun{} _)) args = +toCall f@(L.TFun { tf_fun = L.Fun L.ShapeFun{} _ }) args = Call (Var (toTypedName f (L.typeof args))) [toFutharkExp args] -toCall f@(L.TFun _ (L.Fun L.CLFun{} _)) args = +toCall f@(L.TFun { tf_fun = L.Fun L.CLFun{} _ }) args = Call (Var (toTypedName f (L.typeof args))) [toFutharkExp args] toCall _ _ = error "Unsupported Futhark call" toFuthark :: L.TDef -> Def toFuthark d = case LU.noTupPatifyDef d of { - L.Def f (L.VarPat args) res_ty (L.UserRhs e) -> + L.Def { def_fun = f, def_pat = L.VarPat args + , def_res_ty = res_ty + , def_rhs = L.UserRhs e } -> -- SLPJ: Ignoring def_qvars for now DefFun entry fname [] [param] res_ty' (toFutharkExp e) where fname = toTypedName f (L.typeof args) diff --git a/src/ksc/Ksc/Pipeline.hs b/src/ksc/Ksc/Pipeline.hs index 092069e5a..7de096f80 100644 --- a/src/ksc/Ksc/Pipeline.hs +++ b/src/ksc/Ksc/Pipeline.hs @@ -18,7 +18,7 @@ import Lang (Decl, DeclX(DefDecl), DerivedFun(Fun), Derivations(JustFun), def_fun, displayN, partitionDecls, ppr, renderSexp, (<+>)) import qualified Lang as L -import LangUtils (GblSymTab, emptyGblST, extendGblST, stInsertFun) +import LangUtils (GblSymTab, emptyGblST, extendGblST, stInsertFun, lookupDef) import qualified Ksc.Futhark import Parse (parseF) import Rules (mkRuleBase) @@ -29,7 +29,6 @@ import Ksc.SUF.AD (sufFwdRevPassDef, sufRevDef) import Data.Maybe (maybeToList) import Data.List (intercalate) -import qualified Data.Map as Map import GHC.Stack (HasCallStack) ------------------------------------- @@ -172,7 +171,8 @@ more complicated typechecking story. deriveDecl :: GblSymTab -> L.TDecl -> KM (GblSymTab, [L.TDecl]) deriveDecl = deriveDeclUsing $ \env (L.GDef derivation fun) -> do - { let tdef = case Map.lookup fun env of + { -- fun :: UserFun Typed + let tdef = case lookupDef fun env of Nothing -> error $ unwords [ "Internal bug. Error when attempting to gdef." , "TODO: This ought to have been caught by type checking." ] diff --git a/src/ksc/Ksc/SUF/AD.hs b/src/ksc/Ksc/SUF/AD.hs index b05badd66..7e626dbc4 100644 --- a/src/ksc/Ksc/SUF/AD.hs +++ b/src/ksc/Ksc/SUF/AD.hs @@ -10,7 +10,7 @@ import Ksc.Traversal (traverseState) import Lang import LangUtils (notInScopeTVs, stInsertFun, GblSymTab, - freeVarsOf, lookupGblST) + freeVarsOf, lookupDef) import OptLet (Subst, substVar, notInSubstTVs, substBndr, mkEmptySubst) import Prim @@ -127,7 +127,8 @@ sufFwdRevPass gst subst = \case -- REV{f e} dt b = { (b_e, b_f) = b -- ; da = [sufrevpass f] dt b_f } -- ++ REV{e} da b_e - Call (TFun res_ty f) e -> (gradf_call_f_e, typeof bog, sufRevPass_) + Call tfun@(TFun { tf_ret = res_ty }) e + -> (gradf_call_f_e, typeof bog, sufRevPass_) where (subst2, L4 arg b_e r b_f) = notInSubstTVs subst (L4 (TVar typeof_e (Simple "arg")) (TVar e_bog_ty (Simple "b_e")) @@ -136,7 +137,7 @@ sufFwdRevPass gst subst = \case (fwdpass_e, e_bog_ty, revpass_e) = sufFwdRevPass gst subst2 e - (fwdpass_f, revpass_f, f_bog_ty) = mkSufFuns gst f typeof_e + (fwdpass_f, revpass_f, f_bog_ty) = mkSufFuns gst tfun typeof_e bog = mkBog [b_e, b_f] @@ -421,12 +422,14 @@ sufFwdRevPassDefs gst__ = catMaybes . concat . snd . sufRevPassDefsMaybe gst__ -- revpass$f : (dT, B{f}) -> dS sufFwdRevPassDef :: GblSymTab -> TDef -> (Maybe TDef, Maybe TDef, GblSymTab) sufFwdRevPassDef gst Def{ def_fun = Fun JustFun f + , def_qvars = qvars , def_pat = s , def_rhs = UserRhs rhs , def_res_ty = t_ty } = let fwd = Def { def_fun = Fun SUFFwdPass f + , def_qvars = qvars , def_pat = s , def_rhs = UserRhs rhs' , def_res_ty = TypeTuple [t_ty, bog_ty] @@ -443,6 +446,7 @@ sufFwdRevPassDef gst Def{ def_fun = Fun JustFun f lets = foldr (\(p, er) rest -> Let p er . rest) id lets_ rev = Def { def_fun = Fun SUFRevPass f + , def_qvars = qvars , def_pat = TupPat [ dt, bog ] , def_rhs = UserRhs rhs'' , def_res_ty = ds_ty @@ -466,6 +470,7 @@ sufFwdRevPassDef gst _ = (Nothing, Nothing, gst) sufRevDef :: GblSymTab -> TDef -> Maybe TDef sufRevDef gst Def{ def_fun = fun@(Fun JustFun f) + , def_qvars = qvars , def_pat = s -- We don't actually use the rhs. We just look up -- the fwdpass and revpass in the GblSymTab. @@ -473,6 +478,7 @@ sufRevDef gst Def{ def_fun = fun@(Fun JustFun f) , def_res_ty = t_ty } = Just $ Def { def_fun = Fun SUFRev f + , def_qvars = qvars , def_pat = TupPat [s_var, dt] , def_rhs = UserRhs rhs' , def_res_ty = d s_ty @@ -485,8 +491,10 @@ sufRevDef gst Def{ def_fun = fun@(Fun JustFun f) ds = TVar (d s_ty) (Simple "dsres") bog_var = TVar bog_ty (Simple "bog") - (fwdPass, revPass, bog_ty) = - mkSufFuns gst (userFunToFun fun) (typeof s) + tfun = TFun { tf_fun = userFunToFun fun + , tf_targs = map TypeVar qvars + , tf_ret = t_ty } + (fwdPass, revPass, bog_ty) = mkSufFuns gst tfun (typeof s) rhs' = Let (TupPat [t, bog_var]) (pInline (Call fwdPass (Var s_var))) $ Let (VarPat ds) (pInline (Call revPass (Tuple [Var dt, Var bog_var]))) @@ -504,7 +512,7 @@ deltaOfSimple = \case callResultTy :: GblSymTab -> Fun Typed -> Type -> Either SDoc Type callResultTy env fun arg_ty = case perhapsUserFun fun of - Right user -> case lookupGblST user env of + Right user -> case lookupDef user env of Just f -> pure (def_res_ty f) Nothing -> Left (text "Not in scope" <+> ppr user) Left prim -> primCallResultTy_maybe prim arg_ty @@ -520,12 +528,14 @@ callResultTy env fun arg_ty = case perhapsUserFun fun of -- mkSufFuns env f S T -- -- returns (suffwdpass$f, sufrevpass$f, B) -mkSufFuns :: GblSymTab -> Fun Typed -> Type +mkSufFuns :: GblSymTab -> TFun Typed -> Type -> (TFun Typed, TFun Typed, Type) -mkSufFuns env fun arg_ty = (fwdTFun, revTFun, bog_ty) - where fwdTFun = TFun fwd_res_ty fwdFun - revTFun = TFun rev_res_ty (Fun SUFRevPass funid) - fwdFun = Fun SUFFwdPass funid +mkSufFuns env (TFun { tf_fun = fun, tf_targs = targs })arg_ty + = (fwdTFun, revTFun, bog_ty) + where fwdFun = Fun SUFFwdPass funid + revFun = Fun SUFRevPass funid + fwdTFun = TFun { tf_ret = fwd_res_ty, tf_fun = fwdFun, tf_targs = targs } + revTFun = TFun { tf_ret = rev_res_ty, tf_fun = revFun, tf_targs = targs } fwd_res_ty = case callResultTy env fwdFun arg_ty of Right res_ty' -> res_ty' diff --git a/src/ksc/Lang.hs b/src/ksc/Lang.hs index 51e1d22fd..cf46e4b12 100644 --- a/src/ksc/Lang.hs +++ b/src/ksc/Lang.hs @@ -52,6 +52,7 @@ type TDecl = DeclX Typed data DefX p -- f x = e = Def { def_fun :: UserFun p + , def_qvars :: [TyVar] , def_pat :: Pat -- See Note [Function arity] , def_res_ty :: TypeX -- Result type , def_rhs :: RhsX p } @@ -210,8 +211,11 @@ unzipTEs (TE e t : tes) = (e:es, t:ts) where (es, ts) = unzipTEs tes +type TyVar = String + data TypeX - = TypeBool + = TypeVar TyVar + | TypeBool | TypeInteger | TypeFloat | TypeString @@ -269,6 +273,21 @@ zeroIndexForDimension 1 = kInt 0 zeroIndexForDimension d = mkTuple (replicate d (kInt 0)) +---------------------------------- +--- A monotype has no variables in it + +isMonoType :: Type -> Bool +isMonoType TypeBool = True +isMonoType TypeInteger = True +isMonoType TypeFloat = True +isMonoType TypeString = True +isMonoType TypeUnknown = True +isMonoType (TypeVar {}) = False +isMonoType (TypeTuple ts) = all isMonoType ts +isMonoType (TypeTensor _ t) = isMonoType t +isMonoType (TypeLam t1 t2) = isMonoType t1 && isMonoType t2 +isMonoType (TypeLM t1 t2) = isMonoType t1 && isMonoType t2 + ---------------------------------- --- Tangent space @@ -295,9 +314,10 @@ shapeType TypeFloat = TypeTuple [] shapeType TypeString = TypeTuple [] shapeType (TypeTuple ts) = TypeTuple (map shapeType ts) shapeType (TypeTensor d vt) = TypeTensor d (shapeType vt) -shapeType (TypeLam _ _) = TypeUnknown -shapeType (TypeLM _ _) = TypeUnknown -- TBD -shapeType TypeUnknown = TypeUnknown +shapeType (TypeLam {}) = TypeUnknown +shapeType (TypeLM {}) = TypeUnknown -- TBD +shapeType (TypeVar {}) = TypeUnknown +shapeType TypeUnknown = TypeUnknown {- Note [Shapes] ~~~~~~~~~~~~~~~~ @@ -401,16 +421,21 @@ types. -} -data BaseUserFun p = BaseUserFunId String (BaseUserFunArgTy p) +type BaseUserFunName = String + +data BaseUserFun p = BaseUserFunId BaseUserFunName (BaseUserFunArgTy p) deriving instance Eq (BaseUserFunArgTy p) => Eq (BaseUserFun p) deriving instance Ord (BaseUserFunArgTy p) => Ord (BaseUserFun p) deriving instance Show (BaseUserFunArgTy p) => Show (BaseUserFun p) type family BaseUserFunArgTy p where - BaseUserFunArgTy Parsed = Maybe Type - BaseUserFunArgTy OccAnald = Type - BaseUserFunArgTy Typed = Type + BaseUserFunArgTy Parsed = Maybe ArgTypeDescriptor + BaseUserFunArgTy OccAnald = ArgTypeDescriptor --- was: Type + BaseUserFunArgTy Typed = ArgTypeDescriptor + +data ArgTypeDescriptor = Mono Type | Poly + deriving( Show, Eq, Ord ) data BaseFun (p :: Phase) = BaseUserFun (BaseUserFun p) -- BaseUserFuns have a Def @@ -433,7 +458,7 @@ data Derivations deriving (Eq, Ord, Show) data DerivedFun funid = Fun Derivations funid - deriving (Eq, Ord, Show) + deriving (Eq, Ord, Show, Functor) -- DerivedFun has just two instantiations -- @@ -461,7 +486,7 @@ baseFunFun f (Fun ds fi) = fmap (Fun ds) (f fi) userFunBaseType :: forall p. InPhase p => T.Lens (UserFun p) (UserFun Typed) - (Maybe Type) Type + (Maybe ArgTypeDescriptor) ArgTypeDescriptor userFunBaseType = baseFunFun . baseUserFunType @p funType :: T.Traversal (Fun p) (Fun q) @@ -472,17 +497,23 @@ funType = baseFunFun . baseUserFunBaseFun . baseUserFunT -- otherwise (and in other phases, where the type is there) check that -- the type matches. If mis-match return (Left -- type-that-was-in-UserFun) -addBaseTypeToUserFun :: forall p. InPhase p - => UserFun p -> Type -> Either Type (UserFun Typed) -addBaseTypeToUserFun userfun expectedBaseTy = T.traverseOf (userFunBaseType @p) checkBaseType userfun - where checkBaseType :: Maybe Type -> Either Type Type +addArgTypeDescriptor :: forall p. InPhase p + => UserFun p -> ArgTypeDescriptor + -> Either ArgTypeDescriptor (UserFun Typed) +addArgTypeDescriptor userfun expectedBaseTy + = T.traverseOf (userFunBaseType @p) checkBaseType userfun + where checkBaseType :: Maybe ArgTypeDescriptor -> Either ArgTypeDescriptor ArgTypeDescriptor checkBaseType maybeAppliedType | Just appliedTy <- maybeAppliedType - , not (eqType appliedTy expectedBaseTy) + , not (eqArgTypeDescriptor appliedTy expectedBaseTy) = Left appliedTy | otherwise = Right expectedBaseTy +eqArgTypeDescriptor :: ArgTypeDescriptor -> ArgTypeDescriptor -> Bool +eqArgTypeDescriptor Poly Poly = True +eqArgTypeDescriptor (Mono t1) (Mono t2) = t1 `eqType` t2 +eqArgTypeDescriptor _ _ = False userFunToFun :: UserFun p -> Fun p userFunToFun = T.over baseFunFun BaseUserFun @@ -502,8 +533,8 @@ maybeUserFun f = case perhapsUserFun f of baseFunToBaseUserFunE :: BaseFun p -> Either PrimFun (BaseUserFun p) baseFunToBaseUserFunE = \case - BaseUserFun u -> Right u - PrimFun p -> Left p + BaseUserFun u -> Right u + PrimFun p -> Left p baseFunToBaseUserFun :: BaseFun p -> Maybe (BaseUserFun p) baseFunToBaseUserFun f = case baseFunToBaseUserFunE f of @@ -531,7 +562,19 @@ data ADPlan = BasicAD | TupleAD data ADDir = Fwd | Rev deriving( Eq, Ord, Show ) -data TFun p = TFun Type (Fun p) -- Typed functions. The type is the /return/ +data TFun p -- Typed functions. The type is the /return/ + = TFun { tf_fun :: Fun p -- The function + , tf_targs :: [Type] -- Type arguments (always same length as the + -- def_qvars in its Def) + , tf_ret :: Type } -- Return type; can always be computed from + -- the function plus instantiating types, but + -- it is convenient to cache it + +tFunResTy :: TFun p -> Type +tFunResTy = tf_ret + +tFunFun :: TFun p -> Fun p +tFunFun = tf_fun deriving instance Eq (Fun p) => Eq (TFun p) deriving instance Ord (Fun p) => Ord (TFun p) @@ -540,7 +583,7 @@ deriving instance Ord (Fun p) => Ord (TFun p) -- GHC's machinery to allow that. coerceTFun :: BaseUserFunArgTy p ~ BaseUserFunArgTy q => TFun p -> TFun q -coerceTFun (TFun t f) = TFun t (T.over funType id f) +coerceTFun fun@(TFun { tf_fun = f }) = fun { tf_fun = T.over funType id f } data Var @@ -652,10 +695,12 @@ dropLast :: [a] -> [a] dropLast xs = take (length xs - 1) xs pSel :: Int -> Int -> TExpr -> TExpr -pSel i n e = Call (TFun el_ty - (Fun JustFun (PrimFun (P_SelFun i n)))) e +pSel i n e = Call tfun e where - el_ty = case typeof e of + tfun = TFun { tf_ret = el_ty + , tf_fun = Fun JustFun (PrimFun (P_SelFun i n)) + , tf_targs = [] } -- Dubious + el_ty = case typeof e of TypeTuple ts -> ts !! (i-1) _ -> TypeUnknown -- Better error from Lint @@ -680,7 +725,7 @@ instance HasType TypedExpr where typeof (TE _ ty) = ty instance HasType (TFun p) where - typeof (TFun ty _) = ty + typeof (TFun { tf_ret = ty }) = ty instance HasType TExpr where typeof (Dummy ty) = ty @@ -881,7 +926,7 @@ class InPhase p where getLetBndr :: LetBndrX p -> (Var, Maybe Type) baseUserFunType :: T.Lens (BaseUserFun p) (BaseUserFun Typed) - (Maybe Type) Type + (Maybe ArgTypeDescriptor) ArgTypeDescriptor instance InPhase Parsed where pprVar = ppr @@ -889,7 +934,7 @@ instance InPhase Parsed where pprFunOcc = ppr pprBaseUserFun (BaseUserFunId name mty) = case mty of Nothing -> text name - Just ty -> brackets (text name <+> pprParendType ty) + Just atd -> brackets (text name <+> ppr atd) getVar var = (var, Nothing) getFun fun = (fun, Nothing) @@ -901,11 +946,11 @@ instance InPhase Typed where pprVar = ppr pprLetBndr = pprTVar pprFunOcc = ppr - pprBaseUserFun (BaseUserFunId name ty) = - brackets (text name <+> pprParendType ty) + pprBaseUserFun (BaseUserFunId name atd) = + brackets (text name <+> ppr atd) getVar (TVar ty var) = (var, Just ty) - getFun (TFun ty fun) = (fun', Just ty) + getFun (TFun { tf_fun = fun, tf_ret = ty }) = (fun', Just ty) where fun' = T.over funType Just fun getLetBndr (TVar ty var) = (var, Just ty) @@ -915,18 +960,22 @@ instance InPhase OccAnald where pprVar = ppr pprLetBndr (n,tv) = pprTVar tv <> braces (int n) pprFunOcc = ppr - pprBaseUserFun (BaseUserFunId name ty) = - brackets (text name <+> pprParendType ty) + pprBaseUserFun (BaseUserFunId name atd) = + brackets (text name <+> ppr atd) getVar (TVar ty var) = (var, Just ty) - getFun (TFun ty fun) = (fun', Just ty) + getFun (TFun { tf_fun = fun, tf_ret = ty }) = (fun', Just ty) where fun' = T.over funType Just fun getLetBndr (_, TVar ty var) = (var, Just ty) baseUserFunType g (BaseUserFunId f t) = fmap (BaseUserFunId f) (g (Just t)) pprTFun :: InPhase p => TFun p -> SDoc -pprTFun (TFun ty f) = ppr f <+> text ":" <+> ppr ty +pprTFun (TFun { tf_fun = f, tf_targs = targs, tf_ret = ty }) + = ppr f <> pp_targs <+> text ":" <+> ppr ty + where + pp_targs | null targs = empty + | otherwise = char '@' <> parens (pprList ppr targs) class Pretty a where @@ -972,8 +1021,11 @@ instance Pretty Var where instance InPhase p => Pretty (BaseFun p) where ppr = pprBaseFun +instance Pretty SDoc where + ppr d = d + instance Pretty funid => Pretty (DerivedFun funid) where - ppr = pprDerivedFun ppr + ppr fun = pprDerivedFun (fmap ppr fun) instance Pretty PrimFun where ppr = pprPrimFun @@ -1036,18 +1088,18 @@ pprPrimFun = \case P_elim -> text "elim" pprUserFun :: forall p. InPhase p => UserFun p -> SDoc -pprUserFun = pprDerivedFun (pprBaseUserFun @p) - -pprDerivedFun :: (funid -> SDoc) -> DerivedFun funid -> SDoc -pprDerivedFun f (Fun JustFun s) = f s -pprDerivedFun f (Fun (GradFun adp) s) = brackets (char 'D' <> ppr adp <+> f s) -pprDerivedFun f (Fun (DrvFun (AD adp Fwd)) s) = brackets (text "fwd" <> ppr adp <+> f s) -pprDerivedFun f (Fun (DrvFun (AD adp Rev)) s) = brackets (text "rev" <> ppr adp <+> f s) -pprDerivedFun f (Fun (ShapeFun ds) sf) = brackets (text "shape" <+> pprDerivedFun f (Fun ds sf)) -pprDerivedFun f (Fun CLFun s) = brackets (text "CL" <+> f s) -pprDerivedFun f (Fun SUFFwdPass s) = brackets (text "suffwdpass" <+> f s) -pprDerivedFun f (Fun SUFRevPass s) = brackets (text "sufrevpass" <+> f s) -pprDerivedFun f (Fun SUFRev s) = brackets (text "sufrev" <+> f s) +pprUserFun fun = pprDerivedFun (fmap (pprBaseUserFun @p) fun) + +pprDerivedFun :: DerivedFun SDoc -> SDoc +pprDerivedFun (Fun JustFun s) = s +pprDerivedFun (Fun (GradFun adp) s) = brackets (char 'D' <> ppr adp <+> s) +pprDerivedFun (Fun (DrvFun (AD adp Fwd)) s) = brackets (text "fwd" <> ppr adp <+> s) +pprDerivedFun (Fun (DrvFun (AD adp Rev)) s) = brackets (text "rev" <> ppr adp <+> s) +pprDerivedFun (Fun (ShapeFun ds) sf) = brackets (text "shape" <+> pprDerivedFun (Fun ds sf)) +pprDerivedFun (Fun CLFun s) = brackets (text "CL" <+> s) +pprDerivedFun (Fun SUFFwdPass s) = brackets (text "suffwdpass" <+> s) +pprDerivedFun (Fun SUFRevPass s) = brackets (text "sufrevpass" <+> s) +pprDerivedFun (Fun SUFRev s) = brackets (text "sufrev" <+> s) instance Pretty Pat where pprPrec _ p = pprPat True p @@ -1056,7 +1108,7 @@ instance Pretty TVar where pprPrec _ (TVar _ v) = ppr v instance InPhase p => Pretty (TFun p) where - ppr (TFun _ f) = ppr f + ppr (TFun { tf_fun = f}) = ppr f instance Pretty Konst where pprPrec _ (KInteger i) = integer i @@ -1074,6 +1126,7 @@ instance Pretty TypeX where pprPrec p (TypeLam from to) = parensIf p precZero $ text "Lam" <+> ppr from <+> ppr to pprPrec p (TypeLM s t) = parensIf p precTyApp $ text "LM" <+> pprParendType s <+> pprParendType t + pprPrec _ (TypeVar tv) = text tv pprPrec _ TypeFloat = text "Float" pprPrec _ TypeInteger = text "Integer" pprPrec _ TypeString = text "String" @@ -1214,6 +1267,10 @@ pprDef (Def { def_fun = f, def_pat = vs, def_res_ty = res_ty, def_rhs = rhs }) instance InPhase p => Pretty (BaseUserFun p) where ppr = pprBaseUserFun +instance Pretty ArgTypeDescriptor where + ppr Poly = text "poly" + ppr (Mono ty) = pprParendType ty + pprPat :: Bool -> Pat -> SDoc -- True <=> wrap tuple pattern in parens pprPat _ (VarPat v) = pprTVar v diff --git a/src/ksc/LangUtils.hs b/src/ksc/LangUtils.hs index 63c5a454d..82864e744 100644 --- a/src/ksc/LangUtils.hs +++ b/src/ksc/LangUtils.hs @@ -14,6 +14,10 @@ module LangUtils ( -- Substitution substEMayCapture, + -- Matching + matchTy, + TySubst, mkTySubst, tySubstTy, tySubstTyVar, tySubstPat, tySubstExpr, + -- Equality cmpExpr, @@ -26,7 +30,9 @@ module LangUtils ( LangUtils.hspec, test_FreeIn, -- Symbol table - GblSymTab, extendGblST, lookupGblST, emptyGblST, modifyGblST, + GblSymTab, FunBindings(..), + extendGblST, lookupDef, lookupBindings, + emptyGblST, modifyGblST, stInsertFun, LclSymTab, extendLclST, SymTab(..), newSymTab, emptySymTab, @@ -100,6 +106,7 @@ substEMayCapture subst (Let v r b) = Let v (substEMayCapture subst r) $ where bindersAsMap :: PatG TVar -> M.Map TVar () bindersAsMap = M.fromList . map (\x -> (x, ())) . patVars + ----------------------------------------------- -- Free variables ----------------------------------------------- @@ -150,7 +157,9 @@ hspec = do let var :: String -> TVar var s = TVar TypeFloat (Simple s) fun :: String -> TFun Typed - fun s = TFun TypeFloat (Fun JustFun (BaseUserFun (BaseUserFunId s TypeFloat))) + fun s = TFun { tf_ret = TypeFloat + , tf_fun = Fun JustFun (BaseUserFun (BaseUserFunId s (Mono TypeFloat))) + , tf_targs = [] } e = Call (fun "f") (Var (var "i")) e2 = Call (fun "f") (Tuple [Var (var "_t1"), kInt 5]) describe "notFreeIn" $ do @@ -173,11 +182,15 @@ test_FreeIn = Test.Hspec.hspec LangUtils.hspec ----------------------------------------------- -- Global symbol table -type GblSymTab = M.Map (UserFun Typed) TDef - -- Maps a function to its definition, which lets us +type GblSymTab = M.Map (DerivedFun BaseUserFunName) FunBindings +data FunBindings = PolyBind TDef + | MonoBinds (M.Map Type TDef) + -- A function name like fwd$f maps to + -- either a /single/ polymorphic binding + -- or /multiple/ monomorphic bindings + -- This info lets us -- * Find its return type -- * Inline it - -- Domain is UserFun Typed -- Local symbol table type LclSymTab = M.Map Var Type @@ -189,6 +202,10 @@ data SymTab , lclST :: LclSymTab } +instance Pretty FunBindings where + ppr (PolyBind def) = text "Poly:" <+> ppr def + ppr (MonoBinds defs) = text "Mono:" <+> ppr defs + instance (Pretty k, Pretty v) => Pretty (M.Map k v) where ppr m = braces $ fsep $ punctuate comma $ [ ppr k <+> text ":->" <+> ppr v | (k,v) <- M.toList m ] @@ -210,10 +227,33 @@ newSymTab :: GblSymTab -> SymTab newSymTab gbl_env = ST { gblST = gbl_env, lclST = M.empty } stInsertFun :: TDef -> GblSymTab -> GblSymTab -stInsertFun def@(Def { def_fun = f }) = M.insert f def +stInsertFun def@(Def { def_fun = Fun der (BaseUserFunId name descr) }) gst + = case descr of + Poly -> M.insert top_key (PolyBind def) gst + Mono ty -> M.insertWith add_bind top_key (MonoBinds (M.singleton ty def)) gst + where + top_key = Fun der name + + add_bind :: FunBindings -> FunBindings -> FunBindings + add_bind (MonoBinds bs1) (MonoBinds bs2) = MonoBinds (M.union bs1 bs2) + add_bind _ bs = bs + +lookupDef :: UserFun Typed -> GblSymTab -> Maybe TDef +lookupDef (Fun der (BaseUserFunId name descr)) gst + = case descr of + Poly | Just (PolyBind def) <- M.lookup top_key gst + -> Just def -lookupGblST :: HasCallStack => UserFun Typed -> GblSymTab -> Maybe TDef -lookupGblST = M.lookup + Mono ty | Just (MonoBinds defs) <- M.lookup top_key gst + , Just def <- M.lookup ty defs + -> Just def + + _other -> Nothing + where + top_key = Fun der name + +lookupBindings :: DerivedFun BaseUserFunName -> GblSymTab -> Maybe FunBindings +lookupBindings = M.lookup extendGblST :: GblSymTab -> [TDef] -> GblSymTab extendGblST = foldl (flip stInsertFun) @@ -403,3 +443,89 @@ noTupPatifyExpr in_scope = \case Konst k -> Konst k Var v -> Var v Dummy d -> Dummy d + + +----------------------------------------------- +-- Type matching and substitution +----------------------------------------------- + +type TySubst = M.Map TyVar Type + +emptyTySubst :: TySubst +emptyTySubst = M.empty + +mkTySubst :: [(TyVar,Type)] -> TySubst +mkTySubst = M.fromList + +tySubstTyVar :: TySubst -> TyVar -> Type +tySubstTyVar s tv = case M.lookup tv s of + Just ty -> ty + Nothing -> TypeVar tv + +tySubstTy :: TySubst -> Type -> Type +tySubstTy s ty + = go ty + where + go (TypeVar tv) = tySubstTyVar s tv + go (TypeTuple tys) = TypeTuple (map go tys) + go (TypeTensor n ty) = TypeTensor n (go ty) + go (TypeLam t1 t2) = TypeLam (go t1) (go t2) + go (TypeLM t1 t2) = TypeLM (go t1) (go t2) + -- Other types have no free variables + go ty = ty + +tySubstPat :: TySubst -> Pat -> Pat +tySubstPat s = fmap (tySubstTVar s) + +tySubstTVar :: TySubst -> TVar -> TVar +tySubstTVar s (TVar ty v) = TVar (tySubstTy s ty) v + +tySubstTFun :: TySubst -> TFun Typed -> TFun Typed +tySubstTFun s tfun@(TFun { tf_targs = targs, tf_ret = ret }) + = tfun { tf_targs = map (tySubstTy s) targs, tf_ret = tySubstTy s ret } + +tySubstExpr :: TySubst -> TExpr -> TExpr +tySubstExpr s e = go e + where + go (Var v) = Var (tySubstTVar s v) + go e@(Dummy {}) = e + go e@(Konst _) = e + go (Tuple es) = Tuple (map go es) + go (If b t e) = If (go b) (go t) (go e) + go (Call f e) = Call (tySubstTFun s f) (go e) + go (App f a) = App (go f) (go a) + go (Let p r b) = Let (tySubstPat s p) (go r) (go b) + go (Lam v e) = Lam (tySubstTVar s v) (go e) + go (Assert e1 e2) = Assert (go e1) (go e2) + +matchTy :: Type -> Type -> Maybe TySubst +matchTy t1 t2 + = go emptyTySubst t1 t2 + where + go s (TypeVar tv) ty + = case M.lookup tv s of + Just ty' | ty `eqType` ty' -> Just s + | otherwise -> Nothing + Nothing -> Just (M.insert tv ty s) + go s (TypeTuple ts1) (TypeTuple ts2) + = go_s s ts1 ts2 + go s (TypeTensor n1 t1) (TypeTensor n2 t2) + | n1==n2 + = go s t1 t2 + go s (TypeLam a1 r1) (TypeLam a2 r2) + = do { s1 <- go s a1 a2; go s1 r1 r2 } + go s (TypeLM a1 r1) (TypeLM a2 r2) + = do { s1 <- go s a1 a2; go s1 r1 r2 } + go s TypeBool TypeBool = Just s + go s TypeInteger TypeInteger = Just s + go s TypeFloat TypeFloat = Just s + go s TypeString TypeString = Just s + + go s TypeUnknown _ = Just s + + go _ _ _ = Nothing + + go_s s [] [] = Just s + go_s s (t1:ts1) (t2:ts2) = do { s1 <- go s t1 t2 + ; go_s s1 ts1 ts2 } + go_s _ _ _ = Nothing \ No newline at end of file diff --git a/src/ksc/Opt.hs b/src/ksc/Opt.hs index 5436b6d4f..5223a7fc4 100644 --- a/src/ksc/Opt.hs +++ b/src/ksc/Opt.hs @@ -210,29 +210,28 @@ rewriteCall _ fun (Let v r arg) rewriteCall _ fun (If e1 e2 e3) = Just (If e1 (Call fun e2) (Call fun e3)) -rewriteCall env (TFun _ (Fun JustFun fun)) arg +rewriteCall env (TFun { tf_fun = Fun JustFun fun }) arg = optFun env fun arg -rewriteCall env (TFun ty (Fun (GradFun adm) f)) arg +rewriteCall env (TFun { tf_fun = Fun (GradFun adm) f, tf_ret = ty }) arg = optGradFun (optEnvInScope env) adm ty f arg -rewriteCall _ (TFun _ (Fun (DrvFun adm) f)) arg +rewriteCall _ (TFun { tf_fun = Fun (DrvFun adm) f }) arg = optDrvFun adm f arg -rewriteCall _ f@(TFun (TypeLM _ _) _) _ +rewriteCall _ f@(TFun { tf_ret = TypeLM {} }) _ = trace ("NOTE: Unmatched LM call {" ++ pps f ++ "}") Nothing -rewriteCall env (TFun _ to_inline) arg +rewriteCall env (TFun { tf_fun = to_inline, tf_targs = targs }) arg | Just to_inline <- maybeUserFun to_inline , shouldInline to_inline - , Just Def{ def_pat = pat, def_rhs = UserRhs body } - <- lookupGblST to_inline (optGblST env) - = Just (inlineCall env pat body arg) + , Just def <- lookupDef to_inline (optGblST env) + = inlineCall env def targs arg -rewriteCall _ (TFun _ (Fun SUFFwdPass (PrimFun fun))) arg +rewriteCall _ (TFun { tf_fun = Fun SUFFwdPass (PrimFun fun) }) arg = SUF.rewriteSUFFwdPass fun arg -rewriteCall _ (TFun _ (Fun SUFRevPass (PrimFun fun))) arg +rewriteCall _ (TFun { tf_fun = Fun SUFRevPass (PrimFun fun) }) arg = SUF.rewriteSUFRevPass fun arg rewriteCall _ _ _ @@ -247,9 +246,10 @@ shouldInline to_inline , fwd "add" ff , rev "add" ff ] - where fwd f t = Fun SUFFwdPass (BaseUserFunId f t) - rev f t = Fun SUFRevPass (BaseUserFunId f t) - ff = TypeTuple [TypeFloat, TypeFloat] + where + fwd f t = Fun SUFFwdPass (BaseUserFunId f t) + rev f t = Fun SUFRevPass (BaseUserFunId f t) + ff = Mono (TypeTuple [TypeFloat, TypeFloat]) ----------------------- optFun :: OptEnv -> BaseFun p -> TExpr -> Maybe TExpr @@ -265,11 +265,10 @@ optFun _ (PrimFun (P_SelFun i _)) arg -- $inline needs to look up the global symtab optFun env (PrimFun P_inline) arg - | Call (TFun _ fun) inner_arg <- arg + | Call (TFun { tf_fun = fun, tf_targs = targs }) inner_arg <- arg , Just userFun <- maybeUserFun fun - , Just fun_def <- lookupGblST userFun (optGblST env) - , Def { def_pat = pat, def_rhs = UserRhs body } <- fun_def - = Just (inlineCall env pat body inner_arg) + , Just fun_def <- lookupDef userFun (optGblST env) + = inlineCall env fun_def targs inner_arg -- Other prims are determined by their args optFun env (PrimFun f) e @@ -454,22 +453,36 @@ optLMCompose f g ----------------------- inlineCall :: OptEnv - -> Pat -> TExpr -- Function parameters and body - -> TExpr -- Arguments - -> TExpr -inlineCall _ (VarPat tv) body arg - = mkLet tv arg body - -inlineCall env (TupPat tvs) body arg - = mkLets (fresh_tvs `zip` args) $ - -- See Note [Avoid name clashes in inlineCall] - mkLets [ (tv, Var fresh_tv) - | (tv,fresh_tv) <- tvs `zip` fresh_tvs - , tv /= fresh_tv ] - body - where - args = splitTuple arg (length tvs) - (_, fresh_tvs) = notInScopeTVs (optEnvInScope env) tvs + -> TDef -- Function definition + -> [Type] -- Type arguments + -> TExpr -- Value argument + -> Maybe TExpr +inlineCall env def targs val_arg + | Just (pat, body) <- instantiateDef def targs + = Just $ case pat of + VarPat tv -> mkLet tv val_arg body + TupPat tvs -> mkLets (fresh_tvs `zip` args) $ + -- See Note [Avoid name clashes in inlineCall] + mkLets [ (tv, Var fresh_tv) + | (tv,fresh_tv) <- tvs `zip` fresh_tvs + , tv /= fresh_tv ] + body + where + args = splitTuple val_arg (length tvs) + (_, fresh_tvs) = notInScopeTVs (optEnvInScope env) tvs + | otherwise + = Nothing + +instantiateDef :: TDef -> [Type] -> Maybe (Pat, TExpr) +instantiateDef (Def { def_qvars = qvars, def_pat = pat, def_rhs = rhs }) targs + | UserRhs body <- rhs + , length qvars == length targs + , let ty_subst = mkTySubst (qvars `zip` targs) + = Just (if null qvars + then (pat, body) + else (tySubstPat ty_subst pat, tySubstExpr ty_subst body)) + | otherwise + = Nothing {- Note [Avoid name clashes in inlineCall] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -642,9 +655,11 @@ optGradFun _ _ _ (BaseUserFun {}) _ optGradFun env TupleAD ty f args | TypeTuple [res_ty, lm_ty] <- ty , Just opt_grad <- optGradFun env BasicAD lm_ty f new_args + , let tfun = TFun { tf_ret = res_ty, tf_fun = Fun JustFun f + , tf_targs = [] } -- SLPJ: fix me = Just $ mkLets binds $ - Tuple [ Call (TFun res_ty (Fun JustFun f)) new_args, opt_grad ] + Tuple [ Call tfun new_args, opt_grad ] | otherwise = Nothing where @@ -790,29 +805,35 @@ optLMApply _ (AD TupleAD dir) (Tuple [_, lm]) dx -- Called for (lmApply (lm* es) dx) -- In BasicAD only -optLMApply env (AD BasicAD dir) (Call (TFun _ (Fun JustFun (PrimFun f))) es) dx +optLMApply env (AD BasicAD dir) (Call (TFun { tf_fun = Fun JustFun (PrimFun f) }) es) dx = optLMApplyCall env dir f es dx -- Looking at: D$f(e1, e2) `lmApply` dx -- f :: S1 S2 -> T -- D$f :: S1 S2 -> ((S1,S2) -o T) -- fwd$f :: S1 S2 S1_t S2_t -> T_t -optLMApply _ (AD adp1 Fwd) (Call (TFun (TypeLM _ t) (Fun (GradFun adp2) f)) es) dx - | adp1 == adp2 +optLMApply _ (AD adp1 Fwd) (Call tfun es) dx + | TFun { tf_ret = TypeLM _ t, tf_fun = Fun (GradFun adp2) f } <- tfun + , adp1 == adp2 + , let grad_fun = TFun { tf_ret = tangentType t + , tf_fun = Fun (DrvFun (AD adp1 Fwd)) f + , tf_targs = [] } -- SLPJ: fix me = Just (Call grad_fun es_dx) where - grad_fun = TFun (tangentType t) (Fun (DrvFun (AD adp1 Fwd)) f) es_dx = Tuple [es, dx] -- Looking at: dr `lmApplyR` D$f(e1, e2) -- f :: S1 S2 -> T -- D$f :: S1 S2 -> ((S1,S2) -o T) -- rev$f :: S1 S2 T_ -> (S1_t,S2_t) -optLMApply _ (AD adp1 Rev) (Call (TFun (TypeLM s _) (Fun (GradFun adp2) f)) es) dx - | adp1 == adp2 +optLMApply _ (AD adp1 Rev) (Call tfun es) dx + | TFun { tf_ret = TypeLM s _, tf_fun = Fun (GradFun adp2) f } <- tfun + , adp1 == adp2 + , let grad_fun = TFun { tf_ret = tangentType s + , tf_fun = Fun (DrvFun (AD adp1 Rev)) f + , tf_targs = [] } -- SLPJ: fix me = Just (Call grad_fun es_dx) where - grad_fun = TFun (tangentType s) (Fun (DrvFun (AD adp1 Rev)) f) es_dx = Tuple [es, dx] {- diff --git a/src/ksc/Parse.hs b/src/ksc/Parse.hs index abd81bb77..fb58baca4 100644 --- a/src/ksc/Parse.hs +++ b/src/ksc/Parse.hs @@ -14,9 +14,12 @@ Here's the BNF for our language: ::= ( rule ) - ::= ( def ) + ::= ( def ) -- (def f Float ( (x : Float) (y : Vec Float) ) (...) ) + ::= -- Empty + | [ ... ] + ::= ( edef ) ::= ( gdef ) @@ -99,7 +102,8 @@ Notes: import Lang hiding (parens, brackets) -import Text.Parsec( (<|>), try, many, parse, eof, manyTill, ParseError, unexpected ) +import Text.Parsec( (<|>), try, many, many1, parse, eof, manyTill, ParseError + , unexpected, option ) import Text.Parsec.Char import Text.Parsec.String (Parser) @@ -245,6 +249,7 @@ pType = (pReserved "Integer" >> return TypeInteger) <|> (pReserved "Float" >> return TypeFloat) <|> (pReserved "String" >> return TypeString) <|> (pReserved "Bool" >> return TypeBool) + <|> (do { tv <- pIdentifier; return (TypeVar tv) }) <|> parens pKType pTypes :: Parser [TypeX] @@ -353,11 +358,12 @@ pSelFun = do { rest <- try $ do { f <- pIdentifier Just selfun -> pure selfun } -pBaseUserFunWithType :: (Type -> BaseUserFunArgTy p) -> Parser (BaseUserFun p) +pBaseUserFunWithType :: (ArgTypeDescriptor -> BaseUserFunArgTy p) -> Parser (BaseUserFun p) +-- No way to parse a user-written Poly ArgTypeDescriptor yet pBaseUserFunWithType add = brackets (do { f <- pIdentifier ; ty <- pType - ; pure (BaseUserFunId f (add ty)) + ; pure (BaseUserFunId f (add (Mono ty))) }) pBaseUserFunWithoutType :: Parser (BaseUserFun Parsed) @@ -402,6 +408,7 @@ pDef :: Parser Def -- (def f Type ((x1 : Type) (x2 : Type) (x3 : Type)) rhs) pDef = do { pReserved "def" ; f <- pFun + ; qvs <- pOptQVars ; ty <- pType ; xs <- pParams ; rhs <- pExpr @@ -411,10 +418,15 @@ pDef = do { pReserved "def" [x] -> VarPat x xs -> TupPat xs ; return (Def { def_fun = mk_fun_f + , def_qvars = qvs , def_pat = pat , def_rhs = UserRhs rhs , def_res_ty = ty }) } +pOptQVars :: Parser [TyVar] +pOptQVars = option [] $ + brackets (many1 pIdentifier) + pRule :: Parser Rule pRule = do { pReserved "rule" ; name <- pString @@ -427,10 +439,12 @@ pRule = do { pReserved "rule" pEdef :: Parser (DefX Parsed) pEdef = do { pReserved "edef" ; f <- pFun + ; qvs <- pOptQVars ; returnType <- pType ; argType <- pType ; mk_fun_name <- pIsUserFun f ; return (Def { def_fun = mk_fun_name + , def_qvars = qvs , def_res_ty = returnType -- See note [Function arity] , def_pat = VarPat (mkTVar argType "edefArgVar") diff --git a/src/ksc/Prim.hs b/src/ksc/Prim.hs index 52a12f421..76157009a 100644 --- a/src/ksc/Prim.hs +++ b/src/ksc/Prim.hs @@ -18,12 +18,19 @@ import Control.Monad (zipWithM) primCall :: PrimFun -> Type -> TExpr -> TExpr primCall fun res_ty - = Call (TFun res_ty (Fun JustFun (PrimFun fun))) + = Call (TFun { tf_ret = res_ty + , tf_fun = Fun JustFun (PrimFun fun) + , tf_targs = [] }) userCall :: String -> Type -> TExpr -> TExpr userCall fun res_ty arg - = Call (TFun res_ty (Fun JustFun (BaseUserFun (BaseUserFunId fun arg_ty)))) arg - where arg_ty = typeof arg + = assert (text "userCall") (isMonoType arg_ty) $ + Call tfun arg + where + arg_ty = typeof arg + tfun = TFun { tf_ret = res_ty + , tf_fun = Fun JustFun (BaseUserFun (BaseUserFunId fun (Mono arg_ty))) + , tf_targs = [] } mkPrimCall :: HasCallStack => PrimFun -> TExpr -> TExpr mkPrimCall fun args @@ -78,9 +85,10 @@ getZero tangent_type e -> mkAtomicNoFVs e $ \e -> Tuple $ map go $ [ pSel i n e | i <- [1..n] ] - TypeLam _ _ -> panic - TypeLM _ _ -> panic + TypeLam {} -> panic + TypeLM {} -> panic TypeUnknown -> panic + TypeVar {} -> panic where e_ty = typeof e panic = pprPanic "mkZero" (ppr e_ty $$ ppr e) @@ -245,7 +253,8 @@ lmCompose_Dir Fwd m1 m2 = m1 `lmCompose` m2 lmCompose_Dir Rev m1 m2 = m2 `lmCompose` m1 isThePrimFun :: TFun p -> PrimFun -> Bool -isThePrimFun (TFun _ (Fun JustFun (PrimFun f1))) f2 = f1 == f2 +isThePrimFun (TFun { tf_fun = Fun JustFun (PrimFun f1) }) f2 + = f1 == f2 isThePrimFun _ _ = False isLMOne :: TExpr -> Bool @@ -440,27 +449,32 @@ primFunCallResultTy fun args , ppr (typeof args)]) TypeUnknown --- Just the base function argument type given that the derived function has --- argument type derivedFunArgTy, or Nothing if we can't work it out -baseFunArgTy_maybe :: Pretty p => DerivedFun p -> Type -> Either SDoc (Maybe Type) -baseFunArgTy_maybe derivedFun derivedFunArgTy - = case derivedFun of - Fun JustFun _ -> it's derivedFunArgTy - Fun DrvFun{} _ -> case derivedFunArgTy of - TypeTuple [baseArgTy', _] -> it's baseArgTy' - _ -> Left (text "Expected pair argument type to" <+> pprDerivedFun ppr derivedFun - $$ text "but instead was:" <+> ppr derivedFunArgTy) - Fun GradFun{} _ -> it's derivedFunArgTy - Fun (ShapeFun ds) f -> baseFunArgTy_maybe (Fun ds f) derivedFunArgTy - Fun CLFun _ -> it's derivedFunArgTy - Fun SUFFwdPass _ -> it's derivedFunArgTy - Fun SUFRevPass _ -> don'tKnow - Fun SUFRev _ -> case derivedFunArgTy of - TypeTuple [baseArgTy', _] -> it's baseArgTy' - _ -> Left (text "Expected pair argument type to" <+> pprDerivedFun ppr derivedFun - $$ text "but instead was:" <+> ppr derivedFunArgTy) - where it's = pure . pure - don'tKnow = pure Nothing +baseFunArgTy_maybe :: Pretty f => DerivedFun f -> Type -> Either SDoc (Maybe Type) +-- Given the argument type, figure out the /base/ argument type +-- Right (Just ty) => figured it out +-- Right Nothing => nothing actually wrong, but the Derivations is one +-- that does not determine the base argument type +-- Left err => arg type had an unexpected shape +baseFunArgTy_maybe fun@(Fun ds _) arg_ty + = go ds + where + go (ShapeFun ds) = go ds + go JustFun = it's arg_ty + go GradFun{} = it's arg_ty + go CLFun = it's arg_ty + go SUFFwdPass = it's arg_ty + go SUFRevPass = don'tKnow + go DrvFun{} = case arg_ty of + TypeTuple [baseArgTy', _] -> it's baseArgTy' + _ -> bad "DrvFun" + go SUFRev = case arg_ty of + TypeTuple [baseArgTy', _] -> it's baseArgTy' + _ -> bad "SUFRev" + + it's ty = Right (Just ty) + don'tKnow = Right Nothing + bad _ = Left (text "Expected pair argument type to" <+> ppr fun + $$ text "but instead was:" <+> ppr arg_ty) -- If 'f : S -> T' then -- diff --git a/src/ksc/Shapes.hs b/src/ksc/Shapes.hs index 8b268ebf9..1ee6cbf05 100644 --- a/src/ksc/Shapes.hs +++ b/src/ksc/Shapes.hs @@ -20,11 +20,13 @@ shapeDef (Def { def_fun = Fun ShapeFun{} _ }) = Nothing shapeDef (Def { def_fun = Fun ds f + , def_qvars = qvs , def_pat = VarPat params , def_rhs = UserRhs def_rhs , def_res_ty = res_ty }) = Just $ Def { def_fun = Fun (ShapeFun ds) f + , def_qvars = qvs , def_pat = VarPat params , def_res_ty = shapeType res_ty , def_rhs = UserRhs (mkLetForShapeOfParameter params (shapeE def_rhs)) } @@ -76,16 +78,20 @@ shapeE (App{}) = error "Shape of App not supported" shapeCall :: HasCallStack => TFun Typed -> TExpr -> TExpr -shapeCall (TFun _ (Fun JustFun (PrimFun (P_SelFun i n)))) e +shapeCall (TFun { tf_fun = Fun JustFun (PrimFun (P_SelFun i n)) }) e = pSel i n (shapeE e) -shapeCall (TFun _ (Fun JustFun (PrimFun f))) e +shapeCall (TFun { tf_fun = Fun JustFun (PrimFun f) }) e | Just e' <- shapeCallPrim f e = e' -shapeCall (TFun ty (Fun ds f)) e +shapeCall (TFun { tf_ret = ty, tf_fun = Fun ds f, tf_targs = targs }) e | isBaseUserFun f - = Call (TFun (shapeType ty) (Fun (ShapeFun ds) f)) e + = Call shape_fun e + where + shape_fun = TFun { tf_ret = shapeType ty + , tf_fun = Fun (ShapeFun ds) f + , tf_targs = targs } shapeCall tf e = pShape (Call tf e) -- Fall back to calling the original function and evaluating the shape of the returned object diff --git a/test/ksc/poly1.ks b/test/ksc/poly1.ks new file mode 100644 index 000000000..0c1268d0a --- /dev/null +++ b/test/ksc/poly1.ks @@ -0,0 +1,7 @@ +; Copyright (c) Microsoft Corporation. +; Licensed under the MIT license. + +(def id [a] a ( x : a ) x) + + +(def f Float () (id 0.0))