Skip to content

Commit

Permalink
Merge pull request #98 from vitrivr/feature/imageclassificationretriever
Browse files Browse the repository at this point in the history
Adds very simple retriever for `ImageClassifier`.
  • Loading branch information
ppanopticon authored Aug 23, 2024
2 parents 917f138 + 11901cb commit b04f878
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@ import org.vitrivr.engine.base.features.external.api.AbstractApi
import org.vitrivr.engine.base.features.external.common.ExternalFesAnalyser
import org.vitrivr.engine.core.context.IndexContext
import org.vitrivr.engine.core.context.QueryContext
import org.vitrivr.engine.core.features.AbstractRetriever
import org.vitrivr.engine.core.model.content.Content
import org.vitrivr.engine.core.model.content.element.ContentElement
import org.vitrivr.engine.core.model.content.element.ImageContent
import org.vitrivr.engine.core.model.content.element.TextContent
import org.vitrivr.engine.core.model.descriptor.struct.LabelDescriptor
import org.vitrivr.engine.core.model.descriptor.vector.FloatVectorDescriptor
import org.vitrivr.engine.core.model.metamodel.Analyser
import org.vitrivr.engine.core.model.metamodel.Analyser.Companion.merge
import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.query.Query
import org.vitrivr.engine.core.model.query.bool.SimpleBooleanQuery
import org.vitrivr.engine.core.model.query.proximity.ProximityQuery
import org.vitrivr.engine.core.model.retrievable.Retrievable
import org.vitrivr.engine.core.model.types.Value
import org.vitrivr.engine.core.operators.Operator
Expand All @@ -22,14 +30,14 @@ import java.util.*
* @author Fynn Faber
* @version 1.0.0
*/
class ImageClassification : ExternalFesAnalyser<ImageContent, LabelDescriptor>() {
class ImageClassification : ExternalFesAnalyser<ContentElement<*>, LabelDescriptor>() {
companion object{
const val CLASSES_PARAMETER_NAME = "classes"
const val THRESHOLD_PARAMETER_NAME = "threshold"
const val TOPK_PARAMETER_NAME = "top_k"
}

override val contentClasses = setOf(ImageContent::class)
override val contentClasses = setOf(ImageContent::class, TextContent::class)
override val descriptorClass = LabelDescriptor::class

/**
Expand Down Expand Up @@ -60,13 +68,55 @@ class ImageClassification : ExternalFesAnalyser<ImageContent, LabelDescriptor>()
* @param context The [IndexContext] to use with the [ImageClassification].
* @return [ImageClassification]
*/
override fun newExtractor(field: Schema.Field<ImageContent, LabelDescriptor>, input: Operator<Retrievable>, context: IndexContext) = ImageClassificationExtractor(input, field, this, merge(field, context))
override fun newExtractor(field: Schema.Field<ContentElement<*>, LabelDescriptor>, input: Operator<Retrievable>, context: IndexContext) = ImageClassificationExtractor(input, field, this, merge(field, context))

override fun newRetrieverForContent(field: Schema.Field<ImageContent, LabelDescriptor>, content: Collection<ImageContent>, context: QueryContext): Retriever<ImageContent, LabelDescriptor> {
TODO("Not yet implemented")
/**
* Generates and returns a new [Retriever] instance for this [ImageClassification].
*
* @param field The [Schema.Field] to create an [Retriever] for.
* @param query The [Query] to use with the [Retriever].
* @param context The [QueryContext] to use with the [Retriever].
*
* @return A new [Retriever] instance for this [Analyser]
*/
override fun newRetrieverForQuery(field: Schema.Field<ContentElement<*>, LabelDescriptor>, query: Query, context: QueryContext): Retriever<ContentElement<*>, LabelDescriptor> {
require(field.analyser == this) { "The field '${field.fieldName}' analyser does not correspond with this analyser. This is a programmer's error!" }
require(query is SimpleBooleanQuery<*>) { "The query is not a boolean query. This is a programmer's error!" }
return object : AbstractRetriever<ContentElement<*>, LabelDescriptor>(field, query, context){}
}

/**
* Generates and returns a new [Retriever] instance for this [ImageClassification].
*
* Invoking this method involves converting the provided [FloatVectorDescriptor] into a [ProximityQuery] that can be used to retrieve similar [ImageContent] elements.
*
* @param field The [Schema.Field] to create an [Retriever] for.
* @param descriptors An array of [FloatVectorDescriptor] elements to use with the [Retriever]
* @param context The [QueryContext] to use with the [Retriever]
*/
override fun newRetrieverForDescriptors(field: Schema.Field<ContentElement<*>, LabelDescriptor>, descriptors: Collection<LabelDescriptor>, context: QueryContext): Retriever<ContentElement<*>, LabelDescriptor> {
val descriptor = descriptors.firstOrNull()?.label ?: throw IllegalArgumentException("No label descriptor provided.")
val query = SimpleBooleanQuery(value = descriptor, attributeName = "label")
return newRetrieverForQuery(field, query, context)
}

override fun newRetrieverForQuery(field: Schema.Field<ImageContent, LabelDescriptor>, query: Query, context: QueryContext): Retriever<ImageContent, LabelDescriptor> {
TODO("Not yet implemented")
/**
* Generates and returns a new [Retriever] instance for this [ImageClassification].
*
* Invoking this method involves converting the provided [ImageContent] and the [QueryContext] into a [FloatVectorDescriptor]
* that can be used to retrieve similar [ImageContent] elements.
*
* @param field The [Schema.Field] to create an [Retriever] for.
* @param content An array of [Content] elements to use with the [Retriever]
* @param context The [QueryContext] to use with the [Retriever]
*/
override fun newRetrieverForContent(
field: Schema.Field<ContentElement<*>, LabelDescriptor>,
content: Collection<ContentElement<*>>,
context: QueryContext
): Retriever<ContentElement<*>, LabelDescriptor> {
val first = content.filterIsInstance<TextContent>().firstOrNull() ?: throw IllegalArgumentException("The content does not contain any text.")
val query = SimpleBooleanQuery(value = Value.String(first.content), attributeName = "label")
return newRetrieverForQuery(field, query, context)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.vitrivr.engine.base.features.external.common.FesExtractor
import org.vitrivr.engine.base.features.external.implementations.classification.ImageClassification.Companion.CLASSES_PARAMETER_NAME
import org.vitrivr.engine.base.features.external.implementations.classification.ImageClassification.Companion.THRESHOLD_PARAMETER_NAME
import org.vitrivr.engine.base.features.external.implementations.classification.ImageClassification.Companion.TOPK_PARAMETER_NAME
import org.vitrivr.engine.core.model.content.element.ContentElement
import org.vitrivr.engine.core.model.content.element.ImageContent
import org.vitrivr.engine.core.model.descriptor.struct.LabelDescriptor
import org.vitrivr.engine.core.model.metamodel.Schema
Expand All @@ -21,10 +22,10 @@ import java.util.*
*/
class ImageClassificationExtractor(
input: Operator<Retrievable>,
field: Schema.Field<ImageContent, LabelDescriptor>?,
analyser: ExternalFesAnalyser<ImageContent, LabelDescriptor>,
field: Schema.Field<ContentElement<*>, LabelDescriptor>?,
analyser: ExternalFesAnalyser<ContentElement<*>, LabelDescriptor>,
parameters: Map<String, String>
) : FesExtractor<ImageContent, LabelDescriptor>(input, field, analyser, parameters) {
) : FesExtractor<ContentElement<*>, LabelDescriptor>(input, field, analyser, parameters) {


/** The [ZeroShotClassificationApi] used to perform extraction with. */
Expand All @@ -46,10 +47,11 @@ class ImageClassificationExtractor(
val topK = this.parameters[TOPK_PARAMETER_NAME]?.toInt() ?: 1
val threshold = this.parameters[THRESHOLD_PARAMETER_NAME]?.toFloat() ?: 0.0f

val flatResults = this.api.analyseBatched(
retrievables.flatMap {
this.filterContent(it).map { it to classes }
}).mapIndexed { idx, result ->
val content = retrievables.mapIndexed { idx, retrievable ->
this.filterContent(retrievable).filterIsInstance<ImageContent>().map { idx to (it to classes) }
}.flatten()

return this.api.analyseBatched(content.map{it.second}).zip(content.map{it.first}).map { (result, idx) ->
result.mapIndexed { idy, confidence ->
LabelDescriptor(
UUID.randomUUID(),
Expand All @@ -62,6 +64,5 @@ class ImageClassificationExtractor(
)
}.filter { it.confidence.value >= threshold }.sortedByDescending { it.confidence.value }.take(topK)
}
return flatResults
}
}

0 comments on commit b04f878

Please sign in to comment.