diff --git a/Dockerfile b/Dockerfile index 5dbaa35a17..537ea4c78c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,13 +1,13 @@ # Version values referenced from https://hub.docker.com/_/microsoft-dotnet-aspnet -FROM mcr.microsoft.com/dotnet/sdk:6.0-cbl-mariner2.0. AS build +FROM mcr.microsoft.com/dotnet/sdk:8.0 AS build WORKDIR /src COPY [".", "./"] -RUN dotnet build "./src/Service/Azure.DataApiBuilder.Service.csproj" -c Docker -o /out -r linux-x64 +RUN dotnet build "./src/Service/Azure.DataApiBuilder.Service.csproj" -f net8.0 -o /out -r linux-x64 --self-contained -FROM mcr.microsoft.com/dotnet/aspnet:6.0-cbl-mariner2.0 AS runtime +FROM mcr.microsoft.com/dotnet/aspnet:8.0 AS runtime COPY --from=build /out /App WORKDIR /App diff --git a/src/Config/ObjectModel/AuthenticationOptions.cs b/src/Config/ObjectModel/AuthenticationOptions.cs index 189540fbe6..6750d6e807 100644 --- a/src/Config/ObjectModel/AuthenticationOptions.cs +++ b/src/Config/ObjectModel/AuthenticationOptions.cs @@ -17,6 +17,7 @@ public record AuthenticationOptions(string Provider = nameof(EasyAuthType.Static public const string CLIENT_PRINCIPAL_HEADER = "X-MS-CLIENT-PRINCIPAL"; public const string NAME_CLAIM_TYPE = "name"; public const string ROLE_CLAIM_TYPE = "roles"; + public const string ORIGINAL_ROLE_CLAIM_TYPE = "original_roles"; /// /// Returns whether the configured Provider matches an diff --git a/src/Core/Authorization/AuthorizationResolver.cs b/src/Core/Authorization/AuthorizationResolver.cs index 64785de703..f368eed5f5 100644 --- a/src/Core/Authorization/AuthorizationResolver.cs +++ b/src/Core/Authorization/AuthorizationResolver.cs @@ -604,9 +604,14 @@ public static Dictionary> GetAllAuthenticatedUserClaims(Http // into a list and storing that in resolvedClaims using the claimType as the key. foreach (Claim claim in identity.Claims) { - // 'roles' claim has already been processed. + // 'roles' claim has already been processed. But we preserve the original 'roles' claim if (claim.Type.Equals(AuthenticationOptions.ROLE_CLAIM_TYPE)) { + if(!resolvedClaims.TryAdd(AuthenticationOptions.ORIGINAL_ROLE_CLAIM_TYPE, new List() { claim })) + { + resolvedClaims[AuthenticationOptions.ORIGINAL_ROLE_CLAIM_TYPE].Add(claim); + } + continue; } diff --git a/src/Core/Parsers/RequestParser.cs b/src/Core/Parsers/RequestParser.cs index 9a3a602329..86aed37d06 100644 --- a/src/Core/Parsers/RequestParser.cs +++ b/src/Core/Parsers/RequestParser.cs @@ -30,7 +30,7 @@ public class RequestParser /// /// Prefix used for specifying limit in the query string of the URL. /// - public const string FIRST_URL = "$first"; + public const string FIRST_URL = "$top"; /// /// Prefix used for specifying paging in the query string of the URL. /// diff --git a/src/Core/Resolvers/DWSqlQueryBuilder.cs b/src/Core/Resolvers/DWSqlQueryBuilder.cs index e1768a97df..780c477364 100644 --- a/src/Core/Resolvers/DWSqlQueryBuilder.cs +++ b/src/Core/Resolvers/DWSqlQueryBuilder.cs @@ -46,8 +46,21 @@ public string Build(SqlQueryStructure structure) /// private string BuildAsJson(SqlQueryStructure structure, bool subQueryStructure = false) { + string subQueryAlias = "CountQuery"; + + string countSql = $" CROSS JOIN ( {BuildSqlCountQuery(structure)} ) {subQueryAlias}"; + + //Add a new column to the structure + structure.Columns.Add(new LabelledColumn("", subQueryAlias, "RecordCount", "RecordCount", subQueryAlias)); + + //Add a subquery 'a' ti the structure + structure.JoinQueries.Add(subQueryAlias, structure); + string columns = GenerateColumnsAsJson(structure, subQueryStructure); - string fromSql = $"{BuildSqlQuery(structure)}"; + + structure.JoinQueries.Remove(subQueryAlias); + + string fromSql = $"{BuildSqlQuery(structure, countSql)}"; string query = $"SELECT {columns}" + $" FROM ({fromSql}) AS {QuoteIdentifier(structure.SourceAlias)}"; return query; @@ -64,7 +77,7 @@ private string BuildAsJson(SqlQueryStructure structure, bool subQueryStructure = /// FROM dbo_books AS[table0] /// OUTER APPLY(SubQuery generated by recursive call to build function, will create the _subq tables) /// - private string BuildSqlQuery(SqlQueryStructure structure) + private string BuildSqlQuery(SqlQueryStructure structure, string? subQuery) { string dataIdent = QuoteIdentifier(SqlQueryStructure.DATA_IDENT); StringBuilder fromSql = new(); @@ -87,11 +100,38 @@ private string BuildSqlQuery(SqlQueryStructure structure) string query = $"SELECT TOP {structure.Limit()} {columns}" + $" FROM {fromSql}" + + $" {subQuery}" + $" WHERE {predicates}" + orderBy; return query; } + private string BuildSqlCountQuery(SqlQueryStructure structure) + { + string dataIdent = QuoteIdentifier(SqlQueryStructure.DATA_IDENT); + StringBuilder fromSql = new(); + + fromSql.Append($"{QuoteIdentifier(structure.DatabaseObject.SchemaName)}.{QuoteIdentifier(structure.DatabaseObject.Name)} " + + $"AS {QuoteIdentifier($"{structure.SourceAlias}")}{Build(structure.Joins)}"); + + fromSql.Append(string.Join( + "", + structure.JoinQueries.Select( + x => $" OUTER APPLY ({BuildAsJson(x.Value, true)}) AS {QuoteIdentifier(x.Key)}({dataIdent})"))); + + string predicates = JoinPredicateStrings( + structure.GetDbPolicyForOperation(EntityActionOperation.Read), + structure.FilterPredicates, + Build(structure.Predicates), + Build(structure.PaginationMetadata.PaginationPredicate)); + + string query = $"SELECT cast(count(1) as varchar(50)) as RecordCount " + + $" FROM {fromSql}" + + $" WHERE {predicates}"; + + return query; + } + private static string GenerateColumnsAsJson(SqlQueryStructure structure, bool subQueryStructure = false) { string columns; diff --git a/src/Core/Resolvers/MsSqlQueryExecutor.cs b/src/Core/Resolvers/MsSqlQueryExecutor.cs index 96f82cfa25..1ac61f7dfb 100644 --- a/src/Core/Resolvers/MsSqlQueryExecutor.cs +++ b/src/Core/Resolvers/MsSqlQueryExecutor.cs @@ -219,7 +219,7 @@ public override string GetSessionParamsQuery(HttpContext? httpContext, IDictiona string paramName = $"{SESSION_PARAM_NAME}{counter.Next()}"; parameters.Add(paramName, new(claimValue)); // Append statement to set read only param value - can be set only once for a connection. - string statementToSetReadOnlyParam = "EXEC sp_set_session_context " + $"'{claimType}', " + paramName + ", @read_only = 1;"; + string statementToSetReadOnlyParam = "EXEC sp_set_session_context " + $"'{claimType}', " + paramName + ", @read_only = 0;"; sessionMapQuery = sessionMapQuery.Append(statementToSetReadOnlyParam); } diff --git a/src/Core/Resolvers/SqlResponseHelpers.cs b/src/Core/Resolvers/SqlResponseHelpers.cs index 7701d662d3..9171a3a322 100644 --- a/src/Core/Resolvers/SqlResponseHelpers.cs +++ b/src/Core/Resolvers/SqlResponseHelpers.cs @@ -51,6 +51,13 @@ public static OkObjectResult FormatFindResult( ? DetermineExtraFieldsInResponse(findOperationResponse, context.FieldsToBeReturned) : DetermineExtraFieldsInResponse(findOperationResponse.EnumerateArray().First(), context.FieldsToBeReturned); + //Remove RecordCOunt from extraFieldsInResponse if present + /* + if (extraFieldsInResponse.Contains("RecordCount")) + { + extraFieldsInResponse.Remove("RecordCount"); + } + */ uint defaultPageSize = runtimeConfig.DefaultPageSize(); uint maxPageSize = runtimeConfig.MaxPageSize(); @@ -113,6 +120,16 @@ public static OkObjectResult FormatFindResult( queryStringParameters: context!.ParsedQueryString, after); + //Get the element RecordCount from the first element of the array + JsonElement recordCountElement = rootEnumerated[0].GetProperty("RecordCount"); + string jsonRecordCount = JsonSerializer.Serialize(new[] + { + new + { + recordCount = @$"{rootEnumerated[0].GetProperty("RecordCount")}" + } + }); + // When there are extra fields present, they are removed before returning the response. if (extraFieldsInResponse.Count > 0) { @@ -120,6 +137,7 @@ public static OkObjectResult FormatFindResult( } rootEnumerated.Add(nextLink); + rootEnumerated.Add(JsonSerializer.Deserialize(jsonRecordCount)); return OkResponse(JsonSerializer.SerializeToElement(rootEnumerated)); } @@ -218,13 +236,16 @@ public static OkObjectResult OkResponse(JsonElement jsonResult) // we strip the "[" and "]" and then save the nextLink element // into a dictionary with a key of "nextLink" and a value that // represents the nextLink data we require. - string nextLinkJsonString = JsonSerializer.Serialize(resultEnumerated[resultEnumerated.Count - 1]); + string nextLinkJsonString = JsonSerializer.Serialize(resultEnumerated[resultEnumerated.Count - 2]); + string recordCountJsonString = JsonSerializer.Serialize(resultEnumerated[resultEnumerated.Count - 1]); Dictionary nextLink = JsonSerializer.Deserialize>(nextLinkJsonString[1..^1])!; - IEnumerable value = resultEnumerated.Take(resultEnumerated.Count - 1); + Dictionary recordCount = JsonSerializer.Deserialize>(recordCountJsonString[1..^1])!; + IEnumerable value = resultEnumerated.Take(resultEnumerated.Count - 2); return new OkObjectResult(new { value = value, - @nextLink = nextLink["nextLink"] + @nextLink = nextLink["nextLink"], + @recordCount = recordCount["recordCount"] }); } diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index b33dd9cc9e..1c758176ae 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -20,7 +20,7 @@ - + diff --git a/src/Service.Tests/Authorization/AuthorizationResolverUnitTests.cs b/src/Service.Tests/Authorization/AuthorizationResolverUnitTests.cs index 3c7c31a8ca..733ec15b24 100644 --- a/src/Service.Tests/Authorization/AuthorizationResolverUnitTests.cs +++ b/src/Service.Tests/Authorization/AuthorizationResolverUnitTests.cs @@ -1293,7 +1293,8 @@ public void UniqueClaimsResolvedForDbPolicy_SessionCtx_Usage() new("sub", "Aa_0RISCzzZ-abC1De2fGHIjKLMNo123pQ4rStUVWXY"), new("oid", "55296aad-ea7f-4c44-9a4c-bb1e8d43a005"), new(AuthenticationOptions.ROLE_CLAIM_TYPE, TEST_ROLE), - new(AuthenticationOptions.ROLE_CLAIM_TYPE, "ROLE2") + new(AuthenticationOptions.ROLE_CLAIM_TYPE, "ROLE2"), + new(AuthenticationOptions.ROLE_CLAIM_TYPE, "ROLE3") }; //Add identity object to the Mock context object. @@ -1315,6 +1316,7 @@ public void UniqueClaimsResolvedForDbPolicy_SessionCtx_Usage() Assert.AreEqual(expected: "Aa_0RISCzzZ-abC1De2fGHIjKLMNo123pQ4rStUVWXY", actual: claimsInRequestContext["sub"], message: "Expected the sub claim to be present."); Assert.AreEqual(expected: "55296aad-ea7f-4c44-9a4c-bb1e8d43a005", actual: claimsInRequestContext["oid"], message: "Expected the oid claim to be present."); Assert.AreEqual(claimsInRequestContext[AuthenticationOptions.ROLE_CLAIM_TYPE], actual: TEST_ROLE, message: "The roles claim should have the value:" + TEST_ROLE); + Assert.AreEqual(expected: "[\"" + TEST_ROLE + "\",\"ROLE2\",\"ROLE3\"]", actual: claimsInRequestContext[AuthenticationOptions.ORIGINAL_ROLE_CLAIM_TYPE], message: "Original roles should be preserved in a new context"); } /// @@ -1365,7 +1367,7 @@ public void ValidateUnauthenticatedUserClaimsAreNotResolvedWhenProcessingUserCla Dictionary resolvedClaims = AuthorizationResolver.GetProcessedUserClaims(context.Object); // Assert - Assert.AreEqual(expected: authenticatedUserclaims.Count, actual: resolvedClaims.Count, message: "Only two claims should be present."); + Assert.AreEqual(expected: authenticatedUserclaims.Count + 1, actual: resolvedClaims.Count, message: "Only " + (authenticatedUserclaims.Count + 1) + " claims should be present."); Assert.AreEqual(expected: "openid", actual: resolvedClaims["scp"], message: "Unexpected scp claim returned."); bool didResolveUnauthenticatedRoleClaim = resolvedClaims[AuthenticationOptions.ROLE_CLAIM_TYPE] == "Don't_Parse_This_Role";