Skip to content

Commit

Permalink
Implement caching for proxy responses (#778)
Browse files Browse the repository at this point in the history

Signed-off-by: Paolo Di Tommaso <[email protected]>
  • Loading branch information
pditommaso authored Dec 20, 2024
1 parent 67f91c8 commit 00b6add
Show file tree
Hide file tree
Showing 16 changed files with 417 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ import io.micronaut.scheduling.annotation.ExecuteOn
import io.seqera.wave.ErrorHandler
import io.seqera.wave.configuration.HttpClientConfig
import io.seqera.wave.core.RegistryProxyService
import io.seqera.wave.core.RegistryProxyService.DelegateResponse
import io.seqera.wave.core.RouteHandler
import io.seqera.wave.core.RoutePath
import io.seqera.wave.exception.DockerRegistryException
import io.seqera.wave.exchange.RegistryErrorResponse
import io.seqera.wave.proxy.DelegateResponse
import io.seqera.wave.ratelimit.AcquireRequest
import io.seqera.wave.ratelimit.RateLimiterService
import io.seqera.wave.service.blob.BlobCacheService
Expand All @@ -54,7 +54,6 @@ import io.seqera.wave.storage.DigestStore
import io.seqera.wave.storage.DockerDigestStore
import io.seqera.wave.storage.HttpDigestStore
import io.seqera.wave.storage.Storage
import io.seqera.wave.util.Retryable
import jakarta.inject.Inject
import org.reactivestreams.Publisher
import reactor.core.publisher.Mono
Expand Down Expand Up @@ -274,7 +273,7 @@ class RegistryProxyController {
final resp = proxyService.handleRequest(route, headers)
HttpResponse
.status(HttpStatus.valueOf(resp.statusCode))
.body(resp.body.bytes)
.body(resp.body)
.headers(toMutableHeaders(resp.headers))
}

Expand Down Expand Up @@ -348,14 +347,9 @@ class RegistryProxyController {
}

MutableHttpResponse<?> fromContentResponse(DelegateResponse resp, RoutePath route) {
// create the retry logic on error §
final retryable = Retryable
.<byte[]>of(httpConfig)
.onRetry((event) -> log.warn("Unable to read manifest body - request: $route; event: $event"))

HttpResponse
.status(HttpStatus.valueOf(resp.statusCode))
.body(retryable.apply(()-> resp.body.bytes))
.body(resp.body)
.headers(toMutableHeaders(resp.headers))
}

Expand Down
50 changes: 38 additions & 12 deletions src/main/groovy/io/seqera/wave/core/RegistryProxyService.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ package io.seqera.wave.core

import java.util.concurrent.CompletableFuture

import com.google.common.hash.Hashing
import groovy.transform.CompileStatic
import groovy.transform.ToString
import groovy.util.logging.Slf4j
import io.micronaut.cache.annotation.Cacheable
import io.micronaut.context.annotation.Context
Expand All @@ -36,6 +36,8 @@ import io.seqera.wave.auth.RegistryLookupService
import io.seqera.wave.configuration.HttpClientConfig
import io.seqera.wave.http.HttpClientFactory
import io.seqera.wave.model.ContainerCoordinates
import io.seqera.wave.proxy.DelegateResponse
import io.seqera.wave.proxy.ProxyCache
import io.seqera.wave.proxy.ProxyClient
import io.seqera.wave.service.CredentialsService
import io.seqera.wave.service.builder.BuildRequest
Expand All @@ -44,6 +46,7 @@ import io.seqera.wave.storage.DigestStore
import io.seqera.wave.storage.Storage
import io.seqera.wave.tower.PlatformId
import io.seqera.wave.util.RegHelper
import io.seqera.wave.util.Retryable
import jakarta.inject.Inject
import jakarta.inject.Singleton
import reactor.core.publisher.Flux
Expand Down Expand Up @@ -91,6 +94,9 @@ class RegistryProxyService {
@Client("stream-client")
private ReactorStreamingHttpClient streamClient

@Inject
private ProxyCache cache

private ContainerAugmenter scanner(ProxyClient proxyClient) {
return new ContainerAugmenter()
.withStorage(storage)
Expand Down Expand Up @@ -141,7 +147,31 @@ class RegistryProxyService {
}
}

DelegateResponse handleRequest(RoutePath route, Map<String,List<String>> headers){
static protected String requestKey(RoutePath route, Map<String,List<String>> headers) {
final hasher = Hashing.sipHash24().newHasher()
hasher.putUnencodedChars(route.stableHash())
hasher.putUnencodedChars('/')
for( Map.Entry<String,List<String>> entry : (headers ?: Map.of()) ) {
hasher.putUnencodedChars(entry.key)
for( String it : entry.value ) {
if( it )
hasher.putUnencodedChars(it)
hasher.putUnencodedChars('/')
}
hasher.putUnencodedChars('/')
}
return hasher.hash().toString()
}

DelegateResponse handleRequest(RoutePath route, Map<String,List<String>> headers) {
final resp = cache.getOrCompute(
requestKey(route, headers),
(it)-> handleRequest0(route, headers),
(resp)-> route.isDigest() && resp.isCacheable() )
return resp
}

private DelegateResponse handleRequest0(RoutePath route, Map<String,List<String>> headers) {
ProxyClient proxyClient = client(route)
final resp1 = proxyClient.getStream(route.path, headers, false)
final redirect = resp1.headers().firstValue('Location').orElse(null)
Expand Down Expand Up @@ -182,10 +212,15 @@ class RegistryProxyService {
// otherwise read it and include the body input stream in the response
// the caller must consume and close the body to prevent memory leaks
else {
// create the retry logic on error §
final retryable = Retryable
.<byte[]>of(httpConfig)
.onRetry((event) -> log.warn("Unable to read blob body - request: $route; event: $event"))
// read the body and compose the response
return new DelegateResponse(
statusCode: resp1.statusCode(),
headers: resp1.headers().map(),
body: resp1.body() )
body: retryable.apply(()-> resp1.body().bytes) )
}
}

Expand Down Expand Up @@ -226,15 +261,6 @@ class RegistryProxyService {
return result
}

@ToString(includeNames = true, includePackage = false)
static class DelegateResponse {
int statusCode
Map<String,List<String>> headers
InputStream body
String location
boolean isRedirect() { location }
}

Flux<ByteBuffer<?>> streamBlob(RoutePath route, Map<String,List<String>> headers) {
ProxyClient proxyClient = client(route)
return proxyClient.stream(streamClient, route.path, headers)
Expand Down
6 changes: 6 additions & 0 deletions src/main/groovy/io/seqera/wave/core/RoutePath.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import io.micronaut.core.annotation.Nullable
import io.seqera.wave.model.ContainerCoordinates
import io.seqera.wave.service.request.ContainerRequest
import io.seqera.wave.tower.PlatformId
import io.seqera.wave.util.RegHelper
import static io.seqera.wave.WaveDefault.DOCKER_IO
/**
* Model a container registry route path
Expand Down Expand Up @@ -150,4 +151,9 @@ class RoutePath implements ContainerPath {
else
throw new IllegalArgumentException("Not a valid container path - offending value: '$location'")
}

String stableHash() {
RegHelper.sipHash(type, registry, path, image, reference, identity.stableHash())
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,22 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package io.seqera.wave.tower.client
package io.seqera.wave.proxy

import spock.lang.Specification
import groovy.transform.ToString
import io.seqera.wave.encoder.MoshiExchange

/**
*
* Model a response object to be forwarded to the client
*
* @author Paolo Di Tommaso <[email protected]>
*/
class TowerClientTest extends Specification {

def 'should create consistent hash' () {
given:
def client = new TowerClient()

expect:
client.makeKey('a') == '92cf27ac76c18d8e'
and:
client.makeKey('a') == client.makeKey('a')
and:
client.makeKey('a','b','c') == client.makeKey('a','b','c')
and:
client.makeKey('a','b',null) == client.makeKey('a','b',null)
and:
client.makeKey(new URI('http://foo.com')) == client.makeKey('http://foo.com')
and:
client.makeKey(100l) == client.makeKey('100')
}

@ToString(includeNames = true, includePackage = false)
class DelegateResponse implements MoshiExchange {
int statusCode
Map<String,List<String>> headers
byte[] body
String location
boolean isRedirect() { location }
boolean isCacheable() { location!=null || (body!=null && statusCode>=200 && statusCode<400) }
}
62 changes: 62 additions & 0 deletions src/main/groovy/io/seqera/wave/proxy/ProxyCache.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Wave, containers provisioning service
* Copyright (c) 2023-2024, Seqera Labs
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package io.seqera.wave.proxy

import java.time.Duration

import com.squareup.moshi.adapters.PolymorphicJsonAdapterFactory
import groovy.transform.CompileStatic
import io.micronaut.context.annotation.Value
import io.micronaut.core.annotation.Nullable
import io.seqera.wave.encoder.MoshiEncodeStrategy
import io.seqera.wave.encoder.MoshiExchange
import io.seqera.wave.store.cache.AbstractTieredCache
import io.seqera.wave.store.cache.L2TieredCache
import jakarta.inject.Singleton
/**
* Implements a tiered cache for proxied http responses
*
* @author Paolo Di Tommaso <[email protected]>
*/
@Singleton
@CompileStatic
class ProxyCache extends AbstractTieredCache<DelegateResponse> {
ProxyCache(@Nullable L2TieredCache l2,
@Value('${wave.proxy-cache.duration:1h}') Duration duration,
@Value('${wave.proxy-cache.max-size:10000}') long maxSize) {
super(l2, encoder(), duration, maxSize)
}

static MoshiEncodeStrategy encoder() {
// json adapter factory
final factory = PolymorphicJsonAdapterFactory.of(MoshiExchange.class, "@type")
.withSubtype(Entry.class, Entry.name)
.withSubtype(DelegateResponse.class, DelegateResponse.simpleName)
// the encoding strategy
return new MoshiEncodeStrategy<AbstractTieredCache.Entry>(factory) {}
}

String getName() {
'proxy-cache'
}

String getPrefix() {
'proxy-cache/v1'
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class StreamServiceImpl implements StreamService {
// when it's a response with a binary body, just return it
if( resp.body!=null ) {
log.debug "Streaming response body for route: $route"
return resp.body
return new ByteArrayInputStream(resp.body)
}
// otherwise cache the blob and stream the resulting uri
if( blobCacheService ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCac

private L2TieredCache<String,String> l2

private final Lock sync = new ReentrantLock()
private final WeakHashMap<String,Lock> locks = new WeakHashMap<>()

AbstractTieredCache(L2TieredCache<String,String> l2, MoshiEncodeStrategy encoder, Duration duration, long maxSize) {
log.info "Cache '${getName()}' config - prefix=${getPrefix()}; ttl=${duration}; max-size: ${maxSize}; l2=${l2}"
log.info "Cache '${getName()}' config - prefix=${getPrefix()}; ttl=${duration}; max-size: ${maxSize}"
if( l2==null )
log.warn "Missing L2 cache for tiered cache '${getName()}'"
this.l2 = l2
this.ttl = duration
this.encoder = encoder
Expand All @@ -89,12 +91,49 @@ abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCac
}
}

/**
* Retrieve the value associated with the specified key
*
* @param key
* The key of the value to be retrieved
* @return
* The value associated with the specified key, or {@code null} otherwise
*/
@Override
V get(String key) {
getOrCompute(key, null)
getOrCompute(key, null, (v)->true)
}

/**
* Retrieve the value associated with the specified key
*
* @param key
* The key of the value to be retrieved
* @param loader
* A function invoked to load the value the entry with the specified key is not available
* @return
* The value associated with the specified key, or {@code null} otherwise
*/
V getOrCompute(String key, Function<String,V> loader) {
getOrCompute(key, loader, (v)->true)
}

/**
* Retrieve the value associated with the specified key
*
* @param key
* The key of the value to be retrieved
* @param loader
* The function invoked to load the value the entry with the specified key is not available
* @param cacheCondition
* The function to determine if the loaded value should be cached
* @return
* The value associated with the specified key, or #function result otherwise
*/
V getOrCompute(String key, Function<String,V> loader, Function<V,Boolean> cacheCondition) {
assert key!=null, "Argument key cannot be null"
assert cacheCondition!=null, "Argument condition cannot be null"

log.trace "Cache '${name}' checking key=$key"
// Try L1 cache first
V value = l1.synchronous().getIfPresent(key)
Expand All @@ -103,6 +142,7 @@ abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCac
return value
}

final sync = locks.computeIfAbsent(key, (k)-> new ReentrantLock())
sync.lock()
try {
value = l1.synchronous().getIfPresent(key)
Expand All @@ -124,7 +164,7 @@ abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCac
if( value==null && loader!=null ) {
log.trace "Cache '${name}' invoking loader - key=$key"
value = loader.apply(key)
if( value!=null ) {
if( value!=null && cacheCondition.apply(value) ) {
l1.synchronous().put(key,value)
l2Put(key,value)
}
Expand Down
11 changes: 11 additions & 0 deletions src/main/groovy/io/seqera/wave/tower/PlatformId.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import groovy.transform.Canonical
import groovy.transform.CompileStatic
import io.seqera.wave.api.ContainerInspectRequest
import io.seqera.wave.api.SubmitContainerTokenRequest
import io.seqera.wave.util.RegHelper
import io.seqera.wave.util.StringUtils

/**
Expand Down Expand Up @@ -80,4 +81,14 @@ class PlatformId {
", workflowId=" + workflowId +
')';
}

String stableHash() {
RegHelper.sipHash(
getUserId(),
getUserEmail(),
workspaceId,
accessToken,
towerEndpoint,
workflowId )
}
}
Loading

0 comments on commit 00b6add

Please sign in to comment.