Skip to content

Commit

Permalink
Introduces a distinction between unbound and similarity scores.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralph Gasser committed Mar 27, 2024
1 parent 854ffdd commit e49c93c
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,36 @@ import kotlin.math.max
* Scores are expected to be in the range of [0, 1].
*
* @author Luca Rossetto
* @version 1.0.0
* @author Ralph Gasser
* @version 1.1.0
*/
data class ScoreAttribute(val score: Float) : MergingRetrievableAttribute {
sealed interface ScoreAttribute : MergingRetrievableAttribute {

companion object {
val ZERO = ScoreAttribute(0f)
}
/** The score associated with this [ScoreAttribute]. */
val score: Float

constructor(score: Double) : this(score.toFloat())
/**
* A similarity score. Strictly bound between 0 and 1.
*/
data class Similarity(override val score: Float): ScoreAttribute {
init {
require(score in 0f..1f) { "Similarity score '$score' outside of valid range (0, 1)" }
}

init {
require(score in 0f..1f) { "Score '$score' outside of valid range (0, 1)" }
override fun merge(other: MergingRetrievableAttribute): Similarity = Similarity(
max(this.score, (other as? Similarity)?.score ?: 0f)
)
}

override fun merge(other: MergingRetrievableAttribute): ScoreAttribute = ScoreAttribute(
max(this.score, (other as? ScoreAttribute)?.score ?: 0f)
)
/**
* An unbound score. Strictly bound between 0 and 1.
*/
data class Unbound(override val score: Float): ScoreAttribute {
init {
require(this.score >= 0f) { "Score '$score' outside of valid range (>= 0)." }
}
override fun merge(other: MergingRetrievableAttribute): Unbound = Unbound(
max(this.score, (other as? Unbound)?.score ?: 0f)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ object ScoringFunctions {
* @param max Maximum value. Default is 1.0.
*/
fun max(retrieved: Retrieved, max: Float = 1.0f): ScoreAttribute {
val distance = retrieved.filteredAttribute<DistanceAttribute>()?.distance ?: return ScoreAttribute.ZERO
return ScoreAttribute(max - distance)
val distance = retrieved.filteredAttribute<DistanceAttribute>()?.distance ?: return ScoreAttribute.Unbound(0.0f)
return ScoreAttribute.Unbound(max - distance)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class StringDescriptorReader(field: Schema.Field<*, StringDescriptor>, connectio
val retrievableId = tuple.asUuidValue(RETRIEVABLE_ID_COLUMN_NAME)?.value ?: throw IllegalArgumentException("The provided tuple is missing the required field '${RETRIEVABLE_ID_COLUMN_NAME}'.")
val score = tuple.asDouble(SCORE_COLUMN_NAME) ?: 0.0
val retrieved = Retrieved(retrievableId, null, false)
retrieved.addAttribute(ScoreAttribute(score))
retrieved.addAttribute(ScoreAttribute.Unbound(score.toFloat()))
retrieved
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class StructDescriptorReader(field: Schema.Field<*, StructDescriptor>, connectio
val retrievableId = tuple.asUuidValue(RETRIEVABLE_ID_COLUMN_NAME)?.value ?: throw IllegalArgumentException("The provided tuple is missing the required field '${RETRIEVABLE_ID_COLUMN_NAME}'.")
val score = tuple.asDouble(SCORE_COLUMN_NAME) ?: 0.0
val retrieved = Retrieved(retrievableId, null, false)
retrieved.addAttribute(ScoreAttribute(score))
retrieved.addAttribute(ScoreAttribute.Unbound(score.toFloat()))
retrieved
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class AverageColorRetriever(
override fun toFlow(scope: CoroutineScope) = flow {
val reader = this@AverageColorRetriever.field.getReader()
reader.getAll(this@AverageColorRetriever.query).forEach {
it.addAttribute(ScoreAttribute(scoringFunction(it)))
it.addAttribute(ScoreAttribute.Similarity(scoringFunction(it)))
emit(it)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class TemporalSequenceAggregator(
id, "temporalSequence", true
)

retrieved.addAttribute(ScoreAttribute(score))
retrieved.addAttribute(ScoreAttribute.Unbound(score))
retrieved.addAttribute(RelationshipAttribute(relationships))

emit(retrieved)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class WeightedScoreFusion(
//make a copy and override score
val retrieved = first.copy()
retrieved.filteredAttribute<ScoreAttribute>()
retrieved.addAttribute(ScoreAttribute(score))
retrieved.addAttribute(ScoreAttribute.Unbound(score))

emit(retrieved)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ class ScoreAggregator(
}
}

retrieved.addAttribute(ScoreAttribute(score))
retrieved.addAttribute(ScoreAttribute.Unbound(score))

} else {
retrieved.addAttribute(ScoreAttribute(0f))
retrieved.addAttribute(ScoreAttribute.Unbound(0f))
}

retrieved
Expand Down

0 comments on commit e49c93c

Please sign in to comment.