diff --git a/vitrivr-engine-module-fes/src/main/kotlin/org/vitrivr/engine/base/features/external/implementations/classification/ImageClassification.kt b/vitrivr-engine-module-fes/src/main/kotlin/org/vitrivr/engine/base/features/external/implementations/classification/ImageClassification.kt index 5b8f5d036..70f0fac3e 100644 --- a/vitrivr-engine-module-fes/src/main/kotlin/org/vitrivr/engine/base/features/external/implementations/classification/ImageClassification.kt +++ b/vitrivr-engine-module-fes/src/main/kotlin/org/vitrivr/engine/base/features/external/implementations/classification/ImageClassification.kt @@ -4,11 +4,19 @@ import org.vitrivr.engine.base.features.external.api.AbstractApi import org.vitrivr.engine.base.features.external.common.ExternalFesAnalyser import org.vitrivr.engine.core.context.IndexContext import org.vitrivr.engine.core.context.QueryContext +import org.vitrivr.engine.core.features.AbstractRetriever +import org.vitrivr.engine.core.model.content.Content +import org.vitrivr.engine.core.model.content.element.ContentElement import org.vitrivr.engine.core.model.content.element.ImageContent +import org.vitrivr.engine.core.model.content.element.TextContent import org.vitrivr.engine.core.model.descriptor.struct.LabelDescriptor +import org.vitrivr.engine.core.model.descriptor.vector.FloatVectorDescriptor +import org.vitrivr.engine.core.model.metamodel.Analyser import org.vitrivr.engine.core.model.metamodel.Analyser.Companion.merge import org.vitrivr.engine.core.model.metamodel.Schema import org.vitrivr.engine.core.model.query.Query +import org.vitrivr.engine.core.model.query.bool.SimpleBooleanQuery +import org.vitrivr.engine.core.model.query.proximity.ProximityQuery import org.vitrivr.engine.core.model.retrievable.Retrievable import org.vitrivr.engine.core.model.types.Value import org.vitrivr.engine.core.operators.Operator @@ -22,14 +30,14 @@ import java.util.* * @author Fynn Faber * @version 1.0.0 */ -class ImageClassification : ExternalFesAnalyser() { +class ImageClassification : ExternalFesAnalyser, LabelDescriptor>() { companion object{ const val CLASSES_PARAMETER_NAME = "classes" const val THRESHOLD_PARAMETER_NAME = "threshold" const val TOPK_PARAMETER_NAME = "top_k" } - override val contentClasses = setOf(ImageContent::class) + override val contentClasses = setOf(ImageContent::class, TextContent::class) override val descriptorClass = LabelDescriptor::class /** @@ -60,13 +68,55 @@ class ImageClassification : ExternalFesAnalyser() * @param context The [IndexContext] to use with the [ImageClassification]. * @return [ImageClassification] */ - override fun newExtractor(field: Schema.Field, input: Operator, context: IndexContext) = ImageClassificationExtractor(input, field, this, merge(field, context)) + override fun newExtractor(field: Schema.Field, LabelDescriptor>, input: Operator, context: IndexContext) = ImageClassificationExtractor(input, field, this, merge(field, context)) - override fun newRetrieverForContent(field: Schema.Field, content: Collection, context: QueryContext): Retriever { - TODO("Not yet implemented") + /** + * Generates and returns a new [Retriever] instance for this [ImageClassification]. + * + * @param field The [Schema.Field] to create an [Retriever] for. + * @param query The [Query] to use with the [Retriever]. + * @param context The [QueryContext] to use with the [Retriever]. + * + * @return A new [Retriever] instance for this [Analyser] + */ + override fun newRetrieverForQuery(field: Schema.Field, LabelDescriptor>, query: Query, context: QueryContext): Retriever, LabelDescriptor> { + require(field.analyser == this) { "The field '${field.fieldName}' analyser does not correspond with this analyser. This is a programmer's error!" } + require(query is SimpleBooleanQuery<*>) { "The query is not a boolean query. This is a programmer's error!" } + return object : AbstractRetriever, LabelDescriptor>(field, query, context){} + } + + /** + * Generates and returns a new [Retriever] instance for this [ImageClassification]. + * + * Invoking this method involves converting the provided [FloatVectorDescriptor] into a [ProximityQuery] that can be used to retrieve similar [ImageContent] elements. + * + * @param field The [Schema.Field] to create an [Retriever] for. + * @param descriptors An array of [FloatVectorDescriptor] elements to use with the [Retriever] + * @param context The [QueryContext] to use with the [Retriever] + */ + override fun newRetrieverForDescriptors(field: Schema.Field, LabelDescriptor>, descriptors: Collection, context: QueryContext): Retriever, LabelDescriptor> { + val descriptor = descriptors.firstOrNull()?.label ?: throw IllegalArgumentException("No label descriptor provided.") + val query = SimpleBooleanQuery(value = descriptor, attributeName = "label") + return newRetrieverForQuery(field, query, context) } - override fun newRetrieverForQuery(field: Schema.Field, query: Query, context: QueryContext): Retriever { - TODO("Not yet implemented") + /** + * Generates and returns a new [Retriever] instance for this [ImageClassification]. + * + * Invoking this method involves converting the provided [ImageContent] and the [QueryContext] into a [FloatVectorDescriptor] + * that can be used to retrieve similar [ImageContent] elements. + * + * @param field The [Schema.Field] to create an [Retriever] for. + * @param content An array of [Content] elements to use with the [Retriever] + * @param context The [QueryContext] to use with the [Retriever] + */ + override fun newRetrieverForContent( + field: Schema.Field, LabelDescriptor>, + content: Collection>, + context: QueryContext + ): Retriever, LabelDescriptor> { + val first = content.filterIsInstance().firstOrNull() ?: throw IllegalArgumentException("The content does not contain any text.") + val query = SimpleBooleanQuery(value = Value.String(first.content), attributeName = "label") + return newRetrieverForQuery(field, query, context) } } \ No newline at end of file diff --git a/vitrivr-engine-module-fes/src/main/kotlin/org/vitrivr/engine/base/features/external/implementations/classification/ImageClassificationExtractor.kt b/vitrivr-engine-module-fes/src/main/kotlin/org/vitrivr/engine/base/features/external/implementations/classification/ImageClassificationExtractor.kt index 9684ef944..44155f7f0 100644 --- a/vitrivr-engine-module-fes/src/main/kotlin/org/vitrivr/engine/base/features/external/implementations/classification/ImageClassificationExtractor.kt +++ b/vitrivr-engine-module-fes/src/main/kotlin/org/vitrivr/engine/base/features/external/implementations/classification/ImageClassificationExtractor.kt @@ -6,6 +6,7 @@ import org.vitrivr.engine.base.features.external.common.FesExtractor import org.vitrivr.engine.base.features.external.implementations.classification.ImageClassification.Companion.CLASSES_PARAMETER_NAME import org.vitrivr.engine.base.features.external.implementations.classification.ImageClassification.Companion.THRESHOLD_PARAMETER_NAME import org.vitrivr.engine.base.features.external.implementations.classification.ImageClassification.Companion.TOPK_PARAMETER_NAME +import org.vitrivr.engine.core.model.content.element.ContentElement import org.vitrivr.engine.core.model.content.element.ImageContent import org.vitrivr.engine.core.model.descriptor.struct.LabelDescriptor import org.vitrivr.engine.core.model.metamodel.Schema @@ -21,10 +22,10 @@ import java.util.* */ class ImageClassificationExtractor( input: Operator, - field: Schema.Field?, - analyser: ExternalFesAnalyser, + field: Schema.Field, LabelDescriptor>?, + analyser: ExternalFesAnalyser, LabelDescriptor>, parameters: Map -) : FesExtractor(input, field, analyser, parameters) { +) : FesExtractor, LabelDescriptor>(input, field, analyser, parameters) { /** The [ZeroShotClassificationApi] used to perform extraction with. */ @@ -46,10 +47,11 @@ class ImageClassificationExtractor( val topK = this.parameters[TOPK_PARAMETER_NAME]?.toInt() ?: 1 val threshold = this.parameters[THRESHOLD_PARAMETER_NAME]?.toFloat() ?: 0.0f - val flatResults = this.api.analyseBatched( - retrievables.flatMap { - this.filterContent(it).map { it to classes } - }).mapIndexed { idx, result -> + val content = retrievables.mapIndexed { idx, retrievable -> + this.filterContent(retrievable).filterIsInstance().map { idx to (it to classes) } + }.flatten() + + return this.api.analyseBatched(content.map{it.second}).zip(content.map{it.first}).map { (result, idx) -> result.mapIndexed { idy, confidence -> LabelDescriptor( UUID.randomUUID(), @@ -62,6 +64,5 @@ class ImageClassificationExtractor( ) }.filter { it.confidence.value >= threshold }.sortedByDescending { it.confidence.value }.take(topK) } - return flatResults } } \ No newline at end of file