Skip to content

Commit

Permalink
update: Migration from jprotobuf serialization to protobuf:protoc; mi…
Browse files Browse the repository at this point in the history
…grating the business logic in http2Handler to a custom executorGroup; added http2 headers in grpcRequest; support for handling resetStream
  • Loading branch information
fengye404 committed Oct 14, 2024
1 parent e594712 commit 45b6027
Show file tree
Hide file tree
Showing 27 changed files with 202 additions and 191 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http2.Http2FrameCodecBuilder;
import io.netty.util.concurrent.DefaultEventExecutorGroup;
import io.netty.util.concurrent.EventExecutorGroup;

import java.lang.invoke.MethodHandles;

Expand All @@ -38,7 +40,7 @@ public ArthasGrpcServer(int port, String grpcServicePackageName) {

public void start() {
EventLoopGroup bossGroup = new NioEventLoopGroup(1);
EventLoopGroup workerGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup(10);

GrpcDispatcher grpcDispatcher = new GrpcDispatcher();
grpcDispatcher.loadGrpcService(grpcServicePackageName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* @description: ErrorRes
*/
@ProtobufClass
@Deprecated
public class ErrorRes {
private String errorMsg;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package com.taobao.arthas.grpc.server.handler;


import com.taobao.arthas.grpc.server.handler.annotation.GrpcMethod;
import com.taobao.arthas.grpc.server.handler.annotation.GrpcService;
import com.taobao.arthas.grpc.server.protobuf.ProtobufCodec;
import com.taobao.arthas.grpc.server.protobuf.ProtobufProxy;
import com.taobao.arthas.grpc.server.utils.ByteUtil;
import com.taobao.arthas.grpc.server.utils.ReflectUtil;

import java.lang.invoke.MethodHandle;
Expand All @@ -22,11 +22,17 @@
*/
public class GrpcDispatcher {

private static final String DEFAULT_GRPC_SERVICE_PACKAGE_NAME = "com.taobao.arthas.grpc.server.service.impl";
public static final String DEFAULT_GRPC_SERVICE_PACKAGE_NAME = "com.taobao.arthas.grpc.server.service.impl";

public static Map<String, MethodHandle> grpcMethodInvokeMap = new HashMap<>();

public static Map<String, MethodHandle> requestParseFromMap = new HashMap<>();
public static Map<String, MethodHandle> requestToByteArrayMap = new HashMap<>();

private Map<String, MethodHandle> grpcMethodInvokeMap = new HashMap<>();
public static Map<String, MethodHandle> responseParseFromMap = new HashMap<>();
public static Map<String, MethodHandle> responseToByteArrayMap = new HashMap<>();

private Map<String, Boolean> grpcMethodStreamMap = new HashMap<>();
public static Map<String, Boolean> grpcMethodStreamMap = new HashMap<>();

public void loadGrpcService(String grpcServicePackageName) {
List<Class<?>> classes = ReflectUtil.findClasses(Optional.ofNullable(grpcServicePackageName).orElse(DEFAULT_GRPC_SERVICE_PACKAGE_NAME));
Expand All @@ -43,10 +49,20 @@ public void loadGrpcService(String grpcServicePackageName) {
for (Method method : declaredMethods) {
if (method.isAnnotationPresent(GrpcMethod.class)) {
GrpcMethod grpcMethod = method.getAnnotation(GrpcMethod.class);
MethodHandle methodHandle = lookup.unreflect(method);
MethodHandle grpcInvoke = lookup.unreflect(method);
Class<?> requestClass = grpcInvoke.type().parameterType(1);
Class<?> responseClass = grpcInvoke.type().returnType();
MethodHandle requestParseFrom = lookup.findStatic(requestClass, "parseFrom", MethodType.methodType(requestClass, byte[].class));
MethodHandle responseParseFrom = lookup.findStatic(responseClass, "parseFrom", MethodType.methodType(responseClass, byte[].class));
MethodHandle requestToByteArray = lookup.findVirtual(requestClass, "toByteArray", MethodType.methodType(byte[].class));
MethodHandle responseToByteArray = lookup.findVirtual(responseClass, "toByteArray", MethodType.methodType(byte[].class));
String grpcMethodKey = generateGrpcMethodKey(grpcService.value(), grpcMethod.value());
grpcMethodInvokeMap.put(grpcMethodKey, methodHandle.bindTo(instance));
grpcMethodInvokeMap.put(grpcMethodKey, grpcInvoke.bindTo(instance));
grpcMethodStreamMap.put(grpcMethodKey, grpcMethod.stream());
requestParseFromMap.put(grpcMethodKey, requestParseFrom);
responseParseFromMap.put(grpcMethodKey, responseParseFrom);
requestToByteArrayMap.put(grpcMethodKey, requestToByteArray);
responseToByteArrayMap.put(grpcMethodKey, responseToByteArray);
}
}
} catch (Exception e) {
Expand All @@ -56,28 +72,23 @@ public void loadGrpcService(String grpcServicePackageName) {
}
}

private String generateGrpcMethodKey(String serviceName, String methodName) {
return serviceName + "." + methodName;
}

public GrpcResponse execute(String serviceName, String methodName, Object arg) throws Throwable {
MethodHandle methodHandle = grpcMethodInvokeMap.get(generateGrpcMethodKey(serviceName, methodName));
MethodType type = grpcMethodInvokeMap.get(generateGrpcMethodKey(serviceName, methodName)).type();
Object execute = methodHandle.invoke(arg);
public GrpcResponse execute(String service, String method, byte[] arg) throws Throwable {
MethodHandle methodHandle = grpcMethodInvokeMap.get(generateGrpcMethodKey(service, method));
MethodType type = grpcMethodInvokeMap.get(generateGrpcMethodKey(service, method)).type();
Object req = requestParseFromMap.get(generateGrpcMethodKey(service, method)).invoke(arg);
Object execute = methodHandle.invoke(req);
GrpcResponse grpcResponse = new GrpcResponse();
grpcResponse.setClazz(type.returnType());
grpcResponse.setService(service);
grpcResponse.setMethod(method);
grpcResponse.writeResponseData(execute);
return grpcResponse;
}

public GrpcResponse execute(GrpcRequest request) throws Throwable {
String service = request.getService();
String method = request.getMethod();
// protobuf 规范只能有单入参
request.setClazz(getRequestClass(request.getService(), request.getMethod()));
ProtobufCodec protobufCodec = ProtobufProxy.getCodecCacheSide(request.getClazz());
Object decode = protobufCodec.decode(request.readData());
return this.execute(service, method, decode);
return this.execute(service, method, request.readData());
}

/**
Expand All @@ -87,12 +98,16 @@ public GrpcResponse execute(GrpcRequest request) throws Throwable {
* @param methodName
* @return
*/
public Class<?> getRequestClass(String serviceName, String methodName) {
public static Class<?> getRequestClass(String serviceName, String methodName) {
//protobuf 规范只能有单入参
return Optional.ofNullable(grpcMethodInvokeMap.get(generateGrpcMethodKey(serviceName, methodName))).orElseThrow(() -> new RuntimeException("The specified grpc method does not exist")).type().parameterArray()[0];
}

public void checkGrpcStream(GrpcRequest request) {
public static String generateGrpcMethodKey(String serviceName, String methodName) {
return serviceName + "." + methodName;
}

public static void checkGrpcStream(GrpcRequest request) {
request.setStream(
Optional.ofNullable(grpcMethodStreamMap.get(generateGrpcMethodKey(request.getService(), request.getMethod())))
.orElse(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.taobao.arthas.grpc.server.utils.ByteUtil;
import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.http2.Http2Headers;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand Down Expand Up @@ -56,6 +57,11 @@ public class GrpcRequest {
*/
private boolean streamFirstData;

/**
* http2 headers
*/
private Http2Headers headers;


public GrpcRequest(Integer streamId, String path, String method) {
this.streamId = streamId;
Expand Down Expand Up @@ -158,4 +164,12 @@ public boolean isStreamFirstData() {
public void setStreamFirstData(boolean streamFirstData) {
this.streamFirstData = streamFirstData;
}

public Http2Headers getHeaders() {
return headers;
}

public void setHeaders(Http2Headers headers) {
this.headers = headers;
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package com.taobao.arthas.grpc.server.handler;

import com.taobao.arthas.grpc.server.protobuf.ProtobufCodec;
import com.taobao.arthas.grpc.server.protobuf.ProtobufProxy;

import arthas.grpc.common.ArthasGrpc;
import com.taobao.arthas.grpc.server.utils.ByteUtil;
import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.Http2Headers;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

Expand All @@ -20,6 +19,16 @@ public class GrpcResponse {

private Map<String, String> headers;

/**
* 请求的 service
*/
private String service;

/**
* 请求的 method
*/
private String method;

/**
* 二进制数据
*/
Expand Down Expand Up @@ -52,12 +61,15 @@ public ByteBuf getResponseData() {
}

public void writeResponseData(Object response) {
ProtobufCodec codec = ProtobufProxy.getCodecCacheSide(clazz);
byte[] encode = null;
try {
encode = codec.encode(response);
} catch (IOException e) {
throw new RuntimeException("ProtobufCodec encode error");
if (ArthasGrpc.ErrorRes.class.equals(clazz)) {
encode = ((ArthasGrpc.ErrorRes) response).toByteArray();
} else {
encode = (byte[]) GrpcDispatcher.responseToByteArrayMap.get(GrpcDispatcher.generateGrpcMethodKey(service, method)).invoke(response);
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
this.byteData = ByteUtil.newByteBuf();
this.byteData.writeBoolean(false);
Expand All @@ -68,4 +80,20 @@ public void writeResponseData(Object response) {
public void setClazz(Class<?> clazz) {
this.clazz = clazz;
}

public String getService() {
return service;
}

public void setService(String service) {
this.service = service;
}

public String getMethod() {
return method;
}

public void setMethod(String method) {
this.method = method;
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package com.taobao.arthas.grpc.server.handler;


import arthas.grpc.common.ArthasGrpc;
import com.alibaba.arthas.deps.org.slf4j.Logger;
import com.alibaba.arthas.deps.org.slf4j.LoggerFactory;
import com.taobao.arthas.grpc.server.protobuf.ProtobufCodec;
import com.taobao.arthas.grpc.server.protobuf.ProtobufProxy;
import com.taobao.arthas.grpc.server.utils.ByteUtil;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.http2.*;
import io.netty.util.concurrent.EventExecutorGroup;

import java.io.*;
import java.lang.invoke.MethodHandles;
Expand All @@ -25,6 +26,8 @@ public class Http2Handler extends SimpleChannelInboundHandler<Http2Frame> {

private GrpcDispatcher grpcDispatcher;

private final EventExecutorGroup executorGroup = new NioEventLoopGroup();

/**
* 暂存收到的所有请求的数据
*/
Expand All @@ -47,6 +50,8 @@ protected void channelRead0(ChannelHandlerContext ctx, Http2Frame frame) throws
handleGrpcRequest((Http2HeadersFrame) frame, ctx);
} else if (frame instanceof Http2DataFrame) {
handleGrpcData((Http2DataFrame) frame, ctx);
} else if (frame instanceof Http2ResetFrame) {
handleResetStream((Http2ResetFrame) frame, ctx);
}
}

Expand All @@ -62,63 +67,71 @@ private void handleGrpcRequest(Http2HeadersFrame headersFrame, ChannelHandlerCon
// 去掉前面的斜杠,然后按斜杠分割
String[] parts = path.substring(1).split("/");
GrpcRequest grpcRequest = new GrpcRequest(headersFrame.stream().id(), parts[0], parts[1]);
grpcDispatcher.checkGrpcStream(grpcRequest);
grpcRequest.setHeaders(headersFrame.headers());
GrpcDispatcher.checkGrpcStream(grpcRequest);
dataBuffer.put(id, grpcRequest);
System.out.println("Received headers: " + headersFrame.headers());
}

private void handleGrpcData(Http2DataFrame dataFrame, ChannelHandlerContext ctx) throws IOException {
GrpcRequest grpcRequest = dataBuffer.get(dataFrame.stream().id());
grpcRequest.writeData(dataFrame.content());

if (grpcRequest.isStream()) {
// 流式调用,即刻响应
try {
GrpcResponse response = new GrpcResponse();
byte[] bytes = grpcRequest.readData();
while (bytes != null) {
ProtobufCodec protobufCodec = ProtobufProxy.getCodecCacheSide(grpcDispatcher.getRequestClass(grpcRequest.getService(), grpcRequest.getMethod()));
Object decode = protobufCodec.decode(bytes);
response = grpcDispatcher.execute(grpcRequest.getService(), grpcRequest.getMethod(), decode);

// 针对第一个响应发送 header
if (grpcRequest.isStreamFirstData()) {
ctx.writeAndFlush(new DefaultHttp2HeadersFrame(response.getEndHeader()).stream(dataFrame.stream()));
grpcRequest.setStreamFirstData(false);
}
ctx.writeAndFlush(new DefaultHttp2DataFrame(response.getResponseData()).stream(dataFrame.stream()));
ByteBuf content = dataFrame.content();
grpcRequest.writeData(content);

bytes = grpcRequest.readData();
}
executorGroup.execute(() -> {
if (grpcRequest.isStream()) {
// 流式调用,即刻响应
try {
GrpcResponse response = new GrpcResponse();
byte[] bytes = grpcRequest.readData();
while (bytes != null) {
response = grpcDispatcher.execute(grpcRequest.getService(), grpcRequest.getMethod(), bytes);

// 针对第一个响应发送 header
if (grpcRequest.isStreamFirstData()) {
ctx.writeAndFlush(new DefaultHttp2HeadersFrame(response.getEndHeader()).stream(dataFrame.stream()));
grpcRequest.setStreamFirstData(false);
}
ctx.writeAndFlush(new DefaultHttp2DataFrame(response.getResponseData()).stream(dataFrame.stream()));

bytes = grpcRequest.readData();
}

grpcRequest.clearData();
grpcRequest.clearData();

if (dataFrame.isEndStream()) {
ctx.writeAndFlush(new DefaultHttp2HeadersFrame(response.getEndStreamHeader(), true).stream(dataFrame.stream()));
}
} catch (Throwable e) {
processError(ctx, e, dataFrame.stream());
}
} else {
// 非流式调用,等到 endStream 再响应
if (dataFrame.isEndStream()) {
try {
GrpcResponse response = grpcDispatcher.execute(grpcRequest);
ctx.writeAndFlush(new DefaultHttp2HeadersFrame(response.getEndHeader()).stream(dataFrame.stream()));
ctx.writeAndFlush(new DefaultHttp2DataFrame(response.getResponseData()).stream(dataFrame.stream()));
ctx.writeAndFlush(new DefaultHttp2HeadersFrame(response.getEndStreamHeader(), true).stream(dataFrame.stream()));
if (dataFrame.isEndStream()) {
ctx.writeAndFlush(new DefaultHttp2HeadersFrame(response.getEndStreamHeader(), true).stream(dataFrame.stream()));
}
} catch (Throwable e) {
processError(ctx, e, dataFrame.stream());
}
} else {
// 非流式调用,等到 endStream 再响应
if (dataFrame.isEndStream()) {
try {
GrpcResponse response = grpcDispatcher.execute(grpcRequest);
ctx.writeAndFlush(new DefaultHttp2HeadersFrame(response.getEndHeader()).stream(dataFrame.stream()));
ctx.writeAndFlush(new DefaultHttp2DataFrame(response.getResponseData()).stream(dataFrame.stream()));
ctx.writeAndFlush(new DefaultHttp2HeadersFrame(response.getEndStreamHeader(), true).stream(dataFrame.stream()));
} catch (Throwable e) {
processError(ctx, e, dataFrame.stream());
}
}
}
}
});
}

private void handleResetStream(Http2ResetFrame resetFrame, ChannelHandlerContext ctx) {
int id = resetFrame.stream().id();
System.out.println("handleResetStream");
dataBuffer.remove(id);
}

private void processError(ChannelHandlerContext ctx, Throwable e, Http2FrameStream stream) {
GrpcResponse response = new GrpcResponse();
ErrorRes errorRes = new ErrorRes();
errorRes.setErrorMsg(e.getMessage());
response.setClazz(ErrorRes.class);
ArthasGrpc.ErrorRes.Builder builder = ArthasGrpc.ErrorRes.newBuilder();
ArthasGrpc.ErrorRes errorRes = builder.setErrorMsg(e.getMessage()).build();
response.setClazz(ArthasGrpc.ErrorRes.class);
response.writeResponseData(errorRes);
ctx.writeAndFlush(new DefaultHttp2HeadersFrame(response.getEndHeader()).stream(stream));
ctx.writeAndFlush(new DefaultHttp2DataFrame(response.getResponseData()).stream(stream));
Expand Down
Loading

0 comments on commit 45b6027

Please sign in to comment.