diff --git a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs index b89ce9da79d14..10660f40b4c3e 100644 --- a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs +++ b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs @@ -98,11 +98,39 @@ public AsyncUnaryCall GetInfo(FlightDescriptor flightDescriptor, Met flightInfoResult.Dispose); } + /// + /// Start a Flight Put request. + /// + /// Descriptor for the data to be put + /// gRPC headers to send with the request + /// A object used to write data batches and receive responses public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers = null) { return StartPut(flightDescriptor, headers, null, CancellationToken.None); } + /// + /// Start a Flight Put request. + /// + /// Descriptor for the data to be put + /// The schema of the data + /// gRPC headers to send with the request + /// A object used to write data batches and receive responses + /// Using this method rather than a StartPut overload that doesn't accept a schema + /// means that the schema is sent even if no data batches are sent + public Task StartPut(FlightDescriptor flightDescriptor, Schema schema, Metadata headers = null) + { + return StartPut(flightDescriptor, schema, headers, null, CancellationToken.None); + } + + /// + /// Start a Flight Put request. + /// + /// Descriptor for the data to be put + /// gRPC headers to send with the request + /// Optional deadline. The request will be cancelled if this deadline is reached. + /// Optional token for cancelling the request + /// A object used to write data batches and receive responses public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default) { var channels = _client.DoPut(headers, deadline, cancellationToken); @@ -117,6 +145,33 @@ public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDesc channels.Dispose); } + /// + /// Start a Flight Put request. + /// + /// Descriptor for the data to be put + /// The schema of the data + /// gRPC headers to send with the request + /// Optional deadline. The request will be cancelled if this deadline is reached. + /// Optional token for cancelling the request + /// A object used to write data batches and receive responses + /// Using this method rather than a StartPut overload that doesn't accept a schema + /// means that the schema is sent even if no data batches are sent + public async Task StartPut(FlightDescriptor flightDescriptor, Schema schema, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default) + { + var channels = _client.DoPut(headers, deadline, cancellationToken); + var requestStream = new FlightClientRecordBatchStreamWriter(channels.RequestStream, flightDescriptor); + var readStream = new StreamReader(channels.ResponseStream, putResult => new FlightPutResult(putResult)); + var streamingCall = new FlightRecordBatchDuplexStreamingCall( + requestStream, + readStream, + channels.ResponseHeadersAsync, + channels.GetStatus, + channels.GetTrailers, + channels.Dispose); + await streamingCall.RequestStream.SetupStream(schema).ConfigureAwait(false); + return streamingCall; + } + public AsyncDuplexStreamingCall Handshake(Metadata headers = null) { return Handshake(headers, null, CancellationToken.None); diff --git a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs index 7a8a6fd677c68..314d46da00830 100644 --- a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs +++ b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs @@ -38,9 +38,22 @@ private protected FlightRecordBatchStreamWriter(IAsyncStreamWriter + /// Configure the data stream to write to. + /// + /// + /// The stream will be set up automatically when writing a RecordBatch if required, + /// but calling this method before writing any data allows handling empty streams. + /// + /// The schema of data to be written to this stream + public async Task SetupStream(Schema schema) { + if (_flightDataStream != null) + { + throw new InvalidOperationException("Flight data stream is already set"); + } _flightDataStream = new FlightDataStream(_clientStreamWriter, _flightDescriptor, schema); + await _flightDataStream.SendSchema().ConfigureAwait(false); } public WriteOptions WriteOptions { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } @@ -50,14 +63,14 @@ public Task WriteAsync(RecordBatch message) return WriteAsync(message, default); } - public Task WriteAsync(RecordBatch message, ByteString applicationMetadata) + public async Task WriteAsync(RecordBatch message, ByteString applicationMetadata) { if (_flightDataStream == null) { - SetupStream(message.Schema); + await SetupStream(message.Schema).ConfigureAwait(false); } - return _flightDataStream.Write(message, applicationMetadata); + await _flightDataStream.Write(message, applicationMetadata); } protected virtual void Dispose(bool disposing) diff --git a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs index 72c1551be2917..7cbbe66f40a94 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs @@ -44,7 +44,7 @@ public FlightDataStream(IAsyncStreamWriter clientStreamWrit _flightDescriptor = flightDescriptor; } - private async Task SendSchema() + public async Task SendSchema() { _currentFlightData = new Protocol.FlightData(); diff --git a/csharp/test/Apache.Arrow.Flight.IntegrationTest/Scenarios/JsonTestScenario.cs b/csharp/test/Apache.Arrow.Flight.IntegrationTest/Scenarios/JsonTestScenario.cs index 4f7fed74352fc..784751044065a 100644 --- a/csharp/test/Apache.Arrow.Flight.IntegrationTest/Scenarios/JsonTestScenario.cs +++ b/csharp/test/Apache.Arrow.Flight.IntegrationTest/Scenarios/JsonTestScenario.cs @@ -76,7 +76,7 @@ public async Task RunClient(int serverPort) var batches = jsonFile.Batches.Select(batch => batch.ToArrow(schema, dictionaries)).ToArray(); // 1. Put the data to the server. - await UploadBatches(client, descriptor, batches).ConfigureAwait(false); + await UploadBatches(client, descriptor, schema, batches).ConfigureAwait(false); // 2. Get the ticket for the data. var info = await client.GetInfo(descriptor).ConfigureAwait(false); @@ -112,9 +112,10 @@ public async Task RunClient(int serverPort) } } - private static async Task UploadBatches(FlightClient client, FlightDescriptor descriptor, RecordBatch[] batches) + private static async Task UploadBatches( + FlightClient client, FlightDescriptor descriptor, Schema schema, RecordBatch[] batches) { - using var putCall = client.StartPut(descriptor); + using var putCall = await client.StartPut(descriptor, schema); using var writer = putCall.RequestStream; try diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs index 46c5460912d8c..5689b45bfdec8 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs @@ -51,9 +51,10 @@ public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStr if(_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) { + await responseStream.SetupStream(flightHolder.GetFlightInfo().Schema); + var batches = flightHolder.GetRecordBatches(); - foreach(var batch in batches) { await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata); diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs index 350762c992769..241b3c006a003 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs @@ -57,6 +57,13 @@ private RecordBatch CreateTestBatch(int startValue, int length) return batchBuilder.Build(); } + private Schema GetStoreSchema(FlightDescriptor flightDescriptor) + { + Assert.Contains(flightDescriptor, (IReadOnlyDictionary)_flightStore.Flights); + + var flightHolder = _flightStore.Flights[flightDescriptor]; + return flightHolder.GetFlightInfo().Schema; + } private IEnumerable GetStoreBatch(FlightDescriptor flightDescriptor) { @@ -88,7 +95,7 @@ public async Task TestPutSingleRecordBatch() var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); var expectedBatch = CreateTestBatch(0, 100); - var putStream = _flightClient.StartPut(flightDescriptor); + var putStream = await _flightClient.StartPut(flightDescriptor, expectedBatch.Schema); await putStream.RequestStream.WriteAsync(expectedBatch); await putStream.RequestStream.CompleteAsync(); var putResults = await putStream.ResponseStream.ToListAsync(); @@ -108,7 +115,7 @@ public async Task TestPutTwoRecordBatches() var expectedBatch1 = CreateTestBatch(0, 100); var expectedBatch2 = CreateTestBatch(0, 100); - var putStream = _flightClient.StartPut(flightDescriptor); + var putStream = await _flightClient.StartPut(flightDescriptor, expectedBatch1.Schema); await putStream.RequestStream.WriteAsync(expectedBatch1); await putStream.RequestStream.WriteAsync(expectedBatch2); await putStream.RequestStream.CompleteAsync(); @@ -123,6 +130,23 @@ public async Task TestPutTwoRecordBatches() ArrowReaderVerifier.CompareBatches(expectedBatch2, actualBatches[1].RecordBatch); } + [Fact] + public async Task TestPutZeroRecordBatches() + { + var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); + var schema = CreateTestBatch(0, 1).Schema; + + var putStream = await _flightClient.StartPut(flightDescriptor, schema); + await putStream.RequestStream.CompleteAsync(); + var putResults = await putStream.ResponseStream.ToListAsync(); + + Assert.Empty(putResults); + + var actualSchema = GetStoreSchema(flightDescriptor); + + SchemaComparer.Compare(schema, actualSchema); + } + [Fact] public async Task TestGetRecordBatchWithDelayedSchema() { @@ -230,7 +254,7 @@ public async Task TestPutWithMetadata() var expectedBatch = CreateTestBatch(0, 100); var expectedMetadata = ByteString.CopyFromUtf8("test metadata"); - var putStream = _flightClient.StartPut(flightDescriptor); + var putStream = await _flightClient.StartPut(flightDescriptor, expectedBatch.Schema); await putStream.RequestStream.WriteAsync(expectedBatch, expectedMetadata); await putStream.RequestStream.CompleteAsync(); var putResults = await putStream.ResponseStream.ToListAsync(); @@ -471,8 +495,7 @@ public async Task EnsureCallRaisesDeadlineExceeded() exception = await Assert.ThrowsAsync(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch)); Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode); - var putStream = _flightClient.StartPut(flightDescriptor, null, deadline); - exception = await Assert.ThrowsAsync(async () => await putStream.RequestStream.WriteAsync(batch)); + exception = await Assert.ThrowsAsync(async () => await _flightClient.StartPut(flightDescriptor, batch.Schema, null, deadline)); Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode); exception = await Assert.ThrowsAsync(async () => await _flightClient.GetSchema(flightDescriptor, null, deadline)); @@ -514,8 +537,7 @@ public async Task EnsureCallRaisesRequestCancelled() exception = await Assert.ThrowsAsync(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch)); Assert.Equal(StatusCode.Cancelled, exception.StatusCode); - var putStream = _flightClient.StartPut(flightDescriptor, null, null, cts.Token); - exception = await Assert.ThrowsAsync(async () => await putStream.RequestStream.WriteAsync(batch)); + exception = await Assert.ThrowsAsync(async () => await _flightClient.StartPut(flightDescriptor, batch.Schema, null, null, cts.Token)); Assert.Equal(StatusCode.Cancelled, exception.StatusCode); exception = await Assert.ThrowsAsync(async () => await _flightClient.GetSchema(flightDescriptor, null, null, cts.Token)); diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index b4fbbb2d41498..027e675792dbe 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -1890,10 +1890,7 @@ def _temp_path(): return file_objs = [ - generate_primitive_case([], name='primitive_no_batches') - # TODO(https://github.com/apache/arrow/issues/44363) - .skip_format(SKIP_FLIGHT, 'C#'), - + generate_primitive_case([], name='primitive_no_batches'), generate_primitive_case([17, 20], name='primitive'), generate_primitive_case([0, 0, 0], name='primitive_zerolength'),