dotnet-cqrs/Svrnty.CQRS.Grpc.Generators/ProtoFileGenerator.cs

850 lines
33 KiB
C#

using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Svrnty.CQRS.Grpc.Generators.Models;
namespace Svrnty.CQRS.Grpc.Generators;
/// <summary>
/// Generates Protocol Buffer (.proto) files from C# Command, Query, and Notification types
/// </summary>
internal class ProtoFileGenerator
{
private readonly Compilation _compilation;
private readonly HashSet<string> _requiredImports = new HashSet<string>();
private readonly HashSet<string> _generatedMessages = new HashSet<string>();
private readonly HashSet<string> _generatedEnums = new HashSet<string>();
private readonly List<INamedTypeSymbol> _pendingEnums = new List<INamedTypeSymbol>();
private readonly StringBuilder _messagesBuilder = new StringBuilder();
private readonly StringBuilder _enumsBuilder = new StringBuilder();
private List<INamedTypeSymbol>? _allTypesCache;
/// <summary>
/// Gets the discovered notifications after Generate() is called.
/// </summary>
public List<NotificationInfo> DiscoveredNotifications { get; private set; } = new List<NotificationInfo>();
public ProtoFileGenerator(Compilation compilation)
{
_compilation = compilation;
}
/// <summary>
/// Gets all types from the compilation and all referenced assemblies
/// </summary>
private IEnumerable<INamedTypeSymbol> GetAllTypes()
{
if (_allTypesCache != null)
return _allTypesCache;
_allTypesCache = new List<INamedTypeSymbol>();
// Get types from the current assembly
CollectTypesFromNamespace(_compilation.Assembly.GlobalNamespace, _allTypesCache);
// Get types from all referenced assemblies
foreach (var reference in _compilation.References)
{
var assemblySymbol = _compilation.GetAssemblyOrModuleSymbol(reference) as IAssemblySymbol;
if (assemblySymbol != null)
{
CollectTypesFromNamespace(assemblySymbol.GlobalNamespace, _allTypesCache);
}
}
return _allTypesCache;
}
private static void CollectTypesFromNamespace(INamespaceSymbol ns, List<INamedTypeSymbol> types)
{
foreach (var type in ns.GetTypeMembers())
{
types.Add(type);
CollectNestedTypes(type, types);
}
foreach (var nestedNs in ns.GetNamespaceMembers())
{
CollectTypesFromNamespace(nestedNs, types);
}
}
private static void CollectNestedTypes(INamedTypeSymbol type, List<INamedTypeSymbol> types)
{
foreach (var nestedType in type.GetTypeMembers())
{
types.Add(nestedType);
CollectNestedTypes(nestedType, types);
}
}
public string Generate(string packageName, string csharpNamespace)
{
var commands = DiscoverCommands();
var queries = DiscoverQueries();
var dynamicQueries = DiscoverDynamicQueries();
var notifications = DiscoverNotifications();
DiscoveredNotifications = notifications;
var sb = new StringBuilder();
// Header
sb.AppendLine("syntax = \"proto3\";");
sb.AppendLine();
sb.AppendLine($"option csharp_namespace = \"{csharpNamespace}\";");
sb.AppendLine();
sb.AppendLine($"package {packageName};");
sb.AppendLine();
// Imports (will be added later if needed)
var importsPlaceholder = sb.Length;
// Command Service
if (commands.Any())
{
sb.AppendLine("// Command service for CQRS operations");
sb.AppendLine("service CommandService {");
foreach (var command in commands)
{
var methodName = command.Name.Replace("Command", "");
var requestType = $"{command.Name}Request";
var responseType = $"{command.Name}Response";
sb.AppendLine($" // {GetXmlDocSummary(command)}");
sb.AppendLine($" rpc {methodName} ({requestType}) returns ({responseType});");
sb.AppendLine();
}
sb.AppendLine("}");
sb.AppendLine();
}
// Query Service
if (queries.Any())
{
sb.AppendLine("// Query service for CQRS operations");
sb.AppendLine("service QueryService {");
foreach (var query in queries)
{
var methodName = query.Name.Replace("Query", "");
var requestType = $"{query.Name}Request";
var responseType = $"{query.Name}Response";
sb.AppendLine($" // {GetXmlDocSummary(query)}");
sb.AppendLine($" rpc {methodName} ({requestType}) returns ({responseType});");
sb.AppendLine();
}
sb.AppendLine("}");
sb.AppendLine();
}
// DynamicQuery Service
if (dynamicQueries.Any())
{
sb.AppendLine("// DynamicQuery service for CQRS operations");
sb.AppendLine("service DynamicQueryService {");
foreach (var dq in dynamicQueries)
{
var entityName = dq.Name;
var pluralName = Pluralize(entityName);
var methodName = $"Query{pluralName}";
var requestType = $"DynamicQuery{pluralName}Request";
var responseType = $"DynamicQuery{pluralName}Response";
sb.AppendLine($" // Dynamic query for {entityName}");
sb.AppendLine($" rpc {methodName} ({requestType}) returns ({responseType});");
sb.AppendLine();
}
sb.AppendLine("}");
sb.AppendLine();
}
// Notification Service (server streaming)
if (notifications.Any())
{
sb.AppendLine("// NotificationService for real-time streaming notifications");
sb.AppendLine("service NotificationService {");
foreach (var notification in notifications)
{
var methodName = $"SubscribeTo{notification.Name}";
var requestType = $"SubscribeTo{notification.Name}Request";
sb.AppendLine($" // Subscribe to {notification.Name} notifications");
sb.AppendLine($" rpc {methodName} ({requestType}) returns (stream {notification.Name});");
sb.AppendLine();
}
sb.AppendLine("}");
sb.AppendLine();
}
// Generate messages for commands
foreach (var command in commands)
{
GenerateRequestMessage(command);
GenerateResponseMessage(command);
}
// Generate messages for queries
foreach (var query in queries)
{
GenerateRequestMessage(query);
GenerateResponseMessage(query);
}
// Generate messages for dynamic queries
foreach (var dq in dynamicQueries)
{
GenerateDynamicQueryMessages(dq);
}
// Generate messages for notifications
foreach (var notification in notifications)
{
GenerateNotificationMessages(notification);
}
// Generate any pending enum definitions
GeneratePendingEnums();
// Append all generated enums first, then messages
sb.Append(_enumsBuilder);
sb.Append(_messagesBuilder);
// Insert imports if any were needed
if (_requiredImports.Any())
{
var imports = new StringBuilder();
foreach (var import in _requiredImports.OrderBy(i => i))
{
imports.AppendLine($"import \"{import}\";");
}
imports.AppendLine();
sb.Insert(importsPlaceholder, imports.ToString());
}
return sb.ToString();
}
private List<INamedTypeSymbol> DiscoverCommands()
{
// First, find all command handlers to know which commands are actually registered
var commandHandlerInterface = _compilation.GetTypeByMetadataName("Svrnty.CQRS.Abstractions.ICommandHandler`1");
var commandHandlerWithResultInterface = _compilation.GetTypeByMetadataName("Svrnty.CQRS.Abstractions.ICommandHandler`2");
if (commandHandlerInterface == null && commandHandlerWithResultInterface == null)
return new List<INamedTypeSymbol>();
var registeredCommands = new HashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default);
foreach (var type in GetAllTypes())
{
if (type.IsAbstract || type.IsStatic)
continue;
foreach (var iface in type.AllInterfaces)
{
if (iface.IsGenericType)
{
if ((commandHandlerInterface != null && SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, commandHandlerInterface)) ||
(commandHandlerWithResultInterface != null && SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, commandHandlerWithResultInterface)))
{
var commandType = iface.TypeArguments[0] as INamedTypeSymbol;
if (commandType != null && !HasGrpcIgnoreAttribute(commandType))
{
registeredCommands.Add(commandType);
}
}
}
}
}
return registeredCommands.ToList();
}
private List<INamedTypeSymbol> DiscoverQueries()
{
// First, find all query handlers to know which queries are actually registered
var queryHandlerInterface = _compilation.GetTypeByMetadataName("Svrnty.CQRS.Abstractions.IQueryHandler`2");
var dynamicQueryInterface2 = _compilation.GetTypeByMetadataName("Svrnty.CQRS.DynamicQuery.Abstractions.IDynamicQuery`2");
var dynamicQueryInterface3 = _compilation.GetTypeByMetadataName("Svrnty.CQRS.DynamicQuery.Abstractions.IDynamicQuery`3");
if (queryHandlerInterface == null)
return new List<INamedTypeSymbol>();
var registeredQueries = new HashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default);
foreach (var type in GetAllTypes())
{
if (type.IsAbstract || type.IsStatic)
continue;
foreach (var iface in type.AllInterfaces)
{
if (iface.IsGenericType && SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, queryHandlerInterface))
{
var queryType = iface.TypeArguments[0] as INamedTypeSymbol;
if (queryType != null && !HasGrpcIgnoreAttribute(queryType))
{
// Skip dynamic queries - they're handled separately
if (queryType.IsGenericType &&
((dynamicQueryInterface2 != null && SymbolEqualityComparer.Default.Equals(queryType.OriginalDefinition, dynamicQueryInterface2)) ||
(dynamicQueryInterface3 != null && SymbolEqualityComparer.Default.Equals(queryType.OriginalDefinition, dynamicQueryInterface3))))
{
continue;
}
registeredQueries.Add(queryType);
}
}
}
}
return registeredQueries.ToList();
}
private bool HasGrpcIgnoreAttribute(INamedTypeSymbol type)
{
return type.GetAttributes().Any(attr =>
attr.AttributeClass?.Name == "GrpcIgnoreAttribute");
}
private void GenerateRequestMessage(INamedTypeSymbol type)
{
var messageName = $"{type.Name}Request";
if (_generatedMessages.Contains(messageName))
return;
_generatedMessages.Add(messageName);
_messagesBuilder.AppendLine($"// Request message for {type.Name}");
_messagesBuilder.AppendLine($"message {messageName} {{");
var properties = type.GetMembers()
.OfType<IPropertySymbol>()
.Where(p => p.DeclaredAccessibility == Accessibility.Public)
.ToList();
// Collect nested complex types to generate after closing this message
var nestedComplexTypes = new List<INamedTypeSymbol>();
int fieldNumber = 1;
foreach (var prop in properties)
{
if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type))
{
// Skip unsupported types and add a comment
_messagesBuilder.AppendLine($" // Skipped: {prop.Name} - unsupported type {prop.Type.Name}");
continue;
}
var protoType = ProtoFileTypeMapper.MapType(prop.Type, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name);
_messagesBuilder.AppendLine($" {protoType} {fieldName} = {fieldNumber};");
// Track enums for later generation
var enumType = ProtoFileTypeMapper.GetEnumType(prop.Type);
if (enumType != null)
{
TrackEnumType(enumType);
}
// Collect complex types to generate after this message is closed
// Use GetElementOrUnderlyingType to extract element type from collections
var underlyingType = ProtoFileTypeMapper.GetElementOrUnderlyingType(prop.Type);
if (IsComplexType(underlyingType) && underlyingType is INamedTypeSymbol namedType)
{
nestedComplexTypes.Add(namedType);
}
fieldNumber++;
}
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
// Now generate nested complex type messages
foreach (var nestedType in nestedComplexTypes)
{
GenerateComplexTypeMessage(nestedType);
}
}
private void GenerateResponseMessage(INamedTypeSymbol type)
{
var messageName = $"{type.Name}Response";
if (_generatedMessages.Contains(messageName))
return;
_generatedMessages.Add(messageName);
_messagesBuilder.AppendLine($"// Response message for {type.Name}");
_messagesBuilder.AppendLine($"message {messageName} {{");
// Determine the result type from ICommandHandler<T, TResult> or IQueryHandler<T, TResult>
var resultType = GetResultType(type);
if (resultType != null)
{
var protoType = ProtoFileTypeMapper.MapType(resultType, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
_messagesBuilder.AppendLine($" {protoType} result = 1;");
}
// If no result type, leave message empty (void return)
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
// Generate complex type message after closing the response message
if (resultType != null && IsComplexType(resultType))
{
GenerateComplexTypeMessage(resultType as INamedTypeSymbol);
}
}
private void GenerateComplexTypeMessage(INamedTypeSymbol? type)
{
if (type == null || _generatedMessages.Contains(type.Name))
return;
// Don't generate messages for system types or primitives
if (type.ContainingNamespace?.ToString().StartsWith("System") == true)
return;
_generatedMessages.Add(type.Name);
_messagesBuilder.AppendLine($"// {type.Name} entity");
_messagesBuilder.AppendLine($"message {type.Name} {{");
var properties = type.GetMembers()
.OfType<IPropertySymbol>()
.Where(p => p.DeclaredAccessibility == Accessibility.Public)
.ToList();
// Collect nested complex types to generate after closing this message
var nestedComplexTypes = new List<INamedTypeSymbol>();
int fieldNumber = 1;
foreach (var prop in properties)
{
if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type))
{
_messagesBuilder.AppendLine($" // Skipped: {prop.Name} - unsupported type {prop.Type.Name}");
continue;
}
var protoType = ProtoFileTypeMapper.MapType(prop.Type, out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name);
_messagesBuilder.AppendLine($" {protoType} {fieldName} = {fieldNumber};");
// Track enums for later generation
var enumType = ProtoFileTypeMapper.GetEnumType(prop.Type);
if (enumType != null)
{
TrackEnumType(enumType);
}
// Collect complex types to generate after this message is closed
// Use GetElementOrUnderlyingType to extract element type from collections
var underlyingType = ProtoFileTypeMapper.GetElementOrUnderlyingType(prop.Type);
if (IsComplexType(underlyingType) && underlyingType is INamedTypeSymbol namedType)
{
nestedComplexTypes.Add(namedType);
}
fieldNumber++;
}
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
// Now generate nested complex type messages
foreach (var nestedType in nestedComplexTypes)
{
GenerateComplexTypeMessage(nestedType);
}
}
private ITypeSymbol? GetResultType(INamedTypeSymbol commandOrQueryType)
{
// Scan for handler classes that implement ICommandHandler<T, TResult> or IQueryHandler<T, TResult>
var handlerInterfaceName = commandOrQueryType.Name.EndsWith("Command")
? "ICommandHandler"
: "IQueryHandler";
// Find all types in the compilation and referenced assemblies
foreach (var type in GetAllTypes())
{
// Check if this type implements the handler interface
foreach (var @interface in type.AllInterfaces)
{
if (@interface.Name == handlerInterfaceName && @interface.TypeArguments.Length >= 1)
{
// Check if the first type argument matches our command/query
var firstArg = @interface.TypeArguments[0];
if (SymbolEqualityComparer.Default.Equals(firstArg, commandOrQueryType))
{
// Found the handler! Return the result type (second type argument) if it exists
if (@interface.TypeArguments.Length == 2)
{
return @interface.TypeArguments[1];
}
// If only one type argument, it's a void command (ICommandHandler<T>)
return null;
}
}
}
}
return null; // No handler found
}
private bool IsComplexType(ITypeSymbol type)
{
// Check if it's a user-defined class/struct (not a primitive or system type)
if (type.TypeKind != TypeKind.Class && type.TypeKind != TypeKind.Struct)
return false;
var fullName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
return !fullName.Contains("System.");
}
private string GetXmlDocSummary(INamedTypeSymbol type)
{
var xml = type.GetDocumentationCommentXml();
if (string.IsNullOrEmpty(xml))
return $"{type.Name} operation";
// Simple extraction - could be enhanced
// xml is guaranteed non-null after IsNullOrEmpty check above
var summaryStart = xml!.IndexOf("<summary>");
var summaryEnd = xml.IndexOf("</summary>");
if (summaryStart >= 0 && summaryEnd > summaryStart)
{
var summary = xml.Substring(summaryStart + 9, summaryEnd - summaryStart - 9).Trim();
return summary;
}
return $"{type.Name} operation";
}
private List<INamedTypeSymbol> DiscoverDynamicQueries()
{
// Find IQueryableProvider<T> implementations
var queryableProviderInterface = _compilation.GetTypeByMetadataName("Svrnty.CQRS.DynamicQuery.Abstractions.IQueryableProvider`1");
if (queryableProviderInterface == null)
return new List<INamedTypeSymbol>();
var dynamicQueryTypes = new List<INamedTypeSymbol>();
foreach (var type in GetAllTypes())
{
if (type.IsAbstract || type.IsStatic)
continue;
foreach (var iface in type.AllInterfaces)
{
if (iface.IsGenericType && SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, queryableProviderInterface))
{
// Extract the entity type from IQueryableProvider<TEntity>
var entityType = iface.TypeArguments[0] as INamedTypeSymbol;
if (entityType != null && !dynamicQueryTypes.Contains(entityType, SymbolEqualityComparer.Default))
{
dynamicQueryTypes.Add(entityType);
}
}
}
}
return dynamicQueryTypes;
}
private void GenerateDynamicQueryMessages(INamedTypeSymbol entityType)
{
var pluralName = Pluralize(entityType.Name);
var requestMessageName = $"DynamicQuery{pluralName}Request";
var responseMessageName = $"DynamicQuery{pluralName}Response";
// Common filter/sort/group/aggregate types (only generate once)
if (!_generatedMessages.Contains("DynamicQueryFilter"))
{
_generatedMessages.Add("DynamicQueryFilter");
_messagesBuilder.AppendLine("// Dynamic query filter with AND/OR support");
_messagesBuilder.AppendLine("message DynamicQueryFilter {");
_messagesBuilder.AppendLine(" string path = 1;");
_messagesBuilder.AppendLine(" int32 type = 2; // PoweredSoft.DynamicQuery.Core.FilterType");
_messagesBuilder.AppendLine(" string value = 3;");
_messagesBuilder.AppendLine(" repeated DynamicQueryFilter and = 4;");
_messagesBuilder.AppendLine(" repeated DynamicQueryFilter or = 5;");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
_messagesBuilder.AppendLine("// Dynamic query sort");
_messagesBuilder.AppendLine("message DynamicQuerySort {");
_messagesBuilder.AppendLine(" string path = 1;");
_messagesBuilder.AppendLine(" bool ascending = 2;");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
_messagesBuilder.AppendLine("// Dynamic query group");
_messagesBuilder.AppendLine("message DynamicQueryGroup {");
_messagesBuilder.AppendLine(" string path = 1;");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
_messagesBuilder.AppendLine("// Dynamic query aggregate");
_messagesBuilder.AppendLine("message DynamicQueryAggregate {");
_messagesBuilder.AppendLine(" string path = 1;");
_messagesBuilder.AppendLine(" int32 type = 2; // PoweredSoft.DynamicQuery.Core.AggregateType");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
}
// Request message
_messagesBuilder.AppendLine($"// Dynamic query request for {entityType.Name}");
_messagesBuilder.AppendLine($"message {requestMessageName} {{");
_messagesBuilder.AppendLine(" int32 page = 1;");
_messagesBuilder.AppendLine(" int32 page_size = 2;");
_messagesBuilder.AppendLine(" repeated DynamicQueryFilter filters = 3;");
_messagesBuilder.AppendLine(" repeated DynamicQuerySort sorts = 4;");
_messagesBuilder.AppendLine(" repeated DynamicQueryGroup groups = 5;");
_messagesBuilder.AppendLine(" repeated DynamicQueryAggregate aggregates = 6;");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
// Response message
_messagesBuilder.AppendLine($"// Dynamic query response for {entityType.Name}");
_messagesBuilder.AppendLine($"message {responseMessageName} {{");
_messagesBuilder.AppendLine($" repeated {entityType.Name} data = 1;");
_messagesBuilder.AppendLine(" int64 total_records = 2;");
_messagesBuilder.AppendLine(" int32 number_of_pages = 3;");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
// Generate entity message if not already generated
GenerateComplexTypeMessage(entityType);
}
private static string Pluralize(string word)
{
if (word.EndsWith("y") && word.Length > 1 && !"aeiou".Contains(word[word.Length - 2].ToString()))
return word.Substring(0, word.Length - 1) + "ies";
if (word.EndsWith("s") || word.EndsWith("x") || word.EndsWith("z") || word.EndsWith("ch") || word.EndsWith("sh"))
return word + "es";
return word + "s";
}
/// <summary>
/// Tracks an enum type for later generation
/// </summary>
private void TrackEnumType(INamedTypeSymbol enumType)
{
if (!_generatedEnums.Contains(enumType.Name) && !_pendingEnums.Any(e => e.Name == enumType.Name))
{
_pendingEnums.Add(enumType);
}
}
/// <summary>
/// Generates all pending enum definitions
/// </summary>
private void GeneratePendingEnums()
{
foreach (var enumType in _pendingEnums)
{
if (_generatedEnums.Contains(enumType.Name))
continue;
_generatedEnums.Add(enumType.Name);
_enumsBuilder.AppendLine($"// {enumType.Name} enum");
_enumsBuilder.AppendLine($"enum {enumType.Name} {{");
// Get all enum members
var members = enumType.GetMembers()
.OfType<IFieldSymbol>()
.Where(f => f.HasConstantValue)
.ToList();
foreach (var member in members)
{
var protoFieldName = $"{ProtoFileTypeMapper.ToSnakeCase(enumType.Name).ToUpperInvariant()}_{ProtoFileTypeMapper.ToSnakeCase(member.Name).ToUpperInvariant()}";
var value = member.ConstantValue;
_enumsBuilder.AppendLine($" {protoFieldName} = {value};");
}
_enumsBuilder.AppendLine("}");
_enumsBuilder.AppendLine();
}
}
/// <summary>
/// Discovers types marked with [StreamingNotification] attribute
/// </summary>
private List<NotificationInfo> DiscoverNotifications()
{
var streamingNotificationAttribute = _compilation.GetTypeByMetadataName(
"Svrnty.CQRS.Notifications.Abstractions.StreamingNotificationAttribute");
if (streamingNotificationAttribute == null)
return new List<NotificationInfo>();
var notifications = new List<NotificationInfo>();
foreach (var type in GetAllTypes())
{
if (type.IsAbstract || type.IsStatic)
continue;
var attr = type.GetAttributes()
.FirstOrDefault(a => SymbolEqualityComparer.Default.Equals(
a.AttributeClass, streamingNotificationAttribute));
if (attr == null)
continue;
// Extract SubscriptionKey from attribute
var subscriptionKeyArg = attr.NamedArguments
.FirstOrDefault(a => a.Key == "SubscriptionKey");
var subscriptionKeyProp = subscriptionKeyArg.Value.Value as string;
if (string.IsNullOrEmpty(subscriptionKeyProp))
continue;
// Get all properties of the notification type
var properties = ExtractNotificationProperties(type);
// Find the subscription key property info
var keyPropInfo = properties.FirstOrDefault(p => p.Name == subscriptionKeyProp);
if (keyPropInfo == null)
continue;
notifications.Add(new NotificationInfo
{
Name = type.Name,
FullyQualifiedName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
.Replace("global::", ""),
Namespace = type.ContainingNamespace?.ToDisplayString() ?? "",
SubscriptionKeyProperty = subscriptionKeyProp,
SubscriptionKeyInfo = keyPropInfo,
Properties = properties
});
}
return notifications;
}
/// <summary>
/// Extracts property information from a notification type
/// </summary>
private List<Models.PropertyInfo> ExtractNotificationProperties(INamedTypeSymbol type)
{
var properties = new List<Models.PropertyInfo>();
int fieldNumber = 1;
foreach (var prop in type.GetMembers().OfType<IPropertySymbol>()
.Where(p => p.DeclaredAccessibility == Accessibility.Public))
{
if (ProtoFileTypeMapper.IsUnsupportedType(prop.Type))
continue;
var protoType = ProtoFileTypeMapper.MapType(prop.Type, out _, out _);
var enumType = ProtoFileTypeMapper.GetEnumType(prop.Type);
properties.Add(new Models.PropertyInfo
{
Name = prop.Name,
Type = prop.Type.Name,
FullyQualifiedType = prop.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
.Replace("global::", ""),
ProtoType = protoType,
FieldNumber = fieldNumber++,
IsEnum = enumType != null,
IsDecimal = prop.Type.SpecialType == SpecialType.System_Decimal ||
prop.Type.ToDisplayString().Contains("decimal"),
IsDateTime = prop.Type.ToDisplayString().Contains("DateTime"),
IsNullable = prop.Type.NullableAnnotation == NullableAnnotation.Annotated ||
(prop.Type is INamedTypeSymbol namedType &&
namedType.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T)
});
if (enumType != null)
{
TrackEnumType(enumType);
}
}
return properties;
}
/// <summary>
/// Generates proto messages for a notification type
/// </summary>
private void GenerateNotificationMessages(NotificationInfo notification)
{
// Generate subscription request message (contains only the subscription key)
var requestMessageName = $"SubscribeTo{notification.Name}Request";
if (!_generatedMessages.Contains(requestMessageName))
{
_generatedMessages.Add(requestMessageName);
_messagesBuilder.AppendLine($"// Subscription request for {notification.Name}");
_messagesBuilder.AppendLine($"message {requestMessageName} {{");
_messagesBuilder.AppendLine($" {notification.SubscriptionKeyInfo.ProtoType} {ProtoFileTypeMapper.ToSnakeCase(notification.SubscriptionKeyProperty)} = 1;");
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
}
// Generate the notification message itself
if (!_generatedMessages.Contains(notification.Name))
{
_generatedMessages.Add(notification.Name);
_messagesBuilder.AppendLine($"// {notification.Name} streaming notification");
_messagesBuilder.AppendLine($"message {notification.Name} {{");
foreach (var prop in notification.Properties)
{
var protoType = ProtoFileTypeMapper.MapType(
_compilation.GetTypeByMetadataName(prop.FullyQualifiedType) ??
GetTypeFromName(prop.FullyQualifiedType),
out var needsImport, out var importPath);
if (needsImport && importPath != null)
{
_requiredImports.Add(importPath);
}
var fieldName = ProtoFileTypeMapper.ToSnakeCase(prop.Name);
_messagesBuilder.AppendLine($" {prop.ProtoType} {fieldName} = {prop.FieldNumber};");
}
_messagesBuilder.AppendLine("}");
_messagesBuilder.AppendLine();
}
}
/// <summary>
/// Gets a type symbol from a type name by searching all types
/// </summary>
private ITypeSymbol? GetTypeFromName(string fullTypeName)
{
// Try to find the type in all types
return GetAllTypes().FirstOrDefault(t =>
t.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).Replace("global::", "") == fullTypeName ||
t.ToDisplayString() == fullTypeName);
}
}