Compare commits

..

2 Commits

Author SHA1 Message Date
46e739eead fix: handle nullable numeric types in gRPC request mapping
When a C# property is nullable (int?, long?, etc.), protobuf3 defaults
to 0 when the field is not set. The generator now treats 0 as null
for nullable numeric properties, allowing proper optional field semantics.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-25 11:30:01 -05:00
179b06374d fix: improve gRPC source generator type mapping and property naming
- Add ProtoPropertyName to PropertyInfo for correct proto property naming
- Fix ToPascalCaseHelper to match Grpc.Tools naming (e.g., value_per100g → ValuePer100G)
- Add IsResultNullable and ResultTypeWithoutNullable to QueryInfo
- Fix IsPrimitiveType to correctly handle nullable complex types
- Add GetCollectionElementType helper (excludes strings from collection detection)
- Use AddRange pattern for repeated/collection fields in proto messages
- Add explicit Analyzer reference in props for reliable source generator loading
- Handle null cases in single complex type response mapping
- Fix collection properties in complex results with proper nested type mapping

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-25 10:47:05 -05:00
10 changed files with 708 additions and 1330 deletions

View File

@ -23,6 +23,7 @@ public static class EndpointRouteBuilderExtensions
public static IEndpointRouteBuilder MapSvrntyDynamicQueries(this IEndpointRouteBuilder endpoints, string routePrefix = "api/query") public static IEndpointRouteBuilder MapSvrntyDynamicQueries(this IEndpointRouteBuilder endpoints, string routePrefix = "api/query")
{ {
var queryDiscovery = endpoints.ServiceProvider.GetRequiredService<IQueryDiscovery>(); var queryDiscovery = endpoints.ServiceProvider.GetRequiredService<IQueryDiscovery>();
var authorizationService = endpoints.ServiceProvider.GetService<IQueryAuthorizationService>();
foreach (var queryMeta in queryDiscovery.GetQueries()) foreach (var queryMeta in queryDiscovery.GetQueries())
{ {
@ -42,14 +43,14 @@ public static class EndpointRouteBuilderExtensions
if (dynamicQueryMeta.ParamsType == null) if (dynamicQueryMeta.ParamsType == null)
{ {
// DynamicQuery<TSource, TDestination> // DynamicQuery<TSource, TDestination>
MapDynamicQueryPost(endpoints, route, dynamicQueryMeta); MapDynamicQueryPost(endpoints, route, dynamicQueryMeta, authorizationService);
MapDynamicQueryGet(endpoints, route, dynamicQueryMeta); MapDynamicQueryGet(endpoints, route, dynamicQueryMeta, authorizationService);
} }
else else
{ {
// DynamicQuery<TSource, TDestination, TParams> // DynamicQuery<TSource, TDestination, TParams>
MapDynamicQueryWithParamsPost(endpoints, route, dynamicQueryMeta); MapDynamicQueryWithParamsPost(endpoints, route, dynamicQueryMeta, authorizationService);
MapDynamicQueryWithParamsGet(endpoints, route, dynamicQueryMeta); MapDynamicQueryWithParamsGet(endpoints, route, dynamicQueryMeta, authorizationService);
} }
} }
@ -59,7 +60,8 @@ public static class EndpointRouteBuilderExtensions
private static void MapDynamicQueryPost( private static void MapDynamicQueryPost(
IEndpointRouteBuilder endpoints, IEndpointRouteBuilder endpoints,
string route, string route,
DynamicQueryMeta dynamicQueryMeta) DynamicQueryMeta dynamicQueryMeta,
IQueryAuthorizationService? authorizationService)
{ {
var sourceType = dynamicQueryMeta.SourceType; var sourceType = dynamicQueryMeta.SourceType;
var destinationType = dynamicQueryMeta.DestinationType; var destinationType = dynamicQueryMeta.DestinationType;
@ -73,7 +75,7 @@ public static class EndpointRouteBuilderExtensions
.GetMethod(nameof(MapDynamicQueryPostTyped), BindingFlags.NonPublic | BindingFlags.Static)! .GetMethod(nameof(MapDynamicQueryPostTyped), BindingFlags.NonPublic | BindingFlags.Static)!
.MakeGenericMethod(sourceType, destinationType); .MakeGenericMethod(sourceType, destinationType);
var endpoint = (RouteHandlerBuilder)mapPostMethod.Invoke(null, [endpoints, route, queryType, handlerType])!; var endpoint = (RouteHandlerBuilder)mapPostMethod.Invoke(null, [endpoints, route, queryType, handlerType, authorizationService])!;
endpoint endpoint
.WithName($"DynamicQuery_{dynamicQueryMeta.LowerCamelCaseName}_Post") .WithName($"DynamicQuery_{dynamicQueryMeta.LowerCamelCaseName}_Post")
@ -89,7 +91,8 @@ public static class EndpointRouteBuilderExtensions
IEndpointRouteBuilder endpoints, IEndpointRouteBuilder endpoints,
string route, string route,
Type queryType, Type queryType,
Type handlerType) Type handlerType,
IQueryAuthorizationService? authorizationService)
where TSource : class where TSource : class
where TDestination : class where TDestination : class
{ {
@ -99,7 +102,6 @@ public static class EndpointRouteBuilderExtensions
IServiceProvider serviceProvider, IServiceProvider serviceProvider,
CancellationToken cancellationToken) => CancellationToken cancellationToken) =>
{ {
var authorizationService = serviceProvider.GetService<IQueryAuthorizationService>();
if (authorizationService != null) if (authorizationService != null)
{ {
var authorizationResult = await authorizationService.IsAllowedAsync(queryType, cancellationToken); var authorizationResult = await authorizationService.IsAllowedAsync(queryType, cancellationToken);
@ -127,7 +129,8 @@ public static class EndpointRouteBuilderExtensions
private static void MapDynamicQueryGet( private static void MapDynamicQueryGet(
IEndpointRouteBuilder endpoints, IEndpointRouteBuilder endpoints,
string route, string route,
DynamicQueryMeta dynamicQueryMeta) DynamicQueryMeta dynamicQueryMeta,
IQueryAuthorizationService? authorizationService)
{ {
var sourceType = dynamicQueryMeta.SourceType; var sourceType = dynamicQueryMeta.SourceType;
var destinationType = dynamicQueryMeta.DestinationType; var destinationType = dynamicQueryMeta.DestinationType;
@ -138,7 +141,6 @@ public static class EndpointRouteBuilderExtensions
endpoints.MapGet(route, async (HttpContext context, IServiceProvider serviceProvider, CancellationToken cancellationToken) => endpoints.MapGet(route, async (HttpContext context, IServiceProvider serviceProvider, CancellationToken cancellationToken) =>
{ {
var authorizationService = serviceProvider.GetService<IQueryAuthorizationService>();
if (authorizationService != null) if (authorizationService != null)
{ {
var authorizationResult = await authorizationService.IsAllowedAsync(queryType, cancellationToken); var authorizationResult = await authorizationService.IsAllowedAsync(queryType, cancellationToken);
@ -197,7 +199,8 @@ public static class EndpointRouteBuilderExtensions
private static void MapDynamicQueryWithParamsPost( private static void MapDynamicQueryWithParamsPost(
IEndpointRouteBuilder endpoints, IEndpointRouteBuilder endpoints,
string route, string route,
DynamicQueryMeta dynamicQueryMeta) DynamicQueryMeta dynamicQueryMeta,
IQueryAuthorizationService? authorizationService)
{ {
var sourceType = dynamicQueryMeta.SourceType; var sourceType = dynamicQueryMeta.SourceType;
var destinationType = dynamicQueryMeta.DestinationType; var destinationType = dynamicQueryMeta.DestinationType;
@ -211,7 +214,7 @@ public static class EndpointRouteBuilderExtensions
.GetMethod(nameof(MapDynamicQueryWithParamsPostTyped), BindingFlags.NonPublic | BindingFlags.Static)! .GetMethod(nameof(MapDynamicQueryWithParamsPostTyped), BindingFlags.NonPublic | BindingFlags.Static)!
.MakeGenericMethod(sourceType, destinationType, paramsType); .MakeGenericMethod(sourceType, destinationType, paramsType);
var endpoint = (RouteHandlerBuilder)mapPostMethod.Invoke(null, [endpoints, route, queryType, handlerType])!; var endpoint = (RouteHandlerBuilder)mapPostMethod.Invoke(null, [endpoints, route, queryType, handlerType, authorizationService])!;
endpoint endpoint
.WithName($"DynamicQuery_{dynamicQueryMeta.LowerCamelCaseName}_WithParams_Post") .WithName($"DynamicQuery_{dynamicQueryMeta.LowerCamelCaseName}_WithParams_Post")
@ -227,7 +230,8 @@ public static class EndpointRouteBuilderExtensions
IEndpointRouteBuilder endpoints, IEndpointRouteBuilder endpoints,
string route, string route,
Type queryType, Type queryType,
Type handlerType) Type handlerType,
IQueryAuthorizationService? authorizationService)
where TSource : class where TSource : class
where TDestination : class where TDestination : class
where TParams : class where TParams : class
@ -238,7 +242,6 @@ public static class EndpointRouteBuilderExtensions
IServiceProvider serviceProvider, IServiceProvider serviceProvider,
CancellationToken cancellationToken) => CancellationToken cancellationToken) =>
{ {
var authorizationService = serviceProvider.GetService<IQueryAuthorizationService>();
if (authorizationService != null) if (authorizationService != null)
{ {
var authorizationResult = await authorizationService.IsAllowedAsync(queryType, cancellationToken); var authorizationResult = await authorizationService.IsAllowedAsync(queryType, cancellationToken);
@ -266,7 +269,8 @@ public static class EndpointRouteBuilderExtensions
private static void MapDynamicQueryWithParamsGet( private static void MapDynamicQueryWithParamsGet(
IEndpointRouteBuilder endpoints, IEndpointRouteBuilder endpoints,
string route, string route,
DynamicQueryMeta dynamicQueryMeta) DynamicQueryMeta dynamicQueryMeta,
IQueryAuthorizationService? authorizationService)
{ {
var sourceType = dynamicQueryMeta.SourceType; var sourceType = dynamicQueryMeta.SourceType;
var destinationType = dynamicQueryMeta.DestinationType; var destinationType = dynamicQueryMeta.DestinationType;
@ -278,7 +282,6 @@ public static class EndpointRouteBuilderExtensions
endpoints.MapGet(route, async (HttpContext context, IServiceProvider serviceProvider, CancellationToken cancellationToken) => endpoints.MapGet(route, async (HttpContext context, IServiceProvider serviceProvider, CancellationToken cancellationToken) =>
{ {
var authorizationService = serviceProvider.GetService<IQueryAuthorizationService>();
if (authorizationService != null) if (authorizationService != null)
{ {
var authorizationResult = await authorizationService.IsAllowedAsync(queryType, cancellationToken); var authorizationResult = await authorizationService.IsAllowedAsync(queryType, cancellationToken);

View File

@ -35,11 +35,7 @@ public abstract class DynamicQueryHandlerBase<TSource, TDestination>
protected virtual Task<IQueryable<TSource>> GetQueryableAsync(IDynamicQuery query, CancellationToken cancellationToken = default) protected virtual Task<IQueryable<TSource>> GetQueryableAsync(IDynamicQuery query, CancellationToken cancellationToken = default)
{ {
if (_queryableProviders.Any()) if (_queryableProviders.Any())
{ return _queryableProviders.ElementAt(0).GetQueryableAsync(query, cancellationToken);
// Use Last() to prefer closed generic registrations (overrides) over open generic (default)
// Registration order: open generic first, closed generic (override) last
return _queryableProviders.Last().GetQueryableAsync(query, cancellationToken);
}
throw new Exception($"You must provide a QueryableProvider<TSource> for {typeof(TSource).Name}"); throw new Exception($"You must provide a QueryableProvider<TSource> for {typeof(TSource).Name}");
} }

File diff suppressed because it is too large Load Diff

View File

@ -49,12 +49,6 @@ namespace Svrnty.CQRS.Grpc.Generators.Helpers
isRepeated = false; isRepeated = false;
isOptional = false; isOptional = false;
// Handle byte[] as bytes proto type (NOT repeated uint32)
if (csharpType == "System.Byte[]" || csharpType == "byte[]" || csharpType == "Byte[]")
{
return "bytes";
}
// Handle arrays // Handle arrays
if (csharpType.EndsWith("[]")) if (csharpType.EndsWith("[]"))
{ {

View File

@ -31,6 +31,11 @@ namespace Svrnty.CQRS.Grpc.Generators.Models
public class PropertyInfo public class PropertyInfo
{ {
public string Name { get; set; } public string Name { get; set; }
/// <summary>
/// The property name as generated by Grpc.Tools from the proto field name.
/// This may differ from the C# property name due to casing differences.
/// </summary>
public string ProtoPropertyName { get; set; }
public string Type { get; set; } public string Type { get; set; }
public string FullyQualifiedType { get; set; } public string FullyQualifiedType { get; set; }
public string ProtoType { get; set; } public string ProtoType { get; set; }
@ -44,21 +49,14 @@ namespace Svrnty.CQRS.Grpc.Generators.Models
public bool IsNullable { get; set; } public bool IsNullable { get; set; }
public bool IsDecimal { get; set; } public bool IsDecimal { get; set; }
public bool IsDateTime { get; set; } public bool IsDateTime { get; set; }
public bool IsDateTimeOffset { get; set; }
public bool IsGuid { get; set; }
public bool IsJsonElement { get; set; }
public bool IsBinaryType { get; set; } // Stream, byte[], MemoryStream
public bool IsStream { get; set; } // Specifically Stream types (not byte[])
public bool IsReadOnly { get; set; } // Read-only/computed properties should be skipped
public bool IsValueTypeCollection { get; set; } // Value types that implement IList<T> (like NpgsqlPolygon)
public string? ElementType { get; set; } public string? ElementType { get; set; }
public bool IsElementComplexType { get; set; } public bool IsElementComplexType { get; set; }
public bool IsElementGuid { get; set; }
public List<PropertyInfo>? ElementNestedProperties { get; set; } public List<PropertyInfo>? ElementNestedProperties { get; set; }
public PropertyInfo() public PropertyInfo()
{ {
Name = string.Empty; Name = string.Empty;
ProtoPropertyName = string.Empty;
Type = string.Empty; Type = string.Empty;
FullyQualifiedType = string.Empty; FullyQualifiedType = string.Empty;
ProtoType = string.Empty; ProtoType = string.Empty;
@ -69,15 +67,7 @@ namespace Svrnty.CQRS.Grpc.Generators.Models
IsNullable = false; IsNullable = false;
IsDecimal = false; IsDecimal = false;
IsDateTime = false; IsDateTime = false;
IsDateTimeOffset = false;
IsGuid = false;
IsJsonElement = false;
IsBinaryType = false;
IsStream = false;
IsReadOnly = false;
IsValueTypeCollection = false;
IsElementComplexType = false; IsElementComplexType = false;
IsElementGuid = false;
} }
} }
} }

View File

@ -13,6 +13,26 @@ namespace Svrnty.CQRS.Grpc.Generators.Models
public string HandlerInterfaceName { get; set; } public string HandlerInterfaceName { get; set; }
public List<PropertyInfo> ResultProperties { get; set; } public List<PropertyInfo> ResultProperties { get; set; }
public bool IsResultPrimitiveType { get; set; } public bool IsResultPrimitiveType { get; set; }
/// <summary>
/// True if the result type is a collection (List, IEnumerable, etc.)
/// </summary>
public bool IsResultCollection { get; set; }
/// <summary>
/// The element type name if IsResultCollection is true
/// </summary>
public string ResultElementType { get; set; }
/// <summary>
/// The fully qualified element type name if IsResultCollection is true
/// </summary>
public string ResultElementTypeFullyQualified { get; set; }
/// <summary>
/// True if the result type is nullable (ends with ? or is Nullable<T>)
/// </summary>
public bool IsResultNullable { get; set; }
/// <summary>
/// The result type name without the nullable annotation (e.g., "CnfFoodDetailItem" instead of "CnfFoodDetailItem?")
/// </summary>
public string ResultTypeWithoutNullable { get; set; }
public QueryInfo() public QueryInfo()
{ {
@ -25,6 +45,11 @@ namespace Svrnty.CQRS.Grpc.Generators.Models
HandlerInterfaceName = string.Empty; HandlerInterfaceName = string.Empty;
ResultProperties = new List<PropertyInfo>(); ResultProperties = new List<PropertyInfo>();
IsResultPrimitiveType = false; IsResultPrimitiveType = false;
IsResultCollection = false;
ResultElementType = string.Empty;
ResultElementTypeFullyQualified = string.Empty;
IsResultNullable = false;
ResultTypeWithoutNullable = string.Empty;
} }
} }
} }

View File

@ -320,9 +320,7 @@ internal class ProtoFileGenerator
var properties = type.GetMembers() var properties = type.GetMembers()
.OfType<IPropertySymbol>() .OfType<IPropertySymbol>()
.Where(p => p.DeclaredAccessibility == Accessibility.Public && .Where(p => p.DeclaredAccessibility == Accessibility.Public)
!p.IsIndexer &&
!ProtoFileTypeMapper.IsCollectionInternalProperty(p.Name))
.ToList(); .ToList();
// Collect nested complex types to generate after closing this message // Collect nested complex types to generate after closing this message
@ -405,9 +403,14 @@ internal class ProtoFileGenerator
_messagesBuilder.AppendLine(); _messagesBuilder.AppendLine();
// Generate complex type message after closing the response message // Generate complex type message after closing the response message
if (resultType != null && IsComplexType(resultType)) // Use GetElementOrUnderlyingType to extract element type from collections (e.g., List<CnfFoodItem> -> CnfFoodItem)
if (resultType != null)
{ {
GenerateComplexTypeMessage(resultType as INamedTypeSymbol); var underlyingType = ProtoFileTypeMapper.GetElementOrUnderlyingType(resultType);
if (IsComplexType(underlyingType) && underlyingType is INamedTypeSymbol namedType)
{
GenerateComplexTypeMessage(namedType);
}
} }
} }
@ -425,38 +428,14 @@ internal class ProtoFileGenerator
_messagesBuilder.AppendLine($"// {type.Name} entity"); _messagesBuilder.AppendLine($"// {type.Name} entity");
_messagesBuilder.AppendLine($"message {type.Name} {{"); _messagesBuilder.AppendLine($"message {type.Name} {{");
// Collect nested complex types to generate after closing this message
var nestedComplexTypes = new List<INamedTypeSymbol>();
// Check if this type is a collection (implements IList<T>, ICollection<T>, etc.)
var collectionElementType = ProtoFileTypeMapper.GetCollectionElementTypeByInterface(type);
if (collectionElementType != null)
{
// This type is a collection - generate a single repeated field for items
var protoElementType = ProtoFileTypeMapper.MapType(collectionElementType, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
_messagesBuilder.AppendLine($" repeated {protoElementType} items = 1;");
// Track the element type for nested generation
if (IsComplexType(collectionElementType) && collectionElementType is INamedTypeSymbol elementNamedType)
{
nestedComplexTypes.Add(elementNamedType);
}
}
else
{
// Not a collection - generate properties as usual
var properties = type.GetMembers() var properties = type.GetMembers()
.OfType<IPropertySymbol>() .OfType<IPropertySymbol>()
.Where(p => p.DeclaredAccessibility == Accessibility.Public && .Where(p => p.DeclaredAccessibility == Accessibility.Public)
!p.IsIndexer &&
!ProtoFileTypeMapper.IsCollectionInternalProperty(p.Name))
.ToList(); .ToList();
// Collect nested complex types to generate after closing this message
var nestedComplexTypes = new List<INamedTypeSymbol>();
int fieldNumber = 1; int fieldNumber = 1;
foreach (var prop in properties) foreach (var prop in properties)
{ {
@ -492,7 +471,6 @@ internal class ProtoFileGenerator
fieldNumber++; fieldNumber++;
} }
}
_messagesBuilder.AppendLine("}"); _messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine(); _messagesBuilder.AppendLine();
@ -764,7 +742,7 @@ internal class ProtoFileGenerator
FullyQualifiedName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) FullyQualifiedName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
.Replace("global::", ""), .Replace("global::", ""),
Namespace = type.ContainingNamespace?.ToDisplayString() ?? "", Namespace = type.ContainingNamespace?.ToDisplayString() ?? "",
SubscriptionKeyProperty = subscriptionKeyProp!, // Already validated as non-null above SubscriptionKeyProperty = subscriptionKeyProp,
SubscriptionKeyInfo = keyPropInfo, SubscriptionKeyInfo = keyPropInfo,
Properties = properties Properties = properties
}); });
@ -782,9 +760,7 @@ internal class ProtoFileGenerator
int fieldNumber = 1; int fieldNumber = 1;
foreach (var prop in type.GetMembers().OfType<IPropertySymbol>() foreach (var prop in type.GetMembers().OfType<IPropertySymbol>()
.Where(p => p.DeclaredAccessibility == Accessibility.Public && .Where(p => p.DeclaredAccessibility == Accessibility.Public))
!p.IsIndexer &&
!ProtoFileTypeMapper.IsCollectionInternalProperty(p.Name)))
{ {
if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type)) if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type))
continue; continue;
@ -846,17 +822,15 @@ internal class ProtoFileGenerator
foreach (var prop in notification.Properties) foreach (var prop in notification.Properties)
{ {
var typeSymbol = _compilation.GetTypeByMetadataName(prop.FullyQualifiedType) ?? var protoType = ProtoFileTypeMapper.MapType(
GetTypeFromName(prop.FullyQualifiedType); _compilation.GetTypeByMetadataName(prop.FullyQualifiedType) ??
GetTypeFromName(prop.FullyQualifiedType),
out var needsImport, out var importPath);
if (typeSymbol != null)
{
ProtoFileTypeMapper.MapType(typeSymbol, out var needsImport, out var importPath);
if (needsImport && importPath != null) if (needsImport && importPath != null)
{ {
_requiredImports.Add(importPath); _requiredImports.Add(importPath);
} }
}
var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name); var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name);
_messagesBuilder.AppendLine($" {prop.ProtoType} {fieldName} = {prop.FieldNumber};"); _messagesBuilder.AppendLine($" {prop.ProtoType} {fieldName} = {prop.FieldNumber};");

View File

@ -20,68 +20,6 @@ internal static class ProtoFileTypeMapper
// Note: NullableAnnotation.Annotated is for reference type nullability (List<T>?, string?, etc.) // Note: NullableAnnotation.Annotated is for reference type nullability (List<T>?, string?, etc.)
// We don't unwrap these - just use the underlying type. Nullable<T> value types are handled later. // We don't unwrap these - just use the underlying type. Nullable<T> value types are handled later.
// Handle Nullable<T> value types (e.g., int?, decimal?, enum?) FIRST
if (typeSymbol is INamedTypeSymbol nullableType &&
nullableType.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T &&
nullableType.TypeArguments.Length == 1)
{
// Unwrap the nullable and map the inner type
return MapType(nullableType.TypeArguments[0], out needsImport, out importPath);
}
// Handle collections BEFORE basic type checks (to avoid matching List<Guid> as Guid)
if (typeSymbol is INamedTypeSymbol collectionType)
{
// List, IEnumerable, Array, ICollection etc. (but not Nullable<T>)
var collectionTypeName = collectionType.Name;
if (collectionType.TypeArguments.Length == 1 &&
(collectionTypeName.Contains("List") || collectionTypeName.Contains("Collection") ||
collectionTypeName.Contains("Enumerable") || collectionTypeName.Contains("Array") ||
collectionTypeName.Contains("Set") || collectionTypeName.Contains("IList") ||
collectionTypeName.Contains("ICollection") || collectionTypeName.Contains("IEnumerable")))
{
var elementType = collectionType.TypeArguments[0];
var protoElementType = MapType(elementType, out needsImport, out importPath);
return $"repeated {protoElementType}";
}
// Dictionary<K, V>
if (collectionType.TypeArguments.Length == 2 &&
(collectionTypeName.Contains("Dictionary") || collectionTypeName.Contains("IDictionary")))
{
var keyType = MapType(collectionType.TypeArguments[0], out var keyNeedsImport, out var keyImportPath);
var valueType = MapType(collectionType.TypeArguments[1], out var valueNeedsImport, out var valueImportPath);
// Set import flags if either key or value needs imports
if (keyNeedsImport)
{
needsImport = true;
importPath = keyImportPath;
}
if (valueNeedsImport)
{
needsImport = true;
importPath = valueImportPath; // Note: This only captures last import, may need improvement
}
return $"map<{keyType}, {valueType}>";
}
}
// Handle byte[] array type (check before switch since it's an array)
if (typeSymbol is IArrayTypeSymbol arrayType && arrayType.ElementType.SpecialType == SpecialType.System_Byte)
{
return "bytes";
}
// Handle Stream types -> bytes
if (fullTypeName.Contains("System.IO.Stream") ||
fullTypeName.Contains("System.IO.MemoryStream") ||
fullTypeName.Contains("System.IO.FileStream"))
{
return "bytes";
}
// Basic types // Basic types
switch (typeName) switch (typeName)
{ {
@ -111,35 +49,81 @@ internal static class ProtoFileTypeMapper
return "double"; return "double";
case "Byte[]": case "Byte[]":
return "bytes"; return "bytes";
case "Stream": }
case "MemoryStream":
case "FileStream": // Special types that need imports
return "bytes"; if (fullTypeName.Contains("System.DateTime"))
case "Guid": {
// Guid serialized as string
return "string";
case "Decimal":
// Decimal serialized as string (no native decimal in proto)
return "string";
case "DateTime":
case "DateTimeOffset":
needsImport = true; needsImport = true;
importPath = "google/protobuf/timestamp.proto"; importPath = "google/protobuf/timestamp.proto";
return "google.protobuf.Timestamp"; return "google.protobuf.Timestamp";
case "DateOnly": }
// DateOnly serialized as string (YYYY-MM-DD format)
return "string"; if (fullTypeName.Contains("System.TimeSpan"))
case "TimeOnly": {
// TimeOnly serialized as string (HH:mm:ss format)
return "string";
case "TimeSpan":
needsImport = true; needsImport = true;
importPath = "google/protobuf/duration.proto"; importPath = "google/protobuf/duration.proto";
return "google.protobuf.Duration"; return "google.protobuf.Duration";
case "JsonElement": }
if (fullTypeName.Contains("System.Guid"))
{
// Guid serialized as string
return "string";
}
if (fullTypeName.Contains("System.Decimal") || typeName == "Decimal" || fullTypeName == "decimal")
{
// Decimal serialized as string (no native decimal in proto)
return "string";
}
// Handle Nullable<T> value types (e.g., int?, decimal?, enum?)
if (typeSymbol is INamedTypeSymbol nullableType &&
nullableType.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T &&
nullableType.TypeArguments.Length == 1)
{
// Unwrap the nullable and map the inner type
return MapType(nullableType.TypeArguments[0], out needsImport, out importPath);
}
// Collections
if (typeSymbol is INamedTypeSymbol collectionType)
{
// List, IEnumerable, Array, ICollection etc. (but not Nullable<T>)
var typeName2 = collectionType.Name;
if (collectionType.TypeArguments.Length == 1 &&
(typeName2.Contains("List") || typeName2.Contains("Collection") ||
typeName2.Contains("Enumerable") || typeName2.Contains("Array") ||
typeName2.Contains("Set") || typeName2.Contains("IList") ||
typeName2.Contains("ICollection") || typeName2.Contains("IEnumerable")))
{
var elementType = collectionType.TypeArguments[0];
var protoElementType = MapType(elementType, out needsImport, out importPath);
return $"repeated {protoElementType}";
}
// Dictionary<K, V>
if (collectionType.TypeArguments.Length == 2 &&
(typeName.Contains("Dictionary") || typeName.Contains("IDictionary")))
{
var keyType = MapType(collectionType.TypeArguments[0], out var keyNeedsImport, out var keyImportPath);
var valueType = MapType(collectionType.TypeArguments[1], out var valueNeedsImport, out var valueImportPath);
// Set import flags if either key or value needs imports
if (keyNeedsImport)
{
needsImport = true; needsImport = true;
importPath = "google/protobuf/struct.proto"; importPath = keyImportPath;
return "google.protobuf.Struct"; }
if (valueNeedsImport)
{
needsImport = true;
importPath = valueImportPath; // Note: This only captures last import, may need improvement
}
return $"map<{keyType}, {valueType}>";
}
} }
// Enums // Enums
@ -159,10 +143,7 @@ internal static class ProtoFileTypeMapper
} }
/// <summary> /// <summary>
/// Converts C# PascalCase property name to proto snake_case field name. /// Converts C# PascalCase property name to proto snake_case field name
/// Uses simple conversion: add underscore before each uppercase letter (except first).
/// This matches protobuf's C# codegen expectations for PascalCase conversion.
/// Example: TotalADeduire -> total_a_deduire -> TotalADeduire (in generated C#)
/// </summary> /// </summary>
public static string ToSnakeCase(string pascalCase) public static string ToSnakeCase(string pascalCase)
{ {
@ -176,10 +157,18 @@ internal static class ProtoFileTypeMapper
{ {
var c = pascalCase[i]; var c = pascalCase[i];
if (char.IsUpper(c)) if (char.IsUpper(c))
{
// Handle sequences of uppercase letters (e.g., "APIKey" -> "api_key")
if (i + 1 < pascalCase.Length && char.IsUpper(pascalCase[i + 1]))
{
result.Append(char.ToLowerInvariant(c));
}
else
{ {
result.Append('_'); result.Append('_');
result.Append(char.ToLowerInvariant(c)); result.Append(char.ToLowerInvariant(c));
} }
}
else else
{ {
result.Append(c); result.Append(c);
@ -197,8 +186,8 @@ internal static class ProtoFileTypeMapper
var fullTypeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); var fullTypeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
// Skip these types - they should trigger a warning/error // Skip these types - they should trigger a warning/error
// Note: Stream types are now supported (mapped to bytes) if (fullTypeName.Contains("System.IO.Stream") ||
if (fullTypeName.Contains("System.Threading.CancellationToken") || fullTypeName.Contains("System.Threading.CancellationToken") ||
fullTypeName.Contains("System.Threading.Tasks.Task") || fullTypeName.Contains("System.Threading.Tasks.Task") ||
fullTypeName.Contains("System.Collections.Generic.IAsyncEnumerable") || fullTypeName.Contains("System.Collections.Generic.IAsyncEnumerable") ||
fullTypeName.Contains("System.Func") || fullTypeName.Contains("System.Func") ||
@ -211,31 +200,6 @@ internal static class ProtoFileTypeMapper
return false; return false;
} }
/// <summary>
/// Checks if a type is a Stream or byte array type (for special ByteString handling)
/// </summary>
public static bool IsBinaryType(ITypeSymbol typeSymbol)
{
var fullTypeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
// Check for byte[]
if (typeSymbol is IArrayTypeSymbol arrayType && arrayType.ElementType.SpecialType == SpecialType.System_Byte)
{
return true;
}
// Check for Stream types
if (fullTypeName.Contains("System.IO.Stream") ||
fullTypeName.Contains("System.IO.MemoryStream") ||
fullTypeName.Contains("System.IO.FileStream"))
{
return true;
}
var typeName = typeSymbol.Name;
return typeName == "Stream" || typeName == "MemoryStream" || typeName == "FileStream";
}
/// <summary> /// <summary>
/// Gets the element type from a collection type, or returns the type itself if not a collection. /// Gets the element type from a collection type, or returns the type itself if not a collection.
/// Also unwraps Nullable types. /// Also unwraps Nullable types.
@ -287,97 +251,4 @@ internal static class ProtoFileTypeMapper
} }
return null; return null;
} }
/// <summary>
/// Checks if a type is a collection by checking if it implements IList{T}, ICollection{T}, or IEnumerable{T}
/// This handles types like NpgsqlPolygon that implement IList{NpgsqlPoint} but aren't named "List"
/// </summary>
public static bool IsCollectionTypeByInterface(ITypeSymbol typeSymbol)
{
if (typeSymbol is not INamedTypeSymbol namedType)
return false;
// Skip string (implements IEnumerable<char>)
if (namedType.SpecialType == SpecialType.System_String)
return false;
// Check all interfaces for IList<T>, ICollection<T>, or IEnumerable<T>
foreach (var iface in namedType.AllInterfaces)
{
if (iface.IsGenericType && iface.TypeArguments.Length == 1)
{
var ifaceName = iface.OriginalDefinition.ToDisplayString();
if (ifaceName == "System.Collections.Generic.IList<T>" ||
ifaceName == "System.Collections.Generic.ICollection<T>" ||
ifaceName == "System.Collections.Generic.IEnumerable<T>" ||
ifaceName == "System.Collections.Generic.IReadOnlyList<T>" ||
ifaceName == "System.Collections.Generic.IReadOnlyCollection<T>")
{
return true;
}
}
}
return false;
}
/// <summary>
/// Gets the element type from a collection that implements IList{T}, ICollection{T}, or IEnumerable{T}
/// Returns null if the type is not a collection
/// </summary>
public static ITypeSymbol? GetCollectionElementTypeByInterface(ITypeSymbol typeSymbol)
{
if (typeSymbol is not INamedTypeSymbol namedType)
return null;
// Skip string
if (namedType.SpecialType == SpecialType.System_String)
return null;
// Prefer IList<T> over ICollection<T> over IEnumerable<T>
ITypeSymbol? elementType = null;
int priority = 0;
foreach (var iface in namedType.AllInterfaces)
{
if (iface.IsGenericType && iface.TypeArguments.Length == 1)
{
var ifaceName = iface.OriginalDefinition.ToDisplayString();
int currentPriority = 0;
if (ifaceName == "System.Collections.Generic.IList<T>" ||
ifaceName == "System.Collections.Generic.IReadOnlyList<T>")
currentPriority = 3;
else if (ifaceName == "System.Collections.Generic.ICollection<T>" ||
ifaceName == "System.Collections.Generic.IReadOnlyCollection<T>")
currentPriority = 2;
else if (ifaceName == "System.Collections.Generic.IEnumerable<T>")
currentPriority = 1;
if (currentPriority > priority)
{
priority = currentPriority;
elementType = iface.TypeArguments[0];
}
}
}
return elementType;
}
/// <summary>
/// Collection-internal properties that should be skipped when generating proto messages
/// </summary>
private static readonly System.Collections.Generic.HashSet<string> CollectionInternalProperties = new()
{
"Count", "Capacity", "IsReadOnly", "IsSynchronized", "SyncRoot", "Keys", "Values"
};
/// <summary>
/// Checks if a property name is a collection-internal property that should be skipped
/// </summary>
public static bool IsCollectionInternalProperty(string propertyName)
{
return CollectionInternalProperties.Contains(propertyName);
}
} }

View File

@ -2,5 +2,15 @@
<PropertyGroup> <PropertyGroup>
<!-- Marker to indicate Svrnty.CQRS.Grpc.Generators is referenced --> <!-- Marker to indicate Svrnty.CQRS.Grpc.Generators is referenced -->
<SvrntyCqrsGrpcGeneratorsVersion>$(SvrntyCqrsGrpcGeneratorsVersion)</SvrntyCqrsGrpcGeneratorsVersion> <SvrntyCqrsGrpcGeneratorsVersion>$(SvrntyCqrsGrpcGeneratorsVersion)</SvrntyCqrsGrpcGeneratorsVersion>
<!-- Path resolution for both NuGet package and project reference -->
<_SvrntyCqrsGrpcGeneratorsPath Condition="Exists('$(MSBuildThisFileDirectory)..\analyzers\dotnet\cs\Svrnty.CQRS.Grpc.Generators.dll')">$(MSBuildThisFileDirectory)..\analyzers\dotnet\cs\Svrnty.CQRS.Grpc.Generators.dll</_SvrntyCqrsGrpcGeneratorsPath>
<_SvrntyCqrsGrpcGeneratorsPath Condition="'$(_SvrntyCqrsGrpcGeneratorsPath)' == '' AND Exists('$(MSBuildThisFileDirectory)..\bin\Debug\netstandard2.0\Svrnty.CQRS.Grpc.Generators.dll')">$(MSBuildThisFileDirectory)..\bin\Debug\netstandard2.0\Svrnty.CQRS.Grpc.Generators.dll</_SvrntyCqrsGrpcGeneratorsPath>
<_SvrntyCqrsGrpcGeneratorsPath Condition="'$(_SvrntyCqrsGrpcGeneratorsPath)' == '' AND Exists('$(MSBuildThisFileDirectory)..\bin\Release\netstandard2.0\Svrnty.CQRS.Grpc.Generators.dll')">$(MSBuildThisFileDirectory)..\bin\Release\netstandard2.0\Svrnty.CQRS.Grpc.Generators.dll</_SvrntyCqrsGrpcGeneratorsPath>
</PropertyGroup> </PropertyGroup>
<!-- Explicitly add the generator to the Analyzer ItemGroup -->
<ItemGroup Condition="'$(_SvrntyCqrsGrpcGeneratorsPath)' != ''">
<Analyzer Include="$(_SvrntyCqrsGrpcGeneratorsPath)" />
</ItemGroup>
</Project> </Project>

View File

@ -19,6 +19,7 @@ public static class EndpointRouteBuilderExtensions
public static IEndpointRouteBuilder MapSvrntyQueries(this IEndpointRouteBuilder endpoints, string routePrefix = "api/query") public static IEndpointRouteBuilder MapSvrntyQueries(this IEndpointRouteBuilder endpoints, string routePrefix = "api/query")
{ {
var queryDiscovery = endpoints.ServiceProvider.GetRequiredService<IQueryDiscovery>(); var queryDiscovery = endpoints.ServiceProvider.GetRequiredService<IQueryDiscovery>();
var authorizationService = endpoints.ServiceProvider.GetService<IQueryAuthorizationService>();
foreach (var queryMeta in queryDiscovery.GetQueries()) foreach (var queryMeta in queryDiscovery.GetQueries())
{ {
@ -32,8 +33,8 @@ public static class EndpointRouteBuilderExtensions
var route = $"{routePrefix}/{queryMeta.LowerCamelCaseName}"; var route = $"{routePrefix}/{queryMeta.LowerCamelCaseName}";
MapQueryPost(endpoints, route, queryMeta); MapQueryPost(endpoints, route, queryMeta, authorizationService);
MapQueryGet(endpoints, route, queryMeta); MapQueryGet(endpoints, route, queryMeta, authorizationService);
} }
return endpoints; return endpoints;
@ -42,13 +43,13 @@ public static class EndpointRouteBuilderExtensions
private static void MapQueryPost( private static void MapQueryPost(
IEndpointRouteBuilder endpoints, IEndpointRouteBuilder endpoints,
string route, string route,
IQueryMeta queryMeta) IQueryMeta queryMeta,
IQueryAuthorizationService? authorizationService)
{ {
var handlerType = typeof(IQueryHandler<,>).MakeGenericType(queryMeta.QueryType, queryMeta.QueryResultType); var handlerType = typeof(IQueryHandler<,>).MakeGenericType(queryMeta.QueryType, queryMeta.QueryResultType);
endpoints.MapPost(route, async (HttpContext context, IServiceProvider serviceProvider, CancellationToken cancellationToken) => endpoints.MapPost(route, async (HttpContext context, IServiceProvider serviceProvider, CancellationToken cancellationToken) =>
{ {
var authorizationService = serviceProvider.GetService<IQueryAuthorizationService>();
if (authorizationService != null) if (authorizationService != null)
{ {
var authorizationResult = await authorizationService.IsAllowedAsync(queryMeta.QueryType, cancellationToken); var authorizationResult = await authorizationService.IsAllowedAsync(queryMeta.QueryType, cancellationToken);
@ -89,13 +90,13 @@ public static class EndpointRouteBuilderExtensions
private static void MapQueryGet( private static void MapQueryGet(
IEndpointRouteBuilder endpoints, IEndpointRouteBuilder endpoints,
string route, string route,
IQueryMeta queryMeta) IQueryMeta queryMeta,
IQueryAuthorizationService? authorizationService)
{ {
var handlerType = typeof(IQueryHandler<,>).MakeGenericType(queryMeta.QueryType, queryMeta.QueryResultType); var handlerType = typeof(IQueryHandler<,>).MakeGenericType(queryMeta.QueryType, queryMeta.QueryResultType);
endpoints.MapGet(route, async (HttpContext context, IServiceProvider serviceProvider, CancellationToken cancellationToken) => endpoints.MapGet(route, async (HttpContext context, IServiceProvider serviceProvider, CancellationToken cancellationToken) =>
{ {
var authorizationService = serviceProvider.GetService<IQueryAuthorizationService>();
if (authorizationService != null) if (authorizationService != null)
{ {
var authorizationResult = await authorizationService.IsAllowedAsync(queryMeta.QueryType, cancellationToken); var authorizationResult = await authorizationService.IsAllowedAsync(queryMeta.QueryType, cancellationToken);
@ -152,6 +153,7 @@ public static class EndpointRouteBuilderExtensions
public static IEndpointRouteBuilder MapSvrntyCommands(this IEndpointRouteBuilder endpoints, string routePrefix = "api/command") public static IEndpointRouteBuilder MapSvrntyCommands(this IEndpointRouteBuilder endpoints, string routePrefix = "api/command")
{ {
var commandDiscovery = endpoints.ServiceProvider.GetRequiredService<ICommandDiscovery>(); var commandDiscovery = endpoints.ServiceProvider.GetRequiredService<ICommandDiscovery>();
var authorizationService = endpoints.ServiceProvider.GetService<ICommandAuthorizationService>();
foreach (var commandMeta in commandDiscovery.GetCommands()) foreach (var commandMeta in commandDiscovery.GetCommands())
{ {
@ -163,11 +165,11 @@ public static class EndpointRouteBuilderExtensions
if (commandMeta.CommandResultType == null) if (commandMeta.CommandResultType == null)
{ {
MapCommandWithoutResult(endpoints, route, commandMeta); MapCommandWithoutResult(endpoints, route, commandMeta, authorizationService);
} }
else else
{ {
MapCommandWithResult(endpoints, route, commandMeta); MapCommandWithResult(endpoints, route, commandMeta, authorizationService);
} }
} }
@ -177,13 +179,13 @@ public static class EndpointRouteBuilderExtensions
private static void MapCommandWithoutResult( private static void MapCommandWithoutResult(
IEndpointRouteBuilder endpoints, IEndpointRouteBuilder endpoints,
string route, string route,
ICommandMeta commandMeta) ICommandMeta commandMeta,
ICommandAuthorizationService? authorizationService)
{ {
var handlerType = typeof(ICommandHandler<>).MakeGenericType(commandMeta.CommandType); var handlerType = typeof(ICommandHandler<>).MakeGenericType(commandMeta.CommandType);
endpoints.MapPost(route, async (HttpContext context, IServiceProvider serviceProvider, CancellationToken cancellationToken) => endpoints.MapPost(route, async (HttpContext context, IServiceProvider serviceProvider, CancellationToken cancellationToken) =>
{ {
var authorizationService = serviceProvider.GetService<ICommandAuthorizationService>();
if (authorizationService != null) if (authorizationService != null)
{ {
var authorizationResult = await authorizationService.IsAllowedAsync(commandMeta.CommandType, cancellationToken); var authorizationResult = await authorizationService.IsAllowedAsync(commandMeta.CommandType, cancellationToken);
@ -219,13 +221,13 @@ public static class EndpointRouteBuilderExtensions
private static void MapCommandWithResult( private static void MapCommandWithResult(
IEndpointRouteBuilder endpoints, IEndpointRouteBuilder endpoints,
string route, string route,
ICommandMeta commandMeta) ICommandMeta commandMeta,
ICommandAuthorizationService? authorizationService)
{ {
var handlerType = typeof(ICommandHandler<,>).MakeGenericType(commandMeta.CommandType, commandMeta.CommandResultType!); var handlerType = typeof(ICommandHandler<,>).MakeGenericType(commandMeta.CommandType, commandMeta.CommandResultType!);
endpoints.MapPost(route, async (HttpContext context, IServiceProvider serviceProvider, CancellationToken cancellationToken) => endpoints.MapPost(route, async (HttpContext context, IServiceProvider serviceProvider, CancellationToken cancellationToken) =>
{ {
var authorizationService = serviceProvider.GetService<ICommandAuthorizationService>();
if (authorizationService != null) if (authorizationService != null)
{ {
var authorizationResult = await authorizationService.IsAllowedAsync(commandMeta.CommandType, cancellationToken); var authorizationResult = await authorizationService.IsAllowedAsync(commandMeta.CommandType, cancellationToken);