diff --git a/Svrnty.CQRS.Abstractions/Security/AuthorizationCheckContext.cs b/Svrnty.CQRS.Abstractions/Security/AuthorizationCheckContext.cs new file mode 100644 index 0000000..81073c9 --- /dev/null +++ b/Svrnty.CQRS.Abstractions/Security/AuthorizationCheckContext.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; + +namespace Svrnty.CQRS.Abstractions.Security; + +/// +/// Shared shape for command and query authorization-check contexts. Checks +/// receive the request type, the materialized (and validated) request instance, +/// a scoped , and a free-form +/// dictionary that lets checks in the same pipeline pass signals to each other +/// (e.g. a future mobile-attestation check stamping "mobile_attested" for the +/// Altcha check to read). +/// +public abstract class AuthorizationCheckContext +{ + public required IServiceProvider Services { get; init; } + + public IDictionary Items { get; } = new Dictionary(); +} + +public sealed class CommandAuthorizationCheckContext : AuthorizationCheckContext +{ + public required Type CommandType { get; init; } + + public required object Command { get; init; } +} + +public sealed class QueryAuthorizationCheckContext : AuthorizationCheckContext +{ + public required Type QueryType { get; init; } + + public required object Query { get; init; } +} diff --git a/Svrnty.CQRS.Abstractions/Security/ICommandAuthorizationCheck.cs b/Svrnty.CQRS.Abstractions/Security/ICommandAuthorizationCheck.cs new file mode 100644 index 0000000..692ae42 --- /dev/null +++ b/Svrnty.CQRS.Abstractions/Security/ICommandAuthorizationCheck.cs @@ -0,0 +1,27 @@ +using System.Threading; +using System.Threading.Tasks; + +namespace Svrnty.CQRS.Abstractions.Security; + +/// +/// Cross-cutting authorization check that runs alongside (not in place of) the +/// consumer's . Multiple +/// implementations may be registered; the framework resolves them as +/// IEnumerable<ICommandAuthorizationCheck> and runs each in +/// registration order. AND semantics — any non- +/// short-circuits the pipeline. +/// +/// +/// Use this seam for self-applying, attribute-driven checks shipped by +/// framework modules (proof-of-work, mobile attestation, rate-limit gates, +/// IP allow-lists). The check is responsible for inspecting +/// attributes and +/// no-op'ing (return ) when it +/// doesn't apply. +/// +public interface ICommandAuthorizationCheck +{ + Task CheckAsync( + CommandAuthorizationCheckContext context, + CancellationToken cancellationToken = default); +} diff --git a/Svrnty.CQRS.Abstractions/Security/IQueryAuthorizationCheck.cs b/Svrnty.CQRS.Abstractions/Security/IQueryAuthorizationCheck.cs new file mode 100644 index 0000000..1d17a7a --- /dev/null +++ b/Svrnty.CQRS.Abstractions/Security/IQueryAuthorizationCheck.cs @@ -0,0 +1,15 @@ +using System.Threading; +using System.Threading.Tasks; + +namespace Svrnty.CQRS.Abstractions.Security; + +/// +/// Query-side counterpart to . See +/// that interface's remarks for usage. +/// +public interface IQueryAuthorizationCheck +{ + Task CheckAsync( + QueryAuthorizationCheckContext context, + CancellationToken cancellationToken = default); +} diff --git a/Svrnty.CQRS.Grpc.Generators/GrpcGenerator.cs b/Svrnty.CQRS.Grpc.Generators/GrpcGenerator.cs index 49731eb..811512a 100644 --- a/Svrnty.CQRS.Grpc.Generators/GrpcGenerator.cs +++ b/Svrnty.CQRS.Grpc.Generators/GrpcGenerator.cs @@ -2376,6 +2376,26 @@ public class GrpcGenerator : IIncrementalGenerator sb.AppendLine(" }"); sb.AppendLine(" }"); sb.AppendLine(); + sb.AppendLine(" // Authorization checks (cross-cutting; see ICommandAuthorizationCheck)"); + sb.AppendLine(" var commandChecks = serviceProvider.GetServices();"); + sb.AppendLine(" if (commandChecks != null)"); + sb.AppendLine(" {"); + sb.AppendLine(" var checkContext = new CommandAuthorizationCheckContext"); + sb.AppendLine(" {"); + sb.AppendLine($" CommandType = typeof({command.FullyQualifiedName}),"); + sb.AppendLine(" Command = command,"); + sb.AppendLine(" Services = serviceProvider"); + sb.AppendLine(" };"); + sb.AppendLine(" foreach (var check in commandChecks)"); + sb.AppendLine(" {"); + sb.AppendLine(" var checkResult = await check.CheckAsync(checkContext, context.CancellationToken);"); + sb.AppendLine(" if (checkResult == AuthorizationResult.Unauthorized)"); + sb.AppendLine(" throw new RpcException(new global::Grpc.Core.Status(global::Grpc.Core.StatusCode.Unauthenticated, \"Unauthorized\"));"); + sb.AppendLine(" if (checkResult == AuthorizationResult.Forbidden)"); + sb.AppendLine(" throw new RpcException(new global::Grpc.Core.Status(global::Grpc.Core.StatusCode.PermissionDenied, \"Forbidden\"));"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(); sb.AppendLine($" var handler = serviceProvider.GetRequiredService<{command.HandlerInterfaceName}>();"); if (command.HasResult) @@ -2493,6 +2513,27 @@ public class GrpcGenerator : IIncrementalGenerator sb.AppendLine(assignment); } sb.AppendLine(" };"); + sb.AppendLine(); + sb.AppendLine(" // Authorization checks (cross-cutting; see IQueryAuthorizationCheck)"); + sb.AppendLine(" var queryChecks = serviceProvider.GetServices();"); + sb.AppendLine(" if (queryChecks != null)"); + sb.AppendLine(" {"); + sb.AppendLine(" var checkContext = new QueryAuthorizationCheckContext"); + sb.AppendLine(" {"); + sb.AppendLine($" QueryType = typeof({query.FullyQualifiedName}),"); + sb.AppendLine(" Query = query,"); + sb.AppendLine(" Services = serviceProvider"); + sb.AppendLine(" };"); + sb.AppendLine(" foreach (var check in queryChecks)"); + sb.AppendLine(" {"); + sb.AppendLine(" var checkResult = await check.CheckAsync(checkContext, context.CancellationToken);"); + sb.AppendLine(" if (checkResult == AuthorizationResult.Unauthorized)"); + sb.AppendLine(" throw new RpcException(new global::Grpc.Core.Status(global::Grpc.Core.StatusCode.Unauthenticated, \"Unauthorized\"));"); + sb.AppendLine(" if (checkResult == AuthorizationResult.Forbidden)"); + sb.AppendLine(" throw new RpcException(new global::Grpc.Core.Status(global::Grpc.Core.StatusCode.PermissionDenied, \"Forbidden\"));"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(); sb.AppendLine(" var result = await handler.HandleAsync(query, context.CancellationToken);"); // Generate response with mapping if complex type @@ -2828,6 +2869,26 @@ public class GrpcGenerator : IIncrementalGenerator sb.AppendLine(" Aggregates = ConvertAggregates(request.Aggregates) ?? new()"); sb.AppendLine(" };"); sb.AppendLine(); + sb.AppendLine(" // Authorization checks (cross-cutting; see IQueryAuthorizationCheck)"); + sb.AppendLine(" var queryChecks = serviceProvider.GetServices();"); + sb.AppendLine(" if (queryChecks != null)"); + sb.AppendLine(" {"); + sb.AppendLine(" var checkContext = new QueryAuthorizationCheckContext"); + sb.AppendLine(" {"); + sb.AppendLine($" QueryType = typeof({dynamicQuery.QueryInterfaceName}),"); + sb.AppendLine(" Query = query,"); + sb.AppendLine(" Services = serviceProvider"); + sb.AppendLine(" };"); + sb.AppendLine(" foreach (var check in queryChecks)"); + sb.AppendLine(" {"); + sb.AppendLine(" var checkResult = await check.CheckAsync(checkContext, context.CancellationToken);"); + sb.AppendLine(" if (checkResult == AuthorizationResult.Unauthorized)"); + sb.AppendLine(" throw new RpcException(new global::Grpc.Core.Status(global::Grpc.Core.StatusCode.Unauthenticated, \"Unauthorized\"));"); + sb.AppendLine(" if (checkResult == AuthorizationResult.Forbidden)"); + sb.AppendLine(" throw new RpcException(new global::Grpc.Core.Status(global::Grpc.Core.StatusCode.PermissionDenied, \"Forbidden\"));"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(); // Get the handler and execute sb.AppendLine($" var handler = serviceProvider.GetRequiredService>>();"); diff --git a/Svrnty.CQRS.MinimalApi/EndpointRouteBuilderExtensions.cs b/Svrnty.CQRS.MinimalApi/EndpointRouteBuilderExtensions.cs index fe5e3f0..c7252df 100644 --- a/Svrnty.CQRS.MinimalApi/EndpointRouteBuilderExtensions.cs +++ b/Svrnty.CQRS.MinimalApi/EndpointRouteBuilderExtensions.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Linq; using System.Reflection; using System.Threading; @@ -16,6 +17,64 @@ namespace Svrnty.CQRS.MinimalApi; public static class EndpointRouteBuilderExtensions { + private static async Task RunCommandChecksAsync( + IServiceProvider serviceProvider, + Type commandType, + object command, + CancellationToken cancellationToken) + { + var checks = serviceProvider.GetServices().ToList(); + if (checks.Count == 0) + return null; + + var context = new CommandAuthorizationCheckContext + { + CommandType = commandType, + Command = command, + Services = serviceProvider + }; + + foreach (var check in checks) + { + var result = await check.CheckAsync(context, cancellationToken); + if (result == AuthorizationResult.Forbidden) + return Results.StatusCode(403); + if (result == AuthorizationResult.Unauthorized) + return Results.Unauthorized(); + } + + return null; + } + + private static async Task RunQueryChecksAsync( + IServiceProvider serviceProvider, + Type queryType, + object query, + CancellationToken cancellationToken) + { + var checks = serviceProvider.GetServices().ToList(); + if (checks.Count == 0) + return null; + + var context = new QueryAuthorizationCheckContext + { + QueryType = queryType, + Query = query, + Services = serviceProvider + }; + + foreach (var check in checks) + { + var result = await check.CheckAsync(context, cancellationToken); + if (result == AuthorizationResult.Forbidden) + return Results.StatusCode(403); + if (result == AuthorizationResult.Unauthorized) + return Results.Unauthorized(); + } + + return null; + } + public static IEndpointRouteBuilder MapSvrntyQueries(this IEndpointRouteBuilder endpoints, string routePrefix = "api/query") { var queryDiscovery = endpoints.ServiceProvider.GetRequiredService(); @@ -63,6 +122,10 @@ public static class EndpointRouteBuilderExtensions if (query == null || !queryMeta.QueryType.IsInstanceOfType(query)) return Results.BadRequest("Invalid query payload"); + var checkResult = await RunQueryChecksAsync(serviceProvider, queryMeta.QueryType, query, cancellationToken); + if (checkResult != null) + return checkResult; + var handler = serviceProvider.GetRequiredService(handlerType); var handleMethod = handlerType.GetMethod("HandleAsync"); if (handleMethod == null) @@ -128,6 +191,10 @@ public static class EndpointRouteBuilderExtensions } } + var checkResult = await RunQueryChecksAsync(serviceProvider, queryMeta.QueryType, query, cancellationToken); + if (checkResult != null) + return checkResult; + var handler = serviceProvider.GetRequiredService(handlerType); var handleMethod = handlerType.GetMethod("HandleAsync"); if (handleMethod == null) @@ -198,6 +265,10 @@ public static class EndpointRouteBuilderExtensions if (command == null || !commandMeta.CommandType.IsInstanceOfType(command)) return Results.BadRequest("Invalid command payload"); + var checkResult = await RunCommandChecksAsync(serviceProvider, commandMeta.CommandType, command, cancellationToken); + if (checkResult != null) + return checkResult; + var handler = serviceProvider.GetRequiredService(handlerType); var handleMethod = handlerType.GetMethod("HandleAsync"); if (handleMethod == null) @@ -240,6 +311,10 @@ public static class EndpointRouteBuilderExtensions if (command == null || !commandMeta.CommandType.IsInstanceOfType(command)) return Results.BadRequest("Invalid command payload"); + var checkResult = await RunCommandChecksAsync(serviceProvider, commandMeta.CommandType, command, cancellationToken); + if (checkResult != null) + return checkResult; + var handler = serviceProvider.GetRequiredService(handlerType); var handleMethod = handlerType.GetMethod("HandleAsync"); if (handleMethod == null)