diff --git a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/vector/VectorDescriptorReader.kt b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/vector/VectorDescriptorReader.kt index 79372c7c..93edef10 100644 --- a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/vector/VectorDescriptorReader.kt +++ b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/vector/VectorDescriptorReader.kt @@ -115,30 +115,26 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con * @param query The [ProximityPredicate] to execute. * @return [Sequence] of [VectorDescriptor]s. */ - private fun queryAndJoinProximity(query: ProximityPredicate<*>): Sequence { - val descriptors = mutableListOf, Float>>() - val statement = - "SELECT $DESCRIPTOR_ID_COLUMN_NAME, $RETRIEVABLE_ID_COLUMN_NAME, $VECTOR_ATTRIBUTE_NAME, $VECTOR_ATTRIBUTE_NAME ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME FROM \"${tableName.lowercase()}\" ORDER BY $DISTANCE_COLUMN_NAME ${query.order} LIMIT ${query.k}" + private fun queryAndJoinProximity(query: ProximityPredicate<*>): Sequence = sequence { + val cteTable = "\"${this@VectorDescriptorReader.tableName}_nns\"" + val statement = "WITH $cteTable AS (" + + "SELECT \"$DESCRIPTOR_ID_COLUMN_NAME\",\"$RETRIEVABLE_ID_COLUMN_NAME\",\"$VECTOR_ATTRIBUTE_NAME\",\"$VECTOR_ATTRIBUTE_NAME\" ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME " + + "FROM \"${tableName.lowercase()}\" ORDER BY \"$DISTANCE_COLUMN_NAME\" ${query.order} " + + "LIMIT ${query.k}" + + ") SELECT $cteTable.\"$DESCRIPTOR_ID_COLUMN_NAME\",$cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\",$cteTable.\"$VECTOR_ATTRIBUTE_NAME\",$cteTable.\"$DISTANCE_COLUMN_NAME\",\"$RETRIEVABLE_TYPE_COLUMN_NAME\" " + + "FROM $cteTable INNER JOIN \"$RETRIEVABLE_ENTITY_NAME\" ON (\"$RETRIEVABLE_ENTITY_NAME\".\"$RETRIEVABLE_ID_COLUMN_NAME\" = $cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\")" + + "ORDER BY $cteTable.\"$DISTANCE_COLUMN_NAME\" ${query.order}" + this@VectorDescriptorReader.connection.jdbc.prepareStatement(statement).use { stmt -> stmt.setValue(1, query.value) stmt.executeQuery().use { result -> while (result.next()) { - descriptors.add(this@VectorDescriptorReader.rowToDescriptor(result) to result.getFloat(DISTANCE_COLUMN_NAME)) - } - } - - /* Fetch retrievable ids. */ - val retrievables = this.connection.getRetrievableReader().getAll(descriptors.mapNotNull { it.first.retrievableId }.toSet()).map { it.id to it }.toMap() - return descriptors.asSequence().mapNotNull { (descriptor, distance) -> - val retrievable = retrievables[descriptor.retrievableId] - if (retrievable != null) { + val retrieved = Retrieved(result.getObject(RETRIEVABLE_ID_COLUMN_NAME, UUID::class.java), result.getString(RETRIEVABLE_TYPE_COLUMN_NAME), true) + retrieved.addAttribute(DistanceAttribute(result.getFloat(DISTANCE_COLUMN_NAME))) if (query.fetchVector) { - retrievable.addDescriptor(descriptor) + retrieved.addDescriptor(rowToDescriptor(result)) } - retrievable.addAttribute(DistanceAttribute(distance)) - retrievable as Retrieved - } else { - null + yield(retrieved) } } }