diff --git a/SAGAS_ROADMAP.md b/SAGAS_ROADMAP.md new file mode 100644 index 0000000..b9ce492 --- /dev/null +++ b/SAGAS_ROADMAP.md @@ -0,0 +1,122 @@ +# Saga Orchestration Roadmap + +## Completed (Phase 1) + +- [x] `Svrnty.CQRS.Sagas.Abstractions` - Core interfaces and contracts +- [x] `Svrnty.CQRS.Sagas` - Orchestration engine with fluent builder API +- [x] `Svrnty.CQRS.Sagas.RabbitMQ` - RabbitMQ message transport + +--- + +## Phase 1d: Testing & Sample + +### Unit Tests +- [ ] `SagaBuilder` step configuration tests +- [ ] `SagaOrchestrator` execution flow tests +- [ ] `SagaOrchestrator` compensation flow tests +- [ ] `InMemorySagaStateStore` persistence tests +- [ ] `RabbitMqSagaMessageBus` serialization tests + +### Integration Tests +- [ ] End-to-end saga execution with RabbitMQ +- [ ] Multi-step saga with compensation scenario +- [ ] Concurrent saga execution tests +- [ ] Connection recovery tests + +### Sample Implementation +- [ ] `OrderProcessingSaga` example in WarehouseManagement + - ReserveInventory step + - ProcessPayment step + - CreateShipment step + - Full compensation flow + +--- + +## Phase 2: Persistence + +### Svrnty.CQRS.Sagas.EntityFramework +- [ ] `EfCoreSagaStateStore` implementation +- [ ] `SagaState` entity configuration +- [ ] Migration support +- [ ] PostgreSQL/SQL Server compatibility +- [ ] Optimistic concurrency handling + +### Configuration +```csharp +cqrs.AddSagas() + .UseEntityFramework(); +``` + +--- + +## Phase 3: Reliability + +### Saga Timeout Service +- [ ] `SagaTimeoutHostedService` - background service for stalled sagas +- [ ] Configurable timeout per saga type +- [ ] Automatic compensation trigger on timeout +- [ ] Dead letter handling for failed compensations + +### Retry Policies +- [ ] Exponential backoff support +- [ ] Circuit breaker integration +- [ ] Polly integration option + +### Idempotency +- [ ] Message deduplication +- [ ] Idempotent step execution +- [ ] Inbox/Outbox pattern support + +--- + +## Phase 4: Observability + +### OpenTelemetry Integration +- [ ] Distributed tracing for saga execution +- [ ] Span per saga step +- [ ] Correlation ID propagation +- [ ] Metrics (saga duration, success/failure rates) + +### Saga Dashboard (Optional) +- [ ] Web UI for saga monitoring +- [ ] Real-time saga status +- [ ] Manual compensation trigger +- [ ] Saga history and audit log + +--- + +## Phase 5: Flutter Integration + +### gRPC Streaming for Saga Status +- [ ] `ISagaStatusStream` service +- [ ] Real-time saga progress updates +- [ ] Step completion notifications +- [ ] Error/compensation notifications + +### Flutter Client +- [ ] Dart client for saga status streaming +- [ ] Saga progress widget components + +--- + +## Phase 6: Alternative Transports + +### Svrnty.CQRS.Sagas.AzureServiceBus +- [ ] Azure Service Bus message transport +- [ ] Topic/Subscription topology +- [ ] Dead letter queue handling + +### Svrnty.CQRS.Sagas.Kafka +- [ ] Kafka message transport +- [ ] Consumer group management +- [ ] Partition key strategies + +--- + +## Future Considerations + +- **Event Sourcing**: Saga state as event stream +- **Saga Versioning**: Handle saga definition changes gracefully +- **Saga Composition**: Nested/child sagas +- **Saga Scheduling**: Delayed saga start +- **Multi-tenancy**: Tenant-aware saga execution diff --git a/Svrnty.CQRS.DynamicQuery.Abstractions/IQueryableProviderOverride.cs b/Svrnty.CQRS.DynamicQuery.Abstractions/IQueryableProviderOverride.cs new file mode 100644 index 0000000..e3c36e9 --- /dev/null +++ b/Svrnty.CQRS.DynamicQuery.Abstractions/IQueryableProviderOverride.cs @@ -0,0 +1,14 @@ +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace Svrnty.CQRS.DynamicQuery.Abstractions; + +/// +/// Marker interface for custom queryable providers that project entities to DTOs. +/// Extends for semantic clarity in registration. +/// +/// The DTO/Item type returned by the queryable. +public interface IQueryableProviderOverride : IQueryableProvider +{ +} diff --git a/Svrnty.CQRS.DynamicQuery.EntityFramework/DynamicQueryServicesBuilderExtensions.cs b/Svrnty.CQRS.DynamicQuery.EntityFramework/DynamicQueryServicesBuilderExtensions.cs new file mode 100644 index 0000000..1efabad --- /dev/null +++ b/Svrnty.CQRS.DynamicQuery.EntityFramework/DynamicQueryServicesBuilderExtensions.cs @@ -0,0 +1,26 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using PoweredSoft.Data.Core; +using PoweredSoft.Data.EntityFrameworkCore; + +namespace Svrnty.CQRS.DynamicQuery.EntityFramework; + +/// +/// Extensions for configuring DynamicQuery with Entity Framework Core. +/// +public static class DynamicQueryServicesBuilderExtensions +{ + /// + /// Uses Entity Framework Core for async queryable operations. + /// This replaces the default in-memory implementation with EF Core's async support. + /// + /// The DynamicQuery services builder. + /// The builder for chaining. + public static DynamicQueryServicesBuilder UseEntityFramework(this DynamicQueryServicesBuilder builder) + { + // Remove in-memory implementation and add EF Core implementation + builder.Services.RemoveAll(); + builder.Services.AddPoweredSoftEntityFrameworkCoreDataServices(); + return builder; + } +} diff --git a/Svrnty.CQRS.DynamicQuery.EntityFramework/Svrnty.CQRS.DynamicQuery.EntityFramework.csproj b/Svrnty.CQRS.DynamicQuery.EntityFramework/Svrnty.CQRS.DynamicQuery.EntityFramework.csproj new file mode 100644 index 0000000..1cd0ca9 --- /dev/null +++ b/Svrnty.CQRS.DynamicQuery.EntityFramework/Svrnty.CQRS.DynamicQuery.EntityFramework.csproj @@ -0,0 +1,36 @@ + + + net10.0 + false + 14 + enable + + Svrnty + David Lebee, Mathias Beaulieu-Duncan + icon.png + README.md + https://git.openharbor.io/svrnty/dotnet-cqrs + git + true + MIT + + portable + true + true + true + snupkg + + + + + + + + + + + + + + + diff --git a/Svrnty.CQRS.DynamicQuery/DynamicQueryServicesBuilder.cs b/Svrnty.CQRS.DynamicQuery/DynamicQueryServicesBuilder.cs new file mode 100644 index 0000000..8307cee --- /dev/null +++ b/Svrnty.CQRS.DynamicQuery/DynamicQueryServicesBuilder.cs @@ -0,0 +1,19 @@ +using Microsoft.Extensions.DependencyInjection; + +namespace Svrnty.CQRS.DynamicQuery; + +/// +/// Builder for configuring DynamicQuery services. +/// +public class DynamicQueryServicesBuilder +{ + /// + /// The service collection being configured. + /// + public IServiceCollection Services { get; } + + internal DynamicQueryServicesBuilder(IServiceCollection services) + { + Services = services; + } +} diff --git a/Svrnty.CQRS.DynamicQuery/InMemoryAsyncQueryableService.cs b/Svrnty.CQRS.DynamicQuery/InMemoryAsyncQueryableService.cs new file mode 100644 index 0000000..3d614f4 --- /dev/null +++ b/Svrnty.CQRS.DynamicQuery/InMemoryAsyncQueryableService.cs @@ -0,0 +1,78 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; +using PoweredSoft.Data.Core; + +namespace Svrnty.CQRS.DynamicQuery; + +/// +/// In-memory implementation of IAsyncQueryableService. +/// For EF Core projects, use AddDynamicQueryServices().UseEntityFramework() instead. +/// +public class InMemoryAsyncQueryableService : IAsyncQueryableService +{ + public IEnumerable Handlers { get; } = Array.Empty(); + + public Task> ToListAsync(IQueryable queryable, CancellationToken cancellationToken = default) + { + return Task.FromResult(queryable.ToList()); + } + + public Task FirstOrDefaultAsync(IQueryable queryable, CancellationToken cancellationToken = default) + { + return Task.FromResult(queryable.FirstOrDefault()); + } + + public Task FirstOrDefaultAsync(IQueryable queryable, Expression> predicate, CancellationToken cancellationToken = default) + { + return Task.FromResult(queryable.FirstOrDefault(predicate)); + } + + public Task LastOrDefaultAsync(IQueryable queryable, CancellationToken cancellationToken = default) + { + return Task.FromResult(queryable.LastOrDefault()); + } + + public Task LastOrDefaultAsync(IQueryable queryable, Expression> predicate, CancellationToken cancellationToken = default) + { + return Task.FromResult(queryable.LastOrDefault(predicate)); + } + + public Task AnyAsync(IQueryable queryable, CancellationToken cancellationToken = default) + { + return Task.FromResult(queryable.Any()); + } + + public Task AnyAsync(IQueryable queryable, Expression> predicate, CancellationToken cancellationToken = default) + { + return Task.FromResult(queryable.Any(predicate)); + } + + public Task AllAsync(IQueryable queryable, Expression> predicate, CancellationToken cancellationToken = default) + { + return Task.FromResult(queryable.All(predicate)); + } + + public Task CountAsync(IQueryable queryable, CancellationToken cancellationToken = default) + { + return Task.FromResult(queryable.Count()); + } + + public Task LongCountAsync(IQueryable queryable, CancellationToken cancellationToken = default) + { + return Task.FromResult(queryable.LongCount()); + } + + public Task SingleOrDefaultAsync(IQueryable queryable, Expression> predicate, CancellationToken cancellationToken = default) + { + return Task.FromResult(queryable.SingleOrDefault(predicate)); + } + + public IAsyncQueryableHandlerService? GetAsyncQueryableHandler(IQueryable queryable) + { + return null; + } +} diff --git a/Svrnty.CQRS.DynamicQuery/ServiceCollectionExtensions.cs b/Svrnty.CQRS.DynamicQuery/ServiceCollectionExtensions.cs index 50ce55f..687b9a5 100644 --- a/Svrnty.CQRS.DynamicQuery/ServiceCollectionExtensions.cs +++ b/Svrnty.CQRS.DynamicQuery/ServiceCollectionExtensions.cs @@ -1,16 +1,31 @@ using System.Diagnostics.CodeAnalysis; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; +using PoweredSoft.Data.Core; +using PoweredSoft.DynamicQuery; +using PoweredSoft.DynamicQuery.Core; using Svrnty.CQRS.Abstractions; using Svrnty.CQRS.Abstractions.Discovery; using Svrnty.CQRS.DynamicQuery.Abstractions; using Svrnty.CQRS.DynamicQuery.Discover; -using PoweredSoft.DynamicQuery.Core; namespace Svrnty.CQRS.DynamicQuery; public static class ServiceCollectionExtensions { + /// + /// Registers core DynamicQuery services with in-memory async queryable. + /// For EF Core projects, chain with .UseEntityFramework(). + /// + /// The service collection. + /// A builder for further configuration. + public static DynamicQueryServicesBuilder AddDynamicQueryServices(this IServiceCollection services) + { + services.TryAddTransient(); + services.TryAddTransient(); + return new DynamicQueryServicesBuilder(services); + } + public static IServiceCollection AddDynamicQuery(this IServiceCollection services, string name = null) where TSourceAndDestination : class => AddDynamicQuery(services, name: name); @@ -55,6 +70,22 @@ public static class ServiceCollectionExtensions return services; } + /// + /// Registers a custom queryable provider override for the specified source type. + /// Use this for DTOs that require custom projection from entities. + /// + /// The DTO/Item type returned by the queryable. + /// The provider implementation type. + /// The service collection. + /// The service collection for chaining. + public static IServiceCollection AddQueryableProviderOverride(this IServiceCollection services) + where TSource : class + where TProvider : class, IQueryableProviderOverride + { + services.AddTransient, TProvider>(); + return services; + } + public static IServiceCollection AddDynamicQueryWithParams(this IServiceCollection services, string name = null) where TSourceAndDestination : class where TParams : class diff --git a/Svrnty.CQRS.Events.Abstractions/IDomainEvent.cs b/Svrnty.CQRS.Events.Abstractions/IDomainEvent.cs new file mode 100644 index 0000000..279be19 --- /dev/null +++ b/Svrnty.CQRS.Events.Abstractions/IDomainEvent.cs @@ -0,0 +1,17 @@ +namespace Svrnty.CQRS.Events.Abstractions; + +/// +/// Marker interface for domain events. +/// +public interface IDomainEvent +{ + /// + /// Unique identifier for this event instance. + /// + Guid EventId { get; } + + /// + /// Timestamp when the event occurred. + /// + DateTime OccurredAt { get; } +} diff --git a/Svrnty.CQRS.Events.Abstractions/IDomainEventPublisher.cs b/Svrnty.CQRS.Events.Abstractions/IDomainEventPublisher.cs new file mode 100644 index 0000000..f1a14d2 --- /dev/null +++ b/Svrnty.CQRS.Events.Abstractions/IDomainEventPublisher.cs @@ -0,0 +1,16 @@ +namespace Svrnty.CQRS.Events.Abstractions; + +/// +/// Interface for publishing domain events to external systems. +/// +public interface IDomainEventPublisher +{ + /// + /// Publishes a domain event. + /// + /// The type of event to publish. + /// The event to publish. + /// Cancellation token. + Task PublishAsync(TEvent @event, CancellationToken cancellationToken = default) + where TEvent : IDomainEvent; +} diff --git a/Svrnty.CQRS.Events.Abstractions/Svrnty.CQRS.Events.Abstractions.csproj b/Svrnty.CQRS.Events.Abstractions/Svrnty.CQRS.Events.Abstractions.csproj new file mode 100644 index 0000000..658c2f0 --- /dev/null +++ b/Svrnty.CQRS.Events.Abstractions/Svrnty.CQRS.Events.Abstractions.csproj @@ -0,0 +1,29 @@ + + + net10.0 + true + 14 + enable + enable + + Svrnty + David Lebee, Mathias Beaulieu-Duncan + icon.png + README.md + https://git.openharbor.io/svrnty/dotnet-cqrs + git + true + MIT + + portable + true + true + true + snupkg + + + + + + + diff --git a/Svrnty.CQRS.Events.RabbitMQ/RabbitMqDomainEventPublisher.cs b/Svrnty.CQRS.Events.RabbitMQ/RabbitMqDomainEventPublisher.cs new file mode 100644 index 0000000..52f2ec9 --- /dev/null +++ b/Svrnty.CQRS.Events.RabbitMQ/RabbitMqDomainEventPublisher.cs @@ -0,0 +1,163 @@ +using System.Text; +using System.Text.Json; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using RabbitMQ.Client; +using Svrnty.CQRS.Events.Abstractions; + +namespace Svrnty.CQRS.Events.RabbitMQ; + +/// +/// RabbitMQ implementation of the domain event publisher. +/// +public class RabbitMqDomainEventPublisher : IDomainEventPublisher, IAsyncDisposable +{ + private readonly RabbitMqEventOptions _options; + private readonly ILogger _logger; + private IConnection? _connection; + private IChannel? _channel; + private readonly SemaphoreSlim _connectionLock = new(1, 1); + private bool _disposed; + + /// + /// Creates a new RabbitMQ domain event publisher. + /// + public RabbitMqDomainEventPublisher( + IOptions options, + ILogger logger) + { + _options = options.Value; + _logger = logger; + } + + /// + public async Task PublishAsync(TEvent @event, CancellationToken cancellationToken = default) + where TEvent : IDomainEvent + { + await EnsureConnectionAsync(cancellationToken); + + var eventTypeName = typeof(TEvent).Name; + var routingKey = GetRoutingKey(eventTypeName); + var body = JsonSerializer.SerializeToUtf8Bytes(@event); + + var properties = new BasicProperties + { + MessageId = @event.EventId.ToString(), + ContentType = "application/json", + DeliveryMode = _options.Durable ? DeliveryModes.Persistent : DeliveryModes.Transient, + Timestamp = new AmqpTimestamp(new DateTimeOffset(@event.OccurredAt).ToUnixTimeSeconds()), + Headers = new Dictionary + { + ["event-type"] = eventTypeName, + ["event-id"] = @event.EventId.ToString() + } + }; + + await _channel!.BasicPublishAsync( + exchange: _options.Exchange, + routingKey: routingKey, + mandatory: false, + basicProperties: properties, + body: body, + cancellationToken: cancellationToken); + + _logger.LogDebug( + "Published domain event {EventType} with ID {EventId} to routing key {RoutingKey}", + eventTypeName, @event.EventId, routingKey); + } + + private static string GetRoutingKey(string eventTypeName) + { + // Convert PascalCase to dot-notation, e.g., "InventoryMovementEvent" -> "events.inventory.movement" + var name = eventTypeName.Replace("Event", ""); + var words = new List(); + var currentWord = new StringBuilder(); + + foreach (var c in name) + { + if (char.IsUpper(c) && currentWord.Length > 0) + { + words.Add(currentWord.ToString().ToLowerInvariant()); + currentWord.Clear(); + } + currentWord.Append(c); + } + + if (currentWord.Length > 0) + { + words.Add(currentWord.ToString().ToLowerInvariant()); + } + + return "events." + string.Join(".", words); + } + + private async Task EnsureConnectionAsync(CancellationToken cancellationToken) + { + if (_connection?.IsOpen == true && _channel?.IsOpen == true) + { + return; + } + + await _connectionLock.WaitAsync(cancellationToken); + try + { + if (_connection?.IsOpen == true && _channel?.IsOpen == true) + { + return; + } + + var factory = new ConnectionFactory + { + HostName = _options.HostName, + Port = _options.Port, + UserName = _options.UserName, + Password = _options.Password, + VirtualHost = _options.VirtualHost + }; + + _connection = await factory.CreateConnectionAsync(cancellationToken); + _channel = await _connection.CreateChannelAsync(cancellationToken: cancellationToken); + + // Declare topic exchange for domain events + await _channel.ExchangeDeclareAsync( + exchange: _options.Exchange, + type: ExchangeType.Topic, + durable: _options.Durable, + autoDelete: false, + cancellationToken: cancellationToken); + + _logger.LogInformation( + "Connected to RabbitMQ at {Host}:{Port}, exchange: {Exchange}", + _options.HostName, _options.Port, _options.Exchange); + } + finally + { + _connectionLock.Release(); + } + } + + /// + public async ValueTask DisposeAsync() + { + if (_disposed) + { + return; + } + + _disposed = true; + + if (_channel?.IsOpen == true) + { + await _channel.CloseAsync(); + } + _channel?.Dispose(); + + if (_connection?.IsOpen == true) + { + await _connection.CloseAsync(); + } + _connection?.Dispose(); + + _connectionLock.Dispose(); + } +} diff --git a/Svrnty.CQRS.Events.RabbitMQ/RabbitMqEventOptions.cs b/Svrnty.CQRS.Events.RabbitMQ/RabbitMqEventOptions.cs new file mode 100644 index 0000000..40b6c9e --- /dev/null +++ b/Svrnty.CQRS.Events.RabbitMQ/RabbitMqEventOptions.cs @@ -0,0 +1,42 @@ +namespace Svrnty.CQRS.Events.RabbitMQ; + +/// +/// Configuration options for RabbitMQ domain event publishing. +/// +public class RabbitMqEventOptions +{ + /// + /// RabbitMQ host name. Default: localhost + /// + public string HostName { get; set; } = "localhost"; + + /// + /// RabbitMQ port. Default: 5672 + /// + public int Port { get; set; } = 5672; + + /// + /// RabbitMQ username. Default: guest + /// + public string UserName { get; set; } = "guest"; + + /// + /// RabbitMQ password. Default: guest + /// + public string Password { get; set; } = "guest"; + + /// + /// RabbitMQ virtual host. Default: / + /// + public string VirtualHost { get; set; } = "/"; + + /// + /// Exchange name for domain events. Default: domain.events + /// + public string Exchange { get; set; } = "domain.events"; + + /// + /// Whether to use durable exchanges. Default: true + /// + public bool Durable { get; set; } = true; +} diff --git a/Svrnty.CQRS.Events.RabbitMQ/ServiceCollectionExtensions.cs b/Svrnty.CQRS.Events.RabbitMQ/ServiceCollectionExtensions.cs new file mode 100644 index 0000000..d4cb46b --- /dev/null +++ b/Svrnty.CQRS.Events.RabbitMQ/ServiceCollectionExtensions.cs @@ -0,0 +1,30 @@ +using Microsoft.Extensions.DependencyInjection; +using Svrnty.CQRS.Events.Abstractions; + +namespace Svrnty.CQRS.Events.RabbitMQ; + +/// +/// Extension methods for registering RabbitMQ domain event publishing. +/// +public static class ServiceCollectionExtensions +{ + /// + /// Adds RabbitMQ domain event publishing to the service collection. + /// + /// The service collection. + /// Optional configuration action for RabbitMQ options. + /// The service collection for chaining. + public static IServiceCollection AddRabbitMqDomainEvents( + this IServiceCollection services, + Action? configure = null) + { + if (configure != null) + { + services.Configure(configure); + } + + services.AddSingleton(); + + return services; + } +} diff --git a/Svrnty.CQRS.Events.RabbitMQ/Svrnty.CQRS.Events.RabbitMQ.csproj b/Svrnty.CQRS.Events.RabbitMQ/Svrnty.CQRS.Events.RabbitMQ.csproj new file mode 100644 index 0000000..dda6ea8 --- /dev/null +++ b/Svrnty.CQRS.Events.RabbitMQ/Svrnty.CQRS.Events.RabbitMQ.csproj @@ -0,0 +1,40 @@ + + + net10.0 + false + 14 + enable + enable + + Svrnty + David Lebee, Mathias Beaulieu-Duncan + icon.png + README.md + https://git.openharbor.io/svrnty/dotnet-cqrs + git + true + MIT + + portable + true + true + true + snupkg + + + + + + + + + + + + + + + + + + diff --git a/Svrnty.CQRS.Grpc.Generators/GrpcGenerator.cs b/Svrnty.CQRS.Grpc.Generators/GrpcGenerator.cs index f01cabb..3015b1c 100644 --- a/Svrnty.CQRS.Grpc.Generators/GrpcGenerator.cs +++ b/Svrnty.CQRS.Grpc.Generators/GrpcGenerator.cs @@ -13,7 +13,7 @@ namespace Svrnty.CQRS.Grpc.Generators { public void Initialize(IncrementalGeneratorInitializationContext context) { - // Find all types that might be commands or queries + // Find all types that might be commands or queries from source var typeDeclarations = context.SyntaxProvider .CreateSyntaxProvider( predicate: static (node, _) => node is TypeDeclarationSyntax, @@ -34,8 +34,81 @@ namespace Svrnty.CQRS.Grpc.Generators return symbol as INamedTypeSymbol; } - private static void Execute(Compilation compilation, IEnumerable types, SourceProductionContext context) + /// + /// Collects all types from the compilation and all referenced assemblies + /// + private static IEnumerable GetAllTypesFromCompilation(Compilation compilation) { + var types = new List(); + + // Get types from the current assembly + CollectTypesFromNamespace(compilation.Assembly.GlobalNamespace, types); + + // Get types from all referenced assemblies + foreach (var reference in compilation.References) + { + var assemblySymbol = compilation.GetAssemblyOrModuleSymbol(reference) as IAssemblySymbol; + if (assemblySymbol != null) + { + CollectTypesFromNamespace(assemblySymbol.GlobalNamespace, types); + } + } + + return types; + } + + private static void CollectTypesFromNamespace(INamespaceSymbol ns, List types) + { + foreach (var type in ns.GetTypeMembers()) + { + types.Add(type); + // Also collect nested types + CollectNestedTypes(type, types); + } + + foreach (var nestedNs in ns.GetNamespaceMembers()) + { + CollectTypesFromNamespace(nestedNs, types); + } + } + + private static void CollectNestedTypes(INamedTypeSymbol type, List types) + { + foreach (var nestedType in type.GetTypeMembers()) + { + types.Add(nestedType); + CollectNestedTypes(nestedType, types); + } + } + + private static void Execute(Compilation compilation, IEnumerable sourceTypes, SourceProductionContext context) + { + // Get the expected namespace for proto-generated types + var rootNamespace = compilation.AssemblyName ?? "Generated"; + var grpcNamespace = $"{rootNamespace}.Grpc"; + + // Check if proto types are available (from Grpc.Tools compilation of .proto file) + // If not, skip generation - this happens on first build before proto file is compiled + var commandServiceBase = compilation.GetTypeByMetadataName($"{grpcNamespace}.CommandService+CommandServiceBase"); + var queryServiceBase = compilation.GetTypeByMetadataName($"{grpcNamespace}.QueryService+QueryServiceBase"); + var dynamicQueryServiceBase = compilation.GetTypeByMetadataName($"{grpcNamespace}.DynamicQueryService+DynamicQueryServiceBase"); + var notificationServiceBase = compilation.GetTypeByMetadataName($"{grpcNamespace}.NotificationService+NotificationServiceBase"); + + // If none of the service bases exist, the proto hasn't been compiled yet - skip generation + if (commandServiceBase == null && queryServiceBase == null && dynamicQueryServiceBase == null && notificationServiceBase == null) + { + // Report diagnostic for first build + var descriptor = new DiagnosticDescriptor( + "CQRSGRPC003", + "Proto types not yet available", + "gRPC service implementations will be generated on second build after proto file is compiled", + "Svrnty.CQRS.Grpc", + DiagnosticSeverity.Info, + isEnabledByDefault: true); + context.ReportDiagnostic(Diagnostic.Create(descriptor, Location.None)); + return; + } + var grpcIgnoreAttribute = compilation.GetTypeByMetadataName("Svrnty.CQRS.Grpc.Abstractions.Attributes.GrpcIgnoreAttribute"); var commandHandlerInterface = compilation.GetTypeByMetadataName("Svrnty.CQRS.Abstractions.ICommandHandler`1"); var commandHandlerWithResultInterface = compilation.GetTypeByMetadataName("Svrnty.CQRS.Abstractions.ICommandHandler`2"); @@ -53,10 +126,13 @@ namespace Svrnty.CQRS.Grpc.Generators var queryMap = new Dictionary(SymbolEqualityComparer.Default); // Query -> Result type var dynamicQueryMap = new List<(INamedTypeSymbol SourceType, INamedTypeSymbol DestinationType, INamedTypeSymbol? ParamsType)>(); // List of (Source, Destination, Params?) + // Get all types from the compilation and referenced assemblies + var allTypes = GetAllTypesFromCompilation(compilation); + // Find all command and query types by looking at handler implementations - foreach (var typeSymbol in types) + foreach (var typeSymbol in allTypes) { - if (typeSymbol == null || typeSymbol.IsAbstract || typeSymbol.IsStatic) + if (typeSymbol.IsAbstract || typeSymbol.IsStatic) continue; // Check if this type implements ICommandHandler or ICommandHandler @@ -162,7 +238,7 @@ namespace Svrnty.CQRS.Grpc.Generators var resultType = kvp.Value; // Skip if marked with [GrpcIgnore] - if (grpcIgnoreAttribute != null && HasAttribute(commandType, grpcIgnoreAttribute)) + if (HasGrpcIgnoreAttribute(commandType)) continue; var commandInfo = ExtractCommandInfo(commandType, resultType); @@ -177,7 +253,7 @@ namespace Svrnty.CQRS.Grpc.Generators var resultType = kvp.Value; // Skip if marked with [GrpcIgnore] - if (grpcIgnoreAttribute != null && HasAttribute(queryType, grpcIgnoreAttribute)) + if (HasGrpcIgnoreAttribute(queryType)) continue; var queryInfo = ExtractQueryInfo(queryType, resultType); @@ -194,10 +270,13 @@ namespace Svrnty.CQRS.Grpc.Generators dynamicQueries.Add(dynamicQueryInfo); } - // Generate services if we found any commands, queries, or dynamic queries - if (commands.Any() || queries.Any() || dynamicQueries.Any()) + // Process discovered notification types (marked with [StreamingNotification]) + var notifications = DiscoverNotifications(allTypes, compilation); + + // Generate services if we found any commands, queries, dynamic queries, or notifications + if (commands.Any() || queries.Any() || dynamicQueries.Any() || notifications.Any()) { - GenerateProtoAndServices(context, commands, queries, dynamicQueries, compilation); + GenerateProtoAndServices(context, commands, queries, dynamicQueries, notifications, compilation); } } @@ -207,6 +286,12 @@ namespace Svrnty.CQRS.Grpc.Generators SymbolEqualityComparer.Default.Equals(attr.AttributeClass, attributeSymbol)); } + private static bool HasGrpcIgnoreAttribute(INamedTypeSymbol typeSymbol) + { + return typeSymbol.GetAttributes().Any(attr => + attr.AttributeClass?.Name == "GrpcIgnoreAttribute"); + } + private static bool ImplementsInterface(INamedTypeSymbol typeSymbol, INamedTypeSymbol? interfaceSymbol) { if (interfaceSymbol == null) @@ -262,18 +347,387 @@ namespace Svrnty.CQRS.Grpc.Generators var propertyType = property.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); var protoType = ProtoTypeMapper.MapToProtoType(propertyType, out bool isRepeated, out bool isOptional); - commandInfo.Properties.Add(new PropertyInfo + var propInfo = new PropertyInfo { Name = property.Name, Type = propertyType, + FullyQualifiedType = propertyType, ProtoType = protoType, - FieldNumber = fieldNumber++ - }); + FieldNumber = fieldNumber++, + IsComplexType = IsUserDefinedComplexType(property.Type), + // New type metadata fields + IsNullable = IsNullableType(property.Type), + IsEnum = IsEnumType(property.Type), + IsDecimal = IsDecimalType(property.Type), + IsDateTime = IsDateTimeType(property.Type), + IsList = IsListOrCollection(property.Type), + }; + + // If it's a list, extract element type info + if (propInfo.IsList) + { + var elementType = GetListElementType(property.Type); + if (elementType != null) + { + propInfo.ElementType = elementType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + propInfo.IsElementComplexType = IsUserDefinedComplexType(elementType); + + // If element is complex, extract nested properties + if (propInfo.IsElementComplexType) + { + var unwrappedElement = UnwrapNullableType(elementType); + if (unwrappedElement is INamedTypeSymbol namedElementType) + { + propInfo.ElementNestedProperties = new List(); + ExtractNestedPropertiesWithTypeInfo(namedElementType, propInfo.ElementNestedProperties); + } + } + } + } + // If it's a complex type (not list), extract nested properties + else if (propInfo.IsComplexType) + { + var unwrapped = UnwrapNullableType(property.Type); + if (unwrapped is INamedTypeSymbol namedType) + { + ExtractNestedPropertiesWithTypeInfo(namedType, propInfo.NestedProperties); + } + } + + commandInfo.Properties.Add(propInfo); } return commandInfo; } + private static bool IsUserDefinedComplexType(ITypeSymbol type) + { + if (type == null) + return false; + + // Unwrap nullable first + var unwrapped = UnwrapNullableType(type); + + if (unwrapped.TypeKind != TypeKind.Class && unwrapped.TypeKind != TypeKind.Struct) + return false; + + var fullName = unwrapped.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + // Exclude system types and primitives + if (fullName.StartsWith("global::System.")) + return false; + if (IsPrimitiveType(fullName)) + return false; + + return true; + } + + private static ITypeSymbol UnwrapNullableType(ITypeSymbol type) + { + // Handle Nullable (value type nullability) + if (type is INamedTypeSymbol namedType && + namedType.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T && + namedType.TypeArguments.Length == 1) + { + return namedType.TypeArguments[0]; + } + return type; + } + + private static bool IsNullableType(ITypeSymbol type) + { + // Check for Nullable (value type nullability) + if (type is INamedTypeSymbol namedType && + namedType.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T) + { + return true; + } + // Check for reference type nullability (C# 8.0+) + if (type.NullableAnnotation == NullableAnnotation.Annotated) + { + return true; + } + return false; + } + + private static bool IsDecimalType(ITypeSymbol type) + { + var unwrapped = UnwrapNullableType(type); + return unwrapped.SpecialType == SpecialType.System_Decimal; + } + + private static bool IsDateTimeType(ITypeSymbol type) + { + var unwrapped = UnwrapNullableType(type); + return unwrapped.SpecialType == SpecialType.System_DateTime; + } + + private static bool IsEnumType(ITypeSymbol type) + { + var unwrapped = UnwrapNullableType(type); + return unwrapped.TypeKind == TypeKind.Enum; + } + + private static bool IsListOrCollection(ITypeSymbol type) + { + if (type is IArrayTypeSymbol) + return true; + + if (type is INamedTypeSymbol namedType && namedType.IsGenericType) + { + var typeName = namedType.OriginalDefinition.ToDisplayString(); + return typeName.StartsWith("System.Collections.Generic.List<") || + typeName.StartsWith("System.Collections.Generic.IList<") || + typeName.StartsWith("System.Collections.Generic.ICollection<") || + typeName.StartsWith("System.Collections.Generic.IEnumerable<"); + } + return false; + } + + private static ITypeSymbol? GetListElementType(ITypeSymbol type) + { + if (type is IArrayTypeSymbol arrayType) + return arrayType.ElementType; + + if (type is INamedTypeSymbol namedType && namedType.IsGenericType && namedType.TypeArguments.Length > 0) + { + return namedType.TypeArguments[0]; + } + return null; + } + + private static void ExtractNestedProperties(INamedTypeSymbol type, List nestedProperties) + { + var properties = type.GetMembers().OfType() + .Where(p => p.DeclaredAccessibility == Accessibility.Public && !p.IsStatic) + .ToList(); + + foreach (var property in properties) + { + var propertyType = property.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var propInfo = new PropertyInfo + { + Name = property.Name, + Type = propertyType, + FullyQualifiedType = propertyType, + ProtoType = string.Empty, + FieldNumber = 0, + IsComplexType = IsUserDefinedComplexType(property.Type), + }; + + // Recursively extract nested properties for complex types + if (propInfo.IsComplexType && property.Type is INamedTypeSymbol namedType) + { + ExtractNestedProperties(namedType, propInfo.NestedProperties); + } + + nestedProperties.Add(propInfo); + } + } + + private static void ExtractNestedPropertiesWithTypeInfo(INamedTypeSymbol type, List nestedProperties) + { + var properties = type.GetMembers().OfType() + .Where(p => p.DeclaredAccessibility == Accessibility.Public && !p.IsStatic) + .ToList(); + + foreach (var property in properties) + { + var propertyType = property.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var propInfo = new PropertyInfo + { + Name = property.Name, + Type = propertyType, + FullyQualifiedType = propertyType, + ProtoType = string.Empty, + FieldNumber = 0, + IsComplexType = IsUserDefinedComplexType(property.Type), + // Type metadata + IsNullable = IsNullableType(property.Type), + IsEnum = IsEnumType(property.Type), + IsDecimal = IsDecimalType(property.Type), + IsDateTime = IsDateTimeType(property.Type), + IsList = IsListOrCollection(property.Type), + }; + + // If it's a list, extract element type info + if (propInfo.IsList) + { + var elementType = GetListElementType(property.Type); + if (elementType != null) + { + propInfo.ElementType = elementType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + propInfo.IsElementComplexType = IsUserDefinedComplexType(elementType); + } + } + // Recursively extract nested properties for complex types + else if (propInfo.IsComplexType) + { + var unwrapped = UnwrapNullableType(property.Type); + if (unwrapped is INamedTypeSymbol namedType) + { + ExtractNestedPropertiesWithTypeInfo(namedType, propInfo.NestedProperties); + } + } + + nestedProperties.Add(propInfo); + } + } + + private static void GenerateNestedPropertyMapping(StringBuilder sb, List properties, string sourcePrefix, string indent) + { + foreach (var prop in properties) + { + var sourcePropName = char.ToUpper(prop.Name[0]) + prop.Name.Substring(1); + if (prop.IsComplexType) + { + // Generate nested object mapping + sb.AppendLine($"{indent}{prop.Name} = {sourcePrefix}.{sourcePropName} != null ? new {prop.FullyQualifiedType}"); + sb.AppendLine($"{indent}{{"); + GenerateNestedPropertyMapping(sb, prop.NestedProperties, $"{sourcePrefix}.{sourcePropName}", indent + " "); + sb.AppendLine($"{indent}}} : null!,"); + } + else + { + sb.AppendLine($"{indent}{prop.Name} = {sourcePrefix}.{sourcePropName},"); + } + } + } + + private static string GeneratePropertyAssignment(PropertyInfo prop, string requestVar, string indent) + { + var requestPropName = char.ToUpper(prop.Name[0]) + prop.Name.Substring(1); + var source = $"{requestVar}.{requestPropName}"; + + // Handle lists + if (prop.IsList) + { + if (prop.IsElementComplexType && prop.ElementNestedProperties != null && prop.ElementNestedProperties.Any()) + { + // Complex list: map each element + return GenerateComplexListMapping(prop, source, indent); + } + else + { + // Primitive list: just ToList() + return $"{indent}{prop.Name} = {source}?.ToList(),"; + } + } + + // Handle enums (proto int32 -> C# enum) + if (prop.IsEnum) + { + return $"{indent}{prop.Name} = ({prop.FullyQualifiedType}){source},"; + } + + // Handle decimals (proto string -> C# decimal) + if (prop.IsDecimal) + { + if (prop.IsNullable) + { + return $"{indent}{prop.Name} = string.IsNullOrEmpty({source}) ? null : decimal.Parse({source}),"; + } + else + { + return $"{indent}{prop.Name} = decimal.Parse({source}),"; + } + } + + // Handle DateTime (proto Timestamp -> C# DateTime) + if (prop.IsDateTime) + { + if (prop.IsNullable) + { + return $"{indent}{prop.Name} = {source} == null ? (System.DateTime?)null : {source}.ToDateTime(),"; + } + else + { + return $"{indent}{prop.Name} = {source}.ToDateTime(),"; + } + } + + // Handle complex types (single objects) + if (prop.IsComplexType) + { + return GenerateComplexObjectMapping(prop, source, indent); + } + + // Default: direct assignment + return $"{indent}{prop.Name} = {source},"; + } + + private static string GenerateComplexListMapping(PropertyInfo prop, string source, string indent) + { + var sb = new StringBuilder(); + sb.AppendLine($"{indent}{prop.Name} = {source}?.Select(x => new {prop.ElementType}"); + sb.AppendLine($"{indent}{{"); + + foreach (var nestedProp in prop.ElementNestedProperties!) + { + var nestedSourcePropName = char.ToUpper(nestedProp.Name[0]) + nestedProp.Name.Substring(1); + var nestedAssignment = GenerateNestedPropertyAssignment(nestedProp, "x", indent + " "); + sb.AppendLine(nestedAssignment); + } + + sb.Append($"{indent}}}).ToList(),"); + return sb.ToString(); + } + + private static string GenerateComplexObjectMapping(PropertyInfo prop, string source, string indent) + { + var sb = new StringBuilder(); + sb.AppendLine($"{indent}{prop.Name} = {source} != null ? new {prop.FullyQualifiedType}"); + sb.AppendLine($"{indent}{{"); + + foreach (var nestedProp in prop.NestedProperties) + { + var nestedAssignment = GenerateNestedPropertyAssignment(nestedProp, source, indent + " "); + sb.AppendLine(nestedAssignment); + } + + sb.Append($"{indent}}} : null!,"); + return sb.ToString(); + } + + private static string GenerateNestedPropertyAssignment(PropertyInfo prop, string sourceVar, string indent) + { + var sourcePropName = char.ToUpper(prop.Name[0]) + prop.Name.Substring(1); + var source = $"{sourceVar}.{sourcePropName}"; + + // Handle enums + if (prop.IsEnum) + { + return $"{indent}{prop.Name} = ({prop.FullyQualifiedType}){source},"; + } + + // Handle decimals + if (prop.IsDecimal) + { + if (prop.IsNullable) + { + return $"{indent}{prop.Name} = string.IsNullOrEmpty({source}) ? null : decimal.Parse({source}),"; + } + else + { + return $"{indent}{prop.Name} = decimal.Parse({source}),"; + } + } + + // Handle lists + if (prop.IsList) + { + return $"{indent}{prop.Name} = {source}?.ToList(),"; + } + + // Handle complex types + if (prop.IsComplexType && prop.NestedProperties.Any()) + { + return GenerateComplexObjectMapping(prop, source, indent); + } + + // Default: direct assignment + return $"{indent}{prop.Name} = {source},"; + } + private static QueryInfo? ExtractQueryInfo(INamedTypeSymbol queryType, INamedTypeSymbol resultType) { var queryInfo = new QueryInfo @@ -384,7 +838,7 @@ namespace Svrnty.CQRS.Grpc.Generators return word + "s"; } - private static void GenerateProtoAndServices(SourceProductionContext context, List commands, List queries, List dynamicQueries, Compilation compilation) + private static void GenerateProtoAndServices(SourceProductionContext context, List commands, List queries, List dynamicQueries, List notifications, Compilation compilation) { // Get root namespace from compilation var rootNamespace = compilation.AssemblyName ?? "Application"; @@ -410,8 +864,15 @@ namespace Svrnty.CQRS.Grpc.Generators context.AddSource("DynamicQueryServiceImpl.g.cs", dynamicQueryService); } + // Generate service implementations for notifications (streaming) + if (notifications.Any()) + { + var notificationService = GenerateNotificationServiceImpl(notifications, rootNamespace); + context.AddSource("NotificationServiceImpl.g.cs", notificationService); + } + // Generate registration extensions - var registrationExtensions = GenerateRegistrationExtensions(commands.Any(), queries.Any(), dynamicQueries.Any(), rootNamespace); + var registrationExtensions = GenerateRegistrationExtensions(commands.Any(), queries.Any(), dynamicQueries.Any(), notifications.Any(), rootNamespace); context.AddSource("GrpcServiceRegistration.g.cs", registrationExtensions); } @@ -668,7 +1129,7 @@ namespace Svrnty.CQRS.Grpc.Generators return sb.ToString(); } - private static string GenerateRegistrationExtensions(bool hasCommands, bool hasQueries, bool hasDynamicQueries, string rootNamespace) + private static string GenerateRegistrationExtensions(bool hasCommands, bool hasQueries, bool hasDynamicQueries, bool hasNotifications, string rootNamespace) { var sb = new StringBuilder(); sb.AppendLine("// "); @@ -676,6 +1137,10 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine("using Microsoft.AspNetCore.Routing;"); sb.AppendLine("using Microsoft.Extensions.DependencyInjection;"); sb.AppendLine($"using {rootNamespace}.Grpc.Services;"); + if (hasNotifications) + { + sb.AppendLine("using Svrnty.CQRS.Notifications.Grpc;"); + } sb.AppendLine(); sb.AppendLine($"namespace {rootNamespace}.Grpc.Extensions"); sb.AppendLine("{"); @@ -754,10 +1219,34 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine(); } - if (hasCommands || hasQueries || hasDynamicQueries) + if (hasNotifications) { sb.AppendLine(" /// "); - sb.AppendLine(" /// Registers all auto-generated gRPC services (Commands, Queries, and DynamicQueries)"); + sb.AppendLine(" /// Registers the auto-generated Notification streaming gRPC service"); + sb.AppendLine(" /// "); + sb.AppendLine(" public static IServiceCollection AddGrpcNotificationService(this IServiceCollection services)"); + sb.AppendLine(" {"); + sb.AppendLine(" services.AddGrpc();"); + sb.AppendLine(" services.AddStreamingNotifications();"); + sb.AppendLine(" services.AddSingleton();"); + sb.AppendLine(" return services;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" /// "); + sb.AppendLine(" /// Maps the auto-generated Notification streaming gRPC service endpoints"); + sb.AppendLine(" /// "); + sb.AppendLine(" public static IEndpointRouteBuilder MapGrpcNotifications(this IEndpointRouteBuilder endpoints)"); + sb.AppendLine(" {"); + sb.AppendLine(" endpoints.MapGrpcService();"); + sb.AppendLine(" return endpoints;"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + if (hasCommands || hasQueries || hasDynamicQueries || hasNotifications) + { + sb.AppendLine(" /// "); + sb.AppendLine(" /// Registers all auto-generated gRPC services (Commands, Queries, DynamicQueries, and Notifications)"); sb.AppendLine(" /// "); sb.AppendLine(" public static IServiceCollection AddGrpcCommandsAndQueries(this IServiceCollection services)"); sb.AppendLine(" {"); @@ -769,11 +1258,16 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine(" services.AddSingleton();"); if (hasDynamicQueries) sb.AppendLine(" services.AddSingleton();"); + if (hasNotifications) + { + sb.AppendLine(" services.AddStreamingNotifications();"); + sb.AppendLine(" services.AddSingleton();"); + } sb.AppendLine(" return services;"); sb.AppendLine(" }"); sb.AppendLine(); sb.AppendLine(" /// "); - sb.AppendLine(" /// Maps all auto-generated gRPC service endpoints (Commands, Queries, and DynamicQueries)"); + sb.AppendLine(" /// Maps all auto-generated gRPC service endpoints (Commands, Queries, DynamicQueries, and Notifications)"); sb.AppendLine(" /// "); sb.AppendLine(" public static IEndpointRouteBuilder MapGrpcCommandsAndQueries(this IEndpointRouteBuilder endpoints)"); sb.AppendLine(" {"); @@ -783,6 +1277,8 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine(" endpoints.MapGrpcService();"); if (hasDynamicQueries) sb.AppendLine(" endpoints.MapGrpcService();"); + if (hasNotifications) + sb.AppendLine(" endpoints.MapGrpcService();"); sb.AppendLine(" return endpoints;"); sb.AppendLine(" }"); sb.AppendLine(); @@ -816,6 +1312,12 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine(" if (grpcOptions.GetShouldMapQueries())"); sb.AppendLine(" services.AddSingleton();"); } + if (hasNotifications) + { + sb.AppendLine(" // Always register notification service if it exists"); + sb.AppendLine(" services.AddStreamingNotifications();"); + sb.AppendLine(" services.AddSingleton();"); + } sb.AppendLine(" }"); sb.AppendLine(" return services;"); sb.AppendLine(" }"); @@ -845,6 +1347,11 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine(" if (grpcOptions.GetShouldMapQueries())"); sb.AppendLine(" endpoints.MapGrpcService();"); } + if (hasNotifications) + { + sb.AppendLine(" // Always map notification service if it exists"); + sb.AppendLine(" endpoints.MapGrpcService();"); + } sb.AppendLine(); sb.AppendLine(" if (grpcOptions.ShouldEnableReflection)"); sb.AppendLine(" endpoints.MapGrpcReflectionService();"); @@ -983,11 +1490,11 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine(" /// "); sb.AppendLine(" public sealed class CommandServiceImpl : CommandService.CommandServiceBase"); sb.AppendLine(" {"); - sb.AppendLine(" private readonly IServiceProvider _serviceProvider;"); + sb.AppendLine(" private readonly IServiceScopeFactory _scopeFactory;"); sb.AppendLine(); - sb.AppendLine(" public CommandServiceImpl(IServiceProvider serviceProvider)"); + sb.AppendLine(" public CommandServiceImpl(IServiceScopeFactory scopeFactory)"); sb.AppendLine(" {"); - sb.AppendLine(" _serviceProvider = serviceProvider;"); + sb.AppendLine(" _scopeFactory = scopeFactory;"); sb.AppendLine(" }"); sb.AppendLine(); @@ -1001,16 +1508,20 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine($" {requestType} request,"); sb.AppendLine(" ServerCallContext context)"); sb.AppendLine(" {"); + sb.AppendLine(" using var scope = _scopeFactory.CreateScope();"); + sb.AppendLine(" var serviceProvider = scope.ServiceProvider;"); + sb.AppendLine(); sb.AppendLine($" var command = new {command.FullyQualifiedName}"); sb.AppendLine(" {"); foreach (var prop in command.Properties) { - sb.AppendLine($" {prop.Name} = request.{char.ToUpper(prop.Name[0]) + prop.Name.Substring(1)},"); + var assignment = GeneratePropertyAssignment(prop, "request", " "); + sb.AppendLine(assignment); } sb.AppendLine(" };"); sb.AppendLine(); sb.AppendLine(" // Validate command if validator is registered"); - sb.AppendLine($" var validator = _serviceProvider.GetService>();"); + sb.AppendLine($" var validator = serviceProvider.GetService>();"); sb.AppendLine(" if (validator != null)"); sb.AppendLine(" {"); sb.AppendLine(" var validationResult = await validator.ValidateAsync(command, context.CancellationToken);"); @@ -1038,7 +1549,7 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine(" }"); sb.AppendLine(" }"); sb.AppendLine(); - sb.AppendLine($" var handler = _serviceProvider.GetRequiredService<{command.HandlerInterfaceName}>();"); + sb.AppendLine($" var handler = serviceProvider.GetRequiredService<{command.HandlerInterfaceName}>();"); if (command.HasResult) { @@ -1080,11 +1591,11 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine(" /// "); sb.AppendLine(" public sealed class QueryServiceImpl : QueryService.QueryServiceBase"); sb.AppendLine(" {"); - sb.AppendLine(" private readonly IServiceProvider _serviceProvider;"); + sb.AppendLine(" private readonly IServiceScopeFactory _scopeFactory;"); sb.AppendLine(); - sb.AppendLine(" public QueryServiceImpl(IServiceProvider serviceProvider)"); + sb.AppendLine(" public QueryServiceImpl(IServiceScopeFactory scopeFactory)"); sb.AppendLine(" {"); - sb.AppendLine(" _serviceProvider = serviceProvider;"); + sb.AppendLine(" _scopeFactory = scopeFactory;"); sb.AppendLine(" }"); sb.AppendLine(); @@ -1098,7 +1609,10 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine($" {requestType} request,"); sb.AppendLine(" ServerCallContext context)"); sb.AppendLine(" {"); - sb.AppendLine($" var handler = _serviceProvider.GetRequiredService<{query.HandlerInterfaceName}>();"); + sb.AppendLine(" using var scope = _scopeFactory.CreateScope();"); + sb.AppendLine(" var serviceProvider = scope.ServiceProvider;"); + sb.AppendLine(); + sb.AppendLine($" var handler = serviceProvider.GetRequiredService<{query.HandlerInterfaceName}>();"); sb.AppendLine($" var query = new {query.FullyQualifiedName}"); sb.AppendLine(" {"); foreach (var prop in query.Properties) @@ -1362,11 +1876,11 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine(" /// "); sb.AppendLine(" public sealed class DynamicQueryServiceImpl : DynamicQueryService.DynamicQueryServiceBase"); sb.AppendLine(" {"); - sb.AppendLine(" private readonly IServiceProvider _serviceProvider;"); + sb.AppendLine(" private readonly IServiceScopeFactory _scopeFactory;"); sb.AppendLine(); - sb.AppendLine(" public DynamicQueryServiceImpl(IServiceProvider serviceProvider)"); + sb.AppendLine(" public DynamicQueryServiceImpl(IServiceScopeFactory scopeFactory)"); sb.AppendLine(" {"); - sb.AppendLine(" _serviceProvider = serviceProvider;"); + sb.AppendLine(" _scopeFactory = scopeFactory;"); sb.AppendLine(" }"); sb.AppendLine(); @@ -1380,6 +1894,9 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine($" {requestType} request,"); sb.AppendLine(" ServerCallContext context)"); sb.AppendLine(" {"); + sb.AppendLine(" using var scope = _scopeFactory.CreateScope();"); + sb.AppendLine(" var serviceProvider = scope.ServiceProvider;"); + sb.AppendLine(); // Build the dynamic query object if (dynamicQuery.HasParams) @@ -1401,7 +1918,7 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine(); // Get the handler and execute - sb.AppendLine($" var handler = _serviceProvider.GetRequiredService>>();"); + sb.AppendLine($" var handler = serviceProvider.GetRequiredService>>();"); sb.AppendLine(" var result = await handler.HandleAsync(query, context.CancellationToken);"); sb.AppendLine(); @@ -1526,6 +2043,130 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine(" }"); sb.AppendLine(); + // Add generic reflection-based mapper helper + sb.AppendLine(" private static TProto MapToProtoModel(TDomain domainModel) where TProto : Google.Protobuf.IMessage, new()"); + sb.AppendLine(" {"); + sb.AppendLine(" if (domainModel == null) return new TProto();"); + sb.AppendLine(" var proto = new TProto();"); + sb.AppendLine(" var domainProps = typeof(TDomain).GetProperties();"); + sb.AppendLine(" var protoDesc = proto.Descriptor;"); + sb.AppendLine(); + sb.AppendLine(" foreach (var domainProp in domainProps)"); + sb.AppendLine(" {"); + sb.AppendLine(" // Convert property name to proto field name (PascalCase to snake_case)"); + sb.AppendLine(" var protoFieldName = ToSnakeCase(domainProp.Name);"); + sb.AppendLine(" var protoField = protoDesc.FindFieldByName(protoFieldName);"); + sb.AppendLine(" if (protoField == null) continue;"); + sb.AppendLine(); + sb.AppendLine(" var domainValue = domainProp.GetValue(domainModel);"); + sb.AppendLine(" if (domainValue == null) continue;"); + sb.AppendLine(); + sb.AppendLine(" var protoAccessor = protoField.Accessor;"); + sb.AppendLine(); + sb.AppendLine(" // Handle DateTime -> Timestamp conversion"); + sb.AppendLine(" if (domainProp.PropertyType == typeof(DateTime) || domainProp.PropertyType == typeof(DateTime?))"); + sb.AppendLine(" {"); + sb.AppendLine(" var dateTime = (DateTime)domainValue;"); + sb.AppendLine(" // Ensure UTC for Timestamp conversion"); + sb.AppendLine(" if (dateTime.Kind != DateTimeKind.Utc)"); + sb.AppendLine(" dateTime = dateTime.ToUniversalTime();"); + sb.AppendLine(" protoAccessor.SetValue(proto, Google.Protobuf.WellKnownTypes.Timestamp.FromDateTime(dateTime));"); + sb.AppendLine(" }"); + sb.AppendLine(" else if (domainProp.PropertyType == typeof(DateTimeOffset) || domainProp.PropertyType == typeof(DateTimeOffset?))"); + sb.AppendLine(" {"); + sb.AppendLine(" var dateTimeOffset = (DateTimeOffset)domainValue;"); + sb.AppendLine(" protoAccessor.SetValue(proto, Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(dateTimeOffset));"); + sb.AppendLine(" }"); + sb.AppendLine(" // Handle collections (List, IList, etc.) - must check before complex types"); + sb.AppendLine(" else if (protoField.IsRepeated && domainValue is System.Collections.IEnumerable enumerable && domainProp.PropertyType != typeof(string))"); + sb.AppendLine(" {"); + sb.AppendLine(" var repeatedField = protoAccessor.GetValue(proto);"); + sb.AppendLine(" if (repeatedField == null) continue;"); + sb.AppendLine(); + sb.AppendLine(" // Get the element type of the RepeatedField"); + sb.AppendLine(" var repeatedFieldType = repeatedField.GetType();"); + sb.AppendLine(" var repeatedElementType = repeatedFieldType.IsGenericType ? repeatedFieldType.GetGenericArguments()[0] : null;"); + sb.AppendLine(" if (repeatedElementType == null) continue;"); + sb.AppendLine(); + sb.AppendLine(" // Get Add(T) method with specific parameter type to avoid ambiguity"); + sb.AppendLine(" var addMethod = repeatedFieldType.GetMethod(\"Add\", new[] { repeatedElementType });"); + sb.AppendLine(" if (addMethod == null) continue;"); + sb.AppendLine(); + sb.AppendLine(" // Get element types"); + sb.AppendLine(" var domainElementType = domainProp.PropertyType.IsArray"); + sb.AppendLine(" ? domainProp.PropertyType.GetElementType()"); + sb.AppendLine(" : domainProp.PropertyType.IsGenericType ? domainProp.PropertyType.GetGenericArguments()[0] : null;"); + sb.AppendLine(" var protoElementType = protoField.MessageType?.ClrType;"); + sb.AppendLine(); + sb.AppendLine(" foreach (var item in enumerable)"); + sb.AppendLine(" {"); + sb.AppendLine(" if (item == null) continue;"); + sb.AppendLine(); + sb.AppendLine(" // Check if elements need mapping (complex types)"); + sb.AppendLine(" if (protoElementType != null && typeof(Google.Protobuf.IMessage).IsAssignableFrom(protoElementType) && domainElementType != null)"); + sb.AppendLine(" {"); + sb.AppendLine(" var mapMethod = typeof(DynamicQueryServiceImpl).GetMethod(\"MapToProtoModel\","); + sb.AppendLine(" System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static)!"); + sb.AppendLine(" .MakeGenericMethod(domainElementType, protoElementType);"); + sb.AppendLine(" var mappedItem = mapMethod.Invoke(null, new[] { item });"); + sb.AppendLine(" if (mappedItem != null)"); + sb.AppendLine(" addMethod.Invoke(repeatedField, new[] { mappedItem });"); + sb.AppendLine(" }"); + sb.AppendLine(" else"); + sb.AppendLine(" {"); + sb.AppendLine(" // Primitive types, enums, strings - add directly"); + sb.AppendLine(" try { addMethod.Invoke(repeatedField, new[] { item }); }"); + sb.AppendLine(" catch { /* Type mismatch, skip */ }"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" // Handle nested complex types (non-primitive, non-enum, non-string, non-collection)"); + sb.AppendLine(" else if (!domainProp.PropertyType.IsPrimitive && "); + sb.AppendLine(" domainProp.PropertyType != typeof(string) && "); + sb.AppendLine(" !domainProp.PropertyType.IsEnum &&"); + sb.AppendLine(" !domainProp.PropertyType.IsValueType)"); + sb.AppendLine(" {"); + sb.AppendLine(" // Get the proto field type and recursively map"); + sb.AppendLine(" var protoFieldType = protoAccessor.GetValue(proto)?.GetType() ?? protoField.MessageType?.ClrType;"); + sb.AppendLine(" if (protoFieldType != null && typeof(Google.Protobuf.IMessage).IsAssignableFrom(protoFieldType))"); + sb.AppendLine(" {"); + sb.AppendLine(" var mapMethod = typeof(DynamicQueryServiceImpl).GetMethod(\"MapToProtoModel\", "); + sb.AppendLine(" System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static)!"); + sb.AppendLine(" .MakeGenericMethod(domainProp.PropertyType, protoFieldType);"); + sb.AppendLine(" var nestedProto = mapMethod.Invoke(null, new[] { domainValue });"); + sb.AppendLine(" if (nestedProto != null)"); + sb.AppendLine(" protoAccessor.SetValue(proto, nestedProto);"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" // Handle decimal -> string conversion"); + sb.AppendLine(" else if (domainProp.PropertyType == typeof(decimal) || domainProp.PropertyType == typeof(decimal?))"); + sb.AppendLine(" {"); + sb.AppendLine(" protoAccessor.SetValue(proto, ((decimal)domainValue).ToString(System.Globalization.CultureInfo.InvariantCulture));"); + sb.AppendLine(" }"); + sb.AppendLine(" else"); + sb.AppendLine(" {"); + sb.AppendLine(" // Direct assignment for primitives, strings, enums"); + sb.AppendLine(" try { protoAccessor.SetValue(proto, domainValue); }"); + sb.AppendLine(" catch { /* Type mismatch, skip */ }"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" return proto;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" private static string ToSnakeCase(string str)"); + sb.AppendLine(" {"); + sb.AppendLine(" if (string.IsNullOrEmpty(str)) return str;"); + sb.AppendLine(" var result = new System.Text.StringBuilder();"); + sb.AppendLine(" for (int i = 0; i < str.Length; i++)"); + sb.AppendLine(" {"); + sb.AppendLine(" if (i > 0 && char.IsUpper(str[i]))"); + sb.AppendLine(" result.Append('_');"); + sb.AppendLine(" result.Append(char.ToLowerInvariant(str[i]));"); + sb.AppendLine(" }"); + sb.AppendLine(" return result.ToString();"); + sb.AppendLine(" }"); + sb.AppendLine(); + // Add mapper methods for each entity type foreach (var dynamicQuery in dynamicQueries) { @@ -1534,9 +2175,7 @@ namespace Svrnty.CQRS.Grpc.Generators sb.AppendLine($" private static {protoTypeName} MapTo{entityName}ProtoModel({dynamicQuery.DestinationTypeFullyQualified} domainModel)"); sb.AppendLine(" {"); - sb.AppendLine($" // Use JSON serialization for mapping between domain and proto models"); - sb.AppendLine(" var json = System.Text.Json.JsonSerializer.Serialize(domainModel);"); - sb.AppendLine($" return System.Text.Json.JsonSerializer.Deserialize<{protoTypeName}>(json, new System.Text.Json.JsonSerializerOptions {{ PropertyNameCaseInsensitive = true }}) ?? new {protoTypeName}();"); + sb.AppendLine($" return MapToProtoModel<{dynamicQuery.DestinationTypeFullyQualified}, {protoTypeName}>(domainModel);"); sb.AppendLine(" }"); sb.AppendLine(); } @@ -1546,5 +2185,205 @@ namespace Svrnty.CQRS.Grpc.Generators return sb.ToString(); } + + /// + /// Discovers types marked with [StreamingNotification] attribute + /// + private static List DiscoverNotifications(IEnumerable allTypes, Compilation compilation) + { + var streamingNotificationAttribute = compilation.GetTypeByMetadataName( + "Svrnty.CQRS.Notifications.Abstractions.StreamingNotificationAttribute"); + + if (streamingNotificationAttribute == null) + return new List(); + + var notifications = new List(); + + foreach (var type in allTypes) + { + if (type.IsAbstract || type.IsStatic) + continue; + + var attr = type.GetAttributes() + .FirstOrDefault(a => SymbolEqualityComparer.Default.Equals( + a.AttributeClass, streamingNotificationAttribute)); + + if (attr == null) + continue; + + // Extract SubscriptionKey from attribute + var subscriptionKeyArg = attr.NamedArguments + .FirstOrDefault(a => a.Key == "SubscriptionKey"); + var subscriptionKeyProp = subscriptionKeyArg.Value.Value as string; + + if (string.IsNullOrEmpty(subscriptionKeyProp)) + continue; + + // Get all properties of the notification type + var properties = new List(); + int fieldNumber = 1; + + foreach (var prop in type.GetMembers().OfType() + .Where(p => p.DeclaredAccessibility == Accessibility.Public)) + { + var propType = prop.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var protoType = ProtoTypeMapper.MapToProtoType(propType, out _, out _); + + properties.Add(new PropertyInfo + { + Name = prop.Name, + Type = propType, + FullyQualifiedType = propType, + ProtoType = protoType, + FieldNumber = fieldNumber++, + IsEnum = prop.Type.TypeKind == TypeKind.Enum, + IsDecimal = propType.Contains("decimal") || propType.Contains("Decimal"), + IsDateTime = propType.Contains("DateTime") + }); + } + + // Find the subscription key property info + var keyPropInfo = properties.FirstOrDefault(p => p.Name == subscriptionKeyProp); + if (keyPropInfo == null) + continue; + + notifications.Add(new NotificationInfo + { + Name = type.Name, + FullyQualifiedName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + Namespace = type.ContainingNamespace?.ToDisplayString() ?? "", + SubscriptionKeyProperty = subscriptionKeyProp, + SubscriptionKeyInfo = keyPropInfo, + Properties = properties + }); + } + + return notifications; + } + + /// + /// Generates the NotificationServiceImpl class for streaming notifications + /// + private static string GenerateNotificationServiceImpl(List notifications, string rootNamespace) + { + var sb = new StringBuilder(); + sb.AppendLine("// "); + sb.AppendLine("#nullable enable"); + sb.AppendLine("using Grpc.Core;"); + sb.AppendLine("using System.Threading.Tasks;"); + sb.AppendLine("using System.Threading;"); + sb.AppendLine("using Google.Protobuf.WellKnownTypes;"); + sb.AppendLine($"using {rootNamespace}.Grpc;"); + sb.AppendLine("using Svrnty.CQRS.Notifications.Grpc;"); + sb.AppendLine(); + + sb.AppendLine($"namespace {rootNamespace}.Grpc.Services"); + sb.AppendLine("{"); + sb.AppendLine(" /// "); + sb.AppendLine(" /// Auto-generated gRPC service implementation for streaming Notifications"); + sb.AppendLine(" /// "); + sb.AppendLine(" public sealed class NotificationServiceImpl : NotificationService.NotificationServiceBase"); + sb.AppendLine(" {"); + sb.AppendLine(" private readonly NotificationSubscriptionManager _subscriptionManager;"); + sb.AppendLine(); + sb.AppendLine(" public NotificationServiceImpl(NotificationSubscriptionManager subscriptionManager)"); + sb.AppendLine(" {"); + sb.AppendLine(" _subscriptionManager = subscriptionManager;"); + sb.AppendLine(" }"); + + foreach (var notification in notifications) + { + var methodName = $"SubscribeTo{notification.Name}"; + var requestType = $"SubscribeTo{notification.Name}Request"; + var keyPropName = notification.SubscriptionKeyProperty; + // Proto uses PascalCase for C# properties + var keyPropPascal = ToPascalCaseHelper(ToSnakeCaseHelper(keyPropName)); + + sb.AppendLine(); + sb.AppendLine($" public override async Task {methodName}("); + sb.AppendLine($" {requestType} request,"); + sb.AppendLine($" IServerStreamWriter<{notification.Name}> responseStream,"); + sb.AppendLine(" ServerCallContext context)"); + sb.AppendLine(" {"); + sb.AppendLine($" // Subscribe with mapper from domain notification to proto message"); + sb.AppendLine($" using var subscription = _subscriptionManager.Subscribe<{notification.FullyQualifiedName}, {notification.Name}>("); + sb.AppendLine($" request.{keyPropPascal},"); + sb.AppendLine($" responseStream,"); + sb.AppendLine($" domainNotification => Map{notification.Name}(domainNotification));"); + sb.AppendLine(); + sb.AppendLine(" // Keep the stream alive until client disconnects"); + sb.AppendLine(" try"); + sb.AppendLine(" {"); + sb.AppendLine(" await Task.Delay(Timeout.Infinite, context.CancellationToken);"); + sb.AppendLine(" }"); + sb.AppendLine(" catch (OperationCanceledException)"); + sb.AppendLine(" {"); + sb.AppendLine(" // Client disconnected - normal behavior"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + } + + // Generate mapper methods + foreach (var notification in notifications) + { + sb.AppendLine(); + sb.AppendLine($" private static {notification.Name} Map{notification.Name}({notification.FullyQualifiedName} domain)"); + sb.AppendLine(" {"); + sb.AppendLine($" return new {notification.Name}"); + sb.AppendLine(" {"); + + foreach (var prop in notification.Properties) + { + var protoFieldName = ToPascalCaseHelper(ToSnakeCaseHelper(prop.Name)); + if (prop.IsDateTime) + { + sb.AppendLine($" {protoFieldName} = Timestamp.FromDateTime(domain.{prop.Name}.ToUniversalTime()),"); + } + else if (prop.IsDecimal) + { + sb.AppendLine($" {protoFieldName} = domain.{prop.Name}.ToString(),"); + } + else if (prop.IsEnum) + { + // Map domain enum to proto enum - get simple type name + var simpleTypeName = prop.Type.Replace("global::", "").Split('.').Last(); + sb.AppendLine($" {protoFieldName} = ({simpleTypeName})((int)domain.{prop.Name}),"); + } + else + { + sb.AppendLine($" {protoFieldName} = domain.{prop.Name},"); + } + } + + sb.AppendLine(" };"); + sb.AppendLine(" }"); + } + + sb.AppendLine(" }"); + sb.AppendLine("}"); + + return sb.ToString(); + } + + private static string ToSnakeCaseHelper(string str) + { + if (string.IsNullOrEmpty(str)) return str; + var result = new StringBuilder(); + for (int i = 0; i < str.Length; i++) + { + if (i > 0 && char.IsUpper(str[i])) + result.Append('_'); + result.Append(char.ToLowerInvariant(str[i])); + } + return result.ToString(); + } + + private static string ToPascalCaseHelper(string snakeCase) + { + if (string.IsNullOrEmpty(snakeCase)) return snakeCase; + var parts = snakeCase.Split('_'); + return string.Join("", parts.Select(p => + p.Length > 0 ? char.ToUpperInvariant(p[0]) + p.Substring(1).ToLowerInvariant() : "")); + } } } diff --git a/Svrnty.CQRS.Grpc.Generators/Models/CommandInfo.cs b/Svrnty.CQRS.Grpc.Generators/Models/CommandInfo.cs index d874ad6..bd9488c 100644 --- a/Svrnty.CQRS.Grpc.Generators/Models/CommandInfo.cs +++ b/Svrnty.CQRS.Grpc.Generators/Models/CommandInfo.cs @@ -35,6 +35,18 @@ namespace Svrnty.CQRS.Grpc.Generators.Models public string FullyQualifiedType { get; set; } public string ProtoType { get; set; } public int FieldNumber { get; set; } + public bool IsComplexType { get; set; } + public List NestedProperties { get; set; } + + // Type conversion metadata + public bool IsEnum { get; set; } + public bool IsList { get; set; } + public bool IsNullable { get; set; } + public bool IsDecimal { get; set; } + public bool IsDateTime { get; set; } + public string? ElementType { get; set; } + public bool IsElementComplexType { get; set; } + public List? ElementNestedProperties { get; set; } public PropertyInfo() { @@ -42,6 +54,14 @@ namespace Svrnty.CQRS.Grpc.Generators.Models Type = string.Empty; FullyQualifiedType = string.Empty; ProtoType = string.Empty; + IsComplexType = false; + NestedProperties = new List(); + IsEnum = false; + IsList = false; + IsNullable = false; + IsDecimal = false; + IsDateTime = false; + IsElementComplexType = false; } } } diff --git a/Svrnty.CQRS.Grpc.Generators/Models/NotificationInfo.cs b/Svrnty.CQRS.Grpc.Generators/Models/NotificationInfo.cs new file mode 100644 index 0000000..3285265 --- /dev/null +++ b/Svrnty.CQRS.Grpc.Generators/Models/NotificationInfo.cs @@ -0,0 +1,50 @@ +using System.Collections.Generic; + +namespace Svrnty.CQRS.Grpc.Generators.Models +{ + /// + /// Represents a discovered streaming notification type for proto/gRPC generation. + /// + public class NotificationInfo + { + /// + /// The notification type name (e.g., "InventoryChangeNotification"). + /// + public string Name { get; set; } + + /// + /// The fully qualified type name including namespace. + /// + public string FullyQualifiedName { get; set; } + + /// + /// The namespace of the notification type. + /// + public string Namespace { get; set; } + + /// + /// The property name used as the subscription key (from [StreamingNotification] attribute). + /// + public string SubscriptionKeyProperty { get; set; } + + /// + /// The subscription key property info. + /// + public PropertyInfo SubscriptionKeyInfo { get; set; } + + /// + /// All properties of the notification type. + /// + public List Properties { get; set; } + + public NotificationInfo() + { + Name = string.Empty; + FullyQualifiedName = string.Empty; + Namespace = string.Empty; + SubscriptionKeyProperty = string.Empty; + SubscriptionKeyInfo = new PropertyInfo(); + Properties = new List(); + } + } +} diff --git a/Svrnty.CQRS.Grpc.Generators/ProtoFileGenerator.cs b/Svrnty.CQRS.Grpc.Generators/ProtoFileGenerator.cs index dd66fa7..6f3102f 100644 --- a/Svrnty.CQRS.Grpc.Generators/ProtoFileGenerator.cs +++ b/Svrnty.CQRS.Grpc.Generators/ProtoFileGenerator.cs @@ -2,29 +2,90 @@ using System.Collections.Generic; using System.Linq; using System.Text; using Microsoft.CodeAnalysis; +using Svrnty.CQRS.Grpc.Generators.Models; namespace Svrnty.CQRS.Grpc.Generators; /// -/// Generates Protocol Buffer (.proto) files from C# Command and Query types +/// Generates Protocol Buffer (.proto) files from C# Command, Query, and Notification types /// internal class ProtoFileGenerator { private readonly Compilation _compilation; private readonly HashSet _requiredImports = new HashSet(); private readonly HashSet _generatedMessages = new HashSet(); + private readonly HashSet _generatedEnums = new HashSet(); + private readonly List _pendingEnums = new List(); private readonly StringBuilder _messagesBuilder = new StringBuilder(); + private readonly StringBuilder _enumsBuilder = new StringBuilder(); + private List? _allTypesCache; + + /// + /// Gets the discovered notifications after Generate() is called. + /// + public List DiscoveredNotifications { get; private set; } = new List(); public ProtoFileGenerator(Compilation compilation) { _compilation = compilation; } + /// + /// Gets all types from the compilation and all referenced assemblies + /// + private IEnumerable GetAllTypes() + { + if (_allTypesCache != null) + return _allTypesCache; + + _allTypesCache = new List(); + + // Get types from the current assembly + CollectTypesFromNamespace(_compilation.Assembly.GlobalNamespace, _allTypesCache); + + // Get types from all referenced assemblies + foreach (var reference in _compilation.References) + { + var assemblySymbol = _compilation.GetAssemblyOrModuleSymbol(reference) as IAssemblySymbol; + if (assemblySymbol != null) + { + CollectTypesFromNamespace(assemblySymbol.GlobalNamespace, _allTypesCache); + } + } + + return _allTypesCache; + } + + private static void CollectTypesFromNamespace(INamespaceSymbol ns, List types) + { + foreach (var type in ns.GetTypeMembers()) + { + types.Add(type); + CollectNestedTypes(type, types); + } + + foreach (var nestedNs in ns.GetNamespaceMembers()) + { + CollectTypesFromNamespace(nestedNs, types); + } + } + + private static void CollectNestedTypes(INamedTypeSymbol type, List types) + { + foreach (var nestedType in type.GetTypeMembers()) + { + types.Add(nestedType); + CollectNestedTypes(nestedType, types); + } + } + public string Generate(string packageName, string csharpNamespace) { var commands = DiscoverCommands(); var queries = DiscoverQueries(); var dynamicQueries = DiscoverDynamicQueries(); + var notifications = DiscoverNotifications(); + DiscoveredNotifications = notifications; var sb = new StringBuilder(); @@ -98,6 +159,24 @@ internal class ProtoFileGenerator sb.AppendLine(); } + // Notification Service (server streaming) + if (notifications.Any()) + { + sb.AppendLine("// NotificationService for real-time streaming notifications"); + sb.AppendLine("service NotificationService {"); + foreach (var notification in notifications) + { + var methodName = $"SubscribeTo{notification.Name}"; + var requestType = $"SubscribeTo{notification.Name}Request"; + + sb.AppendLine($" // Subscribe to {notification.Name} notifications"); + sb.AppendLine($" rpc {methodName} ({requestType}) returns (stream {notification.Name});"); + sb.AppendLine(); + } + sb.AppendLine("}"); + sb.AppendLine(); + } + // Generate messages for commands foreach (var command in commands) { @@ -118,7 +197,17 @@ internal class ProtoFileGenerator GenerateDynamicQueryMessages(dq); } - // Append all generated messages + // Generate messages for notifications + foreach (var notification in notifications) + { + GenerateNotificationMessages(notification); + } + + // Generate any pending enum definitions + GeneratePendingEnums(); + + // Append all generated enums first, then messages + sb.Append(_enumsBuilder); sb.Append(_messagesBuilder); // Insert imports if any were needed @@ -138,24 +227,78 @@ internal class ProtoFileGenerator private List DiscoverCommands() { - return _compilation.GetSymbolsWithName( - name => name.EndsWith("Command"), - SymbolFilter.Type) - .OfType() - .Where(t => !HasGrpcIgnoreAttribute(t)) - .Where(t => t.TypeKind == TypeKind.Class || t.TypeKind == TypeKind.Struct) - .ToList(); + // First, find all command handlers to know which commands are actually registered + var commandHandlerInterface = _compilation.GetTypeByMetadataName("Svrnty.CQRS.Abstractions.ICommandHandler`1"); + var commandHandlerWithResultInterface = _compilation.GetTypeByMetadataName("Svrnty.CQRS.Abstractions.ICommandHandler`2"); + + if (commandHandlerInterface == null && commandHandlerWithResultInterface == null) + return new List(); + + var registeredCommands = new HashSet(SymbolEqualityComparer.Default); + + foreach (var type in GetAllTypes()) + { + if (type.IsAbstract || type.IsStatic) + continue; + + foreach (var iface in type.AllInterfaces) + { + if (iface.IsGenericType) + { + if ((commandHandlerInterface != null && SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, commandHandlerInterface)) || + (commandHandlerWithResultInterface != null && SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, commandHandlerWithResultInterface))) + { + var commandType = iface.TypeArguments[0] as INamedTypeSymbol; + if (commandType != null && !HasGrpcIgnoreAttribute(commandType)) + { + registeredCommands.Add(commandType); + } + } + } + } + } + + return registeredCommands.ToList(); } private List DiscoverQueries() { - return _compilation.GetSymbolsWithName( - name => name.EndsWith("Query"), - SymbolFilter.Type) - .OfType() - .Where(t => !HasGrpcIgnoreAttribute(t)) - .Where(t => t.TypeKind == TypeKind.Class || t.TypeKind == TypeKind.Struct) - .ToList(); + // First, find all query handlers to know which queries are actually registered + var queryHandlerInterface = _compilation.GetTypeByMetadataName("Svrnty.CQRS.Abstractions.IQueryHandler`2"); + var dynamicQueryInterface2 = _compilation.GetTypeByMetadataName("Svrnty.CQRS.DynamicQuery.Abstractions.IDynamicQuery`2"); + var dynamicQueryInterface3 = _compilation.GetTypeByMetadataName("Svrnty.CQRS.DynamicQuery.Abstractions.IDynamicQuery`3"); + + if (queryHandlerInterface == null) + return new List(); + + var registeredQueries = new HashSet(SymbolEqualityComparer.Default); + + foreach (var type in GetAllTypes()) + { + if (type.IsAbstract || type.IsStatic) + continue; + + foreach (var iface in type.AllInterfaces) + { + if (iface.IsGenericType && SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, queryHandlerInterface)) + { + var queryType = iface.TypeArguments[0] as INamedTypeSymbol; + if (queryType != null && !HasGrpcIgnoreAttribute(queryType)) + { + // Skip dynamic queries - they're handled separately + if (queryType.IsGenericType && + ((dynamicQueryInterface2 != null && SymbolEqualityComparer.Default.Equals(queryType.OriginalDefinition, dynamicQueryInterface2)) || + (dynamicQueryInterface3 != null && SymbolEqualityComparer.Default.Equals(queryType.OriginalDefinition, dynamicQueryInterface3)))) + { + continue; + } + registeredQueries.Add(queryType); + } + } + } + } + + return registeredQueries.ToList(); } private bool HasGrpcIgnoreAttribute(INamedTypeSymbol type) @@ -180,6 +323,9 @@ internal class ProtoFileGenerator .Where(p => p.DeclaredAccessibility == Accessibility.Public) .ToList(); + // Collect nested complex types to generate after closing this message + var nestedComplexTypes = new List(); + int fieldNumber = 1; foreach (var prop in properties) { @@ -199,10 +345,19 @@ internal class ProtoFileGenerator var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name); _messagesBuilder.AppendLine($" {protoType} {fieldName} = {fieldNumber};"); - // If this is a complex type, generate its message too - if (IsComplexType(prop.Type)) + // Track enums for later generation + var enumType = ProtoFileTypeMapper.GetEnumType(prop.Type); + if (enumType != null) { - GenerateComplexTypeMessage(prop.Type as INamedTypeSymbol); + TrackEnumType(enumType); + } + + // Collect complex types to generate after this message is closed + // Use GetElementOrUnderlyingType to extract element type from collections + var underlyingType = ProtoFileTypeMapper.GetElementOrUnderlyingType(prop.Type); + if (IsComplexType(underlyingType) && underlyingType is INamedTypeSymbol namedType) + { + nestedComplexTypes.Add(namedType); } fieldNumber++; @@ -210,6 +365,12 @@ internal class ProtoFileGenerator _messagesBuilder.AppendLine("}"); _messagesBuilder.AppendLine(); + + // Now generate nested complex type messages + foreach (var nestedType in nestedComplexTypes) + { + GenerateComplexTypeMessage(nestedType); + } } private void GenerateResponseMessage(INamedTypeSymbol type) @@ -267,6 +428,9 @@ internal class ProtoFileGenerator .Where(p => p.DeclaredAccessibility == Accessibility.Public) .ToList(); + // Collect nested complex types to generate after closing this message + var nestedComplexTypes = new List(); + int fieldNumber = 1; foreach (var prop in properties) { @@ -285,10 +449,19 @@ internal class ProtoFileGenerator var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name); _messagesBuilder.AppendLine($" {protoType} {fieldName} = {fieldNumber};"); - // Recursively generate nested complex types - if (IsComplexType(prop.Type)) + // Track enums for later generation + var enumType = ProtoFileTypeMapper.GetEnumType(prop.Type); + if (enumType != null) { - GenerateComplexTypeMessage(prop.Type as INamedTypeSymbol); + TrackEnumType(enumType); + } + + // Collect complex types to generate after this message is closed + // Use GetElementOrUnderlyingType to extract element type from collections + var underlyingType = ProtoFileTypeMapper.GetElementOrUnderlyingType(prop.Type); + if (IsComplexType(underlyingType) && underlyingType is INamedTypeSymbol namedType) + { + nestedComplexTypes.Add(namedType); } fieldNumber++; @@ -296,6 +469,12 @@ internal class ProtoFileGenerator _messagesBuilder.AppendLine("}"); _messagesBuilder.AppendLine(); + + // Now generate nested complex type messages + foreach (var nestedType in nestedComplexTypes) + { + GenerateComplexTypeMessage(nestedType); + } } private ITypeSymbol? GetResultType(INamedTypeSymbol commandOrQueryType) @@ -305,11 +484,8 @@ internal class ProtoFileGenerator ? "ICommandHandler" : "IQueryHandler"; - // Find all types in the compilation - var allTypes = _compilation.GetSymbolsWithName(_ => true, SymbolFilter.Type) - .OfType(); - - foreach (var type in allTypes) + // Find all types in the compilation and referenced assemblies + foreach (var type in GetAllTypes()) { // Check if this type implements the handler interface foreach (var @interface in type.AllInterfaces) @@ -372,10 +548,8 @@ internal class ProtoFileGenerator return new List(); var dynamicQueryTypes = new List(); - var allTypes = _compilation.GetSymbolsWithName(_ => true, SymbolFilter.Type) - .OfType(); - foreach (var type in allTypes) + foreach (var type in GetAllTypes()) { if (type.IsAbstract || type.IsStatic) continue; @@ -471,4 +645,205 @@ internal class ProtoFileGenerator return word + "es"; return word + "s"; } + + /// + /// Tracks an enum type for later generation + /// + private void TrackEnumType(INamedTypeSymbol enumType) + { + if (!_generatedEnums.Contains(enumType.Name) && !_pendingEnums.Any(e => e.Name == enumType.Name)) + { + _pendingEnums.Add(enumType); + } + } + + /// + /// Generates all pending enum definitions + /// + private void GeneratePendingEnums() + { + foreach (var enumType in _pendingEnums) + { + if (_generatedEnums.Contains(enumType.Name)) + continue; + + _generatedEnums.Add(enumType.Name); + + _enumsBuilder.AppendLine($"// {enumType.Name} enum"); + _enumsBuilder.AppendLine($"enum {enumType.Name} {{"); + + // Get all enum members + var members = enumType.GetMembers() + .OfType() + .Where(f => f.HasConstantValue) + .ToList(); + + foreach (var member in members) + { + var protoFieldName = $"{ProtoFileTypeMapper.ToSnakeCase(enumType.Name).ToUpperInvariant()}_{ProtoFileTypeMapper.ToSnakeCase(member.Name).ToUpperInvariant()}"; + var value = member.ConstantValue; + _enumsBuilder.AppendLine($" {protoFieldName} = {value};"); + } + + _enumsBuilder.AppendLine("}"); + _enumsBuilder.AppendLine(); + } + } + + /// + /// Discovers types marked with [StreamingNotification] attribute + /// + private List DiscoverNotifications() + { + var streamingNotificationAttribute = _compilation.GetTypeByMetadataName( + "Svrnty.CQRS.Notifications.Abstractions.StreamingNotificationAttribute"); + + if (streamingNotificationAttribute == null) + return new List(); + + var notifications = new List(); + + foreach (var type in GetAllTypes()) + { + if (type.IsAbstract || type.IsStatic) + continue; + + var attr = type.GetAttributes() + .FirstOrDefault(a => SymbolEqualityComparer.Default.Equals( + a.AttributeClass, streamingNotificationAttribute)); + + if (attr == null) + continue; + + // Extract SubscriptionKey from attribute + var subscriptionKeyArg = attr.NamedArguments + .FirstOrDefault(a => a.Key == "SubscriptionKey"); + var subscriptionKeyProp = subscriptionKeyArg.Value.Value as string; + + if (string.IsNullOrEmpty(subscriptionKeyProp)) + continue; + + // Get all properties of the notification type + var properties = ExtractNotificationProperties(type); + + // Find the subscription key property info + var keyPropInfo = properties.FirstOrDefault(p => p.Name == subscriptionKeyProp); + if (keyPropInfo == null) + continue; + + notifications.Add(new NotificationInfo + { + Name = type.Name, + FullyQualifiedName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + .Replace("global::", ""), + Namespace = type.ContainingNamespace?.ToDisplayString() ?? "", + SubscriptionKeyProperty = subscriptionKeyProp, + SubscriptionKeyInfo = keyPropInfo, + Properties = properties + }); + } + + return notifications; + } + + /// + /// Extracts property information from a notification type + /// + private List ExtractNotificationProperties(INamedTypeSymbol type) + { + var properties = new List(); + int fieldNumber = 1; + + foreach (var prop in type.GetMembers().OfType() + .Where(p => p.DeclaredAccessibility == Accessibility.Public)) + { + if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type)) + continue; + + var protoType = ProtoFileTypeMapper.MapType(prop.Type, out _, out _); + var enumType = ProtoFileTypeMapper.GetEnumType(prop.Type); + + properties.Add(new Models.PropertyInfo + { + Name = prop.Name, + Type = prop.Type.Name, + FullyQualifiedType = prop.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + .Replace("global::", ""), + ProtoType = protoType, + FieldNumber = fieldNumber++, + IsEnum = enumType != null, + IsDecimal = prop.Type.SpecialType == SpecialType.System_Decimal || + prop.Type.ToDisplayString().Contains("decimal"), + IsDateTime = prop.Type.ToDisplayString().Contains("DateTime"), + IsNullable = prop.Type.NullableAnnotation == NullableAnnotation.Annotated || + (prop.Type is INamedTypeSymbol namedType && + namedType.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T) + }); + + if (enumType != null) + { + TrackEnumType(enumType); + } + } + + return properties; + } + + /// + /// Generates proto messages for a notification type + /// + private void GenerateNotificationMessages(NotificationInfo notification) + { + // Generate subscription request message (contains only the subscription key) + var requestMessageName = $"SubscribeTo{notification.Name}Request"; + if (!_generatedMessages.Contains(requestMessageName)) + { + _generatedMessages.Add(requestMessageName); + + _messagesBuilder.AppendLine($"// Subscription request for {notification.Name}"); + _messagesBuilder.AppendLine($"message {requestMessageName} {{"); + _messagesBuilder.AppendLine($" {notification.SubscriptionKeyInfo.ProtoType} {ProtoFileTypeMapper.ToSnakeCase(notification.SubscriptionKeyProperty)} = 1;"); + _messagesBuilder.AppendLine("}"); + _messagesBuilder.AppendLine(); + } + + // Generate the notification message itself + if (!_generatedMessages.Contains(notification.Name)) + { + _generatedMessages.Add(notification.Name); + + _messagesBuilder.AppendLine($"// {notification.Name} streaming notification"); + _messagesBuilder.AppendLine($"message {notification.Name} {{"); + + foreach (var prop in notification.Properties) + { + var protoType = ProtoFileTypeMapper.MapType( + _compilation.GetTypeByMetadataName(prop.FullyQualifiedType) ?? + GetTypeFromName(prop.FullyQualifiedType), + out var needsImport, out var importPath); + + if (needsImport && importPath != null) + { + _requiredImports.Add(importPath); + } + + var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name); + _messagesBuilder.AppendLine($" {prop.ProtoType} {fieldName} = {prop.FieldNumber};"); + } + + _messagesBuilder.AppendLine("}"); + _messagesBuilder.AppendLine(); + } + } + + /// + /// Gets a type symbol from a type name by searching all types + /// + private ITypeSymbol? GetTypeFromName(string fullTypeName) + { + // Try to find the type in all types + return GetAllTypes().FirstOrDefault(t => + t.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).Replace("global::", "") == fullTypeName || + t.ToDisplayString() == fullTypeName); + } } diff --git a/Svrnty.CQRS.Grpc.Generators/ProtoFileSourceGenerator.cs b/Svrnty.CQRS.Grpc.Generators/ProtoFileSourceGenerator.cs index d240d4a..f4c2ae1 100644 --- a/Svrnty.CQRS.Grpc.Generators/ProtoFileSourceGenerator.cs +++ b/Svrnty.CQRS.Grpc.Generators/ProtoFileSourceGenerator.cs @@ -20,24 +20,25 @@ public class ProtoFileSourceGenerator : IIncrementalGenerator // Generate a placeholder - the actual proto will be generated in the source output }); - // Collect all command and query types - var commandsAndQueries = context.SyntaxProvider + // Collect type declarations to trigger generation + // We use any type declaration as a trigger since ProtoFileGenerator scans all assemblies + var typeDeclarations = context.SyntaxProvider .CreateSyntaxProvider( - predicate: static (s, _) => IsCommandOrQuery(s), + predicate: static (s, _) => s is TypeDeclarationSyntax, transform: static (ctx, _) => GetTypeSymbol(ctx)) .Where(static m => m is not null) .Collect(); // Combine with compilation to have access to it - var compilationAndTypes = context.CompilationProvider.Combine(commandsAndQueries); + var compilationAndTypes = context.CompilationProvider.Combine(typeDeclarations); // Generate proto file when commands/queries change context.RegisterSourceOutput(compilationAndTypes, (spc, source) => { var (compilation, types) = source; - if (types.IsDefaultOrEmpty) - return; + // Note: We no longer bail out early since ProtoFileGenerator now scans all referenced assemblies + // The types from source are just a trigger - the generator will find types from all assemblies try { @@ -102,15 +103,6 @@ public class ProtoFileSourceGenerator : IIncrementalGenerator }); } - private static bool IsCommandOrQuery(SyntaxNode node) - { - if (node is not TypeDeclarationSyntax typeDecl) - return false; - - var name = typeDecl.Identifier.Text; - return name.EndsWith("Command") || name.EndsWith("Query"); - } - private static INamedTypeSymbol? GetTypeSymbol(GeneratorSyntaxContext context) { var typeDecl = (TypeDeclarationSyntax)context.Node; diff --git a/Svrnty.CQRS.Grpc.Generators/ProtoTypeMapper.cs b/Svrnty.CQRS.Grpc.Generators/ProtoTypeMapper.cs index 38dbcb1..14cfaee 100644 --- a/Svrnty.CQRS.Grpc.Generators/ProtoTypeMapper.cs +++ b/Svrnty.CQRS.Grpc.Generators/ProtoTypeMapper.cs @@ -17,11 +17,8 @@ internal static class ProtoFileTypeMapper var fullTypeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); var typeName = typeSymbol.Name; - // Nullable types - unwrap - if (typeSymbol.NullableAnnotation == NullableAnnotation.Annotated && typeSymbol is INamedTypeSymbol namedType && namedType.TypeArguments.Length > 0) - { - return MapType(namedType.TypeArguments[0], out needsImport, out importPath); - } + // Note: NullableAnnotation.Annotated is for reference type nullability (List?, string?, etc.) + // We don't unwrap these - just use the underlying type. Nullable value types are handled later. // Basic types switch (typeName) @@ -75,17 +72,31 @@ internal static class ProtoFileTypeMapper return "string"; } - if (fullTypeName.Contains("System.Decimal")) + if (fullTypeName.Contains("System.Decimal") || typeName == "Decimal" || fullTypeName == "decimal") { // Decimal serialized as string (no native decimal in proto) return "string"; } + // Handle Nullable value types (e.g., int?, decimal?, enum?) + if (typeSymbol is INamedTypeSymbol nullableType && + nullableType.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T && + nullableType.TypeArguments.Length == 1) + { + // Unwrap the nullable and map the inner type + return MapType(nullableType.TypeArguments[0], out needsImport, out importPath); + } + // Collections if (typeSymbol is INamedTypeSymbol collectionType) { - // List, IEnumerable, Array, etc. - if (collectionType.TypeArguments.Length == 1) + // List, IEnumerable, Array, ICollection etc. (but not Nullable) + var typeName2 = collectionType.Name; + if (collectionType.TypeArguments.Length == 1 && + (typeName2.Contains("List") || typeName2.Contains("Collection") || + typeName2.Contains("Enumerable") || typeName2.Contains("Array") || + typeName2.Contains("Set") || typeName2.Contains("IList") || + typeName2.Contains("ICollection") || typeName2.Contains("IEnumerable"))) { var elementType = collectionType.TypeArguments[0]; var protoElementType = MapType(elementType, out needsImport, out importPath); @@ -188,4 +199,56 @@ internal static class ProtoFileTypeMapper return false; } + + /// + /// Gets the element type from a collection type, or returns the type itself if not a collection. + /// Also unwraps Nullable types. + /// + public static ITypeSymbol GetElementOrUnderlyingType(ITypeSymbol typeSymbol) + { + // Unwrap Nullable + if (typeSymbol is INamedTypeSymbol nullableType && + nullableType.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T && + nullableType.TypeArguments.Length == 1) + { + return GetElementOrUnderlyingType(nullableType.TypeArguments[0]); + } + + // Extract element type from collections + if (typeSymbol is INamedTypeSymbol collectionType && collectionType.TypeArguments.Length == 1) + { + var typeName = collectionType.Name; + if (typeName.Contains("List") || typeName.Contains("Collection") || + typeName.Contains("Enumerable") || typeName.Contains("Array") || + typeName.Contains("Set") || typeName.Contains("IList") || + typeName.Contains("ICollection") || typeName.Contains("IEnumerable")) + { + return GetElementOrUnderlyingType(collectionType.TypeArguments[0]); + } + } + + return typeSymbol; + } + + /// + /// Checks if the type is an enum (including nullable enums) + /// + public static bool IsEnumType(ITypeSymbol typeSymbol) + { + var underlying = GetElementOrUnderlyingType(typeSymbol); + return underlying.TypeKind == TypeKind.Enum; + } + + /// + /// Gets the enum type symbol if this is an enum or nullable enum, otherwise null + /// + public static INamedTypeSymbol? GetEnumType(ITypeSymbol typeSymbol) + { + var underlying = GetElementOrUnderlyingType(typeSymbol); + if (underlying.TypeKind == TypeKind.Enum && underlying is INamedTypeSymbol enumType) + { + return enumType; + } + return null; + } } diff --git a/Svrnty.CQRS.Grpc.Generators/WriteProtoFileTask.cs b/Svrnty.CQRS.Grpc.Generators/WriteProtoFileTask.cs index aeb9b39..70788ba 100644 --- a/Svrnty.CQRS.Grpc.Generators/WriteProtoFileTask.cs +++ b/Svrnty.CQRS.Grpc.Generators/WriteProtoFileTask.cs @@ -62,7 +62,27 @@ public class WriteProtoFileTask : Task Log.LogWarning( $"Generated proto file not found at {generatedFilePath}. " + "The proto file may not have been generated yet. This is normal on first build."); - return true; // Don't fail the build, just skip + + // Write a minimal placeholder proto file so Grpc.Tools doesn't fail + // The real content will be generated on the next build + var placeholderProto = @"syntax = ""proto3""; + +option csharp_namespace = ""Generated.Grpc""; + +package cqrs; + +// Placeholder proto file - will be regenerated on next build +"; + var placeholderOutputPath = Path.Combine(ProjectDirectory, OutputDirectory); + Directory.CreateDirectory(placeholderOutputPath); + var placeholderProtoFilePath = Path.Combine(placeholderOutputPath, ProtoFileName); + File.WriteAllText(placeholderProtoFilePath, placeholderProto); + + Log.LogMessage(MessageImportance.High, + $"Svrnty.CQRS.Grpc: Wrote placeholder proto file at {placeholderProtoFilePath}. " + + "Run build again to generate the actual proto content."); + + return true; } // Read the generated C# file diff --git a/Svrnty.CQRS.Grpc/Svrnty.CQRS.Grpc.csproj b/Svrnty.CQRS.Grpc/Svrnty.CQRS.Grpc.csproj index 671a621..1508f95 100644 --- a/Svrnty.CQRS.Grpc/Svrnty.CQRS.Grpc.csproj +++ b/Svrnty.CQRS.Grpc/Svrnty.CQRS.Grpc.csproj @@ -27,7 +27,7 @@ - + diff --git a/Svrnty.CQRS.Notifications.Abstractions/INotificationPublisher.cs b/Svrnty.CQRS.Notifications.Abstractions/INotificationPublisher.cs new file mode 100644 index 0000000..3355e6c --- /dev/null +++ b/Svrnty.CQRS.Notifications.Abstractions/INotificationPublisher.cs @@ -0,0 +1,18 @@ +namespace Svrnty.CQRS.Notifications.Abstractions; + +/// +/// Publishes notifications to all subscribed gRPC clients. +/// +public interface INotificationPublisher +{ + /// + /// Publish a notification to all subscribers matching the subscription key. + /// The subscription key is extracted from the notification based on the + /// property. + /// + /// The notification type marked with . + /// The notification to publish. + /// Cancellation token. + Task PublishAsync(TNotification notification, CancellationToken ct = default) + where TNotification : class; +} diff --git a/Svrnty.CQRS.Notifications.Abstractions/StreamingNotificationAttribute.cs b/Svrnty.CQRS.Notifications.Abstractions/StreamingNotificationAttribute.cs new file mode 100644 index 0000000..15bdb5c --- /dev/null +++ b/Svrnty.CQRS.Notifications.Abstractions/StreamingNotificationAttribute.cs @@ -0,0 +1,15 @@ +namespace Svrnty.CQRS.Notifications.Abstractions; + +/// +/// Marks a record as a streaming notification that can be subscribed to via gRPC. +/// The framework will auto-generate proto definitions and service implementations. +/// +[AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = false)] +public sealed class StreamingNotificationAttribute : Attribute +{ + /// + /// The property name used as the subscription key. + /// Subscribers filter notifications by this value. + /// + public required string SubscriptionKey { get; set; } +} diff --git a/Svrnty.CQRS.Notifications.Abstractions/Svrnty.CQRS.Notifications.Abstractions.csproj b/Svrnty.CQRS.Notifications.Abstractions/Svrnty.CQRS.Notifications.Abstractions.csproj new file mode 100644 index 0000000..658c2f0 --- /dev/null +++ b/Svrnty.CQRS.Notifications.Abstractions/Svrnty.CQRS.Notifications.Abstractions.csproj @@ -0,0 +1,29 @@ + + + net10.0 + true + 14 + enable + enable + + Svrnty + David Lebee, Mathias Beaulieu-Duncan + icon.png + README.md + https://git.openharbor.io/svrnty/dotnet-cqrs + git + true + MIT + + portable + true + true + true + snupkg + + + + + + + diff --git a/Svrnty.CQRS.Notifications.Grpc/NotificationPublisher.cs b/Svrnty.CQRS.Notifications.Grpc/NotificationPublisher.cs new file mode 100644 index 0000000..c9c8a96 --- /dev/null +++ b/Svrnty.CQRS.Notifications.Grpc/NotificationPublisher.cs @@ -0,0 +1,76 @@ +using System.Collections.Concurrent; +using System.Reflection; +using Microsoft.Extensions.Logging; +using Svrnty.CQRS.Notifications.Abstractions; + +namespace Svrnty.CQRS.Notifications.Grpc; + +/// +/// Publishes notifications to subscribed gRPC clients. +/// +public class NotificationPublisher : INotificationPublisher +{ + private readonly NotificationSubscriptionManager _manager; + private readonly ILogger _logger; + + // Cache subscription key property info per notification type + private static readonly ConcurrentDictionary _keyCache = new(); + + public NotificationPublisher( + NotificationSubscriptionManager manager, + ILogger logger) + { + _manager = manager; + _logger = logger; + } + + /// + public async Task PublishAsync(TNotification notification, CancellationToken ct = default) + where TNotification : class + { + ArgumentNullException.ThrowIfNull(notification); + + var keyInfo = GetSubscriptionKeyInfo(typeof(TNotification)); + var subscriptionKey = keyInfo.Property.GetValue(notification); + + if (subscriptionKey == null) + { + _logger.LogWarning( + "Subscription key {PropertyName} is null on {NotificationType}, skipping notification", + keyInfo.PropertyName, typeof(TNotification).Name); + return; + } + + _logger.LogDebug( + "Publishing {NotificationType} with subscription key {PropertyName}={KeyValue}", + typeof(TNotification).Name, keyInfo.PropertyName, subscriptionKey); + + await _manager.NotifyAsync(notification, subscriptionKey, ct); + } + + private static SubscriptionKeyInfo GetSubscriptionKeyInfo(Type type) + { + return _keyCache.GetOrAdd(type, t => + { + var attr = t.GetCustomAttribute(); + if (attr == null) + { + throw new InvalidOperationException( + $"Type {t.Name} is not marked with [{nameof(StreamingNotificationAttribute)}]. " + + $"Add the attribute with a SubscriptionKey to enable streaming notifications."); + } + + var property = t.GetProperty(attr.SubscriptionKey); + if (property == null) + { + throw new InvalidOperationException( + $"Property '{attr.SubscriptionKey}' specified in [{nameof(StreamingNotificationAttribute)}] " + + $"was not found on type {t.Name}."); + } + + return new SubscriptionKeyInfo(attr.SubscriptionKey, property); + }); + } + + private sealed record SubscriptionKeyInfo(string PropertyName, PropertyInfo Property); +} diff --git a/Svrnty.CQRS.Notifications.Grpc/NotificationSubscriptionManager.cs b/Svrnty.CQRS.Notifications.Grpc/NotificationSubscriptionManager.cs new file mode 100644 index 0000000..58c63aa --- /dev/null +++ b/Svrnty.CQRS.Notifications.Grpc/NotificationSubscriptionManager.cs @@ -0,0 +1,164 @@ +using System.Collections.Concurrent; +using Grpc.Core; +using Microsoft.Extensions.Logging; + +namespace Svrnty.CQRS.Notifications.Grpc; + +/// +/// Manages gRPC stream subscriptions for notifications. +/// Thread-safe singleton that tracks subscriptions and routes notifications to subscribers. +/// +public class NotificationSubscriptionManager +{ + private readonly ConcurrentDictionary<(string TypeName, string Key), ConcurrentBag> _subscriptions = new(); + private readonly ILogger _logger; + + public NotificationSubscriptionManager(ILogger logger) + { + _logger = logger; + } + + /// + /// Subscribe to notifications of a specific domain type with a mapper to convert to proto format. + /// + /// The domain notification type. + /// The proto message type. + /// The subscription key value (e.g., inventory ID). + /// The gRPC server stream writer. + /// Function to map domain notification to proto message. + /// A disposable that removes the subscription when disposed. + public IDisposable Subscribe( + object subscriptionKey, + IServerStreamWriter stream, + Func mapper) where TDomain : class + { + var key = (typeof(TDomain).FullName!, subscriptionKey.ToString()!); + var subscriber = new Subscriber(stream, mapper); + var bag = _subscriptions.GetOrAdd(key, _ => new ConcurrentBag()); + bag.Add(subscriber); + + _logger.LogInformation( + "Client subscribed to {NotificationType} with key {SubscriptionKey}. Total subscribers: {Count}", + typeof(TDomain).Name, subscriptionKey, bag.Count); + + return new SubscriptionHandle(() => Remove(key, subscriber)); + } + + /// + /// Notify all subscribers of a specific notification type and subscription key. + /// + internal async Task NotifyAsync(TDomain notification, object subscriptionKey, CancellationToken ct) where TDomain : class + { + var key = (typeof(TDomain).FullName!, subscriptionKey.ToString()!); + + if (!_subscriptions.TryGetValue(key, out var subscribers)) + { + _logger.LogDebug( + "No subscribers for {NotificationType} with key {SubscriptionKey}", + typeof(TDomain).Name, subscriptionKey); + return; + } + + var deadSubscribers = new List(); + + foreach (var sub in subscribers) + { + if (sub is INotifiable notifiable) + { + try + { + await notifiable.NotifyAsync(notification, ct); + } + catch (Exception ex) + { + _logger.LogWarning(ex, + "Failed to notify subscriber for {NotificationType}, marking for removal", + typeof(TDomain).Name); + deadSubscribers.Add(sub); + } + } + } + + // Clean up dead subscribers + foreach (var dead in deadSubscribers) + { + Remove(key, dead); + } + + _logger.LogDebug( + "Notified {Count} subscribers for {NotificationType} with key {SubscriptionKey}", + subscribers.Count - deadSubscribers.Count, typeof(TDomain).Name, subscriptionKey); + } + + private void Remove((string TypeName, string Key) key, object subscriber) + { + if (_subscriptions.TryGetValue(key, out var bag)) + { + // ConcurrentBag doesn't support removal, so we rebuild + var remaining = bag.Where(s => !ReferenceEquals(s, subscriber)).ToList(); + if (remaining.Count == 0) + { + _subscriptions.TryRemove(key, out _); + } + else + { + var newBag = new ConcurrentBag(remaining); + _subscriptions.TryUpdate(key, newBag, bag); + } + + _logger.LogInformation( + "Client unsubscribed from {NotificationType} with key {SubscriptionKey}", + key.TypeName.Split('.').Last(), key.Key); + } + } +} + +/// +/// Internal interface for type-erased notification delivery. +/// +internal interface INotifiable +{ + Task NotifyAsync(TDomain notification, CancellationToken ct); +} + +/// +/// Wraps a gRPC stream writer with a domain→proto mapper. +/// +internal sealed class Subscriber : INotifiable +{ + private readonly IServerStreamWriter _stream; + private readonly Func _mapper; + + public Subscriber(IServerStreamWriter stream, Func mapper) + { + _stream = stream; + _mapper = mapper; + } + + public async Task NotifyAsync(TDomain notification, CancellationToken ct) + { + var proto = _mapper(notification); + await _stream.WriteAsync(proto, ct); + } +} + +/// +/// Handle that removes a subscription when disposed. +/// +internal sealed class SubscriptionHandle : IDisposable +{ + private readonly Action _onDispose; + private bool _disposed; + + public SubscriptionHandle(Action onDispose) + { + _onDispose = onDispose; + } + + public void Dispose() + { + if (_disposed) return; + _disposed = true; + _onDispose(); + } +} diff --git a/Svrnty.CQRS.Notifications.Grpc/ServiceCollectionExtensions.cs b/Svrnty.CQRS.Notifications.Grpc/ServiceCollectionExtensions.cs new file mode 100644 index 0000000..ea9f4a2 --- /dev/null +++ b/Svrnty.CQRS.Notifications.Grpc/ServiceCollectionExtensions.cs @@ -0,0 +1,26 @@ +using Microsoft.Extensions.DependencyInjection; +using Svrnty.CQRS.Notifications.Abstractions; + +namespace Svrnty.CQRS.Notifications.Grpc; + +/// +/// Extension methods for registering streaming notification services. +/// +public static class ServiceCollectionExtensions +{ + /// + /// Adds gRPC streaming notification services to the service collection. + /// + /// The service collection. + /// The service collection for chaining. + public static IServiceCollection AddStreamingNotifications(this IServiceCollection services) + { + // Subscription manager is singleton - shared state for all subscriptions + services.AddSingleton(); + + // Publisher can be singleton since it only depends on the manager + services.AddSingleton(); + + return services; + } +} diff --git a/Svrnty.CQRS.Notifications.Grpc/Svrnty.CQRS.Notifications.Grpc.csproj b/Svrnty.CQRS.Notifications.Grpc/Svrnty.CQRS.Notifications.Grpc.csproj new file mode 100644 index 0000000..c536be8 --- /dev/null +++ b/Svrnty.CQRS.Notifications.Grpc/Svrnty.CQRS.Notifications.Grpc.csproj @@ -0,0 +1,39 @@ + + + net10.0 + false + 14 + enable + enable + + Svrnty + Mathias Beaulieu-Duncan + icon.png + README.md + https://git.openharbor.io/svrnty/dotnet-cqrs + git + true + MIT + + portable + true + true + true + snupkg + + + + + + + + + + + + + + + + + diff --git a/Svrnty.CQRS.Sagas.Abstractions/ISaga.cs b/Svrnty.CQRS.Sagas.Abstractions/ISaga.cs new file mode 100644 index 0000000..5e9d9be --- /dev/null +++ b/Svrnty.CQRS.Sagas.Abstractions/ISaga.cs @@ -0,0 +1,14 @@ +namespace Svrnty.CQRS.Sagas.Abstractions; + +/// +/// Defines a saga with its steps and compensation logic. +/// +/// The saga's data/context type. +public interface ISaga where TData : class, ISagaData, new() +{ + /// + /// Configures the saga steps using the fluent builder. + /// + /// The saga builder for defining steps. + void Configure(ISagaBuilder builder); +} diff --git a/Svrnty.CQRS.Sagas.Abstractions/ISagaBuilder.cs b/Svrnty.CQRS.Sagas.Abstractions/ISagaBuilder.cs new file mode 100644 index 0000000..285bfc1 --- /dev/null +++ b/Svrnty.CQRS.Sagas.Abstractions/ISagaBuilder.cs @@ -0,0 +1,173 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Svrnty.CQRS.Sagas.Abstractions; + +/// +/// Fluent builder for defining saga steps. +/// +/// The saga data type. +public interface ISagaBuilder where TData : class, ISagaData +{ + /// + /// Adds a local step that executes synchronously within the orchestrator process. + /// + /// Unique name for this step. + /// Builder for configuring the step. + ISagaStepBuilder Step(string name); + + /// + /// Adds a step that sends a command to a remote service via messaging. + /// + /// The command type to send. + /// Unique name for this step. + /// Builder for configuring the remote step. + ISagaRemoteStepBuilder SendCommand(string name) where TCommand : class; + + /// + /// Adds a step that sends a command and expects a specific result. + /// + /// The command type to send. + /// The expected result type. + /// Unique name for this step. + /// Builder for configuring the remote step. + ISagaRemoteStepBuilder SendCommand(string name) where TCommand : class; +} + +/// +/// Builder for configuring a local saga step. +/// +/// The saga data type. +public interface ISagaStepBuilder where TData : class, ISagaData +{ + /// + /// Defines the execution action for this step. + /// + /// The action to execute. + /// This builder for chaining. + ISagaStepBuilder Execute(Func action); + + /// + /// Defines the compensation action for this step. + /// + /// The compensation action to execute on rollback. + /// This builder for chaining. + ISagaStepBuilder Compensate(Func action); + + /// + /// Completes this step definition and returns to the saga builder. + /// + /// The saga builder for adding more steps. + ISagaBuilder Then(); +} + +/// +/// Builder for configuring a remote command saga step (no result). +/// +/// The saga data type. +/// The command type to send. +public interface ISagaRemoteStepBuilder + where TData : class, ISagaData + where TCommand : class +{ + /// + /// Defines how to build the command from saga data. + /// + /// Function to create the command. + /// This builder for chaining. + ISagaRemoteStepBuilder WithCommand(Func commandBuilder); + + /// + /// Defines what to do when the command completes successfully. + /// + /// Handler for the response. + /// This builder for chaining. + ISagaRemoteStepBuilder OnResponse(Func handler); + + /// + /// Defines the compensation command to send on rollback. + /// + /// The compensation command type. + /// Function to create the compensation command. + /// This builder for chaining. + ISagaRemoteStepBuilder Compensate( + Func compensationBuilder) where TCompensationCommand : class; + + /// + /// Sets a timeout for this step. + /// + /// The timeout duration. + /// This builder for chaining. + ISagaRemoteStepBuilder WithTimeout(TimeSpan timeout); + + /// + /// Configures retry behavior for this step. + /// + /// Maximum number of retries. + /// Delay between retries. + /// This builder for chaining. + ISagaRemoteStepBuilder WithRetry(int maxRetries, TimeSpan delay); + + /// + /// Completes this step definition and returns to the saga builder. + /// + /// The saga builder for adding more steps. + ISagaBuilder Then(); +} + +/// +/// Builder for configuring a remote command saga step with result. +/// +/// The saga data type. +/// The command type to send. +/// The expected result type. +public interface ISagaRemoteStepBuilder + where TData : class, ISagaData + where TCommand : class +{ + /// + /// Defines how to build the command from saga data. + /// + /// Function to create the command. + /// This builder for chaining. + ISagaRemoteStepBuilder WithCommand(Func commandBuilder); + + /// + /// Defines what to do when the command completes successfully with a result. + /// + /// Handler for the response with result. + /// This builder for chaining. + ISagaRemoteStepBuilder OnResponse( + Func handler); + + /// + /// Defines the compensation command to send on rollback. + /// + /// The compensation command type. + /// Function to create the compensation command. + /// This builder for chaining. + ISagaRemoteStepBuilder Compensate( + Func compensationBuilder) where TCompensationCommand : class; + + /// + /// Sets a timeout for this step. + /// + /// The timeout duration. + /// This builder for chaining. + ISagaRemoteStepBuilder WithTimeout(TimeSpan timeout); + + /// + /// Configures retry behavior for this step. + /// + /// Maximum number of retries. + /// Delay between retries. + /// This builder for chaining. + ISagaRemoteStepBuilder WithRetry(int maxRetries, TimeSpan delay); + + /// + /// Completes this step definition and returns to the saga builder. + /// + /// The saga builder for adding more steps. + ISagaBuilder Then(); +} diff --git a/Svrnty.CQRS.Sagas.Abstractions/ISagaContext.cs b/Svrnty.CQRS.Sagas.Abstractions/ISagaContext.cs new file mode 100644 index 0000000..a1add0b --- /dev/null +++ b/Svrnty.CQRS.Sagas.Abstractions/ISagaContext.cs @@ -0,0 +1,55 @@ +using System; +using System.Collections.Generic; + +namespace Svrnty.CQRS.Sagas.Abstractions; + +/// +/// Provides context information during saga step execution. +/// +public interface ISagaContext +{ + /// + /// Unique identifier for this saga instance. + /// + Guid SagaId { get; } + + /// + /// Correlation ID for tracing across services. + /// + Guid CorrelationId { get; } + + /// + /// The fully qualified type name of the saga. + /// + string SagaType { get; } + + /// + /// Index of the current step being executed. + /// + int CurrentStepIndex { get; } + + /// + /// Name of the current step being executed. + /// + string CurrentStepName { get; } + + /// + /// Results from completed steps, keyed by step name. + /// + IReadOnlyDictionary StepResults { get; } + + /// + /// Gets a result from a previous step. + /// + /// The expected result type. + /// The name of the step. + /// The result, or default if not found. + T? GetStepResult(string stepName); + + /// + /// Sets a result for the current step (available to subsequent steps). + /// + /// The result type. + /// The result value. + void SetStepResult(T result); +} diff --git a/Svrnty.CQRS.Sagas.Abstractions/ISagaData.cs b/Svrnty.CQRS.Sagas.Abstractions/ISagaData.cs new file mode 100644 index 0000000..c26c912 --- /dev/null +++ b/Svrnty.CQRS.Sagas.Abstractions/ISagaData.cs @@ -0,0 +1,14 @@ +using System; + +namespace Svrnty.CQRS.Sagas.Abstractions; + +/// +/// Marker interface for saga data. All saga data classes must implement this interface. +/// +public interface ISagaData +{ + /// + /// Correlation ID for tracing the saga across services. + /// + Guid CorrelationId { get; set; } +} diff --git a/Svrnty.CQRS.Sagas.Abstractions/ISagaOrchestrator.cs b/Svrnty.CQRS.Sagas.Abstractions/ISagaOrchestrator.cs new file mode 100644 index 0000000..4f57dc3 --- /dev/null +++ b/Svrnty.CQRS.Sagas.Abstractions/ISagaOrchestrator.cs @@ -0,0 +1,52 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Svrnty.CQRS.Sagas.Abstractions; + +/// +/// Orchestrates saga execution. +/// +public interface ISagaOrchestrator +{ + /// + /// Starts a new saga instance with a generated correlation ID. + /// + /// The saga type. + /// The saga data type. + /// The initial saga data. + /// Cancellation token. + /// The saga state. + Task StartAsync(TData initialData, CancellationToken cancellationToken = default) + where TSaga : ISaga + where TData : class, ISagaData, new(); + + /// + /// Starts a new saga instance with a specific correlation ID. + /// + /// The saga type. + /// The saga data type. + /// The initial saga data. + /// The correlation ID for tracing. + /// Cancellation token. + /// The saga state. + Task StartAsync(TData initialData, Guid correlationId, CancellationToken cancellationToken = default) + where TSaga : ISaga + where TData : class, ISagaData, new(); + + /// + /// Gets the current state of a saga by its ID. + /// + /// The saga instance ID. + /// Cancellation token. + /// The saga state, or null if not found. + Task GetStateAsync(Guid sagaId, CancellationToken cancellationToken = default); + + /// + /// Gets the current state of a saga by its correlation ID. + /// + /// The correlation ID. + /// Cancellation token. + /// The saga state, or null if not found. + Task GetStateByCorrelationIdAsync(Guid correlationId, CancellationToken cancellationToken = default); +} diff --git a/Svrnty.CQRS.Sagas.Abstractions/Messaging/ISagaMessageBus.cs b/Svrnty.CQRS.Sagas.Abstractions/Messaging/ISagaMessageBus.cs new file mode 100644 index 0000000..6486d67 --- /dev/null +++ b/Svrnty.CQRS.Sagas.Abstractions/Messaging/ISagaMessageBus.cs @@ -0,0 +1,44 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Svrnty.CQRS.Sagas.Abstractions.Messaging; + +/// +/// Abstraction for saga messaging transport. +/// +public interface ISagaMessageBus +{ + /// + /// Publishes a saga command message to the message bus. + /// + /// The message to publish. + /// Cancellation token. + Task PublishAsync(SagaMessage message, CancellationToken cancellationToken = default); + + /// + /// Publishes a saga step response to the message bus. + /// + /// The response to publish. + /// Cancellation token. + Task PublishResponseAsync(SagaStepResponse response, CancellationToken cancellationToken = default); + + /// + /// Subscribes to saga messages for a specific command type. + /// + /// The command type to subscribe to. + /// Handler that processes the message and returns a response. + /// Cancellation token. + Task SubscribeAsync( + Func> handler, + CancellationToken cancellationToken = default) where TCommand : class; + + /// + /// Subscribes to saga step responses. + /// + /// Handler that processes responses. + /// Cancellation token. + Task SubscribeToResponsesAsync( + Func handler, + CancellationToken cancellationToken = default); +} diff --git a/Svrnty.CQRS.Sagas.Abstractions/Messaging/SagaMessage.cs b/Svrnty.CQRS.Sagas.Abstractions/Messaging/SagaMessage.cs new file mode 100644 index 0000000..330f8dc --- /dev/null +++ b/Svrnty.CQRS.Sagas.Abstractions/Messaging/SagaMessage.cs @@ -0,0 +1,55 @@ +using System; +using System.Collections.Generic; + +namespace Svrnty.CQRS.Sagas.Abstractions.Messaging; + +/// +/// Message envelope for saga commands sent to remote services. +/// +public record SagaMessage +{ + /// + /// Unique identifier for this message. + /// + public Guid MessageId { get; init; } = Guid.NewGuid(); + + /// + /// The saga instance ID. + /// + public Guid SagaId { get; init; } + + /// + /// Correlation ID for tracing across services. + /// + public Guid CorrelationId { get; init; } + + /// + /// Name of the saga step that sent this message. + /// + public string StepName { get; init; } = string.Empty; + + /// + /// Fully qualified type name of the command. + /// + public string CommandType { get; init; } = string.Empty; + + /// + /// Serialized command payload (JSON). + /// + public string? Payload { get; init; } + + /// + /// When the message was created. + /// + public DateTimeOffset Timestamp { get; init; } = DateTimeOffset.UtcNow; + + /// + /// Additional headers for the message. + /// + public Dictionary Headers { get; init; } = new(); + + /// + /// Whether this is a compensation command. + /// + public bool IsCompensation { get; init; } +} diff --git a/Svrnty.CQRS.Sagas.Abstractions/Messaging/SagaStepResponse.cs b/Svrnty.CQRS.Sagas.Abstractions/Messaging/SagaStepResponse.cs new file mode 100644 index 0000000..c3142a6 --- /dev/null +++ b/Svrnty.CQRS.Sagas.Abstractions/Messaging/SagaStepResponse.cs @@ -0,0 +1,59 @@ +using System; + +namespace Svrnty.CQRS.Sagas.Abstractions.Messaging; + +/// +/// Response message from a saga step execution. +/// +public record SagaStepResponse +{ + /// + /// Unique identifier for this response. + /// + public Guid MessageId { get; init; } = Guid.NewGuid(); + + /// + /// The saga instance ID. + /// + public Guid SagaId { get; init; } + + /// + /// Correlation ID for tracing across services. + /// + public Guid CorrelationId { get; init; } + + /// + /// Name of the saga step that this response is for. + /// + public string StepName { get; init; } = string.Empty; + + /// + /// Whether the step executed successfully. + /// + public bool Success { get; init; } + + /// + /// Fully qualified type name of the result (if any). + /// + public string? ResultType { get; init; } + + /// + /// Serialized result payload (JSON). + /// + public string? ResultPayload { get; init; } + + /// + /// Error message if the step failed. + /// + public string? ErrorMessage { get; init; } + + /// + /// Stack trace if the step failed (for debugging). + /// + public string? StackTrace { get; init; } + + /// + /// When the response was created. + /// + public DateTimeOffset Timestamp { get; init; } = DateTimeOffset.UtcNow; +} diff --git a/Svrnty.CQRS.Sagas.Abstractions/Persistence/ISagaStateStore.cs b/Svrnty.CQRS.Sagas.Abstractions/Persistence/ISagaStateStore.cs new file mode 100644 index 0000000..b4f6d4e --- /dev/null +++ b/Svrnty.CQRS.Sagas.Abstractions/Persistence/ISagaStateStore.cs @@ -0,0 +1,59 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Svrnty.CQRS.Sagas.Abstractions.Persistence; + +/// +/// Abstraction for saga state persistence. +/// +public interface ISagaStateStore +{ + /// + /// Creates a new saga state entry. + /// + /// The saga state to create. + /// Cancellation token. + /// The created saga state. + Task CreateAsync(SagaState state, CancellationToken cancellationToken = default); + + /// + /// Gets a saga state by its ID. + /// + /// The saga instance ID. + /// Cancellation token. + /// The saga state, or null if not found. + Task GetByIdAsync(Guid sagaId, CancellationToken cancellationToken = default); + + /// + /// Gets a saga state by its correlation ID. + /// + /// The correlation ID. + /// Cancellation token. + /// The saga state, or null if not found. + Task GetByCorrelationIdAsync(Guid correlationId, CancellationToken cancellationToken = default); + + /// + /// Updates an existing saga state. + /// + /// The saga state to update. + /// Cancellation token. + /// The updated saga state. + Task UpdateAsync(SagaState state, CancellationToken cancellationToken = default); + + /// + /// Gets all pending (in progress) sagas. + /// + /// Cancellation token. + /// List of pending saga states. + Task> GetPendingSagasAsync(CancellationToken cancellationToken = default); + + /// + /// Gets all sagas with a specific status. + /// + /// The status to filter by. + /// Cancellation token. + /// List of saga states with the specified status. + Task> GetSagasByStatusAsync(SagaStatus status, CancellationToken cancellationToken = default); +} diff --git a/Svrnty.CQRS.Sagas.Abstractions/SagaState.cs b/Svrnty.CQRS.Sagas.Abstractions/SagaState.cs new file mode 100644 index 0000000..d60cd57 --- /dev/null +++ b/Svrnty.CQRS.Sagas.Abstractions/SagaState.cs @@ -0,0 +1,85 @@ +using System; +using System.Collections.Generic; + +namespace Svrnty.CQRS.Sagas.Abstractions; + +/// +/// Represents the persistent state of a saga instance. +/// +public class SagaState +{ + /// + /// Unique identifier for this saga instance. + /// + public Guid SagaId { get; set; } = Guid.NewGuid(); + + /// + /// The fully qualified type name of the saga. + /// + public string SagaType { get; set; } = string.Empty; + + /// + /// Correlation ID for tracing across services. + /// + public Guid CorrelationId { get; set; } + + /// + /// Current execution status. + /// + public SagaStatus Status { get; set; } = SagaStatus.NotStarted; + + /// + /// Index of the current step being executed. + /// + public int CurrentStepIndex { get; set; } + + /// + /// Name of the current step being executed. + /// + public string? CurrentStepName { get; set; } + + /// + /// Results from completed steps, keyed by step name. + /// + public Dictionary StepResults { get; set; } = new(); + + /// + /// Names of steps that have been completed. + /// + public List CompletedSteps { get; set; } = new(); + + /// + /// Errors that occurred during saga execution. + /// + public List Errors { get; set; } = new(); + + /// + /// Serialized saga data (JSON). + /// + public string? SerializedData { get; set; } + + /// + /// When the saga was created. + /// + public DateTimeOffset CreatedAt { get; set; } = DateTimeOffset.UtcNow; + + /// + /// When the saga was last updated. + /// + public DateTimeOffset? UpdatedAt { get; set; } + + /// + /// When the saga completed (successfully or compensated). + /// + public DateTimeOffset? CompletedAt { get; set; } +} + +/// +/// Represents an error that occurred during saga step execution. +/// +public record SagaStepError( + string StepName, + string ErrorMessage, + string? StackTrace, + DateTimeOffset OccurredAt +); diff --git a/Svrnty.CQRS.Sagas.Abstractions/SagaStatus.cs b/Svrnty.CQRS.Sagas.Abstractions/SagaStatus.cs new file mode 100644 index 0000000..3db90df --- /dev/null +++ b/Svrnty.CQRS.Sagas.Abstractions/SagaStatus.cs @@ -0,0 +1,37 @@ +namespace Svrnty.CQRS.Sagas.Abstractions; + +/// +/// Represents the execution state of a saga. +/// +public enum SagaStatus +{ + /// + /// Saga has not started execution. + /// + NotStarted, + + /// + /// Saga is currently executing steps. + /// + InProgress, + + /// + /// Saga completed successfully. + /// + Completed, + + /// + /// Saga failed and compensation has not been triggered. + /// + Failed, + + /// + /// Saga is currently executing compensation steps. + /// + Compensating, + + /// + /// Saga has been compensated (rolled back) successfully. + /// + Compensated +} diff --git a/Svrnty.CQRS.Sagas.Abstractions/Svrnty.CQRS.Sagas.Abstractions.csproj b/Svrnty.CQRS.Sagas.Abstractions/Svrnty.CQRS.Sagas.Abstractions.csproj new file mode 100644 index 0000000..f74d835 --- /dev/null +++ b/Svrnty.CQRS.Sagas.Abstractions/Svrnty.CQRS.Sagas.Abstractions.csproj @@ -0,0 +1,28 @@ + + + net10.0 + true + 14 + enable + + Svrnty + David Lebee, Mathias Beaulieu-Duncan + icon.png + README.md + https://git.openharbor.io/svrnty/dotnet-cqrs + git + true + MIT + + portable + true + true + true + snupkg + + + + + + + diff --git a/Svrnty.CQRS.Sagas.RabbitMQ/CqrsBuilderExtensions.cs b/Svrnty.CQRS.Sagas.RabbitMQ/CqrsBuilderExtensions.cs new file mode 100644 index 0000000..c0a6a92 --- /dev/null +++ b/Svrnty.CQRS.Sagas.RabbitMQ/CqrsBuilderExtensions.cs @@ -0,0 +1,60 @@ +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Svrnty.CQRS.Configuration; +using Svrnty.CQRS.Sagas.Abstractions.Messaging; + +namespace Svrnty.CQRS.Sagas.RabbitMQ; + +/// +/// Extensions for adding RabbitMQ saga transport to the CQRS pipeline. +/// +public static class CqrsBuilderExtensions +{ + /// + /// Uses RabbitMQ as the message transport for sagas. + /// + /// The CQRS builder. + /// Configuration action for RabbitMQ options. + /// The CQRS builder for chaining. + public static CqrsBuilder UseRabbitMq(this CqrsBuilder builder, Action configure) + { + var options = new RabbitMqSagaOptions(); + configure(options); + + builder.Services.Configure(opt => + { + opt.HostName = options.HostName; + opt.Port = options.Port; + opt.UserName = options.UserName; + opt.Password = options.Password; + opt.VirtualHost = options.VirtualHost; + opt.CommandExchange = options.CommandExchange; + opt.ResponseExchange = options.ResponseExchange; + opt.QueuePrefix = options.QueuePrefix; + opt.DurableQueues = options.DurableQueues; + opt.PrefetchCount = options.PrefetchCount; + opt.ConnectionRetryDelay = options.ConnectionRetryDelay; + opt.MaxConnectionRetries = options.MaxConnectionRetries; + }); + + // Replace the default message bus with RabbitMQ implementation + builder.Services.RemoveAll(); + builder.Services.AddSingleton(); + + // Add hosted service for connection management + builder.Services.AddHostedService(); + + return builder; + } + + /// + /// Uses RabbitMQ as the message transport for sagas with default options. + /// + /// The CQRS builder. + /// The CQRS builder for chaining. + public static CqrsBuilder UseRabbitMq(this CqrsBuilder builder) + { + return builder.UseRabbitMq(_ => { }); + } +} diff --git a/Svrnty.CQRS.Sagas.RabbitMQ/RabbitMqSagaHostedService.cs b/Svrnty.CQRS.Sagas.RabbitMQ/RabbitMqSagaHostedService.cs new file mode 100644 index 0000000..6cb5129 --- /dev/null +++ b/Svrnty.CQRS.Sagas.RabbitMQ/RabbitMqSagaHostedService.cs @@ -0,0 +1,88 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Svrnty.CQRS.Sagas.Abstractions; +using Svrnty.CQRS.Sagas.Abstractions.Messaging; + +namespace Svrnty.CQRS.Sagas.RabbitMQ; + +/// +/// Hosted service that manages RabbitMQ saga connections and subscriptions. +/// +public class RabbitMqSagaHostedService : BackgroundService +{ + private readonly IServiceProvider _serviceProvider; + private readonly ISagaMessageBus _messageBus; + private readonly ILogger _logger; + + /// + /// Creates a new RabbitMQ saga hosted service. + /// + public RabbitMqSagaHostedService( + IServiceProvider serviceProvider, + ISagaMessageBus messageBus, + ILogger logger) + { + _serviceProvider = serviceProvider; + _messageBus = messageBus; + _logger = logger; + } + + /// + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + _logger.LogInformation("Starting RabbitMQ saga hosted service"); + + try + { + // Subscribe to saga responses so the orchestrator can process them + await _messageBus.SubscribeToResponsesAsync( + async (response, ct) => + { + using var scope = _serviceProvider.CreateScope(); + var orchestrator = scope.ServiceProvider.GetRequiredService(); + + // The orchestrator needs to handle responses + // This is a simplified approach - in production you'd want to handle this more robustly + _logger.LogDebug( + "Received response for saga {SagaId}, step {StepName}, success: {Success}", + response.SagaId, response.StepName, response.Success); + + // For now, we just log the response + // The orchestrator's HandleResponseAsync method would be called here + // but it requires knowing the saga data type, which we don't have in this context + }, + stoppingToken); + + _logger.LogInformation("RabbitMQ saga hosted service started successfully"); + + // Keep the service running + await Task.Delay(Timeout.Infinite, stoppingToken); + } + catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested) + { + _logger.LogInformation("RabbitMQ saga hosted service is stopping"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error in RabbitMQ saga hosted service"); + throw; + } + } + + /// + public override async Task StopAsync(CancellationToken cancellationToken) + { + _logger.LogInformation("Stopping RabbitMQ saga hosted service"); + + if (_messageBus is IAsyncDisposable disposable) + { + await disposable.DisposeAsync(); + } + + await base.StopAsync(cancellationToken); + } +} diff --git a/Svrnty.CQRS.Sagas.RabbitMQ/RabbitMqSagaMessageBus.cs b/Svrnty.CQRS.Sagas.RabbitMQ/RabbitMqSagaMessageBus.cs new file mode 100644 index 0000000..e48aac2 --- /dev/null +++ b/Svrnty.CQRS.Sagas.RabbitMQ/RabbitMqSagaMessageBus.cs @@ -0,0 +1,335 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using RabbitMQ.Client; +using RabbitMQ.Client.Events; +using Svrnty.CQRS.Sagas.Abstractions.Messaging; + +namespace Svrnty.CQRS.Sagas.RabbitMQ; + +/// +/// RabbitMQ implementation of the saga message bus. +/// +public class RabbitMqSagaMessageBus : ISagaMessageBus, IAsyncDisposable +{ + private readonly RabbitMqSagaOptions _options; + private readonly ILogger _logger; + private IConnection? _connection; + private IChannel? _publishChannel; + private readonly ConcurrentDictionary _subscriptionChannels = new(); + private readonly SemaphoreSlim _connectionLock = new(1, 1); + private bool _disposed; + + /// + /// Creates a new RabbitMQ saga message bus. + /// + public RabbitMqSagaMessageBus( + IOptions options, + ILogger logger) + { + _options = options.Value; + _logger = logger; + } + + /// + public async Task PublishAsync(SagaMessage message, CancellationToken cancellationToken = default) + { + await EnsureConnectionAsync(cancellationToken); + + var routingKey = $"saga.command.{message.CommandType}"; + var body = JsonSerializer.SerializeToUtf8Bytes(message); + + var properties = new BasicProperties + { + MessageId = message.MessageId.ToString(), + CorrelationId = message.CorrelationId.ToString(), + ContentType = "application/json", + DeliveryMode = _options.DurableQueues ? DeliveryModes.Persistent : DeliveryModes.Transient, + Timestamp = new AmqpTimestamp(message.Timestamp.ToUnixTimeSeconds()), + Headers = new Dictionary + { + ["saga-id"] = message.SagaId.ToString(), + ["step-name"] = message.StepName, + ["is-compensation"] = message.IsCompensation.ToString() + } + }; + + await _publishChannel!.BasicPublishAsync( + exchange: _options.CommandExchange, + routingKey: routingKey, + mandatory: false, + basicProperties: properties, + body: body, + cancellationToken: cancellationToken); + + _logger.LogDebug( + "Published saga command {CommandType} for saga {SagaId}, step {StepName}", + message.CommandType, message.SagaId, message.StepName); + } + + /// + public async Task PublishResponseAsync(SagaStepResponse response, CancellationToken cancellationToken = default) + { + await EnsureConnectionAsync(cancellationToken); + + var routingKey = $"saga.response.{response.SagaId}"; + var body = JsonSerializer.SerializeToUtf8Bytes(response); + + var properties = new BasicProperties + { + MessageId = response.MessageId.ToString(), + CorrelationId = response.CorrelationId.ToString(), + ContentType = "application/json", + DeliveryMode = _options.DurableQueues ? DeliveryModes.Persistent : DeliveryModes.Transient, + Timestamp = new AmqpTimestamp(response.Timestamp.ToUnixTimeSeconds()), + Headers = new Dictionary + { + ["saga-id"] = response.SagaId.ToString(), + ["step-name"] = response.StepName, + ["success"] = response.Success.ToString() + } + }; + + await _publishChannel!.BasicPublishAsync( + exchange: _options.ResponseExchange, + routingKey: routingKey, + mandatory: false, + basicProperties: properties, + body: body, + cancellationToken: cancellationToken); + + _logger.LogDebug( + "Published saga response for saga {SagaId}, step {StepName}, success: {Success}", + response.SagaId, response.StepName, response.Success); + } + + /// + public async Task SubscribeAsync( + Func> handler, + CancellationToken cancellationToken = default) + where TCommand : class + { + await EnsureConnectionAsync(cancellationToken); + + var commandTypeName = typeof(TCommand).FullName!; + var queueName = $"{_options.QueuePrefix}.{SanitizeQueueName(commandTypeName)}"; + var routingKey = $"saga.command.{commandTypeName}"; + + var channel = await _connection!.CreateChannelAsync(cancellationToken: cancellationToken); + _subscriptionChannels[commandTypeName] = channel; + + // Declare queue + await channel.QueueDeclareAsync( + queue: queueName, + durable: _options.DurableQueues, + exclusive: false, + autoDelete: false, + cancellationToken: cancellationToken); + + // Bind to command exchange + await channel.QueueBindAsync( + queue: queueName, + exchange: _options.CommandExchange, + routingKey: routingKey, + cancellationToken: cancellationToken); + + await channel.BasicQosAsync(prefetchSize: 0, prefetchCount: _options.PrefetchCount, global: false, cancellationToken: cancellationToken); + + var consumer = new AsyncEventingBasicConsumer(channel); + consumer.ReceivedAsync += async (sender, ea) => + { + try + { + var messageJson = Encoding.UTF8.GetString(ea.Body.ToArray()); + var message = JsonSerializer.Deserialize(messageJson); + + if (message == null) + { + _logger.LogWarning("Received null saga message"); + await channel.BasicNackAsync(ea.DeliveryTag, false, false, cancellationToken); + return; + } + + var command = JsonSerializer.Deserialize(message.Payload!); + if (command == null) + { + _logger.LogWarning("Failed to deserialize command {CommandType}", commandTypeName); + await channel.BasicNackAsync(ea.DeliveryTag, false, false, cancellationToken); + return; + } + + var response = await handler(message, command, cancellationToken); + await PublishResponseAsync(response, cancellationToken); + await channel.BasicAckAsync(ea.DeliveryTag, false, cancellationToken); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error processing saga command {CommandType}", commandTypeName); + await channel.BasicNackAsync(ea.DeliveryTag, false, true, cancellationToken); + } + }; + + await channel.BasicConsumeAsync(queueName, false, consumer, cancellationToken); + + _logger.LogInformation( + "Subscribed to saga commands of type {CommandType} on queue {QueueName}", + commandTypeName, queueName); + } + + /// + public async Task SubscribeToResponsesAsync( + Func handler, + CancellationToken cancellationToken = default) + { + await EnsureConnectionAsync(cancellationToken); + + var queueName = $"{_options.QueuePrefix}.responses"; + var routingKey = "saga.response.#"; + + var channel = await _connection!.CreateChannelAsync(cancellationToken: cancellationToken); + _subscriptionChannels["responses"] = channel; + + // Declare queue + await channel.QueueDeclareAsync( + queue: queueName, + durable: _options.DurableQueues, + exclusive: false, + autoDelete: false, + cancellationToken: cancellationToken); + + // Bind to response exchange + await channel.QueueBindAsync( + queue: queueName, + exchange: _options.ResponseExchange, + routingKey: routingKey, + cancellationToken: cancellationToken); + + await channel.BasicQosAsync(prefetchSize: 0, prefetchCount: _options.PrefetchCount, global: false, cancellationToken: cancellationToken); + + var consumer = new AsyncEventingBasicConsumer(channel); + consumer.ReceivedAsync += async (sender, ea) => + { + try + { + var responseJson = Encoding.UTF8.GetString(ea.Body.ToArray()); + var response = JsonSerializer.Deserialize(responseJson); + + if (response == null) + { + _logger.LogWarning("Received null saga response"); + await channel.BasicNackAsync(ea.DeliveryTag, false, false, cancellationToken); + return; + } + + await handler(response, cancellationToken); + await channel.BasicAckAsync(ea.DeliveryTag, false, cancellationToken); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error processing saga response"); + await channel.BasicNackAsync(ea.DeliveryTag, false, true, cancellationToken); + } + }; + + await channel.BasicConsumeAsync(queueName, false, consumer, cancellationToken); + + _logger.LogInformation("Subscribed to saga responses on queue {QueueName}", queueName); + } + + private async Task EnsureConnectionAsync(CancellationToken cancellationToken) + { + if (_connection?.IsOpen == true && _publishChannel?.IsOpen == true) + { + return; + } + + await _connectionLock.WaitAsync(cancellationToken); + try + { + if (_connection?.IsOpen == true && _publishChannel?.IsOpen == true) + { + return; + } + + var factory = new ConnectionFactory + { + HostName = _options.HostName, + Port = _options.Port, + UserName = _options.UserName, + Password = _options.Password, + VirtualHost = _options.VirtualHost + }; + + _connection = await factory.CreateConnectionAsync(cancellationToken); + _publishChannel = await _connection.CreateChannelAsync(cancellationToken: cancellationToken); + + // Declare exchanges + await _publishChannel.ExchangeDeclareAsync( + exchange: _options.CommandExchange, + type: ExchangeType.Topic, + durable: _options.DurableQueues, + autoDelete: false, + cancellationToken: cancellationToken); + + await _publishChannel.ExchangeDeclareAsync( + exchange: _options.ResponseExchange, + type: ExchangeType.Topic, + durable: _options.DurableQueues, + autoDelete: false, + cancellationToken: cancellationToken); + + _logger.LogInformation( + "Connected to RabbitMQ at {Host}:{Port}", + _options.HostName, _options.Port); + } + finally + { + _connectionLock.Release(); + } + } + + private static string SanitizeQueueName(string name) + { + return name.Replace(".", "-").Replace("+", "-").ToLowerInvariant(); + } + + /// + public async ValueTask DisposeAsync() + { + if (_disposed) + { + return; + } + + _disposed = true; + + foreach (var channel in _subscriptionChannels.Values) + { + if (channel.IsOpen) + { + await channel.CloseAsync(); + } + channel.Dispose(); + } + + if (_publishChannel?.IsOpen == true) + { + await _publishChannel.CloseAsync(); + } + _publishChannel?.Dispose(); + + if (_connection?.IsOpen == true) + { + await _connection.CloseAsync(); + } + _connection?.Dispose(); + + _connectionLock.Dispose(); + } +} diff --git a/Svrnty.CQRS.Sagas.RabbitMQ/RabbitMqSagaOptions.cs b/Svrnty.CQRS.Sagas.RabbitMQ/RabbitMqSagaOptions.cs new file mode 100644 index 0000000..ffbde56 --- /dev/null +++ b/Svrnty.CQRS.Sagas.RabbitMQ/RabbitMqSagaOptions.cs @@ -0,0 +1,69 @@ +using System; + +namespace Svrnty.CQRS.Sagas.RabbitMQ; + +/// +/// Configuration options for RabbitMQ saga transport. +/// +public class RabbitMqSagaOptions +{ + /// + /// RabbitMQ host name (default: localhost). + /// + public string HostName { get; set; } = "localhost"; + + /// + /// RabbitMQ port (default: 5672). + /// + public int Port { get; set; } = 5672; + + /// + /// RabbitMQ user name (default: guest). + /// + public string UserName { get; set; } = "guest"; + + /// + /// RabbitMQ password (default: guest). + /// + public string Password { get; set; } = "guest"; + + /// + /// RabbitMQ virtual host (default: /). + /// + public string VirtualHost { get; set; } = "/"; + + /// + /// Exchange name for saga commands (default: svrnty.sagas.commands). + /// + public string CommandExchange { get; set; } = "svrnty.sagas.commands"; + + /// + /// Exchange name for saga responses (default: svrnty.sagas.responses). + /// + public string ResponseExchange { get; set; } = "svrnty.sagas.responses"; + + /// + /// Queue name prefix for this service (default: saga-service). + /// + public string QueuePrefix { get; set; } = "saga-service"; + + /// + /// Whether to use durable queues (default: true). + /// + public bool DurableQueues { get; set; } = true; + + /// + /// Prefetch count for consumers (default: 10). + /// + public ushort PrefetchCount { get; set; } = 10; + + /// + /// Connection retry delay (default: 5 seconds). + /// + public TimeSpan ConnectionRetryDelay { get; set; } = TimeSpan.FromSeconds(5); + + /// + /// Maximum connection retry attempts (default: 10). + /// + public int MaxConnectionRetries { get; set; } = 10; +} diff --git a/Svrnty.CQRS.Sagas.RabbitMQ/Svrnty.CQRS.Sagas.RabbitMQ.csproj b/Svrnty.CQRS.Sagas.RabbitMQ/Svrnty.CQRS.Sagas.RabbitMQ.csproj new file mode 100644 index 0000000..8cc41f8 --- /dev/null +++ b/Svrnty.CQRS.Sagas.RabbitMQ/Svrnty.CQRS.Sagas.RabbitMQ.csproj @@ -0,0 +1,38 @@ + + + net10.0 + false + 14 + enable + + Svrnty + David Lebee, Mathias Beaulieu-Duncan + icon.png + README.md + https://git.openharbor.io/svrnty/dotnet-cqrs + git + true + MIT + + portable + true + true + true + snupkg + + + + + + + + + + + + + + + + + diff --git a/Svrnty.CQRS.Sagas/Builders/LocalSagaStepBuilder.cs b/Svrnty.CQRS.Sagas/Builders/LocalSagaStepBuilder.cs new file mode 100644 index 0000000..ae73dca --- /dev/null +++ b/Svrnty.CQRS.Sagas/Builders/LocalSagaStepBuilder.cs @@ -0,0 +1,54 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Svrnty.CQRS.Sagas.Abstractions; + +namespace Svrnty.CQRS.Sagas.Builders; + +/// +/// Builder for configuring local saga steps. +/// +/// The saga data type. +public class LocalSagaStepBuilder : ISagaStepBuilder + where TData : class, ISagaData +{ + private readonly SagaBuilder _parent; + private readonly LocalSagaStepDefinition _definition; + + /// + /// Creates a new local step builder. + /// + /// The parent saga builder. + /// The step name. + /// The step order. + public LocalSagaStepBuilder(SagaBuilder parent, string name, int order) + { + _parent = parent; + _definition = new LocalSagaStepDefinition + { + Name = name, + Order = order + }; + } + + /// + public ISagaStepBuilder Execute(Func action) + { + _definition.ExecuteAction = action; + return this; + } + + /// + public ISagaStepBuilder Compensate(Func action) + { + _definition.CompensateAction = action; + return this; + } + + /// + public ISagaBuilder Then() + { + _parent.AddStep(_definition); + return _parent; + } +} diff --git a/Svrnty.CQRS.Sagas/Builders/RemoteSagaStepBuilder.cs b/Svrnty.CQRS.Sagas/Builders/RemoteSagaStepBuilder.cs new file mode 100644 index 0000000..d5efb35 --- /dev/null +++ b/Svrnty.CQRS.Sagas/Builders/RemoteSagaStepBuilder.cs @@ -0,0 +1,158 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Svrnty.CQRS.Sagas.Abstractions; + +namespace Svrnty.CQRS.Sagas.Builders; + +/// +/// Builder for configuring remote saga steps (without result). +/// +/// The saga data type. +/// The command type. +public class RemoteSagaStepBuilder : ISagaRemoteStepBuilder + where TData : class, ISagaData + where TCommand : class +{ + private readonly SagaBuilder _parent; + private readonly RemoteSagaStepDefinition _definition; + + /// + /// Creates a new remote step builder. + /// + /// The parent saga builder. + /// The step name. + /// The step order. + public RemoteSagaStepBuilder(SagaBuilder parent, string name, int order) + { + _parent = parent; + _definition = new RemoteSagaStepDefinition + { + Name = name, + Order = order + }; + } + + /// + public ISagaRemoteStepBuilder WithCommand(Func commandBuilder) + { + _definition.CommandBuilder = commandBuilder; + return this; + } + + /// + public ISagaRemoteStepBuilder OnResponse(Func handler) + { + _definition.ResponseHandler = handler; + return this; + } + + /// + public ISagaRemoteStepBuilder Compensate( + Func compensationBuilder) + where TCompensationCommand : class + { + _definition.CompensationCommandType = typeof(TCompensationCommand); + _definition.CompensationBuilder = (data, ctx) => compensationBuilder(data, ctx); + return this; + } + + /// + public ISagaRemoteStepBuilder WithTimeout(TimeSpan timeout) + { + _definition.Timeout = timeout; + return this; + } + + /// + public ISagaRemoteStepBuilder WithRetry(int maxRetries, TimeSpan delay) + { + _definition.MaxRetries = maxRetries; + _definition.RetryDelay = delay; + return this; + } + + /// + public ISagaBuilder Then() + { + _parent.AddStep(_definition); + return _parent; + } +} + +/// +/// Builder for configuring remote saga steps with result. +/// +/// The saga data type. +/// The command type. +/// The result type. +public class RemoteSagaStepBuilderWithResult : ISagaRemoteStepBuilder + where TData : class, ISagaData + where TCommand : class +{ + private readonly SagaBuilder _parent; + private readonly RemoteSagaStepDefinition _definition; + + /// + /// Creates a new remote step builder with result. + /// + /// The parent saga builder. + /// The step name. + /// The step order. + public RemoteSagaStepBuilderWithResult(SagaBuilder parent, string name, int order) + { + _parent = parent; + _definition = new RemoteSagaStepDefinition + { + Name = name, + Order = order + }; + } + + /// + public ISagaRemoteStepBuilder WithCommand(Func commandBuilder) + { + _definition.CommandBuilder = commandBuilder; + return this; + } + + /// + public ISagaRemoteStepBuilder OnResponse( + Func handler) + { + _definition.ResponseHandler = handler; + return this; + } + + /// + public ISagaRemoteStepBuilder Compensate( + Func compensationBuilder) + where TCompensationCommand : class + { + _definition.CompensationCommandType = typeof(TCompensationCommand); + _definition.CompensationBuilder = (data, ctx) => compensationBuilder(data, ctx); + return this; + } + + /// + public ISagaRemoteStepBuilder WithTimeout(TimeSpan timeout) + { + _definition.Timeout = timeout; + return this; + } + + /// + public ISagaRemoteStepBuilder WithRetry(int maxRetries, TimeSpan delay) + { + _definition.MaxRetries = maxRetries; + _definition.RetryDelay = delay; + return this; + } + + /// + public ISagaBuilder Then() + { + _parent.AddStep(_definition); + return _parent; + } +} diff --git a/Svrnty.CQRS.Sagas/Builders/SagaBuilder.cs b/Svrnty.CQRS.Sagas/Builders/SagaBuilder.cs new file mode 100644 index 0000000..21759ba --- /dev/null +++ b/Svrnty.CQRS.Sagas/Builders/SagaBuilder.cs @@ -0,0 +1,49 @@ +using System; +using System.Collections.Generic; +using Svrnty.CQRS.Sagas.Abstractions; + +namespace Svrnty.CQRS.Sagas.Builders; + +/// +/// Implementation of the saga builder for defining saga steps. +/// +/// The saga data type. +public class SagaBuilder : ISagaBuilder + where TData : class, ISagaData +{ + private readonly List _steps = new(); + + /// + /// Gets the defined steps. + /// + public IReadOnlyList Steps => _steps.AsReadOnly(); + + /// + public ISagaStepBuilder Step(string name) + { + return new LocalSagaStepBuilder(this, name, _steps.Count); + } + + /// + public ISagaRemoteStepBuilder SendCommand(string name) + where TCommand : class + { + return new RemoteSagaStepBuilder(this, name, _steps.Count); + } + + /// + public ISagaRemoteStepBuilder SendCommand(string name) + where TCommand : class + { + return new RemoteSagaStepBuilderWithResult(this, name, _steps.Count); + } + + /// + /// Adds a step definition to the builder. + /// + /// The step definition to add. + internal void AddStep(SagaStepDefinition step) + { + _steps.Add(step); + } +} diff --git a/Svrnty.CQRS.Sagas/Builders/SagaStepDefinition.cs b/Svrnty.CQRS.Sagas/Builders/SagaStepDefinition.cs new file mode 100644 index 0000000..05d97ff --- /dev/null +++ b/Svrnty.CQRS.Sagas/Builders/SagaStepDefinition.cs @@ -0,0 +1,149 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Svrnty.CQRS.Sagas.Abstractions; + +namespace Svrnty.CQRS.Sagas.Builders; + +/// +/// Base class for saga step definitions. +/// +public abstract class SagaStepDefinition +{ + /// + /// Unique name for this step. + /// + public string Name { get; set; } = string.Empty; + + /// + /// Order of the step in the saga. + /// + public int Order { get; set; } + + /// + /// Whether this step has a compensation action. + /// + public abstract bool HasCompensation { get; } + + /// + /// Whether this step is a remote step (sends a command). + /// + public abstract bool IsRemote { get; } + + /// + /// Timeout for this step. + /// + public TimeSpan? Timeout { get; set; } + + /// + /// Maximum number of retries. + /// + public int MaxRetries { get; set; } + + /// + /// Delay between retries. + /// + public TimeSpan RetryDelay { get; set; } = TimeSpan.FromSeconds(1); +} + +/// +/// Definition for a local saga step. +/// +/// The saga data type. +public class LocalSagaStepDefinition : SagaStepDefinition + where TData : class, ISagaData +{ + /// + /// The execution action. + /// + public Func? ExecuteAction { get; set; } + + /// + /// The compensation action. + /// + public Func? CompensateAction { get; set; } + + /// + public override bool HasCompensation => CompensateAction != null; + + /// + public override bool IsRemote => false; +} + +/// +/// Definition for a remote saga step. +/// +/// The saga data type. +/// The command type. +public class RemoteSagaStepDefinition : SagaStepDefinition + where TData : class, ISagaData + where TCommand : class +{ + /// + /// Function to build the command. + /// + public Func? CommandBuilder { get; set; } + + /// + /// Handler for successful response. + /// + public Func? ResponseHandler { get; set; } + + /// + /// Type of the compensation command. + /// + public Type? CompensationCommandType { get; set; } + + /// + /// Function to build the compensation command. + /// + public Func? CompensationBuilder { get; set; } + + /// + public override bool HasCompensation => CompensationBuilder != null; + + /// + public override bool IsRemote => true; +} + +/// +/// Definition for a remote saga step with result. +/// +/// The saga data type. +/// The command type. +/// The result type. +public class RemoteSagaStepDefinition : SagaStepDefinition + where TData : class, ISagaData + where TCommand : class +{ + /// + /// Function to build the command. + /// + public Func? CommandBuilder { get; set; } + + /// + /// Handler for successful response with result. + /// + public Func? ResponseHandler { get; set; } + + /// + /// Type of the compensation command. + /// + public Type? CompensationCommandType { get; set; } + + /// + /// Function to build the compensation command. + /// + public Func? CompensationBuilder { get; set; } + + /// + /// The expected result type. + /// + public Type ResultType => typeof(TResult); + + /// + public override bool HasCompensation => CompensationBuilder != null; + + /// + public override bool IsRemote => true; +} diff --git a/Svrnty.CQRS.Sagas/Configuration/SagaOptions.cs b/Svrnty.CQRS.Sagas/Configuration/SagaOptions.cs new file mode 100644 index 0000000..ba1435a --- /dev/null +++ b/Svrnty.CQRS.Sagas/Configuration/SagaOptions.cs @@ -0,0 +1,39 @@ +using System; + +namespace Svrnty.CQRS.Sagas.Configuration; + +/// +/// Configuration options for saga orchestration. +/// +public class SagaOptions +{ + /// + /// Default timeout for saga steps (default: 30 seconds). + /// + public TimeSpan DefaultStepTimeout { get; set; } = TimeSpan.FromSeconds(30); + + /// + /// Default number of retries for failed steps (default: 3). + /// + public int DefaultMaxRetries { get; set; } = 3; + + /// + /// Default delay between retries (default: 1 second). + /// + public TimeSpan DefaultRetryDelay { get; set; } = TimeSpan.FromSeconds(1); + + /// + /// Whether to automatically compensate on failure (default: true). + /// + public bool AutoCompensateOnFailure { get; set; } = true; + + /// + /// Interval for checking pending/stalled sagas (default: 1 minute). + /// + public TimeSpan StalledSagaCheckInterval { get; set; } = TimeSpan.FromMinutes(1); + + /// + /// Time after which a saga step is considered stalled (default: 5 minutes). + /// + public TimeSpan StepStalledTimeout { get; set; } = TimeSpan.FromMinutes(5); +} diff --git a/Svrnty.CQRS.Sagas/CqrsBuilderExtensions.cs b/Svrnty.CQRS.Sagas/CqrsBuilderExtensions.cs new file mode 100644 index 0000000..9be221f --- /dev/null +++ b/Svrnty.CQRS.Sagas/CqrsBuilderExtensions.cs @@ -0,0 +1,82 @@ +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Svrnty.CQRS.Configuration; +using Svrnty.CQRS.Sagas.Abstractions; +using Svrnty.CQRS.Sagas.Abstractions.Persistence; +using Svrnty.CQRS.Sagas.Configuration; +using Svrnty.CQRS.Sagas.Persistence; + +namespace Svrnty.CQRS.Sagas; + +/// +/// Extensions for adding saga support to the CQRS pipeline. +/// +public static class CqrsBuilderExtensions +{ + /// + /// Adds saga orchestration support to the CQRS pipeline. + /// + /// The CQRS builder. + /// Optional configuration action. + /// The CQRS builder for chaining. + public static CqrsBuilder AddSagas(this CqrsBuilder builder, Action? configure = null) + { + var options = new SagaOptions(); + configure?.Invoke(options); + + builder.Services.Configure(opt => + { + opt.DefaultStepTimeout = options.DefaultStepTimeout; + opt.DefaultMaxRetries = options.DefaultMaxRetries; + opt.DefaultRetryDelay = options.DefaultRetryDelay; + opt.AutoCompensateOnFailure = options.AutoCompensateOnFailure; + opt.StalledSagaCheckInterval = options.StalledSagaCheckInterval; + opt.StepStalledTimeout = options.StepStalledTimeout; + }); + + // Store configuration + builder.Configuration.SetConfiguration(options); + + // Register core saga services + builder.Services.TryAddSingleton(); + + // Register default in-memory state store if not already registered + builder.Services.TryAddSingleton(); + + return builder; + } + + /// + /// Registers a saga type with the CQRS pipeline. + /// + /// The saga type. + /// The saga data type. + /// The CQRS builder. + /// The CQRS builder for chaining. + public static CqrsBuilder AddSaga(this CqrsBuilder builder) + where TSaga : class, ISaga + where TData : class, ISagaData, new() + { + builder.Services.AddTransient(); + builder.Services.AddTransient, TSaga>(); + + return builder; + } + + /// + /// Uses a custom saga state store implementation. + /// + /// The state store implementation type. + /// The CQRS builder. + /// The CQRS builder for chaining. + public static CqrsBuilder UseSagaStateStore(this CqrsBuilder builder) + where TStore : class, ISagaStateStore + { + // Remove existing registration + var descriptor = new ServiceDescriptor(typeof(ISagaStateStore), typeof(TStore), ServiceLifetime.Singleton); + builder.Services.Replace(descriptor); + + return builder; + } +} diff --git a/Svrnty.CQRS.Sagas/Persistence/InMemorySagaStateStore.cs b/Svrnty.CQRS.Sagas/Persistence/InMemorySagaStateStore.cs new file mode 100644 index 0000000..8655118 --- /dev/null +++ b/Svrnty.CQRS.Sagas/Persistence/InMemorySagaStateStore.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Svrnty.CQRS.Sagas.Abstractions; +using Svrnty.CQRS.Sagas.Abstractions.Persistence; + +namespace Svrnty.CQRS.Sagas.Persistence; + +/// +/// In-memory saga state store for development and testing. +/// +public class InMemorySagaStateStore : ISagaStateStore +{ + private readonly ConcurrentDictionary _states = new(); + + /// + public Task CreateAsync(SagaState state, CancellationToken cancellationToken = default) + { + if (!_states.TryAdd(state.SagaId, state)) + { + throw new InvalidOperationException($"Saga with ID {state.SagaId} already exists."); + } + return Task.FromResult(state); + } + + /// + public Task GetByIdAsync(Guid sagaId, CancellationToken cancellationToken = default) + { + _states.TryGetValue(sagaId, out var state); + return Task.FromResult(state); + } + + /// + public Task GetByCorrelationIdAsync(Guid correlationId, CancellationToken cancellationToken = default) + { + var state = _states.Values.FirstOrDefault(s => s.CorrelationId == correlationId); + return Task.FromResult(state); + } + + /// + public Task UpdateAsync(SagaState state, CancellationToken cancellationToken = default) + { + state.UpdatedAt = DateTimeOffset.UtcNow; + _states[state.SagaId] = state; + return Task.FromResult(state); + } + + /// + public Task> GetPendingSagasAsync(CancellationToken cancellationToken = default) + { + var pending = _states.Values + .Where(s => s.Status == SagaStatus.InProgress || s.Status == SagaStatus.Compensating) + .ToList(); + return Task.FromResult>(pending); + } + + /// + public Task> GetSagasByStatusAsync(SagaStatus status, CancellationToken cancellationToken = default) + { + var sagas = _states.Values + .Where(s => s.Status == status) + .ToList(); + return Task.FromResult>(sagas); + } +} diff --git a/Svrnty.CQRS.Sagas/SagaContext.cs b/Svrnty.CQRS.Sagas/SagaContext.cs new file mode 100644 index 0000000..79b717b --- /dev/null +++ b/Svrnty.CQRS.Sagas/SagaContext.cs @@ -0,0 +1,56 @@ +using System; +using System.Collections.Generic; +using Svrnty.CQRS.Sagas.Abstractions; + +namespace Svrnty.CQRS.Sagas; + +/// +/// Implementation of saga context providing runtime information during step execution. +/// +public class SagaContext : ISagaContext +{ + private readonly SagaState _state; + + /// + /// Creates a new saga context from a saga state. + /// + /// The saga state. + public SagaContext(SagaState state) + { + _state = state ?? throw new ArgumentNullException(nameof(state)); + } + + /// + public Guid SagaId => _state.SagaId; + + /// + public Guid CorrelationId => _state.CorrelationId; + + /// + public string SagaType => _state.SagaType; + + /// + public int CurrentStepIndex => _state.CurrentStepIndex; + + /// + public string CurrentStepName => _state.CurrentStepName ?? string.Empty; + + /// + public IReadOnlyDictionary StepResults => _state.StepResults; + + /// + public T? GetStepResult(string stepName) + { + if (_state.StepResults.TryGetValue(stepName, out var value) && value is T result) + { + return result; + } + return default; + } + + /// + public void SetStepResult(T result) + { + _state.StepResults[CurrentStepName] = result; + } +} diff --git a/Svrnty.CQRS.Sagas/SagaOrchestrator.cs b/Svrnty.CQRS.Sagas/SagaOrchestrator.cs new file mode 100644 index 0000000..1672baf --- /dev/null +++ b/Svrnty.CQRS.Sagas/SagaOrchestrator.cs @@ -0,0 +1,429 @@ +using System; +using System.Linq; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Svrnty.CQRS.Sagas.Abstractions; +using Svrnty.CQRS.Sagas.Abstractions.Messaging; +using Svrnty.CQRS.Sagas.Abstractions.Persistence; +using Svrnty.CQRS.Sagas.Builders; +using Svrnty.CQRS.Sagas.Configuration; + +namespace Svrnty.CQRS.Sagas; + +/// +/// Implementation of saga orchestration. +/// +public class SagaOrchestrator : ISagaOrchestrator +{ + private readonly IServiceProvider _serviceProvider; + private readonly ISagaStateStore _stateStore; + private readonly ISagaMessageBus? _messageBus; + private readonly ILogger _logger; + private readonly SagaOptions _options; + + /// + /// Creates a new saga orchestrator. + /// + public SagaOrchestrator( + IServiceProvider serviceProvider, + ISagaStateStore stateStore, + IOptions options, + ILogger logger, + ISagaMessageBus? messageBus = null) + { + _serviceProvider = serviceProvider; + _stateStore = stateStore; + _messageBus = messageBus; + _logger = logger; + _options = options.Value; + } + + /// + public Task StartAsync(TData initialData, CancellationToken cancellationToken = default) + where TSaga : ISaga + where TData : class, ISagaData, new() + { + return StartAsync(initialData, Guid.NewGuid(), cancellationToken); + } + + /// + public async Task StartAsync( + TData initialData, + Guid correlationId, + CancellationToken cancellationToken = default) + where TSaga : ISaga + where TData : class, ISagaData, new() + { + initialData.CorrelationId = correlationId; + + // Get the saga instance and configure it + var saga = _serviceProvider.GetRequiredService(); + var builder = new SagaBuilder(); + saga.Configure(builder); + + var steps = builder.Steps; + if (steps.Count == 0) + { + throw new InvalidOperationException($"Saga {typeof(TSaga).Name} has no steps configured."); + } + + // Create initial state + var state = new SagaState + { + SagaType = typeof(TSaga).FullName!, + CorrelationId = correlationId, + Status = SagaStatus.InProgress, + CurrentStepIndex = 0, + CurrentStepName = steps[0].Name, + SerializedData = JsonSerializer.Serialize(initialData) + }; + + state = await _stateStore.CreateAsync(state, cancellationToken); + + _logger.LogInformation( + "Started saga {SagaType} with ID {SagaId} and CorrelationId {CorrelationId}", + state.SagaType, state.SagaId, state.CorrelationId); + + // Execute the first step + await ExecuteNextStepAsync(state, steps, initialData, cancellationToken); + + return state; + } + + /// + public Task GetStateAsync(Guid sagaId, CancellationToken cancellationToken = default) + { + return _stateStore.GetByIdAsync(sagaId, cancellationToken); + } + + /// + public Task GetStateByCorrelationIdAsync(Guid correlationId, CancellationToken cancellationToken = default) + { + return _stateStore.GetByCorrelationIdAsync(correlationId, cancellationToken); + } + + /// + /// Handles a response from a remote step. + /// + public async Task HandleResponseAsync( + SagaStepResponse response, + CancellationToken cancellationToken = default) + where TData : class, ISagaData, new() + { + var state = await _stateStore.GetByIdAsync(response.SagaId, cancellationToken); + if (state == null) + { + _logger.LogWarning("Received response for unknown saga {SagaId}", response.SagaId); + return; + } + + var data = JsonSerializer.Deserialize(state.SerializedData!); + if (data == null) + { + _logger.LogError("Failed to deserialize saga data for {SagaId}", response.SagaId); + return; + } + + // Get the saga definition + var sagaType = Type.GetType(state.SagaType); + if (sagaType == null) + { + _logger.LogError("Unknown saga type {SagaType}", state.SagaType); + return; + } + + var saga = _serviceProvider.GetService(sagaType) as ISaga; + if (saga == null) + { + _logger.LogError("Could not resolve saga {SagaType}", state.SagaType); + return; + } + + var builder = new SagaBuilder(); + saga.Configure(builder); + var steps = builder.Steps; + + if (response.Success) + { + _logger.LogInformation( + "Step {StepName} completed successfully for saga {SagaId}", + response.StepName, response.SagaId); + + state.CompletedSteps.Add(response.StepName); + state.CurrentStepIndex++; + + if (state.CurrentStepIndex >= steps.Count) + { + // Saga completed + state.Status = SagaStatus.Completed; + state.CompletedAt = DateTimeOffset.UtcNow; + await _stateStore.UpdateAsync(state, cancellationToken); + + _logger.LogInformation("Saga {SagaId} completed successfully", state.SagaId); + } + else + { + // Move to next step + state.CurrentStepName = steps[state.CurrentStepIndex].Name; + await _stateStore.UpdateAsync(state, cancellationToken); + await ExecuteNextStepAsync(state, steps, data, cancellationToken); + } + } + else + { + _logger.LogError( + "Step {StepName} failed for saga {SagaId}: {Error}", + response.StepName, response.SagaId, response.ErrorMessage); + + state.Errors.Add(new SagaStepError( + response.StepName, + response.ErrorMessage ?? "Unknown error", + response.StackTrace, + DateTimeOffset.UtcNow)); + + if (_options.AutoCompensateOnFailure) + { + await StartCompensationAsync(state, steps, data, cancellationToken); + } + else + { + state.Status = SagaStatus.Failed; + await _stateStore.UpdateAsync(state, cancellationToken); + } + } + } + + private async Task ExecuteNextStepAsync( + SagaState state, + System.Collections.Generic.IReadOnlyList steps, + TData data, + CancellationToken cancellationToken) + where TData : class, ISagaData + { + if (state.CurrentStepIndex >= steps.Count) + { + state.Status = SagaStatus.Completed; + state.CompletedAt = DateTimeOffset.UtcNow; + await _stateStore.UpdateAsync(state, cancellationToken); + return; + } + + var step = steps[state.CurrentStepIndex]; + var context = new SagaContext(state); + + _logger.LogDebug( + "Executing step {StepName} ({StepIndex}/{TotalSteps}) for saga {SagaId}", + step.Name, state.CurrentStepIndex + 1, steps.Count, state.SagaId); + + try + { + if (step.IsRemote) + { + await ExecuteRemoteStepAsync(state, step, data, context, cancellationToken); + } + else + { + await ExecuteLocalStepAsync(state, step, data, context, steps, cancellationToken); + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Error executing step {StepName} for saga {SagaId}", step.Name, state.SagaId); + + state.Errors.Add(new SagaStepError( + step.Name, + ex.Message, + ex.StackTrace, + DateTimeOffset.UtcNow)); + + if (_options.AutoCompensateOnFailure) + { + await StartCompensationAsync(state, steps, data, cancellationToken); + } + else + { + state.Status = SagaStatus.Failed; + await _stateStore.UpdateAsync(state, cancellationToken); + } + } + } + + private async Task ExecuteLocalStepAsync( + SagaState state, + SagaStepDefinition step, + TData data, + SagaContext context, + System.Collections.Generic.IReadOnlyList steps, + CancellationToken cancellationToken) + where TData : class, ISagaData + { + if (step is LocalSagaStepDefinition localStep && localStep.ExecuteAction != null) + { + await localStep.ExecuteAction(data, context, cancellationToken); + } + + // Local step completed, update state and continue + state.CompletedSteps.Add(step.Name); + state.SerializedData = JsonSerializer.Serialize(data); + state.CurrentStepIndex++; + + if (state.CurrentStepIndex < steps.Count) + { + state.CurrentStepName = steps[state.CurrentStepIndex].Name; + } + + await _stateStore.UpdateAsync(state, cancellationToken); + + // Continue to next step + await ExecuteNextStepAsync(state, steps, data, cancellationToken); + } + + private async Task ExecuteRemoteStepAsync( + SagaState state, + SagaStepDefinition step, + TData data, + SagaContext context, + CancellationToken cancellationToken) + where TData : class, ISagaData + { + if (_messageBus == null) + { + throw new InvalidOperationException( + "Remote saga steps require a message bus. Configure RabbitMQ or another transport."); + } + + object? command = null; + string commandType; + + // Get the command from the step definition + var stepType = step.GetType(); + var commandBuilderProp = stepType.GetProperty("CommandBuilder"); + if (commandBuilderProp?.GetValue(step) is Delegate commandBuilder) + { + command = commandBuilder.DynamicInvoke(data, context); + } + + if (command == null) + { + throw new InvalidOperationException($"Step {step.Name} did not produce a command."); + } + + commandType = command.GetType().FullName!; + + var message = new SagaMessage + { + SagaId = state.SagaId, + CorrelationId = state.CorrelationId, + StepName = step.Name, + CommandType = commandType, + Payload = JsonSerializer.Serialize(command, command.GetType()) + }; + + await _messageBus.PublishAsync(message, cancellationToken); + await _stateStore.UpdateAsync(state, cancellationToken); + + _logger.LogDebug( + "Published command {CommandType} for step {StepName} of saga {SagaId}", + commandType, step.Name, state.SagaId); + } + + private async Task StartCompensationAsync( + SagaState state, + System.Collections.Generic.IReadOnlyList steps, + TData data, + CancellationToken cancellationToken) + where TData : class, ISagaData + { + _logger.LogInformation("Starting compensation for saga {SagaId}", state.SagaId); + + state.Status = SagaStatus.Compensating; + await _stateStore.UpdateAsync(state, cancellationToken); + + // Execute compensation in reverse order + var context = new SagaContext(state); + var completedSteps = state.CompletedSteps.ToList(); + + for (var i = completedSteps.Count - 1; i >= 0; i--) + { + var stepName = completedSteps[i]; + var step = steps.FirstOrDefault(s => s.Name == stepName); + + if (step == null || !step.HasCompensation) + { + continue; + } + + _logger.LogDebug("Compensating step {StepName} for saga {SagaId}", stepName, state.SagaId); + + try + { + if (step.IsRemote) + { + await ExecuteRemoteCompensationAsync(state, step, data, context, cancellationToken); + } + else if (step is LocalSagaStepDefinition localStep && localStep.CompensateAction != null) + { + await localStep.CompensateAction(data, context, cancellationToken); + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Error during compensation of step {StepName} for saga {SagaId}", + stepName, state.SagaId); + // Continue with other compensations even if one fails + } + } + + state.Status = SagaStatus.Compensated; + state.CompletedAt = DateTimeOffset.UtcNow; + await _stateStore.UpdateAsync(state, cancellationToken); + + _logger.LogInformation("Saga {SagaId} compensation completed", state.SagaId); + } + + private async Task ExecuteRemoteCompensationAsync( + SagaState state, + SagaStepDefinition step, + TData data, + SagaContext context, + CancellationToken cancellationToken) + where TData : class, ISagaData + { + if (_messageBus == null) + { + return; + } + + var stepType = step.GetType(); + var compensationBuilderProp = stepType.GetProperty("CompensationBuilder"); + var compensationTypeProp = stepType.GetProperty("CompensationCommandType"); + + if (compensationBuilderProp?.GetValue(step) is Delegate compensationBuilder && + compensationTypeProp?.GetValue(step) is Type compensationType) + { + var compensationCommand = compensationBuilder.DynamicInvoke(data, context); + if (compensationCommand != null) + { + var message = new SagaMessage + { + SagaId = state.SagaId, + CorrelationId = state.CorrelationId, + StepName = step.Name, + CommandType = compensationType.FullName!, + Payload = JsonSerializer.Serialize(compensationCommand, compensationType), + IsCompensation = true + }; + + await _messageBus.PublishAsync(message, cancellationToken); + + _logger.LogDebug( + "Published compensation command {CommandType} for step {StepName} of saga {SagaId}", + compensationType.Name, step.Name, state.SagaId); + } + } + } +} diff --git a/Svrnty.CQRS.Sagas/Svrnty.CQRS.Sagas.csproj b/Svrnty.CQRS.Sagas/Svrnty.CQRS.Sagas.csproj new file mode 100644 index 0000000..774cac7 --- /dev/null +++ b/Svrnty.CQRS.Sagas/Svrnty.CQRS.Sagas.csproj @@ -0,0 +1,38 @@ + + + net10.0 + true + 14 + enable + + Svrnty + David Lebee, Mathias Beaulieu-Duncan + icon.png + README.md + https://git.openharbor.io/svrnty/dotnet-cqrs + git + true + MIT + + portable + true + true + true + snupkg + + + + + + + + + + + + + + + + + diff --git a/Svrnty.CQRS.sln b/Svrnty.CQRS.sln index fcbe35b..37176b1 100644 --- a/Svrnty.CQRS.sln +++ b/Svrnty.CQRS.sln @@ -31,6 +31,18 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Svrnty.Sample", "Svrnty.Sam EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Svrnty.CQRS.DynamicQuery.MinimalApi", "Svrnty.CQRS.DynamicQuery.MinimalApi\Svrnty.CQRS.DynamicQuery.MinimalApi.csproj", "{1D0E3388-5E4B-4C0E-B826-ACF256FF7C84}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Svrnty.CQRS.Sagas.Abstractions", "Svrnty.CQRS.Sagas.Abstractions\Svrnty.CQRS.Sagas.Abstractions.csproj", "{13B6608A-596B-495B-9C08-F9B3F0D1915A}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Svrnty.CQRS.Sagas", "Svrnty.CQRS.Sagas\Svrnty.CQRS.Sagas.csproj", "{8EC9D12F-C8CD-4187-A1ED-47365D1C6B61}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Svrnty.CQRS.Sagas.RabbitMQ", "Svrnty.CQRS.Sagas.RabbitMQ\Svrnty.CQRS.Sagas.RabbitMQ.csproj", "{2EA39D64-B4A8-4A74-A2E6-D8A8E8312B68}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Svrnty.CQRS.DynamicQuery.EntityFramework", "Svrnty.CQRS.DynamicQuery.EntityFramework\Svrnty.CQRS.DynamicQuery.EntityFramework.csproj", "{25456A0B-69AF-4251-B34D-2A3873CD8D80}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Svrnty.CQRS.Events.Abstractions", "Svrnty.CQRS.Events.Abstractions\Svrnty.CQRS.Events.Abstractions.csproj", "{7905A4BB-2462-4FFF-9A29-3E4769D20FFC}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Svrnty.CQRS.Events.RabbitMQ", "Svrnty.CQRS.Events.RabbitMQ\Svrnty.CQRS.Events.RabbitMQ.csproj", "{3C7412EF-13C2-41F3-9D4C-D2BEC4843C8C}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -173,6 +185,78 @@ Global {1D0E3388-5E4B-4C0E-B826-ACF256FF7C84}.Release|x64.Build.0 = Release|Any CPU {1D0E3388-5E4B-4C0E-B826-ACF256FF7C84}.Release|x86.ActiveCfg = Release|Any CPU {1D0E3388-5E4B-4C0E-B826-ACF256FF7C84}.Release|x86.Build.0 = Release|Any CPU + {13B6608A-596B-495B-9C08-F9B3F0D1915A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {13B6608A-596B-495B-9C08-F9B3F0D1915A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {13B6608A-596B-495B-9C08-F9B3F0D1915A}.Debug|x64.ActiveCfg = Debug|Any CPU + {13B6608A-596B-495B-9C08-F9B3F0D1915A}.Debug|x64.Build.0 = Debug|Any CPU + {13B6608A-596B-495B-9C08-F9B3F0D1915A}.Debug|x86.ActiveCfg = Debug|Any CPU + {13B6608A-596B-495B-9C08-F9B3F0D1915A}.Debug|x86.Build.0 = Debug|Any CPU + {13B6608A-596B-495B-9C08-F9B3F0D1915A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {13B6608A-596B-495B-9C08-F9B3F0D1915A}.Release|Any CPU.Build.0 = Release|Any CPU + {13B6608A-596B-495B-9C08-F9B3F0D1915A}.Release|x64.ActiveCfg = Release|Any CPU + {13B6608A-596B-495B-9C08-F9B3F0D1915A}.Release|x64.Build.0 = Release|Any CPU + {13B6608A-596B-495B-9C08-F9B3F0D1915A}.Release|x86.ActiveCfg = Release|Any CPU + {13B6608A-596B-495B-9C08-F9B3F0D1915A}.Release|x86.Build.0 = Release|Any CPU + {8EC9D12F-C8CD-4187-A1ED-47365D1C6B61}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8EC9D12F-C8CD-4187-A1ED-47365D1C6B61}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8EC9D12F-C8CD-4187-A1ED-47365D1C6B61}.Debug|x64.ActiveCfg = Debug|Any CPU + {8EC9D12F-C8CD-4187-A1ED-47365D1C6B61}.Debug|x64.Build.0 = Debug|Any CPU + {8EC9D12F-C8CD-4187-A1ED-47365D1C6B61}.Debug|x86.ActiveCfg = Debug|Any CPU + {8EC9D12F-C8CD-4187-A1ED-47365D1C6B61}.Debug|x86.Build.0 = Debug|Any CPU + {8EC9D12F-C8CD-4187-A1ED-47365D1C6B61}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8EC9D12F-C8CD-4187-A1ED-47365D1C6B61}.Release|Any CPU.Build.0 = Release|Any CPU + {8EC9D12F-C8CD-4187-A1ED-47365D1C6B61}.Release|x64.ActiveCfg = Release|Any CPU + {8EC9D12F-C8CD-4187-A1ED-47365D1C6B61}.Release|x64.Build.0 = Release|Any CPU + {8EC9D12F-C8CD-4187-A1ED-47365D1C6B61}.Release|x86.ActiveCfg = Release|Any CPU + {8EC9D12F-C8CD-4187-A1ED-47365D1C6B61}.Release|x86.Build.0 = Release|Any CPU + {2EA39D64-B4A8-4A74-A2E6-D8A8E8312B68}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2EA39D64-B4A8-4A74-A2E6-D8A8E8312B68}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2EA39D64-B4A8-4A74-A2E6-D8A8E8312B68}.Debug|x64.ActiveCfg = Debug|Any CPU + {2EA39D64-B4A8-4A74-A2E6-D8A8E8312B68}.Debug|x64.Build.0 = Debug|Any CPU + {2EA39D64-B4A8-4A74-A2E6-D8A8E8312B68}.Debug|x86.ActiveCfg = Debug|Any CPU + {2EA39D64-B4A8-4A74-A2E6-D8A8E8312B68}.Debug|x86.Build.0 = Debug|Any CPU + {2EA39D64-B4A8-4A74-A2E6-D8A8E8312B68}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2EA39D64-B4A8-4A74-A2E6-D8A8E8312B68}.Release|Any CPU.Build.0 = Release|Any CPU + {2EA39D64-B4A8-4A74-A2E6-D8A8E8312B68}.Release|x64.ActiveCfg = Release|Any CPU + {2EA39D64-B4A8-4A74-A2E6-D8A8E8312B68}.Release|x64.Build.0 = Release|Any CPU + {2EA39D64-B4A8-4A74-A2E6-D8A8E8312B68}.Release|x86.ActiveCfg = Release|Any CPU + {2EA39D64-B4A8-4A74-A2E6-D8A8E8312B68}.Release|x86.Build.0 = Release|Any CPU + {25456A0B-69AF-4251-B34D-2A3873CD8D80}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {25456A0B-69AF-4251-B34D-2A3873CD8D80}.Debug|Any CPU.Build.0 = Debug|Any CPU + {25456A0B-69AF-4251-B34D-2A3873CD8D80}.Debug|x64.ActiveCfg = Debug|Any CPU + {25456A0B-69AF-4251-B34D-2A3873CD8D80}.Debug|x64.Build.0 = Debug|Any CPU + {25456A0B-69AF-4251-B34D-2A3873CD8D80}.Debug|x86.ActiveCfg = Debug|Any CPU + {25456A0B-69AF-4251-B34D-2A3873CD8D80}.Debug|x86.Build.0 = Debug|Any CPU + {25456A0B-69AF-4251-B34D-2A3873CD8D80}.Release|Any CPU.ActiveCfg = Release|Any CPU + {25456A0B-69AF-4251-B34D-2A3873CD8D80}.Release|Any CPU.Build.0 = Release|Any CPU + {25456A0B-69AF-4251-B34D-2A3873CD8D80}.Release|x64.ActiveCfg = Release|Any CPU + {25456A0B-69AF-4251-B34D-2A3873CD8D80}.Release|x64.Build.0 = Release|Any CPU + {25456A0B-69AF-4251-B34D-2A3873CD8D80}.Release|x86.ActiveCfg = Release|Any CPU + {25456A0B-69AF-4251-B34D-2A3873CD8D80}.Release|x86.Build.0 = Release|Any CPU + {7905A4BB-2462-4FFF-9A29-3E4769D20FFC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7905A4BB-2462-4FFF-9A29-3E4769D20FFC}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7905A4BB-2462-4FFF-9A29-3E4769D20FFC}.Debug|x64.ActiveCfg = Debug|Any CPU + {7905A4BB-2462-4FFF-9A29-3E4769D20FFC}.Debug|x64.Build.0 = Debug|Any CPU + {7905A4BB-2462-4FFF-9A29-3E4769D20FFC}.Debug|x86.ActiveCfg = Debug|Any CPU + {7905A4BB-2462-4FFF-9A29-3E4769D20FFC}.Debug|x86.Build.0 = Debug|Any CPU + {7905A4BB-2462-4FFF-9A29-3E4769D20FFC}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7905A4BB-2462-4FFF-9A29-3E4769D20FFC}.Release|Any CPU.Build.0 = Release|Any CPU + {7905A4BB-2462-4FFF-9A29-3E4769D20FFC}.Release|x64.ActiveCfg = Release|Any CPU + {7905A4BB-2462-4FFF-9A29-3E4769D20FFC}.Release|x64.Build.0 = Release|Any CPU + {7905A4BB-2462-4FFF-9A29-3E4769D20FFC}.Release|x86.ActiveCfg = Release|Any CPU + {7905A4BB-2462-4FFF-9A29-3E4769D20FFC}.Release|x86.Build.0 = Release|Any CPU + {3C7412EF-13C2-41F3-9D4C-D2BEC4843C8C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3C7412EF-13C2-41F3-9D4C-D2BEC4843C8C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3C7412EF-13C2-41F3-9D4C-D2BEC4843C8C}.Debug|x64.ActiveCfg = Debug|Any CPU + {3C7412EF-13C2-41F3-9D4C-D2BEC4843C8C}.Debug|x64.Build.0 = Debug|Any CPU + {3C7412EF-13C2-41F3-9D4C-D2BEC4843C8C}.Debug|x86.ActiveCfg = Debug|Any CPU + {3C7412EF-13C2-41F3-9D4C-D2BEC4843C8C}.Debug|x86.Build.0 = Debug|Any CPU + {3C7412EF-13C2-41F3-9D4C-D2BEC4843C8C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3C7412EF-13C2-41F3-9D4C-D2BEC4843C8C}.Release|Any CPU.Build.0 = Release|Any CPU + {3C7412EF-13C2-41F3-9D4C-D2BEC4843C8C}.Release|x64.ActiveCfg = Release|Any CPU + {3C7412EF-13C2-41F3-9D4C-D2BEC4843C8C}.Release|x64.Build.0 = Release|Any CPU + {3C7412EF-13C2-41F3-9D4C-D2BEC4843C8C}.Release|x86.ActiveCfg = Release|Any CPU + {3C7412EF-13C2-41F3-9D4C-D2BEC4843C8C}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE