Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-44363: [C#] Handle Flight data with zero batches #45315

Merged
merged 4 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,39 @@ public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor flightDescriptor, Met
flightInfoResult.Dispose);
}

/// <summary>
/// Start a Flight Put request.
/// </summary>
/// <param name="flightDescriptor">Descriptor for the data to be put</param>
/// <param name="headers">gRPC headers to send with the request</param>
/// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> object used to write data batches and receive responses</returns>
public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers = null)
{
return StartPut(flightDescriptor, headers, null, CancellationToken.None);
}

/// <summary>
/// Start a Flight Put request.
/// </summary>
/// <param name="flightDescriptor">Descriptor for the data to be put</param>
/// <param name="schema">The schema of the data</param>
/// <param name="headers">gRPC headers to send with the request</param>
/// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> object used to write data batches and receive responses</returns>
/// <remarks>Using this method rather than a StartPut overload that doesn't accept a schema
/// means that the schema is sent even if no data batches are sent</remarks>
public Task<FlightRecordBatchDuplexStreamingCall> StartPut(FlightDescriptor flightDescriptor, Schema schema, Metadata headers = null)
{
return StartPut(flightDescriptor, schema, headers, null, CancellationToken.None);
}

/// <summary>
/// Start a Flight Put request.
/// </summary>
/// <param name="flightDescriptor">Descriptor for the data to be put</param>
/// <param name="headers">gRPC headers to send with the request</param>
/// <param name="deadline">Optional deadline. The request will be cancelled if this deadline is reached.</param>
/// <param name="cancellationToken">Optional token for cancelling the request</param>
/// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> object used to write data batches and receive responses</returns>
public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var channels = _client.DoPut(headers, deadline, cancellationToken);
Expand All @@ -117,6 +145,33 @@ public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDesc
channels.Dispose);
}

/// <summary>
/// Start a Flight Put request.
/// </summary>
/// <param name="flightDescriptor">Descriptor for the data to be put</param>
/// <param name="schema">The schema of the data</param>
/// <param name="headers">gRPC headers to send with the request</param>
/// <param name="deadline">Optional deadline. The request will be cancelled if this deadline is reached.</param>
/// <param name="cancellationToken">Optional token for cancelling the request</param>
/// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> object used to write data batches and receive responses</returns>
/// <remarks>Using this method rather than a StartPut overload that doesn't accept a schema
/// means that the schema is sent even if no data batches are sent</remarks>
public async Task<FlightRecordBatchDuplexStreamingCall> StartPut(FlightDescriptor flightDescriptor, Schema schema, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var channels = _client.DoPut(headers, deadline, cancellationToken);
var requestStream = new FlightClientRecordBatchStreamWriter(channels.RequestStream, flightDescriptor);
var readStream = new StreamReader<Protocol.PutResult, FlightPutResult>(channels.ResponseStream, putResult => new FlightPutResult(putResult));
var streamingCall = new FlightRecordBatchDuplexStreamingCall(
requestStream,
readStream,
channels.ResponseHeadersAsync,
channels.GetStatus,
channels.GetTrailers,
channels.Dispose);
await streamingCall.RequestStream.SetupStream(schema).ConfigureAwait(false);
return streamingCall;
}

public AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse> Handshake(Metadata headers = null)
{
return Handshake(headers, null, CancellationToken.None);
Expand Down
21 changes: 17 additions & 4 deletions csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,22 @@ private protected FlightRecordBatchStreamWriter(IAsyncStreamWriter<Protocol.Flig
_flightDescriptor = flightDescriptor;
}

private void SetupStream(Schema schema)
/// <summary>
/// Configure the data stream to write to.
/// </summary>
/// <remarks>
/// The stream will be set up automatically when writing a RecordBatch if required,
/// but calling this method before writing any data allows handling empty streams.
/// </remarks>
/// <param name="schema">The schema of data to be written to this stream</param>
public async Task SetupStream(Schema schema)
{
if (_flightDataStream != null)
{
throw new InvalidOperationException("Flight data stream is already set");
}
_flightDataStream = new FlightDataStream(_clientStreamWriter, _flightDescriptor, schema);
await _flightDataStream.SendSchema().ConfigureAwait(false);
}

public WriteOptions WriteOptions { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
Expand All @@ -50,14 +63,14 @@ public Task WriteAsync(RecordBatch message)
return WriteAsync(message, default);
}

public Task WriteAsync(RecordBatch message, ByteString applicationMetadata)
public async Task WriteAsync(RecordBatch message, ByteString applicationMetadata)
{
if (_flightDataStream == null)
{
SetupStream(message.Schema);
await SetupStream(message.Schema).ConfigureAwait(false);
}

return _flightDataStream.Write(message, applicationMetadata);
await _flightDataStream.Write(message, applicationMetadata);
}

protected virtual void Dispose(bool disposing)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public FlightDataStream(IAsyncStreamWriter<Protocol.FlightData> clientStreamWrit
_flightDescriptor = flightDescriptor;
}

private async Task SendSchema()
public async Task SendSchema()
{
_currentFlightData = new Protocol.FlightData();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public async Task RunClient(int serverPort)
var batches = jsonFile.Batches.Select(batch => batch.ToArrow(schema, dictionaries)).ToArray();

// 1. Put the data to the server.
await UploadBatches(client, descriptor, batches).ConfigureAwait(false);
await UploadBatches(client, descriptor, schema, batches).ConfigureAwait(false);

// 2. Get the ticket for the data.
var info = await client.GetInfo(descriptor).ConfigureAwait(false);
Expand Down Expand Up @@ -112,9 +112,10 @@ public async Task RunClient(int serverPort)
}
}

private static async Task UploadBatches(FlightClient client, FlightDescriptor descriptor, RecordBatch[] batches)
private static async Task UploadBatches(
FlightClient client, FlightDescriptor descriptor, Schema schema, RecordBatch[] batches)
{
using var putCall = client.StartPut(descriptor);
using var putCall = await client.StartPut(descriptor, schema);
using var writer = putCall.RequestStream;

try
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStr

if(_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder))
{
await responseStream.SetupStream(flightHolder.GetFlightInfo().Schema);

var batches = flightHolder.GetRecordBatches();


foreach(var batch in batches)
{
await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata);
Expand Down
36 changes: 29 additions & 7 deletions csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ private RecordBatch CreateTestBatch(int startValue, int length)
return batchBuilder.Build();
}

private Schema GetStoreSchema(FlightDescriptor flightDescriptor)
{
Assert.Contains(flightDescriptor, (IReadOnlyDictionary<FlightDescriptor, FlightHolder>)_flightStore.Flights);

var flightHolder = _flightStore.Flights[flightDescriptor];
return flightHolder.GetFlightInfo().Schema;
}

private IEnumerable<RecordBatchWithMetadata> GetStoreBatch(FlightDescriptor flightDescriptor)
{
Expand Down Expand Up @@ -88,7 +95,7 @@ public async Task TestPutSingleRecordBatch()
var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
var expectedBatch = CreateTestBatch(0, 100);

var putStream = _flightClient.StartPut(flightDescriptor);
var putStream = await _flightClient.StartPut(flightDescriptor, expectedBatch.Schema);
await putStream.RequestStream.WriteAsync(expectedBatch);
await putStream.RequestStream.CompleteAsync();
var putResults = await putStream.ResponseStream.ToListAsync();
Expand All @@ -108,7 +115,7 @@ public async Task TestPutTwoRecordBatches()
var expectedBatch1 = CreateTestBatch(0, 100);
var expectedBatch2 = CreateTestBatch(0, 100);

var putStream = _flightClient.StartPut(flightDescriptor);
var putStream = await _flightClient.StartPut(flightDescriptor, expectedBatch1.Schema);
await putStream.RequestStream.WriteAsync(expectedBatch1);
await putStream.RequestStream.WriteAsync(expectedBatch2);
await putStream.RequestStream.CompleteAsync();
Expand All @@ -123,6 +130,23 @@ public async Task TestPutTwoRecordBatches()
ArrowReaderVerifier.CompareBatches(expectedBatch2, actualBatches[1].RecordBatch);
}

[Fact]
public async Task TestPutZeroRecordBatches()
{
var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
var schema = CreateTestBatch(0, 1).Schema;

var putStream = await _flightClient.StartPut(flightDescriptor, schema);
await putStream.RequestStream.CompleteAsync();
var putResults = await putStream.ResponseStream.ToListAsync();

Assert.Empty(putResults);

var actualSchema = GetStoreSchema(flightDescriptor);

SchemaComparer.Compare(schema, actualSchema);
}

[Fact]
public async Task TestGetRecordBatchWithDelayedSchema()
{
Expand Down Expand Up @@ -230,7 +254,7 @@ public async Task TestPutWithMetadata()
var expectedBatch = CreateTestBatch(0, 100);
var expectedMetadata = ByteString.CopyFromUtf8("test metadata");

var putStream = _flightClient.StartPut(flightDescriptor);
var putStream = await _flightClient.StartPut(flightDescriptor, expectedBatch.Schema);
await putStream.RequestStream.WriteAsync(expectedBatch, expectedMetadata);
await putStream.RequestStream.CompleteAsync();
var putResults = await putStream.ResponseStream.ToListAsync();
Expand Down Expand Up @@ -471,8 +495,7 @@ public async Task EnsureCallRaisesDeadlineExceeded()
exception = await Assert.ThrowsAsync<RpcException>(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

var putStream = _flightClient.StartPut(flightDescriptor, null, deadline);
exception = await Assert.ThrowsAsync<RpcException>(async () => await putStream.RequestStream.WriteAsync(batch));
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.StartPut(flightDescriptor, batch.Schema, null, deadline));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetSchema(flightDescriptor, null, deadline));
Expand Down Expand Up @@ -514,8 +537,7 @@ public async Task EnsureCallRaisesRequestCancelled()
exception = await Assert.ThrowsAsync<RpcException>(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

var putStream = _flightClient.StartPut(flightDescriptor, null, null, cts.Token);
exception = await Assert.ThrowsAsync<RpcException>(async () => await putStream.RequestStream.WriteAsync(batch));
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.StartPut(flightDescriptor, batch.Schema, null, null, cts.Token));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetSchema(flightDescriptor, null, null, cts.Token));
Expand Down
5 changes: 1 addition & 4 deletions dev/archery/archery/integration/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,10 +1890,7 @@ def _temp_path():
return

file_objs = [
generate_primitive_case([], name='primitive_no_batches')
# TODO(https://github.com/apache/arrow/issues/44363)
.skip_format(SKIP_FLIGHT, 'C#'),

generate_primitive_case([], name='primitive_no_batches'),
generate_primitive_case([17, 20], name='primitive'),
generate_primitive_case([0, 0, 0], name='primitive_zerolength'),

Expand Down
Loading