diff --git a/Tapeti/Config/IMessageContext.cs b/Tapeti/Config/IMessageContext.cs index 6bad232..d0ecbd5 100644 --- a/Tapeti/Config/IMessageContext.cs +++ b/Tapeti/Config/IMessageContext.cs @@ -21,5 +21,7 @@ namespace Tapeti.Config object Controller { get; } IBinding Binding { get; } + + IMessageContext SetupNestedContext(); } } diff --git a/Tapeti/Connection/TapetiConsumer.cs b/Tapeti/Connection/TapetiConsumer.cs index 4d83df9..e55b024 100644 --- a/Tapeti/Connection/TapetiConsumer.cs +++ b/Tapeti/Connection/TapetiConsumer.cs @@ -6,6 +6,7 @@ using RabbitMQ.Client; using Tapeti.Config; using Tapeti.Default; using Tapeti.Helpers; +using System.Threading.Tasks; namespace Tapeti.Connection { @@ -56,30 +57,12 @@ namespace Tapeti.Connection { foreach (var binding in bindings) { - if (!binding.Accept(context, message)) - continue; + if (binding.Accept(context, message)) + { + InvokeUsingBinding(context, binding, message); - context.Binding = binding; - - // ReSharper disable AccessToDisposedClosure - MiddlewareHelper will not keep a reference to the lambdas - MiddlewareHelper.GoAsync( - binding.MessageFilterMiddleware, - async (handler, next) => await handler.Handle(context, next), - async () => - { - context.Controller = dependencyResolver.Resolve(binding.Controller); - - await MiddlewareHelper.GoAsync( - binding.MessageMiddleware != null - ? messageMiddleware.Concat(binding.MessageMiddleware).ToList() - : messageMiddleware, - async (handler, next) => await handler.Handle(context, next), - () => binding.Invoke(context, message) - ); - }).Wait(); - // ReSharper restore AccessToDisposedClosure - - validMessageType = true; + validMessageType = true; + } } if (!validMessageType) @@ -104,6 +87,60 @@ namespace Tapeti.Connection } + private void InvokeUsingBinding(MessageContext context, IBinding binding, object message) + { + context.Binding = binding; + + RecursiveCaller firstCaller = null; + RecursiveCaller currentCaller = null; + + Action addHandler = (Handler handle) => + { + var caller = new RecursiveCaller(handle); + if (currentCaller == null) + firstCaller = caller; + else + currentCaller.next = caller; + currentCaller = caller; + }; + + if (binding.MessageFilterMiddleware != null) + { + foreach (var m in binding.MessageFilterMiddleware) + { + addHandler(m.Handle); + } + } + + addHandler(async (c, next) => + { + c.Controller = dependencyResolver.Resolve(binding.Controller); + await next(); + }); + + foreach (var m in messageMiddleware) + { + addHandler(m.Handle); + } + + if (binding.MessageMiddleware != null) + { + foreach (var m in binding.MessageMiddleware) + { + addHandler(m.Handle); + } + } + + addHandler(async (c, next) => + { + await binding.Invoke(context, message); + }); + + firstCaller.Call(context) + .Wait(); + + } + private static Exception UnwrapException(Exception exception) { // In async/await style code this is handled similarly. For synchronous @@ -120,4 +157,61 @@ namespace Tapeti.Connection } } } + + public delegate Task Handler(MessageContext context, Func next); + + public class RecursiveCaller: ICallFrame + { + private Handler handle; + private MessageContext context; + private MessageContext nextContext; + public RecursiveCaller next; + + public RecursiveCaller(Handler handle) + { + this.handle = handle; + } + + internal async Task Call(MessageContext context) + { + if (this.context != null) + throw new InvalidOperationException("Cannot simultaneously call 'next' in Middleware."); + + try + { + this.context = context; + + if (next != null) + context.SetCallFrame(this); + + await handle(context, callNext); + } + finally + { + context = null; + } + } + + private Task callNext() + { + if (next == null) + return Task.CompletedTask; + + return next.Call(nextContext ?? context); + } + + void ICallFrame.UseNestedContext(MessageContext context) + { + if (nextContext != null) + throw new InvalidOperationException("Previous nested context was not yet disposed."); + nextContext = context; + } + + void ICallFrame.OnContextDisposed(MessageContext context) + { + if (nextContext == context) + nextContext = null; + } + } + } diff --git a/Tapeti/Default/MessageContext.cs b/Tapeti/Default/MessageContext.cs index 5872701..728655b 100644 --- a/Tapeti/Default/MessageContext.cs +++ b/Tapeti/Default/MessageContext.cs @@ -1,10 +1,17 @@ using System; +using System.Collections; using System.Collections.Generic; using RabbitMQ.Client; using Tapeti.Config; +using System.Linq; namespace Tapeti.Default { + public interface ICallFrame { + void UseNestedContext(MessageContext context); + void OnContextDisposed(MessageContext context); + } + public class MessageContext : IMessageContext { public IDependencyResolver DependencyResolver { get; set; } @@ -17,13 +24,198 @@ namespace Tapeti.Default public object Message { get; set; } public IBasicProperties Properties { get; set; } - public IDictionary Items { get; } = new Dictionary(); + public IDictionary Items { get; } + private readonly MessageContext outerContext; + private ICallFrame callFrame; + + public MessageContext() + { + Items = new Dictionary(); + } + + public MessageContext(ICallFrame callFrame) + { + Items = new Dictionary(); + + this.callFrame = callFrame; + } + + private MessageContext(MessageContext outerContext) + { + DependencyResolver = outerContext.DependencyResolver; + + Controller = outerContext.Controller; + Binding = outerContext.Binding; + + Queue = outerContext.Queue; + RoutingKey = outerContext.RoutingKey; + Message = outerContext.Message; + Properties = outerContext.Properties; + + Items = new DeferingDictionary(outerContext.Items); + + this.outerContext = outerContext; + } public void Dispose() { - foreach (var value in Items.Values) + var items = (Items as DeferingDictionary)?.MyState ?? Items; + + foreach (var value in items.Values) (value as IDisposable)?.Dispose(); + + callFrame?.OnContextDisposed(this); + } + + public void SetCallFrame(ICallFrame callFrame) + { + this.callFrame = callFrame; + } + + public IMessageContext SetupNestedContext() + { + if (callFrame == null) + throw new NotSupportedException("This context does not support creating nested contexts"); + + var nested = new MessageContext(this); + + callFrame.UseNestedContext(nested); + + return nested; + } + + private class DeferingDictionary : IDictionary + { + private IDictionary myState; + private IDictionary deferee; + + public DeferingDictionary(IDictionary deferee) + { + myState = new Dictionary(); + this.deferee = deferee; + } + + public IDictionary MyState => myState; + + object IDictionary.this[string key] + { + get + { + return myState.ContainsKey(key) ? myState[key] : deferee[key]; + } + + set + { + if (deferee.ContainsKey(key)) + throw new InvalidOperationException("Cannot hide an item set in an outer context."); + + myState[key] = value; + } + } + + int ICollection>.Count + { + get + { + return myState.Count + deferee.Count; + } + } + + bool ICollection>.IsReadOnly + { + get + { + return false; + } + } + + ICollection IDictionary.Keys + { + get + { + return myState.Keys.Concat(deferee.Keys).ToList().AsReadOnly(); + } + } + + ICollection IDictionary.Values + { + get + { + return myState.Values.Concat(deferee.Values).ToList().AsReadOnly(); + } + } + + void ICollection>.Add(KeyValuePair item) + { + if (deferee.ContainsKey(item.Key)) + throw new InvalidOperationException("Cannot hide an item set in an outer context."); + + myState.Add(item); + } + + void IDictionary.Add(string key, object value) + { + if (deferee.ContainsKey(key)) + throw new InvalidOperationException("Cannot hide an item set in an outer context."); + + myState.Add(key, value); + } + + void ICollection>.Clear() + { + throw new InvalidOperationException("Cannot influence the items in an outer context."); + } + + bool ICollection>.Contains(KeyValuePair item) + { + return myState.Contains(item) || deferee.Contains(item); + } + + bool IDictionary.ContainsKey(string key) + { + return myState.ContainsKey(key) || deferee.ContainsKey(key); + } + + void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) + { + foreach(var item in myState.Concat(deferee)) + { + array[arrayIndex++] = item; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return (IEnumerator)myState.Concat(deferee); + } + + IEnumerator> IEnumerable>.GetEnumerator() + { + return (IEnumerator < KeyValuePair < string, object>> )myState.Concat(deferee); + } + + bool ICollection>.Remove(KeyValuePair item) + { + if (deferee.ContainsKey(item.Key)) + throw new InvalidOperationException("Cannot remove an item set in an outer context."); + + return myState.Remove(item); + } + + bool IDictionary.Remove(string key) + { + if (deferee.ContainsKey(key)) + throw new InvalidOperationException("Cannot remove an item set in an outer context."); + + return myState.Remove(key); + } + + bool IDictionary.TryGetValue(string key, out object value) + { + return myState.TryGetValue(key, out value) + || deferee.TryGetValue(key, out value); + } } } }