Skip to content

Commit

Permalink
[tuya] Improve Netty handlers (#564)
Browse files Browse the repository at this point in the history
* [tuya] Improve Netty handlers

Signed-off-by: Jan N. Klug <[email protected]>
  • Loading branch information
J-N-K committed Jan 30, 2024
1 parent a4c5f6c commit b9fc33c
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.AttributeKey;

/**
* The {@link TuyaDevice} handles the device connection
Expand All @@ -57,22 +58,27 @@
*/
@NonNullByDefault
public class TuyaDevice implements ChannelFutureListener {
public static final AttributeKey<String> DEVICE_ID_ATTR = AttributeKey.valueOf("deviceId");
public static final AttributeKey<ProtocolVersion> PROTOCOL_ATTR = AttributeKey.valueOf("protocol");
public static final AttributeKey<byte[]> SESSION_RANDOM_ATTR = AttributeKey.valueOf("sessionRandom");
public static final AttributeKey<byte[]> SESSION_KEY_ATTR = AttributeKey.valueOf("sessionKey");

private final Logger logger = LoggerFactory.getLogger(TuyaDevice.class);

private final Bootstrap bootstrap = new Bootstrap();
private final DeviceStatusListener deviceStatusListener;
private final String deviceId;
private final byte[] deviceKey;

private final String address;
private final ProtocolVersion protocolVersion;
private final KeyStore keyStore;
private @Nullable Channel channel;

public TuyaDevice(Gson gson, DeviceStatusListener deviceStatusListener, EventLoopGroup eventLoopGroup,
String deviceId, byte[] deviceKey, String address, String protocolVersion) {
this.address = address;
this.deviceId = deviceId;
this.keyStore = new KeyStore(deviceKey);
this.deviceKey = deviceKey;
this.deviceStatusListener = deviceStatusListener;
this.protocolVersion = ProtocolVersion.fromString(protocolVersion);
bootstrap.group(eventLoopGroup).channel(NioSocketChannel.class);
Expand All @@ -83,20 +89,17 @@ protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast("idleStateHandler",
new IdleStateHandler(TCP_CONNECTION_TIMEOUT, TCP_CONNECTION_HEARTBEAT_INTERVAL, 0));
pipeline.addLast("messageEncoder",
new TuyaEncoder(gson, deviceId, keyStore, TuyaDevice.this.protocolVersion));
pipeline.addLast("messageDecoder",
new TuyaDecoder(gson, deviceId, keyStore, TuyaDevice.this.protocolVersion));
pipeline.addLast("heartbeatHandler", new HeartbeatHandler(deviceId));
pipeline.addLast("deviceHandler", new TuyaMessageHandler(deviceId, keyStore, deviceStatusListener));
pipeline.addLast("userEventHandler", new UserEventHandler(deviceId));
pipeline.addLast("messageEncoder", new TuyaEncoder(gson));
pipeline.addLast("messageDecoder", new TuyaDecoder(gson));
pipeline.addLast("heartbeatHandler", new HeartbeatHandler());
pipeline.addLast("deviceHandler", new TuyaMessageHandler(deviceStatusListener));
pipeline.addLast("userEventHandler", new UserEventHandler());
}
});
connect();
}

public void connect() {
keyStore.reset(); // reset session key
bootstrap.connect(address, 6668).addListener(this);
}

Expand Down Expand Up @@ -147,12 +150,22 @@ public void dispose() {
public void operationComplete(@NonNullByDefault({}) ChannelFuture channelFuture) throws Exception {
if (channelFuture.isSuccess()) {
Channel channel = channelFuture.channel();
this.channel = channel;
channel.attr(DEVICE_ID_ATTR).set(deviceId);
channel.attr(PROTOCOL_ATTR).set(protocolVersion);
// session key is device key before negotiation
channel.attr(SESSION_KEY_ATTR).set(deviceKey);

if (protocolVersion == V3_4) {
byte[] sessionRandom = CryptoUtil.generateRandom(16);
channel.attr(SESSION_RANDOM_ATTR).set(sessionRandom);
this.channel = channel;

// handshake for session key required
MessageWrapper<?> m = new MessageWrapper<>(SESS_KEY_NEG_START, keyStore.getRandom());
MessageWrapper<?> m = new MessageWrapper<>(SESS_KEY_NEG_START, sessionRandom);
channel.writeAndFlush(m);
} else {
this.channel = channel;

// no handshake for 3.1/3.3
requestStatus();
}
Expand All @@ -164,37 +177,4 @@ public void operationComplete(@NonNullByDefault({}) ChannelFuture channelFuture)
deviceStatusListener.connectionStatus(false);
}
}

public static class KeyStore {
private final byte[] deviceKey;
private byte[] sessionKey;
private byte[] random;

public KeyStore(byte[] deviceKey) {
this.deviceKey = deviceKey;
this.sessionKey = deviceKey;
this.random = CryptoUtil.generateRandom(16).clone();
}

public void reset() {
this.sessionKey = this.deviceKey;
this.random = CryptoUtil.generateRandom(16).clone();
}

public byte[] getDeviceKey() {
return sessionKey;
}

public byte[] getSessionKey() {
return sessionKey;
}

public void setSessionKey(byte[] sessionKey) {
this.sessionKey = sessionKey;
}

public byte[] getRandom() {
return random;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,25 @@ private void activate() {
protected void initChannel(DatagramChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast("udpDecoder", new DatagramToByteBufDecoder());
pipeline.addLast("messageDecoder", new TuyaDecoder(gson, "udpListener",
new TuyaDevice.KeyStore(TUYA_UDP_KEY), ProtocolVersion.V3_1));
pipeline.addLast("messageDecoder", new TuyaDecoder(gson));
pipeline.addLast("discoveryHandler",
new DiscoveryMessageHandler(deviceInfos, deviceListeners));
pipeline.addLast("userEventHandler", new UserEventHandler("udpListener"));
pipeline.addLast("userEventHandler", new UserEventHandler());
}
});

ChannelFuture futureEncrypted = b.bind(6667).addListener(this).sync();
encryptedChannel = futureEncrypted.channel();
encryptedChannel.attr(TuyaDevice.DEVICE_ID_ATTR).set("udpListener");
encryptedChannel.attr(TuyaDevice.PROTOCOL_ATTR).set(ProtocolVersion.V3_1);
encryptedChannel.attr(TuyaDevice.SESSION_KEY_ATTR).set(TUYA_UDP_KEY);

ChannelFuture futureRaw = b.bind(6666).addListener(this).sync();
rawChannel = futureRaw.channel();
rawChannel.attr(TuyaDevice.DEVICE_ID_ATTR).set("udpListener");
rawChannel.attr(TuyaDevice.PROTOCOL_ATTR).set(ProtocolVersion.V3_1);
rawChannel.attr(TuyaDevice.SESSION_KEY_ATTR).set(TUYA_UDP_KEY);

} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.slf4j.LoggerFactory;
import org.smarthomej.binding.tuya.internal.local.CommandType;
import org.smarthomej.binding.tuya.internal.local.MessageWrapper;
import org.smarthomej.binding.tuya.internal.local.TuyaDevice;

import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
Expand All @@ -37,18 +38,19 @@
@NonNullByDefault
public class HeartbeatHandler extends ChannelDuplexHandler {
private final Logger logger = LoggerFactory.getLogger(HeartbeatHandler.class);
private final String deviceId;
private int heartBeatMissed = 0;

public HeartbeatHandler(String deviceId) {
this.deviceId = deviceId;
}

@Override
public void userEventTriggered(@NonNullByDefault({}) ChannelHandlerContext ctx, @NonNullByDefault({}) Object evt)
throws Exception {
if (evt instanceof IdleStateEvent) {
IdleStateEvent e = (IdleStateEvent) evt;
if (!ctx.channel().hasAttr(TuyaDevice.DEVICE_ID_ATTR)) {
logger.warn("{}: Failed to retrieve deviceId from ChannelHandlerContext. This is a bug.",
Objects.requireNonNullElse(ctx.channel().remoteAddress(), ""));
return;
}
String deviceId = ctx.channel().attr(TuyaDevice.DEVICE_ID_ATTR).get();

if (evt instanceof IdleStateEvent e) {
if (IdleState.READER_IDLE.equals(e.state())) {
logger.warn("{}{}: Did not receive a message from for {} seconds. Connection seems to be dead.",
deviceId, Objects.requireNonNullElse(ctx.channel().remoteAddress(), ""),
Expand All @@ -75,8 +77,14 @@ public void userEventTriggered(@NonNullByDefault({}) ChannelHandlerContext ctx,
@Override
public void channelRead(@NonNullByDefault({}) ChannelHandlerContext ctx, @NonNullByDefault({}) Object msg)
throws Exception {
if (msg instanceof MessageWrapper<?>) {
MessageWrapper<?> m = (MessageWrapper<?>) msg;
if (!ctx.channel().hasAttr(TuyaDevice.DEVICE_ID_ATTR)) {
logger.warn("{}: Failed to retrieve deviceId from ChannelHandlerContext. This is a bug.",
Objects.requireNonNullElse(ctx.channel().remoteAddress(), ""));
return;
}
String deviceId = ctx.channel().attr(TuyaDevice.DEVICE_ID_ATTR).get();

if (msg instanceof MessageWrapper<?> m) {
if (CommandType.HEART_BEAT.equals(m.commandType)) {
logger.trace("{}{}: Received pong", deviceId,
Objects.requireNonNullElse(ctx.channel().remoteAddress(), ""));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.smarthomej.binding.tuya.internal.local.CommandType.UDP_NEW;
import static org.smarthomej.binding.tuya.internal.local.ProtocolVersion.V3_3;
import static org.smarthomej.binding.tuya.internal.local.ProtocolVersion.V3_4;
import static org.smarthomej.binding.tuya.internal.local.TuyaDevice.*;

import java.nio.ByteBuffer;
import java.util.Arrays;
Expand All @@ -35,7 +36,6 @@
import org.smarthomej.binding.tuya.internal.local.CommandType;
import org.smarthomej.binding.tuya.internal.local.MessageWrapper;
import org.smarthomej.binding.tuya.internal.local.ProtocolVersion;
import org.smarthomej.binding.tuya.internal.local.TuyaDevice;
import org.smarthomej.binding.tuya.internal.local.dto.DiscoveryMessage;
import org.smarthomej.binding.tuya.internal.local.dto.TcpStatusPayload;
import org.smarthomej.binding.tuya.internal.util.CryptoUtil;
Expand All @@ -58,16 +58,10 @@
public class TuyaDecoder extends ByteToMessageDecoder {
private final Logger logger = LoggerFactory.getLogger(TuyaDecoder.class);

private final TuyaDevice.KeyStore keyStore;
private final ProtocolVersion version;
private final Gson gson;
private final String deviceId;

public TuyaDecoder(Gson gson, String deviceId, TuyaDevice.KeyStore keyStore, ProtocolVersion version) {
public TuyaDecoder(Gson gson) {
this.gson = gson;
this.keyStore = keyStore;
this.version = version;
this.deviceId = deviceId;
}

@Override
Expand All @@ -78,6 +72,17 @@ public void decode(@NonNullByDefault({}) ChannelHandlerContext ctx, @NonNullByDe
return;
}

if (!ctx.channel().hasAttr(DEVICE_ID_ATTR) || !ctx.channel().hasAttr(PROTOCOL_ATTR)
|| !ctx.channel().hasAttr(SESSION_KEY_ATTR)) {
logger.warn(
"{}: Failed to retrieve deviceId, protocol or sessionKey from ChannelHandlerContext. This is a bug.",
Objects.requireNonNullElse(ctx.channel().remoteAddress(), ""));
return;
}
String deviceId = ctx.channel().attr(DEVICE_ID_ATTR).get();
ProtocolVersion protocol = ctx.channel().attr(PROTOCOL_ATTR).get();
byte[] sessionKey = ctx.channel().attr(SESSION_KEY_ATTR).get();

// we need to take a copy first so the buffer stays intact if we exit early
ByteBuf inCopy = in.copy();
byte[] bytes = new byte[inCopy.readableBytes()];
Expand Down Expand Up @@ -111,20 +116,20 @@ public void decode(@NonNullByDefault({}) ChannelHandlerContext ctx, @NonNullByDe
if ((returnCode & 0xffffff00) != 0) {
// rewind if no return code is present
buffer.position(buffer.position() - 4);
payload = version == V3_4 ? new byte[payloadLength - 32] : new byte[payloadLength - 8];
payload = protocol == V3_4 ? new byte[payloadLength - 32] : new byte[payloadLength - 8];
} else {
payload = version == V3_4 ? new byte[payloadLength - 32 - 8] : new byte[payloadLength - 8 - 4];
payload = protocol == V3_4 ? new byte[payloadLength - 32 - 8] : new byte[payloadLength - 8 - 4];
}

buffer.get(payload);

if (version == V3_4 && commandType != UDP && commandType != UDP_NEW) {
if (protocol == V3_4 && commandType != UDP && commandType != UDP_NEW) {
byte[] fullMessage = new byte[buffer.position()];
buffer.position(0);
buffer.get(fullMessage);
byte[] expectedHmac = new byte[32];
buffer.get(expectedHmac);
byte[] calculatedHmac = CryptoUtil.hmac(fullMessage, keyStore.getSessionKey());
byte[] calculatedHmac = CryptoUtil.hmac(fullMessage, sessionKey);
if (!Arrays.equals(expectedHmac, calculatedHmac)) {
logger.warn("{}{}: Checksum failed for message: calculated {}, found {}", deviceId,
Objects.requireNonNullElse(ctx.channel().remoteAddress(), ""),
Expand All @@ -150,8 +155,8 @@ public void decode(@NonNullByDefault({}) ChannelHandlerContext ctx, @NonNullByDe
return;
}

if (Arrays.equals(Arrays.copyOfRange(payload, 0, version.getBytes().length), version.getBytes())) {
if (version == V3_3) {
if (Arrays.equals(Arrays.copyOfRange(payload, 0, protocol.getBytes().length), protocol.getBytes())) {
if (protocol == V3_3) {
// Remove 3.3 header
payload = Arrays.copyOfRange(payload, 15, payload.length);
} else {
Expand All @@ -165,13 +170,14 @@ public void decode(@NonNullByDefault({}) ChannelHandlerContext ctx, @NonNullByDe
m = new MessageWrapper<>(commandType,
Objects.requireNonNull(gson.fromJson(new String(payload), DiscoveryMessage.class)));
} else {
byte[] decodedMessage = version == V3_4 ? CryptoUtil.decryptAesEcb(payload, keyStore.getSessionKey(), true)
: CryptoUtil.decryptAesEcb(payload, keyStore.getDeviceKey(), false);
byte[] decodedMessage = protocol == V3_4 ? CryptoUtil.decryptAesEcb(payload, sessionKey, true)
: CryptoUtil.decryptAesEcb(payload, sessionKey, false);
if (decodedMessage == null) {
return;
}
if (Arrays.equals(Arrays.copyOfRange(decodedMessage, 0, version.getBytes().length), version.getBytes())) {
if (version == V3_4) {

if (Arrays.equals(Arrays.copyOfRange(decodedMessage, 0, protocol.getBytes().length), protocol.getBytes())) {
if (protocol == V3_4) {
// Remove 3.4 header
decodedMessage = Arrays.copyOfRange(decodedMessage, 15, decodedMessage.length);
}
Expand Down
Loading

0 comments on commit b9fc33c

Please sign in to comment.