using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using RabbitMQ.Client; using RabbitMQ.Client.Events; using Svrnty.CQRS.Sagas.Abstractions.Messaging; namespace Svrnty.CQRS.Sagas.RabbitMQ; /// /// RabbitMQ implementation of the saga message bus. /// public class RabbitMqSagaMessageBus : ISagaMessageBus, IAsyncDisposable { private readonly RabbitMqSagaOptions _options; private readonly ILogger _logger; private IConnection? _connection; private IChannel? _publishChannel; private readonly ConcurrentDictionary _subscriptionChannels = new(); private readonly SemaphoreSlim _connectionLock = new(1, 1); private bool _disposed; /// /// Creates a new RabbitMQ saga message bus. /// public RabbitMqSagaMessageBus( IOptions options, ILogger logger) { _options = options.Value; _logger = logger; } /// public async Task PublishAsync(SagaMessage message, CancellationToken cancellationToken = default) { await EnsureConnectionAsync(cancellationToken); var routingKey = $"saga.command.{message.CommandType}"; var body = JsonSerializer.SerializeToUtf8Bytes(message); var properties = new BasicProperties { MessageId = message.MessageId.ToString(), CorrelationId = message.CorrelationId.ToString(), ContentType = "application/json", DeliveryMode = _options.DurableQueues ? DeliveryModes.Persistent : DeliveryModes.Transient, Timestamp = new AmqpTimestamp(message.Timestamp.ToUnixTimeSeconds()), Headers = new Dictionary { ["saga-id"] = message.SagaId.ToString(), ["step-name"] = message.StepName, ["is-compensation"] = message.IsCompensation.ToString() } }; await _publishChannel!.BasicPublishAsync( exchange: _options.CommandExchange, routingKey: routingKey, mandatory: false, basicProperties: properties, body: body, cancellationToken: cancellationToken); _logger.LogDebug( "Published saga command {CommandType} for saga {SagaId}, step {StepName}", message.CommandType, message.SagaId, message.StepName); } /// public async Task PublishResponseAsync(SagaStepResponse response, CancellationToken cancellationToken = default) { await EnsureConnectionAsync(cancellationToken); var routingKey = $"saga.response.{response.SagaId}"; var body = JsonSerializer.SerializeToUtf8Bytes(response); var properties = new BasicProperties { MessageId = response.MessageId.ToString(), CorrelationId = response.CorrelationId.ToString(), ContentType = "application/json", DeliveryMode = _options.DurableQueues ? DeliveryModes.Persistent : DeliveryModes.Transient, Timestamp = new AmqpTimestamp(response.Timestamp.ToUnixTimeSeconds()), Headers = new Dictionary { ["saga-id"] = response.SagaId.ToString(), ["step-name"] = response.StepName, ["success"] = response.Success.ToString() } }; await _publishChannel!.BasicPublishAsync( exchange: _options.ResponseExchange, routingKey: routingKey, mandatory: false, basicProperties: properties, body: body, cancellationToken: cancellationToken); _logger.LogDebug( "Published saga response for saga {SagaId}, step {StepName}, success: {Success}", response.SagaId, response.StepName, response.Success); } /// public async Task SubscribeAsync( Func> handler, CancellationToken cancellationToken = default) where TCommand : class { await EnsureConnectionAsync(cancellationToken); var commandTypeName = typeof(TCommand).FullName!; var queueName = $"{_options.QueuePrefix}.{SanitizeQueueName(commandTypeName)}"; var routingKey = $"saga.command.{commandTypeName}"; var channel = await _connection!.CreateChannelAsync(cancellationToken: cancellationToken); _subscriptionChannels[commandTypeName] = channel; // Declare queue await channel.QueueDeclareAsync( queue: queueName, durable: _options.DurableQueues, exclusive: false, autoDelete: false, cancellationToken: cancellationToken); // Bind to command exchange await channel.QueueBindAsync( queue: queueName, exchange: _options.CommandExchange, routingKey: routingKey, cancellationToken: cancellationToken); await channel.BasicQosAsync(prefetchSize: 0, prefetchCount: _options.PrefetchCount, global: false, cancellationToken: cancellationToken); var consumer = new AsyncEventingBasicConsumer(channel); consumer.ReceivedAsync += async (sender, ea) => { try { var messageJson = Encoding.UTF8.GetString(ea.Body.ToArray()); var message = JsonSerializer.Deserialize(messageJson); if (message == null) { _logger.LogWarning("Received null saga message"); await channel.BasicNackAsync(ea.DeliveryTag, false, false, cancellationToken); return; } var command = JsonSerializer.Deserialize(message.Payload!); if (command == null) { _logger.LogWarning("Failed to deserialize command {CommandType}", commandTypeName); await channel.BasicNackAsync(ea.DeliveryTag, false, false, cancellationToken); return; } var response = await handler(message, command, cancellationToken); await PublishResponseAsync(response, cancellationToken); await channel.BasicAckAsync(ea.DeliveryTag, false, cancellationToken); } catch (Exception ex) { _logger.LogError(ex, "Error processing saga command {CommandType}", commandTypeName); await channel.BasicNackAsync(ea.DeliveryTag, false, true, cancellationToken); } }; await channel.BasicConsumeAsync(queueName, false, consumer, cancellationToken); _logger.LogInformation( "Subscribed to saga commands of type {CommandType} on queue {QueueName}", commandTypeName, queueName); } /// public async Task SubscribeToResponsesAsync( Func handler, CancellationToken cancellationToken = default) { await EnsureConnectionAsync(cancellationToken); var queueName = $"{_options.QueuePrefix}.responses"; var routingKey = "saga.response.#"; var channel = await _connection!.CreateChannelAsync(cancellationToken: cancellationToken); _subscriptionChannels["responses"] = channel; // Declare queue await channel.QueueDeclareAsync( queue: queueName, durable: _options.DurableQueues, exclusive: false, autoDelete: false, cancellationToken: cancellationToken); // Bind to response exchange await channel.QueueBindAsync( queue: queueName, exchange: _options.ResponseExchange, routingKey: routingKey, cancellationToken: cancellationToken); await channel.BasicQosAsync(prefetchSize: 0, prefetchCount: _options.PrefetchCount, global: false, cancellationToken: cancellationToken); var consumer = new AsyncEventingBasicConsumer(channel); consumer.ReceivedAsync += async (sender, ea) => { try { var responseJson = Encoding.UTF8.GetString(ea.Body.ToArray()); var response = JsonSerializer.Deserialize(responseJson); if (response == null) { _logger.LogWarning("Received null saga response"); await channel.BasicNackAsync(ea.DeliveryTag, false, false, cancellationToken); return; } await handler(response, cancellationToken); await channel.BasicAckAsync(ea.DeliveryTag, false, cancellationToken); } catch (Exception ex) { _logger.LogError(ex, "Error processing saga response"); await channel.BasicNackAsync(ea.DeliveryTag, false, true, cancellationToken); } }; await channel.BasicConsumeAsync(queueName, false, consumer, cancellationToken); _logger.LogInformation("Subscribed to saga responses on queue {QueueName}", queueName); } private async Task EnsureConnectionAsync(CancellationToken cancellationToken) { if (_connection?.IsOpen == true && _publishChannel?.IsOpen == true) { return; } await _connectionLock.WaitAsync(cancellationToken); try { if (_connection?.IsOpen == true && _publishChannel?.IsOpen == true) { return; } var factory = new ConnectionFactory { HostName = _options.HostName, Port = _options.Port, UserName = _options.UserName, Password = _options.Password, VirtualHost = _options.VirtualHost }; _connection = await factory.CreateConnectionAsync(cancellationToken); _publishChannel = await _connection.CreateChannelAsync(cancellationToken: cancellationToken); // Declare exchanges await _publishChannel.ExchangeDeclareAsync( exchange: _options.CommandExchange, type: ExchangeType.Topic, durable: _options.DurableQueues, autoDelete: false, cancellationToken: cancellationToken); await _publishChannel.ExchangeDeclareAsync( exchange: _options.ResponseExchange, type: ExchangeType.Topic, durable: _options.DurableQueues, autoDelete: false, cancellationToken: cancellationToken); _logger.LogInformation( "Connected to RabbitMQ at {Host}:{Port}", _options.HostName, _options.Port); } finally { _connectionLock.Release(); } } private static string SanitizeQueueName(string name) { return name.Replace(".", "-").Replace("+", "-").ToLowerInvariant(); } /// public async ValueTask DisposeAsync() { if (_disposed) { return; } _disposed = true; foreach (var channel in _subscriptionChannels.Values) { if (channel.IsOpen) { await channel.CloseAsync(); } channel.Dispose(); } if (_publishChannel?.IsOpen == true) { await _publishChannel.CloseAsync(); } _publishChannel?.Dispose(); if (_connection?.IsOpen == true) { await _connection.CloseAsync(); } _connection?.Dispose(); _connectionLock.Dispose(); } }