dotnet-cqrs/Svrnty.CQRS.Grpc.Generators/ProtoFileGenerator.cs
Mathias Beaulieu-Duncan f76dbb1a97 fix: add Guid to string conversion in gRPC source generator
The MapToProtoModel function was silently failing when mapping Guid
properties to proto string fields, causing IDs to be empty in gRPC
responses. Added explicit Guid → string conversion handling.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-27 19:06:18 -05:00

852 lines
33 KiB
C#

using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Svrnty.CQRS.Grpc.Generators.Models;
namespace Svrnty.CQRS.Grpc.Generators;
/// <summary>
/// Generates Protocol Buffer (.proto) files from C# Command, Query, and Notification types
/// </summary>
internal class ProtoFileGenerator
{
private readonly Compilation _compilation;
private readonly HashSet<string> _requiredImports = new HashSet<string>();
private readonly HashSet<string> _generatedMessages = new HashSet<string>();
private readonly HashSet<string> _generatedEnums = new HashSet<string>();
private readonly List<INamedTypeSymbol> _pendingEnums = new List<INamedTypeSymbol>();
private readonly StringBuilder _messagesBuilder = new StringBuilder();
private readonly StringBuilder _enumsBuilder = new StringBuilder();
private List<INamedTypeSymbol>? _allTypesCache;
/// <summary>
/// Gets the discovered notifications after Generate() is called.
/// </summary>
public List<NotificationInfo> DiscoveredNotifications { get; private set; } = new List<NotificationInfo>();
public ProtoFileGenerator(Compilation compilation)
{
_compilation = compilation;
}
/// <summary>
/// Gets all types from the compilation and all referenced assemblies
/// </summary>
private IEnumerable<INamedTypeSymbol> GetAllTypes()
{
if (_allTypesCache != null)
return _allTypesCache;
_allTypesCache = new List<INamedTypeSymbol>();
// Get types from the current assembly
CollectTypesFromNamespace(_compilation.Assembly.GlobalNamespace, _allTypesCache);
// Get types from all referenced assemblies
foreach (var reference in _compilation.References)
{
var assemblySymbol = _compilation.GetAssemblyOrModuleSymbol(reference) as IAssemblySymbol;
if (assemblySymbol != null)
{
CollectTypesFromNamespace(assemblySymbol.GlobalNamespace, _allTypesCache);
}
}
return _allTypesCache;
}
private static void CollectTypesFromNamespace(INamespaceSymbol ns, List<INamedTypeSymbol> types)
{
foreach (var type in ns.GetTypeMembers())
{
types.Add(type);
CollectNestedTypes(type, types);
}
foreach (var nestedNs in ns.GetNamespaceMembers())
{
CollectTypesFromNamespace(nestedNs, types);
}
}
private static void CollectNestedTypes(INamedTypeSymbol type, List<INamedTypeSymbol> types)
{
foreach (var nestedType in type.GetTypeMembers())
{
types.Add(nestedType);
CollectNestedTypes(nestedType, types);
}
}
public string Generate(string packageName, string csharpNamespace)
{
var commands = DiscoverCommands();
var queries = DiscoverQueries();
var dynamicQueries = DiscoverDynamicQueries();
var notifications = DiscoverNotifications();
DiscoveredNotifications = notifications;
var sb = new StringBuilder();
// Header
sb.AppendLine("syntax = \"proto3\";");
sb.AppendLine();
sb.AppendLine($"option csharp_namespace = \"{csharpNamespace}\";");
sb.AppendLine();
sb.AppendLine($"package {packageName};");
sb.AppendLine();
// Imports (will be added later if needed)
var importsPlaceholder = sb.Length;
// Command Service
if (commands.Any())
{
sb.AppendLine("// Command service for CQRS operations");
sb.AppendLine("service CommandService {");
foreach (var command in commands)
{
var methodName = command.Name.Replace("Command", "");
var requestType = $"{command.Name}Request";
var responseType = $"{command.Name}Response";
sb.AppendLine($" // {GetXmlDocSummary(command)}");
sb.AppendLine($" rpc {methodName} ({requestType}) returns ({responseType});");
sb.AppendLine();
}
sb.AppendLine("}");
sb.AppendLine();
}
// Query Service
if (queries.Any())
{
sb.AppendLine("// Query service for CQRS operations");
sb.AppendLine("service QueryService {");
foreach (var query in queries)
{
var methodName = query.Name.Replace("Query", "");
var requestType = $"{query.Name}Request";
var responseType = $"{query.Name}Response";
sb.AppendLine($" // {GetXmlDocSummary(query)}");
sb.AppendLine($" rpc {methodName} ({requestType}) returns ({responseType});");
sb.AppendLine();
}
sb.AppendLine("}");
sb.AppendLine();
}
// DynamicQuery Service
if (dynamicQueries.Any())
{
sb.AppendLine("// DynamicQuery service for CQRS operations");
sb.AppendLine("service DynamicQueryService {");
foreach (var dq in dynamicQueries)
{
var entityName = dq.Name;
var pluralName = Pluralize(entityName);
var methodName = $"Query{pluralName}";
var requestType = $"DynamicQuery{pluralName}Request";
var responseType = $"DynamicQuery{pluralName}Response";
sb.AppendLine($" // Dynamic query for {entityName}");
sb.AppendLine($" rpc {methodName} ({requestType}) returns ({responseType});");
sb.AppendLine();
}
sb.AppendLine("}");
sb.AppendLine();
}
// Notification Service (server streaming)
if (notifications.Any())
{
sb.AppendLine("// NotificationService for real-time streaming notifications");
sb.AppendLine("service NotificationService {");
foreach (var notification in notifications)
{
var methodName = $"SubscribeTo{notification.Name}";
var requestType = $"SubscribeTo{notification.Name}Request";
sb.AppendLine($" // Subscribe to {notification.Name} notifications");
sb.AppendLine($" rpc {methodName} ({requestType}) returns (stream {notification.Name});");
sb.AppendLine();
}
sb.AppendLine("}");
sb.AppendLine();
}
// Generate messages for commands
foreach (var command in commands)
{
GenerateRequestMessage(command);
GenerateResponseMessage(command);
}
// Generate messages for queries
foreach (var query in queries)
{
GenerateRequestMessage(query);
GenerateResponseMessage(query);
}
// Generate messages for dynamic queries
foreach (var dq in dynamicQueries)
{
GenerateDynamicQueryMessages(dq);
}
// Generate messages for notifications
foreach (var notification in notifications)
{
GenerateNotificationMessages(notification);
}
// Generate any pending enum definitions
GeneratePendingEnums();
// Append all generated enums first, then messages
sb.Append(_enumsBuilder);
sb.Append(_messagesBuilder);
// Insert imports if any were needed
if (_requiredImports.Any())
{
var imports = new StringBuilder();
foreach (var import in _requiredImports.OrderBy(i => i))
{
imports.AppendLine($"import \"{import}\";");
}
imports.AppendLine();
sb.Insert(importsPlaceholder, imports.ToString());
}
return sb.ToString();
}
private List<INamedTypeSymbol> DiscoverCommands()
{
// First, find all command handlers to know which commands are actually registered
var commandHandlerInterface = _compilation.GetTypeByMetadataName("Svrnty.CQRS.Abstractions.ICommandHandler`1");
var commandHandlerWithResultInterface = _compilation.GetTypeByMetadataName("Svrnty.CQRS.Abstractions.ICommandHandler`2");
if (commandHandlerInterface == null && commandHandlerWithResultInterface == null)
return new List<INamedTypeSymbol>();
var registeredCommands = new HashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default);
foreach (var type in GetAllTypes())
{
if (type.IsAbstract || type.IsStatic)
continue;
foreach (var iface in type.AllInterfaces)
{
if (iface.IsGenericType)
{
if ((commandHandlerInterface != null && SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, commandHandlerInterface)) ||
(commandHandlerWithResultInterface != null && SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, commandHandlerWithResultInterface)))
{
var commandType = iface.TypeArguments[0] as INamedTypeSymbol;
if (commandType != null && !HasGrpcIgnoreAttribute(commandType))
{
registeredCommands.Add(commandType);
}
}
}
}
}
return registeredCommands.ToList();
}
private List<INamedTypeSymbol> DiscoverQueries()
{
// First, find all query handlers to know which queries are actually registered
var queryHandlerInterface = _compilation.GetTypeByMetadataName("Svrnty.CQRS.Abstractions.IQueryHandler`2");
var dynamicQueryInterface2 = _compilation.GetTypeByMetadataName("Svrnty.CQRS.DynamicQuery.Abstractions.IDynamicQuery`2");
var dynamicQueryInterface3 = _compilation.GetTypeByMetadataName("Svrnty.CQRS.DynamicQuery.Abstractions.IDynamicQuery`3");
if (queryHandlerInterface == null)
return new List<INamedTypeSymbol>();
var registeredQueries = new HashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default);
foreach (var type in GetAllTypes())
{
if (type.IsAbstract || type.IsStatic)
continue;
foreach (var iface in type.AllInterfaces)
{
if (iface.IsGenericType && SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, queryHandlerInterface))
{
var queryType = iface.TypeArguments[0] as INamedTypeSymbol;
if (queryType != null && !HasGrpcIgnoreAttribute(queryType))
{
// Skip dynamic queries - they're handled separately
if (queryType.IsGenericType &&
((dynamicQueryInterface2 != null && SymbolEqualityComparer.Default.Equals(queryType.OriginalDefinition, dynamicQueryInterface2)) ||
(dynamicQueryInterface3 != null && SymbolEqualityComparer.Default.Equals(queryType.OriginalDefinition, dynamicQueryInterface3))))
{
continue;
}
registeredQueries.Add(queryType);
}
}
}
}
return registeredQueries.ToList();
}
private bool HasGrpcIgnoreAttribute(INamedTypeSymbol type)
{
return type.GetAttributes().Any(attr =>
attr.AttributeClass?.Name == "GrpcIgnoreAttribute");
}
private void GenerateRequestMessage(INamedTypeSymbol type)
{
var messageName = $"{type.Name}Request";
if (_generatedMessages.Contains(messageName))
return;
_generatedMessages.Add(messageName);
_messagesBuilder.AppendLine($"// Request message for {type.Name}");
_messagesBuilder.AppendLine($"message {messageName} {{");
var properties = type.GetMembers()
.OfType<IPropertySymbol>()
.Where(p => p.DeclaredAccessibility == Accessibility.Public)
.ToList();
// Collect nested complex types to generate after closing this message
var nestedComplexTypes = new List<INamedTypeSymbol>();
int fieldNumber = 1;
foreach (var prop in properties)
{
if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type))
{
// Skip unsupported types and add a comment
_messagesBuilder.AppendLine($" // Skipped: {prop.Name} - unsupported type {prop.Type.Name}");
continue;
}
var protoType = ProtoFileTypeMapper.MapType(prop.Type, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name);
_messagesBuilder.AppendLine($" {protoType} {fieldName} = {fieldNumber};");
// Track enums for later generation
var enumType = ProtoFileTypeMapper.GetEnumType(prop.Type);
if (enumType != null)
{
TrackEnumType(enumType);
}
// Collect complex types to generate after this message is closed
// Use GetElementOrUnderlyingType to extract element type from collections
var underlyingType = ProtoFileTypeMapper.GetElementOrUnderlyingType(prop.Type);
if (IsComplexType(underlyingType) && underlyingType is INamedTypeSymbol namedType)
{
nestedComplexTypes.Add(namedType);
}
fieldNumber++;
}
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
// Now generate nested complex type messages
foreach (var nestedType in nestedComplexTypes)
{
GenerateComplexTypeMessage(nestedType);
}
}
private void GenerateResponseMessage(INamedTypeSymbol type)
{
var messageName = $"{type.Name}Response";
if (_generatedMessages.Contains(messageName))
return;
_generatedMessages.Add(messageName);
_messagesBuilder.AppendLine($"// Response message for {type.Name}");
_messagesBuilder.AppendLine($"message {messageName} {{");
// Determine the result type from ICommandHandler<T, TResult> or IQueryHandler<T, TResult>
var resultType = GetResultType(type);
if (resultType != null)
{
var protoType = ProtoFileTypeMapper.MapType(resultType, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
_messagesBuilder.AppendLine($" {protoType} result = 1;");
}
// If no result type, leave message empty (void return)
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
// Generate complex type message after closing the response message
if (resultType != null && IsComplexType(resultType))
{
GenerateComplexTypeMessage(resultType as INamedTypeSymbol);
}
}
private void GenerateComplexTypeMessage(INamedTypeSymbol? type)
{
if (type == null || _generatedMessages.Contains(type.Name))
return;
// Don't generate messages for system types or primitives
if (type.ContainingNamespace?.ToString().StartsWith("System") == true)
return;
_generatedMessages.Add(type.Name);
_messagesBuilder.AppendLine($"// {type.Name} entity");
_messagesBuilder.AppendLine($"message {type.Name} {{");
var properties = type.GetMembers()
.OfType<IPropertySymbol>()
.Where(p => p.DeclaredAccessibility == Accessibility.Public)
.ToList();
// Collect nested complex types to generate after closing this message
var nestedComplexTypes = new List<INamedTypeSymbol>();
int fieldNumber = 1;
foreach (var prop in properties)
{
if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type))
{
_messagesBuilder.AppendLine($" // Skipped: {prop.Name} - unsupported type {prop.Type.Name}");
continue;
}
var protoType = ProtoFileTypeMapper.MapType(prop.Type, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name);
_messagesBuilder.AppendLine($" {protoType} {fieldName} = {fieldNumber};");
// Track enums for later generation
var enumType = ProtoFileTypeMapper.GetEnumType(prop.Type);
if (enumType != null)
{
TrackEnumType(enumType);
}
// Collect complex types to generate after this message is closed
// Use GetElementOrUnderlyingType to extract element type from collections
var underlyingType = ProtoFileTypeMapper.GetElementOrUnderlyingType(prop.Type);
if (IsComplexType(underlyingType) && underlyingType is INamedTypeSymbol namedType)
{
nestedComplexTypes.Add(namedType);
}
fieldNumber++;
}
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
// Now generate nested complex type messages
foreach (var nestedType in nestedComplexTypes)
{
GenerateComplexTypeMessage(nestedType);
}
}
private ITypeSymbol? GetResultType(INamedTypeSymbol commandOrQueryType)
{
// Scan for handler classes that implement ICommandHandler<T, TResult> or IQueryHandler<T, TResult>
var handlerInterfaceName = commandOrQueryType.Name.EndsWith("Command")
? "ICommandHandler"
: "IQueryHandler";
// Find all types in the compilation and referenced assemblies
foreach (var type in GetAllTypes())
{
// Check if this type implements the handler interface
foreach (var @interface in type.AllInterfaces)
{
if (@interface.Name == handlerInterfaceName && @interface.TypeArguments.Length >= 1)
{
// Check if the first type argument matches our command/query
var firstArg = @interface.TypeArguments[0];
if (SymbolEqualityComparer.Default.Equals(firstArg, commandOrQueryType))
{
// Found the handler! Return the result type (second type argument) if it exists
if (@interface.TypeArguments.Length == 2)
{
return @interface.TypeArguments[1];
}
// If only one type argument, it's a void command (ICommandHandler<T>)
return null;
}
}
}
}
return null; // No handler found
}
private bool IsComplexType(ITypeSymbol type)
{
// Check if it's a user-defined class/struct (not a primitive or system type)
if (type.TypeKind != TypeKind.Class && type.TypeKind != TypeKind.Struct)
return false;
var fullName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
return !fullName.Contains("System.");
}
private string GetXmlDocSummary(INamedTypeSymbol type)
{
var xml = type.GetDocumentationCommentXml();
if (string.IsNullOrEmpty(xml))
return $"{type.Name} operation";
// Simple extraction - could be enhanced
// xml is guaranteed non-null after IsNullOrEmpty check above
var summaryStart = xml!.IndexOf("<summary>");
var summaryEnd = xml.IndexOf("</summary>");
if (summaryStart >= 0 && summaryEnd > summaryStart)
{
var summary = xml.Substring(summaryStart + 9, summaryEnd - summaryStart - 9).Trim();
return summary;
}
return $"{type.Name} operation";
}
private List<INamedTypeSymbol> DiscoverDynamicQueries()
{
// Find IQueryableProvider<T> implementations
var queryableProviderInterface = _compilation.GetTypeByMetadataName("Svrnty.CQRS.DynamicQuery.Abstractions.IQueryableProvider`1");
if (queryableProviderInterface == null)
return new List<INamedTypeSymbol>();
var dynamicQueryTypes = new List<INamedTypeSymbol>();
foreach (var type in GetAllTypes())
{
if (type.IsAbstract || type.IsStatic)
continue;
foreach (var iface in type.AllInterfaces)
{
if (iface.IsGenericType && SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, queryableProviderInterface))
{
// Extract the entity type from IQueryableProvider<TEntity>
var entityType = iface.TypeArguments[0] as INamedTypeSymbol;
if (entityType != null && !dynamicQueryTypes.Contains(entityType, SymbolEqualityComparer.Default))
{
dynamicQueryTypes.Add(entityType);
}
}
}
}
return dynamicQueryTypes;
}
private void GenerateDynamicQueryMessages(INamedTypeSymbol entityType)
{
var pluralName = Pluralize(entityType.Name);
var requestMessageName = $"DynamicQuery{pluralName}Request";
var responseMessageName = $"DynamicQuery{pluralName}Response";
// Common filter/sort/group/aggregate types (only generate once)
if (!_generatedMessages.Contains("DynamicQueryFilter"))
{
_generatedMessages.Add("DynamicQueryFilter");
_messagesBuilder.AppendLine("// Dynamic query filter with AND/OR support");
_messagesBuilder.AppendLine("message DynamicQueryFilter {");
_messagesBuilder.AppendLine(" string path = 1;");
_messagesBuilder.AppendLine(" int32 type = 2; // PoweredSoft.DynamicQuery.Core.FilterType");
_messagesBuilder.AppendLine(" string value = 3;");
_messagesBuilder.AppendLine(" repeated DynamicQueryFilter and = 4;");
_messagesBuilder.AppendLine(" repeated DynamicQueryFilter or = 5;");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
_messagesBuilder.AppendLine("// Dynamic query sort");
_messagesBuilder.AppendLine("message DynamicQuerySort {");
_messagesBuilder.AppendLine(" string path = 1;");
_messagesBuilder.AppendLine(" bool ascending = 2;");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
_messagesBuilder.AppendLine("// Dynamic query group");
_messagesBuilder.AppendLine("message DynamicQueryGroup {");
_messagesBuilder.AppendLine(" string path = 1;");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
_messagesBuilder.AppendLine("// Dynamic query aggregate");
_messagesBuilder.AppendLine("message DynamicQueryAggregate {");
_messagesBuilder.AppendLine(" string path = 1;");
_messagesBuilder.AppendLine(" int32 type = 2; // PoweredSoft.DynamicQuery.Core.AggregateType");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
}
// Request message
_messagesBuilder.AppendLine($"// Dynamic query request for {entityType.Name}");
_messagesBuilder.AppendLine($"message {requestMessageName} {{");
_messagesBuilder.AppendLine(" int32 page = 1;");
_messagesBuilder.AppendLine(" int32 page_size = 2;");
_messagesBuilder.AppendLine(" repeated DynamicQueryFilter filters = 3;");
_messagesBuilder.AppendLine(" repeated DynamicQuerySort sorts = 4;");
_messagesBuilder.AppendLine(" repeated DynamicQueryGroup groups = 5;");
_messagesBuilder.AppendLine(" repeated DynamicQueryAggregate aggregates = 6;");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
// Response message
_messagesBuilder.AppendLine($"// Dynamic query response for {entityType.Name}");
_messagesBuilder.AppendLine($"message {responseMessageName} {{");
_messagesBuilder.AppendLine($" repeated {entityType.Name} data = 1;");
_messagesBuilder.AppendLine(" int64 total_records = 2;");
_messagesBuilder.AppendLine(" int32 number_of_pages = 3;");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
// Generate entity message if not already generated
GenerateComplexTypeMessage(entityType);
}
private static string Pluralize(string word)
{
if (word.EndsWith("y") && word.Length > 1 && !"aeiou".Contains(word[word.Length - 2].ToString()))
return word.Substring(0, word.Length - 1) + "ies";
if (word.EndsWith("s") || word.EndsWith("x") || word.EndsWith("z") || word.EndsWith("ch") || word.EndsWith("sh"))
return word + "es";
return word + "s";
}
/// <summary>
/// Tracks an enum type for later generation
/// </summary>
private void TrackEnumType(INamedTypeSymbol enumType)
{
if (!_generatedEnums.Contains(enumType.Name) && !_pendingEnums.Any(e => e.Name == enumType.Name))
{
_pendingEnums.Add(enumType);
}
}
/// <summary>
/// Generates all pending enum definitions
/// </summary>
private void GeneratePendingEnums()
{
foreach (var enumType in _pendingEnums)
{
if (_generatedEnums.Contains(enumType.Name))
continue;
_generatedEnums.Add(enumType.Name);
_enumsBuilder.AppendLine($"// {enumType.Name} enum");
_enumsBuilder.AppendLine($"enum {enumType.Name} {{");
// Get all enum members
var members = enumType.GetMembers()
.OfType<IFieldSymbol>()
.Where(f => f.HasConstantValue)
.ToList();
foreach (var member in members)
{
var protoFieldName = $"{ProtoFileTypeMapper.ToSnakeCase(enumType.Name).ToUpperInvariant()}_{ProtoFileTypeMapper.ToSnakeCase(member.Name).ToUpperInvariant()}";
var value = member.ConstantValue;
_enumsBuilder.AppendLine($" {protoFieldName} = {value};");
}
_enumsBuilder.AppendLine("}");
_enumsBuilder.AppendLine();
}
}
/// <summary>
/// Discovers types marked with [StreamingNotification] attribute
/// </summary>
private List<NotificationInfo> DiscoverNotifications()
{
var streamingNotificationAttribute = _compilation.GetTypeByMetadataName(
"Svrnty.CQRS.Notifications.Abstractions.StreamingNotificationAttribute");
if (streamingNotificationAttribute == null)
return new List<NotificationInfo>();
var notifications = new List<NotificationInfo>();
foreach (var type in GetAllTypes())
{
if (type.IsAbstract || type.IsStatic)
continue;
var attr = type.GetAttributes()
.FirstOrDefault(a => SymbolEqualityComparer.Default.Equals(
a.AttributeClass, streamingNotificationAttribute));
if (attr == null)
continue;
// Extract SubscriptionKey from attribute
var subscriptionKeyArg = attr.NamedArguments
.FirstOrDefault(a => a.Key == "SubscriptionKey");
var subscriptionKeyProp = subscriptionKeyArg.Value.Value as string;
if (string.IsNullOrEmpty(subscriptionKeyProp))
continue;
// Get all properties of the notification type
var properties = ExtractNotificationProperties(type);
// Find the subscription key property info
var keyPropInfo = properties.FirstOrDefault(p => p.Name == subscriptionKeyProp);
if (keyPropInfo == null)
continue;
notifications.Add(new NotificationInfo
{
Name = type.Name,
FullyQualifiedName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
.Replace("global::", ""),
Namespace = type.ContainingNamespace?.ToDisplayString() ?? "",
SubscriptionKeyProperty = subscriptionKeyProp!, // Already validated as non-null above
SubscriptionKeyInfo = keyPropInfo,
Properties = properties
});
}
return notifications;
}
/// <summary>
/// Extracts property information from a notification type
/// </summary>
private List<Models.PropertyInfo> ExtractNotificationProperties(INamedTypeSymbol type)
{
var properties = new List<Models.PropertyInfo>();
int fieldNumber = 1;
foreach (var prop in type.GetMembers().OfType<IPropertySymbol>()
.Where(p => p.DeclaredAccessibility == Accessibility.Public))
{
if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type))
continue;
var protoType = ProtoFileTypeMapper.MapType(prop.Type, out _, out _);
var enumType = ProtoFileTypeMapper.GetEnumType(prop.Type);
properties.Add(new Models.PropertyInfo
{
Name = prop.Name,
Type = prop.Type.Name,
FullyQualifiedType = prop.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
.Replace("global::", ""),
ProtoType = protoType,
FieldNumber = fieldNumber++,
IsEnum = enumType != null,
IsDecimal = prop.Type.SpecialType == SpecialType.System_Decimal ||
prop.Type.ToDisplayString().Contains("decimal"),
IsDateTime = prop.Type.ToDisplayString().Contains("DateTime"),
IsNullable = prop.Type.NullableAnnotation == NullableAnnotation.Annotated ||
(prop.Type is INamedTypeSymbol namedType &&
namedType.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T)
});
if (enumType != null)
{
TrackEnumType(enumType);
}
}
return properties;
}
/// <summary>
/// Generates proto messages for a notification type
/// </summary>
private void GenerateNotificationMessages(NotificationInfo notification)
{
// Generate subscription request message (contains only the subscription key)
var requestMessageName = $"SubscribeTo{notification.Name}Request";
if (!_generatedMessages.Contains(requestMessageName))
{
_generatedMessages.Add(requestMessageName);
_messagesBuilder.AppendLine($"// Subscription request for {notification.Name}");
_messagesBuilder.AppendLine($"message {requestMessageName} {{");
_messagesBuilder.AppendLine($" {notification.SubscriptionKeyInfo.ProtoType} {ProtoFileTypeMapper.ToSnakeCase(notification.SubscriptionKeyProperty)} = 1;");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
}
// Generate the notification message itself
if (!_generatedMessages.Contains(notification.Name))
{
_generatedMessages.Add(notification.Name);
_messagesBuilder.AppendLine($"// {notification.Name} streaming notification");
_messagesBuilder.AppendLine($"message {notification.Name} {{");
foreach (var prop in notification.Properties)
{
var typeSymbol = _compilation.GetTypeByMetadataName(prop.FullyQualifiedType) ??
GetTypeFromName(prop.FullyQualifiedType);
if (typeSymbol != null)
{
ProtoFileTypeMapper.MapType(typeSymbol, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
}
var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name);
_messagesBuilder.AppendLine($" {prop.ProtoType} {fieldName} = {prop.FieldNumber};");
}
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
}
}
/// <summary>
/// Gets a type symbol from a type name by searching all types
/// </summary>
private ITypeSymbol? GetTypeFromName(string fullTypeName)
{
// Try to find the type in all types
return GetAllTypes().FirstOrDefault(t =>
t.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).Replace("global::", "") == fullTypeName ||
t.ToDisplayString() == fullTypeName);
}
}