Skip to content

Commit

Permalink
Sketches a simple case of filter-pushdown based on @net-cscience-raphael
Browse files Browse the repository at this point in the history
's idea.

Signed-off-by: Ralph Gasser <[email protected]>
  • Loading branch information
ppanopticon committed Nov 28, 2024
1 parent 3344a29 commit c61a7d8
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ package org.vitrivr.engine.core.model.query.bool
* @version 1.0
*/
sealed interface Logical : BooleanPredicate {
data class Not(val predicate: BooleanPredicate) : Logical

data class And(val predicates: List<BooleanPredicate>) : Logical

data class Or(val predicates: List<BooleanPredicate>) : Logical
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.vitrivr.engine.core.model.query.fulltext

import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.query.Predicate
import org.vitrivr.engine.core.model.query.bool.BooleanPredicate
import org.vitrivr.engine.core.model.types.Value


Expand All @@ -23,5 +24,8 @@ data class SimpleFulltextPredicate(
*
* Typically, this is pre-determined by the analyser. However, in some cases, this must be specified (e.g., when querying struct fields).
*/
val attributeName: String? = null
val attributeName: String? = null,

/** Optional filter query for this [SimpleFulltextPredicate]. */
val filter: BooleanPredicate? = null
) : Predicate
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.query.Predicate
import org.vitrivr.engine.core.model.query.basics.Distance
import org.vitrivr.engine.core.model.query.basics.SortOrder
import org.vitrivr.engine.core.model.query.bool.BooleanPredicate
import org.vitrivr.engine.core.model.query.fulltext.SimpleFulltextPredicate
import org.vitrivr.engine.core.model.types.Value

/**
Expand All @@ -17,7 +19,6 @@ import org.vitrivr.engine.core.model.types.Value
* @author Ralph Gasser
* @version 1.1.0
*/

data class ProximityPredicate<T : Value.Vector<*>>(
/** The [Schema.Field] that this [Predicate] is applied to. */
val field: Schema.Field<*, *>,
Expand All @@ -42,5 +43,8 @@ data class ProximityPredicate<T : Value.Vector<*>>(
*
* Typically, this is pre-determined by the analyser. However, in some cases, this must be specified (e.g., when querying struct fields).
*/
val attributeName: String? = null
val attributeName: String? = null,

/** Optional filter query for this [SimpleFulltextPredicate]. */
val filter: BooleanPredicate? = null
) : Predicate
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ internal fun BooleanPredicate.toWhere(): String = when (this) {
is Comparison<*> -> this.toTerm()
is Logical.And -> this.predicates.joinToString(" AND ", "(", ")") { it.toWhere() }
is Logical.Or -> this.predicates.joinToString(" OR ", "(", ")") { it.toWhere() }
is Logical.Not -> "NOT (${this.predicate.toWhere()})"
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ import org.vitrivr.engine.core.model.descriptor.DescriptorId
import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.query.Predicate
import org.vitrivr.engine.core.model.query.Query
import org.vitrivr.engine.core.model.query.bool.BooleanPredicate
import org.vitrivr.engine.core.model.query.bool.Comparison
import org.vitrivr.engine.core.model.query.bool.Logical
import org.vitrivr.engine.core.model.retrievable.RetrievableId
import org.vitrivr.engine.core.model.retrievable.Retrieved
import org.vitrivr.engine.database.pgvector.*
import org.vitrivr.engine.database.pgvector.descriptor.scalar.ScalarDescriptorReader
import org.vitrivr.engine.database.pgvector.descriptor.struct.StructDescriptorReader
import java.sql.ResultSet
import java.util.*

Expand Down Expand Up @@ -200,6 +205,44 @@ abstract class AbstractDescriptorReader<D : Descriptor<*>>(final override val fi
}
}

/**
* Resolves a complex [BooleanPredicate] into a set of [RetrievableId]s that match it.
*
* @param predicate [BooleanPredicate] to resolve.
* @return Set of [RetrievableId]s that match the [BooleanPredicate].
*/
internal fun resolveBooleanPredicate(predicate: BooleanPredicate): Set<RetrievableId> = when (predicate) {
is Comparison<*> -> {
val field = predicate.field
val reader = field.getReader()
when (reader) {
is ScalarDescriptorReader -> reader.queryComparison(predicate).map { it.id }.toSet()
is StructDescriptorReader -> reader.queryComparison(predicate).map { it.id }.toSet()
else -> throw IllegalArgumentException("Cannot resolve predicate $predicate.")
}
}

is Logical.And -> {
val intersection = mutableSetOf<RetrievableId>()
for ((index, child) in predicate.predicates.withIndex()) {
if (index == 0) {
intersection.addAll(resolveBooleanPredicate(child))
} else {
intersection.intersect(resolveBooleanPredicate(child))
}
}
intersection
}

is Logical.Or -> {
val union = mutableSetOf<RetrievableId>()
for (child in predicate.predicates) {
union.addAll(resolveBooleanPredicate(child))
}
union
}
}

/**
* Converts a [ResultSet] to a [Descriptor] of type [D].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ScalarDescriptorReader(field: Schema.Field<*, ScalarDescriptor<*, *>>, con
*/
override fun query(query: Query): Sequence<ScalarDescriptor<*, *>> = when (val predicate = query.predicate) {
is SimpleFulltextPredicate -> queryFulltext(predicate)
is Comparison<*> -> queryBoolean(predicate)
is Comparison<*> -> queryComparison(predicate)
else -> throw IllegalArgumentException("Query of type ${query::class} is not supported by ScalarDescriptorReader.")
}

Expand Down Expand Up @@ -84,7 +84,7 @@ class ScalarDescriptorReader(field: Schema.Field<*, ScalarDescriptor<*, *>>, con
* @param query The [Comparison] to execute.
* @return [Sequence] of [ScalarDescriptor]s.
*/
private fun queryBoolean(query: Comparison<*>): Sequence<ScalarDescriptor<*, *>> {
internal fun queryComparison(query: Comparison<*>): Sequence<ScalarDescriptor<*, *>> {
val statement = "SELECT * FROM \"$tableName\" WHERE ${query.toWhere()}"
return sequence {
this@ScalarDescriptorReader.connection.jdbc.prepareStatement(statement).use { stmt ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class StructDescriptorReader(field: Schema.Field<*, StructDescriptor<*>>, connec
*/
override fun query(query: Query): Sequence<StructDescriptor<*>> = when (val predicate = query.predicate) {
is SimpleFulltextPredicate -> queryFulltext(predicate)
is Comparison<*> -> queryBoolean(predicate)
is Comparison<*> -> queryComparison(predicate)
else -> throw IllegalArgumentException("Query of typ ${query::class} is not supported by StructDescriptorReader.")
}

Expand Down Expand Up @@ -106,7 +106,7 @@ class StructDescriptorReader(field: Schema.Field<*, StructDescriptor<*>>, connec
* @param query The [Comparison] to execute.
* @return [Sequence] of [StructDescriptor]s.
*/
private fun queryBoolean(query: Comparison<*>): Sequence<StructDescriptor<*>> {
internal fun queryComparison(query: Comparison<*>): Sequence<StructDescriptor<*>> {
require(query.attributeName != null) { "Query attribute must not be null for a fulltext query on a struct descriptor." }
val statement = "SELECT * FROM \"${tableName.lowercase()}\" WHERE ${query.toWhere()}"
return sequence {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,34 @@ import org.vitrivr.engine.database.pgvector.*
import org.vitrivr.engine.database.pgvector.descriptor.AbstractDescriptorReader
import org.vitrivr.engine.database.pgvector.descriptor.model.PgBitVector
import org.vitrivr.engine.database.pgvector.descriptor.model.PgVector
import java.sql.PreparedStatement
import java.sql.ResultSet
import java.util.*

/**
* An abstract implementation of a [DescriptorReader] for Cottontail DB.
*
* @author Ralph Gasser
* @version 1.0.0
* @version 1.1.0
*/
class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, connection: PgVectorConnection) : AbstractDescriptorReader<VectorDescriptor<*, *>>(field, connection) {
/**
* Executes the provided [Query] and returns a [Sequence] of [Retrieved]s that match it.
*
* @param query The [Query] to execute.
*/
override fun query(query: Query): Sequence<VectorDescriptor<*, *>> = when (val predicate = query.predicate) {
is ProximityPredicate<*> -> queryProximity(predicate)
else -> throw UnsupportedOperationException("Query of typ ${query::class} is not supported by VectorDescriptorReader.")
override fun query(query: Query): Sequence<VectorDescriptor<*, *>> = sequence {
when (val predicate = query.predicate) {
is ProximityPredicate<*> -> prepareProximity(predicate).use { stmt ->
stmt.executeQuery().use { result ->
while (result.next()) {
yield(rowToDescriptor(result))
}
}
}

else -> throw UnsupportedOperationException("Query of typ ${query::class} is not supported by VectorDescriptorReader.")
}
}

/**
Expand All @@ -41,9 +51,23 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con
* @param query The [Query] that should be executed.
* @return [Sequence] of [Retrieved].
*/
override fun queryAndJoin(query: Query): Sequence<Retrieved> = when (val predicate = query.predicate) {
is ProximityPredicate<*> -> queryAndJoinProximity(predicate)
else -> super.queryAndJoin(query)
override fun queryAndJoin(query: Query): Sequence<Retrieved> = sequence {
when (val predicate = query.predicate) {
is ProximityPredicate<*> -> queryAndJoinProximity(predicate).use { stmt ->
stmt.executeQuery().use { result ->
while (result.next()) {
val retrieved = Retrieved(result.getObject(RETRIEVABLE_ID_COLUMN_NAME, UUID::class.java), result.getString(RETRIEVABLE_TYPE_COLUMN_NAME), true)
retrieved.addAttribute(DistanceAttribute(result.getFloat(DISTANCE_COLUMN_NAME)))
if (predicate.fetchVector) {
retrieved.addDescriptor(rowToDescriptor(result))
}
yield(retrieved)
}
}
}

else -> throw UnsupportedOperationException("Query of typ ${query::class} is not supported by VectorDescriptorReader.")
}
}

/**
Expand Down Expand Up @@ -85,58 +109,79 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con
retrievableId,
result.getObject(VECTOR_ATTRIBUTE_NAME, PgBitVector::class.java)?.toBooleanVector() ?: throw IllegalArgumentException("The provided vector value is missing the required field '$VECTOR_ATTRIBUTE_NAME'.")
)

else -> throw IllegalArgumentException("Unsupported descriptor type ${this.prototype::class}.")
}
}

/**
* Executes a [ProximityPredicate] and returns a [Sequence] of [VectorDescriptor]s.
* Prepares a [ProximityPredicate] and returns a [PreparedStatement].
*
* @param query The [ProximityPredicate] to execute.
* @return [Sequence] of [VectorDescriptor]s.
* @return [PreparedStatement] for [ProximityPredicate].
*/
private fun queryProximity(query: ProximityPredicate<*>): Sequence<VectorDescriptor<*, *>> = sequence {
val statement =
"SELECT $DESCRIPTOR_ID_COLUMN_NAME, $RETRIEVABLE_ID_COLUMN_NAME, $VECTOR_ATTRIBUTE_NAME, $VECTOR_ATTRIBUTE_NAME ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME FROM \"${tableName.lowercase()}\" ORDER BY $DISTANCE_COLUMN_NAME ${query.order} LIMIT ${query.k}"
this@VectorDescriptorReader.connection.jdbc.prepareStatement(statement).use { stmt ->
private fun prepareProximity(query: ProximityPredicate<*>): PreparedStatement {
val tableName = this.tableName.lowercase()
val filter = query.filter
if (filter == null) {
val sql = "SELECT $DESCRIPTOR_ID_COLUMN_NAME, $RETRIEVABLE_ID_COLUMN_NAME, $VECTOR_ATTRIBUTE_NAME, $VECTOR_ATTRIBUTE_NAME ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME " +
"FROM \"$tableName\" " +
"ORDER BY $DISTANCE_COLUMN_NAME ${query.order} " +
"LIMIT ${query.k}"
val stmt = this@VectorDescriptorReader.connection.jdbc.prepareStatement(sql)
stmt.setValue(1, query.value)
stmt.executeQuery().use { result ->
while (result.next()) {
yield(rowToDescriptor(result))
}
}
return stmt
} else {
val sql = "SELECT $DESCRIPTOR_ID_COLUMN_NAME, $RETRIEVABLE_ID_COLUMN_NAME, $VECTOR_ATTRIBUTE_NAME, $VECTOR_ATTRIBUTE_NAME ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME " +
"FROM \"$tableName\" " +
"WHERE $RETRIEVABLE_ID_COLUMN_NAME = ANY(?) " +
"ORDER BY $DISTANCE_COLUMN_NAME ${query.order} " +
"LIMIT ${query.k}"

val retrievableIds = this.resolveBooleanPredicate(filter)
val stmt = this@VectorDescriptorReader.connection.jdbc.prepareStatement(sql)
stmt.setValue(1, query.value)
stmt.setArray(2, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray()))
return stmt
}
}

/**
* Executes a [ProximityPredicate] and returns a [Sequence] of [VectorDescriptor]s.
* Prepares a [ProximityPredicate] and returns a [Sequence] of [PreparedStatement]s.
*
* @param query The [ProximityPredicate] to execute.
* @return [Sequence] of [VectorDescriptor]s.
* @return [PreparedStatement].
*/
private fun queryAndJoinProximity(query: ProximityPredicate<*>): Sequence<Retrieved> = sequence {
val cteTable = "\"${this@VectorDescriptorReader.tableName}_nns\""
val statement = "WITH $cteTable AS (" +
"SELECT \"$DESCRIPTOR_ID_COLUMN_NAME\",\"$RETRIEVABLE_ID_COLUMN_NAME\",\"$VECTOR_ATTRIBUTE_NAME\",\"$VECTOR_ATTRIBUTE_NAME\" ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME " +
"FROM \"${tableName.lowercase()}\" ORDER BY \"$DISTANCE_COLUMN_NAME\" ${query.order} " +
"LIMIT ${query.k}" +
") SELECT $cteTable.\"$DESCRIPTOR_ID_COLUMN_NAME\",$cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\",$cteTable.\"$VECTOR_ATTRIBUTE_NAME\",$cteTable.\"$DISTANCE_COLUMN_NAME\",\"$RETRIEVABLE_TYPE_COLUMN_NAME\" " +
"FROM $cteTable INNER JOIN \"$RETRIEVABLE_ENTITY_NAME\" ON (\"$RETRIEVABLE_ENTITY_NAME\".\"$RETRIEVABLE_ID_COLUMN_NAME\" = $cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\")" +
"ORDER BY $cteTable.\"$DISTANCE_COLUMN_NAME\" ${query.order}"
private fun queryAndJoinProximity(query: ProximityPredicate<*>): PreparedStatement {
val tableName = "\"${this@VectorDescriptorReader.tableName.lowercase()}\""
val cteTable = "\"${tableName}_nns\""
val filter = query.filter
if (filter == null) {
val sql = "WITH $cteTable AS (" +
"SELECT \"$DESCRIPTOR_ID_COLUMN_NAME\",\"$RETRIEVABLE_ID_COLUMN_NAME\",\"$VECTOR_ATTRIBUTE_NAME\",\"$VECTOR_ATTRIBUTE_NAME\" ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME " +
"FROM $tableName " +
"ORDER BY \"$DISTANCE_COLUMN_NAME\" ${query.order} " +
"LIMIT ${query.k}" +
") SELECT $cteTable.\"$DESCRIPTOR_ID_COLUMN_NAME\",$cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\",$cteTable.\"$VECTOR_ATTRIBUTE_NAME\",$cteTable.\"$DISTANCE_COLUMN_NAME\",\"$RETRIEVABLE_TYPE_COLUMN_NAME\" " +
"FROM $cteTable INNER JOIN \"$RETRIEVABLE_ENTITY_NAME\" ON (\"$RETRIEVABLE_ENTITY_NAME\".\"$RETRIEVABLE_ID_COLUMN_NAME\" = $cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\")" +
"ORDER BY $cteTable.\"$DISTANCE_COLUMN_NAME\" ${query.order}"
val stmt = this@VectorDescriptorReader.connection.jdbc.prepareStatement(sql)
stmt.setValue(1, query.value)
return stmt
} else {
val sql = "WITH $cteTable AS (" +
"SELECT \"$DESCRIPTOR_ID_COLUMN_NAME\",\"$RETRIEVABLE_ID_COLUMN_NAME\",\"$VECTOR_ATTRIBUTE_NAME\",\"$VECTOR_ATTRIBUTE_NAME\" ${query.distance.toSql()} ? AS $DISTANCE_COLUMN_NAME " +
"FROM $tableName " +
"WHERE $RETRIEVABLE_ID_COLUMN_NAME = ANY(?) " +
"ORDER BY \"$DISTANCE_COLUMN_NAME\" ${query.order} " +
"LIMIT ${query.k}" +
") SELECT $cteTable.\"$DESCRIPTOR_ID_COLUMN_NAME\",$cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\",$cteTable.\"$VECTOR_ATTRIBUTE_NAME\",$cteTable.\"$DISTANCE_COLUMN_NAME\",\"$RETRIEVABLE_TYPE_COLUMN_NAME\" " +
"FROM $cteTable INNER JOIN \"$RETRIEVABLE_ENTITY_NAME\" ON (\"$RETRIEVABLE_ENTITY_NAME\".\"$RETRIEVABLE_ID_COLUMN_NAME\" = $cteTable.\"$RETRIEVABLE_ID_COLUMN_NAME\")" +
"ORDER BY $cteTable.\"$DISTANCE_COLUMN_NAME\" ${query.order}"

this@VectorDescriptorReader.connection.jdbc.prepareStatement(statement).use { stmt ->
val retrievableIds = this.resolveBooleanPredicate(filter)
val stmt = this@VectorDescriptorReader.connection.jdbc.prepareStatement(sql)
stmt.setValue(1, query.value)
stmt.executeQuery().use { result ->
while (result.next()) {
val retrieved = Retrieved(result.getObject(RETRIEVABLE_ID_COLUMN_NAME, UUID::class.java), result.getString(RETRIEVABLE_TYPE_COLUMN_NAME), true)
retrieved.addAttribute(DistanceAttribute(result.getFloat(DISTANCE_COLUMN_NAME)))
if (query.fetchVector) {
retrieved.addDescriptor(rowToDescriptor(result))
}
yield(retrieved)
}
}
stmt.setArray(2, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray()))
return stmt
}
}
}

0 comments on commit c61a7d8

Please sign in to comment.