Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void packetSent(Session session, Packet packet) {
public void connected(ConnectedEvent event) {
log.info("CLIENT Connected");

event.getSession().enableEncryption(((TestProtocol) event.getSession().getPacketProtocol()).getEncryption());
event.getSession().setEncryption(((TestProtocol) event.getSession().getPacketProtocol()).getEncryption());
event.getSession().send(new PingPacket("hello"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public void serverClosed(ServerClosedEvent event) {
public void sessionAdded(SessionAddedEvent event) {
log.info("SERVER Session Added: {}:{}", event.getSession().getHost(), event.getSession().getPort());
((TestProtocol) event.getSession().getPacketProtocol()).setSecretKey(this.key);
event.getSession().enableEncryption(((TestProtocol) event.getSession().getPacketProtocol()).getEncryption());
event.getSession().setEncryption(((TestProtocol) event.getSession().getPacketProtocol()).getEncryption());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import org.geysermc.mcprotocollib.network.codec.PacketDefinition;
import org.geysermc.mcprotocollib.network.codec.PacketSerializer;
import org.geysermc.mcprotocollib.network.crypt.AESEncryption;
import org.geysermc.mcprotocollib.network.crypt.PacketEncryption;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;
import org.geysermc.mcprotocollib.network.packet.DefaultPacketHeader;
import org.geysermc.mcprotocollib.network.packet.PacketHeader;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
Expand All @@ -23,7 +23,7 @@ public class TestProtocol extends PacketProtocol {
private static final Logger log = LoggerFactory.getLogger(TestProtocol.class);
private final PacketHeader header = new DefaultPacketHeader();
private final PacketRegistry registry = new PacketRegistry();
private AESEncryption encrypt;
private EncryptionConfig encrypt;

@SuppressWarnings("unused")
public TestProtocol() {
Expand Down Expand Up @@ -51,7 +51,7 @@ public PingPacket deserialize(ByteBuf buf, PacketCodecHelper helper, PacketDefin
});

try {
this.encrypt = new AESEncryption(key);
this.encrypt = new EncryptionConfig(new AESEncryption(key));
} catch (GeneralSecurityException e) {
log.error("Failed to create encryption", e);
}
Expand All @@ -67,7 +67,7 @@ public PacketHeader getPacketHeader() {
return this.header;
}

public PacketEncryption getEncryption() {
public EncryptionConfig getEncryption() {
return this.encrypt;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package org.geysermc.mcprotocollib.network;

import io.netty.util.AttributeKey;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;

public class NetworkConstants {
public static final AttributeKey<CompressionConfig> COMPRESSION_ATTRIBUTE_KEY = AttributeKey.valueOf("compression_threshold");
public static final AttributeKey<EncryptionConfig> ENCRYPTION_ATTRIBUTE_KEY = AttributeKey.valueOf("encryption");
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.geysermc.mcprotocollib.network.codec.PacketCodecHelper;
import org.geysermc.mcprotocollib.network.crypt.PacketEncryption;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;
import org.geysermc.mcprotocollib.network.event.session.SessionEvent;
import org.geysermc.mcprotocollib.network.event.session.SessionListener;
import org.geysermc.mcprotocollib.network.packet.Packet;
Expand Down Expand Up @@ -183,26 +184,21 @@ public interface Session {
void callPacketSent(Packet packet);

/**
* Gets the compression packet length threshold for this session (-1 = disabled).
* Sets the compression config for this session.
*
* @return This session's compression threshold.
* @param compressionConfig the compression to compress with,
* or null to disable compression
*/
int getCompressionThreshold();
void setCompression(CompressionConfig compressionConfig);

/**
* Sets the compression packet length threshold for this session (-1 = disabled).
* Sets encryption for this session.
*
* @param threshold The new compression threshold.
* @param validateDecompression whether to validate that the decompression fits within size checks.
*/
void setCompressionThreshold(int threshold, boolean validateDecompression);

/**
* Enables encryption for this session.
* @param encryptionConfig the encryption to encrypt with,
* or null to disable encryption
*
* @param encryption the encryption to encrypt with
*/
void enableEncryption(PacketEncryption encryption);
void setEncryption(EncryptionConfig encryptionConfig);

/**
* Returns true if the session is connected.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package org.geysermc.mcprotocollib.network.compression;

public record CompressionConfig(int threshold, PacketCompression compression, boolean validateDecompression) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package org.geysermc.mcprotocollib.network.crypt;

public record EncryptionConfig(PacketEncryption encryption) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ public void initChannel(Channel channel) {
pipeline.addLast("read-timeout", new ReadTimeoutHandler(getFlag(BuiltinFlags.READ_TIMEOUT, 30)));
pipeline.addLast("write-timeout", new WriteTimeoutHandler(getFlag(BuiltinFlags.WRITE_TIMEOUT, 0)));

pipeline.addLast("encryption", new TcpPacketEncryptor());
pipeline.addLast("sizer", new TcpPacketSizer(protocol.getPacketHeader(), getCodecHelper()));
pipeline.addLast("compression", new TcpPacketCompression(getCodecHelper()));

pipeline.addLast("codec", new TcpPacketCodec(TcpClientSession.this, true));
pipeline.addLast("manager", TcpClientSession.this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,71 @@
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.MessageToMessageCodec;
import org.geysermc.mcprotocollib.network.Session;
import org.geysermc.mcprotocollib.network.compression.PacketCompression;
import lombok.RequiredArgsConstructor;
import org.geysermc.mcprotocollib.network.NetworkConstants;
import org.geysermc.mcprotocollib.network.codec.PacketCodecHelper;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;

import java.util.List;

@RequiredArgsConstructor
public class TcpPacketCompression extends MessageToMessageCodec<ByteBuf, ByteBuf> {
private static final int MAX_UNCOMPRESSED_SIZE = 8 * 1024 * 1024; // 8MiB

private final Session session;
private final PacketCompression compression;
private final boolean validateDecompression;

public TcpPacketCompression(Session session, PacketCompression compression, boolean validateDecompression) {
this.session = session;
this.compression = compression;
this.validateDecompression = validateDecompression;
}
private final PacketCodecHelper helper;

@Override
public void handlerRemoved(ChannelHandlerContext ctx) {
this.compression.close();
CompressionConfig config = ctx.channel().attr(NetworkConstants.COMPRESSION_ATTRIBUTE_KEY).get();
if (config == null) {
return;
}

config.compression().close();
}

@Override
public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
CompressionConfig config = ctx.channel().attr(NetworkConstants.COMPRESSION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(msg.retain());
return;
}

int uncompressed = msg.readableBytes();
if (uncompressed > MAX_UNCOMPRESSED_SIZE) {
throw new IllegalArgumentException("Packet too big (is " + uncompressed + ", should be less than " + MAX_UNCOMPRESSED_SIZE + ")");
}

ByteBuf outBuf = ctx.alloc().directBuffer(uncompressed);
if (uncompressed < this.session.getCompressionThreshold()) {
if (uncompressed < config.threshold()) {
// Under the threshold, there is nothing to do.
this.session.getCodecHelper().writeVarInt(outBuf, 0);
this.helper.writeVarInt(outBuf, 0);
outBuf.writeBytes(msg);
} else {
this.session.getCodecHelper().writeVarInt(outBuf, uncompressed);
compression.deflate(msg, outBuf);
this.helper.writeVarInt(outBuf, uncompressed);
config.compression().deflate(msg, outBuf);
}

out.add(outBuf);
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
int claimedUncompressedSize = this.session.getCodecHelper().readVarInt(in);
CompressionConfig config = ctx.channel().attr(NetworkConstants.COMPRESSION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(in.retain());
return;
}

int claimedUncompressedSize = this.helper.readVarInt(in);
if (claimedUncompressedSize == 0) {
out.add(in.retain());
return;
}

if (validateDecompression) {
if (claimedUncompressedSize < this.session.getCompressionThreshold()) {
throw new DecoderException("Badly compressed packet - size of " + claimedUncompressedSize + " is below server threshold of " + this.session.getCompressionThreshold());
if (config.validateDecompression()) {
if (claimedUncompressedSize < config.threshold()) {
throw new DecoderException("Badly compressed packet - size of " + claimedUncompressedSize + " is below server threshold of " + config.threshold());
}

if (claimedUncompressedSize > MAX_UNCOMPRESSED_SIZE) {
Expand All @@ -67,7 +78,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {

ByteBuf uncompressed = ctx.alloc().directBuffer(claimedUncompressedSize);
try {
compression.inflate(in, uncompressed, claimedUncompressedSize);
config.compression().inflate(in, uncompressed, claimedUncompressedSize);
out.add(uncompressed);
} catch (Exception e) {
uncompressed.release();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,27 @@
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.MessageToMessageCodec;
import org.geysermc.mcprotocollib.network.crypt.PacketEncryption;
import org.geysermc.mcprotocollib.network.NetworkConstants;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;

import java.util.List;

public class TcpPacketEncryptor extends MessageToMessageCodec<ByteBuf, ByteBuf> {
private final PacketEncryption encryption;

public TcpPacketEncryptor(PacketEncryption encryption) {
this.encryption = encryption;
}

@Override
public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
EncryptionConfig config = ctx.channel().attr(NetworkConstants.ENCRYPTION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(msg.retain());
return;
}

ByteBuf heapBuf = this.ensureHeapBuffer(ctx.alloc(), msg);

int inBytes = heapBuf.readableBytes();
int baseOffset = heapBuf.arrayOffset() + heapBuf.readerIndex();

try {
encryption.encrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
config.encryption().encrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
out.add(heapBuf);
} catch (Exception e) {
heapBuf.release();
Expand All @@ -35,13 +36,19 @@ public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
EncryptionConfig config = ctx.channel().attr(NetworkConstants.ENCRYPTION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(in.retain());
return;
}

ByteBuf heapBuf = this.ensureHeapBuffer(ctx.alloc(), in).slice();

int inBytes = heapBuf.readableBytes();
int baseOffset = heapBuf.arrayOffset() + heapBuf.readerIndex();

try {
encryption.decrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
config.encryption().decrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
out.add(heapBuf);
} catch (Exception e) {
heapBuf.release();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ public void initChannel(Channel channel) {
pipeline.addLast("read-timeout", new ReadTimeoutHandler(session.getFlag(BuiltinFlags.READ_TIMEOUT, 30)));
pipeline.addLast("write-timeout", new WriteTimeoutHandler(session.getFlag(BuiltinFlags.WRITE_TIMEOUT, 0)));

pipeline.addLast("encryption", new TcpPacketEncryptor());
pipeline.addLast("sizer", new TcpPacketSizer(protocol.getPacketHeader(), session.getCodecHelper()));
pipeline.addLast("compression", new TcpPacketCompression(session.getCodecHelper()));

pipeline.addLast("codec", new TcpPacketCodec(session, false));
pipeline.addLast("manager", session);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.geysermc.mcprotocollib.network.Flag;
import org.geysermc.mcprotocollib.network.NetworkConstants;
import org.geysermc.mcprotocollib.network.Session;
import org.geysermc.mcprotocollib.network.compression.ZlibCompression;
import org.geysermc.mcprotocollib.network.crypt.PacketEncryption;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;
import org.geysermc.mcprotocollib.network.event.session.ConnectedEvent;
import org.geysermc.mcprotocollib.network.event.session.DisconnectedEvent;
import org.geysermc.mcprotocollib.network.event.session.DisconnectingEvent;
Expand Down Expand Up @@ -47,8 +48,6 @@ public abstract class TcpSession extends SimpleChannelInboundHandler<Packet> imp
private final PacketProtocol protocol;
private final EventLoop eventLoop = createEventLoop();

private int compressionThreshold = -1;

private final Map<String, Object> flags = new HashMap<>();
private final List<SessionListener> listeners = new CopyOnWriteArrayList<>();

Expand Down Expand Up @@ -188,31 +187,21 @@ public void callPacketSent(Packet packet) {
}

@Override
public int getCompressionThreshold() {
return this.compressionThreshold;
}

@Override
public void setCompressionThreshold(int threshold, boolean validateDecompression) {
this.compressionThreshold = threshold;
if (this.channel != null) {
if (this.compressionThreshold >= 0) {
if (this.channel.pipeline().get("compression") == null) {
this.channel.pipeline().addBefore("codec", "compression",
new TcpPacketCompression(this, new ZlibCompression(), validateDecompression));
}
} else if (this.channel.pipeline().get("compression") != null) {
this.channel.pipeline().remove("compression");
}
public void setCompression(CompressionConfig compressionConfig) {
if (this.channel == null) {
throw new IllegalStateException("You need to be connected to set the compression!");
}

channel.attr(NetworkConstants.COMPRESSION_ATTRIBUTE_KEY).set(compressionConfig);
}

@Override
public void enableEncryption(PacketEncryption encryption) {
public void setEncryption(EncryptionConfig encryptionConfig) {
if (channel == null) {
throw new IllegalStateException("Connect the client before initializing encryption!");
throw new IllegalStateException("You need to connect to enable encryption!");
}
channel.pipeline().addBefore("sizer", "encryption", new TcpPacketEncryptor(encryption));

channel.attr(NetworkConstants.ENCRYPTION_ATTRIBUTE_KEY).set(encryptionConfig);
}

@Override
Expand Down Expand Up @@ -267,7 +256,7 @@ public void disconnect(@NonNull Component reason, @Nullable Throwable cause) {
// daemon threads and their interaction with the runtime.
PACKET_EVENT_LOOP = new DefaultEventLoopGroup(new DefaultThreadFactory(this.getClass(), true));
Runtime.getRuntime().addShutdownHook(new Thread(
() -> PACKET_EVENT_LOOP.shutdownGracefully(SHUTDOWN_QUIET_PERIOD_MS, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)));
() -> PACKET_EVENT_LOOP.shutdownGracefully(SHUTDOWN_QUIET_PERIOD_MS, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)));
}
return PACKET_EVENT_LOOP.next();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import org.geysermc.mcprotocollib.auth.GameProfile;
import org.geysermc.mcprotocollib.auth.SessionService;
import org.geysermc.mcprotocollib.network.Session;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;
import org.geysermc.mcprotocollib.network.compression.ZlibCompression;
import org.geysermc.mcprotocollib.network.event.session.ConnectedEvent;
import org.geysermc.mcprotocollib.network.event.session.SessionAdapter;
import org.geysermc.mcprotocollib.network.packet.Packet;
Expand Down Expand Up @@ -91,13 +93,14 @@ public void packetReceived(Session session, Packet packet) {
}

session.send(new ServerboundKeyPacket(helloPacket.getPublicKey(), key, helloPacket.getChallenge()));
session.enableEncryption(protocol.enableEncryption(key));
session.setEncryption(protocol.enableEncryption(key));
} else if (packet instanceof ClientboundGameProfilePacket) {
session.send(new ServerboundLoginAcknowledgedPacket());
} else if (packet instanceof ClientboundLoginDisconnectPacket loginDisconnectPacket) {
session.disconnect(loginDisconnectPacket.getReason());
} else if (packet instanceof ClientboundLoginCompressionPacket loginCompressionPacket) {
session.setCompressionThreshold(loginCompressionPacket.getThreshold(), false);
session.setCompression(loginCompressionPacket.getThreshold() >= 0 ?
new CompressionConfig(loginCompressionPacket.getThreshold(), new ZlibCompression(), false) : null);
}
} else if (protocol.getState() == ProtocolState.STATUS) {
if (packet instanceof ClientboundStatusResponsePacket statusResponsePacket) {
Expand Down
Loading
Loading