diff --git a/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/descriptor/struct/LabelDescriptor.kt b/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/descriptor/struct/LabelDescriptor.kt index 0e132554..c3eff9a9 100644 --- a/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/descriptor/struct/LabelDescriptor.kt +++ b/vitrivr-engine-core/src/main/kotlin/org/vitrivr/engine/core/model/descriptor/struct/LabelDescriptor.kt @@ -21,9 +21,13 @@ class LabelDescriptor( override val field: Schema.Field<*, LabelDescriptor>? = null ) : StructDescriptor(id, retrievableId, SCHEMA, values, field) { companion object { + + const val LABEL_FIELD_NAME = "label" + const val CONFIDENCE_FIELD_NAME = "label" + private val SCHEMA = listOf( - Attribute("label", Type.String), - Attribute("confidence", Type.Float), + Attribute(LABEL_FIELD_NAME, Type.String), + Attribute(CONFIDENCE_FIELD_NAME, Type.Float), ) } @@ -33,7 +37,7 @@ class LabelDescriptor( label: String, confidence: Float = 1f, field: Schema.Field<*, LabelDescriptor>? = null - ) : this(id, retrievableId, mapOf("label" to Value.String(label), "confidence" to Value.Float(confidence)), field) + ) : this(id, retrievableId, mapOf(LABEL_FIELD_NAME to Value.String(label), CONFIDENCE_FIELD_NAME to Value.Float(confidence)), field) /** The stored label. */ val label: Value.String by this.values diff --git a/vitrivr-engine-module-torchserve/src/main/kotlin/org/vitrivr/engine/features/external/torchserve/TSImageLabel.kt b/vitrivr-engine-module-torchserve/src/main/kotlin/org/vitrivr/engine/features/external/torchserve/TSImageLabel.kt index 8bfb8cb8..b513af1e 100644 --- a/vitrivr-engine-module-torchserve/src/main/kotlin/org/vitrivr/engine/features/external/torchserve/TSImageLabel.kt +++ b/vitrivr-engine-module-torchserve/src/main/kotlin/org/vitrivr/engine/features/external/torchserve/TSImageLabel.kt @@ -4,12 +4,15 @@ import com.google.protobuf.ByteString import kotlinx.serialization.json.Json import org.vitrivr.engine.core.context.IndexContext import org.vitrivr.engine.core.context.QueryContext +import org.vitrivr.engine.core.features.bool.StructBooleanRetriever import org.vitrivr.engine.core.model.content.Content import org.vitrivr.engine.core.model.content.element.ImageContent import org.vitrivr.engine.core.model.descriptor.struct.LabelDescriptor import org.vitrivr.engine.core.model.metamodel.Analyser import org.vitrivr.engine.core.model.metamodel.Schema import org.vitrivr.engine.core.model.query.Query +import org.vitrivr.engine.core.model.query.basics.ComparisonOperator +import org.vitrivr.engine.core.model.query.bool.SimpleBooleanQuery import org.vitrivr.engine.core.model.retrievable.Retrievable import org.vitrivr.engine.core.operators.Operator import org.vitrivr.engine.core.operators.retrieve.Retriever @@ -44,20 +47,31 @@ class TSImageLabel : TorchServe() { */ override fun prototype(field: Schema.Field<*, *>) = LabelDescriptor(UUID.randomUUID(), UUID.randomUUID(), "", 0.0f, null) - /** + * Generates and returns a new [StructBooleanRetriever] instance for this [TSImageLabel]. + * + * @param field The [Schema.Field] to create an [Retriever] for. + * @param query The [Query] to use with the [Retriever]. + * @param context The [QueryContext] to use with the [Retriever]. * + * @return A new [StructBooleanRetriever] instance for this [TSImageLabel] */ - override fun newRetrieverForQuery(field: Schema.Field, query: Query, context: QueryContext): Retriever { - TODO() + override fun newRetrieverForQuery(field: Schema.Field, query: Query, context: QueryContext): StructBooleanRetriever { + require(query is SimpleBooleanQuery<*>) { "TSImageLabel only supports boolean queries." } + return StructBooleanRetriever(field, query, context) } - /** + * Generates and returns a new [StructBooleanRetriever] instance for this [TSImageLabel]. * + * @param field The [Schema.Field] to create an [Retriever] for. + * @param descriptors An array of [LabelDescriptor] elements to use with the [Retriever] + * @param context The [QueryContext] to use with the [Retriever] */ - override fun newRetrieverForDescriptors(field: Schema.Field, descriptors: Collection, context: QueryContext): Retriever { - TODO() + override fun newRetrieverForDescriptors(field: Schema.Field, descriptors: Collection, context: QueryContext): StructBooleanRetriever { + val values = descriptors.map { it.label } + val query = SimpleBooleanQuery(values.first(), ComparisonOperator.EQ, LabelDescriptor.LABEL_FIELD_NAME) + return this.newRetrieverForQuery(field, query, context) } /**