using System.Collections.Generic; using System.Linq; using System.Text; using Microsoft.CodeAnalysis; namespace Svrnty.CQRS.Grpc.Generators; /// /// Generates Protocol Buffer (.proto) files from C# Command and Query types /// internal class ProtoFileGenerator { private readonly Compilation _compilation; private readonly HashSet _requiredImports = new HashSet(); private readonly HashSet _generatedMessages = new HashSet(); private readonly StringBuilder _messagesBuilder = new StringBuilder(); public ProtoFileGenerator(Compilation compilation) { _compilation = compilation; } public string Generate(string packageName, string csharpNamespace) { var commands = DiscoverCommands(); var queries = DiscoverQueries(); var sb = new StringBuilder(); // Header sb.AppendLine("syntax = \"proto3\";"); sb.AppendLine(); sb.AppendLine($"option csharp_namespace = \"{csharpNamespace}\";"); sb.AppendLine(); sb.AppendLine($"package {packageName};"); sb.AppendLine(); // Imports (will be added later if needed) var importsPlaceholder = sb.Length; // Command Service if (commands.Any()) { sb.AppendLine("// Command service for CQRS operations"); sb.AppendLine("service CommandService {"); foreach (var command in commands) { var methodName = command.Name.Replace("Command", ""); var requestType = $"{command.Name}Request"; var responseType = $"{command.Name}Response"; sb.AppendLine($" // {GetXmlDocSummary(command)}"); sb.AppendLine($" rpc {methodName} ({requestType}) returns ({responseType});"); sb.AppendLine(); } sb.AppendLine("}"); sb.AppendLine(); } // Query Service if (queries.Any()) { sb.AppendLine("// Query service for CQRS operations"); sb.AppendLine("service QueryService {"); foreach (var query in queries) { var methodName = query.Name.Replace("Query", ""); var requestType = $"{query.Name}Request"; var responseType = $"{query.Name}Response"; sb.AppendLine($" // {GetXmlDocSummary(query)}"); sb.AppendLine($" rpc {methodName} ({requestType}) returns ({responseType});"); sb.AppendLine(); } sb.AppendLine("}"); sb.AppendLine(); } // Generate messages for commands foreach (var command in commands) { GenerateRequestMessage(command); GenerateResponseMessage(command); } // Generate messages for queries foreach (var query in queries) { GenerateRequestMessage(query); GenerateResponseMessage(query); } // Append all generated messages sb.Append(_messagesBuilder); // Insert imports if any were needed if (_requiredImports.Any()) { var imports = new StringBuilder(); foreach (var import in _requiredImports.OrderBy(i => i)) { imports.AppendLine($"import \"{import}\";"); } imports.AppendLine(); sb.Insert(importsPlaceholder, imports.ToString()); } return sb.ToString(); } private List DiscoverCommands() { return _compilation.GetSymbolsWithName( name => name.EndsWith("Command"), SymbolFilter.Type) .OfType() .Where(t => !HasGrpcIgnoreAttribute(t)) .Where(t => t.TypeKind == TypeKind.Class || t.TypeKind == TypeKind.Struct) .ToList(); } private List DiscoverQueries() { return _compilation.GetSymbolsWithName( name => name.EndsWith("Query"), SymbolFilter.Type) .OfType() .Where(t => !HasGrpcIgnoreAttribute(t)) .Where(t => t.TypeKind == TypeKind.Class || t.TypeKind == TypeKind.Struct) .ToList(); } private bool HasGrpcIgnoreAttribute(INamedTypeSymbol type) { return type.GetAttributes().Any(attr => attr.AttributeClass?.Name == "GrpcIgnoreAttribute"); } private void GenerateRequestMessage(INamedTypeSymbol type) { var messageName = $"{type.Name}Request"; if (_generatedMessages.Contains(messageName)) return; _generatedMessages.Add(messageName); _messagesBuilder.AppendLine($"// Request message for {type.Name}"); _messagesBuilder.AppendLine($"message {messageName} {{"); var properties = type.GetMembers() .OfType() .Where(p => p.DeclaredAccessibility == Accessibility.Public) .ToList(); int fieldNumber = 1; foreach (var prop in properties) { if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type)) { // Skip unsupported types and add a comment _messagesBuilder.AppendLine($" // Skipped: {prop.Name} - unsupported type {prop.Type.Name}"); continue; } var protoType = ProtoFileTypeMapper.MapType(prop.Type, out var needsImport, out var importPath); if (needsImport && importPath != null) { _requiredImports.Add(importPath); } var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name); _messagesBuilder.AppendLine($" {protoType} {fieldName} = {fieldNumber};"); // If this is a complex type, generate its message too if (IsComplexType(prop.Type)) { GenerateComplexTypeMessage(prop.Type as INamedTypeSymbol); } fieldNumber++; } _messagesBuilder.AppendLine("}"); _messagesBuilder.AppendLine(); } private void GenerateResponseMessage(INamedTypeSymbol type) { var messageName = $"{type.Name}Response"; if (_generatedMessages.Contains(messageName)) return; _generatedMessages.Add(messageName); _messagesBuilder.AppendLine($"// Response message for {type.Name}"); _messagesBuilder.AppendLine($"message {messageName} {{"); // Determine the result type from ICommandHandler or IQueryHandler var resultType = GetResultType(type); if (resultType != null) { var protoType = ProtoFileTypeMapper.MapType(resultType, out var needsImport, out var importPath); if (needsImport && importPath != null) { _requiredImports.Add(importPath); } _messagesBuilder.AppendLine($" {protoType} result = 1;"); } // If no result type, leave message empty (void return) _messagesBuilder.AppendLine("}"); _messagesBuilder.AppendLine(); // Generate complex type message after closing the response message if (resultType != null && IsComplexType(resultType)) { GenerateComplexTypeMessage(resultType as INamedTypeSymbol); } } private void GenerateComplexTypeMessage(INamedTypeSymbol? type) { if (type == null || _generatedMessages.Contains(type.Name)) return; // Don't generate messages for system types or primitives if (type.ContainingNamespace?.ToString().StartsWith("System") == true) return; _generatedMessages.Add(type.Name); _messagesBuilder.AppendLine($"// {type.Name} entity"); _messagesBuilder.AppendLine($"message {type.Name} {{"); var properties = type.GetMembers() .OfType() .Where(p => p.DeclaredAccessibility == Accessibility.Public) .ToList(); int fieldNumber = 1; foreach (var prop in properties) { if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type)) { _messagesBuilder.AppendLine($" // Skipped: {prop.Name} - unsupported type {prop.Type.Name}"); continue; } var protoType = ProtoFileTypeMapper.MapType(prop.Type, out var needsImport, out var importPath); if (needsImport && importPath != null) { _requiredImports.Add(importPath); } var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name); _messagesBuilder.AppendLine($" {protoType} {fieldName} = {fieldNumber};"); // Recursively generate nested complex types if (IsComplexType(prop.Type)) { GenerateComplexTypeMessage(prop.Type as INamedTypeSymbol); } fieldNumber++; } _messagesBuilder.AppendLine("}"); _messagesBuilder.AppendLine(); } private ITypeSymbol? GetResultType(INamedTypeSymbol commandOrQueryType) { // Scan for handler classes that implement ICommandHandler or IQueryHandler var handlerInterfaceName = commandOrQueryType.Name.EndsWith("Command") ? "ICommandHandler" : "IQueryHandler"; // Find all types in the compilation var allTypes = _compilation.GetSymbolsWithName(_ => true, SymbolFilter.Type) .OfType(); foreach (var type in allTypes) { // Check if this type implements the handler interface foreach (var @interface in type.AllInterfaces) { if (@interface.Name == handlerInterfaceName && @interface.TypeArguments.Length >= 1) { // Check if the first type argument matches our command/query var firstArg = @interface.TypeArguments[0]; if (SymbolEqualityComparer.Default.Equals(firstArg, commandOrQueryType)) { // Found the handler! Return the result type (second type argument) if it exists if (@interface.TypeArguments.Length == 2) { return @interface.TypeArguments[1]; } // If only one type argument, it's a void command (ICommandHandler) return null; } } } } return null; // No handler found } private bool IsComplexType(ITypeSymbol type) { // Check if it's a user-defined class/struct (not a primitive or system type) if (type.TypeKind != TypeKind.Class && type.TypeKind != TypeKind.Struct) return false; var fullName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); return !fullName.Contains("System."); } private string GetXmlDocSummary(INamedTypeSymbol type) { var xml = type.GetDocumentationCommentXml(); if (string.IsNullOrEmpty(xml)) return $"{type.Name} operation"; // Simple extraction - could be enhanced // xml is guaranteed non-null after IsNullOrEmpty check above var summaryStart = xml!.IndexOf(""); var summaryEnd = xml.IndexOf(""); if (summaryStart >= 0 && summaryEnd > summaryStart) { var summary = xml.Substring(summaryStart + 9, summaryEnd - summaryStart - 9).Trim(); return summary; } return $"{type.Name} operation"; } }