Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add very simple retriever for image classifier. #98

Merged
merged 5 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@ import org.vitrivr.engine.base.features.external.api.AbstractApi
import org.vitrivr.engine.base.features.external.common.ExternalFesAnalyser
import org.vitrivr.engine.core.context.IndexContext
import org.vitrivr.engine.core.context.QueryContext
import org.vitrivr.engine.core.features.AbstractRetriever
import org.vitrivr.engine.core.model.content.Content
import org.vitrivr.engine.core.model.content.element.ContentElement
import org.vitrivr.engine.core.model.content.element.ImageContent
import org.vitrivr.engine.core.model.content.element.TextContent
import org.vitrivr.engine.core.model.descriptor.struct.LabelDescriptor
import org.vitrivr.engine.core.model.descriptor.vector.FloatVectorDescriptor
import org.vitrivr.engine.core.model.metamodel.Analyser
import org.vitrivr.engine.core.model.metamodel.Analyser.Companion.merge
import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.query.Query
import org.vitrivr.engine.core.model.query.bool.SimpleBooleanQuery
import org.vitrivr.engine.core.model.query.proximity.ProximityQuery
import org.vitrivr.engine.core.model.retrievable.Retrievable
import org.vitrivr.engine.core.model.types.Value
import org.vitrivr.engine.core.operators.Operator
Expand All @@ -22,14 +30,14 @@ import java.util.*
* @author Fynn Faber
* @version 1.0.0
*/
class ImageClassification : ExternalFesAnalyser<ImageContent, LabelDescriptor>() {
class ImageClassification : ExternalFesAnalyser<ContentElement<*>, LabelDescriptor>() {
companion object{
const val CLASSES_PARAMETER_NAME = "classes"
const val THRESHOLD_PARAMETER_NAME = "threshold"
const val TOPK_PARAMETER_NAME = "top_k"
}

override val contentClasses = setOf(ImageContent::class)
override val contentClasses = setOf(ImageContent::class, TextContent::class)
override val descriptorClass = LabelDescriptor::class

/**
Expand Down Expand Up @@ -60,13 +68,55 @@ class ImageClassification : ExternalFesAnalyser<ImageContent, LabelDescriptor>()
* @param context The [IndexContext] to use with the [ImageClassification].
* @return [ImageClassification]
*/
override fun newExtractor(field: Schema.Field<ImageContent, LabelDescriptor>, input: Operator<Retrievable>, context: IndexContext) = ImageClassificationExtractor(input, field, this, merge(field, context))
override fun newExtractor(field: Schema.Field<ContentElement<*>, LabelDescriptor>, input: Operator<Retrievable>, context: IndexContext) = ImageClassificationExtractor(input, field, this, merge(field, context))

override fun newRetrieverForContent(field: Schema.Field<ImageContent, LabelDescriptor>, content: Collection<ImageContent>, context: QueryContext): Retriever<ImageContent, LabelDescriptor> {
TODO("Not yet implemented")
/**
* Generates and returns a new [Retriever] instance for this [ImageClassification].
*
* @param field The [Schema.Field] to create an [Retriever] for.
* @param query The [Query] to use with the [Retriever].
* @param context The [QueryContext] to use with the [Retriever].
*
* @return A new [Retriever] instance for this [Analyser]
*/
override fun newRetrieverForQuery(field: Schema.Field<ContentElement<*>, LabelDescriptor>, query: Query, context: QueryContext): Retriever<ContentElement<*>, LabelDescriptor> {
require(field.analyser == this) { "The field '${field.fieldName}' analyser does not correspond with this analyser. This is a programmer's error!" }
require(query is SimpleBooleanQuery<*>) { "The query is not a boolean query. This is a programmer's error!" }
return object : AbstractRetriever<ContentElement<*>, LabelDescriptor>(field, query, context){}
}

/**
* Generates and returns a new [Retriever] instance for this [ImageClassification].
*
* Invoking this method involves converting the provided [FloatVectorDescriptor] into a [ProximityQuery] that can be used to retrieve similar [ImageContent] elements.
*
* @param field The [Schema.Field] to create an [Retriever] for.
* @param descriptors An array of [FloatVectorDescriptor] elements to use with the [Retriever]
* @param context The [QueryContext] to use with the [Retriever]
*/
override fun newRetrieverForDescriptors(field: Schema.Field<ContentElement<*>, LabelDescriptor>, descriptors: Collection<LabelDescriptor>, context: QueryContext): Retriever<ContentElement<*>, LabelDescriptor> {
val descriptor = descriptors.firstOrNull()?.label ?: throw IllegalArgumentException("No label descriptor provided.")
val query = SimpleBooleanQuery(value = descriptor, attributeName = "label")
return newRetrieverForQuery(field, query, context)
}

override fun newRetrieverForQuery(field: Schema.Field<ImageContent, LabelDescriptor>, query: Query, context: QueryContext): Retriever<ImageContent, LabelDescriptor> {
TODO("Not yet implemented")
/**
* Generates and returns a new [Retriever] instance for this [ImageClassification].
*
* Invoking this method involves converting the provided [ImageContent] and the [QueryContext] into a [FloatVectorDescriptor]
* that can be used to retrieve similar [ImageContent] elements.
*
* @param field The [Schema.Field] to create an [Retriever] for.
* @param content An array of [Content] elements to use with the [Retriever]
* @param context The [QueryContext] to use with the [Retriever]
*/
override fun newRetrieverForContent(
field: Schema.Field<ContentElement<*>, LabelDescriptor>,
content: Collection<ContentElement<*>>,
context: QueryContext
): Retriever<ContentElement<*>, LabelDescriptor> {
val first = content.filterIsInstance<TextContent>().firstOrNull() ?: throw IllegalArgumentException("The content does not contain any text.")
val query = SimpleBooleanQuery(value = Value.String(first.content), attributeName = "label")
return newRetrieverForQuery(field, query, context)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.vitrivr.engine.base.features.external.common.FesExtractor
import org.vitrivr.engine.base.features.external.implementations.classification.ImageClassification.Companion.CLASSES_PARAMETER_NAME
import org.vitrivr.engine.base.features.external.implementations.classification.ImageClassification.Companion.THRESHOLD_PARAMETER_NAME
import org.vitrivr.engine.base.features.external.implementations.classification.ImageClassification.Companion.TOPK_PARAMETER_NAME
import org.vitrivr.engine.core.model.content.element.ContentElement
import org.vitrivr.engine.core.model.content.element.ImageContent
import org.vitrivr.engine.core.model.descriptor.struct.LabelDescriptor
import org.vitrivr.engine.core.model.metamodel.Schema
Expand All @@ -21,10 +22,10 @@ import java.util.*
*/
class ImageClassificationExtractor(
input: Operator<Retrievable>,
field: Schema.Field<ImageContent, LabelDescriptor>?,
analyser: ExternalFesAnalyser<ImageContent, LabelDescriptor>,
field: Schema.Field<ContentElement<*>, LabelDescriptor>?,
analyser: ExternalFesAnalyser<ContentElement<*>, LabelDescriptor>,
parameters: Map<String, String>
) : FesExtractor<ImageContent, LabelDescriptor>(input, field, analyser, parameters) {
) : FesExtractor<ContentElement<*>, LabelDescriptor>(input, field, analyser, parameters) {


/** The [ZeroShotClassificationApi] used to perform extraction with. */
Expand All @@ -46,10 +47,11 @@ class ImageClassificationExtractor(
val topK = this.parameters[TOPK_PARAMETER_NAME]?.toInt() ?: 1
val threshold = this.parameters[THRESHOLD_PARAMETER_NAME]?.toFloat() ?: 0.0f

val flatResults = this.api.analyseBatched(
retrievables.flatMap {
this.filterContent(it).map { it to classes }
}).mapIndexed { idx, result ->
val content = retrievables.mapIndexed { idx, retrievable ->
this.filterContent(retrievable).filterIsInstance<ImageContent>().map { idx to (it to classes) }
}.flatten()

return this.api.analyseBatched(content.map{it.second}).zip(content.map{it.first}).map { (result, idx) ->
result.mapIndexed { idy, confidence ->
LabelDescriptor(
UUID.randomUUID(),
Expand All @@ -62,6 +64,5 @@ class ImageClassificationExtractor(
)
}.filter { it.confidence.value >= threshold }.sortedByDescending { it.confidence.value }.take(topK)
}
return flatResults
}
}