using System;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading.Tasks;
using Tapeti.Config;
namespace Tapeti.Flow.Default
{
///
/// Default implementation for IFlowStarter.
///
internal class FlowStarter : IFlowStarter
{
private readonly ITapetiConfig config;
///
///
public FlowStarter(ITapetiConfig config)
{
this.config = config;
}
///
public async Task Start(Expression>> methodSelector) where TController : class
{
await CallControllerMethod(GetExpressionMethod(methodSelector), value => Task.FromResult((IYieldPoint)value), Array.Empty());
}
///
public async Task Start(Expression>>> methodSelector) where TController : class
{
await CallControllerMethod(GetExpressionMethod(methodSelector), value => (Task)value, Array.Empty());
}
///
public async Task Start(Expression>> methodSelector, TParameter parameter) where TController : class
{
await CallControllerMethod(GetExpressionMethod(methodSelector), value => Task.FromResult((IYieldPoint)value), new object?[] {parameter});
}
///
public async Task Start(Expression>>> methodSelector, TParameter parameter) where TController : class
{
await CallControllerMethod(GetExpressionMethod(methodSelector), value => (Task)value, new object?[] {parameter});
}
private async Task CallControllerMethod(MethodInfo method, Func> getYieldPointResult, object?[] parameters) where TController : class
{
var controller = config.DependencyResolver.Resolve();
var result = method.Invoke(controller, parameters);
if (result == null)
throw new InvalidOperationException($"Method {method.Name} must return an IYieldPoint or Task, got null");
var yieldPoint = await getYieldPointResult(result);
var context = new FlowHandlerContext(config, controller, method);
var flowHandler = config.DependencyResolver.Resolve();
await flowHandler.Execute(context, yieldPoint);
}
private static MethodInfo GetExpressionMethod(Expression>> methodSelector)
{
var callExpression = (methodSelector.Body as UnaryExpression)?.Operand as MethodCallExpression;
var targetMethodExpression = callExpression?.Object as ConstantExpression;
var method = targetMethodExpression?.Value as MethodInfo;
if (method == null)
throw new ArgumentException("Unable to determine the starting method", nameof(methodSelector));
return method;
}
private static MethodInfo GetExpressionMethod(Expression>> methodSelector)
{
var callExpression = (methodSelector.Body as UnaryExpression)?.Operand as MethodCallExpression;
var targetMethodExpression = callExpression?.Object as ConstantExpression;
var method = targetMethodExpression?.Value as MethodInfo;
if (method == null)
throw new ArgumentException("Unable to determine the starting method", nameof(methodSelector));
return method;
}
}
}