Skip to content

Commit

Permalink
Merged image and text clip implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
lucaro committed Nov 18, 2023
1 parent fad807c commit 72d55b3
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 308 deletions.
51 changes: 49 additions & 2 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
},
{
"name": "clip",
"factory": "CLIPImage"
"factory": "CLIP"
},
{
"name": "dino",
Expand Down Expand Up @@ -82,7 +82,7 @@
},
{
"name": "clip",
"factory": "CLIPImage"
"factory": "CLIP"
},
{
"name": "dino",
Expand Down Expand Up @@ -112,6 +112,53 @@
}
}
]
},
{
"name": "MVK",
"connection": {
"database": "CottontailConnectionProvider",
"parameters": {
"host": "127.0.0.1",
"port": "1865"
}
},
"fields": [
{
"name": "averagecolor",
"factory": "AverageColor"
},
{
"name": "clip",
"factory": "CLIP"
},
{
"name": "dino",
"factory": "DINO"
},
{
"name": "file",
"factory": "FileMetadata"
},
{
"name": "time",
"factory": "TemporalMetadata"
}
],
"exporters": [
{
"name": "thumbnail",
"factory": "ThumbnailExporter",
"parameters": {
"key": "ThumbnailExporter-value-schema"
},
"resolver": {
"factory": "DiskResolver",
"parameters": {
"key": "DiskResolver-value-schema"
}
}
}
]
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,12 @@ import org.vitrivr.engine.core.model.metamodel.Analyser
* @param T Type of [ContentElement] that this external analyzer operates on.
* @param U Type of [Descriptor] produced by this external analyzer.
*
* @property host The host address of the external feature extraction service.
* @property port The port of the external feature extraction service.
* @property endpoint The endpoint of the external feature to extract.
*
* @see [Analyser]
*
* @author Rahel Arnold
* @version 1.0.0
*/
abstract class ExternalAnalyser<T : ContentElement<*>, U : Descriptor> : Analyser<T, U> {
/** The host address of the external feature extraction service. */
abstract val host: String

/** The port of the external feature extraction service. */
abstract val port: Int

/** The endpoint for the external feature extraction service. */
abstract val endpoint: String

/**
* Requests the external feature descriptor for the given [ContentElement].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ abstract class ExternalWithFloatVectorDescriptorAnalyser<C : ContentElement<*>>
return resultList
}

fun httpRequest(content: ContentElement<*>): List<Float> {
val url = "http://$host:$port$endpoint"
fun httpRequest(content: ContentElement<*>, url: String): List<Float> {
val base64 = when (content) {
is TextContent -> encodeTextToBase64(content.content)
is ImageContent -> content.content.toDataURL()
Expand All @@ -147,6 +146,4 @@ abstract class ExternalWithFloatVectorDescriptorAnalyser<C : ContentElement<*>>
// Encode the byte array to base64
return Base64.getEncoder().encodeToString(textBytes)
}


}
Original file line number Diff line number Diff line change
@@ -1,54 +1,41 @@
package org.vitrivr.engine.base.features.external.implementations.clip.image
package org.vitrivr.engine.base.features.external.implementations.clip

import io.github.oshai.kotlinlogging.KLogger
import io.github.oshai.kotlinlogging.KotlinLogging
import org.vitrivr.engine.base.features.external.ExternalAnalyser
import org.vitrivr.engine.base.features.external.common.ExternalWithFloatVectorDescriptorAnalyser
import org.vitrivr.engine.base.features.external.implementations.dino.DINO
import org.vitrivr.engine.core.context.IndexContext
import org.vitrivr.engine.core.context.QueryContext
import org.vitrivr.engine.core.model.content.Content
import org.vitrivr.engine.core.model.content.element.ContentElement
import org.vitrivr.engine.core.model.content.element.ImageContent
import org.vitrivr.engine.core.model.descriptor.Descriptor
import org.vitrivr.engine.core.model.content.element.TextContent
import org.vitrivr.engine.core.model.descriptor.vector.FloatVectorDescriptor
import org.vitrivr.engine.core.model.metamodel.Analyser
import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.retrievable.Retrievable
import org.vitrivr.engine.core.operators.Operator
import org.vitrivr.engine.core.operators.ingest.Extractor
import org.vitrivr.engine.core.operators.retrieve.Retriever
import java.awt.Image
import java.util.*

private val logger: KLogger = KotlinLogging.logger {}

/**
* Implementation of the [CLIPImage] [ExternalAnalyser], which derives the CLIP feature from an [ImageContent] as [FloatVectorDescriptor].
* Implementation of the [CLIP] [ExternalAnalyser], which derives the CLIP feature from an [ImageContent] or [TextContent] as [FloatVectorDescriptor].
*
* @author Rahel Arnold
* @version 1.0.0
*/
class CLIPImage : ExternalWithFloatVectorDescriptorAnalyser<ImageContent>() {
class CLIP : ExternalWithFloatVectorDescriptorAnalyser<ContentElement<*>>() {

companion object {
/**
* Static method to request the CLIP feature descriptor for the given [ContentElement].
*
* @param content The [ContentElement] for which to request the CLIP feature descriptor.
* @return A list of CLIP feature descriptors.
*/
fun requestDescriptor(content: ContentElement<*>): List<Float> {
return CLIPImage().httpRequest(content)
}
private val logger: KLogger = KotlinLogging.logger {}
}

override val contentClasses = setOf(ImageContent::class)
override val contentClasses = setOf(ImageContent::class, TextContent::class)
override val descriptorClass = FloatVectorDescriptor::class

// Default values for external API
override val endpoint: String = "/extract/clip_image"
override val host: String = "localhost"
override val port: Int = 8888
// override val endpoint: String = "/extract/clip_image"
// override val host: String = "localhost"
// override val port: Int = 8888

// Size and list for prototypical descriptor
override val size = 512
Expand All @@ -60,10 +47,20 @@ class CLIPImage : ExternalWithFloatVectorDescriptorAnalyser<ImageContent>() {
* @param content The [ContentElement] for which to request the CLIP feature descriptor.
* @return A list of CLIP feature descriptors.
*/
override fun requestDescriptor(content: ContentElement<*>): List<Float> = httpRequest(content)
override fun requestDescriptor(content: ContentElement<*>): List<Float> {

//TODO make endpoints configurable
return when(content) {
is ImageContent -> httpRequest(content, "http://localhost:8888/extract/clip_image")
is TextContent -> httpRequest(content, "http://localhost:8888/extract/clip_text")
else -> throw IllegalArgumentException("Content '$content' not supported")
}


}

/**
* Generates a prototypical [FloatVectorDescriptor] for this [CLIPImage].
* Generates a prototypical [FloatVectorDescriptor] for this [CLIP].
*
* @return [FloatVectorDescriptor]
*/
Expand All @@ -81,15 +78,15 @@ class CLIPImage : ExternalWithFloatVectorDescriptorAnalyser<ImageContent>() {
* @throws [UnsupportedOperationException], if this [Analyser] does not support the creation of an [Extractor] instance.
*/
override fun newExtractor(
field: Schema.Field<ImageContent, FloatVectorDescriptor>,
field: Schema.Field<ContentElement<*>, FloatVectorDescriptor>,
input: Operator<Retrievable>,
context: IndexContext,
persisting: Boolean,
parameters: Map<String, Any>
): Extractor<ImageContent, FloatVectorDescriptor> {
): Extractor<ContentElement<*>, FloatVectorDescriptor> {
require(field.analyser == this) { "" }
logger.debug { "Creating new CLIPImageExtractor for field '${field.fieldName}' with parameters $parameters." }
return CLIPImageExtractor(input, field, persisting)
return CLIPExtractor(input, field, persisting, this)
}

/**
Expand All @@ -103,15 +100,15 @@ class CLIPImage : ExternalWithFloatVectorDescriptorAnalyser<ImageContent>() {
* @throws [UnsupportedOperationException], if this [Analyser] does not support the creation of an [Retriever] instance.
*/
override fun newRetrieverForContent(
field: Schema.Field<ImageContent, FloatVectorDescriptor>,
content: Collection<ImageContent>,
field: Schema.Field<ContentElement<*>, FloatVectorDescriptor>,
content: Collection<ContentElement<*>>,
context: QueryContext
): Retriever<ImageContent, FloatVectorDescriptor> {
): Retriever<ContentElement<*>, FloatVectorDescriptor> {
return this.newRetrieverForDescriptors(field, this.processContent(content), context)
}

/**
* Generates and returns a new [Retriever] instance for this [CLIPImage].
* Generates and returns a new [Retriever] instance for this [CLIP].
*
* @param field The [Schema.Field] to create an [Retriever] for.
* @param descriptors An array of [Descriptor] elements to use with the [Retriever]
Expand All @@ -121,11 +118,11 @@ class CLIPImage : ExternalWithFloatVectorDescriptorAnalyser<ImageContent>() {
* @throws [UnsupportedOperationException], if this [Analyser] does not support the creation of an [Retriever] instance.
*/
override fun newRetrieverForDescriptors(
field: Schema.Field<ImageContent, FloatVectorDescriptor>,
field: Schema.Field<ContentElement<*>, FloatVectorDescriptor>,
descriptors: Collection<FloatVectorDescriptor>,
context: QueryContext
): Retriever<ImageContent, FloatVectorDescriptor> {
): Retriever<ContentElement<*>, FloatVectorDescriptor> {
require(field.analyser == this) { }
return CLIPImageRetriever(field, descriptors.first(), context)
return CLIPRetriever(field, descriptors.first(), context)
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package org.vitrivr.engine.base.features.external.implementations.clip.image
package org.vitrivr.engine.base.features.external.implementations.clip

import org.vitrivr.engine.core.features.AbstractExtractor
import org.vitrivr.engine.core.features.metadata.file.FileMetadataExtractor
import org.vitrivr.engine.core.model.content.element.ContentElement
import org.vitrivr.engine.core.model.content.element.ImageContent
import org.vitrivr.engine.core.model.descriptor.Descriptor
import org.vitrivr.engine.core.model.descriptor.vector.FloatVectorDescriptor
Expand All @@ -14,7 +15,7 @@ import org.vitrivr.engine.core.operators.ingest.Extractor
import org.vitrivr.engine.core.source.file.FileSource

/**
* [CLIPImageExtractor] implementation of an [AbstractExtractor] for [CLIPImage].
* [CLIPExtractor] implementation of an [AbstractExtractor] for [CLIP].
*
* @param field Schema field for which the extractor generates descriptors.
* @param input Operator representing the input data source.
Expand All @@ -23,7 +24,7 @@ import org.vitrivr.engine.core.source.file.FileSource
* @author Rahel Arnold
* @version 1.0.0
*/
class CLIPImageExtractor(input: Operator<Retrievable>, field: Schema.Field<ImageContent, FloatVectorDescriptor>, persisting: Boolean = true) : AbstractExtractor<ImageContent, FloatVectorDescriptor>(input, field, persisting) {
class CLIPExtractor(input: Operator<Retrievable>, field: Schema.Field<ContentElement<*>, FloatVectorDescriptor>, persisting: Boolean = true, private val clip: CLIP) : AbstractExtractor<ContentElement<*>, FloatVectorDescriptor>(input, field, persisting) {
/**
* Internal method to check, if [Retrievable] matches this [Extractor] and should thus be processed.
*
Expand All @@ -43,6 +44,6 @@ class CLIPImageExtractor(input: Operator<Retrievable>, field: Schema.Field<Image
override fun extract(retrievable: Retrievable): List<FloatVectorDescriptor> {
check(retrievable is RetrievableWithContent) { "Incoming retrievable is not a retrievable with content. This is a programmer's error!" }
val content = retrievable.content.filterIsInstance<ImageContent>()
return content.map { c -> FloatVectorDescriptor(retrievableId = retrievable.id, vector = CLIPImage.requestDescriptor(c), transient = !this.persisting) }
return content.map { c -> FloatVectorDescriptor(retrievableId = retrievable.id, vector = clip.requestDescriptor(c), transient = !this.persisting) }
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
package org.vitrivr.engine.base.features.external.implementations.clip.image
package org.vitrivr.engine.base.features.external.implementations.clip

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import org.vitrivr.engine.core.context.QueryContext
import org.vitrivr.engine.core.features.AbstractRetriever
import org.vitrivr.engine.core.model.content.element.ContentElement
import org.vitrivr.engine.core.model.content.element.ImageContent
import org.vitrivr.engine.core.model.descriptor.vector.FloatVectorDescriptor
import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.query.proximity.ProximityQuery
import org.vitrivr.engine.core.model.retrievable.Retrieved

/**
* [CLIPImageRetriever] implementation for external CLIP image feature retrieval.
* [CLIPRetriever] implementation for external CLIP image feature retrieval.
*
* @param field Schema field for which the retriever operates.
* @param query The query vector for proximity-based retrieval.
Expand All @@ -24,10 +25,10 @@ import org.vitrivr.engine.core.model.retrievable.Retrieved
* @author Rahel Arnold
* @version 1.0.0
*/
class CLIPImageRetriever(field: Schema.Field<ImageContent, FloatVectorDescriptor>, query: FloatVectorDescriptor, context: QueryContext) : AbstractRetriever<ImageContent, FloatVectorDescriptor>(field, query, context) {
class CLIPRetriever(field: Schema.Field<ContentElement<*>, FloatVectorDescriptor>, query: FloatVectorDescriptor, context: QueryContext) : AbstractRetriever<ContentElement<*>, FloatVectorDescriptor>(field, query, context) {
override fun toFlow(scope: CoroutineScope): Flow<Retrieved> = flow {
val query = ProximityQuery(this@CLIPImageRetriever.query) /* TODO: Not sure, if the default setting should be used here. */
this@CLIPImageRetriever.reader.getAll(query).forEach {
val query = ProximityQuery(this@CLIPRetriever.query) /* TODO: Not sure, if the default setting should be used here. */
this@CLIPRetriever.reader.getAll(query).forEach {
emit(it)
}
}
Expand Down
Loading

0 comments on commit 72d55b3

Please sign in to comment.