Skip to content

Commit

Permalink
Introduces UnitTests for AbstractDescriptorReaderTest.getAll() and fi…
Browse files Browse the repository at this point in the history
…xes the two methods for PGVector implementation.
  • Loading branch information
Ralph Gasser committed Aug 14, 2024
1 parent 4cfccbc commit 6142a6a
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ data class SchemaConfig(
) {

companion object {

/**
* Tries to load a [SchemaConfig] from the resources.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ abstract class AbstractDatabaseTest(schemaPath: String) {
init {
/* Loads schema. */
val schema = SchemaConfig.loadFromResource(schemaPath)
schema.name = SCHEMA_NAME
this.manager.load(schema)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,41 @@ abstract class AbstractFloatVectorDescriptorReaderTest(schemaPath: String) : Abs
/** The [Schema.Field] used for this [DescriptorInitializerTest]. */
private val field: Schema.Field<*, FloatVectorDescriptor> = this.testSchema["averagecolor"]!! as Schema.Field<*, FloatVectorDescriptor>

/**
* Tests [VectorDescriptorReader.getAll] method.
*/
@Test
fun testReadAll() {
val writer = this.testConnection.getDescriptorWriter(this.field)
val reader = this.testConnection.getDescriptorReader(this.field)
val random = SplittableRandom()

/* Generate and store test data. */
val descriptors = this.initialize(writer, random)
reader.getAll().forEachIndexed { index, floatVectorDescriptor ->
Assertions.assertEquals(descriptors[index].id, floatVectorDescriptor.id)
Assertions.assertEquals(descriptors[index].retrievableId, floatVectorDescriptor.retrievableId)
Assertions.assertArrayEquals(descriptors[index].vector.value, floatVectorDescriptor.vector.value)
}
}

/**
* Tests [VectorDescriptorReader.getAll] (with parameters) method.
*/
@Test
fun testGetAll() {
val writer = this.testConnection.getDescriptorWriter(this.field)
val reader = this.testConnection.getDescriptorReader(this.field)
val random = SplittableRandom()

/* Generate and store test data. */
val descriptors = this.initialize(writer, random)
val selection = descriptors.shuffled().take(100).map { it.id }
reader.getAll(selection).forEach{ floatVectorDescriptor ->
Assertions.assertTrue(selection.contains(floatVectorDescriptor.id))
}
}

/**
* Tests [VectorDescriptorReader.queryAndJoin] method.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,18 @@ abstract class AbstractDescriptorReader<D : Descriptor>(final override val field
* @param retrievableId The [RetrievableId] to search for.
* @return [Sequence] of [Descriptor] of type [D]
*/
override fun getForRetrievable(retrievableId: RetrievableId): Sequence<D> {
override fun getForRetrievable(retrievableId: RetrievableId): Sequence<D> = sequence {
try {
this.connection.jdbc.prepareStatement("SELECT * FROM $tableName WHERE $RETRIEVABLE_ID_COLUMN_NAME = ?").use { stmt ->
this@AbstractDescriptorReader.connection.jdbc.prepareStatement("SELECT * FROM $tableName WHERE $RETRIEVABLE_ID_COLUMN_NAME = ?").use { stmt ->
stmt.setObject(1, retrievableId)
val result = stmt.executeQuery()
return sequence {
result.use {
while (result.next()) {
yield(this@AbstractDescriptorReader.rowToDescriptor(result))
}
val result = stmt.executeQuery().use { result ->
while (result.next()) {
yield(this@AbstractDescriptorReader.rowToDescriptor(result))
}
}
}
} catch (e: Exception) {
LOGGER.error(e) { "Failed to fetch descriptor for retrievable $retrievableId from '$tableName' due to SQL error." }
return emptySequence()
}
}

Expand Down Expand Up @@ -101,21 +97,17 @@ abstract class AbstractDescriptorReader<D : Descriptor>(final override val field
*
* @return [Sequence] of all [Descriptor]s.
*/
override fun getAll(): Sequence<D> {
override fun getAll(): Sequence<D> = sequence {
try {
this.connection.jdbc.prepareStatement("SELECT * FROM $tableName").use { stmt ->
val result = stmt.executeQuery()
return sequence {
result.use {
while (result.next()) {
yield(this@AbstractDescriptorReader.rowToDescriptor(result))
}
this@AbstractDescriptorReader.connection.jdbc.prepareStatement("SELECT * FROM $tableName").use { stmt ->
stmt.executeQuery().use { result ->
while (result.next()) {
yield(this@AbstractDescriptorReader.rowToDescriptor(result))
}
}
}
} catch (e: Exception) {
} catch (e: Throwable) {
LOGGER.error(e) { "Failed to fetch descriptors from '$tableName' due to SQL error." }
return emptySequence()
}
}

Expand All @@ -125,23 +117,19 @@ abstract class AbstractDescriptorReader<D : Descriptor>(final override val field
* @param descriptorIds A [Iterable] of [DescriptorId]s to return.
* @return [Sequence] of [Descriptor] of type [D]
*/
override fun getAll(descriptorIds: Iterable<DescriptorId>): Sequence<D> {
override fun getAll(descriptorIds: Iterable<DescriptorId>): Sequence<D> = sequence {
try {
this.connection.jdbc.prepareStatement("SELECT * FROM $tableName WHERE $DESCRIPTOR_ID_COLUMN_NAME = ANY (?)").use { stmt ->
this@AbstractDescriptorReader.connection.jdbc.prepareStatement("SELECT * FROM $tableName WHERE $DESCRIPTOR_ID_COLUMN_NAME = ANY (?)").use { stmt ->
val values = descriptorIds.map { it }.toTypedArray()
stmt.setArray(1, this.connection.jdbc.createArrayOf("uuid", values))
val result = stmt.executeQuery()
return sequence {
result.use {
while (result.next()) {
yield(this@AbstractDescriptorReader.rowToDescriptor(result))
}
stmt.setArray(1, this@AbstractDescriptorReader.connection.jdbc.createArrayOf("uuid", values))
stmt.executeQuery().use { result ->
while (result.next()) {
yield(this@AbstractDescriptorReader.rowToDescriptor(result))
}
}
}
} catch (e: Exception) {
LOGGER.error(e) { "Failed to fetch descriptors from '$tableName' due to SQL error." }
return emptySequence()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,17 @@ class RetrievableReader(override val connection: PgVectorConnection): Retrievabl
*
* @return A [Sequence] of all [Retrievable]s in the database.
*/
override fun getAll(): Sequence<Retrievable> {
override fun getAll(): Sequence<Retrievable> = sequence {
try {
this.connection.jdbc.prepareStatement("SELECT * FROM $RETRIEVABLE_ENTITY_NAME").use { stmt ->
val result = stmt.executeQuery()
return sequence {
this@RetrievableReader.connection.jdbc.prepareStatement("SELECT * FROM $RETRIEVABLE_ENTITY_NAME").use { stmt ->
stmt.executeQuery().use { result ->
while (result.next()) {
yield(Retrieved(result.getObject(RETRIEVABLE_ID_COLUMN_NAME, UUID::class.java), result.getString(RETRIEVABLE_TYPE_COLUMN_NAME), false))
}
result.close()
}
}
} catch (e: Exception) {
LOGGER.error(e) { "Failed to check for retrievables due to SQL error." }
return emptySequence()
}
}

Expand Down

0 comments on commit 6142a6a

Please sign in to comment.