From d598bfb63f4e8ab441c14fa6693a8c12929d57c4 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Fri, 13 Dec 2024 09:42:20 +0100 Subject: [PATCH] Translate to NULLIF Closes #31682 --- EFCore.sln.DotSettings | 2 + .../Query/SqlExpressionFactory.cs | 52 ++++++++++++++++ .../MiscellaneousTranslationsCosmosTest.cs | 60 +++++++++++++++++++ .../MiscellaneousTranslationsTestBase.cs | 34 +++++++++++ .../MiscellaneousTranslationsSqlServerTest.cs | 52 ++++++++++++++++ .../MiscellaneousTranslationsSqliteTest.cs | 52 ++++++++++++++++ 6 files changed, 252 insertions(+) diff --git a/EFCore.sln.DotSettings b/EFCore.sln.DotSettings index b8b7f701254..1203f411873 100644 --- a/EFCore.sln.DotSettings +++ b/EFCore.sln.DotSettings @@ -356,7 +356,9 @@ The .NET Foundation licenses this file to you under the MIT license. True True True + True True + True True True True diff --git a/src/EFCore.Relational/Query/SqlExpressionFactory.cs b/src/EFCore.Relational/Query/SqlExpressionFactory.cs index 064e97cbe74..c5ef316c269 100644 --- a/src/EFCore.Relational/Query/SqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/SqlExpressionFactory.cs @@ -825,6 +825,31 @@ public virtual SqlExpression Case( elseResult = lastCase.ElseResult; } + // Optimize: + // a == b ? null : a -> NULLIF(a, b) + // a != b ? a : null -> NULLIF(a, b) + if (operand is null + && typeMappedWhenClauses is + [ + { + Test: SqlBinaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual } binary, + Result: var result + } + ]) + { + switch (binary.OperatorType) + { + case ExpressionType.Equal + when result is SqlConstantExpression { Value: null } + && elseResult is not null + && TryTranslateToNullIf(elseResult, out var nullIfTranslation): + case ExpressionType.NotEqual + when elseResult is null or SqlConstantExpression { Value: null } + && TryTranslateToNullIf(result, out nullIfTranslation): + return nullIfTranslation; + } + } + return existingExpression is CaseExpression expr && operand == expr.Operand && typeMappedWhenClauses.SequenceEqual(expr.WhenClauses) @@ -837,6 +862,33 @@ bool IsSkipped(CaseWhenClause clause) bool IsMatched(CaseWhenClause clause) => operand is null && clause.Test is SqlConstantExpression { Value: true }; + + bool TryTranslateToNullIf(SqlExpression conditionalResult, [NotNullWhen(true)] out SqlExpression? nullIfTranslation) + { + var (left, right) = (binary.Left, binary.Right); + + // If one of sides of the equality is equal to the result of the conditional - a == b ? null : a - convert to + // NULLIF(a, b). + // Specifically refrain from doing so for when the other side is a null constant, as that would transform a == null ? null : a + // to NULLIF(a, NULL), which we don't want. + + if (left.Equals(conditionalResult) && right is not SqlConstantExpression { Value: null }) + { + nullIfTranslation = Function( + "NULLIF", [left, right], true, [false, false], left.Type, left.TypeMapping); + return true; + } + + if (right.Equals(conditionalResult) && left is not SqlConstantExpression { Value: null }) + { + nullIfTranslation = Function( + "NULLIF", [right, left], true, [false, false], right.Type, right.TypeMapping); + return true; + } + + nullIfTranslation = null; + return false; + } } /// diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/Translations/MiscellaneousTranslationsCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/Translations/MiscellaneousTranslationsCosmosTest.cs index 5efbf266cda..7ec7edfa2a1 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/Translations/MiscellaneousTranslationsCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/Translations/MiscellaneousTranslationsCosmosTest.cs @@ -175,6 +175,66 @@ public override async Task TimeSpan_Compare_to_simple_zero(bool async, bool comp #endregion Compare + #region Uncoalescing conditional / NullIf + + public override Task Uncoalescing_conditional_with_equality_left(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Uncoalescing_conditional_with_equality_left(a); + + AssertSql( + """ +SELECT VALUE c +FROM root c +WHERE (((c["Int"] = 9) ? null : c["Int"]) > 1) +"""); + }); + + public override Task Uncoalescing_conditional_with_equality_right(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Uncoalescing_conditional_with_equality_right(a); + + AssertSql( + """ +SELECT VALUE c +FROM root c +WHERE (((9 = c["Int"]) ? null : c["Int"]) > 1) +"""); + }); + + public override Task Uncoalescing_conditional_with_unequality_left(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Uncoalescing_conditional_with_unequality_left(a); + + AssertSql( + """ +SELECT VALUE c +FROM root c +WHERE (((c["Int"] != 9) ? c["Int"] : null) > 1) +"""); + }); + + public override Task Uncoalescing_conditional_with_inequality_right(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Uncoalescing_conditional_with_inequality_right(a); + + AssertSql( + """ +SELECT VALUE c +FROM root c +WHERE (((9 != c["Int"]) ? c["Int"] : null) > 1) +"""); + }); + + #endregion Uncoalescing conditional / NullIf + [ConditionalFact] public virtual void Check_all_tests_overridden() => TestHelpers.AssertAllMethodsOverridden(GetType()); diff --git a/test/EFCore.Specification.Tests/Query/Translations/MiscellaneousTranslationsTestBase.cs b/test/EFCore.Specification.Tests/Query/Translations/MiscellaneousTranslationsTestBase.cs index d3b98d701f3..f27110d747a 100644 --- a/test/EFCore.Specification.Tests/Query/Translations/MiscellaneousTranslationsTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/Translations/MiscellaneousTranslationsTestBase.cs @@ -429,4 +429,38 @@ await AssertQuery( } #endregion + + #region Uncoalescing conditional + + // In relational providers, x == a ? null : x is translated to SQL NULLIF + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Uncoalescing_conditional_with_equality_left(bool async) + => AssertQuery( + async, + cs => cs.Set().Where(x => (x.Int == 9 ? null : x.Int) > 1)); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Uncoalescing_conditional_with_equality_right(bool async) + => AssertQuery( + async, + cs => cs.Set().Where(x => (9 == x.Int ? null : x.Int) > 1)); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Uncoalescing_conditional_with_unequality_left(bool async) + => AssertQuery( + async, + cs => cs.Set().Where(x => (x.Int != 9 ? x.Int : null) > 1)); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Uncoalescing_conditional_with_inequality_right(bool async) + => AssertQuery( + async, + cs => cs.Set().Where(x => (9 != x.Int ? x.Int : null) > 1)); + + #endregion Uncoalescing conditional } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/Translations/MiscellaneousTranslationsSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/Translations/MiscellaneousTranslationsSqlServerTest.cs index aff4ff2712d..9bfc4726093 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/Translations/MiscellaneousTranslationsSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/Translations/MiscellaneousTranslationsSqlServerTest.cs @@ -803,6 +803,58 @@ FROM [BasicTypesEntities] AS [b] #endregion Compare + #region Uncoalescing conditional / NullIf + + public override async Task Uncoalescing_conditional_with_equality_left(bool async) + { + await base.Uncoalescing_conditional_with_equality_left(async); + + AssertSql( + """ +SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan] +FROM [BasicTypesEntities] AS [b] +WHERE NULLIF([b].[Int], 9) > 1 +"""); + } + + public override async Task Uncoalescing_conditional_with_equality_right(bool async) + { + await base.Uncoalescing_conditional_with_equality_right(async); + + AssertSql( + """ +SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan] +FROM [BasicTypesEntities] AS [b] +WHERE NULLIF([b].[Int], 9) > 1 +"""); + } + + public override async Task Uncoalescing_conditional_with_unequality_left(bool async) + { + await base.Uncoalescing_conditional_with_unequality_left(async); + + AssertSql( + """ +SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan] +FROM [BasicTypesEntities] AS [b] +WHERE NULLIF([b].[Int], 9) > 1 +"""); + } + + public override async Task Uncoalescing_conditional_with_inequality_right(bool async) + { + await base.Uncoalescing_conditional_with_inequality_right(async); + + AssertSql( + """ +SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan] +FROM [BasicTypesEntities] AS [b] +WHERE NULLIF([b].[Int], 9) > 1 +"""); + } + + #endregion Uncoalescing conditional / NullIf + [ConditionalFact] public virtual void Check_all_tests_overridden() => TestHelpers.AssertAllMethodsOverridden(GetType()); diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/Translations/MiscellaneousTranslationsSqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/Translations/MiscellaneousTranslationsSqliteTest.cs index 87b01cf25e1..9655f550e13 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/Translations/MiscellaneousTranslationsSqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/Translations/MiscellaneousTranslationsSqliteTest.cs @@ -274,6 +274,58 @@ public override async Task TimeSpan_Compare_to_simple_zero(bool async, bool comp #endregion Compare + #region Uncoalescing conditional / NullIf + + public override async Task Uncoalescing_conditional_with_equality_left(bool async) + { + await base.Uncoalescing_conditional_with_equality_left(async); + + AssertSql( + """ +SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan" +FROM "BasicTypesEntities" AS "b" +WHERE NULLIF("b"."Int", 9) > 1 +"""); + } + + public override async Task Uncoalescing_conditional_with_equality_right(bool async) + { + await base.Uncoalescing_conditional_with_equality_right(async); + + AssertSql( + """ +SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan" +FROM "BasicTypesEntities" AS "b" +WHERE NULLIF("b"."Int", 9) > 1 +"""); + } + + public override async Task Uncoalescing_conditional_with_unequality_left(bool async) + { + await base.Uncoalescing_conditional_with_unequality_left(async); + + AssertSql( + """ +SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan" +FROM "BasicTypesEntities" AS "b" +WHERE NULLIF("b"."Int", 9) > 1 +"""); + } + + public override async Task Uncoalescing_conditional_with_inequality_right(bool async) + { + await base.Uncoalescing_conditional_with_inequality_right(async); + + AssertSql( + """ +SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan" +FROM "BasicTypesEntities" AS "b" +WHERE NULLIF("b"."Int", 9) > 1 +"""); + } + + #endregion Uncoalescing conditional / NullIf + [ConditionalFact] public virtual void Check_all_tests_overridden() => TestHelpers.AssertAllMethodsOverridden(GetType());