diff --git a/src/main/java/dev/zarr/zarrjava/v3/codec/core/ZstdCodec.java b/src/main/java/dev/zarr/zarrjava/v3/codec/core/ZstdCodec.java index fa85765..f042f11 100644 --- a/src/main/java/dev/zarr/zarrjava/v3/codec/core/ZstdCodec.java +++ b/src/main/java/dev/zarr/zarrjava/v3/codec/core/ZstdCodec.java @@ -2,89 +2,74 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.github.luben.zstd.ZstdInputStream; -import com.github.luben.zstd.ZstdOutputStream; +import com.github.luben.zstd.Zstd; +import com.github.luben.zstd.ZstdCompressCtx; import dev.zarr.zarrjava.ZarrException; -import dev.zarr.zarrjava.utils.Utils; import dev.zarr.zarrjava.v3.ArrayMetadata; import dev.zarr.zarrjava.v3.codec.BytesBytesCodec; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.nio.ByteBuffer; + import javax.annotation.Nonnull; +import java.nio.ByteBuffer; public class ZstdCodec extends BytesBytesCodec { - public final String name = "zstd"; - @Nonnull - public final Configuration configuration; + public final String name = "zstd"; + @Nonnull + public final Configuration configuration; - @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) - public ZstdCodec( - @Nonnull @JsonProperty(value = "configuration", required = true) Configuration configuration) { - this.configuration = configuration; - } - - private void copy(InputStream inputStream, OutputStream outputStream) throws IOException { - byte[] buffer = new byte[4096]; - int len; - while ((len = inputStream.read(buffer)) > 0) { - outputStream.write(buffer, 0, len); + @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) + public ZstdCodec( + @Nonnull @JsonProperty(value = "configuration", required = true) Configuration configuration) { + this.configuration = configuration; } - } - @Override - public ByteBuffer decode(ByteBuffer chunkBytes) - throws ZarrException { - try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ZstdInputStream inputStream = new ZstdInputStream( - new ByteArrayInputStream(Utils.toArray(chunkBytes)))) { - copy(inputStream, outputStream); - inputStream.close(); - return ByteBuffer.wrap(outputStream.toByteArray()); - } catch (IOException ex) { - throw new ZarrException("Error in decoding zstd.", ex); + @Override + public ByteBuffer decode(ByteBuffer compressedBytes) throws ZarrException { + byte[] compressedArray = compressedBytes.array(); + + long originalSize = Zstd.decompressedSize(compressedArray); + if (originalSize == 0) { + throw new ZarrException("Failed to get decompressed size"); + } + + byte[] decompressed = Zstd.decompress(compressedArray, (int) originalSize); + return ByteBuffer.wrap(decompressed); } - } - @Override - public ByteBuffer encode(ByteBuffer chunkBytes) - throws ZarrException { - try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ZstdOutputStream zstdStream = new ZstdOutputStream( - outputStream, configuration.level).setChecksum( - configuration.checksum)) { - zstdStream.write(Utils.toArray(chunkBytes)); - zstdStream.close(); - return ByteBuffer.wrap(outputStream.toByteArray()); - } catch (IOException ex) { - throw new ZarrException("Error in encoding zstd.", ex); + @Override + public ByteBuffer encode(ByteBuffer chunkBytes) throws ZarrException { + byte[] arr = chunkBytes.array(); + byte[] compressed; + try (ZstdCompressCtx ctx = new ZstdCompressCtx()) { + ctx.setLevel(configuration.level); + ctx.setChecksum(configuration.checksum); + compressed = ctx.compress(arr); + } + return ByteBuffer.wrap(compressed); } - } - @Override - public long computeEncodedSize(long inputByteLength, - ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException { - throw new ZarrException("Not implemented for Zstd codec."); - } + @Override + public long computeEncodedSize(long inputByteLength, + ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException { + throw new ZarrException("Not implemented for Zstd codec."); + } - public static final class Configuration { + public static final class Configuration { - public final int level; - public final boolean checksum; + public final int level; + public final boolean checksum; - @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) - public Configuration(@JsonProperty(value = "level", defaultValue = "5") int level, - @JsonProperty(value = "checksum", defaultValue = "true") boolean checksum) - throws ZarrException { - if (level < -131072 || level > 22) { - throw new ZarrException("'level' needs to be between -131072 and 22."); - } - this.level = level; - this.checksum = checksum; + @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) + public Configuration(@JsonProperty(value = "level", defaultValue = "5") int level, + @JsonProperty(value = "checksum", defaultValue = "true") boolean checksum) + throws ZarrException { + if (level < -131072 || level > 22) { + throw new ZarrException("'level' needs to be between -131072 and 22."); + } + this.level = level; + this.checksum = checksum; + } } - } } diff --git a/src/test/java/dev/zarr/zarrjava/ZarrTest.java b/src/test/java/dev/zarr/zarrjava/ZarrTest.java index ebf87cc..524c741 100644 --- a/src/test/java/dev/zarr/zarrjava/ZarrTest.java +++ b/src/test/java/dev/zarr/zarrjava/ZarrTest.java @@ -4,23 +4,26 @@ import com.amazonaws.auth.AnonymousAWSCredentials; import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.luben.zstd.Zstd; +import com.github.luben.zstd.ZstdCompressCtx; import dev.zarr.zarrjava.store.FilesystemStore; import dev.zarr.zarrjava.store.HttpStore; import dev.zarr.zarrjava.store.S3Store; import dev.zarr.zarrjava.store.StoreHandle; import dev.zarr.zarrjava.utils.MultiArrayUtils; import dev.zarr.zarrjava.v3.*; +import dev.zarr.zarrjava.v3.codec.CodecBuilder; import dev.zarr.zarrjava.v3.codec.core.TransposeCodec; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.ValueSource; +import ucar.ma2.MAMath; -import java.io.BufferedReader; -import java.io.File; -import java.io.IOException; -import java.io.InputStreamReader; +import java.io.*; +import java.nio.ByteBuffer; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -35,8 +38,7 @@ public class ZarrTest { final static Path TESTDATA = Paths.get("testdata"); final static Path TESTOUTPUT = Paths.get("testoutput"); - final static Path ZARRITA_WRITE_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/zarrita_write.py"); - final static Path ZARRITA_READ_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/zarrita_read.py"); + final static Path PYTHON_TEST_PATH = Paths.get("src/test/python-scripts/"); public static String pythonPath() { if (System.getProperty("os.name").startsWith("Windows")) { @@ -60,7 +62,7 @@ public static void clearTestoutputFolder() throws IOException { public void testReadFromZarrita(String codec) throws IOException, ZarrException, InterruptedException { String command = pythonPath(); - ProcessBuilder pb = new ProcessBuilder(command, ZARRITA_WRITE_PATH.toString(), codec, TESTOUTPUT.toString()); + ProcessBuilder pb = new ProcessBuilder(command, PYTHON_TEST_PATH.resolve("zarrita_write.py").toString(), codec, TESTOUTPUT.toString()); Process process = pb.start(); BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); @@ -91,10 +93,43 @@ public void testReadFromZarrita(String codec) throws IOException, ZarrException, Assertions.assertArrayEquals(expectedData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT)); } + @CsvSource({"0,true", "0,false", "5, true", "10, false"}) + @ParameterizedTest + public void testZstdLibrary(int clevel, boolean checksumFlag) throws IOException, InterruptedException { + //compress using ZstdCompressCtx + int number = 123456; + byte[] src = ByteBuffer.allocate(4).putInt(number).array(); + byte[] compressed; + try (ZstdCompressCtx ctx = new ZstdCompressCtx()) { + ctx.setLevel(clevel); + ctx.setChecksum(checksumFlag); + compressed = ctx.compress(src); + } + //decompress with Zstd.decompress + long originalSize = Zstd.decompressedSize(compressed); + byte[] decompressed = Zstd.decompress(compressed, (int) originalSize); + Assertions.assertEquals(number, ByteBuffer.wrap(decompressed).getInt()); + + //write compressed to file + String compressedDataPath = TESTOUTPUT.resolve("compressed" + clevel + checksumFlag + ".bin").toString(); + try (FileOutputStream fos = new FileOutputStream(compressedDataPath)) { + fos.write(compressed); + } + + //decompress in python + Process process = new ProcessBuilder( + pythonPath(), + PYTHON_TEST_PATH.resolve("zstd_decompress.py").toString(), + compressedDataPath, + Integer.toString(number) + ).start(); + int exitCode = process.waitFor(); + assert exitCode == 0; + } + //TODO: add crc32c - //Disabled "zstd": known issue @ParameterizedTest - @ValueSource(strings = {"blosc", "gzip", "bytes", "transpose", "sharding_start", "sharding_end"}) + @ValueSource(strings = {"blosc", "gzip", "zstd", "bytes", "transpose", "sharding_start", "sharding_end"}) public void testWriteToZarrita(String codec) throws IOException, ZarrException, InterruptedException { StoreHandle storeHandle = new FilesystemStore(TESTOUTPUT).resolve("write_to_zarrita", codec); ArrayMetadataBuilder builder = Array.metadataBuilder() @@ -105,10 +140,10 @@ public void testWriteToZarrita(String codec) throws IOException, ZarrException, switch (codec) { case "blosc": - builder = builder.withCodecs(c -> c.withBlosc()); + builder = builder.withCodecs(CodecBuilder::withBlosc); break; case "gzip": - builder = builder.withCodecs(c -> c.withGzip()); + builder = builder.withCodecs(CodecBuilder::withGzip); break; case "zstd": builder = builder.withCodecs(c -> c.withZstd(0)); @@ -140,7 +175,7 @@ public void testWriteToZarrita(String codec) throws IOException, ZarrException, String command = pythonPath(); - ProcessBuilder pb = new ProcessBuilder(command, ZARRITA_READ_PATH.toString(), codec, TESTOUTPUT.toString()); + ProcessBuilder pb = new ProcessBuilder(command, PYTHON_TEST_PATH.resolve("zarrita_read.py").toString(), codec, TESTOUTPUT.toString()); Process process = pb.start(); BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); @@ -161,7 +196,7 @@ public void testWriteToZarrita(String codec) throws IOException, ZarrException, @ParameterizedTest @ValueSource(strings = {"blosc", "gzip", "zstd", "bytes", "transpose", "sharding_start", "sharding_end"}) - public void testCodecsWriteRead(String codec) throws IOException, ZarrException, InterruptedException { + public void testCodecsWriteRead(String codec) throws IOException, ZarrException { int[] testData = new int[16 * 16 * 16]; Arrays.setAll(testData, p -> p); @@ -175,10 +210,10 @@ public void testCodecsWriteRead(String codec) throws IOException, ZarrException, switch (codec) { case "blosc": - builder = builder.withCodecs(c -> c.withBlosc()); + builder = builder.withCodecs(CodecBuilder::withBlosc); break; case "gzip": - builder = builder.withCodecs(c -> c.withGzip()); + builder = builder.withCodecs(CodecBuilder::withGzip); break; case "zstd": builder = builder.withCodecs(c -> c.withZstd(0)); @@ -216,8 +251,31 @@ public void testCodecsWriteRead(String codec) throws IOException, ZarrException, Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT)); } + @ParameterizedTest + @CsvSource({"0,true", "0,false", "5, true", "5, false"}) + public void testZstdCodecReadWrite(int clevel, boolean checksum) throws ZarrException, IOException { + int[] testData = new int[16 * 16 * 16]; + Arrays.setAll(testData, p -> p); + + StoreHandle storeHandle = new FilesystemStore(TESTOUTPUT).resolve("testZstdCodecReadWrite", "checksum_" + checksum, "clevel_" + clevel); + ArrayMetadataBuilder builder = Array.metadataBuilder() + .withShape(16, 16, 16) + .withDataType(DataType.UINT32) + .withChunkShape(2, 4, 8) + .withFillValue(0) + .withCodecs(c -> c.withZstd(clevel, checksum)); + Array writeArray = Array.create(storeHandle, builder.build()); + writeArray.write(ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{16, 16, 16}, testData)); + + Array readArray = Array.open(storeHandle); + ucar.ma2.Array result = readArray.read(); + + Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT)); + + } + @Test - public void testCodecTranspose() throws IOException, ZarrException, InterruptedException { + public void testTransposeCodec() throws ZarrException { ucar.ma2.Array testData = ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{2, 3, 3}, new int[]{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}); ucar.ma2.Array testDataTransposed120 = ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{3, 3, 2}, new int[]{ @@ -237,8 +295,8 @@ public void testCodecTranspose() throws IOException, ZarrException, InterruptedE transposeCodecWrongOrder2.setCoreArrayMetadata(metadata); transposeCodecWrongOrder3.setCoreArrayMetadata(metadata); - assert ucar.ma2.MAMath.equals(testDataTransposed120, transposeCodec.encode(testData)); - assert ucar.ma2.MAMath.equals(testData, transposeCodec.decode(testDataTransposed120)); + assert MAMath.equals(testDataTransposed120, transposeCodec.encode(testData)); + assert MAMath.equals(testData, transposeCodec.decode(testDataTransposed120)); assertThrows(ZarrException.class, () -> transposeCodecWrongOrder1.encode(testData)); assertThrows(ZarrException.class, () -> transposeCodecWrongOrder2.encode(testData)); assertThrows(ZarrException.class, () -> transposeCodecWrongOrder3.encode(testData)); @@ -295,8 +353,8 @@ public void testV3ShardingReadCutout() throws IOException, ZarrException { Array array = Array.open(new FilesystemStore(TESTDATA).resolve("l4_sample", "color", "1")); ucar.ma2.Array outArray = array.read(new long[]{0, 3073, 3073, 513}, new int[]{1, 64, 64, 64}); - assertEquals(outArray.getSize(), 64 * 64 * 64); - assertEquals(outArray.getByte(0), -98); + Assertions.assertEquals(outArray.getSize(), 64 * 64 * 64); + Assertions.assertEquals(outArray.getByte(0), -98); } @Test @@ -306,8 +364,8 @@ public void testV3Access() throws IOException, ZarrException { ucar.ma2.Array outArray = readArray.access().withOffset(0, 3073, 3073, 513) .withShape(1, 64, 64, 64) .read(); - assertEquals(outArray.getSize(), 64 * 64 * 64); - assertEquals(outArray.getByte(0), -98); + Assertions.assertEquals(outArray.getSize(), 64 * 64 * 64); + Assertions.assertEquals(outArray.getByte(0), -98); Array writeArray = Array.create( new FilesystemStore(TESTOUTPUT).resolve("l4_sample_2", "color", "1"), @@ -375,15 +433,15 @@ public void testV3ArrayMetadataBuilder() throws ZarrException { .withChunkShape(1, 1024, 1024, 1024) .withFillValue(0) .withCodecs( - c -> c.withSharding(new int[]{1, 32, 32, 32}, c1 -> c1.withBlosc())) + c -> c.withSharding(new int[]{1, 32, 32, 32}, CodecBuilder::withBlosc)) .build(); } @Test public void testV3FillValue() throws ZarrException { - assertEquals((int) ArrayMetadata.parseFillValue(0, DataType.UINT32), 0); - assertEquals((int) ArrayMetadata.parseFillValue("0x00010203", DataType.UINT32), 50462976); - assertEquals((byte) ArrayMetadata.parseFillValue("0b00000010", DataType.UINT8), 2); + Assertions.assertEquals((int) ArrayMetadata.parseFillValue(0, DataType.UINT32), 0); + Assertions.assertEquals((int) ArrayMetadata.parseFillValue("0x00010203", DataType.UINT32), 50462976); + Assertions.assertEquals((byte) ArrayMetadata.parseFillValue("0b00000010", DataType.UINT8), 2); assert Double.isNaN((double) ArrayMetadata.parseFillValue("NaN", DataType.FLOAT64)); assert Double.isInfinite((double) ArrayMetadata.parseFillValue("-Infinity", DataType.FLOAT64)); } @@ -401,18 +459,15 @@ public void testV3Group() throws IOException, ZarrException { ); array.write(new long[]{2, 2}, ucar.ma2.Array.factory(ucar.ma2.DataType.UBYTE, new int[]{8, 8})); - assertArrayEquals( - ((Array) ((Group) group.listAsArray()[0]).listAsArray()[0]).metadata.chunkShape(), - new int[]{5, 5}); + Assertions.assertArrayEquals(((Array) ((Group) group.listAsArray()[0]).listAsArray()[0]).metadata.chunkShape(), new int[]{5, 5}); } @Test - public void testV2() throws IOException, ZarrException { + public void testV2() throws IOException{ FilesystemStore fsStore = new FilesystemStore(""); HttpStore httpStore = new HttpStore("https://static.webknossos.org/data"); - System.out.println( - dev.zarr.zarrjava.v2.Array.open(httpStore.resolve("l4_sample", "color", "1"))); + System.out.println(dev.zarr.zarrjava.v2.Array.open(httpStore.resolve("l4_sample", "color", "1"))); } diff --git a/src/test/java/dev/zarr/zarrjava/zarrita_read.py b/src/test/python-scripts/zarrita_read.py similarity index 98% rename from src/test/java/dev/zarr/zarrjava/zarrita_read.py rename to src/test/python-scripts/zarrita_read.py index 07cb083..4eff03f 100644 --- a/src/test/java/dev/zarr/zarrjava/zarrita_read.py +++ b/src/test/python-scripts/zarrita_read.py @@ -10,7 +10,7 @@ elif codec_string == "gzip": codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.gzip_codec()] elif codec_string == "zstd": - codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.zstd_codec()] + codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.zstd_codec(checksum=True)] elif codec_string == "bytes": codec = [zarrita.codecs.bytes_codec()] elif codec_string == "transpose": diff --git a/src/test/java/dev/zarr/zarrjava/zarrita_write.py b/src/test/python-scripts/zarrita_write.py similarity index 95% rename from src/test/java/dev/zarr/zarrjava/zarrita_write.py rename to src/test/python-scripts/zarrita_write.py index 3a3fd3f..ae6611d 100644 --- a/src/test/java/dev/zarr/zarrjava/zarrita_write.py +++ b/src/test/python-scripts/zarrita_write.py @@ -14,7 +14,7 @@ elif codec_string == "bytes": codec = [zarrita.codecs.bytes_codec()] elif codec_string == "transpose": - codec = [zarrita.codecs.transpose_codec([0, 1]), zarrita.codecs.bytes_codec()] + codec = [zarrita.codecs.transpose_codec((0, 1)), zarrita.codecs.bytes_codec()] elif codec_string == "sharding_start": codec = [zarrita.codecs.sharding_codec(chunk_shape=(1, 2), codecs=[zarrita.codecs.bytes_codec()], index_location=zarrita.metadata.ShardingCodecIndexLocation.start)] elif codec_string == "sharding_end": diff --git a/src/test/python-scripts/zstd_decompress.py b/src/test/python-scripts/zstd_decompress.py new file mode 100644 index 0000000..0235fdd --- /dev/null +++ b/src/test/python-scripts/zstd_decompress.py @@ -0,0 +1,13 @@ +import sys + +import zstandard as zstd + +data_path = sys.argv[1] +expected = sys.argv[2] + +with open(data_path, "rb") as f: + compressed = f.read() + +decompressed = zstd.ZstdDecompressor().decompress(compressed) +number = int.from_bytes(decompressed, byteorder='big') +assert number == int(expected)