Skip to content

Commit

Permalink
Support proxy protocol for gRPC and Remoting server.
Browse files Browse the repository at this point in the history
  • Loading branch information
shuangxi.dsx committed Jun 27, 2023
1 parent a82b991 commit c0b5837
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,13 @@ private void replaceEventWithMessage(HAProxyMessage msg) {
builder.set(AttributesConstants.PROXY_PROTOCOL_ADDR, msg.sourceAddress());
}
if (msg.sourcePort() > 0) {
builder.set(AttributesConstants.PROXY_PROTOCOL_PORT, msg.sourcePort());
builder.set(AttributesConstants.PROXY_PROTOCOL_PORT, String.valueOf(msg.sourcePort()));
}
if (StringUtils.isNotBlank(msg.destinationAddress())) {
builder.set(AttributesConstants.PROXY_PROTOCOL_SERVER_ADDR, msg.destinationAddress());
}
if (msg.destinationPort() > 0) {
builder.set(AttributesConstants.PROXY_PROTOCOL_SERVER_PORT, msg.destinationPort());
builder.set(AttributesConstants.PROXY_PROTOCOL_SERVER_PORT, String.valueOf(msg.destinationPort()));
}
if (CollectionUtils.isNotEmpty(msg.tlvs())) {
msg.tlvs().forEach(tlv -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ public class AttributesConstants {
public static final Attributes.Key<String> PROXY_PROTOCOL_ADDR =
Attributes.Key.create(HAProxyConstants.PROXY_PROTOCOL_ADDR);

public static final Attributes.Key<Integer> PROXY_PROTOCOL_PORT =
public static final Attributes.Key<String> PROXY_PROTOCOL_PORT =
Attributes.Key.create(HAProxyConstants.PROXY_PROTOCOL_PORT);

public static final Attributes.Key<String> PROXY_PROTOCOL_SERVER_ADDR =
Attributes.Key.create(HAProxyConstants.PROXY_PROTOCOL_SERVER_ADDR);

public static final Attributes.Key<Integer> PROXY_PROTOCOL_SERVER_PORT =
public static final Attributes.Key<String> PROXY_PROTOCOL_SERVER_PORT =
Attributes.Key.create(HAProxyConstants.PROXY_PROTOCOL_SERVER_PORT);
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ private String parseSocketAddress(SocketAddress socketAddress) {

private String getProxyProtocolAddress(Attributes attributes) {
String proxyProtocolAddr = attributes.get(AttributesConstants.PROXY_PROTOCOL_ADDR);
Integer proxyProtocolPort = attributes.get(AttributesConstants.PROXY_PROTOCOL_PORT);
if (StringUtils.isBlank(proxyProtocolAddr) || proxyProtocolPort == null) {
String proxyProtocolPort = attributes.get(AttributesConstants.PROXY_PROTOCOL_PORT);
if (StringUtils.isBlank(proxyProtocolAddr) || StringUtils.isEmpty(proxyProtocolPort)) {
return null;
}
return proxyProtocolAddr + ":" + proxyProtocolPort;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
import io.netty.handler.timeout.IdleStateHandler;
import java.io.IOException;
import java.security.cert.CertificateException;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.logging.org.slf4j.Logger;
import org.apache.rocketmq.logging.org.slf4j.LoggerFactory;
Expand All @@ -37,6 +35,9 @@
import org.apache.rocketmq.remoting.netty.NettyServerConfig;
import org.apache.rocketmq.remoting.netty.TlsSystemConfig;

import java.io.IOException;
import java.security.cert.CertificateException;

/**
* support remoting and http2 protocol at one port
*/
Expand Down Expand Up @@ -79,6 +80,7 @@ public void loadSslContext() {
protected ChannelPipeline configChannel(SocketChannel ch) {
return ch.pipeline()
.addLast(this.getDefaultEventExecutorGroup(), HA_PROXY_DECODER, new HAProxyMessageDecoder())
.addLast(this.getDefaultEventExecutorGroup(), HA_PROXY_HANDLER, new HAProxyMessageHandler())
.addLast(this.getDefaultEventExecutorGroup(), HANDSHAKE_HANDLER_NAME, this.getHandshakeHandler())
.addLast(this.getDefaultEventExecutorGroup(),
new IdleStateHandler(0, 0, nettyServerConfig.getServerChannelMaxIdleTimeSeconds()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,8 @@
import io.netty.channel.ChannelFutureListener;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.apache.rocketmq.common.constant.HAProxyConstants;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.common.utils.NetworkUtil;
import org.apache.rocketmq.logging.org.slf4j.Logger;
Expand All @@ -43,13 +37,25 @@
import org.apache.rocketmq.remoting.protocol.RequestCode;
import org.apache.rocketmq.remoting.protocol.ResponseCode;

import java.io.IOException;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.HashMap;
import java.util.Map;

public class RemotingHelper {
public static final String DEFAULT_CHARSET = "UTF-8";
public static final String DEFAULT_CIDR_ALL = "0.0.0.0/0";

private static final Logger log = LoggerFactory.getLogger(LoggerName.ROCKETMQ_REMOTING_NAME);
private static final AttributeKey<String> REMOTE_ADDR_KEY = AttributeKey.valueOf("RemoteAddr");

private static final AttributeKey<String> PROXY_PROTOCOL_ADDR = AttributeKey.valueOf(HAProxyConstants.PROXY_PROTOCOL_ADDR);
private static final AttributeKey<String> PROXY_PROTOCOL_PORT = AttributeKey.valueOf(HAProxyConstants.PROXY_PROTOCOL_PORT);

public static final AttributeKey<String> CLIENT_ID_KEY = AttributeKey.valueOf("ClientId");

public static final AttributeKey<Integer> VERSION_KEY = AttributeKey.valueOf("Version");
Expand Down Expand Up @@ -203,19 +209,35 @@ public static String parseChannelRemoteAddr(final Channel channel) {
if (null == channel) {
return "";
}
String addr = getProxyProtocolAddress(channel);
if (StringUtils.isNotBlank(addr)) {
return addr;
}
Attribute<String> att = channel.attr(REMOTE_ADDR_KEY);
if (att == null) {
// mocked in unit test
return parseChannelRemoteAddr0(channel);
}
String addr = att.get();
addr = att.get();
if (addr == null) {
addr = parseChannelRemoteAddr0(channel);
att.set(addr);
}
return addr;
}

private static String getProxyProtocolAddress(Channel channel) {
if (!channel.hasAttr(PROXY_PROTOCOL_ADDR)) {
return null;
}
String proxyProtocolAddr = getAttributeValue(PROXY_PROTOCOL_ADDR, channel);
String proxyProtocolPort = getAttributeValue(PROXY_PROTOCOL_PORT, channel);
if (StringUtils.isBlank(proxyProtocolAddr) || proxyProtocolPort == null) {
return null;
}
return proxyProtocolAddr + ":" + proxyProtocolPort;
}

private static String parseChannelRemoteAddr0(final Channel channel) {
SocketAddress remote = channel.remoteAddress();
final String addr = remote != null ? remote.toString() : "";
Expand Down Expand Up @@ -255,7 +277,7 @@ public static String parseSocketAddressAddr(SocketAddress socketAddress) {
return "";
}

public static int parseSocketAddressPort(SocketAddress socketAddress) {
public static Integer parseSocketAddressPort(SocketAddress socketAddress) {
if (socketAddress instanceof InetSocketAddress) {
return ((InetSocketAddress) socketAddress).getPort();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.rocketmq.remoting.netty;


import io.netty.util.AttributeKey;
import org.apache.rocketmq.common.constant.HAProxyConstants;

public class AttributesConstants {

public static final AttributeKey<String> PROXY_PROTOCOL_ADDR =
AttributeKey.valueOf(HAProxyConstants.PROXY_PROTOCOL_ADDR);

public static final AttributeKey<String> PROXY_PROTOCOL_PORT =
AttributeKey.valueOf(HAProxyConstants.PROXY_PROTOCOL_PORT);

public static final AttributeKey<String> PROXY_PROTOCOL_SERVER_ADDR =
AttributeKey.valueOf(HAProxyConstants.PROXY_PROTOCOL_SERVER_ADDR);

public static final AttributeKey<String> PROXY_PROTOCOL_SERVER_PORT =
AttributeKey.valueOf(HAProxyConstants.PROXY_PROTOCOL_SERVER_PORT);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,8 @@

import com.google.common.base.Stopwatch;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.util.AttributeKey;
import io.netty.util.CharsetUtil;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.rocketmq.common.constant.HAProxyConstants;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.logging.org.slf4j.Logger;
import org.apache.rocketmq.logging.org.slf4j.LoggerFactory;
Expand All @@ -43,16 +36,6 @@ public NettyDecoder() {
super(FRAME_MAX_LENGTH, 0, 4, 0, 4);
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof HAProxyMessage) {
HAProxyMessage message = (HAProxyMessage) msg;
this.addProxyProtocolHeader(message, ctx.channel());
} else {
super.channelRead(ctx, msg);
}
}

@Override
public Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
ByteBuf frame = null;
Expand All @@ -76,31 +59,4 @@ public Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {

return null;
}

/**
* The definition of key refers to the implementation of nginx
* <a href="https://nginx.org/en/docs/http/ngx_http_core_module.html#var_proxy_protocol_addr">ngx_http_core_module</a>
* @param msg
* @param channel
*/
private void addProxyProtocolHeader(HAProxyMessage msg, Channel channel) {
if (StringUtils.isNotBlank(msg.sourceAddress())) {
channel.attr(AttributeKey.valueOf(HAProxyConstants.PROXY_PROTOCOL_ADDR)).set(msg.sourceAddress());
}
if (msg.sourcePort() > 0) {
channel.attr(AttributeKey.valueOf(HAProxyConstants.PROXY_PROTOCOL_PORT)).set(msg.sourcePort());
}
if (StringUtils.isNotBlank(msg.destinationAddress())) {
channel.attr(AttributeKey.valueOf(HAProxyConstants.PROXY_PROTOCOL_SERVER_ADDR)).set(msg.destinationAddress());
}
if (msg.destinationPort() > 0) {
channel.attr(AttributeKey.valueOf(HAProxyConstants.PROXY_PROTOCOL_SERVER_PORT)).set(msg.destinationPort());
}
if (CollectionUtils.isNotEmpty(msg.tlvs())) {
msg.tlvs().forEach(tlv ->
channel.attr(AttributeKey.valueOf(HAProxyConstants.PROXY_PROTOCOL_TLV_PREFIX
+ String.format("%02x", tlv.typeByteValue())))
.set(tlv.content().toString(CharsetUtil.UTF_8)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
Expand All @@ -36,28 +37,22 @@
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.AttributeKey;
import io.netty.util.CharsetUtil;
import io.netty.util.HashedWheelTimer;
import io.netty.util.Timeout;
import io.netty.util.TimerTask;
import io.netty.util.concurrent.DefaultEventExecutorGroup;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.cert.CertificateException;
import java.util.NoSuchElementException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.rocketmq.common.Pair;
import org.apache.rocketmq.common.ThreadFactoryImpl;
import org.apache.rocketmq.common.constant.HAProxyConstants;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.common.utils.NetworkUtil;
import org.apache.rocketmq.logging.org.slf4j.Logger;
Expand All @@ -72,6 +67,19 @@
import org.apache.rocketmq.remoting.exception.RemotingTooMuchRequestException;
import org.apache.rocketmq.remoting.protocol.RemotingCommand;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.cert.CertificateException;
import java.util.NoSuchElementException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

@SuppressWarnings("NullableProblems")
public class NettyRemotingServer extends NettyRemotingAbstract implements RemotingServer {
private static final Logger log = LoggerFactory.getLogger(LoggerName.ROCKETMQ_REMOTING_NAME);
Expand All @@ -97,6 +105,7 @@ public class NettyRemotingServer extends NettyRemotingAbstract implements Remoti
private final ConcurrentMap<Integer/*Port*/, NettyRemotingAbstract> remotingServerTable = new ConcurrentHashMap<>();

public static final String HA_PROXY_DECODER = "HAProxyDecoder";
public static final String HA_PROXY_HANDLER = "HAProxyHandler";
public static final String HANDSHAKE_HANDLER_NAME = "handshakeHandler";
public static final String TLS_HANDLER_NAME = "sslHandler";
public static final String FILE_REGION_ENCODER_NAME = "fileRegionEncoder";
Expand Down Expand Up @@ -252,7 +261,8 @@ public void run(Timeout timeout) {
*/
protected ChannelPipeline configChannel(SocketChannel ch) {
return ch.pipeline()
.addLast(this.getDefaultEventExecutorGroup(), HA_PROXY_DECODER, new HAProxyMessageDecoder())
.addLast(defaultEventExecutorGroup, HA_PROXY_DECODER, new HAProxyMessageDecoder())
.addLast(defaultEventExecutorGroup, HA_PROXY_HANDLER, new HAProxyMessageHandler())
.addLast(defaultEventExecutorGroup, HANDSHAKE_HANDLER_NAME, handshakeHandler)
.addLast(defaultEventExecutorGroup,
encoder,
Expand Down Expand Up @@ -709,4 +719,46 @@ public ExecutorService getCallbackExecutor() {
return NettyRemotingServer.this.getCallbackExecutor();
}
}

public static class HAProxyMessageHandler extends ChannelInboundHandlerAdapter {

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof HAProxyMessage) {
fillChannelWithMessage((HAProxyMessage) msg, ctx.channel());
} else {
super.channelRead(ctx, msg);
}
ctx.pipeline().remove(this);
}

/**
* The definition of key refers to the implementation of nginx
* <a href="https://nginx.org/en/docs/http/ngx_http_core_module.html#var_proxy_protocol_addr">ngx_http_core_module</a>
* @param msg
* @param channel
*/
private void fillChannelWithMessage(HAProxyMessage msg, Channel channel) {
if (StringUtils.isNotBlank(msg.sourceAddress())) {
channel.attr(AttributesConstants.PROXY_PROTOCOL_ADDR).set(msg.sourceAddress());
}
if (msg.sourcePort() > 0) {
channel.attr(AttributesConstants.PROXY_PROTOCOL_PORT).set(String.valueOf(msg.sourcePort()));
}
if (StringUtils.isNotBlank(msg.destinationAddress())) {
channel.attr(AttributesConstants.PROXY_PROTOCOL_SERVER_ADDR).set(msg.destinationAddress());
}
if (msg.destinationPort() > 0) {
channel.attr(AttributesConstants.PROXY_PROTOCOL_SERVER_PORT).set(String.valueOf(msg.destinationPort()));
}
if (CollectionUtils.isNotEmpty(msg.tlvs())) {
msg.tlvs().forEach(tlv -> {
AttributeKey<String> key = AttributeKey.valueOf(HAProxyConstants.PROXY_PROTOCOL_TLV_PREFIX
+ String.format("%02x", tlv.typeByteValue()));
String value = StringUtils.trim(tlv.content().toString(CharsetUtil.UTF_8));
channel.attr(key).set(value);
});
}
}
}
}

0 comments on commit c0b5837

Please sign in to comment.