Skip to content

Commit

Permalink
Add support for complex type discriminator.
Browse files Browse the repository at this point in the history
Breaking change: IDiscriminatorPropertySetConvention.ProcessDiscriminatorPropertySet signature changed

Part of #31376
  • Loading branch information
AndriySvyryd committed Mar 7, 2025
1 parent 34ee81a commit 459741f
Show file tree
Hide file tree
Showing 67 changed files with 2,107 additions and 762 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ private static void ProcessEntityType(IConventionEntityTypeBuilder entityTypeBui

/// <inheritdoc />
public override void ProcessDiscriminatorPropertySet(
IConventionEntityTypeBuilder entityTypeBuilder,
IConventionTypeBaseBuilder structuralTypeBuilder,
string? name,
IConventionContext<string> context)
{
var entityType = entityTypeBuilder.Metadata;
if (entityType.IsDocumentRoot())
if (structuralTypeBuilder.Metadata is not IConventionEntityType entityType
|| entityType.IsDocumentRoot())
{
base.ProcessDiscriminatorPropertySet(entityTypeBuilder, name, context);
base.ProcessDiscriminatorPropertySet(structuralTypeBuilder, name, context);
}
}

Expand Down
21 changes: 10 additions & 11 deletions src/EFCore.Cosmos/Metadata/Conventions/CosmosJsonIdConvention.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,10 @@ private void ProcessEntityType(IConventionEntityType entityType, IConventionCont
}

// Don't chain, because each of these could return null if the property has been explicitly configured with some other value.
computedIdPropertyBuilder = computedIdPropertyBuilder.ToJsonProperty(IdPropertyJsonName)
?? computedIdPropertyBuilder;

computedIdPropertyBuilder = computedIdPropertyBuilder.IsRequired(true)
?? computedIdPropertyBuilder;

computedIdPropertyBuilder = computedIdPropertyBuilder.HasValueGeneratorFactory(typeof(IdValueGeneratorFactory))
?? computedIdPropertyBuilder;

computedIdPropertyBuilder.ToJsonProperty(IdPropertyJsonName);
computedIdPropertyBuilder.HasValueGeneratorFactory(typeof(IdValueGeneratorFactory));
computedIdPropertyBuilder.AfterSave(PropertySaveBehavior.Throw);
computedIdPropertyBuilder.IsRequired(true);
}

/// <inheritdoc />
Expand Down Expand Up @@ -327,8 +321,13 @@ public virtual void ProcessModelAnnotationChanged(

/// <inheritdoc />
public virtual void ProcessDiscriminatorPropertySet(
IConventionEntityTypeBuilder entityTypeBuilder,
IConventionTypeBaseBuilder structuralTypeBuilder,
string? name,
IConventionContext<string?> context)
=> ProcessEntityType(entityTypeBuilder.Metadata, context);
{
if (structuralTypeBuilder is IConventionEntityTypeBuilder entityTypeBuilder)
{
ProcessEntityType(entityTypeBuilder.Metadata, context);
}
}
}
55 changes: 53 additions & 2 deletions src/EFCore.Design/Migrations/Design/CSharpSnapshotGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -628,9 +628,60 @@ protected virtual void GenerateComplexPropertyAnnotations(
IComplexProperty property,
IndentedStringBuilder stringBuilder)
{
var discriminatorProperty = property.ComplexType.FindDiscriminatorProperty();
if (discriminatorProperty != null)
{
stringBuilder
.AppendLine()
.Append(propertyBuilderName)
.Append('.')
.Append("HasDiscriminator");

if (discriminatorProperty.DeclaringType == property.ComplexType
&& discriminatorProperty.Name != "Discriminator")
{
var propertyClrType = FindValueConverter(discriminatorProperty)?.ProviderClrType
.MakeNullable(discriminatorProperty.IsNullable)
?? discriminatorProperty.ClrType;
stringBuilder
.Append('<')
.Append(Code.Reference(propertyClrType))
.Append(">(")
.Append(Code.Literal(discriminatorProperty.Name))
.Append(')');
}
else
{
stringBuilder
.Append("()");
}

var discriminatorValue = property.ComplexType.GetDiscriminatorValue();
if (discriminatorValue != null)
{
if (discriminatorProperty != null)
{
var valueConverter = FindValueConverter(discriminatorProperty);
if (valueConverter != null)
{
discriminatorValue = valueConverter.ConvertToProvider(discriminatorValue);
}
}

stringBuilder
.Append('.')
.Append("HasValue")
.Append('(')
.Append(Code.UnknownLiteral(discriminatorValue))
.Append(')');
}

stringBuilder.AppendLine(";");
}

var propertyAnnotations = Dependencies.AnnotationCodeGenerator
.FilterIgnoredAnnotations(property.GetAnnotations())
.ToDictionary(a => a.Name, a => a);
.FilterIgnoredAnnotations(property.GetAnnotations())
.ToDictionary(a => a.Name, a => a);

var typeAnnotations = Dependencies.AnnotationCodeGenerator
.FilterIgnoredAnnotations(property.ComplexType.GetAnnotations())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -925,14 +925,6 @@ private void Create(IEntityType entityType, CSharpRuntimeAnnotationCodeGenerator
.Append(_code.Literal(entityType.HasSharedClrType));
}

var discriminatorProperty = entityType.GetDiscriminatorPropertyName();
if (discriminatorProperty != null)
{
mainBuilder.AppendLine(",")
.Append("discriminatorProperty: ")
.Append(_code.Literal(discriminatorProperty));
}

var changeTrackingStrategy = entityType.GetChangeTrackingStrategy();
if (changeTrackingStrategy != ChangeTrackingStrategy.Snapshot)
{
Expand All @@ -959,6 +951,14 @@ private void Create(IEntityType entityType, CSharpRuntimeAnnotationCodeGenerator
.Append(_code.Literal(true));
}

var discriminatorProperty = entityType.GetDiscriminatorPropertyName();
if (discriminatorProperty != null)
{
mainBuilder.AppendLine(",")
.Append("discriminatorProperty: ")
.Append(_code.Literal(discriminatorProperty));
}

var discriminatorValue = entityType.GetDiscriminatorValue();
if (discriminatorValue != null)
{
Expand Down Expand Up @@ -2182,6 +2182,24 @@ private void CreateComplexProperty(
.Append(_code.Literal(true));
}

var discriminatorPropertyName = complexType.GetDiscriminatorPropertyName();
if (discriminatorPropertyName != null)
{
mainBuilder.AppendLine(",")
.Append("discriminatorProperty: ")
.Append(_code.Literal(discriminatorPropertyName));
}

var discriminatorValue = complexType.GetDiscriminatorValue();
if (discriminatorValue != null)
{
AddNamespace(discriminatorValue.GetType(), parameters.Namespaces);

mainBuilder.AppendLine(",")
.Append("discriminatorValue: ")
.Append(_code.UnknownLiteral(discriminatorValue));
}

mainBuilder.AppendLine(",")
.Append("propertyCount: ")
.Append(_code.Literal(complexType.GetDeclaredProperties().Count()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1957,6 +1957,11 @@ protected override void ValidateInheritanceMapping(
var discriminatorValues = new Dictionary<string, IEntityType>();
foreach (var derivedType in derivedTypes)
{
foreach (var complexProperty in derivedType.GetDeclaredComplexProperties())
{
ValidateDiscriminatorValues(complexProperty.ComplexType);
}

var discriminatorValue = derivedType.GetDiscriminatorValue();
if (!derivedType.ClrType.IsInstantiable()
|| discriminatorValue is null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public virtual void ProcessModelFinalizing(IConventionModelBuilder modelBuilder,
&& !discriminatorProperty.IsForeignKey())
{
var maxDiscriminatorValueLength =
entityType.GetDerivedTypesInclusive().Select(e => ((string)e.GetDiscriminatorValue()!).Length).Max();
entityType.GetDerivedTypesInclusive().Select(e => (e.GetDiscriminatorValue() as string)?.Length ?? 0).Max();

var previous = 1;
var current = 1;
Expand Down
70 changes: 68 additions & 2 deletions src/EFCore/Infrastructure/ModelValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -640,12 +640,16 @@ protected virtual void ValidateInheritanceMapping(
/// <param name="rootEntityType">The entity type to validate.</param>
protected virtual void ValidateDiscriminatorValues(IEntityType rootEntityType)
{
var derivedTypes = rootEntityType.GetDerivedTypesInclusive().ToList();
var derivedTypes = rootEntityType.GetDerivedTypesInclusive();
var discriminatorProperty = rootEntityType.FindDiscriminatorProperty();
if (discriminatorProperty == null)
{
if (derivedTypes.Count == 1)
if (!derivedTypes.Skip(1).Any())
{
foreach (var complexProperty in rootEntityType.GetDeclaredComplexProperties())
{
ValidateDiscriminatorValues(complexProperty.ComplexType);
}
return;
}

Expand All @@ -654,6 +658,68 @@ protected virtual void ValidateDiscriminatorValues(IEntityType rootEntityType)
}

var discriminatorValues = new Dictionary<object, IEntityType>(discriminatorProperty.GetKeyValueComparer());
foreach (var derivedType in derivedTypes)
{
foreach (var complexProperty in derivedType.GetDeclaredComplexProperties())
{
ValidateDiscriminatorValues(complexProperty.ComplexType);
}

if (!derivedType.ClrType.IsInstantiable())
{
continue;
}

var discriminatorValue = derivedType[CoreAnnotationNames.DiscriminatorValue];
if (discriminatorValue == null)
{
throw new InvalidOperationException(
CoreStrings.NoDiscriminatorValue(derivedType.DisplayName()));
}

if (!discriminatorProperty.ClrType.IsInstanceOfType(discriminatorValue))
{
throw new InvalidOperationException(
CoreStrings.DiscriminatorValueIncompatible(
discriminatorValue, derivedType.DisplayName(), discriminatorProperty.ClrType.DisplayName()));
}

if (discriminatorValues.TryGetValue(discriminatorValue, out var duplicateEntityType))
{
throw new InvalidOperationException(
CoreStrings.DuplicateDiscriminatorValue(
derivedType.DisplayName(), discriminatorValue, duplicateEntityType.DisplayName()));
}

discriminatorValues[discriminatorValue] = derivedType;
}
}

/// <summary>
/// Validates the discriminator and values for the given complex type and nested ones.
/// </summary>
/// <param name="complexType">The entity type to validate.</param>
protected virtual void ValidateDiscriminatorValues(IComplexType complexType)
{
foreach (var complexProperty in complexType.GetComplexProperties())
{
ValidateDiscriminatorValues(complexProperty.ComplexType);
}

var derivedTypes = complexType.GetDerivedTypesInclusive();
var discriminatorProperty = complexType.FindDiscriminatorProperty();
if (discriminatorProperty == null)
{
if (!derivedTypes.Skip(1).Any())
{
return;
}

throw new InvalidOperationException(
CoreStrings.NoDiscriminatorProperty(complexType.DisplayName()));
}

var discriminatorValues = new Dictionary<object, IComplexType>(discriminatorProperty.GetKeyValueComparer());

foreach (var derivedType in derivedTypes)
{
Expand Down
Loading

0 comments on commit 459741f

Please sign in to comment.