using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Reflection; using System.Threading.Tasks; using Tapeti.Annotations; using Tapeti.Config; using Tapeti.Default; using Tapeti.Flow.Annotations; using Tapeti.Flow.FlowHelpers; namespace Tapeti.Flow.Default { /// /> /// /// Default implementation for IFlowProvider. /// public class FlowProvider : IFlowProvider, IFlowHandler { private readonly ITapetiConfig config; private readonly IInternalPublisher publisher; /// public FlowProvider(ITapetiConfig config, IPublisher publisher) { this.config = config; this.publisher = (IInternalPublisher)publisher; } /// public IYieldPoint YieldWithRequest(TRequest message, Func> responseHandler) { var responseHandlerInfo = GetResponseHandlerInfo(config, message, responseHandler); return new DelegateYieldPoint(context => SendRequest(context, message, responseHandlerInfo)); } /// public IYieldPoint YieldWithRequestSync(TRequest message, Func responseHandler) { var responseHandlerInfo = GetResponseHandlerInfo(config, message, responseHandler); return new DelegateYieldPoint(context => SendRequest(context, message, responseHandlerInfo)); } /// public IFlowParallelRequestBuilder YieldWithParallelRequest() { return new ParallelRequestBuilder(config, SendRequest); } /// public IYieldPoint EndWithResponse(TResponse message) { return new DelegateYieldPoint(context => SendResponse(context, message)); } /// public IYieldPoint End() { return new DelegateYieldPoint(EndFlow); } private async Task SendRequest(FlowContext context, object message, ResponseHandlerInfo responseHandlerInfo, string convergeMethodName = null, bool convergeMethodTaskSync = false) { if (context.FlowState == null) { await CreateNewFlowState(context); Debug.Assert(context.FlowState != null, "context.FlowState != null"); } var continuationID = Guid.NewGuid(); context.FlowState.Continuations.Add(continuationID, new ContinuationMetadata { MethodName = responseHandlerInfo.MethodName, ConvergeMethodName = convergeMethodName, ConvergeMethodSync = convergeMethodTaskSync }); var properties = new MessageProperties { CorrelationId = continuationID.ToString(), ReplyTo = responseHandlerInfo.ReplyToQueue }; await context.Store(); await publisher.Publish(message, properties, true); } private async Task SendResponse(FlowContext context, object message) { var reply = context.FlowState == null ? GetReply(context.MessageContext) : context.FlowState.Metadata.Reply; if (reply == null) throw new YieldPointException("No response is required"); if (message.GetType().FullName != reply.ResponseTypeName) throw new YieldPointException($"Flow must end with a response message of type {reply.ResponseTypeName}, {message.GetType().FullName} was returned instead"); var properties = new MessageProperties { CorrelationId = reply.CorrelationId }; // TODO disallow if replyto is not specified? if (reply.ReplyTo != null) await publisher.PublishDirect(message, reply.ReplyTo, properties, reply.Mandatory); else await publisher.Publish(message, properties, reply.Mandatory); await context.Delete(); } private static async Task EndFlow(FlowContext context) { await context.Delete(); if (context.FlowState?.Metadata.Reply != null) throw new YieldPointException($"Flow must end with a response message of type {context.FlowState.Metadata.Reply.ResponseTypeName}"); } private static ResponseHandlerInfo GetResponseHandlerInfo(ITapetiConfig config, object request, Delegate responseHandler) { var binding = config.Bindings.ForMethod(responseHandler); if (binding == null) throw new ArgumentException("responseHandler must be a registered message handler", nameof(responseHandler)); var requestAttribute = request.GetType().GetCustomAttribute(); if (requestAttribute?.Response != null && !binding.Accept(requestAttribute.Response)) throw new ArgumentException($"responseHandler must accept message of type {requestAttribute.Response}", nameof(responseHandler)); var continuationAttribute = binding.Method.GetCustomAttribute(); if (continuationAttribute == null) throw new ArgumentException("responseHandler must be marked with the Continuation attribute", nameof(responseHandler)); if (binding.QueueName == null) throw new ArgumentException("responseHandler must bind to a valid queue", nameof(responseHandler)); return new ResponseHandlerInfo { MethodName = MethodSerializer.Serialize(responseHandler.Method), ReplyToQueue = binding.QueueName }; } private static ReplyMetadata GetReply(IMessageContext context) { var requestAttribute = context.Message?.GetType().GetCustomAttribute(); if (requestAttribute?.Response == null) return null; return new ReplyMetadata { CorrelationId = context.Properties.CorrelationId, ReplyTo = context.Properties.ReplyTo, ResponseTypeName = requestAttribute.Response.FullName, Mandatory = context.Properties.Persistent.GetValueOrDefault(true) }; } private static async Task CreateNewFlowState(FlowContext flowContext) { var flowStore = flowContext.MessageContext.Config.DependencyResolver.Resolve(); var flowID = Guid.NewGuid(); flowContext.FlowStateLock = await flowStore.LockFlowState(flowID); if (flowContext.FlowStateLock == null) throw new InvalidOperationException("Unable to lock a new flow"); flowContext.FlowState = new FlowState { Metadata = new FlowMetadata { Reply = GetReply(flowContext.MessageContext) } }; } /// public async Task Execute(IControllerMessageContext context, IYieldPoint yieldPoint) { if (!(yieldPoint is DelegateYieldPoint executableYieldPoint)) throw new YieldPointException($"Yield point is required in controller {context.Controller.GetType().Name} for method {context.Binding.Method.Name}"); if (!context.Get(ContextItems.FlowContext, out FlowContext flowContext)) { flowContext = new FlowContext { MessageContext = context }; context.Store(ContextItems.FlowContext, flowContext); } try { await executableYieldPoint.Execute(flowContext); } catch (YieldPointException e) { // Useful for debugging e.Data["Tapeti.Controller.Name"] = context.Controller.GetType().FullName; e.Data["Tapeti.Controller.Method"] = context.Binding.Method.Name; throw; } flowContext.EnsureStoreOrDeleteIsCalled(); } private class ParallelRequestBuilder : IFlowParallelRequestBuilder { public delegate Task SendRequestFunc(FlowContext context, object message, ResponseHandlerInfo responseHandlerInfo, string convergeMethodName, bool convergeMethodSync); private class RequestInfo { public object Message { get; set; } public ResponseHandlerInfo ResponseHandlerInfo { get; set; } } private readonly ITapetiConfig config; private readonly SendRequestFunc sendRequest; private readonly List requests = new List(); public ParallelRequestBuilder(ITapetiConfig config, SendRequestFunc sendRequest) { this.config = config; this.sendRequest = sendRequest; } public IFlowParallelRequestBuilder AddRequest(TRequest message, Func responseHandler) { requests.Add(new RequestInfo { Message = message, ResponseHandlerInfo = GetResponseHandlerInfo(config, message, responseHandler) }); return this; } public IFlowParallelRequestBuilder AddRequestSync(TRequest message, Action responseHandler) { requests.Add(new RequestInfo { Message = message, ResponseHandlerInfo = GetResponseHandlerInfo(config, message, responseHandler) }); return this; } public IYieldPoint Yield(Func> continuation) { return BuildYieldPoint(continuation, false); } public IYieldPoint YieldSync(Func continuation) { return BuildYieldPoint(continuation, true); } private IYieldPoint BuildYieldPoint(Delegate convergeMethod, bool convergeMethodSync) { if (convergeMethod?.Method == null) throw new ArgumentNullException(nameof(convergeMethod)); return new DelegateYieldPoint(context => { if (convergeMethod.Method.DeclaringType != context.MessageContext.Controller.GetType()) throw new YieldPointException("Converge method must be in the same controller class"); return Task.WhenAll(requests.Select(requestInfo => sendRequest(context, requestInfo.Message, requestInfo.ResponseHandlerInfo, convergeMethod.Method.Name, convergeMethodSync))); }); } } internal class ResponseHandlerInfo { public string MethodName { get; set; } public string ReplyToQueue { get; set; } } } }