diff --git a/Tapeti/Config/IConfig.cs b/Tapeti/Config/IConfig.cs index 07494cb..6f7449d 100644 --- a/Tapeti/Config/IConfig.cs +++ b/Tapeti/Config/IConfig.cs @@ -30,6 +30,7 @@ namespace Tapeti.Config public interface IDynamicQueue : IQueue { + string GetDeclareQueueName(); void SetName(string name); } diff --git a/Tapeti/Connection/IConnectionEventListener.cs b/Tapeti/Connection/IConnectionEventListener.cs index c0e82df..d86feab 100644 --- a/Tapeti/Connection/IConnectionEventListener.cs +++ b/Tapeti/Connection/IConnectionEventListener.cs @@ -1,9 +1,16 @@ namespace Tapeti.Connection { + public class DisconnectedEventArgs + { + public ushort ReplyCode; + public string ReplyText; + } + + public interface IConnectionEventListener { void Connected(); void Reconnected(); - void Disconnected(); + void Disconnected(DisconnectedEventArgs e); } } diff --git a/Tapeti/Connection/TapetiWorker.cs b/Tapeti/Connection/TapetiWorker.cs index 79db3ed..f9d577a 100644 --- a/Tapeti/Connection/TapetiWorker.cs +++ b/Tapeti/Connection/TapetiWorker.cs @@ -1,8 +1,10 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using RabbitMQ.Client; +using RabbitMQ.Client.Events; using RabbitMQ.Client.Exceptions; using RabbitMQ.Client.Framing; using Tapeti.Config; @@ -16,7 +18,7 @@ namespace Tapeti.Connection { private const int ReconnectDelay = 5000; private const int MandatoryReturnTimeout = 30000; - private const int PublishMaxConnectAttempts = 3; + private const int MinimumConnectedReconnectDelay = 1000; private readonly IConfig config; private readonly ILogger logger; @@ -28,10 +30,30 @@ namespace Tapeti.Connection private readonly IExchangeStrategy exchangeStrategy; private readonly Lazy taskQueue = new Lazy(); + // These fields are for use in the taskQueue only! private RabbitMQ.Client.IConnection connection; + private bool isReconnect; private IModel channelInstance; - private TaskCompletionSource publishResultTaskSource; + private ulong lastDeliveryTag; + private DateTime connectedDateTime; + private readonly Dictionary confirmMessages = new Dictionary(); + private readonly Dictionary returnRoutingKeys = new Dictionary(); + + + private class ConfirmMessageInfo + { + public string ReturnKey; + public TaskCompletionSource CompletionSource; + } + + + private class ReturnInfo + { + public uint RefCount; + public int FirstReplyCode; + } + public TapetiWorker(IConfig config) @@ -64,7 +86,7 @@ namespace Tapeti.Connection return taskQueue.Value.Add(() => { - GetChannel().BasicConsume(queueName, false, new TapetiConsumer(this, queueName, config.DependencyResolver, bindings, config.MessageMiddleware, config.CleanupMiddleware)); + WithRetryableChannel(channel => channel.BasicConsume(queueName, false, new TapetiConsumer(this, queueName, config.DependencyResolver, bindings, config.MessageMiddleware, config.CleanupMiddleware))); }); } @@ -73,34 +95,38 @@ namespace Tapeti.Connection { return taskQueue.Value.Add(() => { - var channel = GetChannel(); - - if (queue.Dynamic) + WithRetryableChannel(channel => { - var dynamicQueue = channel.QueueDeclare(queue.Name); - (queue as IDynamicQueue)?.SetName(dynamicQueue.QueueName); - - foreach (var binding in queue.Bindings) + if (queue.Dynamic) { - if (binding.QueueBindingMode == QueueBindingMode.RoutingKey) + if (!(queue is IDynamicQueue dynamicQueue)) + throw new NullReferenceException("Queue with Dynamic = true must implement IDynamicQueue"); + + var declaredQueue = channel.QueueDeclare(dynamicQueue.GetDeclareQueueName()); + dynamicQueue.SetName(declaredQueue.QueueName); + + foreach (var binding in queue.Bindings) { - var routingKey = routingKeyStrategy.GetRoutingKey(binding.MessageClass); - var exchange = exchangeStrategy.GetExchange(binding.MessageClass); + if (binding.QueueBindingMode == QueueBindingMode.RoutingKey) + { + var routingKey = routingKeyStrategy.GetRoutingKey(binding.MessageClass); + var exchange = exchangeStrategy.GetExchange(binding.MessageClass); - channel.QueueBind(dynamicQueue.QueueName, exchange, routingKey); + channel.QueueBind(declaredQueue.QueueName, exchange, routingKey); + } + + (binding as IBuildBinding)?.SetQueueName(declaredQueue.QueueName); } - - (binding as IBuildBinding)?.SetQueueName(dynamicQueue.QueueName); } - } - else - { - channel.QueueDeclarePassive(queue.Name); - foreach (var binding in queue.Bindings) + else { - (binding as IBuildBinding)?.SetQueueName(queue.Name); + channel.QueueDeclarePassive(queue.Name); + foreach (var binding in queue.Bindings) + { + (binding as IBuildBinding)?.SetQueueName(queue.Name); + } } - } + }); }); } @@ -109,6 +135,8 @@ namespace Tapeti.Connection { return taskQueue.Value.Add(() => { + // No need for a retryable channel here, if the connection is lost we can't + // use the deliveryTag anymore. switch (response) { case ConsumeResponse.Ack: @@ -122,6 +150,9 @@ namespace Tapeti.Connection case ConsumeResponse.Requeue: GetChannel().BasicNack(deliveryTag, false, true); break; + + default: + throw new ArgumentOutOfRangeException(nameof(response), response, null); } }); @@ -175,27 +206,48 @@ namespace Tapeti.Connection return MiddlewareHelper.GoAsync( config.PublishMiddleware, async (handler, next) => await handler.Handle(context, next), - () => taskQueue.Value.Add(() => + () => taskQueue.Value.Add(async () => { var body = messageSerializer.Serialize(context.Message, context.Properties); + Task publishResultTask = null; - - if (config.UsePublisherConfirms) + var messageInfo = new ConfirmMessageInfo { - publishResultTaskSource = new TaskCompletionSource(); - publishResultTask = publishResultTaskSource.Task; - } - else - mandatory = false; + ReturnKey = GetReturnKey(context.Exchange, context.RoutingKey), + CompletionSource = new TaskCompletionSource() + }; + + + WithRetryableChannel(channel => + { + // The delivery tag is lost after a reconnect, register under the new tag + if (config.UsePublisherConfirms) + { + lastDeliveryTag++; + + confirmMessages.Add(lastDeliveryTag, messageInfo); + publishResultTask = messageInfo.CompletionSource.Task; + } + else + mandatory = false; + + channel.BasicPublish(context.Exchange, context.RoutingKey, mandatory, context.Properties, body); + }); - GetChannel(PublishMaxConnectAttempts).BasicPublish(context.Exchange, context.RoutingKey, mandatory, context.Properties, body); if (publishResultTask == null) return; - if (!publishResultTask.Wait(MandatoryReturnTimeout)) + var delayCancellationTokenSource = new CancellationTokenSource(); + var signalledTask = await Task.WhenAny(publishResultTask, Task.Delay(MandatoryReturnTimeout, delayCancellationTokenSource.Token)); + + if (signalledTask != publishResultTask) throw new TimeoutException($"Timeout while waiting for basic.return for message with class {context.Message?.GetType().FullName ?? "null"} and Id {context.Properties.MessageId}"); + delayCancellationTokenSource.Cancel(); + + if (publishResultTask.IsCanceled) + throw new NackException($"Mandatory message with class {context.Message?.GetType().FullName ?? "null"} was nacked"); var replyCode = publishResultTask.Result; @@ -210,16 +262,43 @@ namespace Tapeti.Connection // ReSharper restore ImplicitlyCapturedClosure } + /// /// Only call this from a task in the taskQueue to ensure IModel is only used /// by a single thread, as is recommended in the RabbitMQ .NET Client documentation. /// - private IModel GetChannel(int? maxAttempts = null) + private void WithRetryableChannel(Action operation) { - if (channelInstance != null) + while (true) + { + try + { + operation(GetChannel()); + break; + } + catch (AlreadyClosedException e) + { + // TODO log? + } + } + } + + + /// + /// Only call this from a task in the taskQueue to ensure IModel is only used + /// by a single thread, as is recommended in the RabbitMQ .NET Client documentation. + /// + private IModel GetChannel() + { + if (channelInstance != null && channelInstance.IsOpen) return channelInstance; - var attempts = 0; + // If the Disconnect quickly follows the Connect (when an error occurs that is reported back by RabbitMQ + // not related to the connection), wait for a bit to avoid spamming the connection + if ((DateTime.UtcNow - connectedDateTime).TotalMilliseconds <= MinimumConnectedReconnectDelay) + Thread.Sleep(ReconnectDelay); + + var connectionFactory = new ConnectionFactory { HostName = ConnectionParams.HostName, @@ -227,8 +306,8 @@ namespace Tapeti.Connection VirtualHost = ConnectionParams.VirtualHost, UserName = ConnectionParams.Username, Password = ConnectionParams.Password, - AutomaticRecoveryEnabled = true, // The created connection is an IRecoverable - TopologyRecoveryEnabled = false, // We'll manually redeclare all queues in the Reconnect event to update the internal state for dynamic queues + AutomaticRecoveryEnabled = false, + TopologyRecoveryEnabled = false, RequestedHeartbeat = 30 }; @@ -240,39 +319,50 @@ namespace Tapeti.Connection connection = connectionFactory.CreateConnection(); channelInstance = connection.CreateModel(); - channelInstance.ConfirmSelect(); + + if (channelInstance == null) + throw new BrokerUnreachableException(null); + + if (config.UsePublisherConfirms) + { + lastDeliveryTag = 0; + confirmMessages.Clear(); + channelInstance.ConfirmSelect(); + } if (ConnectionParams.PrefetchCount > 0) channelInstance.BasicQos(0, ConnectionParams.PrefetchCount, false); - ((IRecoverable)connection).Recovery += (sender, e) => ConnectionEventListener?.Reconnected(); - - channelInstance.ModelShutdown += (sender, eventArgs) => ConnectionEventListener?.Disconnected(); - channelInstance.BasicReturn += (sender, eventArgs) => + channelInstance.ModelShutdown += (sender, e) => { - publishResultTaskSource?.SetResult(eventArgs.ReplyCode); - publishResultTaskSource = null; + ConnectionEventListener?.Disconnected(new DisconnectedEventArgs + { + ReplyCode = e.ReplyCode, + ReplyText = e.ReplyText + }); + + channelInstance = null; }; - channelInstance.BasicAcks += (sender, eventArgs) => - { - publishResultTaskSource?.SetResult(0); - publishResultTaskSource = null; - }; + channelInstance.BasicReturn += HandleBasicReturn; + channelInstance.BasicAcks += HandleBasicAck; + channelInstance.BasicNacks += HandleBasicNack; + + connectedDateTime = DateTime.UtcNow; + + if (isReconnect) + ConnectionEventListener?.Reconnected(); + else + ConnectionEventListener?.Connected(); - ConnectionEventListener?.Connected(); logger.ConnectSuccess(ConnectionParams); + isReconnect = true; break; } catch (BrokerUnreachableException e) { logger.ConnectFailed(ConnectionParams, e); - - attempts++; - if (maxAttempts.HasValue && attempts > maxAttempts.Value) - throw; - Thread.Sleep(ReconnectDelay); } } @@ -281,6 +371,93 @@ namespace Tapeti.Connection } + private void HandleBasicReturn(object sender, BasicReturnEventArgs e) + { + /* + * "If the message is also published as mandatory, the basic.return is sent to the client before basic.ack." + * - https://www.rabbitmq.com/confirms.html + * + * Because there is no delivery tag included in the basic.return message. This solution is modeled after + * user OhJeez' answer on StackOverflow: + * + * "Since all messages with the same routing key are routed the same way. I assumed that once I get a + * basic.return about a specific routing key, all messages with this routing key can be considered undelivered" + * https://stackoverflow.com/questions/21336659/how-to-tell-which-amqp-message-was-not-routed-from-basic-return-response + */ + var key = GetReturnKey(e.Exchange, e.RoutingKey); + + if (!returnRoutingKeys.TryGetValue(key, out var returnInfo)) + { + returnInfo = new ReturnInfo + { + RefCount = 0, + FirstReplyCode = e.ReplyCode + }; + + returnRoutingKeys.Add(key, returnInfo); + } + + returnInfo.RefCount++; + } + + + private void HandleBasicAck(object sender, BasicAckEventArgs e) + { + foreach (var deliveryTag in GetDeliveryTags(e)) + { + if (!confirmMessages.TryGetValue(deliveryTag, out var messageInfo)) + continue; + + if (returnRoutingKeys.TryGetValue(messageInfo.ReturnKey, out var returnInfo)) + { + messageInfo.CompletionSource.SetResult(returnInfo.FirstReplyCode); + + returnInfo.RefCount--; + if (returnInfo.RefCount == 0) + returnRoutingKeys.Remove(messageInfo.ReturnKey); + } + + messageInfo.CompletionSource.SetResult(0); + confirmMessages.Remove(deliveryTag); + } + } + + + private void HandleBasicNack(object sender, BasicNackEventArgs e) + { + foreach (var deliveryTag in GetDeliveryTags(e)) + { + if (!confirmMessages.TryGetValue(deliveryTag, out var messageInfo)) + continue; + + messageInfo.CompletionSource.SetCanceled(); + confirmMessages.Remove(e.DeliveryTag); + } + } + + + private IEnumerable GetDeliveryTags(BasicAckEventArgs e) + { + return e.Multiple + ? confirmMessages.Keys.Where(tag => tag <= e.DeliveryTag).ToArray() + : new[] { e.DeliveryTag }; + } + + + private IEnumerable GetDeliveryTags(BasicNackEventArgs e) + { + return e.Multiple + ? confirmMessages.Keys.Where(tag => tag <= e.DeliveryTag).ToArray() + : new[] { e.DeliveryTag }; + } + + + private static string GetReturnKey(string exchange, string routingKey) + { + return exchange + ':' + routingKey; + } + + private class PublishContext : IPublishContext { public IDependencyResolver DependencyResolver { get; set; } diff --git a/Tapeti/Exceptions/NackException.cs b/Tapeti/Exceptions/NackException.cs new file mode 100644 index 0000000..408dd71 --- /dev/null +++ b/Tapeti/Exceptions/NackException.cs @@ -0,0 +1,9 @@ +using System; + +namespace Tapeti.Exceptions +{ + public class NackException : Exception + { + public NackException(string message) : base(message) { } + } +} diff --git a/Tapeti/TapetiConfig.cs b/Tapeti/TapetiConfig.cs index 9a0867f..c785408 100644 --- a/Tapeti/TapetiConfig.cs +++ b/Tapeti/TapetiConfig.cs @@ -462,6 +462,8 @@ namespace Tapeti protected class Queue : IDynamicQueue { + private readonly string declareQueueName; + public bool Dynamic { get; } public string Name { get; set; } public IEnumerable Bindings { get; } @@ -469,12 +471,20 @@ namespace Tapeti public Queue(QueueInfo queue, IEnumerable bindings) { + declareQueueName = queue.Name; + Dynamic = queue.Dynamic.GetValueOrDefault(); Name = queue.Name; Bindings = bindings; } + public string GetDeclareQueueName() + { + return declareQueueName; + } + + public void SetName(string name) { Name = name; diff --git a/Tapeti/TapetiConnection.cs b/Tapeti/TapetiConnection.cs index 238d320..d66f880 100644 --- a/Tapeti/TapetiConnection.cs +++ b/Tapeti/TapetiConnection.cs @@ -8,6 +8,8 @@ using Tapeti.Connection; namespace Tapeti { + public delegate void DisconnectedEventHandler(object sender, DisconnectedEventArgs e); + public class TapetiConnection : IDisposable { private readonly IConfig config; @@ -29,11 +31,10 @@ namespace Tapeti } public event EventHandler Connected; - - public event EventHandler Disconnected; - + public event DisconnectedEventHandler Disconnected; public event EventHandler Reconnected; + public async Task Subscribe(bool startConsuming = true) { if (subscriber == null) @@ -87,9 +88,9 @@ namespace Tapeti owner.OnConnected(new EventArgs()); } - public void Disconnected() + public void Disconnected(DisconnectedEventArgs e) { - owner.OnDisconnected(new EventArgs()); + owner.OnDisconnected(e); } public void Reconnected() @@ -114,7 +115,7 @@ namespace Tapeti }); } - protected virtual void OnDisconnected(EventArgs e) + protected virtual void OnDisconnected(DisconnectedEventArgs e) { Task.Run(() => Disconnected?.Invoke(this, e)); } diff --git a/Tapeti/Tasks/SingleThreadTaskQueue.cs b/Tapeti/Tasks/SingleThreadTaskQueue.cs index fa28949..f22f869 100644 --- a/Tapeti/Tasks/SingleThreadTaskQueue.cs +++ b/Tapeti/Tasks/SingleThreadTaskQueue.cs @@ -27,7 +27,7 @@ namespace Tapeti.Tasks } - public Task Add(Func func) + public Task Add(Func func) { lock (previousTaskLock) { @@ -36,7 +36,11 @@ namespace Tapeti.Tasks , singleThreadScheduler.Value); previousTask = task; - return task; + + // 'task' completes at the moment a Task is returned (for example, an await is encountered), + // this is used to chain the next. We return the unwrapped Task however, so that the caller + // awaits until the full task chain has completed. + return task.Unwrap(); } }