Skip to content

Commit

Permalink
Separate type constraints and AUTOMAP constraints.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Jan 28, 2025
1 parent 9f8ff45 commit d04c366
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 30 deletions.
9 changes: 3 additions & 6 deletions src/Language/Futhark/TypeChecker/Constraints.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
-- | Constraint solver for solving type equations produced
-- post-AUTOMAP.
module Language.Futhark.TypeChecker.Constraints
( Reason (..),
SVar,
Expand Down Expand Up @@ -85,21 +87,17 @@ instance Located Reason where
locOf (ReasonApplySplit l _ _ _) = l
locOf (ReasonBranches l _ _) = l

data Ct
= CtEq Reason Type Type
| CtAM Reason SVar SVar (Shape SComp)
data Ct = CtEq Reason Type Type
deriving (Show)

ctReason :: Ct -> Reason
ctReason (CtEq r _ _) = r
ctReason (CtAM r _ _ _) = r

instance Located Ct where
locOf = locOf . ctReason

instance Pretty Ct where
pretty (CtEq _ t1 t2) = pretty t1 <+> "~" <+> pretty t2
pretty (CtAM _ r m _) = prettyName r <+> "=" <+> "" <+> "" <+> prettyName m <+> "=" <+> ""

type Constraints = [Ct]

Expand Down Expand Up @@ -684,7 +682,6 @@ solveCt :: Ct -> SolveM ()
solveCt ct =
case ct of
CtEq reason t1 t2 -> solveEq reason mempty t1 t2
CtAM {} -> pure () -- Good vibes only.

scopeCheck :: Reason -> TyVar -> Int -> Type -> SolveM ()
scopeCheck reason v v_lvl ty = do
Expand Down
46 changes: 26 additions & 20 deletions src/Language/Futhark/TypeChecker/Rank.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Language.Futhark.TypeChecker.Rank
( rankAnalysis,
rankAnalysis1,
CtAM (..),
)
where

Expand All @@ -25,6 +26,14 @@ import Language.Futhark.TypeChecker.Constraints
import Language.Futhark.TypeChecker.Monad
import System.IO.Unsafe

data CtAM = CtAM Reason SVar SVar (Shape SComp)

instance Located CtAM where
locOf (CtAM r _ _ _) = locOf r

instance Pretty CtAM where
pretty (CtAM _ r m _) = prettyName r <+> "=" <+> "" <+> "" <+> prettyName m <+> "=" <+> ""

type LSum = LP.LSum VName Int

type Constraint = LP.Constraint VName Int
Expand Down Expand Up @@ -84,7 +93,6 @@ distribAndSplitArrows (CtEq r t1 t2) =
t1r' = t1r `setUniqueness` NoUniqueness
t2r' = t2r `setUniqueness` NoUniqueness
splitArrows c = [c]
distribAndSplitArrows ct = [ct]

distribAndSplitCnstrs :: Ct -> [Ct]
distribAndSplitCnstrs ct@(CtEq r t1 t2) =
Expand All @@ -103,7 +111,6 @@ distribAndSplitCnstrs ct@(CtEq r t1 t2) =
splitCnstrs (CtEq reason (Scalar (Sum cs1)) (Scalar (Sum cs2))) =
concat $ concat $ (zipWith . zipWith) (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems cs1) (M.elems cs2)
splitCnstrs _ = []
distribAndSplitCnstrs ct = [ct]

data RankState = RankState
{ rankBinVars :: Map VName VName,
Expand Down Expand Up @@ -148,7 +155,9 @@ addObj sv =

addCt :: Ct -> RankM ()
addCt (CtEq _ t1 t2) = addConstraint $ rank t1 ~==~ rank t2
addCt (CtAM _ r m f) = do

addCtAM :: CtAM -> RankM ()
addCtAM (CtAM _ r m f) = do
b_r <- binVar r
b_m <- binVar m
b_max <- VName "c_max" <$> incCounter
Expand All @@ -168,8 +177,8 @@ addTyVarInfo tv (_, TyVarRecord {}) =
addTyVarInfo tv (_, TyVarSum {}) =
addConstraint $ rank tv ~==~ constant 0

mkLinearProg :: [Ct] -> TyVars -> LinearProg
mkLinearProg cs tyVars =
mkLinearProg :: [Ct] -> [CtAM] -> TyVars -> LinearProg
mkLinearProg cs cs_am tyVars =
LP.LinearProg
{ optType = Minimize,
objective = rankObj finalState,
Expand All @@ -187,6 +196,7 @@ mkLinearProg cs tyVars =
}
buildLP = do
mapM_ addCt cs
mapM_ addCtAM cs_am
mapM_ (uncurry addTyVarInfo) $ M.toList tyVars
finalState = flip execState initState $ runRankM buildLP

Expand Down Expand Up @@ -249,7 +259,7 @@ solveRankILP loc prog = do
rankAnalysis1 ::
(MonadTypeChecker m) =>
SrcLoc ->
[Ct] ->
([Ct], [CtAM]) ->
TyVars ->
M.Map TyVar Type ->
[Pat ParamType] ->
Expand All @@ -261,8 +271,8 @@ rankAnalysis1 ::
Exp,
Maybe (TypeExp Exp VName)
)
rankAnalysis1 loc cs tyVars artificial params body retdecl = do
solutions <- rankAnalysis loc cs tyVars artificial params body retdecl
rankAnalysis1 loc (cs, cs_am) tyVars artificial params body retdecl = do
solutions <- rankAnalysis loc (cs, cs_am) tyVars artificial params body retdecl
case solutions of
[sol] -> pure sol
sols -> do
Expand All @@ -277,7 +287,7 @@ rankAnalysis1 loc cs tyVars artificial params body retdecl = do
rankAnalysis ::
(MonadTypeChecker m) =>
SrcLoc ->
[Ct] ->
([Ct], [CtAM]) ->
TyVars ->
M.Map TyVar Type ->
[Pat ParamType] ->
Expand All @@ -290,9 +300,9 @@ rankAnalysis ::
Maybe (TypeExp Exp VName)
)
]
rankAnalysis _ [] tyVars artificial params body retdecl =
rankAnalysis _ ([], []) tyVars artificial params body retdecl =
pure [(([], artificial, tyVars), params, body, retdecl)]
rankAnalysis loc cs tyVars artificial params body retdecl = do
rankAnalysis loc (cs, cs_am) tyVars artificial params body retdecl = do
debugTraceM 3 $
unlines
[ "##rankAnalysis",
Expand All @@ -301,8 +311,8 @@ rankAnalysis loc cs tyVars artificial params body retdecl = do
"cs':",
unlines $ map prettyString cs'
]
rank_maps <- solveRankILP loc (mkLinearProg cs' tyVars)
cts_tyvars' <- mapM (substRankInfo cs artificial tyVars) rank_maps
rank_maps <- solveRankILP loc (mkLinearProg cs' cs_am tyVars)
cts_tyvars' <- mapM (substRankInfo (cs, cs_am) artificial tyVars) rank_maps
let bodys = map (`updAM` body) rank_maps
params' = map ((`map` params) . updAMPat) rank_maps
retdecls = map ((<$> retdecl) . updAMTypeExp) rank_maps
Expand All @@ -316,19 +326,16 @@ type RankMap = M.Map VName Int

substRankInfo ::
(MonadTypeChecker m) =>
[Ct] ->
([Ct], [CtAM]) ->
M.Map VName Type ->
TyVars ->
RankMap ->
m ([Ct], M.Map VName Type, TyVars)
substRankInfo cs artificial tyVars rankmap = do
substRankInfo (cs, _cs_am) artificial tyVars rankmap = do
((cs', artificial', tyVars'), new_cs, new_tyVars) <-
runSubstT tyVars rankmap $
(,,) <$> substRanks (filter (not . isCtAM) cs) <*> traverse substRanks artificial <*> traverse substRanks tyVars
(,,) <$> substRanks cs <*> traverse substRanks artificial <*> traverse substRanks tyVars
pure (cs' <> new_cs, artificial', new_tyVars <> tyVars')
where
isCtAM (CtAM {}) = True
isCtAM _ = False

runSubstT :: (MonadTypeChecker m) => TyVars -> RankMap -> SubstT m a -> m (a, [Ct], TyVars)
runSubstT tyVars rankmap (SubstT m) = do
Expand Down Expand Up @@ -443,7 +450,6 @@ instance SubstRanks (TypeBase SComp u) where

instance SubstRanks Ct where
substRanks (CtEq r t1 t2) = CtEq r <$> substRanks t1 <*> substRanks t2
substRanks _ = error ""

instance SubstRanks TyVarInfo where
substRanks tv@TyVarFree {} = pure tv
Expand Down
16 changes: 12 additions & 4 deletions src/Language/Futhark/TypeChecker/Terms2.hs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ data TermEnv = TermEnv
-- generating unique names, as these will be user-visible.
data TermState = TermState
{ termConstraints :: Constraints,
termAM :: [CtAM],
termTyVars :: TyVars,
termTyParams :: TyParams,
termCounter :: !Int,
Expand Down Expand Up @@ -193,6 +194,7 @@ runTermM (TermM m) = do
initial_state =
TermState
{ termConstraints = mempty,
termAM = mempty,
termTyVars = mempty,
termTyParams = mempty,
termWarnings = mempty,
Expand Down Expand Up @@ -311,7 +313,10 @@ ctEq reason t1 t2 =
t2' = t2 `setUniqueness` NoUniqueness

ctAM :: Reason -> SVar -> SVar -> Shape SComp -> TermM ()
ctAM reason r m f = addCt $ CtAM reason r m f
ctAM reason r m f =
modify $ \s -> s {termAM = ct : termAM s}
where
ct = CtAM reason r m f

localScope :: (TermScope -> TermScope) -> TermM a -> TermM a
localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv}
Expand Down Expand Up @@ -1370,6 +1375,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do
pure (params', body', retdecl')

cts <- gets termConstraints
cts_am <- gets termAM
tyvars <- gets termTyVars
typarams <- gets termTyParams
artificial <- gets termArtificial
Expand All @@ -1389,7 +1395,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do
]

onRankSolution typarams
=<< rankAnalysis1 loc cts tyvars artificial params' body' retdecl'
=<< rankAnalysis1 loc (cts, cts_am) tyvars artificial params' body' retdecl'
where
onRankSolution typarams ((cts', artificial, tyvars'), params', body'', retdecl') = do
solution <-
Expand Down Expand Up @@ -1430,11 +1436,12 @@ checkSingleExp ::
checkSingleExp e = runTermM $ do
e' <- checkExp e
cts <- gets termConstraints
cts_am <- gets termAM
tyvars <- gets termTyVars
typarams <- gets termTyParams
artificial <- gets termArtificial
((cts', _artificial', tyvars'), _, e'', _) <-
rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' Nothing
rankAnalysis1 (srclocOf e') (cts, cts_am) tyvars artificial [] e' Nothing
case solve cts' typarams tyvars' of
Left err -> pure (Left err, e'')
Right (unconstrained, solution) -> do
Expand All @@ -1450,12 +1457,13 @@ checkSizeExp ::
checkSizeExp e = runTermM $ do
e' <- checkSizeExp' e
cts <- gets termConstraints
cts_am <- gets termAM
tyvars <- gets termTyVars
typarams <- gets termTyParams
artificial <- gets termArtificial

(cts_tyvars', _, es', _) <-
L.unzip4 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' Nothing
L.unzip4 <$> rankAnalysis (srclocOf e) (cts, cts_am) tyvars artificial [] e' Nothing

solutions <-
forM cts_tyvars' $ \(cts', _artificial', tyvars') ->
Expand Down

0 comments on commit d04c366

Please sign in to comment.