wip proto file generation

This commit is contained in:
Mathias Beaulieu-Duncan 2025-11-02 11:22:28 -05:00
parent 6735261f21
commit ccfaa35c1d
Signed by: mathias
GPG Key ID: 1C16CF05BAF9162D
7 changed files with 695 additions and 7 deletions

View File

@ -0,0 +1,338 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
namespace Svrnty.CQRS.Grpc.Generators;
/// <summary>
/// Generates Protocol Buffer (.proto) files from C# Command and Query types
/// </summary>
internal class ProtoFileGenerator
{
private readonly Compilation _compilation;
private readonly HashSet<string> _requiredImports = new HashSet<string>();
private readonly HashSet<string> _generatedMessages = new HashSet<string>();
private readonly StringBuilder _messagesBuilder = new StringBuilder();
public ProtoFileGenerator(Compilation compilation)
{
_compilation = compilation;
}
public string Generate(string packageName, string csharpNamespace)
{
var commands = DiscoverCommands();
var queries = DiscoverQueries();
var sb = new StringBuilder();
// Header
sb.AppendLine("syntax = \"proto3\";");
sb.AppendLine();
sb.AppendLine($"option csharp_namespace = \"{csharpNamespace}\";");
sb.AppendLine();
sb.AppendLine($"package {packageName};");
sb.AppendLine();
// Imports (will be added later if needed)
var importsPlaceholder = sb.Length;
// Command Service
if (commands.Any())
{
sb.AppendLine("// Command service for CQRS operations");
sb.AppendLine("service CommandService {");
foreach (var command in commands)
{
var methodName = command.Name.Replace("Command", "");
var requestType = $"{command.Name}Request";
var responseType = $"{command.Name}Response";
sb.AppendLine($" // {GetXmlDocSummary(command)}");
sb.AppendLine($" rpc {methodName} ({requestType}) returns ({responseType});");
sb.AppendLine();
}
sb.AppendLine("}");
sb.AppendLine();
}
// Query Service
if (queries.Any())
{
sb.AppendLine("// Query service for CQRS operations");
sb.AppendLine("service QueryService {");
foreach (var query in queries)
{
var methodName = query.Name.Replace("Query", "");
var requestType = $"{query.Name}Request";
var responseType = $"{query.Name}Response";
sb.AppendLine($" // {GetXmlDocSummary(query)}");
sb.AppendLine($" rpc {methodName} ({requestType}) returns ({responseType});");
sb.AppendLine();
}
sb.AppendLine("}");
sb.AppendLine();
}
// Generate messages for commands
foreach (var command in commands)
{
GenerateRequestMessage(command);
GenerateResponseMessage(command);
}
// Generate messages for queries
foreach (var query in queries)
{
GenerateRequestMessage(query);
GenerateResponseMessage(query);
}
// Append all generated messages
sb.Append(_messagesBuilder);
// Insert imports if any were needed
if (_requiredImports.Any())
{
var imports = new StringBuilder();
foreach (var import in _requiredImports.OrderBy(i => i))
{
imports.AppendLine($"import \"{import}\";");
}
imports.AppendLine();
sb.Insert(importsPlaceholder, imports.ToString());
}
return sb.ToString();
}
private List<INamedTypeSymbol> DiscoverCommands()
{
return _compilation.GetSymbolsWithName(
name => name.EndsWith("Command"),
SymbolFilter.Type)
.OfType<INamedTypeSymbol>()
.Where(t => !HasGrpcIgnoreAttribute(t))
.Where(t => t.TypeKind == TypeKind.Class || t.TypeKind == TypeKind.Struct)
.ToList();
}
private List<INamedTypeSymbol> DiscoverQueries()
{
return _compilation.GetSymbolsWithName(
name => name.EndsWith("Query"),
SymbolFilter.Type)
.OfType<INamedTypeSymbol>()
.Where(t => !HasGrpcIgnoreAttribute(t))
.Where(t => t.TypeKind == TypeKind.Class || t.TypeKind == TypeKind.Struct)
.ToList();
}
private bool HasGrpcIgnoreAttribute(INamedTypeSymbol type)
{
return type.GetAttributes().Any(attr =>
attr.AttributeClass?.Name == "GrpcIgnoreAttribute");
}
private void GenerateRequestMessage(INamedTypeSymbol type)
{
var messageName = $"{type.Name}Request";
if (_generatedMessages.Contains(messageName))
return;
_generatedMessages.Add(messageName);
_messagesBuilder.AppendLine($"// Request message for {type.Name}");
_messagesBuilder.AppendLine($"message {messageName} {{");
var properties = type.GetMembers()
.OfType<IPropertySymbol>()
.Where(p => p.DeclaredAccessibility == Accessibility.Public)
.ToList();
int fieldNumber = 1;
foreach (var prop in properties)
{
if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type))
{
// Skip unsupported types and add a comment
_messagesBuilder.AppendLine($" // Skipped: {prop.Name} - unsupported type {prop.Type.Name}");
continue;
}
var protoType = ProtoFileTypeMapper.MapType(prop.Type, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name);
_messagesBuilder.AppendLine($" {protoType} {fieldName} = {fieldNumber};");
// If this is a complex type, generate its message too
if (IsComplexType(prop.Type))
{
GenerateComplexTypeMessage(prop.Type as INamedTypeSymbol);
}
fieldNumber++;
}
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
}
private void GenerateResponseMessage(INamedTypeSymbol type)
{
var messageName = $"{type.Name}Response";
if (_generatedMessages.Contains(messageName))
return;
_generatedMessages.Add(messageName);
_messagesBuilder.AppendLine($"// Response message for {type.Name}");
_messagesBuilder.AppendLine($"message {messageName} {{");
// Determine the result type from ICommandHandler<T, TResult> or IQueryHandler<T, TResult>
var resultType = GetResultType(type);
if (resultType != null)
{
var protoType = ProtoFileTypeMapper.MapType(resultType, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
_messagesBuilder.AppendLine($" {protoType} result = 1;");
}
// If no result type, leave message empty (void return)
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
// Generate complex type message after closing the response message
if (resultType != null && IsComplexType(resultType))
{
GenerateComplexTypeMessage(resultType as INamedTypeSymbol);
}
}
private void GenerateComplexTypeMessage(INamedTypeSymbol? type)
{
if (type == null || _generatedMessages.Contains(type.Name))
return;
// Don't generate messages for system types or primitives
if (type.ContainingNamespace?.ToString().StartsWith("System") == true)
return;
_generatedMessages.Add(type.Name);
_messagesBuilder.AppendLine($"// {type.Name} entity");
_messagesBuilder.AppendLine($"message {type.Name} {{");
var properties = type.GetMembers()
.OfType<IPropertySymbol>()
.Where(p => p.DeclaredAccessibility == Accessibility.Public)
.ToList();
int fieldNumber = 1;
foreach (var prop in properties)
{
if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type))
{
_messagesBuilder.AppendLine($" // Skipped: {prop.Name} - unsupported type {prop.Type.Name}");
continue;
}
var protoType = ProtoFileTypeMapper.MapType(prop.Type, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name);
_messagesBuilder.AppendLine($" {protoType} {fieldName} = {fieldNumber};");
// Recursively generate nested complex types
if (IsComplexType(prop.Type))
{
GenerateComplexTypeMessage(prop.Type as INamedTypeSymbol);
}
fieldNumber++;
}
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
}
private ITypeSymbol? GetResultType(INamedTypeSymbol commandOrQueryType)
{
// Scan for handler classes that implement ICommandHandler<T, TResult> or IQueryHandler<T, TResult>
var handlerInterfaceName = commandOrQueryType.Name.EndsWith("Command")
? "ICommandHandler"
: "IQueryHandler";
// Find all types in the compilation
var allTypes = _compilation.GetSymbolsWithName(_ => true, SymbolFilter.Type)
.OfType<INamedTypeSymbol>();
foreach (var type in allTypes)
{
// Check if this type implements the handler interface
foreach (var @interface in type.AllInterfaces)
{
if (@interface.Name == handlerInterfaceName && @interface.TypeArguments.Length >= 1)
{
// Check if the first type argument matches our command/query
var firstArg = @interface.TypeArguments[0];
if (SymbolEqualityComparer.Default.Equals(firstArg, commandOrQueryType))
{
// Found the handler! Return the result type (second type argument) if it exists
if (@interface.TypeArguments.Length == 2)
{
return @interface.TypeArguments[1];
}
// If only one type argument, it's a void command (ICommandHandler<T>)
return null;
}
}
}
}
return null; // No handler found
}
private bool IsComplexType(ITypeSymbol type)
{
// Check if it's a user-defined class/struct (not a primitive or system type)
if (type.TypeKind != TypeKind.Class && type.TypeKind != TypeKind.Struct)
return false;
var fullName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
return !fullName.Contains("System.");
}
private string GetXmlDocSummary(INamedTypeSymbol type)
{
var xml = type.GetDocumentationCommentXml();
if (string.IsNullOrEmpty(xml))
return $"{type.Name} operation";
// Simple extraction - could be enhanced
// xml is guaranteed non-null after IsNullOrEmpty check above
var summaryStart = xml!.IndexOf("<summary>");
var summaryEnd = xml.IndexOf("</summary>");
if (summaryStart >= 0 && summaryEnd > summaryStart)
{
var summary = xml.Substring(summaryStart + 9, summaryEnd - summaryStart - 9).Trim();
return summary;
}
return $"{type.Name} operation";
}
}

View File

@ -0,0 +1,131 @@
using System;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace Svrnty.CQRS.Grpc.Generators;
/// <summary>
/// Incremental source generator that generates .proto files from C# commands and queries
/// </summary>
[Generator]
public class ProtoFileSourceGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Register a post-initialization output to generate the proto file
context.RegisterPostInitializationOutput(ctx =>
{
// Generate a placeholder - the actual proto will be generated in the source output
});
// Collect all command and query types
var commandsAndQueries = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (s, _) => IsCommandOrQuery(s),
transform: static (ctx, _) => GetTypeSymbol(ctx))
.Where(static m => m is not null)
.Collect();
// Combine with compilation to have access to it
var compilationAndTypes = context.CompilationProvider.Combine(commandsAndQueries);
// Generate proto file when commands/queries change
context.RegisterSourceOutput(compilationAndTypes, (spc, source) =>
{
var (compilation, types) = source;
if (types.IsDefaultOrEmpty)
return;
try
{
// Get build properties for configuration
var packageName = GetBuildProperty(spc, "RootNamespace") ?? "cqrs";
var csharpNamespace = GetBuildProperty(spc, "RootNamespace") ?? "Generated.Grpc";
// Generate the proto file content
var generator = new ProtoFileGenerator(compilation);
var protoContent = generator.Generate(packageName, csharpNamespace);
// Output as an embedded resource that can be extracted
var protoFileName = "cqrs_services.proto";
// Generate a C# class that contains the proto content
// This allows build tools to extract it if needed
var csContent = $$"""
// <auto-generated />
#nullable enable
namespace Svrnty.CQRS.Grpc.Generated
{
/// <summary>
/// Contains the auto-generated Protocol Buffer definition
/// </summary>
internal static class GeneratedProtoFile
{
public const string FileName = "{{protoFileName}}";
public const string Content = @"{{protoContent.Replace("\"", "\"\"")}}";
}
}
""";
spc.AddSource("GeneratedProtoFile.g.cs", csContent);
// Report that we generated the proto content
var descriptor = new DiagnosticDescriptor(
"CQRSGRPC002",
"Proto file generated",
"Generated proto file content in GeneratedProtoFile class",
"Svrnty.CQRS.Grpc",
DiagnosticSeverity.Info,
isEnabledByDefault: true);
spc.ReportDiagnostic(Diagnostic.Create(descriptor, Location.None));
}
catch (Exception ex)
{
// Report diagnostic if generation fails
var descriptor = new DiagnosticDescriptor(
"CQRSGRPC001",
"Proto file generation failed",
"Failed to generate proto file: {0}",
"Svrnty.CQRS.Grpc",
DiagnosticSeverity.Warning,
isEnabledByDefault: true);
spc.ReportDiagnostic(Diagnostic.Create(descriptor, Location.None, ex.Message));
}
});
}
private static bool IsCommandOrQuery(SyntaxNode node)
{
if (node is not TypeDeclarationSyntax typeDecl)
return false;
var name = typeDecl.Identifier.Text;
return name.EndsWith("Command") || name.EndsWith("Query");
}
private static INamedTypeSymbol? GetTypeSymbol(GeneratorSyntaxContext context)
{
var typeDecl = (TypeDeclarationSyntax)context.Node;
var symbol = context.SemanticModel.GetDeclaredSymbol(typeDecl) as INamedTypeSymbol;
// Skip if it has GrpcIgnore attribute
if (symbol?.GetAttributes().Any(a => a.AttributeClass?.Name == "GrpcIgnoreAttribute") == true)
return null;
return symbol;
}
private static string? GetBuildProperty(SourceProductionContext context, string propertyName)
{
// Try to get build properties from the compilation options
// This is a simplified approach - in practice, you might need analyzer config
return null; // Will use defaults
}
}

View File

@ -0,0 +1,191 @@
using System;
using Microsoft.CodeAnalysis;
namespace Svrnty.CQRS.Grpc.Generators;
/// <summary>
/// Maps C# types to Protocol Buffer types for proto file generation
/// </summary>
internal static class ProtoFileTypeMapper
{
public static string MapType(ITypeSymbol typeSymbol, out bool needsImport, out string? importPath)
{
needsImport = false;
importPath = null;
// Handle special name (fully qualified name)
var fullTypeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var typeName = typeSymbol.Name;
// Nullable types - unwrap
if (typeSymbol.NullableAnnotation == NullableAnnotation.Annotated && typeSymbol is INamedTypeSymbol namedType && namedType.TypeArguments.Length > 0)
{
return MapType(namedType.TypeArguments[0], out needsImport, out importPath);
}
// Basic types
switch (typeName)
{
case "String":
return "string";
case "Int32":
return "int32";
case "UInt32":
return "uint32";
case "Int64":
return "int64";
case "UInt64":
return "uint64";
case "Int16":
return "int32"; // Proto has no int16
case "UInt16":
return "uint32"; // Proto has no uint16
case "Byte":
return "uint32"; // Proto has no byte
case "SByte":
return "int32"; // Proto has no sbyte
case "Boolean":
return "bool";
case "Single":
return "float";
case "Double":
return "double";
case "Byte[]":
return "bytes";
}
// Special types that need imports
if (fullTypeName.Contains("System.DateTime"))
{
needsImport = true;
importPath = "google/protobuf/timestamp.proto";
return "google.protobuf.Timestamp";
}
if (fullTypeName.Contains("System.TimeSpan"))
{
needsImport = true;
importPath = "google/protobuf/duration.proto";
return "google.protobuf.Duration";
}
if (fullTypeName.Contains("System.Guid"))
{
// Guid serialized as string
return "string";
}
if (fullTypeName.Contains("System.Decimal"))
{
// Decimal serialized as string (no native decimal in proto)
return "string";
}
// Collections
if (typeSymbol is INamedTypeSymbol collectionType)
{
// List, IEnumerable, Array, etc.
if (collectionType.TypeArguments.Length == 1)
{
var elementType = collectionType.TypeArguments[0];
var protoElementType = MapType(elementType, out needsImport, out importPath);
return $"repeated {protoElementType}";
}
// Dictionary<K, V>
if (collectionType.TypeArguments.Length == 2 &&
(typeName.Contains("Dictionary") || typeName.Contains("IDictionary")))
{
var keyType = MapType(collectionType.TypeArguments[0], out var keyNeedsImport, out var keyImportPath);
var valueType = MapType(collectionType.TypeArguments[1], out var valueNeedsImport, out var valueImportPath);
// Set import flags if either key or value needs imports
if (keyNeedsImport)
{
needsImport = true;
importPath = keyImportPath;
}
if (valueNeedsImport)
{
needsImport = true;
importPath = valueImportPath; // Note: This only captures last import, may need improvement
}
return $"map<{keyType}, {valueType}>";
}
}
// Enums
if (typeSymbol.TypeKind == TypeKind.Enum)
{
return typeName; // Use the enum name directly
}
// Complex types (classes/records) become message types
if (typeSymbol.TypeKind == TypeKind.Class || typeSymbol.TypeKind == TypeKind.Struct)
{
return typeName; // Reference the message type by name
}
// Fallback
return "string"; // Default to string for unknown types
}
/// <summary>
/// Converts C# PascalCase property name to proto snake_case field name
/// </summary>
public static string ToSnakeCase(string pascalCase)
{
if (string.IsNullOrEmpty(pascalCase))
return pascalCase;
var result = new System.Text.StringBuilder();
result.Append(char.ToLowerInvariant(pascalCase[0]));
for (int i = 1; i < pascalCase.Length; i++)
{
var c = pascalCase[i];
if (char.IsUpper(c))
{
// Handle sequences of uppercase letters (e.g., "APIKey" -> "api_key")
if (i + 1 < pascalCase.Length && char.IsUpper(pascalCase[i + 1]))
{
result.Append(char.ToLowerInvariant(c));
}
else
{
result.Append('_');
result.Append(char.ToLowerInvariant(c));
}
}
else
{
result.Append(c);
}
}
return result.ToString();
}
/// <summary>
/// Checks if a type should be skipped/ignored for proto generation
/// </summary>
public static bool IsUnsupportedType(ITypeSymbol typeSymbol)
{
var fullTypeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
// Skip these types - they should trigger a warning/error
if (fullTypeName.Contains("System.IO.Stream") ||
fullTypeName.Contains("System.Threading.CancellationToken") ||
fullTypeName.Contains("System.Threading.Tasks.Task") ||
fullTypeName.Contains("System.Collections.Generic.IAsyncEnumerable") ||
fullTypeName.Contains("System.Func") ||
fullTypeName.Contains("System.Action") ||
fullTypeName.Contains("System.Delegate"))
{
return true;
}
return false;
}
}

View File

@ -29,11 +29,17 @@
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.11.0" PrivateAssets="all" /> <PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="5.0.0-2.final" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.4" PrivateAssets="all" /> <PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.11.0" PrivateAssets="all" />
<PackageReference Include="Microsoft.Build.Utilities.Core" Version="17.0.0" PrivateAssets="all" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<!-- Package as analyzer -->
<None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" /> <None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" />
<!-- Also package as build task -->
<None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="build" Visible="false" />
<None Include="build\Svrnty.CQRS.Grpc.Generators.targets" Pack="true" PackagePath="build" />
</ItemGroup> </ItemGroup>
</Project> </Project>

View File

@ -0,0 +1,22 @@
<Project>
<PropertyGroup>
<!-- Set default values for proto generation -->
<GenerateProtoFile Condition="'$(GenerateProtoFile)' == ''">true</GenerateProtoFile>
<ProtoOutputDirectory Condition="'$(ProtoOutputDirectory)' == ''">$(MSBuildProjectDirectory)\Protos</ProtoOutputDirectory>
<GeneratedProtoFileName Condition="'$(GeneratedProtoFileName)' == ''">cqrs_services.proto</GeneratedProtoFileName>
</PropertyGroup>
<Target Name="SvrntyGenerateProtoInfo" BeforeTargets="CoreCompile">
<Message Text="Svrnty.CQRS.Grpc.Generators: Proto file will be auto-generated to $(ProtoOutputDirectory)\$(GeneratedProtoFileName)" Importance="normal" />
</Target>
<!-- This target ensures the Protos directory exists before the generator runs -->
<Target Name="EnsureProtosDirectory" BeforeTargets="CoreCompile">
<MakeDir Directories="$(ProtoOutputDirectory)" Condition="!Exists('$(ProtoOutputDirectory)')" />
</Target>
<!-- Set environment variable so the source generator can find the project directory -->
<PropertyGroup>
<MSBuildProjectDirectory>$(MSBuildProjectDirectory)</MSBuildProjectDirectory>
</PropertyGroup>
</Project>

View File

@ -13,13 +13,13 @@
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="Grpc.AspNetCore" Version="2.70.0" /> <PackageReference Include="Grpc.AspNetCore" Version="2.71.0" />
<PackageReference Include="Grpc.AspNetCore.Server.Reflection" Version="2.71.0" /> <PackageReference Include="Grpc.AspNetCore.Server.Reflection" Version="2.71.0" />
<PackageReference Include="Grpc.Tools" Version="2.70.0"> <PackageReference Include="Grpc.Tools" Version="2.76.0">
<PrivateAssets>all</PrivateAssets> <PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference> </PackageReference>
<PackageReference Include="Grpc.StatusProto" Version="1.70.0" /> <PackageReference Include="Grpc.StatusProto" Version="2.71.0" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="9.0.6" /> <PackageReference Include="Swashbuckle.AspNetCore" Version="9.0.6" />
</ItemGroup> </ItemGroup>

View File

@ -27,8 +27,8 @@
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="Grpc.AspNetCore" Version="2.70.0" /> <PackageReference Include="Grpc.AspNetCore" Version="2.71.0" />
<PackageReference Include="Grpc.Tools" Version="2.70.0"> <PackageReference Include="Grpc.Tools" Version="2.76.0">
<PrivateAssets>all</PrivateAssets> <PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference> </PackageReference>