Skip to content

Commit

Permalink
Cancel all handlers on finalizing the background process manager
Browse files Browse the repository at this point in the history
  • Loading branch information
eikek committed Sep 18, 2024
1 parent 1122112 commit 4ffd475
Showing 1 changed file with 67 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,62 +79,73 @@ object BackgroundProcessManage:
retryDelay: FiniteDuration,
maxRetries: Option[Int] = None
): Resource[F, BackgroundProcessManage[F]] =
val logger = scribe.cats.effect[F]
Supervisor[F](await = false).flatMap { supervisor =>
Resource.eval(Ref.of[F, State[F]](State.empty[F])).map { state =>
new BackgroundProcessManage[F] {
def register(name: TaskName, task: F[Unit]): F[Unit] =
state.update(_.put(name, wrapTask(name, task)))

def startAll: F[Unit] =
state.get
.flatMap(s => logger.info(s"Starting ${s.tasks.size} background tasks")) >>
background(_ => true)

def currentProcesses: F[Set[TaskName]] =
state.get.map(_.processes.keySet)

def background(taskFilter: TaskName => Boolean): F[Unit] =
for {
ts <- state.get.map(_.getTasks(taskFilter))
_ <- ts.toList
.traverse { case (name, task) =>
supervisor.supervise(task).map(t => name -> t)
}
.map(_.toMap)
.flatMap(ps => state.update(_.setProcesses(ps)))
} yield ()

/** Stop all tasks by filtering on their registered name. */
def cancelProcesses(filter: TaskName => Boolean): F[Unit] =
for
current <- state.get
ps = current.getProcesses(filter)
_ <- ps.toList.traverse_ { case (name, p) =>
logger.info(s"Cancel background process $name") >> p.cancel >> p.join
.flatMap(out => logger.info(s"Task $name cancelled: $out"))
}
_ <- state.update(_.removeProcesses(ps.keySet))
yield ()

private def wrapTask(name: TaskName, task: F[Unit]): F[Unit] =
def run(c: Ref[F, Long]): F[Unit] =
logger.info(s"Starting process for: ${name}") >>
task.handleErrorWith { err =>
c.updateAndGet(_ + 1).flatMap {
case n if maxRetries.exists(_ <= n) =>
logger.error(
s"Max retries ($maxRetries) for process ${name} exceeded"
) >> Async[F].raiseError(err)
case n =>
val maxRetriesLabel = maxRetries.map(m => s"/$m").getOrElse("")
logger.error(
s"Starting process for '${name}' failed ($n$maxRetriesLabel), retrying",
err
) >> Async[F].delayBy(run(c), retryDelay)
}
}
Ref.of[F, Long](0).flatMap(run) >> state.update(_.removeProcesses(Set(name)))
}
Resource.eval(Ref.of[F, State[F]](State.empty[F])).flatMap { state =>
Resource
.make(Async[F].pure(new Impl(supervisor, state, retryDelay, maxRetries)))(
_.cancelProcesses(_ => true)
)
}
}

private class Impl[F[_]: Async](
supervisor: Supervisor[F],
state: Ref[F, State[F]],
retryDelay: FiniteDuration,
maxRetries: Option[Int] = None
) extends BackgroundProcessManage[F] {
val logger = scribe.cats.effect[F]

def register(name: TaskName, task: F[Unit]): F[Unit] =
state.update(_.put(name, wrapTask(name, task)))

def startAll: F[Unit] =
state.get
.flatMap(s => logger.info(s"Starting ${s.tasks.size} background tasks")) >>
background(_ => true)

def currentProcesses: F[Set[TaskName]] =
state.get.map(_.processes.keySet)

def background(taskFilter: TaskName => Boolean): F[Unit] =
for {
ts <- state.get.map(_.getTasks(taskFilter))
_ <- ts.toList
.traverse { case (name, task) =>
supervisor.supervise(task).map(t => name -> t)
}
.map(_.toMap)
.flatMap(ps => state.update(_.setProcesses(ps)))
} yield ()

/** Stop all tasks by filtering on their registered name. */
def cancelProcesses(filter: TaskName => Boolean): F[Unit] =
for
current <- state.get
ps = current.getProcesses(filter)
_ <- ps.toList.traverse_ { case (name, p) =>
logger.info(s"Cancel background process $name") >> p.cancel >> p.join
.flatMap(out => logger.info(s"Task $name cancelled: $out"))
}
_ <- state.update(_.removeProcesses(ps.keySet))
yield ()

private def wrapTask(name: TaskName, task: F[Unit]): F[Unit] =
def run(c: Ref[F, Long]): F[Unit] =
logger.info(s"Starting process for: ${name}") >>
task.handleErrorWith { err =>
c.updateAndGet(_ + 1).flatMap {
case n if maxRetries.exists(_ <= n) =>
logger.error(
s"Max retries ($maxRetries) for process ${name} exceeded"
) >> Async[F].raiseError(err)
case n =>
val maxRetriesLabel = maxRetries.map(m => s"/$m").getOrElse("")
logger.error(
s"Starting process for '${name}' failed ($n$maxRetriesLabel), retrying",
err
) >> Async[F].delayBy(run(c), retryDelay)
}
}
Ref.of[F, Long](0).flatMap(run) >> state.update(_.removeProcesses(Set(name)))
}

0 comments on commit 4ffd475

Please sign in to comment.