diff --git a/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/query/bool/Logical.kt b/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/query/bool/Logical.kt index 562d107c..0e7fcd1e 100644 --- a/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/query/bool/Logical.kt +++ b/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/query/bool/Logical.kt @@ -6,8 +6,6 @@ package org.vitrivr.engine.core.model.query.bool * @version 1.0 */ sealed interface Logical : BooleanPredicate { - data class Not(val predicate: BooleanPredicate) : Logical - data class And(val predicates: List) : Logical data class Or(val predicates: List) : Logical diff --git a/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/query/fulltext/SimpleFulltextPredicate.kt b/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/query/fulltext/SimpleFulltextPredicate.kt index 3eece392..48a9a8b8 100644 --- a/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/query/fulltext/SimpleFulltextPredicate.kt +++ b/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/query/fulltext/SimpleFulltextPredicate.kt @@ -2,6 +2,7 @@ package org.vitrivr.engine.core.model.query.fulltext import org.vitrivr.engine.core.model.metamodel.Schema import org.vitrivr.engine.core.model.query.Predicate +import org.vitrivr.engine.core.model.query.bool.BooleanPredicate import org.vitrivr.engine.core.model.types.Value @@ -23,5 +24,8 @@ data class SimpleFulltextPredicate( * * Typically, this is pre-determined by the analyser. However, in some cases, this must be specified (e.g., when querying struct fields). */ - val attributeName: String? = null + val attributeName: String? = null, + + /** Optional filter query for this [SimpleFulltextPredicate]. */ + val filter: BooleanPredicate? = null ) : Predicate diff --git a/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/query/proximity/ProximityPredicate.kt b/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/query/proximity/ProximityPredicate.kt index 495fccd0..f900ef37 100644 --- a/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/query/proximity/ProximityPredicate.kt +++ b/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/query/proximity/ProximityPredicate.kt @@ -5,6 +5,8 @@ import org.vitrivr.engine.core.model.metamodel.Schema import org.vitrivr.engine.core.model.query.Predicate import org.vitrivr.engine.core.model.query.basics.Distance import org.vitrivr.engine.core.model.query.basics.SortOrder +import org.vitrivr.engine.core.model.query.bool.BooleanPredicate +import org.vitrivr.engine.core.model.query.fulltext.SimpleFulltextPredicate import org.vitrivr.engine.core.model.types.Value /** @@ -17,7 +19,6 @@ import org.vitrivr.engine.core.model.types.Value * @author Ralph Gasser * @version 1.1.0 */ - data class ProximityPredicate>( /** The [Schema.Field] that this [Predicate] is applied to. */ val field: Schema.Field<*, *>, @@ -42,5 +43,8 @@ data class ProximityPredicate>( * * Typically, this is pre-determined by the analyser. However, in some cases, this must be specified (e.g., when querying struct fields). */ - val attributeName: String? = null + val attributeName: String? = null, + + /** Optional filter query for this [SimpleFulltextPredicate]. */ + val filter: BooleanPredicate? = null ) : Predicate \ No newline at end of file diff --git a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/Utilities.kt b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/Utilities.kt index 1dd7bd4c..49dc7044 100644 --- a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/Utilities.kt +++ b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/Utilities.kt @@ -96,7 +96,6 @@ internal fun BooleanPredicate.toWhere(): String = when (this) { is Comparison<*> -> this.toTerm() is Logical.And -> this.predicates.joinToString(" AND ", "(", ")") { it.toWhere() } is Logical.Or -> this.predicates.joinToString(" OR ", "(", ")") { it.toWhere() } - is Logical.Not -> "NOT (${this.predicate.toWhere()})" } /** diff --git a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/AbstractDescriptorReader.kt b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/AbstractDescriptorReader.kt index 06fa03eb..8f018806 100644 --- a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/AbstractDescriptorReader.kt +++ b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/AbstractDescriptorReader.kt @@ -6,9 +6,14 @@ import org.vitrivr.engine.core.model.descriptor.DescriptorId import org.vitrivr.engine.core.model.metamodel.Schema import org.vitrivr.engine.core.model.query.Predicate import org.vitrivr.engine.core.model.query.Query +import org.vitrivr.engine.core.model.query.bool.BooleanPredicate +import org.vitrivr.engine.core.model.query.bool.Comparison +import org.vitrivr.engine.core.model.query.bool.Logical import org.vitrivr.engine.core.model.retrievable.RetrievableId import org.vitrivr.engine.core.model.retrievable.Retrieved import org.vitrivr.engine.database.pgvector.* +import org.vitrivr.engine.database.pgvector.descriptor.scalar.ScalarDescriptorReader +import org.vitrivr.engine.database.pgvector.descriptor.struct.StructDescriptorReader import java.sql.ResultSet import java.util.* @@ -200,6 +205,44 @@ abstract class AbstractDescriptorReader>(final override val fi } } + /** + * Resolves a complex [BooleanPredicate] into a set of [RetrievableId]s that match it. + * + * @param predicate [BooleanPredicate] to resolve. + * @return Set of [RetrievableId]s that match the [BooleanPredicate]. + */ + internal fun resolveBooleanPredicate(predicate: BooleanPredicate): Set = when (predicate) { + is Comparison<*> -> { + val field = predicate.field + val reader = field.getReader() + when (reader) { + is ScalarDescriptorReader -> reader.queryComparison(predicate).map { it.id }.toSet() + is StructDescriptorReader -> reader.queryComparison(predicate).map { it.id }.toSet() + else -> throw IllegalArgumentException("Cannot resolve predicate $predicate.") + } + } + + is Logical.And -> { + val intersection = mutableSetOf() + for ((index, child) in predicate.predicates.withIndex()) { + if (index == 0) { + intersection.addAll(resolveBooleanPredicate(child)) + } else { + intersection.intersect(resolveBooleanPredicate(child)) + } + } + intersection + } + + is Logical.Or -> { + val union = mutableSetOf() + for (child in predicate.predicates) { + union.addAll(resolveBooleanPredicate(child)) + } + union + } + } + /** * Converts a [ResultSet] to a [Descriptor] of type [D]. * diff --git a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/scalar/ScalarDescriptorReader.kt b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/scalar/ScalarDescriptorReader.kt index 61ab62be..f86c5bde 100644 --- a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/scalar/ScalarDescriptorReader.kt +++ b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/scalar/ScalarDescriptorReader.kt @@ -31,7 +31,7 @@ class ScalarDescriptorReader(field: Schema.Field<*, ScalarDescriptor<*, *>>, con */ override fun query(query: Query): Sequence> = when (val predicate = query.predicate) { is SimpleFulltextPredicate -> queryFulltext(predicate) - is Comparison<*> -> queryBoolean(predicate) + is Comparison<*> -> queryComparison(predicate) else -> throw IllegalArgumentException("Query of type ${query::class} is not supported by ScalarDescriptorReader.") } @@ -84,7 +84,7 @@ class ScalarDescriptorReader(field: Schema.Field<*, ScalarDescriptor<*, *>>, con * @param query The [Comparison] to execute. * @return [Sequence] of [ScalarDescriptor]s. */ - private fun queryBoolean(query: Comparison<*>): Sequence> { + internal fun queryComparison(query: Comparison<*>): Sequence> { val statement = "SELECT * FROM \"$tableName\" WHERE ${query.toWhere()}" return sequence { this@ScalarDescriptorReader.connection.jdbc.prepareStatement(statement).use { stmt -> diff --git a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/struct/StructDescriptorReader.kt b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/struct/StructDescriptorReader.kt index bd761e12..4de4d782 100644 --- a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/struct/StructDescriptorReader.kt +++ b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/struct/StructDescriptorReader.kt @@ -32,7 +32,7 @@ class StructDescriptorReader(field: Schema.Field<*, StructDescriptor<*>>, connec */ override fun query(query: Query): Sequence> = when (val predicate = query.predicate) { is SimpleFulltextPredicate -> queryFulltext(predicate) - is Comparison<*> -> queryBoolean(predicate) + is Comparison<*> -> queryComparison(predicate) else -> throw IllegalArgumentException("Query of typ ${query::class} is not supported by StructDescriptorReader.") } @@ -106,7 +106,7 @@ class StructDescriptorReader(field: Schema.Field<*, StructDescriptor<*>>, connec * @param query The [Comparison] to execute. * @return [Sequence] of [StructDescriptor]s. */ - private fun queryBoolean(query: Comparison<*>): Sequence> { + internal fun queryComparison(query: Comparison<*>): Sequence> { require(query.attributeName != null) { "Query attribute must not be null for a fulltext query on a struct descriptor." } val statement = "SELECT * FROM \"${tableName.lowercase()}\" WHERE ${query.toWhere()}" return sequence { diff --git a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/vector/VectorDescriptorReader.kt b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/vector/VectorDescriptorReader.kt index 93edef10..0c3c41b2 100644 --- a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/vector/VectorDescriptorReader.kt +++ b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/vector/VectorDescriptorReader.kt @@ -13,6 +13,7 @@ import org.vitrivr.engine.database.pgvector.* import org.vitrivr.engine.database.pgvector.descriptor.AbstractDescriptorReader import org.vitrivr.engine.database.pgvector.descriptor.model.PgBitVector import org.vitrivr.engine.database.pgvector.descriptor.model.PgVector +import java.sql.PreparedStatement import java.sql.ResultSet import java.util.* @@ -20,7 +21,7 @@ import java.util.* * An abstract implementation of a [DescriptorReader] for Cottontail DB. * * @author Ralph Gasser - * @version 1.0.0 + * @version 1.1.0 */ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, connection: PgVectorConnection) : AbstractDescriptorReader>(field, connection) { /** @@ -28,9 +29,18 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con * * @param query The [Query] to execute. */ - override fun query(query: Query): Sequence> = when (val predicate = query.predicate) { - is ProximityPredicate<*> -> queryProximity(predicate) - else -> throw UnsupportedOperationException("Query of typ ${query::class} is not supported by VectorDescriptorReader.") + override fun query(query: Query): Sequence> = sequence { + when (val predicate = query.predicate) { + is ProximityPredicate<*> -> prepareProximity(predicate).use { stmt -> + stmt.executeQuery().use { result -> + while (result.next()) { + yield(rowToDescriptor(result)) + } + } + } + + else -> throw UnsupportedOperationException("Query of typ ${query::class} is not supported by VectorDescriptorReader.") + } } /** @@ -41,9 +51,23 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con * @param query The [Query] that should be executed. * @return [Sequence] of [Retrieved]. */ - override fun queryAndJoin(query: Query): Sequence = when (val predicate = query.predicate) { - is ProximityPredicate<*> -> queryAndJoinProximity(predicate) - else -> super.queryAndJoin(query) + override fun queryAndJoin(query: Query): Sequence = sequence { + when (val predicate = query.predicate) { + is ProximityPredicate<*> -> queryAndJoinProximity(predicate).use { stmt -> + stmt.executeQuery().use { result -> + while (result.next()) { + val retrieved = Retrieved(result.getObject(RETRIEVABLE_ID_COLUMN_NAME, UUID::class.java), result.getString(RETRIEVABLE_TYPE_COLUMN_NAME), true) + retrieved.addAttribute(DistanceAttribute(result.getFloat(DISTANCE_COLUMN_NAME))) + if (predicate.fetchVector) { + retrieved.addDescriptor(rowToDescriptor(result)) + } + yield(retrieved) + } + } + } + + else -> throw UnsupportedOperationException("Query of typ ${query::class} is not supported by VectorDescriptorReader.") + } } /** @@ -85,58 +109,79 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con retrievableId, result.getObject(VECTOR_ATTRIBUTE_NAME, PgBitVector::class.java)?.toBooleanVector() ?: throw IllegalArgumentException("The provided vector value is missing the required field '$VECTOR_ATTRIBUTE_NAME'.") ) - - else -> throw IllegalArgumentException("Unsupported descriptor type ${this.prototype::class}.") } } /** - * Executes a [ProximityPredicate] and returns a [Sequence] of [VectorDescriptor]s. + * Prepares a [ProximityPredicate] and returns a [PreparedStatement]. * * @param query The [ProximityPredicate] to execute. - * @return [Sequence] of [VectorDescriptor]s. + * @return [PreparedStatement] for [ProximityPredicate]. */ - private fun queryProximity(query: ProximityPredicate<*>): Sequence> = sequence { - val statement = - "SELECT $DESCRIPTOR_ID_COLUMN_NAME, $RETRIEVABLE_ID_COLUMN_NAME, $VECTOR_ATTRIBUTE_NAME, $VECTOR_ATTRIBUTE_NAME ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME FROM \"${tableName.lowercase()}\" ORDER BY $DISTANCE_COLUMN_NAME ${query.order} LIMIT ${query.k}" - this@VectorDescriptorReader.connection.jdbc.prepareStatement(statement).use { stmt -> + private fun prepareProximity(query: ProximityPredicate<*>): PreparedStatement { + val tableName = this.tableName.lowercase() + val filter = query.filter + if (filter == null) { + val sql = "SELECT $DESCRIPTOR_ID_COLUMN_NAME, $RETRIEVABLE_ID_COLUMN_NAME, $VECTOR_ATTRIBUTE_NAME, $VECTOR_ATTRIBUTE_NAME ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME " + + "FROM \"$tableName\" " + + "ORDER BY $DISTANCE_COLUMN_NAME ${query.order} " + + "LIMIT ${query.k}" + val stmt = this@VectorDescriptorReader.connection.jdbc.prepareStatement(sql) stmt.setValue(1, query.value) - stmt.executeQuery().use { result -> - while (result.next()) { - yield(rowToDescriptor(result)) - } - } + return stmt + } else { + val sql = "SELECT $DESCRIPTOR_ID_COLUMN_NAME, $RETRIEVABLE_ID_COLUMN_NAME, $VECTOR_ATTRIBUTE_NAME, $VECTOR_ATTRIBUTE_NAME ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME " + + "FROM \"$tableName\" " + + "WHERE $RETRIEVABLE_ID_COLUMN_NAME = ANY(?) " + + "ORDER BY $DISTANCE_COLUMN_NAME ${query.order} " + + "LIMIT ${query.k}" + + val retrievableIds = this.resolveBooleanPredicate(filter) + val stmt = this@VectorDescriptorReader.connection.jdbc.prepareStatement(sql) + stmt.setValue(1, query.value) + stmt.setArray(2, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray())) + return stmt } } /** - * Executes a [ProximityPredicate] and returns a [Sequence] of [VectorDescriptor]s. + * Prepares a [ProximityPredicate] and returns a [Sequence] of [PreparedStatement]s. * * @param query The [ProximityPredicate] to execute. - * @return [Sequence] of [VectorDescriptor]s. + * @return [PreparedStatement]. */ - private fun queryAndJoinProximity(query: ProximityPredicate<*>): Sequence = sequence { - val cteTable = "\"${this@VectorDescriptorReader.tableName}_nns\"" - val statement = "WITH $cteTable AS (" + - "SELECT \"$DESCRIPTOR_ID_COLUMN_NAME\",\"$RETRIEVABLE_ID_COLUMN_NAME\",\"$VECTOR_ATTRIBUTE_NAME\",\"$VECTOR_ATTRIBUTE_NAME\" ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME " + - "FROM \"${tableName.lowercase()}\" ORDER BY \"$DISTANCE_COLUMN_NAME\" ${query.order} " + - "LIMIT ${query.k}" + - ") SELECT $cteTable.\"$DESCRIPTOR_ID_COLUMN_NAME\",$cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\",$cteTable.\"$VECTOR_ATTRIBUTE_NAME\",$cteTable.\"$DISTANCE_COLUMN_NAME\",\"$RETRIEVABLE_TYPE_COLUMN_NAME\" " + - "FROM $cteTable INNER JOIN \"$RETRIEVABLE_ENTITY_NAME\" ON (\"$RETRIEVABLE_ENTITY_NAME\".\"$RETRIEVABLE_ID_COLUMN_NAME\" = $cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\")" + - "ORDER BY $cteTable.\"$DISTANCE_COLUMN_NAME\" ${query.order}" + private fun queryAndJoinProximity(query: ProximityPredicate<*>): PreparedStatement { + val tableName = "\"${this@VectorDescriptorReader.tableName.lowercase()}\"" + val cteTable = "\"${tableName}_nns\"" + val filter = query.filter + if (filter == null) { + val sql = "WITH $cteTable AS (" + + "SELECT \"$DESCRIPTOR_ID_COLUMN_NAME\",\"$RETRIEVABLE_ID_COLUMN_NAME\",\"$VECTOR_ATTRIBUTE_NAME\",\"$VECTOR_ATTRIBUTE_NAME\" ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME " + + "FROM $tableName " + + "ORDER BY \"$DISTANCE_COLUMN_NAME\" ${query.order} " + + "LIMIT ${query.k}" + + ") SELECT $cteTable.\"$DESCRIPTOR_ID_COLUMN_NAME\",$cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\",$cteTable.\"$VECTOR_ATTRIBUTE_NAME\",$cteTable.\"$DISTANCE_COLUMN_NAME\",\"$RETRIEVABLE_TYPE_COLUMN_NAME\" " + + "FROM $cteTable INNER JOIN \"$RETRIEVABLE_ENTITY_NAME\" ON (\"$RETRIEVABLE_ENTITY_NAME\".\"$RETRIEVABLE_ID_COLUMN_NAME\" = $cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\")" + + "ORDER BY $cteTable.\"$DISTANCE_COLUMN_NAME\" ${query.order}" + val stmt = this@VectorDescriptorReader.connection.jdbc.prepareStatement(sql) + stmt.setValue(1, query.value) + return stmt + } else { + val sql = "WITH $cteTable AS (" + + "SELECT \"$DESCRIPTOR_ID_COLUMN_NAME\",\"$RETRIEVABLE_ID_COLUMN_NAME\",\"$VECTOR_ATTRIBUTE_NAME\",\"$VECTOR_ATTRIBUTE_NAME\" ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME " + + "FROM $tableName " + + "WHERE $RETRIEVABLE_ID_COLUMN_NAME = ANY(?) " + + "ORDER BY \"$DISTANCE_COLUMN_NAME\" ${query.order} " + + "LIMIT ${query.k}" + + ") SELECT $cteTable.\"$DESCRIPTOR_ID_COLUMN_NAME\",$cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\",$cteTable.\"$VECTOR_ATTRIBUTE_NAME\",$cteTable.\"$DISTANCE_COLUMN_NAME\",\"$RETRIEVABLE_TYPE_COLUMN_NAME\" " + + "FROM $cteTable INNER JOIN \"$RETRIEVABLE_ENTITY_NAME\" ON (\"$RETRIEVABLE_ENTITY_NAME\".\"$RETRIEVABLE_ID_COLUMN_NAME\" = $cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\")" + + "ORDER BY $cteTable.\"$DISTANCE_COLUMN_NAME\" ${query.order}" - this@VectorDescriptorReader.connection.jdbc.prepareStatement(statement).use { stmt -> + val retrievableIds = this.resolveBooleanPredicate(filter) + val stmt = this@VectorDescriptorReader.connection.jdbc.prepareStatement(sql) stmt.setValue(1, query.value) - stmt.executeQuery().use { result -> - while (result.next()) { - val retrieved = Retrieved(result.getObject(RETRIEVABLE_ID_COLUMN_NAME, UUID::class.java), result.getString(RETRIEVABLE_TYPE_COLUMN_NAME), true) - retrieved.addAttribute(DistanceAttribute(result.getFloat(DISTANCE_COLUMN_NAME))) - if (query.fetchVector) { - retrieved.addDescriptor(rowToDescriptor(result)) - } - yield(retrieved) - } - } + stmt.setArray(2, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray())) + return stmt } } } \ No newline at end of file