using System; using System.Linq.Expressions; using System.Reflection; using System.Threading.Tasks; using Tapeti.Config; namespace Tapeti.Flow.Default { /// /// Default implementation for IFlowStarter. /// internal class FlowStarter : IFlowStarter { private readonly ITapetiConfig config; /// /// public FlowStarter(ITapetiConfig config) { this.config = config; } /// public async Task Start(Expression>> methodSelector) where TController : class { await CallControllerMethod(GetExpressionMethod(methodSelector), value => Task.FromResult((IYieldPoint)value), Array.Empty()); } /// public async Task Start(Expression>>> methodSelector) where TController : class { await CallControllerMethod(GetExpressionMethod(methodSelector), value => (Task)value, Array.Empty()); } /// public async Task Start(Expression>> methodSelector, TParameter parameter) where TController : class { await CallControllerMethod(GetExpressionMethod(methodSelector), value => Task.FromResult((IYieldPoint)value), new object?[] {parameter}); } /// public async Task Start(Expression>>> methodSelector, TParameter parameter) where TController : class { await CallControllerMethod(GetExpressionMethod(methodSelector), value => (Task)value, new object?[] {parameter}); } private async Task CallControllerMethod(MethodInfo method, Func> getYieldPointResult, object?[] parameters) where TController : class { var controller = config.DependencyResolver.Resolve(); var result = method.Invoke(controller, parameters); if (result == null) throw new InvalidOperationException($"Method {method.Name} must return an IYieldPoint or Task, got null"); var yieldPoint = await getYieldPointResult(result); var context = new FlowHandlerContext(config, controller, method); var flowHandler = config.DependencyResolver.Resolve(); await flowHandler.Execute(context, yieldPoint); } private static MethodInfo GetExpressionMethod(Expression>> methodSelector) { var callExpression = (methodSelector.Body as UnaryExpression)?.Operand as MethodCallExpression; var targetMethodExpression = callExpression?.Object as ConstantExpression; var method = targetMethodExpression?.Value as MethodInfo; if (method == null) throw new ArgumentException("Unable to determine the starting method", nameof(methodSelector)); return method; } private static MethodInfo GetExpressionMethod(Expression>> methodSelector) { var callExpression = (methodSelector.Body as UnaryExpression)?.Operand as MethodCallExpression; var targetMethodExpression = callExpression?.Object as ConstantExpression; var method = targetMethodExpression?.Value as MethodInfo; if (method == null) throw new ArgumentException("Unable to determine the starting method", nameof(methodSelector)); return method; } } }