diff --git a/Tapeti.Flow/Default/DelegateYieldPoint.cs b/Tapeti.Flow/Default/DelegateYieldPoint.cs index 105492e..1c0a28b 100644 --- a/Tapeti.Flow/Default/DelegateYieldPoint.cs +++ b/Tapeti.Flow/Default/DelegateYieldPoint.cs @@ -3,15 +3,7 @@ using System.Threading.Tasks; namespace Tapeti.Flow.Default { - internal interface IExecutableYieldPoint : IYieldPoint - { - bool StoreState { get; } - - Task Execute(FlowContext context); - } - - - internal class DelegateYieldPoint : IYieldPoint + internal class DelegateYieldPoint : IStateYieldPoint, IExecutableYieldPoint { public bool StoreState { get; } diff --git a/Tapeti.Flow/Default/FlowBindingMiddleware.cs b/Tapeti.Flow/Default/FlowBindingMiddleware.cs index bb11bda..66f1e12 100644 --- a/Tapeti.Flow/Default/FlowBindingMiddleware.cs +++ b/Tapeti.Flow/Default/FlowBindingMiddleware.cs @@ -12,8 +12,8 @@ namespace Tapeti.Flow.Default { public void Handle(IBindingContext context, Action next) { - RegisterContinuationFilter(context); RegisterYieldPointResult(context); + RegisterContinuationFilter(context); next(); @@ -29,6 +29,26 @@ namespace Tapeti.Flow.Default context.Use(new FlowMessageFilterMiddleware()); context.Use(new FlowMessageMiddleware()); + + if (context.Result.HasHandler) + return; + + // Continuation without IYieldPoint indicates a ParallelRequestBuilder response handler, + // make sure to store it's state as well + if (context.Result.Info.ParameterType == typeof(Task)) + { + context.Result.SetHandler(async (messageContext, value) => + { + await (Task)value; + await HandleParallelResponse(messageContext); + }); + } + else if (context.Result.Info.ParameterType == typeof(void)) + { + context.Result.SetHandler((messageContext, value) => HandleParallelResponse(messageContext)); + } + else + throw new ArgumentException($"Result type must be IYieldPoint, Task or void in controller {context. Method.DeclaringType?.FullName}, method {context.Method.Name}"); } @@ -59,6 +79,13 @@ namespace Tapeti.Flow.Default } + private static Task HandleParallelResponse(IMessageContext context) + { + var flowHandler = context.DependencyResolver.Resolve(); + return flowHandler.Execute(context, new StateYieldPoint(true)); + } + + private static void ValidateRequestResponse(IBindingContext context) { var request = context.MessageClass?.GetCustomAttribute(); diff --git a/Tapeti.Flow/Default/FlowMessageMiddleware.cs b/Tapeti.Flow/Default/FlowMessageMiddleware.cs index 85f0925..394ae0b 100644 --- a/Tapeti.Flow/Default/FlowMessageMiddleware.cs +++ b/Tapeti.Flow/Default/FlowMessageMiddleware.cs @@ -1,4 +1,5 @@ using System; +using System.Reflection; using System.Threading.Tasks; using Tapeti.Config; @@ -13,12 +14,41 @@ namespace Tapeti.Flow.Default { Newtonsoft.Json.JsonConvert.PopulateObject(flowContext.FlowState.Data, context.Controller); + // Remove Continuation now because the IYieldPoint result handler will store the new state + flowContext.FlowState.Continuations.Remove(flowContext.ContinuationID); + var converge = flowContext.FlowState.Continuations.Count == 0 && + flowContext.ContinuationMetadata.ConvergeMethodName != null; + await next(); - flowContext.FlowState.Continuations.Remove(flowContext.ContinuationID); + if (converge) + await CallConvergeMethod(context, + flowContext.ContinuationMetadata.ConvergeMethodName, + flowContext.ContinuationMetadata.ConvergeMethodSync); } else await next(); } + + + private static async Task CallConvergeMethod(IMessageContext context, string methodName, bool sync) + { + IYieldPoint yieldPoint; + + var method = context.Controller.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance); + if (method == null) + throw new ArgumentException($"Unknown converge method in controller {context.Controller.GetType().Name}: {methodName}"); + + if (sync) + yieldPoint = (IYieldPoint)method.Invoke(context.Controller, new object[] {}); + else + yieldPoint = await (Task)method.Invoke(context.Controller, new object[] { }); + + if (yieldPoint == null) + throw new YieldPointException($"Yield point is required in controller {context.Controller.GetType().Name} for converge method {methodName}"); + + var flowHandler = context.DependencyResolver.Resolve(); + await flowHandler.Execute(context, yieldPoint); + } } } diff --git a/Tapeti.Flow/Default/FlowProvider.cs b/Tapeti.Flow/Default/FlowProvider.cs index 232d11a..aeab792 100644 --- a/Tapeti.Flow/Default/FlowProvider.cs +++ b/Tapeti.Flow/Default/FlowProvider.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Reflection; using System.Threading.Tasks; using RabbitMQ.Client.Framing; @@ -36,8 +37,7 @@ namespace Tapeti.Flow.Default public IFlowParallelRequestBuilder YieldWithParallelRequest() { - throw new NotImplementedException(); - //return new ParallelRequestBuilder(); + return new ParallelRequestBuilder(config, SendRequest); } public IYieldPoint EndWithResponse(TResponse message) @@ -51,7 +51,8 @@ namespace Tapeti.Flow.Default } - private async Task SendRequest(FlowContext context, object message, ResponseHandlerInfo responseHandlerInfo) + private async Task SendRequest(FlowContext context, object message, ResponseHandlerInfo responseHandlerInfo, + string convergeMethodName = null, bool convergeMethodTaskSync = false) { var continuationID = Guid.NewGuid(); @@ -59,7 +60,8 @@ namespace Tapeti.Flow.Default new ContinuationMetadata { MethodName = responseHandlerInfo.MethodName, - ConvergeMethodName = null + ConvergeMethodName = convergeMethodName, + ConvergeMethodSync = convergeMethodTaskSync }); var properties = new BasicProperties @@ -144,8 +146,10 @@ namespace Tapeti.Flow.Default public async Task Execute(IMessageContext context, IYieldPoint yieldPoint) { - var delegateYieldPoint = (DelegateYieldPoint)yieldPoint; - var storeState = delegateYieldPoint.StoreState; + var stateYieldPoint = yieldPoint as IStateYieldPoint; + var executableYieldPoint = yieldPoint as IExecutableYieldPoint; + + var storeState = stateYieldPoint?.StoreState ?? false; FlowContext flowContext; object flowContextItem; @@ -181,7 +185,8 @@ namespace Tapeti.Flow.Default try { - await delegateYieldPoint.Execute(flowContext); + if (executableYieldPoint != null) + await executableYieldPoint.Execute(flowContext); } catch (YieldPointException e) { @@ -203,10 +208,16 @@ namespace Tapeti.Flow.Default } - /* private class ParallelRequestBuilder : IFlowParallelRequestBuilder { - internal class RequestInfo + public delegate Task SendRequestFunc(FlowContext context, + object message, + ResponseHandlerInfo responseHandlerInfo, + string convergeMethodName, + bool convergeMethodSync); + + + private class RequestInfo { public object Message { get; set; } public ResponseHandlerInfo ResponseHandlerInfo { get; set; } @@ -214,15 +225,13 @@ namespace Tapeti.Flow.Default private readonly IConfig config; - private readonly IFlowStore flowStore; - private readonly Func sendRequest; + private readonly SendRequestFunc sendRequest; private readonly List requests = new List(); - public ParallelRequestBuilder(IConfig config, IFlowStore flowStore, Func sendRequest) + public ParallelRequestBuilder(IConfig config, SendRequestFunc sendRequest) { this.config = config; - this.flowStore = flowStore; this.sendRequest = sendRequest; } @@ -253,15 +262,34 @@ namespace Tapeti.Flow.Default public IYieldPoint Yield(Func> continuation) { - return new YieldPoint(flowStore, true, context => Task.WhenAll(requests.Select(requestInfo => sendRequest(context, requestInfo.Message, requestInfo.ResponseHandlerInfo)))); + return BuildYieldPoint(continuation, false); } - public IYieldPoint Yield(Func continuation) + public IYieldPoint YieldSync(Func continuation) { - return new YieldPoint(flowStore, true, context => Task.WhenAll(requests.Select(requestInfo => sendRequest(context, requestInfo.Message, requestInfo.ResponseHandlerInfo)))); + return BuildYieldPoint(continuation, true); } - }*/ + + + private IYieldPoint BuildYieldPoint(Delegate convergeMethod, bool convergeMethodSync) + { + if (convergeMethod?.Method == null) + throw new ArgumentNullException(nameof(convergeMethod)); + + return new DelegateYieldPoint(true, context => + { + if (convergeMethod.Method.DeclaringType != context.MessageContext.Controller.GetType()) + throw new YieldPointException("Converge method must be in the same controller class"); + + return Task.WhenAll(requests.Select(requestInfo => + sendRequest(context, requestInfo.Message, + requestInfo.ResponseHandlerInfo, + convergeMethod.Method.Name, + convergeMethodSync))); + }); + } + } internal class ResponseHandlerInfo diff --git a/Tapeti.Flow/Default/FlowState.cs b/Tapeti.Flow/Default/FlowState.cs index d120370..e1cb0cf 100644 --- a/Tapeti.Flow/Default/FlowState.cs +++ b/Tapeti.Flow/Default/FlowState.cs @@ -81,6 +81,7 @@ namespace Tapeti.Flow.Default { public string MethodName { get; set; } public string ConvergeMethodName { get; set; } + public bool ConvergeMethodSync { get; set; } public ContinuationMetadata Clone() @@ -88,7 +89,8 @@ namespace Tapeti.Flow.Default return new ContinuationMetadata { MethodName = MethodName, - ConvergeMethodName = ConvergeMethodName + ConvergeMethodName = ConvergeMethodName, + ConvergeMethodSync = ConvergeMethodSync }; } } diff --git a/Tapeti.Flow/Default/IInternalYieldPoint.cs b/Tapeti.Flow/Default/IInternalYieldPoint.cs new file mode 100644 index 0000000..5415f35 --- /dev/null +++ b/Tapeti.Flow/Default/IInternalYieldPoint.cs @@ -0,0 +1,15 @@ +using System.Threading.Tasks; + +namespace Tapeti.Flow.Default +{ + internal interface IStateYieldPoint : IYieldPoint + { + bool StoreState { get; } + } + + + internal interface IExecutableYieldPoint : IYieldPoint + { + Task Execute(FlowContext context); + } +} diff --git a/Tapeti.Flow/Default/StateYieldPoint.cs b/Tapeti.Flow/Default/StateYieldPoint.cs new file mode 100644 index 0000000..df9c36e --- /dev/null +++ b/Tapeti.Flow/Default/StateYieldPoint.cs @@ -0,0 +1,13 @@ +namespace Tapeti.Flow.Default +{ + internal class StateYieldPoint : IStateYieldPoint + { + public bool StoreState { get; } + + + public StateYieldPoint(bool storeState) + { + StoreState = storeState; + } + } +} diff --git a/Tapeti.Flow/IFlowProvider.cs b/Tapeti.Flow/IFlowProvider.cs index 0b619e3..97894b9 100644 --- a/Tapeti.Flow/IFlowProvider.cs +++ b/Tapeti.Flow/IFlowProvider.cs @@ -30,7 +30,7 @@ namespace Tapeti.Flow IFlowParallelRequestBuilder AddRequestSync(TRequest message, Action responseHandler); IYieldPoint Yield(Func> continuation); - IYieldPoint Yield(Func continuation); + IYieldPoint YieldSync(Func continuation); } public interface IYieldPoint diff --git a/Tapeti.Flow/Tapeti.Flow.csproj b/Tapeti.Flow/Tapeti.Flow.csproj index bd90e39..0001a78 100644 --- a/Tapeti.Flow/Tapeti.Flow.csproj +++ b/Tapeti.Flow/Tapeti.Flow.csproj @@ -56,9 +56,11 @@ + + diff --git a/Test/MarcoController.cs b/Test/MarcoController.cs index 1f1f9a8..c0fa8bb 100644 --- a/Test/MarcoController.cs +++ b/Test/MarcoController.cs @@ -36,23 +36,42 @@ namespace Test Console.WriteLine(">> Marco (yielding with request)"); await myVisualizer.VisualizeMarco(); + StateTestGuid = Guid.NewGuid(); - return flowProvider.YieldWithRequestSync( - new PoloConfirmationRequestMessage() + return flowProvider.YieldWithParallelRequest() + .AddRequestSync(new PoloConfirmationRequestMessage { StoredInState = StateTestGuid - }, - HandlePoloConfirmationResponse); + }, HandlePoloConfirmationResponse1) + + .AddRequestSync(new PoloConfirmationRequestMessage + { + StoredInState = StateTestGuid + }, HandlePoloConfirmationResponse2) + + .YieldSync(ContinuePoloConfirmation); } [Continuation] - public IYieldPoint HandlePoloConfirmationResponse(PoloConfirmationResponseMessage message) + public void HandlePoloConfirmationResponse1(PoloConfirmationResponseMessage message) { - Console.WriteLine(">> HandlePoloConfirmationResponse (ending flow)"); + Console.WriteLine(">> HandlePoloConfirmationResponse1"); + Console.WriteLine(message.ShouldMatchState.Equals(StateTestGuid) ? "Confirmed!" : "Oops! Mismatch!"); + } + + [Continuation] + public void HandlePoloConfirmationResponse2(PoloConfirmationResponseMessage message) + { + Console.WriteLine(">> HandlePoloConfirmationResponse2"); Console.WriteLine(message.ShouldMatchState.Equals(StateTestGuid) ? "Confirmed!" : "Oops! Mismatch!"); + } + + private IYieldPoint ContinuePoloConfirmation() + { + Console.WriteLine("> ConvergePoloConfirmation (ending flow)"); return flowProvider.EndWithResponse(new PoloMessage()); } @@ -77,7 +96,6 @@ namespace Test public void Polo(PoloMessage message) { Console.WriteLine(">> Polo"); - StateTestGuid = Guid.NewGuid(); } } diff --git a/Test/MarcoEmitter.cs b/Test/MarcoEmitter.cs index 79c0911..7298937 100644 --- a/Test/MarcoEmitter.cs +++ b/Test/MarcoEmitter.cs @@ -17,8 +17,9 @@ namespace Test public async Task Run() { -// await publisher.Publish(new MarcoMessage()); + await publisher.Publish(new MarcoMessage()); + /* var concurrent = new SemaphoreSlim(20); while (true) @@ -38,12 +39,12 @@ namespace Test await Task.Delay(200); } + */ - /* while (true) { await Task.Delay(1000); - }*/ + } } } } diff --git a/Test/Program.cs b/Test/Program.cs index f4244d2..3a81a3d 100644 --- a/Test/Program.cs +++ b/Test/Program.cs @@ -33,7 +33,11 @@ namespace Test }) { Console.WriteLine("Subscribing..."); - connection.Subscribe().Wait(); + var subscriber = connection.Subscribe(false).Result; + + Console.WriteLine("Consuming..."); + subscriber.Resume().Wait(); + Console.WriteLine("Done!"); var emitter = container.GetInstance(); diff --git a/Test/Test.csproj b/Test/Test.csproj index d21d19d..62a3b1a 100644 --- a/Test/Test.csproj +++ b/Test/Test.csproj @@ -59,6 +59,10 @@ + + {c4897d64-d04e-4ae9-bd98-d64295d1d13a} + Tapeti.Annotations + {6de7b122-eb6a-46b8-aeaf-f84dde18f9c7} Tapeti.Flow.SQL