diff --git a/src/RabbitMQ.Next.Publisher/Publisher.cs b/src/RabbitMQ.Next.Publisher/Publisher.cs index b4f67be..1c6c214 100644 --- a/src/RabbitMQ.Next.Publisher/Publisher.cs +++ b/src/RabbitMQ.Next.Publisher/Publisher.cs @@ -110,8 +110,7 @@ private async Task PublishAsyncImpl(TContent content, MessageBuilder m this.messagePropsPool.Return(message); } } - - + private async Task InternalPublishAsync(IMessageBuilder message, TContent content) { var flags = ComposePublishFlags(message); @@ -120,11 +119,14 @@ private async Task InternalPublishAsync(IMessageBuilder message, TCont var deliveryTag = await ch.PublishAsync(this.exchange, message.RoutingKey, content, message, flags) .ConfigureAwait(false); - var confirmed = await this.confirms.WaitForConfirmAsync(deliveryTag, default).ConfigureAwait(false); - if (!confirmed) + if (this.confirms != null) { - // todo: provide some useful info here - throw new DeliveryFailedException(); + var confirmed = await this.confirms.WaitForConfirmAsync(deliveryTag, default).ConfigureAwait(false); + if (!confirmed) + { + // todo: provide some useful info here + throw new DeliveryFailedException(); + } } } diff --git a/src/RabbitMQ.Next/Connection.cs b/src/RabbitMQ.Next/Connection.cs index 53ae4f1..0cbea13 100644 --- a/src/RabbitMQ.Next/Connection.cs +++ b/src/RabbitMQ.Next/Connection.cs @@ -83,8 +83,8 @@ public async Task OpenConnectionAsync(CancellationToken cancellation) this.socket = await EndpointResolver.OpenSocketAsync(this.connectionDetails.Settings.Endpoints, cancellation).ConfigureAwait(false); this.socketIoCancellation = new CancellationTokenSource(); - Task.Factory.StartNew(() => this.ReceiveLoop(this.socketIoCancellation.Token), TaskCreationOptions.LongRunning); - Task.Factory.StartNew(this.SendLoop, TaskCreationOptions.LongRunning); + StartThread(() => this.ReceiveLoop(this.socketIoCancellation.Token), "RabbitMQ.Next-Receive loop"); + StartThread(this.SendLoop, "RabbitMQ.Next-Send loop"); this.connectionChannel = this.CreateChannel(ProtocolConstants.ConnectionChannel); var connectionCloseWait = new WaitMethodMessageHandler(default); @@ -110,38 +110,27 @@ public async Task OpenConnectionAsync(CancellationToken cancellation) var amqpHeaderMemory = new MemoryAccessor(ProtocolConstants.AmqpHeader); await this.socketSender.Writer.WriteAsync(amqpHeaderMemory, cancellation).ConfigureAwait(false); - this.connectionDetails.Negotiated = await negotiateTask.ConfigureAwait(false); - - // start heartbeat - Task.Factory.StartNew(() => this.HeartbeatLoop(this.connectionDetails.Negotiated.HeartbeatInterval, this.socketIoCancellation.Token), TaskCreationOptions.LongRunning); - + var negotiationResults = await negotiateTask.ConfigureAwait(false); + this.connectionDetails.PopulateWithNegotiationResults(negotiationResults); + this.State = ConnectionState.Configuring; this.State = ConnectionState.Open; } private IChannelInternal CreateChannel(ushort channelNumber) { - var maxFrameSize = this.connectionDetails?.Negotiated?.FrameMaxSize ?? ProtocolConstants.FrameMinSize; + var maxFrameSize = this.connectionDetails.FrameMaxSize ?? ProtocolConstants.FrameMinSize; var policy = new MessageBuilderPoolPolicy(this.memoryPool, channelNumber, maxFrameSize); var messageBuilderPool = new DefaultObjectPool(policy); return new Channel(this.socketSender.Writer, messageBuilderPool, this.connectionDetails.Settings.Serializer); } - private async Task HeartbeatLoop(TimeSpan interval, CancellationToken cancellation) - { - var heartbeatMemory = new MemoryAccessor(ProtocolConstants.HeartbeatFrame); - while (!cancellation.IsCancellationRequested) - { - await Task.Delay(interval, cancellation).ConfigureAwait(false); - await this.socketSender.Writer.WriteAsync(heartbeatMemory, cancellation).ConfigureAwait(false); - } - } - private void SendLoop() { + var heartbeatMemory = new MemoryAccessor(ProtocolConstants.HeartbeatFrame); var socketChannel = this.socketSender.Reader; - while (socketChannel.WaitToReadAsync().Wait()) + do { while (socketChannel.TryRead(out var memory)) { @@ -154,7 +143,21 @@ private void SendLoop() memory = next; } } - } + + var waitResult = socketChannel.WaitToReadAsync().Wait(this.connectionDetails.HeartbeatInterval); + if (waitResult.IsCompleted) + { + if (waitResult.Result) + { + continue; + } + + break; + } + + // wait long enough and nothing was sent, need to send heartbeat frame + this.socket.Send(heartbeatMemory); + } while (true); } private void ReceiveLoop(CancellationToken cancellationToken) @@ -298,6 +301,15 @@ private void ConnectionClose(Exception ex) this.connectionChannel.TryComplete(ex); } + private static void StartThread(Action threadStart, string threadName) + { + var thread = new Thread(new ThreadStart(threadStart)) + { + Name = threadName, + }; + thread.Start(); + } + private static async Task NegotiateConnectionAsync(IChannel channel, ConnectionSettings settings, CancellationToken cancellation) { // connection should be forcibly closed if negotiation phase take more then 10s. @@ -334,7 +346,6 @@ private static async Task NegotiateConnectionAsync(IChannel var tuneMethod = await tuneMethodTask.ConfigureAwait(false); var negotiationResult = new NegotiationResults( - settings.Auth.Type, tuneMethod.ChannelMax, Math.Min(settings.MaxFrameSize, (int)tuneMethod.MaxFrameSize), TimeSpan.FromSeconds(tuneMethod.HeartbeatInterval)); diff --git a/src/RabbitMQ.Next/ConnectionDetails.cs b/src/RabbitMQ.Next/ConnectionDetails.cs index 62b6d33..609dc9d 100644 --- a/src/RabbitMQ.Next/ConnectionDetails.cs +++ b/src/RabbitMQ.Next/ConnectionDetails.cs @@ -1,3 +1,5 @@ +using System; + namespace RabbitMQ.Next; internal class ConnectionDetails @@ -8,14 +10,19 @@ public ConnectionDetails(ConnectionSettings settings) } public ConnectionSettings Settings { get; } + + public int? ChannelMax { get; private set; } - public NegotiationResults Negotiated { get; set; } - - public string RemoteHost { get; set; } - - public string RemotePort { get; set; } + public int? FrameMaxSize { get; private set; } - public bool IsSsl { get; set; } - - public string VirtualHost { get; set; } -} \ No newline at end of file + public TimeSpan? HeartbeatInterval { get; private set; } + + public void PopulateWithNegotiationResults(NegotiationResults negotiationResults) + { + ArgumentNullException.ThrowIfNull(negotiationResults); + + this.ChannelMax = negotiationResults.ChannelMax; + this.FrameMaxSize = negotiationResults.FrameMaxSize; + this.HeartbeatInterval = negotiationResults.HeartbeatInterval; + } +} diff --git a/src/RabbitMQ.Next/NegotiationResults.cs b/src/RabbitMQ.Next/NegotiationResults.cs index 0ef416f..4f70cc6 100644 --- a/src/RabbitMQ.Next/NegotiationResults.cs +++ b/src/RabbitMQ.Next/NegotiationResults.cs @@ -4,9 +4,8 @@ namespace RabbitMQ.Next; internal class NegotiationResults { - public NegotiationResults(string authMechanism, int channelMax, int frameMaxSize, TimeSpan heartbeatInterval) + public NegotiationResults(int channelMax, int frameMaxSize, TimeSpan heartbeatInterval) { - this.AuthMechanism = authMechanism; this.ChannelMax = channelMax; this.FrameMaxSize = frameMaxSize; this.HeartbeatInterval = heartbeatInterval; @@ -17,6 +16,4 @@ public NegotiationResults(string authMechanism, int channelMax, int frameMaxSize public int FrameMaxSize { get; } public TimeSpan HeartbeatInterval { get; } - - public string AuthMechanism { get; } -} \ No newline at end of file +} diff --git a/src/RabbitMQ.Next/Tasks/TaskExtensions.cs b/src/RabbitMQ.Next/Tasks/TaskExtensions.cs index 141c1cf..759191c 100644 --- a/src/RabbitMQ.Next/Tasks/TaskExtensions.cs +++ b/src/RabbitMQ.Next/Tasks/TaskExtensions.cs @@ -1,3 +1,4 @@ +using System; using System.Threading; using System.Threading.Tasks; @@ -17,15 +18,26 @@ public static Task AsTask(this CancellationToken cancellation) return tcs.Task; } - public static T Wait(this ValueTask valueTask) + public static (bool IsCompleted, T Result) Wait(this ValueTask valueTask, TimeSpan? timeout) { if (valueTask.IsCompleted) { - return valueTask.Result; + return (true, valueTask.Result); } - + var task = valueTask.AsTask(); - return task.GetAwaiter().GetResult(); + var timeoutMs = -1; // -1 means infinite + if (timeout.HasValue) + { + timeoutMs = (int)timeout.Value.TotalMilliseconds; + } + + if (task.Wait(timeoutMs)) + { + return (true, task.Result); + } + + return (false, default); } public static Task WithCancellation(this Task task, CancellationToken cancellation)