using System;
using System.Linq;
using System.Reflection;
using Microsoft.Extensions.DependencyInjection;
using Svrnty.CQRS.Configuration;
namespace Svrnty.CQRS.Grpc;
///
/// Extension methods for CqrsBuilder to add gRPC support
///
public static class CqrsBuilderExtensions
{
///
/// Adds gRPC support to the CQRS pipeline
///
/// The CQRS builder
/// Optional configuration for gRPC endpoints
/// The CQRS builder for method chaining
public static CqrsBuilder AddGrpc(this CqrsBuilder builder, Action? configure = null)
{
var options = new GrpcCqrsOptions();
configure?.Invoke(options);
builder.Configuration.SetConfiguration(options);
// Try to find and call the generated AddGrpcFromConfiguration method for service registration
var addGrpcMethod = FindExtensionMethod("AddGrpcFromConfiguration", typeof(IServiceCollection));
if (addGrpcMethod != null)
{
addGrpcMethod.Invoke(null, new object[] { builder.Services });
}
else
{
Console.WriteLine("Warning: AddGrpcFromConfiguration not found. gRPC services were not registered.");
Console.WriteLine("Make sure your project has source generators enabled and references Svrnty.CQRS.Grpc.Generators.");
}
// Register mapping callback for automatic endpoint mapping
builder.Configuration.AddMappingCallback(app =>
{
// Find the generated MapGrpcFromConfiguration method
var mapGrpcMethod = FindExtensionMethod("MapGrpcFromConfiguration", typeof(Microsoft.AspNetCore.Routing.IEndpointRouteBuilder));
if (mapGrpcMethod != null)
{
mapGrpcMethod.Invoke(null, new object[] { app });
}
});
return builder;
}
private static MethodInfo? FindExtensionMethod(string methodName, Type parameterType)
{
// Search through all loaded assemblies for the extension method
foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies())
{
try
{
var types = assembly.GetTypes()
.Where(t => t.IsClass && t.IsSealed && !t.IsGenericType && t.IsPublic);
foreach (var type in types)
{
var method = type.GetMethod(methodName,
BindingFlags.Static | BindingFlags.Public,
null,
new[] { parameterType },
null);
if (method != null)
return method;
}
}
catch
{
// Skip assemblies that can't be inspected
}
}
return null;
}
}