dotnet-cqrs/Svrnty.CQRS.Grpc.Generators/ProtoFileGenerator.cs

339 lines
12 KiB
C#

using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
namespace Svrnty.CQRS.Grpc.Generators;
/// <summary>
/// Generates Protocol Buffer (.proto) files from C# Command and Query types
/// </summary>
internal class ProtoFileGenerator
{
private readonly Compilation _compilation;
private readonly HashSet<string> _requiredImports = new HashSet<string>();
private readonly HashSet<string> _generatedMessages = new HashSet<string>();
private readonly StringBuilder _messagesBuilder = new StringBuilder();
public ProtoFileGenerator(Compilation compilation)
{
_compilation = compilation;
}
public string Generate(string packageName, string csharpNamespace)
{
var commands = DiscoverCommands();
var queries = DiscoverQueries();
var sb = new StringBuilder();
// Header
sb.AppendLine("syntax = \"proto3\";");
sb.AppendLine();
sb.AppendLine($"option csharp_namespace = \"{csharpNamespace}\";");
sb.AppendLine();
sb.AppendLine($"package {packageName};");
sb.AppendLine();
// Imports (will be added later if needed)
var importsPlaceholder = sb.Length;
// Command Service
if (commands.Any())
{
sb.AppendLine("// Command service for CQRS operations");
sb.AppendLine("service CommandService {");
foreach (var command in commands)
{
var methodName = command.Name.Replace("Command", "");
var requestType = $"{command.Name}Request";
var responseType = $"{command.Name}Response";
sb.AppendLine($" // {GetXmlDocSummary(command)}");
sb.AppendLine($" rpc {methodName} ({requestType}) returns ({responseType});");
sb.AppendLine();
}
sb.AppendLine("}");
sb.AppendLine();
}
// Query Service
if (queries.Any())
{
sb.AppendLine("// Query service for CQRS operations");
sb.AppendLine("service QueryService {");
foreach (var query in queries)
{
var methodName = query.Name.Replace("Query", "");
var requestType = $"{query.Name}Request";
var responseType = $"{query.Name}Response";
sb.AppendLine($" // {GetXmlDocSummary(query)}");
sb.AppendLine($" rpc {methodName} ({requestType}) returns ({responseType});");
sb.AppendLine();
}
sb.AppendLine("}");
sb.AppendLine();
}
// Generate messages for commands
foreach (var command in commands)
{
GenerateRequestMessage(command);
GenerateResponseMessage(command);
}
// Generate messages for queries
foreach (var query in queries)
{
GenerateRequestMessage(query);
GenerateResponseMessage(query);
}
// Append all generated messages
sb.Append(_messagesBuilder);
// Insert imports if any were needed
if (_requiredImports.Any())
{
var imports = new StringBuilder();
foreach (var import in _requiredImports.OrderBy(i => i))
{
imports.AppendLine($"import \"{import}\";");
}
imports.AppendLine();
sb.Insert(importsPlaceholder, imports.ToString());
}
return sb.ToString();
}
private List<INamedTypeSymbol> DiscoverCommands()
{
return _compilation.GetSymbolsWithName(
name => name.EndsWith("Command"),
SymbolFilter.Type)
.OfType<INamedTypeSymbol>()
.Where(t => !HasGrpcIgnoreAttribute(t))
.Where(t => t.TypeKind == TypeKind.Class || t.TypeKind == TypeKind.Struct)
.ToList();
}
private List<INamedTypeSymbol> DiscoverQueries()
{
return _compilation.GetSymbolsWithName(
name => name.EndsWith("Query"),
SymbolFilter.Type)
.OfType<INamedTypeSymbol>()
.Where(t => !HasGrpcIgnoreAttribute(t))
.Where(t => t.TypeKind == TypeKind.Class || t.TypeKind == TypeKind.Struct)
.ToList();
}
private bool HasGrpcIgnoreAttribute(INamedTypeSymbol type)
{
return type.GetAttributes().Any(attr =>
attr.AttributeClass?.Name == "GrpcIgnoreAttribute");
}
private void GenerateRequestMessage(INamedTypeSymbol type)
{
var messageName = $"{type.Name}Request";
if (_generatedMessages.Contains(messageName))
return;
_generatedMessages.Add(messageName);
_messagesBuilder.AppendLine($"// Request message for {type.Name}");
_messagesBuilder.AppendLine($"message {messageName} {{");
var properties = type.GetMembers()
.OfType<IPropertySymbol>()
.Where(p => p.DeclaredAccessibility == Accessibility.Public)
.ToList();
int fieldNumber = 1;
foreach (var prop in properties)
{
if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type))
{
// Skip unsupported types and add a comment
_messagesBuilder.AppendLine($" // Skipped: {prop.Name} - unsupported type {prop.Type.Name}");
continue;
}
var protoType = ProtoFileTypeMapper.MapType(prop.Type, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name);
_messagesBuilder.AppendLine($" {protoType} {fieldName} = {fieldNumber};");
// If this is a complex type, generate its message too
if (IsComplexType(prop.Type))
{
GenerateComplexTypeMessage(prop.Type as INamedTypeSymbol);
}
fieldNumber++;
}
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
}
private void GenerateResponseMessage(INamedTypeSymbol type)
{
var messageName = $"{type.Name}Response";
if (_generatedMessages.Contains(messageName))
return;
_generatedMessages.Add(messageName);
_messagesBuilder.AppendLine($"// Response message for {type.Name}");
_messagesBuilder.AppendLine($"message {messageName} {{");
// Determine the result type from ICommandHandler<T, TResult> or IQueryHandler<T, TResult>
var resultType = GetResultType(type);
if (resultType != null)
{
var protoType = ProtoFileTypeMapper.MapType(resultType, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
_messagesBuilder.AppendLine($" {protoType} result = 1;");
}
// If no result type, leave message empty (void return)
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
// Generate complex type message after closing the response message
if (resultType != null && IsComplexType(resultType))
{
GenerateComplexTypeMessage(resultType as INamedTypeSymbol);
}
}
private void GenerateComplexTypeMessage(INamedTypeSymbol? type)
{
if (type == null || _generatedMessages.Contains(type.Name))
return;
// Don't generate messages for system types or primitives
if (type.ContainingNamespace?.ToString().StartsWith("System") == true)
return;
_generatedMessages.Add(type.Name);
_messagesBuilder.AppendLine($"// {type.Name} entity");
_messagesBuilder.AppendLine($"message {type.Name} {{");
var properties = type.GetMembers()
.OfType<IPropertySymbol>()
.Where(p => p.DeclaredAccessibility == Accessibility.Public)
.ToList();
int fieldNumber = 1;
foreach (var prop in properties)
{
if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type))
{
_messagesBuilder.AppendLine($" // Skipped: {prop.Name} - unsupported type {prop.Type.Name}");
continue;
}
var protoType = ProtoFileTypeMapper.MapType(prop.Type, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name);
_messagesBuilder.AppendLine($" {protoType} {fieldName} = {fieldNumber};");
// Recursively generate nested complex types
if (IsComplexType(prop.Type))
{
GenerateComplexTypeMessage(prop.Type as INamedTypeSymbol);
}
fieldNumber++;
}
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
}
private ITypeSymbol? GetResultType(INamedTypeSymbol commandOrQueryType)
{
// Scan for handler classes that implement ICommandHandler<T, TResult> or IQueryHandler<T, TResult>
var handlerInterfaceName = commandOrQueryType.Name.EndsWith("Command")
? "ICommandHandler"
: "IQueryHandler";
// Find all types in the compilation
var allTypes = _compilation.GetSymbolsWithName(_ => true, SymbolFilter.Type)
.OfType<INamedTypeSymbol>();
foreach (var type in allTypes)
{
// Check if this type implements the handler interface
foreach (var @interface in type.AllInterfaces)
{
if (@interface.Name == handlerInterfaceName && @interface.TypeArguments.Length >= 1)
{
// Check if the first type argument matches our command/query
var firstArg = @interface.TypeArguments[0];
if (SymbolEqualityComparer.Default.Equals(firstArg, commandOrQueryType))
{
// Found the handler! Return the result type (second type argument) if it exists
if (@interface.TypeArguments.Length == 2)
{
return @interface.TypeArguments[1];
}
// If only one type argument, it's a void command (ICommandHandler<T>)
return null;
}
}
}
}
return null; // No handler found
}
private bool IsComplexType(ITypeSymbol type)
{
// Check if it's a user-defined class/struct (not a primitive or system type)
if (type.TypeKind != TypeKind.Class && type.TypeKind != TypeKind.Struct)
return false;
var fullName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
return !fullName.Contains("System.");
}
private string GetXmlDocSummary(INamedTypeSymbol type)
{
var xml = type.GetDocumentationCommentXml();
if (string.IsNullOrEmpty(xml))
return $"{type.Name} operation";
// Simple extraction - could be enhanced
// xml is guaranteed non-null after IsNullOrEmpty check above
var summaryStart = xml!.IndexOf("<summary>");
var summaryEnd = xml.IndexOf("</summary>");
if (summaryStart >= 0 && summaryEnd > summaryStart)
{
var summary = xml.Substring(summaryStart + 9, summaryEnd - summaryStart - 9).Trim();
return summary;
}
return $"{type.Name} operation";
}
}