From 06c6f3f35c6b819a45378fd9a07778391e1a64a2 Mon Sep 17 00:00:00 2001 From: Igor Berntein Date: Tue, 27 Aug 2024 15:36:08 -0400 Subject: [PATCH] Fix BigtableIO.write() client sharing This PR expands the refcount lease on the underlying Bigtable client from Start/StopBundle to the first StartBundle until Teardown. The previous behavior had a client & connection churn when all worker threads had similar load. --- .../beam/sdk/io/gcp/bigtable/BigtableIO.java | 103 ++++--- .../bigtable/BigtableSharedClientTest.java | 260 ++++++++++++++++++ 2 files changed, 310 insertions(+), 53 deletions(-) create mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableSharedClientTest.java diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java index 6d20109e947b..4ffc98d99cda 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java @@ -1359,15 +1359,19 @@ private static class BigtableWriterFn throttleReportThresMsecs = firstNonNull(writeOptions.getThrottlingReportTargetMs(), 180_000); LOG.debug("Created Bigtable Write Fn with writeOptions {} ", writeOptions); } - @StartBundle public void startBundle(StartBundleContext c) throws IOException { recordsWritten = 0; this.seenWindows = Maps.newHashMapWithExpectedSize(1); - if (bigtableWriter == null) { + // Ideally this would be in @Setup, but we need access to PipelineOptions and there is no easy + // way to plumb it to @Setup. + if (serviceEntry == null) { serviceEntry = factory.getServiceForWriting(id, config, writeOptions, c.getPipelineOptions()); + } + + if (bigtableWriter == null) { bigtableWriter = serviceEntry.getService().openForWriting(writeOptions); } @@ -1458,65 +1462,58 @@ private static boolean isDataException(Throwable e) { @FinishBundle public void finishBundle(FinishBundleContext c) throws Exception { - try { - if (bigtableWriter != null) { - Instant closeStart = Instant.now(); - try { - bigtableWriter.close(); - } catch (IOException e) { - // If the writer fails due to a batching exception, but no failures were detected - // it means that error handling was enabled, and that errors were detected and routed - // to the error queue. Bigtable will successfully write other failures in the batch, - // so this exception should be ignored - if (!(e.getCause() instanceof BatchingException)) { - throttlingMsecs.inc(new Duration(closeStart, Instant.now()).getMillis()); - throw e; - } - } - // add the excessive amount to throttling metrics if elapsed time > target latency - if (throttleReportThresMsecs > 0) { - long excessTime = - new Duration(closeStart, Instant.now()).getMillis() - throttleReportThresMsecs; - if (excessTime > 0) { - throttlingMsecs.inc(excessTime); - } + if (bigtableWriter != null) { + Instant closeStart = Instant.now(); + try { + bigtableWriter.close(); + } catch (IOException e) { + // If the writer fails due to a batching exception, but no failures were detected + // it means that error handling was enabled, and that errors were detected and routed + // to the error queue. Bigtable will successfully write other failures in the batch, + // so this exception should be ignored + if (!(e.getCause() instanceof BatchingException)) { + throttlingMsecs.inc(new Duration(closeStart, Instant.now()).getMillis()); + throw e; } - if (!reportedLineage) { - bigtableWriter.reportLineage(); - reportedLineage = true; + } + // add the excessive amount to throttling metrics if elapsed time > target latency + if (throttleReportThresMsecs > 0) { + long excessTime = + new Duration(closeStart, Instant.now()).getMillis() - throttleReportThresMsecs; + if (excessTime > 0) { + throttlingMsecs.inc(excessTime); } - bigtableWriter = null; } + if (!reportedLineage) { + bigtableWriter.reportLineage(); + reportedLineage = true; + } + bigtableWriter = null; + } - for (KV badRecord : badRecords) { - try { - badRecordRouter.route( - c, - badRecord.getKey().getRecord(), - inputCoder, - (Exception) badRecord.getKey().getCause(), - "Failed to write malformed mutation to Bigtable", - badRecord.getValue()); - } catch (Exception e) { - failures.add(badRecord.getKey()); - } + for (KV badRecord : badRecords) { + try { + badRecordRouter.route( + c, + badRecord.getKey().getRecord(), + inputCoder, + (Exception) badRecord.getKey().getCause(), + "Failed to write malformed mutation to Bigtable", + badRecord.getValue()); + } catch (Exception e) { + failures.add(badRecord.getKey()); } + } - checkForFailures(); + checkForFailures(); - LOG.debug("Wrote {} records", recordsWritten); + LOG.debug("Wrote {} records", recordsWritten); - for (Map.Entry entry : seenWindows.entrySet()) { - c.output( - BigtableWriteResult.create(entry.getValue()), - entry.getKey().maxTimestamp(), - entry.getKey()); - } - } finally { - if (serviceEntry != null) { - serviceEntry.close(); - serviceEntry = null; - } + for (Map.Entry entry : seenWindows.entrySet()) { + c.output( + BigtableWriteResult.create(entry.getValue()), + entry.getKey().maxTimestamp(), + entry.getKey()); } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableSharedClientTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableSharedClientTest.java new file mode 100644 index 000000000000..701d60368ed7 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableSharedClientTest.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigtable; + +import com.google.api.gax.grpc.ChannelPoolSettings; +import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider; +import com.google.bigtable.v2.BigtableGrpc; +import com.google.bigtable.v2.MutateRowsRequest; +import com.google.bigtable.v2.MutateRowsResponse; +import com.google.bigtable.v2.MutateRowsResponse.Entry; +import com.google.bigtable.v2.Mutation; +import com.google.bigtable.v2.Mutation.SetCell; +import com.google.bigtable.v2.PingAndWarmRequest; +import com.google.bigtable.v2.PingAndWarmResponse; +import com.google.cloud.bigtable.data.v2.BigtableDataSettings.Builder; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import com.google.rpc.Code; +import com.google.rpc.Status; +import io.grpc.BindableService; +import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.net.ServerSocket; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; +import org.apache.beam.runners.direct.DirectRunner; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult.State; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.KV; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Ensure that BigtableIO.write() reuses the same instance of the underlying bigtable client. + * This test will create a toy pipeline using DirectRunner and have it write to a local emulator. + * The emulator will record all of the client connections. Then the test will check that only a + * single connection was used. + */ +@RunWith(JUnit4.class) +public class BigtableSharedClientTest { + + private FakeBigtable fakeService; + private ServerClientConnectionCounterInterceptor clientConnectionInterceptor; + private Server fakeServer; + + + @Before + public void setUp() throws Exception { + clientConnectionInterceptor = new ServerClientConnectionCounterInterceptor(); + this.fakeService = new FakeBigtable(); + + IOException lastError = null; + + for (int i = 0; i < 10; i++) { + try { + this.fakeServer = createServer(fakeService, clientConnectionInterceptor); + lastError = null; + break; + } catch (IOException e) { + lastError = e; + } + } + if (lastError != null) { + throw lastError; + } + } + + @After + public void tearDown() throws Exception { + if (fakeServer != null) { + fakeServer.shutdownNow(); + } + } + + private static Server createServer(BindableService service, ServerInterceptor... interceptors) throws IOException { + int port; + try(ServerSocket ss = new ServerSocket(0)) { + port = ss.getLocalPort(); + } + + ServerBuilder serverBuilder = ServerBuilder.forPort(port) + .addService(service); + + for (ServerInterceptor interceptor : interceptors) { + serverBuilder.intercept(interceptor); + } + return serverBuilder.build().start(); + } + + @Test + public void testClientReusedAcrossBundles() { + PipelineOptions opts = PipelineOptionsFactory.create(); + opts.setRunner(DirectRunner.class); + ExperimentalOptions.addExperiment( + opts.as(ExperimentalOptions.class), + String.format( + "%s=%s", + BigtableConfigTranslator.BIGTABLE_SETTINGS_OVERRIDE, + ClientSettingsOverride.class) + ); + + Pipeline pipeline = Pipeline.create(opts); + + AtomicInteger bundleCount = new AtomicInteger(); + MutationsDoFn dofn = new MutationsDoFn(bundleCount); + + pipeline + .apply(GenerateSequence.from(0).to(10_000)) + .apply(ParDo.of(dofn)) // create Mutations & count bundles + .apply( + BigtableIO.write() + .withProjectId("fake-project") + .withInstanceId("fake-instance") + .withTableId("fake-table") + .withEmulator("localhost:" + fakeServer.getPort()) + ); + + Assert.assertEquals(pipeline.run().waitUntilFinish(), State.DONE); + // Make sure that the test is valid by making sure that multiple bundles were processed + MatcherAssert.assertThat(dofn.bundleCount.get(), Matchers.greaterThan(1)); + // Make sure that a single client was shared across all the bundles + MatcherAssert.assertThat(clientConnectionInterceptor.getClientConnections(), Matchers.hasSize(1)); + + } + + /** Minimal implementation of a Bigtable emulator for BigtableIO.write() */ + static class FakeBigtable extends BigtableGrpc.BigtableImplBase { + @Override + public void mutateRows(MutateRowsRequest request, + StreamObserver responseObserver) { + MutateRowsResponse.Builder builder = MutateRowsResponse.newBuilder(); + + for (int i = 0; i < request.getEntriesCount(); i++) { + builder.addEntries( + Entry.newBuilder() + .setIndex(i) + .setStatus(Status.newBuilder().setCode(Code.OK_VALUE)) + .build() + ); + } + responseObserver.onNext(builder.build()); + responseObserver.onCompleted(); + } + + @Override + public void pingAndWarm(PingAndWarmRequest request, + StreamObserver responseObserver) { + responseObserver.onCompleted(); + } + } + + static class MutationsDoFn extends DoFn>> { + private final AtomicInteger bundleCount; + + public MutationsDoFn(AtomicInteger bundleCount) { + this.bundleCount = bundleCount; + } + + @StartBundle + public void startBundle(StartBundleContext ctx) { + bundleCount.incrementAndGet(); + } + @ProcessElement + public void processElement(@Element Long input, OutputReceiver>> output) { + output.output( + KV.of( + ByteString.copyFromUtf8(input.toString()), + ImmutableList.of( + Mutation.newBuilder().setSetCell( + SetCell.newBuilder() + .setFamilyName("fake-family") + .setColumnQualifier(ByteString.copyFromUtf8("fake-qualifier")) + .setTimestampMicros(System.currentTimeMillis() * 1000) + .setValue(ByteString.copyFromUtf8("fake-value")) + ) + .build() + ) + ) + ); + } + } + + /** Overrides the default settings to ensure 1 channel per client */ + public static class ClientSettingsOverride implements BiFunction { + + @Override + public Builder apply(Builder builder, PipelineOptions pipelineOptions) { + InstantiatingGrpcChannelProvider oldTransport = (InstantiatingGrpcChannelProvider) builder.stubSettings() + .getTransportChannelProvider(); + + InstantiatingGrpcChannelProvider channelProvider = + ((InstantiatingGrpcChannelProvider) builder.stubSettings().getTransportChannelProvider()) + .toBuilder() + .setChannelPoolSettings(ChannelPoolSettings.staticallySized(1)) + .build(); + builder.stubSettings().setTransportChannelProvider(channelProvider); + + return builder; + } + } + + static class ServerClientConnectionCounterInterceptor implements ServerInterceptor { + private Set clientConnections = Collections.synchronizedSet(new HashSet<>()); + @Override + public Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + + return new SimpleForwardingServerCallListener(next.startCall(call, headers)) { + @Override + public void onComplete() { + clientConnections.add(call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString()); + super.onComplete(); + } + }; + } + + public Set getClientConnections() { + return clientConnections; + } + } +}