From eb1181bd81704f9d146f4a9b4d94d3999511d3a5 Mon Sep 17 00:00:00 2001 From: Eike Kettner Date: Tue, 3 Sep 2024 17:18:55 +0200 Subject: [PATCH] Extend background process manager It must be able to track processes so they can be restarted --- .../provision/BackgroundProcessManage.scala | 140 ++++++++++++------ .../search/provision/MessageHandlers.scala | 67 ++++++--- .../renku/search/provision/Microservice.scala | 12 +- .../io/renku/search/provision/Services.scala | 7 +- .../BackgroundProcessManageSpec.scala | 94 ++++++++++++ 5 files changed, 248 insertions(+), 72 deletions(-) create mode 100644 modules/search-provision/src/test/scala/io/renku/search/provision/BackgroundProcessManageSpec.scala diff --git a/modules/search-provision/src/main/scala/io/renku/search/provision/BackgroundProcessManage.scala b/modules/search-provision/src/main/scala/io/renku/search/provision/BackgroundProcessManage.scala index 83493877..610c6200 100644 --- a/modules/search-provision/src/main/scala/io/renku/search/provision/BackgroundProcessManage.scala +++ b/modules/search-provision/src/main/scala/io/renku/search/provision/BackgroundProcessManage.scala @@ -23,68 +23,118 @@ import scala.concurrent.duration.FiniteDuration import cats.effect.* import cats.effect.kernel.Fiber import cats.effect.kernel.Ref +import cats.effect.std.Supervisor import cats.syntax.all.* +import io.renku.search.provision.BackgroundProcessManage.TaskName + trait BackgroundProcessManage[F[_]]: - def register(name: String, process: F[Unit]): F[Unit] + def register(name: TaskName, task: F[Unit]): F[Unit] + + /** Starts all registered tasks in the background. */ + def background(taskFilter: TaskName => Boolean): F[Unit] - /** Starts all registered tasks in the background, represented by `F[Unit]`. */ - def background: Resource[F, F[Unit]] + def startAll: F[Unit] - /** Same as `.background.useForever` */ - def startAll: F[Nothing] + /** Stop all tasks by filtering on their registered name. */ + def cancelProcesses(filter: TaskName => Boolean): F[Unit] + + /** Get the names of all processses currently running. */ + def currentProcesses: F[Set[TaskName]] object BackgroundProcessManage: - type Process[F[_]] = Fiber[F, Throwable, Unit] + private type Process[F[_]] = Fiber[F, Throwable, Unit] + + trait TaskName: + def equals(x: Any): Boolean + def hashCode(): Int + + object TaskName: + final case class Name(value: String) extends TaskName + def fromString(name: String): TaskName = Name(name) + + private case class State[F[_]]( + tasks: Map[TaskName, F[Unit]], + processes: Map[TaskName, Process[F]] + ): + def put(name: TaskName, p: F[Unit]): State[F] = + State(tasks.updated(name, p), processes) + + def getTasks(filter: TaskName => Boolean): Map[TaskName, F[Unit]] = + tasks.view.filterKeys(filter).toMap - private case class State[F[_]](tasks: Map[String, F[Unit]]): - def put(name: String, p: F[Unit]): State[F] = - State(tasks.updated(name, p)) + def getProcesses(filter: TaskName => Boolean): Map[TaskName, Process[F]] = + processes.view.filterKeys(filter).toMap - def getTasks: List[F[Unit]] = tasks.values.toList + def setProcesses(ps: Map[TaskName, Process[F]]): State[F] = + copy(processes = ps) + + def removeProcesses(names: Set[TaskName]): State[F] = + copy(processes = processes.view.filterKeys(n => !names.contains(n)).toMap) private object State: - def empty[F[_]]: State[F] = State[F](Map.empty) + def empty[F[_]]: State[F] = State[F](Map.empty, Map.empty) def apply[F[_]: Async]( retryDelay: FiniteDuration, maxRetries: Option[Int] = None - ): F[BackgroundProcessManage[F]] = + ): Resource[F, BackgroundProcessManage[F]] = val logger = scribe.cats.effect[F] - Ref.of[F, State[F]](State.empty[F]).map { state => - new BackgroundProcessManage[F] { - def register(name: String, task: F[Unit]): F[Unit] = - state.update(_.put(name, wrapTask(name, task))) - - def startAll: F[Nothing] = - state.get - .flatMap(s => logger.info(s"Starting ${s.tasks.size} background tasks")) >> - background.useForever - - def background: Resource[F, F[Unit]] = - for { - ts <- Resource.eval(state.get.map(_.getTasks)) - x <- ts.traverse(t => Async[F].background(t)) - y = x.traverse_(_.map(_.embed(logger.info(s"Got cancelled")))) - } yield y - - def wrapTask(name: String, task: F[Unit]): F[Unit] = - def run(c: Ref[F, Long]): F[Unit] = - logger.info(s"Starting process for: ${name}") >> - task.handleErrorWith { err => - c.updateAndGet(_ + 1).flatMap { - case n if maxRetries.exists(_ <= n) => - logger.error( - s"Max retries ($maxRetries) for process ${name} exceeded" - ) >> Async[F].raiseError(err) - case n => - val maxRetriesLabel = maxRetries.map(m => s"/$m").getOrElse("") - logger.error( - s"Starting process for '${name}' failed ($n$maxRetriesLabel), retrying", - err - ) >> Async[F].delayBy(run(c), retryDelay) + Supervisor[F](await = false).flatMap { supervisor => + Resource.eval(Ref.of[F, State[F]](State.empty[F])).map { state => + new BackgroundProcessManage[F] { + def register(name: TaskName, task: F[Unit]): F[Unit] = + state.update(_.put(name, wrapTask(name, task))) + + def startAll: F[Unit] = + state.get + .flatMap(s => logger.info(s"Starting ${s.tasks.size} background tasks")) >> + background(_ => true) + + def currentProcesses: F[Set[TaskName]] = + state.get.map(_.processes.keySet) + + def background(taskFilter: TaskName => Boolean): F[Unit] = + for { + ts <- state.get.map(_.getTasks(taskFilter)) + _ <- ts.toList + .traverse { case (name, task) => + supervisor.supervise(task).map(t => name -> t) } + .map(_.toMap) + .flatMap(ps => state.update(_.setProcesses(ps))) + } yield () + + /** Stop all tasks by filtering on their registered name. */ + def cancelProcesses(filter: TaskName => Boolean): F[Unit] = + for + current <- state.get + ps = current.getProcesses(filter) + _ <- ps.toList.traverse_ { case (name, p) => + logger.info(s"Cancel background process $name") >> p.cancel >> p.join + .flatMap(out => logger.info(s"Task $name cancelled: $out")) } - Ref.of[F, Long](0).flatMap(run) + _ <- state.update(_.removeProcesses(ps.keySet)) + yield () + + private def wrapTask(name: TaskName, task: F[Unit]): F[Unit] = + def run(c: Ref[F, Long]): F[Unit] = + logger.info(s"Starting process for: ${name}") >> + task.handleErrorWith { err => + c.updateAndGet(_ + 1).flatMap { + case n if maxRetries.exists(_ <= n) => + logger.error( + s"Max retries ($maxRetries) for process ${name} exceeded" + ) >> Async[F].raiseError(err) + case n => + val maxRetriesLabel = maxRetries.map(m => s"/$m").getOrElse("") + logger.error( + s"Starting process for '${name}' failed ($n$maxRetriesLabel), retrying", + err + ) >> Async[F].delayBy(run(c), retryDelay) + } + } + Ref.of[F, Long](0).flatMap(run) >> state.update(_.removeProcesses(Set(name))) + } } } diff --git a/modules/search-provision/src/main/scala/io/renku/search/provision/MessageHandlers.scala b/modules/search-provision/src/main/scala/io/renku/search/provision/MessageHandlers.scala index 526322c3..99ce4a72 100644 --- a/modules/search-provision/src/main/scala/io/renku/search/provision/MessageHandlers.scala +++ b/modules/search-provision/src/main/scala/io/renku/search/provision/MessageHandlers.scala @@ -24,6 +24,8 @@ import fs2.Stream import io.renku.redis.client.QueueName import io.renku.search.config.QueuesConfig +import io.renku.search.provision.BackgroundProcessManage.TaskName +import io.renku.search.provision.MessageHandlers.MessageHandlerKey import io.renku.search.provision.handler.* /** The entry point for defining all message handlers. @@ -39,48 +41,48 @@ final class MessageHandlers[F[_]: Async]( assert(maxConflictRetries >= 0, "maxConflictRetries must be >= 0") private val logger = scribe.cats.effect[F] - private var tasks: Map[String, F[Unit]] = Map.empty - private def add[A](queue: QueueName, task: Stream[F, A]): Stream[F, Unit] = - tasks = tasks.updated(queue.name, task.compile.drain) + private var tasks: Map[TaskName, F[Unit]] = Map.empty + private def add[A](name: MessageHandlerKey, task: Stream[F, A]): Stream[F, Unit] = + tasks = tasks.updated(name, task.compile.drain) task.void private[provision] def withMaxConflictRetries(n: Int): MessageHandlers[F] = new MessageHandlers[F](steps, cfg, n) - def getAll: Map[String, F[Unit]] = tasks + def getAll: Map[TaskName, F[Unit]] = tasks val allEvents = add( - cfg.dataServiceAllEvents, + MessageHandlerKey.DataServiceAllEvents, SyncMessageHandler(steps(cfg.dataServiceAllEvents), maxConflictRetries).create ) val projectCreated: Stream[F, Unit] = add( - cfg.projectCreated, + MessageHandlerKey.ProjectCreated, SyncMessageHandler(steps(cfg.projectCreated), maxConflictRetries).create ) val projectUpdated: Stream[F, Unit] = add( - cfg.projectUpdated, + MessageHandlerKey.ProjectUpdated, SyncMessageHandler(steps(cfg.projectUpdated), maxConflictRetries).create ) val projectRemoved: Stream[F, Unit] = add( - cfg.projectRemoved, + MessageHandlerKey.ProjectRemoved, SyncMessageHandler(steps(cfg.projectRemoved), maxConflictRetries).create ) val projectAuthAdded: Stream[F, Unit] = add( - cfg.projectAuthorizationAdded, + MessageHandlerKey.ProjectAuthorizationAdded, SyncMessageHandler(steps(cfg.projectAuthorizationAdded), maxConflictRetries).create ) val projectAuthUpdated: Stream[F, Unit] = add( - cfg.projectAuthorizationUpdated, + MessageHandlerKey.ProjectAuthorizationUpdated, SyncMessageHandler( steps(cfg.projectAuthorizationUpdated), maxConflictRetries @@ -88,60 +90,85 @@ final class MessageHandlers[F[_]: Async]( ) val projectAuthRemoved: Stream[F, Unit] = add( - cfg.projectAuthorizationRemoved, + MessageHandlerKey.ProjectAuthorizationRemoved, SyncMessageHandler(steps(cfg.projectAuthorizationRemoved), maxConflictRetries).create ) val userAdded: Stream[F, Unit] = add( - cfg.userAdded, + MessageHandlerKey.UserAdded, SyncMessageHandler(steps(cfg.userAdded), maxConflictRetries).create ) val userUpdated: Stream[F, Unit] = add( - cfg.userUpdated, + MessageHandlerKey.UserUpdated, SyncMessageHandler(steps(cfg.userUpdated), maxConflictRetries).create ) val userRemoved: Stream[F, Unit] = add( - cfg.userRemoved, + MessageHandlerKey.UserRemoved, SyncMessageHandler(steps(cfg.userRemoved), maxConflictRetries).create ) val groupAdded: Stream[F, Unit] = add( - cfg.groupAdded, + MessageHandlerKey.GroupAdded, SyncMessageHandler(steps(cfg.groupAdded), maxConflictRetries).create ) val groupUpdated: Stream[F, Unit] = add( - cfg.groupUpdated, + MessageHandlerKey.GroupUpdated, SyncMessageHandler(steps(cfg.groupUpdated), maxConflictRetries).create ) val groupRemove: Stream[F, Unit] = add( - cfg.groupRemoved, + MessageHandlerKey.GroupRemoved, SyncMessageHandler(steps(cfg.groupRemoved), maxConflictRetries).create ) val groupMemberAdded: Stream[F, Unit] = add( - cfg.groupMemberAdded, + MessageHandlerKey.GroupMemberAdded, SyncMessageHandler(steps(cfg.groupMemberAdded), maxConflictRetries).create ) val groupMemberUpdated: Stream[F, Unit] = add( - cfg.groupMemberUpdated, + MessageHandlerKey.GroupMemberUpdated, SyncMessageHandler(steps(cfg.groupMemberUpdated), maxConflictRetries).create ) val groupMemberRemoved: Stream[F, Unit] = add( - cfg.groupMemberRemoved, + MessageHandlerKey.GroupMemberRemoved, SyncMessageHandler(steps(cfg.groupMemberRemoved), maxConflictRetries).create ) + +object MessageHandlers: + + enum MessageHandlerKey extends TaskName: + case DataServiceAllEvents + case GroupMemberRemoved + case GroupMemberUpdated + case GroupMemberAdded + case GroupRemoved + case GroupUpdated + case GroupAdded + case UserRemoved + case UserAdded + case UserUpdated + case ProjectAuthorizationRemoved + case ProjectAuthorizationUpdated + case ProjectAuthorizationAdded + case ProjectRemoved + case ProjectUpdated + case ProjectCreated + + object MessageHandlerKey: + def isInstance(tn: TaskName): Boolean = tn match + case _: MessageHandlerKey => true + case _ => false diff --git a/modules/search-provision/src/main/scala/io/renku/search/provision/Microservice.scala b/modules/search-provision/src/main/scala/io/renku/search/provision/Microservice.scala index e7609e2b..16e421ec 100644 --- a/modules/search-provision/src/main/scala/io/renku/search/provision/Microservice.scala +++ b/modules/search-provision/src/main/scala/io/renku/search/provision/Microservice.scala @@ -26,6 +26,7 @@ import cats.syntax.all.* import io.renku.logging.LoggingSetup import io.renku.search.http.HttpServer import io.renku.search.metrics.CollectorRegistryBuilder +import io.renku.search.provision.BackgroundProcessManage.TaskName import io.renku.search.provision.metrics.* import io.renku.search.solr.schema.Migrations import io.renku.solr.client.migration.SchemaMigrator @@ -46,22 +47,23 @@ object Microservice extends IOApp: metrics = metricsUpdaterTask(services) httpServer = httpServerTask(registryBuilder, services.config) tasks = services.messageHandlers.getAll + metrics + httpServer - pm <- BackgroundProcessManage[IO](services.config.retryOnErrorDelay) + pm = services.backgroundManage _ <- tasks.toList.traverse_(pm.register.tupled) _ <- pm.startAll + _ <- IO.never } yield ExitCode.Success } private def httpServerTask( registryBuilder: CollectorRegistryBuilder[IO], config: SearchProvisionConfig - ) = + ): (TaskName, IO[Unit]) = val io = Routes[IO](registryBuilder) .flatMap(HttpServer.build(_, config.httpServerConfig)) .use(_ => IO.never) - "http server" -> io + TaskName.fromString("http server") -> io - private def metricsUpdaterTask(services: Services[IO]) = + private def metricsUpdaterTask(services: Services[IO]): (TaskName, IO[Unit]) = val updateInterval = services.config.metricsUpdateInterval val io = if (updateInterval <= Duration.Zero) @@ -76,7 +78,7 @@ object Microservice extends IOApp: services.queueClient, services.solrClient ).run() - "metrics updater" -> io + TaskName.fromString("metrics updater") -> io private def runSolrMigrations(cfg: SearchProvisionConfig): IO[Unit] = SchemaMigrator[IO](cfg.solrConfig) diff --git a/modules/search-provision/src/main/scala/io/renku/search/provision/Services.scala b/modules/search-provision/src/main/scala/io/renku/search/provision/Services.scala index d10934db..e262b886 100644 --- a/modules/search-provision/src/main/scala/io/renku/search/provision/Services.scala +++ b/modules/search-provision/src/main/scala/io/renku/search/provision/Services.scala @@ -30,7 +30,8 @@ final case class Services[F[_]]( config: SearchProvisionConfig, solrClient: SearchSolrClient[F], queueClient: Stream[F, QueueClient[F]], - messageHandlers: MessageHandlers[F] + messageHandlers: MessageHandlers[F], + backgroundManage: BackgroundProcessManage[F] ) object Services: @@ -49,4 +50,6 @@ object Services: inChunkSize = 1 ) handlers = MessageHandlers[F](steps, cfg.queuesConfig) - } yield Services(cfg, solr, redis, handlers) + + bm <- BackgroundProcessManage[F](cfg.retryOnErrorDelay) + } yield Services(cfg, solr, redis, handlers, bm) diff --git a/modules/search-provision/src/test/scala/io/renku/search/provision/BackgroundProcessManageSpec.scala b/modules/search-provision/src/test/scala/io/renku/search/provision/BackgroundProcessManageSpec.scala new file mode 100644 index 00000000..a1afe67a --- /dev/null +++ b/modules/search-provision/src/test/scala/io/renku/search/provision/BackgroundProcessManageSpec.scala @@ -0,0 +1,94 @@ +/* + * Copyright 2024 Swiss Data Science Center (SDSC) + * A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and + * Eidgenössische Technische Hochschule Zürich (ETHZ). + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.renku.search.provision + +import scala.concurrent.duration.* + +import cats.effect.* +import fs2.Stream + +import io.renku.search.LoggingConfigure +import io.renku.search.provision.BackgroundProcessManageSpec.Key +import munit.CatsEffectSuite + +class BackgroundProcessManageSpec extends CatsEffectSuite with LoggingConfigure: + + def makeInfiniteTask(effect: IO[Unit], pause: FiniteDuration = 20.millis) = Stream + .repeatEval(effect) + .interleave(Stream.sleep[IO](pause).repeat) + .compile + .drain + + val manager = BackgroundProcessManage[IO](10.millis) + + test("register tasks, start and cancel"): + manager.use { m => + for + counter <- Ref[IO].of(0) + task = makeInfiniteTask(counter.update(_ + 1)) + _ <- m.register(Key.Count, task) + _ <- m.startAll + ps <- m.currentProcesses + _ = assertEquals(ps, Set(Key.Count.widen)) + + _ <- IO.sleep(50.millis) + _ <- m.cancelProcesses(_ => true) + ps2 <- m.currentProcesses + _ = assert(ps2.isEmpty) + _ <- IO.sleep(50.millis) + n <- counter.get + _ = assert(n >= 1 && n <= 3, s"$n is not between 1 and 3 (inclusive)") + yield () + } + + test("remove process when done"): + manager.use { m => + for + _ <- m.register(Key.Nothing, IO.unit) + _ <- m.startAll + _ <- IO.sleep(10.millis) + ps <- m.currentProcesses + _ = assert(ps.isEmpty) + yield () + } + + test("restart on error"): + manager.use { m => + for + counter <- Ref[IO].of(0) + task = makeInfiniteTask(counter.updateAndGet(_ + 1).flatMap { n => + if (n % 3 == 0) IO.raiseError(new Exception("boom")) else IO.unit + }) + _ <- m.register(Key.Errors, task) + _ <- m.startAll + _ <- IO.sleep(100.millis) + _ <- m.cancelProcesses(_ => true) + n <- counter.get + _ = assert(n > 3, s"$n is not greater than 3") + yield () + } + +object BackgroundProcessManageSpec: + + enum Key extends BackgroundProcessManage.TaskName: + case Count + case Errors + case Nothing + + def widen: BackgroundProcessManage.TaskName = this