Skip to content

Commit

Permalink
Properly fix this rule.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Jan 23, 2025
1 parent 00bf0a0 commit b1857f5
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions src/Futhark/IR/SOACS/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -538,36 +538,35 @@ removeDeadReduction (_, used) pat aux (Screma w arrs form) =
let redlam_res = bodyResult $ lambdaBody redlam,
let redlam_params = lambdaParams redlam,
let (redlam_xparams, redlam_yparams) =
splitAt (length nes) (lambdaParams redlam),
splitAt (length nes) redlam_params,
let used_after =
map snd . filter ((`UT.used` used) . patElemName . fst) $
zip red_pes redlam_params,
zip (red_pes <> red_pes) redlam_params,
let necessary =
findNecessaryForReturned
(`elem` used_after)
(zip redlam_params $ map resSubExp $ redlam_res <> redlam_res)
redlam_deps,
let alive_mask = map ((`nameIn` necessary) . paramName) redlam_params,
let alive_mask =
zipWith
(||)
(map ((`nameIn` necessary) . paramName) redlam_xparams)
(map ((`nameIn` necessary) . paramName) redlam_yparams),
not $ and alive_mask = Simplify $ do
let fixDeadToNeutral lives ne = if lives then Nothing else Just ne
dead_fix = zipWith fixDeadToNeutral alive_mask nes
keep (_, (x, y), _) =
(paramName x `nameIn` necessary)
|| (paramName y `nameIn` necessary)
(used_red_pes, _, used_nes) =
unzip3 . filter keep $
zip3 red_pes (zip redlam_xparams redlam_yparams) nes
(used_red_pes, used_nes) =
unzip . map snd . filter fst $ zip alive_mask $ zip red_pes nes

when (used_nes == nes) cannotSimplify

let maplam' = removeLambdaResults (take (length nes) alive_mask) maplam
redlam' <- removeLambdaResults (take (length nes) alive_mask) <$> fixLambdaParams redlam (dead_fix ++ dead_fix)
let maplam' = removeLambdaResults alive_mask maplam
redlam' <-
removeLambdaResults alive_mask
<$> fixLambdaParams redlam (dead_fix ++ dead_fix)

auxing aux $
letBind (Pat $ used_red_pes ++ map_pes) $
Op $
Screma w arrs $
mkOp redlam' used_nes maplam'
auxing aux . letBind (Pat $ used_red_pes ++ map_pes) . Op $
Screma w arrs (mkOp redlam' used_nes maplam')
removeDeadReduction' _ _ _ _ = Skip
removeDeadReduction _ _ _ _ = Skip

Expand Down

0 comments on commit b1857f5

Please sign in to comment.