Skip to content

Commit 9733e07

Browse files
authored
Merge pull request #4 from scalableminds/fix-zstd-codec
Fix zstd codec
2 parents e51ac9d + 29541d5 commit 9733e07

File tree

5 files changed

+152
-99
lines changed

5 files changed

+152
-99
lines changed

src/main/java/dev/zarr/zarrjava/v3/codec/core/ZstdCodec.java

+50-65
Original file line numberDiff line numberDiff line change
@@ -2,89 +2,74 @@
22

33
import com.fasterxml.jackson.annotation.JsonCreator;
44
import com.fasterxml.jackson.annotation.JsonProperty;
5-
import com.github.luben.zstd.ZstdInputStream;
6-
import com.github.luben.zstd.ZstdOutputStream;
5+
import com.github.luben.zstd.Zstd;
6+
import com.github.luben.zstd.ZstdCompressCtx;
77
import dev.zarr.zarrjava.ZarrException;
8-
import dev.zarr.zarrjava.utils.Utils;
98
import dev.zarr.zarrjava.v3.ArrayMetadata;
109
import dev.zarr.zarrjava.v3.codec.BytesBytesCodec;
11-
import java.io.ByteArrayInputStream;
12-
import java.io.ByteArrayOutputStream;
13-
import java.io.IOException;
14-
import java.io.InputStream;
15-
import java.io.OutputStream;
16-
import java.nio.ByteBuffer;
10+
1711
import javax.annotation.Nonnull;
12+
import java.nio.ByteBuffer;
1813

1914
public class ZstdCodec extends BytesBytesCodec {
2015

21-
public final String name = "zstd";
22-
@Nonnull
23-
public final Configuration configuration;
16+
public final String name = "zstd";
17+
@Nonnull
18+
public final Configuration configuration;
2419

25-
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
26-
public ZstdCodec(
27-
@Nonnull @JsonProperty(value = "configuration", required = true) Configuration configuration) {
28-
this.configuration = configuration;
29-
}
30-
31-
private void copy(InputStream inputStream, OutputStream outputStream) throws IOException {
32-
byte[] buffer = new byte[4096];
33-
int len;
34-
while ((len = inputStream.read(buffer)) > 0) {
35-
outputStream.write(buffer, 0, len);
20+
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
21+
public ZstdCodec(
22+
@Nonnull @JsonProperty(value = "configuration", required = true) Configuration configuration) {
23+
this.configuration = configuration;
3624
}
37-
}
3825

39-
@Override
40-
public ByteBuffer decode(ByteBuffer chunkBytes)
41-
throws ZarrException {
42-
try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ZstdInputStream inputStream = new ZstdInputStream(
43-
new ByteArrayInputStream(Utils.toArray(chunkBytes)))) {
44-
copy(inputStream, outputStream);
45-
inputStream.close();
46-
return ByteBuffer.wrap(outputStream.toByteArray());
47-
} catch (IOException ex) {
48-
throw new ZarrException("Error in decoding zstd.", ex);
26+
@Override
27+
public ByteBuffer decode(ByteBuffer compressedBytes) throws ZarrException {
28+
byte[] compressedArray = compressedBytes.array();
29+
30+
long originalSize = Zstd.decompressedSize(compressedArray);
31+
if (originalSize == 0) {
32+
throw new ZarrException("Failed to get decompressed size");
33+
}
34+
35+
byte[] decompressed = Zstd.decompress(compressedArray, (int) originalSize);
36+
return ByteBuffer.wrap(decompressed);
4937
}
50-
}
5138

52-
@Override
53-
public ByteBuffer encode(ByteBuffer chunkBytes)
54-
throws ZarrException {
55-
try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ZstdOutputStream zstdStream = new ZstdOutputStream(
56-
outputStream, configuration.level).setChecksum(
57-
configuration.checksum)) {
58-
zstdStream.write(Utils.toArray(chunkBytes));
59-
zstdStream.close();
60-
return ByteBuffer.wrap(outputStream.toByteArray());
61-
} catch (IOException ex) {
62-
throw new ZarrException("Error in encoding zstd.", ex);
39+
@Override
40+
public ByteBuffer encode(ByteBuffer chunkBytes) throws ZarrException {
41+
byte[] arr = chunkBytes.array();
42+
byte[] compressed;
43+
try (ZstdCompressCtx ctx = new ZstdCompressCtx()) {
44+
ctx.setLevel(configuration.level);
45+
ctx.setChecksum(configuration.checksum);
46+
compressed = ctx.compress(arr);
47+
}
48+
return ByteBuffer.wrap(compressed);
6349
}
64-
}
6550

66-
@Override
67-
public long computeEncodedSize(long inputByteLength,
68-
ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException {
69-
throw new ZarrException("Not implemented for Zstd codec.");
70-
}
51+
@Override
52+
public long computeEncodedSize(long inputByteLength,
53+
ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException {
54+
throw new ZarrException("Not implemented for Zstd codec.");
55+
}
7156

72-
public static final class Configuration {
57+
public static final class Configuration {
7358

74-
public final int level;
75-
public final boolean checksum;
59+
public final int level;
60+
public final boolean checksum;
7661

77-
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
78-
public Configuration(@JsonProperty(value = "level", defaultValue = "5") int level,
79-
@JsonProperty(value = "checksum", defaultValue = "true") boolean checksum)
80-
throws ZarrException {
81-
if (level < -131072 || level > 22) {
82-
throw new ZarrException("'level' needs to be between -131072 and 22.");
83-
}
84-
this.level = level;
85-
this.checksum = checksum;
62+
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
63+
public Configuration(@JsonProperty(value = "level", defaultValue = "5") int level,
64+
@JsonProperty(value = "checksum", defaultValue = "true") boolean checksum)
65+
throws ZarrException {
66+
if (level < -131072 || level > 22) {
67+
throw new ZarrException("'level' needs to be between -131072 and 22.");
68+
}
69+
this.level = level;
70+
this.checksum = checksum;
71+
}
8672
}
87-
}
8873
}
8974

9075

src/test/java/dev/zarr/zarrjava/ZarrTest.java

+87-32
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,26 @@
44
import com.amazonaws.auth.AnonymousAWSCredentials;
55
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
66
import com.fasterxml.jackson.databind.ObjectMapper;
7+
import com.github.luben.zstd.Zstd;
8+
import com.github.luben.zstd.ZstdCompressCtx;
79
import dev.zarr.zarrjava.store.FilesystemStore;
810
import dev.zarr.zarrjava.store.HttpStore;
911
import dev.zarr.zarrjava.store.S3Store;
1012
import dev.zarr.zarrjava.store.StoreHandle;
1113
import dev.zarr.zarrjava.utils.MultiArrayUtils;
1214
import dev.zarr.zarrjava.v3.*;
15+
import dev.zarr.zarrjava.v3.codec.CodecBuilder;
1316
import dev.zarr.zarrjava.v3.codec.core.TransposeCodec;
1417
import org.junit.jupiter.api.Assertions;
1518
import org.junit.jupiter.api.BeforeAll;
1619
import org.junit.jupiter.api.Test;
1720
import org.junit.jupiter.params.ParameterizedTest;
21+
import org.junit.jupiter.params.provider.CsvSource;
1822
import org.junit.jupiter.params.provider.ValueSource;
23+
import ucar.ma2.MAMath;
1924

20-
import java.io.BufferedReader;
21-
import java.io.File;
22-
import java.io.IOException;
23-
import java.io.InputStreamReader;
25+
import java.io.*;
26+
import java.nio.ByteBuffer;
2427
import java.nio.file.Files;
2528
import java.nio.file.Path;
2629
import java.nio.file.Paths;
@@ -35,8 +38,7 @@ public class ZarrTest {
3538

3639
final static Path TESTDATA = Paths.get("testdata");
3740
final static Path TESTOUTPUT = Paths.get("testoutput");
38-
final static Path ZARRITA_WRITE_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/zarrita_write.py");
39-
final static Path ZARRITA_READ_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/zarrita_read.py");
41+
final static Path PYTHON_TEST_PATH = Paths.get("src/test/python-scripts/");
4042

4143
public static String pythonPath() {
4244
if (System.getProperty("os.name").startsWith("Windows")) {
@@ -60,7 +62,7 @@ public static void clearTestoutputFolder() throws IOException {
6062
public void testReadFromZarrita(String codec) throws IOException, ZarrException, InterruptedException {
6163

6264
String command = pythonPath();
63-
ProcessBuilder pb = new ProcessBuilder(command, ZARRITA_WRITE_PATH.toString(), codec, TESTOUTPUT.toString());
65+
ProcessBuilder pb = new ProcessBuilder(command, PYTHON_TEST_PATH.resolve("zarrita_write.py").toString(), codec, TESTOUTPUT.toString());
6466
Process process = pb.start();
6567

6668
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
@@ -91,10 +93,43 @@ public void testReadFromZarrita(String codec) throws IOException, ZarrException,
9193
Assertions.assertArrayEquals(expectedData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT));
9294
}
9395

96+
@CsvSource({"0,true", "0,false", "5, true", "10, false"})
97+
@ParameterizedTest
98+
public void testZstdLibrary(int clevel, boolean checksumFlag) throws IOException, InterruptedException {
99+
//compress using ZstdCompressCtx
100+
int number = 123456;
101+
byte[] src = ByteBuffer.allocate(4).putInt(number).array();
102+
byte[] compressed;
103+
try (ZstdCompressCtx ctx = new ZstdCompressCtx()) {
104+
ctx.setLevel(clevel);
105+
ctx.setChecksum(checksumFlag);
106+
compressed = ctx.compress(src);
107+
}
108+
//decompress with Zstd.decompress
109+
long originalSize = Zstd.decompressedSize(compressed);
110+
byte[] decompressed = Zstd.decompress(compressed, (int) originalSize);
111+
Assertions.assertEquals(number, ByteBuffer.wrap(decompressed).getInt());
112+
113+
//write compressed to file
114+
String compressedDataPath = TESTOUTPUT.resolve("compressed" + clevel + checksumFlag + ".bin").toString();
115+
try (FileOutputStream fos = new FileOutputStream(compressedDataPath)) {
116+
fos.write(compressed);
117+
}
118+
119+
//decompress in python
120+
Process process = new ProcessBuilder(
121+
pythonPath(),
122+
PYTHON_TEST_PATH.resolve("zstd_decompress.py").toString(),
123+
compressedDataPath,
124+
Integer.toString(number)
125+
).start();
126+
int exitCode = process.waitFor();
127+
assert exitCode == 0;
128+
}
129+
94130
//TODO: add crc32c
95-
//Disabled "zstd": known issue
96131
@ParameterizedTest
97-
@ValueSource(strings = {"blosc", "gzip", "bytes", "transpose", "sharding_start", "sharding_end"})
132+
@ValueSource(strings = {"blosc", "gzip", "zstd", "bytes", "transpose", "sharding_start", "sharding_end"})
98133
public void testWriteToZarrita(String codec) throws IOException, ZarrException, InterruptedException {
99134
StoreHandle storeHandle = new FilesystemStore(TESTOUTPUT).resolve("write_to_zarrita", codec);
100135
ArrayMetadataBuilder builder = Array.metadataBuilder()
@@ -105,10 +140,10 @@ public void testWriteToZarrita(String codec) throws IOException, ZarrException,
105140

106141
switch (codec) {
107142
case "blosc":
108-
builder = builder.withCodecs(c -> c.withBlosc());
143+
builder = builder.withCodecs(CodecBuilder::withBlosc);
109144
break;
110145
case "gzip":
111-
builder = builder.withCodecs(c -> c.withGzip());
146+
builder = builder.withCodecs(CodecBuilder::withGzip);
112147
break;
113148
case "zstd":
114149
builder = builder.withCodecs(c -> c.withZstd(0));
@@ -140,7 +175,7 @@ public void testWriteToZarrita(String codec) throws IOException, ZarrException,
140175

141176
String command = pythonPath();
142177

143-
ProcessBuilder pb = new ProcessBuilder(command, ZARRITA_READ_PATH.toString(), codec, TESTOUTPUT.toString());
178+
ProcessBuilder pb = new ProcessBuilder(command, PYTHON_TEST_PATH.resolve("zarrita_read.py").toString(), codec, TESTOUTPUT.toString());
144179
Process process = pb.start();
145180

146181
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
@@ -161,7 +196,7 @@ public void testWriteToZarrita(String codec) throws IOException, ZarrException,
161196

162197
@ParameterizedTest
163198
@ValueSource(strings = {"blosc", "gzip", "zstd", "bytes", "transpose", "sharding_start", "sharding_end"})
164-
public void testCodecsWriteRead(String codec) throws IOException, ZarrException, InterruptedException {
199+
public void testCodecsWriteRead(String codec) throws IOException, ZarrException {
165200
int[] testData = new int[16 * 16 * 16];
166201
Arrays.setAll(testData, p -> p);
167202

@@ -175,10 +210,10 @@ public void testCodecsWriteRead(String codec) throws IOException, ZarrException,
175210

176211
switch (codec) {
177212
case "blosc":
178-
builder = builder.withCodecs(c -> c.withBlosc());
213+
builder = builder.withCodecs(CodecBuilder::withBlosc);
179214
break;
180215
case "gzip":
181-
builder = builder.withCodecs(c -> c.withGzip());
216+
builder = builder.withCodecs(CodecBuilder::withGzip);
182217
break;
183218
case "zstd":
184219
builder = builder.withCodecs(c -> c.withZstd(0));
@@ -216,8 +251,31 @@ public void testCodecsWriteRead(String codec) throws IOException, ZarrException,
216251
Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT));
217252
}
218253

254+
@ParameterizedTest
255+
@CsvSource({"0,true", "0,false", "5, true", "5, false"})
256+
public void testZstdCodecReadWrite(int clevel, boolean checksum) throws ZarrException, IOException {
257+
int[] testData = new int[16 * 16 * 16];
258+
Arrays.setAll(testData, p -> p);
259+
260+
StoreHandle storeHandle = new FilesystemStore(TESTOUTPUT).resolve("testZstdCodecReadWrite", "checksum_" + checksum, "clevel_" + clevel);
261+
ArrayMetadataBuilder builder = Array.metadataBuilder()
262+
.withShape(16, 16, 16)
263+
.withDataType(DataType.UINT32)
264+
.withChunkShape(2, 4, 8)
265+
.withFillValue(0)
266+
.withCodecs(c -> c.withZstd(clevel, checksum));
267+
Array writeArray = Array.create(storeHandle, builder.build());
268+
writeArray.write(ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{16, 16, 16}, testData));
269+
270+
Array readArray = Array.open(storeHandle);
271+
ucar.ma2.Array result = readArray.read();
272+
273+
Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT));
274+
275+
}
276+
219277
@Test
220-
public void testCodecTranspose() throws IOException, ZarrException, InterruptedException {
278+
public void testTransposeCodec() throws ZarrException {
221279
ucar.ma2.Array testData = ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{2, 3, 3}, new int[]{
222280
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
223281
ucar.ma2.Array testDataTransposed120 = ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{3, 3, 2}, new int[]{
@@ -237,8 +295,8 @@ public void testCodecTranspose() throws IOException, ZarrException, InterruptedE
237295
transposeCodecWrongOrder2.setCoreArrayMetadata(metadata);
238296
transposeCodecWrongOrder3.setCoreArrayMetadata(metadata);
239297

240-
assert ucar.ma2.MAMath.equals(testDataTransposed120, transposeCodec.encode(testData));
241-
assert ucar.ma2.MAMath.equals(testData, transposeCodec.decode(testDataTransposed120));
298+
assert MAMath.equals(testDataTransposed120, transposeCodec.encode(testData));
299+
assert MAMath.equals(testData, transposeCodec.decode(testDataTransposed120));
242300
assertThrows(ZarrException.class, () -> transposeCodecWrongOrder1.encode(testData));
243301
assertThrows(ZarrException.class, () -> transposeCodecWrongOrder2.encode(testData));
244302
assertThrows(ZarrException.class, () -> transposeCodecWrongOrder3.encode(testData));
@@ -295,8 +353,8 @@ public void testV3ShardingReadCutout() throws IOException, ZarrException {
295353
Array array = Array.open(new FilesystemStore(TESTDATA).resolve("l4_sample", "color", "1"));
296354

297355
ucar.ma2.Array outArray = array.read(new long[]{0, 3073, 3073, 513}, new int[]{1, 64, 64, 64});
298-
assertEquals(outArray.getSize(), 64 * 64 * 64);
299-
assertEquals(outArray.getByte(0), -98);
356+
Assertions.assertEquals(outArray.getSize(), 64 * 64 * 64);
357+
Assertions.assertEquals(outArray.getByte(0), -98);
300358
}
301359

302360
@Test
@@ -306,8 +364,8 @@ public void testV3Access() throws IOException, ZarrException {
306364
ucar.ma2.Array outArray = readArray.access().withOffset(0, 3073, 3073, 513)
307365
.withShape(1, 64, 64, 64)
308366
.read();
309-
assertEquals(outArray.getSize(), 64 * 64 * 64);
310-
assertEquals(outArray.getByte(0), -98);
367+
Assertions.assertEquals(outArray.getSize(), 64 * 64 * 64);
368+
Assertions.assertEquals(outArray.getByte(0), -98);
311369

312370
Array writeArray = Array.create(
313371
new FilesystemStore(TESTOUTPUT).resolve("l4_sample_2", "color", "1"),
@@ -375,15 +433,15 @@ public void testV3ArrayMetadataBuilder() throws ZarrException {
375433
.withChunkShape(1, 1024, 1024, 1024)
376434
.withFillValue(0)
377435
.withCodecs(
378-
c -> c.withSharding(new int[]{1, 32, 32, 32}, c1 -> c1.withBlosc()))
436+
c -> c.withSharding(new int[]{1, 32, 32, 32}, CodecBuilder::withBlosc))
379437
.build();
380438
}
381439

382440
@Test
383441
public void testV3FillValue() throws ZarrException {
384-
assertEquals((int) ArrayMetadata.parseFillValue(0, DataType.UINT32), 0);
385-
assertEquals((int) ArrayMetadata.parseFillValue("0x00010203", DataType.UINT32), 50462976);
386-
assertEquals((byte) ArrayMetadata.parseFillValue("0b00000010", DataType.UINT8), 2);
442+
Assertions.assertEquals((int) ArrayMetadata.parseFillValue(0, DataType.UINT32), 0);
443+
Assertions.assertEquals((int) ArrayMetadata.parseFillValue("0x00010203", DataType.UINT32), 50462976);
444+
Assertions.assertEquals((byte) ArrayMetadata.parseFillValue("0b00000010", DataType.UINT8), 2);
387445
assert Double.isNaN((double) ArrayMetadata.parseFillValue("NaN", DataType.FLOAT64));
388446
assert Double.isInfinite((double) ArrayMetadata.parseFillValue("-Infinity", DataType.FLOAT64));
389447
}
@@ -401,18 +459,15 @@ public void testV3Group() throws IOException, ZarrException {
401459
);
402460
array.write(new long[]{2, 2}, ucar.ma2.Array.factory(ucar.ma2.DataType.UBYTE, new int[]{8, 8}));
403461

404-
assertArrayEquals(
405-
((Array) ((Group) group.listAsArray()[0]).listAsArray()[0]).metadata.chunkShape(),
406-
new int[]{5, 5});
462+
Assertions.assertArrayEquals(((Array) ((Group) group.listAsArray()[0]).listAsArray()[0]).metadata.chunkShape(), new int[]{5, 5});
407463
}
408464

409465
@Test
410-
public void testV2() throws IOException, ZarrException {
466+
public void testV2() throws IOException{
411467
FilesystemStore fsStore = new FilesystemStore("");
412468
HttpStore httpStore = new HttpStore("https://static.webknossos.org/data");
413469

414-
System.out.println(
415-
dev.zarr.zarrjava.v2.Array.open(httpStore.resolve("l4_sample", "color", "1")));
470+
System.out.println(dev.zarr.zarrjava.v2.Array.open(httpStore.resolve("l4_sample", "color", "1")));
416471
}
417472

418473

src/test/java/dev/zarr/zarrjava/zarrita_read.py renamed to src/test/python-scripts/zarrita_read.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
elif codec_string == "gzip":
1111
codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.gzip_codec()]
1212
elif codec_string == "zstd":
13-
codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.zstd_codec()]
13+
codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.zstd_codec(checksum=True)]
1414
elif codec_string == "bytes":
1515
codec = [zarrita.codecs.bytes_codec()]
1616
elif codec_string == "transpose":

0 commit comments

Comments
 (0)