diff --git a/Microsoft.Azure.Cosmos/src/Handler/HandlerConstants.cs b/Microsoft.Azure.Cosmos/src/Handler/HandlerConstants.cs index 8003ab7ad9..70c083c4a1 100644 --- a/Microsoft.Azure.Cosmos/src/Handler/HandlerConstants.cs +++ b/Microsoft.Azure.Cosmos/src/Handler/HandlerConstants.cs @@ -7,6 +7,13 @@ internal static class HandlerConstants { public const string StartEpkString = "x-ms-start-epk-string"; public const string EndEpkString = "x-ms-end-epk-string"; - public const string ResourceUri = "x-ms-resource-uri"; + public const string ResourceUri = "x-ms-resource-uri"; + + public const string RoutedViaProxy = "x-ms-thinclient-route-via-proxy"; + public const string ProxyStartEpk = "x-ms-thinclient-range-min"; + public const string ProxyEndEpk = "x-ms-thinclient-range-max"; + + public const string ProxyOperationType = "x-ms-thinclient-proxy-operation-type"; + public const string ProxyResourceType = "x-ms-thinclient-proxy-resource-type"; } } \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/src/ThinClientTransportSerializer.cs b/Microsoft.Azure.Cosmos/src/ThinClientTransportSerializer.cs new file mode 100644 index 0000000000..bff6c91e9b --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/ThinClientTransportSerializer.cs @@ -0,0 +1,279 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos +{ + using System; + using System.Buffers; + using System.Collections.Generic; + using System.IO; + using System.Linq; + using System.Net; + using System.Net.Http; + using System.Threading.Tasks; + using Microsoft.Azure.Documents; + using Microsoft.Azure.Documents.Collections; + + /// + /// The ThinClientTransportSerializer class provides methods for serializing and deserializing proxy requests and responses + /// to and from the RNTBD (Remote Network Transport Binary Data) protocol format. This class is used internally within the + /// Azure Cosmos DB SDK to handle communication with the backend services. + /// + internal static class ThinClientTransportSerializer + { + private static readonly PartitionKeyDefinition HashV2SinglePath; + + static ThinClientTransportSerializer() + { + HashV2SinglePath = new PartitionKeyDefinition + { + Kind = PartitionKind.Hash, + Version = Documents.PartitionKeyDefinitionVersion.V2, + }; + HashV2SinglePath.Paths.Add("/id"); + } + + /// + /// Wrapper to expose a public bufferprovider for the RNTBD stack. + /// + public class BufferProviderWrapper + { + internal BufferProvider Provider { get; set; } = new (); + } + + /// + /// Serialize the Proxy request to the RNTBD protocol format. + /// Today this takes the HttprequestMessage and reconstructs the DSR. + /// If the SDK can push properties to the HttpRequestMessage then the handler above having + /// the DSR can allow us to push that directly to the serialization. + /// + public static async Task SerializeProxyRequestAsync( + BufferProviderWrapper bufferProvider, + string accountName, + HttpRequestMessage requestMessage) + { + // Skip this and use the original DSR. + OperationType operationType = (OperationType)Enum.Parse(typeof(OperationType), requestMessage.Headers.GetValues(HandlerConstants.ProxyOperationType).First()); + ResourceType resourceType = (ResourceType)Enum.Parse(typeof(ResourceType), requestMessage.Headers.GetValues(HandlerConstants.ProxyResourceType).First()); + + Guid activityId = Guid.Parse(requestMessage.Headers.GetValues(HttpConstants.HttpHeaders.ActivityId).First()); + + Stream requestStream = null; + if (requestMessage.Content != null) + { + requestStream = await requestMessage.Content.ReadAsStreamAsync(); + } + + RequestNameValueCollection dictionaryCollection = new (); + foreach (KeyValuePair> header in requestMessage.Headers) + { + dictionaryCollection.Set(header.Key, string.Join(",", header.Value)); + } + + using DocumentServiceRequest request = new ( + operationType, + resourceType, + requestMessage.RequestUri.PathAndQuery, + requestStream, + AuthorizationTokenType.PrimaryMasterKey, + dictionaryCollection); + + if (operationType.IsPointOperation()) + { + string partitionKey = request.Headers.Get(HttpConstants.HttpHeaders.PartitionKey); + + if (string.IsNullOrEmpty(partitionKey)) + { + throw new InternalServerErrorException("Partition key is missing or empty."); + } + + string epk = GetEffectivePartitionKeyHash(partitionKey); + + request.Properties = new Dictionary + { + { "x-ms-effective-partition-key", HexStringUtility.HexStringToBytes(epk) } + }; + } + else if (request.Headers[HandlerConstants.ProxyStartEpk] != null) + { + // Re-add EPK headers removed by RequestInvokerHandler through Properties + request.Properties = new Dictionary + { + { WFConstants.BackendHeaders.StartEpkHash, HexStringUtility.HexStringToBytes(request.Headers[HandlerConstants.ProxyStartEpk]) }, + { WFConstants.BackendHeaders.EndEpkHash, HexStringUtility.HexStringToBytes(request.Headers[HandlerConstants.ProxyEndEpk]) } + }; + + request.Headers.Add(HttpConstants.HttpHeaders.ReadFeedKeyType, RntbdConstants.RntdbReadFeedKeyType.EffectivePartitionKeyRange.ToString()); + request.Headers.Add(HttpConstants.HttpHeaders.StartEpk, request.Headers[HandlerConstants.ProxyStartEpk]); + request.Headers.Add(HttpConstants.HttpHeaders.EndEpk, request.Headers[HandlerConstants.ProxyEndEpk]); + } + + await request.EnsureBufferedBodyAsync(); + + using Documents.Rntbd.TransportSerialization.SerializedRequest serializedRequest = + Documents.Rntbd.TransportSerialization.BuildRequestForProxy( + request, + new ResourceOperation(operationType, resourceType), + activityId, + bufferProvider.Provider, + accountName, + out _, + out _); + + // TODO: consider using the SerializedRequest directly. + MemoryStream memoryStream = new MemoryStream(serializedRequest.RequestSize); + await serializedRequest.CopyToStreamAsync(memoryStream); + memoryStream.Position = 0; + return memoryStream; + } + + public static string GetEffectivePartitionKeyHash(string partitionJson) + { + return Documents.PartitionKey.FromJsonString(partitionJson).InternalKey.GetEffectivePartitionKeyString(HashV2SinglePath); + } + + /// + /// Deserialize the Proxy Response from the RNTBD protocol format to the Http format needed by the caller. + /// Today this takes the HttpResponseMessage and reconstructs the modified Http response. + /// + public static async Task ConvertProxyResponseAsync(HttpResponseMessage responseMessage) + { + using Stream responseStream = await responseMessage.Content.ReadAsStreamAsync(); + + (StatusCodes status, byte[] metadata) = await ReadHeaderAndMetadataAsync(responseStream); + + if (responseMessage.StatusCode != (HttpStatusCode)status) + { + throw new InternalServerErrorException("Status code mismatch"); + } + + Rntbd.BytesDeserializer bytesDeserializer = new (metadata, metadata.Length); + if (!Documents.Rntbd.HeadersTransportSerialization.TryParseMandatoryResponseHeaders(ref bytesDeserializer, out bool payloadPresent, out _)) + { + throw new InternalServerErrorException("Length mismatch"); + } + + MemoryStream bodyStream = null; + if (payloadPresent) + { + int length = await ReadBodyLengthAsync(responseStream); + bodyStream = new MemoryStream(length); + await responseStream.CopyToAsync(bodyStream); + bodyStream.Position = 0; + } + + // TODO: Clean this up. + bytesDeserializer = new Rntbd.BytesDeserializer(metadata, metadata.Length); + StoreResponse storeResponse = Documents.Rntbd.TransportSerialization.MakeStoreResponse( + status, + Guid.NewGuid(), + bodyStream, + HttpConstants.Versions.CurrentVersion, + ref bytesDeserializer); + + HttpResponseMessage response = new ((HttpStatusCode)storeResponse.StatusCode) + { + RequestMessage = responseMessage.RequestMessage + }; + + if (bodyStream != null) + { + response.Content = new StreamContent(bodyStream); + } + + foreach (string header in storeResponse.Headers.Keys()) + { + if (header == HttpConstants.HttpHeaders.SessionToken) + { + string newSessionToken = $"{storeResponse.PartitionKeyRangeId}:{storeResponse.Headers.Get(header)}"; + response.Headers.TryAddWithoutValidation(header, newSessionToken); + } + else + { + response.Headers.TryAddWithoutValidation(header, storeResponse.Headers.Get(header)); + } + } + + response.Headers.TryAddWithoutValidation(HandlerConstants.RoutedViaProxy, "1"); + return response; + } + + private static async Task<(StatusCodes, byte[] metadata)> ReadHeaderAndMetadataAsync(Stream stream) + { + byte[] header = ArrayPool.Shared.Rent(24); + const int headerLength = 24; + try + { + int headerRead = 0; + while (headerRead < headerLength) + { + int read = await stream.ReadAsync(header, headerRead, headerLength - headerRead); + + if (read == 0) + { + throw new DocumentClientException("Unexpected end of stream while reading header bytes", HttpStatusCode.Gone, SubStatusCodes.Unknown); + } + + headerRead += read; + } + + uint totalLength = BitConverter.ToUInt32(header, 0); + StatusCodes status = (StatusCodes)BitConverter.ToUInt32(header, 4); + + if (totalLength < headerLength) + { + throw new InternalServerErrorException("Header length mismatch"); + } + + int metadataLength = (int)totalLength - headerLength; + byte[] metadata = new byte[metadataLength]; + int responseMetadataRead = 0; + while (responseMetadataRead < metadataLength) + { + int read = await stream.ReadAsync(metadata, responseMetadataRead, metadataLength - responseMetadataRead); + + if (read == 0) + { + throw new DocumentClientException("Unexpected end of stream while reading metadata bytes", HttpStatusCode.Gone, SubStatusCodes.Unknown); + } + + responseMetadataRead += read; + } + + return (status, metadata); + } + finally + { + ArrayPool.Shared.Return(header); + } + } + + private static async Task ReadBodyLengthAsync(Stream stream) + { + byte[] header = ArrayPool.Shared.Rent(4); + const int headerLength = 4; + try + { + int headerRead = 0; + while (headerRead < headerLength) + { + int read = await stream.ReadAsync(header, headerRead, headerLength - headerRead); + + if (read == 0) + { + throw new DocumentClientException("Unexpected end of stream while reading body length", HttpStatusCode.Gone, SubStatusCodes.Unknown); + } + + headerRead += read; + } + + return BitConverter.ToInt32(header, 0); + } + finally + { + ArrayPool.Shared.Return(header); + } + } + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ThinClientTransportSerializerTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ThinClientTransportSerializerTests.cs new file mode 100644 index 0000000000..a111434b2c --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ThinClientTransportSerializerTests.cs @@ -0,0 +1,113 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Tests +{ + using Moq; + using System; + using System.Collections.Generic; + using System.IO; + using System.Linq; + using System.Net; + using System.Net.Http; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos; + using Microsoft.Azure.Documents; + using Microsoft.Azure.Documents.Collections; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class ThinClientTransportSerializerTests + { + private readonly Mock mockBufferProviderWrapper; + private readonly string testAccountName = "testAccount"; + private readonly Uri testUri = new Uri("http://localhost/dbs/db1/colls/coll1/docs/doc1"); + + public ThinClientTransportSerializerTests() + { + this.mockBufferProviderWrapper = new Mock(); + } + + [TestMethod] + public async Task SerializeProxyRequestAsync_ShouldSerializeRequest() + { + // Arrange + HttpRequestMessage requestMessage = new HttpRequestMessage(HttpMethod.Get, this.testUri); + requestMessage.Headers.Add(HandlerConstants.ProxyOperationType, "Read"); + requestMessage.Headers.Add(HandlerConstants.ProxyResourceType, "Document"); + requestMessage.Headers.Add(HttpConstants.HttpHeaders.ActivityId, Guid.NewGuid().ToString()); + requestMessage.Headers.Add(HttpConstants.HttpHeaders.PartitionKey, "[\"testPartitionKey\"]"); + + // Act + Stream result = await ThinClientTransportSerializer.SerializeProxyRequestAsync( + this.mockBufferProviderWrapper.Object, + this.testAccountName, + requestMessage); + + // Assert + Assert.IsNotNull(result); + Assert.IsInstanceOfType(result, typeof(Stream)); + } + + [TestMethod] + public async Task SerializeProxyRequestAsync_ThrowsException_WhenPartitionKeyIsMissing() + { + // Arrange + HttpRequestMessage requestMessage = new HttpRequestMessage(HttpMethod.Get, this.testUri); + requestMessage.Headers.Add(HandlerConstants.ProxyOperationType, "Read"); + requestMessage.Headers.Add(HandlerConstants.ProxyResourceType, "Document"); + requestMessage.Headers.Add(HttpConstants.HttpHeaders.ActivityId, Guid.NewGuid().ToString()); + + // Act & Assert + await Assert.ThrowsExceptionAsync(() => + ThinClientTransportSerializer.SerializeProxyRequestAsync( + this.mockBufferProviderWrapper.Object, + this.testAccountName, + requestMessage)); + } + + [TestMethod] + public async Task SerializeProxyRequestAsync_InvalidOperationType_ThrowsException() + { + // Arrange + HttpRequestMessage requestMessage = new HttpRequestMessage(HttpMethod.Get, this.testUri); + requestMessage.Headers.Add(HandlerConstants.ProxyOperationType, "InvalidOperation"); + requestMessage.Headers.Add(HandlerConstants.ProxyResourceType, "Document"); + requestMessage.Headers.Add(HttpConstants.HttpHeaders.ActivityId, Guid.NewGuid().ToString()); + requestMessage.Headers.Add(HttpConstants.HttpHeaders.PartitionKey, "[\"testPartitionKey\"]"); + + // Act & Assert + await Assert.ThrowsExceptionAsync(() => + ThinClientTransportSerializer.SerializeProxyRequestAsync( + this.mockBufferProviderWrapper.Object, + this.testAccountName, + requestMessage)); + } + + [TestMethod] + public async Task SerializeProxyRequestAsync_WithRequestBody_ShouldSerializeRequest() + { + // Arrange + HttpRequestMessage requestMessage = new HttpRequestMessage(HttpMethod.Post, this.testUri) + { + Content = new StringContent("{ \"key\": \"value\" }") + }; + requestMessage.Headers.Add(HandlerConstants.ProxyOperationType, "Create"); + requestMessage.Headers.Add(HandlerConstants.ProxyResourceType, "Document"); + requestMessage.Headers.Add(HttpConstants.HttpHeaders.ActivityId, Guid.NewGuid().ToString()); + requestMessage.Headers.Add(HttpConstants.HttpHeaders.PartitionKey, "[\"testPartitionKey\"]"); + + // Act + Stream result = await ThinClientTransportSerializer.SerializeProxyRequestAsync( + this.mockBufferProviderWrapper.Object, + this.testAccountName, + requestMessage); + + // Assert + Assert.IsNotNull(result); + Assert.IsInstanceOfType(result, typeof(Stream)); + Assert.IsTrue(result.Length > 0); + } + } +} \ No newline at end of file