Skip to content

Commit

Permalink
Merge pull request #4208 from underskyer/Fix-MapRef-constructor-vulne…
Browse files Browse the repository at this point in the history
…rability

Fix MapRef.fromSeqRefs vulnerability
  • Loading branch information
djspiewak authored Dec 21, 2024
2 parents ec62beb + 3fa3927 commit c7d3a03
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions std/shared/src/main/scala/cats/effect/std/MapRef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ object MapRef extends MapRefCompanionPlatform {

}

private def noShardsException = new IllegalArgumentException(
"Shards count should be greater than zero")

/**
* Creates a sharded map ref to reduce atomic contention on the Map, given an efficient and
* equally distributed hash, the contention should allow for interaction like a general
Expand All @@ -68,11 +71,13 @@ object MapRef extends MapRefCompanionPlatform {
def ofShardedImmutableMap[F[_]: Concurrent, K, V](
shardCount: Int
): F[MapRef[F, K, Option[V]]] = {
assert(shardCount >= 1, "MapRef.sharded should have at least 1 shard")
List
.fill(shardCount)(())
.traverse(_ => Concurrent[F].ref[Map[K, V]](Map.empty))
.map(fromSeqRefs(_))
if (shardCount >= 1)
List
.fill(shardCount)(())
.traverse(_ => Concurrent[F].ref[Map[K, V]](Map.empty))
.map(lst => fromNonEmptySeqRefs(NonEmptySeq.fromSeqUnsafe(lst)))
else
ApplicativeError[F, Throwable].raiseError(noShardsException)
}

/**
Expand All @@ -84,24 +89,39 @@ object MapRef extends MapRefCompanionPlatform {
*/
def inShardedImmutableMap[G[_]: Sync, F[_]: Sync, K, V](
shardCount: Int
): G[MapRef[F, K, Option[V]]] = Sync[G].defer {
assert(shardCount >= 1, "MapRef.sharded should have at least 1 shard")
List
.fill(shardCount)(())
.traverse(_ => Ref.in[G, F, Map[K, V]](Map.empty))
.map(fromSeqRefs(_))
): G[MapRef[F, K, Option[V]]] = {
if (shardCount >= 1)
List
.fill(shardCount)(())
.traverse(_ => Ref.in[G, F, Map[K, V]](Map.empty))
.map(lst => fromNonEmptySeqRefs(NonEmptySeq.fromSeqUnsafe(lst)))
else
ApplicativeError[G, Throwable].raiseError(noShardsException)
}

/**
* Creates a sharded map ref from a sequence of refs.
*
* This uses universal hashCode and equality on K.
*/
@deprecated("Use fromNonEmptySeqRefs instead", "3.6.0")
def fromSeqRefs[F[_]: Functor, K, V](
seq: scala.collection.immutable.Seq[Ref[F, Map[K, V]]]
): MapRef[F, K, Option[V]] =
fromNonEmptySeqRefs(
seq.toNeSeq.getOrElse(throw noShardsException)
)

/**
* Creates a sharded map ref from a nonempty sequence of refs.
*
* This uses universal hashCode and equality on K.
*/
def fromNonEmptySeqRefs[F[_]: Functor, K, V](
seq: NonEmptySeq[Ref[F, Map[K, V]]]
): MapRef[F, K, Option[V]] = {
val array = seq.toArray
val shardCount = seq.size
val array = seq.toSeq.toArray
val shardCount = array.length
val refFunction = { (k: K) =>
val location = Math.abs(k.## % shardCount)
array(location)
Expand Down

0 comments on commit c7d3a03

Please sign in to comment.