using System;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Svrnty.CQRS.Sagas.Abstractions;
using Svrnty.CQRS.Sagas.Abstractions.Messaging;
using Svrnty.CQRS.Sagas.Abstractions.Persistence;
using Svrnty.CQRS.Sagas.Builders;
using Svrnty.CQRS.Sagas.Configuration;
namespace Svrnty.CQRS.Sagas;
///
/// Implementation of saga orchestration.
///
public class SagaOrchestrator : ISagaOrchestrator
{
private readonly IServiceProvider _serviceProvider;
private readonly ISagaStateStore _stateStore;
private readonly ISagaMessageBus? _messageBus;
private readonly ILogger _logger;
private readonly SagaOptions _options;
///
/// Creates a new saga orchestrator.
///
public SagaOrchestrator(
IServiceProvider serviceProvider,
ISagaStateStore stateStore,
IOptions options,
ILogger logger,
ISagaMessageBus? messageBus = null)
{
_serviceProvider = serviceProvider;
_stateStore = stateStore;
_messageBus = messageBus;
_logger = logger;
_options = options.Value;
}
///
public Task StartAsync(TData initialData, CancellationToken cancellationToken = default)
where TSaga : ISaga
where TData : class, ISagaData, new()
{
return StartAsync(initialData, Guid.NewGuid(), cancellationToken);
}
///
public async Task StartAsync(
TData initialData,
Guid correlationId,
CancellationToken cancellationToken = default)
where TSaga : ISaga
where TData : class, ISagaData, new()
{
initialData.CorrelationId = correlationId;
// Get the saga instance and configure it
var saga = _serviceProvider.GetRequiredService();
var builder = new SagaBuilder();
saga.Configure(builder);
var steps = builder.Steps;
if (steps.Count == 0)
{
throw new InvalidOperationException($"Saga {typeof(TSaga).Name} has no steps configured.");
}
// Create initial state
var state = new SagaState
{
SagaType = typeof(TSaga).FullName!,
CorrelationId = correlationId,
Status = SagaStatus.InProgress,
CurrentStepIndex = 0,
CurrentStepName = steps[0].Name,
SerializedData = JsonSerializer.Serialize(initialData)
};
state = await _stateStore.CreateAsync(state, cancellationToken);
_logger.LogInformation(
"Started saga {SagaType} with ID {SagaId} and CorrelationId {CorrelationId}",
state.SagaType, state.SagaId, state.CorrelationId);
// Execute the first step
await ExecuteNextStepAsync(state, steps, initialData, cancellationToken);
return state;
}
///
public Task GetStateAsync(Guid sagaId, CancellationToken cancellationToken = default)
{
return _stateStore.GetByIdAsync(sagaId, cancellationToken);
}
///
public Task GetStateByCorrelationIdAsync(Guid correlationId, CancellationToken cancellationToken = default)
{
return _stateStore.GetByCorrelationIdAsync(correlationId, cancellationToken);
}
///
/// Handles a response from a remote step.
///
public async Task HandleResponseAsync(
SagaStepResponse response,
CancellationToken cancellationToken = default)
where TData : class, ISagaData, new()
{
var state = await _stateStore.GetByIdAsync(response.SagaId, cancellationToken);
if (state == null)
{
_logger.LogWarning("Received response for unknown saga {SagaId}", response.SagaId);
return;
}
var data = JsonSerializer.Deserialize(state.SerializedData!);
if (data == null)
{
_logger.LogError("Failed to deserialize saga data for {SagaId}", response.SagaId);
return;
}
// Get the saga definition
var sagaType = Type.GetType(state.SagaType);
if (sagaType == null)
{
_logger.LogError("Unknown saga type {SagaType}", state.SagaType);
return;
}
var saga = _serviceProvider.GetService(sagaType) as ISaga;
if (saga == null)
{
_logger.LogError("Could not resolve saga {SagaType}", state.SagaType);
return;
}
var builder = new SagaBuilder();
saga.Configure(builder);
var steps = builder.Steps;
if (response.Success)
{
_logger.LogInformation(
"Step {StepName} completed successfully for saga {SagaId}",
response.StepName, response.SagaId);
state.CompletedSteps.Add(response.StepName);
state.CurrentStepIndex++;
if (state.CurrentStepIndex >= steps.Count)
{
// Saga completed
state.Status = SagaStatus.Completed;
state.CompletedAt = DateTimeOffset.UtcNow;
await _stateStore.UpdateAsync(state, cancellationToken);
_logger.LogInformation("Saga {SagaId} completed successfully", state.SagaId);
}
else
{
// Move to next step
state.CurrentStepName = steps[state.CurrentStepIndex].Name;
await _stateStore.UpdateAsync(state, cancellationToken);
await ExecuteNextStepAsync(state, steps, data, cancellationToken);
}
}
else
{
_logger.LogError(
"Step {StepName} failed for saga {SagaId}: {Error}",
response.StepName, response.SagaId, response.ErrorMessage);
state.Errors.Add(new SagaStepError(
response.StepName,
response.ErrorMessage ?? "Unknown error",
response.StackTrace,
DateTimeOffset.UtcNow));
if (_options.AutoCompensateOnFailure)
{
await StartCompensationAsync(state, steps, data, cancellationToken);
}
else
{
state.Status = SagaStatus.Failed;
await _stateStore.UpdateAsync(state, cancellationToken);
}
}
}
private async Task ExecuteNextStepAsync(
SagaState state,
System.Collections.Generic.IReadOnlyList steps,
TData data,
CancellationToken cancellationToken)
where TData : class, ISagaData
{
if (state.CurrentStepIndex >= steps.Count)
{
state.Status = SagaStatus.Completed;
state.CompletedAt = DateTimeOffset.UtcNow;
await _stateStore.UpdateAsync(state, cancellationToken);
return;
}
var step = steps[state.CurrentStepIndex];
var context = new SagaContext(state);
_logger.LogDebug(
"Executing step {StepName} ({StepIndex}/{TotalSteps}) for saga {SagaId}",
step.Name, state.CurrentStepIndex + 1, steps.Count, state.SagaId);
try
{
if (step.IsRemote)
{
await ExecuteRemoteStepAsync(state, step, data, context, cancellationToken);
}
else
{
await ExecuteLocalStepAsync(state, step, data, context, steps, cancellationToken);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error executing step {StepName} for saga {SagaId}", step.Name, state.SagaId);
state.Errors.Add(new SagaStepError(
step.Name,
ex.Message,
ex.StackTrace,
DateTimeOffset.UtcNow));
if (_options.AutoCompensateOnFailure)
{
await StartCompensationAsync(state, steps, data, cancellationToken);
}
else
{
state.Status = SagaStatus.Failed;
await _stateStore.UpdateAsync(state, cancellationToken);
}
}
}
private async Task ExecuteLocalStepAsync(
SagaState state,
SagaStepDefinition step,
TData data,
SagaContext context,
System.Collections.Generic.IReadOnlyList steps,
CancellationToken cancellationToken)
where TData : class, ISagaData
{
if (step is LocalSagaStepDefinition localStep && localStep.ExecuteAction != null)
{
await localStep.ExecuteAction(data, context, cancellationToken);
}
// Local step completed, update state and continue
state.CompletedSteps.Add(step.Name);
state.SerializedData = JsonSerializer.Serialize(data);
state.CurrentStepIndex++;
if (state.CurrentStepIndex < steps.Count)
{
state.CurrentStepName = steps[state.CurrentStepIndex].Name;
}
await _stateStore.UpdateAsync(state, cancellationToken);
// Continue to next step
await ExecuteNextStepAsync(state, steps, data, cancellationToken);
}
private async Task ExecuteRemoteStepAsync(
SagaState state,
SagaStepDefinition step,
TData data,
SagaContext context,
CancellationToken cancellationToken)
where TData : class, ISagaData
{
if (_messageBus == null)
{
throw new InvalidOperationException(
"Remote saga steps require a message bus. Configure RabbitMQ or another transport.");
}
object? command = null;
string commandType;
// Get the command from the step definition
var stepType = step.GetType();
var commandBuilderProp = stepType.GetProperty("CommandBuilder");
if (commandBuilderProp?.GetValue(step) is Delegate commandBuilder)
{
command = commandBuilder.DynamicInvoke(data, context);
}
if (command == null)
{
throw new InvalidOperationException($"Step {step.Name} did not produce a command.");
}
commandType = command.GetType().FullName!;
var message = new SagaMessage
{
SagaId = state.SagaId,
CorrelationId = state.CorrelationId,
StepName = step.Name,
CommandType = commandType,
Payload = JsonSerializer.Serialize(command, command.GetType())
};
await _messageBus.PublishAsync(message, cancellationToken);
await _stateStore.UpdateAsync(state, cancellationToken);
_logger.LogDebug(
"Published command {CommandType} for step {StepName} of saga {SagaId}",
commandType, step.Name, state.SagaId);
}
private async Task StartCompensationAsync(
SagaState state,
System.Collections.Generic.IReadOnlyList steps,
TData data,
CancellationToken cancellationToken)
where TData : class, ISagaData
{
_logger.LogInformation("Starting compensation for saga {SagaId}", state.SagaId);
state.Status = SagaStatus.Compensating;
await _stateStore.UpdateAsync(state, cancellationToken);
// Execute compensation in reverse order
var context = new SagaContext(state);
var completedSteps = state.CompletedSteps.ToList();
for (var i = completedSteps.Count - 1; i >= 0; i--)
{
var stepName = completedSteps[i];
var step = steps.FirstOrDefault(s => s.Name == stepName);
if (step == null || !step.HasCompensation)
{
continue;
}
_logger.LogDebug("Compensating step {StepName} for saga {SagaId}", stepName, state.SagaId);
try
{
if (step.IsRemote)
{
await ExecuteRemoteCompensationAsync(state, step, data, context, cancellationToken);
}
else if (step is LocalSagaStepDefinition localStep && localStep.CompensateAction != null)
{
await localStep.CompensateAction(data, context, cancellationToken);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error during compensation of step {StepName} for saga {SagaId}",
stepName, state.SagaId);
// Continue with other compensations even if one fails
}
}
state.Status = SagaStatus.Compensated;
state.CompletedAt = DateTimeOffset.UtcNow;
await _stateStore.UpdateAsync(state, cancellationToken);
_logger.LogInformation("Saga {SagaId} compensation completed", state.SagaId);
}
private async Task ExecuteRemoteCompensationAsync(
SagaState state,
SagaStepDefinition step,
TData data,
SagaContext context,
CancellationToken cancellationToken)
where TData : class, ISagaData
{
if (_messageBus == null)
{
return;
}
var stepType = step.GetType();
var compensationBuilderProp = stepType.GetProperty("CompensationBuilder");
var compensationTypeProp = stepType.GetProperty("CompensationCommandType");
if (compensationBuilderProp?.GetValue(step) is Delegate compensationBuilder &&
compensationTypeProp?.GetValue(step) is Type compensationType)
{
var compensationCommand = compensationBuilder.DynamicInvoke(data, context);
if (compensationCommand != null)
{
var message = new SagaMessage
{
SagaId = state.SagaId,
CorrelationId = state.CorrelationId,
StepName = step.Name,
CommandType = compensationType.FullName!,
Payload = JsonSerializer.Serialize(compensationCommand, compensationType),
IsCompensation = true
};
await _messageBus.PublishAsync(message, cancellationToken);
_logger.LogDebug(
"Published compensation command {CommandType} for step {StepName} of saga {SagaId}",
compensationType.Name, step.Name, state.SagaId);
}
}
}
}