From 76fef20ff6fb64e42f8f0e1c5098fb9bce3ead6d Mon Sep 17 00:00:00 2001 From: hengyunabc Date: Wed, 2 Aug 2023 16:00:58 +0800 Subject: [PATCH] SimpleHttpResponse adds deserialization whitelist --- tunnel-common/pom.xml | 12 +++- .../tunnel/common/SimpleHttpResponse.java | 37 +++++------- .../tunnel/common/SimpleHttpResponseTest.java | 59 +++++++++++++++++++ 3 files changed, 85 insertions(+), 23 deletions(-) create mode 100644 tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java diff --git a/tunnel-common/pom.xml b/tunnel-common/pom.xml index 8a80f02708..b998028336 100644 --- a/tunnel-common/pom.xml +++ b/tunnel-common/pom.xml @@ -12,8 +12,16 @@ https://github.com/alibaba/arthas - - + + junit + junit + test + + + org.assertj + assertj-core + test + diff --git a/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java b/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java index bdaacc20c8..302216b33b 100644 --- a/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java +++ b/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java @@ -3,11 +3,14 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.ObjectInput; +import java.io.InvalidClassException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; import java.io.Serializable; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; /** @@ -16,9 +19,11 @@ * */ public class SimpleHttpResponse implements Serializable { - private static final long serialVersionUID = 1L; + private static final List whitelist = Arrays.asList(byte[].class.getName(), String.class.getName(), + Map.class.getName(), HashMap.class.getName(), SimpleHttpResponse.class.getName()); + private int status = 200; private Map headers = new HashMap(); @@ -55,35 +60,25 @@ public void setStatus(int status) { public static byte[] toBytes(SimpleHttpResponse response) throws IOException { ByteArrayOutputStream bos = new ByteArrayOutputStream(); - ObjectOutputStream out = null; - try { - out = new ObjectOutputStream(bos); + try (ObjectOutputStream out = new ObjectOutputStream(bos)) { out.writeObject(response); out.flush(); return bos.toByteArray(); - } finally { - try { - bos.close(); - } catch (IOException ex) { - // ignore close exception - } } } public static SimpleHttpResponse fromBytes(byte[] bytes) throws IOException, ClassNotFoundException { ByteArrayInputStream bis = new ByteArrayInputStream(bytes); - ObjectInput in = null; - try { - in = new ObjectInputStream(bis); - return (SimpleHttpResponse) in.readObject(); - } finally { - try { - if (in != null) { - in.close(); + try (ObjectInputStream in = new ObjectInputStream(bis) { + protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { + if (!whitelist.contains(desc.getName())) { + throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName()); } - } catch (IOException ex) { - // ignore close exception + return super.resolveClass(desc); } + }) { + return (SimpleHttpResponse) in.readObject(); } } + } diff --git a/tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java b/tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java new file mode 100644 index 0000000000..477b7203aa --- /dev/null +++ b/tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java @@ -0,0 +1,59 @@ +package com.alibaba.arthas.tunnel.common; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InvalidClassException; +import java.io.ObjectOutputStream; +import java.util.Date; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; + +public class SimpleHttpResponseTest { + + @Test + public void testSerialization() throws IOException, ClassNotFoundException { + SimpleHttpResponse response = new SimpleHttpResponse(); + response.setStatus(200); + + Map headers = new HashMap(); + headers.put("Content-Type", "text/plain"); + response.setHeaders(headers); + + String content = "Hello, world!"; + response.setContent(content.getBytes()); + + byte[] bytes = SimpleHttpResponse.toBytes(response); + + SimpleHttpResponse deserializedResponse = SimpleHttpResponse.fromBytes(bytes); + + assertEquals(response.getStatus(), deserializedResponse.getStatus()); + assertEquals(response.getHeaders(), deserializedResponse.getHeaders()); + assertArrayEquals(response.getContent(), deserializedResponse.getContent()); + } + + private static byte[] toBytes(Object object) throws IOException { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try (ObjectOutputStream out = new ObjectOutputStream(bos)) { + out.writeObject(object); + out.flush(); + return bos.toByteArray(); + } + } + + @Test(expected = InvalidClassException.class) + public void testDeserializationWithUnauthorizedClass() throws IOException, ClassNotFoundException { + Date date = new Date(); + + byte[] bytes = toBytes(date); + + // Try to deserialize the object with an unauthorized class + // This should throw an InvalidClassException + SimpleHttpResponse.fromBytes(bytes); + } + +}