From 15bd825c1b55d0ff404dc08c2f1253a685f2e913 Mon Sep 17 00:00:00 2001 From: oSumAtrIX Date: Fri, 1 Nov 2024 02:46:25 +0100 Subject: [PATCH] finish --- .../repository/AnnouncementRepository.kt | 49 ++++++++++++------- .../api/configuration/routes/Announcements.kt | 21 +++++--- .../services/AnnouncementService.kt | 3 +- .../services/AnnouncementServiceTest.kt | 34 ++++++++----- 4 files changed, 71 insertions(+), 36 deletions(-) diff --git a/src/main/kotlin/app/revanced/api/configuration/repository/AnnouncementRepository.kt b/src/main/kotlin/app/revanced/api/configuration/repository/AnnouncementRepository.kt index e570004..3f6ff6f 100644 --- a/src/main/kotlin/app/revanced/api/configuration/repository/AnnouncementRepository.kt +++ b/src/main/kotlin/app/revanced/api/configuration/repository/AnnouncementRepository.kt @@ -12,10 +12,10 @@ import org.jetbrains.exposed.dao.IntEntityClass import org.jetbrains.exposed.dao.id.EntityID import org.jetbrains.exposed.dao.id.IntIdTable import org.jetbrains.exposed.sql.* -import org.jetbrains.exposed.sql.SqlExpressionBuilder.inList import org.jetbrains.exposed.sql.kotlin.datetime.CurrentDateTime import org.jetbrains.exposed.sql.kotlin.datetime.datetime import org.jetbrains.exposed.sql.transactions.experimental.newSuspendedTransaction +import java.time.LocalDateTime internal class AnnouncementRepository { // This is better than doing a maxByOrNull { it.id } on every request. @@ -73,20 +73,35 @@ internal class AnnouncementRepository { fun latestId(tags: Set) = tags.map { tag -> latestAnnouncementByTag[tag]?.id?.value }.toApiResponseAnnouncementId() - suspend fun paged(offset: Int, count: Int, tags: Set?) = transaction { - if (tags == null) { - Announcement.all() - } else { - @Suppress("NAME_SHADOWING") - val tags = tags.mapNotNull { Tag.findById(it)?.id } - - Announcement.find { - Announcements.id inSubQuery Announcements.innerJoin(AnnouncementTags) - .select(Announcements.id) - .where { AnnouncementTags.tag inList tags } - .withDistinct() + suspend fun paged(cursor: Int, count: Int, tags: Set?, archived: Boolean) = transaction { + Announcement.find { + fun idLessEq() = Announcements.id lessEq cursor + fun archivedAtIsNull() = Announcements.archivedAt.isNull() + fun archivedAtGreaterNow() = Announcements.archivedAt greater LocalDateTime.now().toKotlinLocalDateTime() + + if (tags == null) { + if (archived) { + idLessEq() + } else { + idLessEq() and (archivedAtIsNull() or archivedAtGreaterNow()) + } + } else { + fun archivedAtGreaterOrNullOrTrue() = if (archived) { + Op.TRUE + } else { + archivedAtIsNull() or archivedAtGreaterNow() + } + + fun hasTags() = tags.mapNotNull { Tag.findById(it)?.id }.let { tags -> + Announcements.id inSubQuery Announcements.leftJoin(AnnouncementTags) + .select(AnnouncementTags.announcement) + .where { AnnouncementTags.tag inList tags } + .withDistinct() + } + + idLessEq() and archivedAtGreaterOrNullOrTrue() and hasTags() } - }.orderBy(Announcements.id to SortOrder.DESC).limit(count, offset.toLong()).map { it }.toApiAnnouncement() + }.orderBy(Announcements.id to SortOrder.DESC).limit(count).toApiAnnouncement() } suspend fun get(id: Int) = transaction { @@ -242,11 +257,11 @@ internal class AnnouncementRepository { ) } - private fun List.toApiAnnouncement() = map { it.toApiResponseAnnouncement()!! } + private fun Iterable.toApiAnnouncement() = map { it.toApiResponseAnnouncement()!! } - private fun List.toApiTag() = map { ApiAnnouncementTag(it.id.value, it.name) } + private fun Iterable.toApiTag() = map { ApiAnnouncementTag(it.id.value, it.name) } private fun Int?.toApiResponseAnnouncementId() = this?.let { ApiResponseAnnouncementId(this) } - private fun List.toApiResponseAnnouncementId() = map { it.toApiResponseAnnouncementId() } + private fun Iterable.toApiResponseAnnouncementId() = map { it.toApiResponseAnnouncementId() } } diff --git a/src/main/kotlin/app/revanced/api/configuration/routes/Announcements.kt b/src/main/kotlin/app/revanced/api/configuration/routes/Announcements.kt index 6c75731..aeb1c0f 100644 --- a/src/main/kotlin/app/revanced/api/configuration/routes/Announcements.kt +++ b/src/main/kotlin/app/revanced/api/configuration/routes/Announcements.kt @@ -15,7 +15,6 @@ import io.ktor.http.* import io.ktor.server.application.* import io.ktor.server.auth.* import io.ktor.server.plugins.ratelimit.* -import io.ktor.server.request.* import io.ktor.server.response.* import io.ktor.server.routing.* import io.ktor.server.util.* @@ -31,11 +30,12 @@ internal fun Route.announcementsRoute() = route("announcements") { rateLimit(RateLimitName("strong")) { get { - val offset = call.parameters["offset"]?.toInt() ?: 0 + val cursor = call.parameters["cursor"]?.toInt() ?: Int.MAX_VALUE val count = call.parameters["count"]?.toInt() ?: 16 val tags = call.parameters.getAll("tag") + val archived = call.parameters["archived"]?.toBoolean() ?: true - call.respond(announcementService.paged(offset, count, tags?.map { it.toInt() }?.toSet())) + call.respond(announcementService.paged(cursor, count, tags?.map { it.toInt() }?.toSet(), archived)) } } @@ -130,24 +130,31 @@ private fun Route.installAnnouncementsRouteDocumentation() = installNotarizedRou summary("Get announcements") parameters( Parameter( - name = "offset", + name = "cursor", `in` = Parameter.Location.query, schema = TypeDefinition.INT, - description = "The offset of the announcements", + description = "The offset of the announcements. Default is Int.MAX_VALUE (Newest first)", required = false, ), Parameter( name = "count", `in` = Parameter.Location.query, schema = TypeDefinition.INT, - description = "The count of the announcements", + description = "The count of the announcements. Default is 16", required = false, ), Parameter( name = "tag", `in` = Parameter.Location.query, schema = TypeDefinition.INT, - description = "The tag IDs to filter the announcements by", + description = "The tag IDs to filter the announcements by. Default is all tags", + required = false, + ), + Parameter( + name = "archived", + `in` = Parameter.Location.query, + schema = TypeDefinition.BOOLEAN, + description = "Whether to include archived announcements. Default is true", required = false, ), ) diff --git a/src/main/kotlin/app/revanced/api/configuration/services/AnnouncementService.kt b/src/main/kotlin/app/revanced/api/configuration/services/AnnouncementService.kt index 91bf890..434f0a5 100644 --- a/src/main/kotlin/app/revanced/api/configuration/services/AnnouncementService.kt +++ b/src/main/kotlin/app/revanced/api/configuration/services/AnnouncementService.kt @@ -14,7 +14,8 @@ internal class AnnouncementService( fun latestId() = announcementRepository.latestId() - suspend fun paged(offset: Int, limit: Int, tags: Set?) = announcementRepository.paged(offset, limit, tags) + suspend fun paged(cursor: Int, limit: Int, tags: Set?, archived: Boolean) = + announcementRepository.paged(cursor, limit, tags, archived) suspend fun get(id: Int) = announcementRepository.get(id) diff --git a/src/test/kotlin/app/revanced/api/configuration/services/AnnouncementServiceTest.kt b/src/test/kotlin/app/revanced/api/configuration/services/AnnouncementServiceTest.kt index 3c9ffcf..4b6e058 100644 --- a/src/test/kotlin/app/revanced/api/configuration/services/AnnouncementServiceTest.kt +++ b/src/test/kotlin/app/revanced/api/configuration/services/AnnouncementServiceTest.kt @@ -158,25 +158,37 @@ private object AnnouncementServiceTest { announcementService.new(ApiAnnouncement(title = "title$it")) } - val announcements = announcementService.paged(0, 5, null) - assertEquals(5, announcements.size) - assertEquals("title9", announcements.first().title) + val announcements = announcementService.paged(Int.MAX_VALUE, 5, null, true) + assertEquals(5, announcements.size, "Returns correct number of announcements") + assertEquals("title9", announcements.first().title, "Starts from the latest announcement") - val announcements2 = announcementService.paged(5, 5, null) - assertEquals(5, announcements2.size) - assertEquals("title4", announcements2.first().title) + val announcements2 = announcementService.paged(5, 5, null, true) + assertEquals(5, announcements2.size, "Returns correct number of announcements when starting from the cursor") + assertEquals("title4", announcements2.first().title, "Starts from the cursor") - announcements2.map { it.id }.forEach { id -> + (0..4).forEach { id -> announcementService.update( id, - ApiAnnouncement(title = "title$id", tags = (1..id).map { "tag$it" }), + ApiAnnouncement( + title = "title$id", + tags = (0..id).map { "tag$it" }, + archivedAt = if (id % 2 == 0) { + // Only two announcements will be archived. + LocalDateTime.now().plusDays(2).minusDays(id.toLong()).toKotlinLocalDateTime() + } else { + null + }, + ), ) } val tags = announcementService.tags() - assertEquals(5, tags.size) + assertEquals(5, tags.size, "Returns correct number of newly created tags") - val announcements3 = announcementService.paged(0, 5, setOf(tags.first().id)) - assertEquals(5, announcements3.size) + val announcements3 = announcementService.paged(5, 5, setOf(tags[1].id), true) + assertEquals(4, announcements3.size, "Filters announcements by tag") + + val announcements4 = announcementService.paged(Int.MAX_VALUE, 10, null, false) + assertEquals(8, announcements4.size, "Filters out archived announcements") } }