Skip to content

Commit

Permalink
Support google.api.HttpBody in gRPC/JSON transcoding (#5400)
Browse files Browse the repository at this point in the history
Motivation:

#5311 

Trying to make this work.

Modifications:

- Parse media type and content dynamically if method descriptor is a
`google.api.HttpBody`.

Result:

- Closes #5311

Need to do some testing, I couldn't fully understand how to do this. The
way I am doing it is a bit tricky bit I guess it should work

---------

Co-authored-by: jrhee17 <[email protected]>
  • Loading branch information
Dogacel and jrhee17 authored Jul 25, 2024
1 parent cc400f3 commit c7b2a44
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ protected void frameAndServe(
RequestHeaders grpcHeaders,
HttpData content,
CompletableFuture<HttpResponse> res,
@Nullable Function<HttpData, HttpData> responseBodyConverter,
MediaType responseContentType) {
@Nullable Function<AggregatedHttpResponse, AggregatedHttpResponse> responseConverter) {
final HttpRequest grpcRequest;
ctx.setAttr(IS_UNFRAMED_GRPC, true);
try (ArmeriaMessageFramer framer = new ArmeriaMessageFramer(
Expand Down Expand Up @@ -177,7 +176,7 @@ protected void frameAndServe(
res.completeExceptionally(t);
} else {
deframeAndRespond(ctx, framedResponse, res, unframedGrpcErrorHandler,
responseBodyConverter, responseContentType);
responseConverter);
}
}
return null;
Expand All @@ -189,8 +188,8 @@ static void deframeAndRespond(ServiceRequestContext ctx,
AggregatedHttpResponse grpcResponse,
CompletableFuture<HttpResponse> res,
UnframedGrpcErrorHandler unframedGrpcErrorHandler,
@Nullable Function<HttpData, HttpData> responseBodyConverter,
MediaType responseContentType) {
@Nullable
Function<AggregatedHttpResponse, AggregatedHttpResponse> responseConverter) {
final HttpHeaders trailers = !grpcResponse.trailers().isEmpty() ?
grpcResponse.trailers() : grpcResponse.headers();
final String grpcStatusCode = trailers.get(GrpcHeaderNames.GRPC_STATUS);
Expand Down Expand Up @@ -226,19 +225,19 @@ static void deframeAndRespond(ServiceRequestContext ctx,

final ResponseHeadersBuilder unframedHeaders = grpcResponse.headers().toBuilder();
unframedHeaders.set(GrpcHeaderNames.GRPC_STATUS, grpcStatusCode); // grpcStatusCode is 0 which is OK.
unframedHeaders.contentType(responseContentType);

final ArmeriaMessageDeframer deframer = new ArmeriaMessageDeframer(
// Max outbound message size is handled by the GrpcService, so we don't need to set it here.
Integer.MAX_VALUE);
final Subscriber<DeframedMessage> subscriber = singleSubscriber(
unframedHeaders, res, responseConverter);
grpcResponse.toHttpResponse().decode(deframer, ctx.alloc())
.subscribe(singleSubscriber(unframedHeaders, res, responseBodyConverter), ctx.eventLoop(),
SubscriptionOption.WITH_POOLED_OBJECTS);
.subscribe(subscriber, ctx.eventLoop(), SubscriptionOption.WITH_POOLED_OBJECTS);
}

static Subscriber<DeframedMessage> singleSubscriber(
ResponseHeadersBuilder unframedHeaders, CompletableFuture<HttpResponse> res,
@Nullable Function<HttpData, HttpData> responseBodyConverter) {
@Nullable Function<AggregatedHttpResponse, AggregatedHttpResponse> responseConverter) {
return new Subscriber<DeframedMessage>() {

@Override
Expand All @@ -249,12 +248,19 @@ public void onSubscribe(Subscription subscription) {
@Override
public void onNext(DeframedMessage message) {
// We know that we don't support compression, so this is always a ByteBuf.
HttpData unframedContent = HttpData.wrap(message.buf());
if (responseBodyConverter != null) {
unframedContent = responseBodyConverter.apply(unframedContent);
final HttpData unframedContent = HttpData.wrap(message.buf());
unframedHeaders.contentType(MediaType.JSON_UTF_8);

final AggregatedHttpResponse existingResponse = AggregatedHttpResponse.of(
unframedHeaders.build(),
unframedContent);

if (responseConverter != null) {
final AggregatedHttpResponse convertedResponse = responseConverter.apply(existingResponse);
res.complete(convertedResponse.toHttpResponse());
} else {
res.complete(existingResponse.toHttpResponse());
}
unframedHeaders.contentLength(unframedContent.length());
res.complete(HttpResponse.of(unframedHeaders.build(), unframedContent));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import java.io.IOException;
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.Base64;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
Expand All @@ -45,6 +46,7 @@
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.api.AnnotationsProto;
import com.google.api.HttpBody;
import com.google.api.HttpRule;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.CaseFormat;
Expand Down Expand Up @@ -78,7 +80,9 @@
import com.google.protobuf.Value;

import com.linecorp.armeria.common.AggregatedHttpRequest;
import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
Expand All @@ -87,6 +91,7 @@
import com.linecorp.armeria.common.QueryParams;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.RequestHeadersBuilder;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.grpc.GrpcSerializationFormats;
import com.linecorp.armeria.common.grpc.protocol.GrpcHeaderNames;
Expand All @@ -107,6 +112,7 @@
import com.linecorp.armeria.server.grpc.HttpJsonTranscodingPathParser.VariablePathSegment;
import com.linecorp.armeria.server.grpc.HttpJsonTranscodingPathParser.VerbPathSegment;
import com.linecorp.armeria.server.grpc.HttpJsonTranscodingService.PathVariable.ValueDefinition.Type;
import com.linecorp.armeria.unsafe.PooledObjects;

import io.grpc.MethodDescriptor.MethodType;
import io.grpc.ServerMethodDefinition;
Expand Down Expand Up @@ -195,11 +201,11 @@ static GrpcService of(GrpcService delegate, HttpJsonTranscodingOptions httpJsonT
= toRouteAndPathVariables(additionalHttpRule);
if (additionalRouteAndVariables != null) {
specs.put(additionalRouteAndVariables.getKey(),
new TranscodingSpec(order++, additionalHttpRule, methodDefinition,
serviceDesc, methodDesc, originalFields,
camelCaseFields,
additionalRouteAndVariables.getValue(),
responseBody));
new TranscodingSpec(order++, additionalHttpRule, methodDefinition,
serviceDesc, methodDesc, originalFields,
camelCaseFields,
additionalRouteAndVariables.getValue(),
responseBody));
}
}
}
Expand Down Expand Up @@ -435,7 +441,7 @@ private static String getResponseBody(List<FieldDescriptor> topLevelFields,
if (StringUtil.isNullOrEmpty(responseBody)) {
return null;
}
for (FieldDescriptor fieldDescriptor: topLevelFields) {
for (FieldDescriptor fieldDescriptor : topLevelFields) {
if (fieldDescriptor.getName().equals(responseBody)) {
return responseBody;
}
Expand All @@ -444,41 +450,93 @@ private static String getResponseBody(List<FieldDescriptor> topLevelFields,
}

@Nullable
private static Function<HttpData, HttpData> generateResponseBodyConverter(TranscodingSpec spec) {
@Nullable final String responseBody = spec.responseBody;
private static Function<AggregatedHttpResponse, AggregatedHttpResponse> generateResponseConverter(
TranscodingSpec spec) {
// Ignore the spec if the method is HttpBody. The response body is already in the correct format.
if (HttpBody.getDescriptor().equals(spec.methodDescriptor.getOutputType())) {
return httpResponse -> {
final HttpData data = httpResponse.content();
final JsonNode jsonNode = extractHttpBody(data);

// Failed to parse the JSON body, return the original response.
if (jsonNode == null) {
return httpResponse;
}

PooledObjects.close(data);

// The data field is base64 encoded.
// https://protobuf.dev/programming-guides/proto3/#json
final String httpBody = jsonNode.get("data").asText();
final byte[] httpBodyBytes = Base64.getDecoder().decode(httpBody);

final ResponseHeaders newHeaders = httpResponse.headers().withMutations(builder -> {
final JsonNode contentType = jsonNode.get("contentType");

if (contentType != null && contentType.isTextual()) {
builder.set(HttpHeaderNames.CONTENT_TYPE, contentType.textValue());
} else {
builder.remove(HttpHeaderNames.CONTENT_TYPE);
}
});

return AggregatedHttpResponse.of(newHeaders, HttpData.wrap(httpBodyBytes));
};
}

@Nullable
final String responseBody = spec.responseBody;
if (responseBody == null) {
return null;
} else {
return httpData -> {
try (HttpData data = httpData) {
final byte[] array = data.array();
try {
final JsonNode jsonNode = mapper.readValue(array, JsonNode.class);
// we try to convert lower snake case response body to camel case
final String lowerCamelCaseResponseBody =
CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, responseBody);
final Iterator<Entry<String, JsonNode>> fields = jsonNode.fields();
while (fields.hasNext()) {
final Entry<String, JsonNode> entry = fields.next();
final String fieldName = entry.getKey();
final JsonNode responseBodyJsonNode = entry.getValue();
// try to match field name and response body
// 1. by default the marshaller would use lowerCamelCase in json field
// 2. when the marshaller use original name in .proto file when serializing messages
if (fieldName.equals(lowerCamelCaseResponseBody) ||
fieldName.equals(responseBody)) {
final byte[] bytes = mapper.writeValueAsBytes(responseBodyJsonNode);
return HttpData.wrap(bytes);
}
}
return HttpData.ofUtf8("null");
} catch (IOException e) {
logger.warn("Unexpected exception while extracting responseBody '{}' from {}",
responseBody, data, e);
return HttpData.wrap(array);
}
}

return httpResponse -> {
try (HttpData data = httpResponse.content()) {
final HttpData convertedData = convertHttpDataForResponseBody(responseBody, data);
return AggregatedHttpResponse.of(httpResponse.headers(), convertedData);
}
};
}

@Nullable
private static JsonNode extractHttpBody(HttpData data) {
final byte[] array = data.array();

try {
return mapper.readValue(array, JsonNode.class);
} catch (IOException e) {
logger.warn("Unexpected exception while parsing HttpBody from {}", data, e);
return null;
}
}

private static HttpData convertHttpDataForResponseBody(String responseBody, HttpData data) {
final byte[] array = data.array();
try {
final JsonNode jsonNode = mapper.readValue(array, JsonNode.class);

// we try to convert lower snake case response body to camel case
final String lowerCamelCaseResponseBody =
CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, responseBody);
final Iterator<Entry<String, JsonNode>> fields = jsonNode.fields();
while (fields.hasNext()) {
final Entry<String, JsonNode> entry = fields.next();
final String fieldName = entry.getKey();
final JsonNode responseBodyJsonNode = entry.getValue();
// try to match field name and response body
// 1. by default the marshaller would use lowerCamelCase in json field
// 2. when the marshaller use original name in .proto file when serializing messages
if (fieldName.equals(lowerCamelCaseResponseBody) ||
fieldName.equals(responseBody)) {
final byte[] bytes = mapper.writeValueAsBytes(responseBodyJsonNode);
return HttpData.wrap(bytes);
}
};
}
return HttpData.ofUtf8("null");
} catch (IOException e) {
logger.warn("Unexpected exception while extracting responseBody '{}' from {}",
responseBody, data, e);
return HttpData.wrap(array);
}
}

Expand Down Expand Up @@ -582,10 +640,20 @@ private HttpResponse serve0(ServiceRequestContext ctx, HttpRequest req,
} else {
try {
ctx.setAttr(FramedGrpcService.RESOLVED_GRPC_METHOD, spec.method);
// Set JSON media type (https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_filters/grpc_json_transcoder_filter#sending-arbitrary-content)
final HttpData requestContent;

// https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_filters/grpc_json_transcoder_filter#sending-arbitrary-content
if (HttpBody.getDescriptor().equals(spec.methodDescriptor.getInputType())) {
// Convert the HTTP request to a JSON representation of HttpBody.
requestContent = convertToHttpBody(clientRequest);
} else {
// Convert the HTTP request to gRPC JSON.
requestContent = convertToJson(ctx, clientRequest, spec);
}

frameAndServe(unwrap(), ctx, grpcHeaders.build(),
convertToJson(ctx, clientRequest, spec), responseFuture,
generateResponseBodyConverter(spec), MediaType.JSON_UTF_8);
requestContent, responseFuture,
generateResponseConverter(spec));
} catch (IllegalArgumentException iae) {
responseFuture.completeExceptionally(
HttpStatusException.of(HttpStatus.BAD_REQUEST, iae));
Expand All @@ -599,6 +667,29 @@ private HttpResponse serve0(ServiceRequestContext ctx, HttpRequest req,
return HttpResponse.of(responseFuture);
}

private static HttpData convertToHttpBody(AggregatedHttpRequest request) throws IOException {
final ObjectNode body = mapper.createObjectNode();

try (HttpData content = request.content()) {
final MediaType contentType;

@Nullable
final MediaType requestContentType = request.contentType();
if (requestContentType != null) {
contentType = requestContentType;
} else {
contentType = MediaType.OCTET_STREAM;
}

body.put("content_type", contentType.toString());
// Jackson converts byte array to base64 string. gRPC transcoding spec also returns base64 string.
// https://protobuf.dev/programming-guides/proto3/#json
body.put("data", content.array());

return HttpData.wrap(mapper.writeValueAsBytes(body));
}
}

/**
* Converts the HTTP request to gRPC JSON with the {@link TranscodingSpec}.
*/
Expand All @@ -625,7 +716,7 @@ private HttpData convertToJson(ServiceRequestContext ctx,
root = mapper.createObjectNode();
} else {
throw new IllegalArgumentException("Unexpected JSON: " +
body + ", (expected: ObjectNode or null).");
body + ", (expected: ObjectNode or null).");
}
return setParametersAndWriteJson(root, ctx, spec);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;

import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.AggregationOptions;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.RequestHeadersBuilder;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.grpc.GrpcSerializationFormats;
import com.linecorp.armeria.common.grpc.protocol.GrpcHeaderNames;
Expand Down Expand Up @@ -153,8 +156,15 @@ public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exc
if (t != null) {
responseFuture.completeExceptionally(t);
} else {
// Add the content type to the response headers.
final Function<AggregatedHttpResponse, AggregatedHttpResponse> responseConverter =
response -> {
final ResponseHeaders headers = response.headers().withMutations(
builder -> builder.contentType(contentType));
return AggregatedHttpResponse.of(headers, response.content());
};
frameAndServe(unwrap(), ctx, grpcHeaders.build(), clientRequest.content(),
responseFuture, null, contentType);
responseFuture, responseConverter);
}
}
return null;
Expand Down
Loading

0 comments on commit c7b2a44

Please sign in to comment.