Skip to content

Commit

Permalink
Merge pull request #4 from scalableminds/fix-zstd-codec
Browse files Browse the repository at this point in the history
Fix zstd codec
  • Loading branch information
brokkoli71 authored May 30, 2024
2 parents e51ac9d + 29541d5 commit 9733e07
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 99 deletions.
115 changes: 50 additions & 65 deletions src/main/java/dev/zarr/zarrjava/v3/codec/core/ZstdCodec.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,89 +2,74 @@

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.github.luben.zstd.ZstdInputStream;
import com.github.luben.zstd.ZstdOutputStream;
import com.github.luben.zstd.Zstd;
import com.github.luben.zstd.ZstdCompressCtx;
import dev.zarr.zarrjava.ZarrException;
import dev.zarr.zarrjava.utils.Utils;
import dev.zarr.zarrjava.v3.ArrayMetadata;
import dev.zarr.zarrjava.v3.codec.BytesBytesCodec;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;

import javax.annotation.Nonnull;
import java.nio.ByteBuffer;

public class ZstdCodec extends BytesBytesCodec {

public final String name = "zstd";
@Nonnull
public final Configuration configuration;
public final String name = "zstd";
@Nonnull
public final Configuration configuration;

@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
public ZstdCodec(
@Nonnull @JsonProperty(value = "configuration", required = true) Configuration configuration) {
this.configuration = configuration;
}

private void copy(InputStream inputStream, OutputStream outputStream) throws IOException {
byte[] buffer = new byte[4096];
int len;
while ((len = inputStream.read(buffer)) > 0) {
outputStream.write(buffer, 0, len);
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
public ZstdCodec(
@Nonnull @JsonProperty(value = "configuration", required = true) Configuration configuration) {
this.configuration = configuration;
}
}

@Override
public ByteBuffer decode(ByteBuffer chunkBytes)
throws ZarrException {
try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ZstdInputStream inputStream = new ZstdInputStream(
new ByteArrayInputStream(Utils.toArray(chunkBytes)))) {
copy(inputStream, outputStream);
inputStream.close();
return ByteBuffer.wrap(outputStream.toByteArray());
} catch (IOException ex) {
throw new ZarrException("Error in decoding zstd.", ex);
@Override
public ByteBuffer decode(ByteBuffer compressedBytes) throws ZarrException {
byte[] compressedArray = compressedBytes.array();

long originalSize = Zstd.decompressedSize(compressedArray);
if (originalSize == 0) {
throw new ZarrException("Failed to get decompressed size");
}

byte[] decompressed = Zstd.decompress(compressedArray, (int) originalSize);
return ByteBuffer.wrap(decompressed);
}
}

@Override
public ByteBuffer encode(ByteBuffer chunkBytes)
throws ZarrException {
try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ZstdOutputStream zstdStream = new ZstdOutputStream(
outputStream, configuration.level).setChecksum(
configuration.checksum)) {
zstdStream.write(Utils.toArray(chunkBytes));
zstdStream.close();
return ByteBuffer.wrap(outputStream.toByteArray());
} catch (IOException ex) {
throw new ZarrException("Error in encoding zstd.", ex);
@Override
public ByteBuffer encode(ByteBuffer chunkBytes) throws ZarrException {
byte[] arr = chunkBytes.array();
byte[] compressed;
try (ZstdCompressCtx ctx = new ZstdCompressCtx()) {
ctx.setLevel(configuration.level);
ctx.setChecksum(configuration.checksum);
compressed = ctx.compress(arr);
}
return ByteBuffer.wrap(compressed);
}
}

@Override
public long computeEncodedSize(long inputByteLength,
ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException {
throw new ZarrException("Not implemented for Zstd codec.");
}
@Override
public long computeEncodedSize(long inputByteLength,
ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException {
throw new ZarrException("Not implemented for Zstd codec.");
}

public static final class Configuration {
public static final class Configuration {

public final int level;
public final boolean checksum;
public final int level;
public final boolean checksum;

@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
public Configuration(@JsonProperty(value = "level", defaultValue = "5") int level,
@JsonProperty(value = "checksum", defaultValue = "true") boolean checksum)
throws ZarrException {
if (level < -131072 || level > 22) {
throw new ZarrException("'level' needs to be between -131072 and 22.");
}
this.level = level;
this.checksum = checksum;
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
public Configuration(@JsonProperty(value = "level", defaultValue = "5") int level,
@JsonProperty(value = "checksum", defaultValue = "true") boolean checksum)
throws ZarrException {
if (level < -131072 || level > 22) {
throw new ZarrException("'level' needs to be between -131072 and 22.");
}
this.level = level;
this.checksum = checksum;
}
}
}
}


119 changes: 87 additions & 32 deletions src/test/java/dev/zarr/zarrjava/ZarrTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,26 @@
import com.amazonaws.auth.AnonymousAWSCredentials;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.luben.zstd.Zstd;
import com.github.luben.zstd.ZstdCompressCtx;
import dev.zarr.zarrjava.store.FilesystemStore;
import dev.zarr.zarrjava.store.HttpStore;
import dev.zarr.zarrjava.store.S3Store;
import dev.zarr.zarrjava.store.StoreHandle;
import dev.zarr.zarrjava.utils.MultiArrayUtils;
import dev.zarr.zarrjava.v3.*;
import dev.zarr.zarrjava.v3.codec.CodecBuilder;
import dev.zarr.zarrjava.v3.codec.core.TransposeCodec;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import ucar.ma2.MAMath;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
Expand All @@ -35,8 +38,7 @@ public class ZarrTest {

final static Path TESTDATA = Paths.get("testdata");
final static Path TESTOUTPUT = Paths.get("testoutput");
final static Path ZARRITA_WRITE_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/zarrita_write.py");
final static Path ZARRITA_READ_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/zarrita_read.py");
final static Path PYTHON_TEST_PATH = Paths.get("src/test/python-scripts/");

public static String pythonPath() {
if (System.getProperty("os.name").startsWith("Windows")) {
Expand All @@ -60,7 +62,7 @@ public static void clearTestoutputFolder() throws IOException {
public void testReadFromZarrita(String codec) throws IOException, ZarrException, InterruptedException {

String command = pythonPath();
ProcessBuilder pb = new ProcessBuilder(command, ZARRITA_WRITE_PATH.toString(), codec, TESTOUTPUT.toString());
ProcessBuilder pb = new ProcessBuilder(command, PYTHON_TEST_PATH.resolve("zarrita_write.py").toString(), codec, TESTOUTPUT.toString());
Process process = pb.start();

BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
Expand Down Expand Up @@ -91,10 +93,43 @@ public void testReadFromZarrita(String codec) throws IOException, ZarrException,
Assertions.assertArrayEquals(expectedData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT));
}

@CsvSource({"0,true", "0,false", "5, true", "10, false"})
@ParameterizedTest
public void testZstdLibrary(int clevel, boolean checksumFlag) throws IOException, InterruptedException {
//compress using ZstdCompressCtx
int number = 123456;
byte[] src = ByteBuffer.allocate(4).putInt(number).array();
byte[] compressed;
try (ZstdCompressCtx ctx = new ZstdCompressCtx()) {
ctx.setLevel(clevel);
ctx.setChecksum(checksumFlag);
compressed = ctx.compress(src);
}
//decompress with Zstd.decompress
long originalSize = Zstd.decompressedSize(compressed);
byte[] decompressed = Zstd.decompress(compressed, (int) originalSize);
Assertions.assertEquals(number, ByteBuffer.wrap(decompressed).getInt());

//write compressed to file
String compressedDataPath = TESTOUTPUT.resolve("compressed" + clevel + checksumFlag + ".bin").toString();
try (FileOutputStream fos = new FileOutputStream(compressedDataPath)) {
fos.write(compressed);
}

//decompress in python
Process process = new ProcessBuilder(
pythonPath(),
PYTHON_TEST_PATH.resolve("zstd_decompress.py").toString(),
compressedDataPath,
Integer.toString(number)
).start();
int exitCode = process.waitFor();
assert exitCode == 0;
}

//TODO: add crc32c
//Disabled "zstd": known issue
@ParameterizedTest
@ValueSource(strings = {"blosc", "gzip", "bytes", "transpose", "sharding_start", "sharding_end"})
@ValueSource(strings = {"blosc", "gzip", "zstd", "bytes", "transpose", "sharding_start", "sharding_end"})
public void testWriteToZarrita(String codec) throws IOException, ZarrException, InterruptedException {
StoreHandle storeHandle = new FilesystemStore(TESTOUTPUT).resolve("write_to_zarrita", codec);
ArrayMetadataBuilder builder = Array.metadataBuilder()
Expand All @@ -105,10 +140,10 @@ public void testWriteToZarrita(String codec) throws IOException, ZarrException,

switch (codec) {
case "blosc":
builder = builder.withCodecs(c -> c.withBlosc());
builder = builder.withCodecs(CodecBuilder::withBlosc);
break;
case "gzip":
builder = builder.withCodecs(c -> c.withGzip());
builder = builder.withCodecs(CodecBuilder::withGzip);
break;
case "zstd":
builder = builder.withCodecs(c -> c.withZstd(0));
Expand Down Expand Up @@ -140,7 +175,7 @@ public void testWriteToZarrita(String codec) throws IOException, ZarrException,

String command = pythonPath();

ProcessBuilder pb = new ProcessBuilder(command, ZARRITA_READ_PATH.toString(), codec, TESTOUTPUT.toString());
ProcessBuilder pb = new ProcessBuilder(command, PYTHON_TEST_PATH.resolve("zarrita_read.py").toString(), codec, TESTOUTPUT.toString());
Process process = pb.start();

BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
Expand All @@ -161,7 +196,7 @@ public void testWriteToZarrita(String codec) throws IOException, ZarrException,

@ParameterizedTest
@ValueSource(strings = {"blosc", "gzip", "zstd", "bytes", "transpose", "sharding_start", "sharding_end"})
public void testCodecsWriteRead(String codec) throws IOException, ZarrException, InterruptedException {
public void testCodecsWriteRead(String codec) throws IOException, ZarrException {
int[] testData = new int[16 * 16 * 16];
Arrays.setAll(testData, p -> p);

Expand All @@ -175,10 +210,10 @@ public void testCodecsWriteRead(String codec) throws IOException, ZarrException,

switch (codec) {
case "blosc":
builder = builder.withCodecs(c -> c.withBlosc());
builder = builder.withCodecs(CodecBuilder::withBlosc);
break;
case "gzip":
builder = builder.withCodecs(c -> c.withGzip());
builder = builder.withCodecs(CodecBuilder::withGzip);
break;
case "zstd":
builder = builder.withCodecs(c -> c.withZstd(0));
Expand Down Expand Up @@ -216,8 +251,31 @@ public void testCodecsWriteRead(String codec) throws IOException, ZarrException,
Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT));
}

@ParameterizedTest
@CsvSource({"0,true", "0,false", "5, true", "5, false"})
public void testZstdCodecReadWrite(int clevel, boolean checksum) throws ZarrException, IOException {
int[] testData = new int[16 * 16 * 16];
Arrays.setAll(testData, p -> p);

StoreHandle storeHandle = new FilesystemStore(TESTOUTPUT).resolve("testZstdCodecReadWrite", "checksum_" + checksum, "clevel_" + clevel);
ArrayMetadataBuilder builder = Array.metadataBuilder()
.withShape(16, 16, 16)
.withDataType(DataType.UINT32)
.withChunkShape(2, 4, 8)
.withFillValue(0)
.withCodecs(c -> c.withZstd(clevel, checksum));
Array writeArray = Array.create(storeHandle, builder.build());
writeArray.write(ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{16, 16, 16}, testData));

Array readArray = Array.open(storeHandle);
ucar.ma2.Array result = readArray.read();

Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT));

}

@Test
public void testCodecTranspose() throws IOException, ZarrException, InterruptedException {
public void testTransposeCodec() throws ZarrException {
ucar.ma2.Array testData = ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{2, 3, 3}, new int[]{
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
ucar.ma2.Array testDataTransposed120 = ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{3, 3, 2}, new int[]{
Expand All @@ -237,8 +295,8 @@ public void testCodecTranspose() throws IOException, ZarrException, InterruptedE
transposeCodecWrongOrder2.setCoreArrayMetadata(metadata);
transposeCodecWrongOrder3.setCoreArrayMetadata(metadata);

assert ucar.ma2.MAMath.equals(testDataTransposed120, transposeCodec.encode(testData));
assert ucar.ma2.MAMath.equals(testData, transposeCodec.decode(testDataTransposed120));
assert MAMath.equals(testDataTransposed120, transposeCodec.encode(testData));
assert MAMath.equals(testData, transposeCodec.decode(testDataTransposed120));
assertThrows(ZarrException.class, () -> transposeCodecWrongOrder1.encode(testData));
assertThrows(ZarrException.class, () -> transposeCodecWrongOrder2.encode(testData));
assertThrows(ZarrException.class, () -> transposeCodecWrongOrder3.encode(testData));
Expand Down Expand Up @@ -295,8 +353,8 @@ public void testV3ShardingReadCutout() throws IOException, ZarrException {
Array array = Array.open(new FilesystemStore(TESTDATA).resolve("l4_sample", "color", "1"));

ucar.ma2.Array outArray = array.read(new long[]{0, 3073, 3073, 513}, new int[]{1, 64, 64, 64});
assertEquals(outArray.getSize(), 64 * 64 * 64);
assertEquals(outArray.getByte(0), -98);
Assertions.assertEquals(outArray.getSize(), 64 * 64 * 64);
Assertions.assertEquals(outArray.getByte(0), -98);
}

@Test
Expand All @@ -306,8 +364,8 @@ public void testV3Access() throws IOException, ZarrException {
ucar.ma2.Array outArray = readArray.access().withOffset(0, 3073, 3073, 513)
.withShape(1, 64, 64, 64)
.read();
assertEquals(outArray.getSize(), 64 * 64 * 64);
assertEquals(outArray.getByte(0), -98);
Assertions.assertEquals(outArray.getSize(), 64 * 64 * 64);
Assertions.assertEquals(outArray.getByte(0), -98);

Array writeArray = Array.create(
new FilesystemStore(TESTOUTPUT).resolve("l4_sample_2", "color", "1"),
Expand Down Expand Up @@ -375,15 +433,15 @@ public void testV3ArrayMetadataBuilder() throws ZarrException {
.withChunkShape(1, 1024, 1024, 1024)
.withFillValue(0)
.withCodecs(
c -> c.withSharding(new int[]{1, 32, 32, 32}, c1 -> c1.withBlosc()))
c -> c.withSharding(new int[]{1, 32, 32, 32}, CodecBuilder::withBlosc))
.build();
}

@Test
public void testV3FillValue() throws ZarrException {
assertEquals((int) ArrayMetadata.parseFillValue(0, DataType.UINT32), 0);
assertEquals((int) ArrayMetadata.parseFillValue("0x00010203", DataType.UINT32), 50462976);
assertEquals((byte) ArrayMetadata.parseFillValue("0b00000010", DataType.UINT8), 2);
Assertions.assertEquals((int) ArrayMetadata.parseFillValue(0, DataType.UINT32), 0);
Assertions.assertEquals((int) ArrayMetadata.parseFillValue("0x00010203", DataType.UINT32), 50462976);
Assertions.assertEquals((byte) ArrayMetadata.parseFillValue("0b00000010", DataType.UINT8), 2);
assert Double.isNaN((double) ArrayMetadata.parseFillValue("NaN", DataType.FLOAT64));
assert Double.isInfinite((double) ArrayMetadata.parseFillValue("-Infinity", DataType.FLOAT64));
}
Expand All @@ -401,18 +459,15 @@ public void testV3Group() throws IOException, ZarrException {
);
array.write(new long[]{2, 2}, ucar.ma2.Array.factory(ucar.ma2.DataType.UBYTE, new int[]{8, 8}));

assertArrayEquals(
((Array) ((Group) group.listAsArray()[0]).listAsArray()[0]).metadata.chunkShape(),
new int[]{5, 5});
Assertions.assertArrayEquals(((Array) ((Group) group.listAsArray()[0]).listAsArray()[0]).metadata.chunkShape(), new int[]{5, 5});
}

@Test
public void testV2() throws IOException, ZarrException {
public void testV2() throws IOException{
FilesystemStore fsStore = new FilesystemStore("");
HttpStore httpStore = new HttpStore("https://static.webknossos.org/data");

System.out.println(
dev.zarr.zarrjava.v2.Array.open(httpStore.resolve("l4_sample", "color", "1")));
System.out.println(dev.zarr.zarrjava.v2.Array.open(httpStore.resolve("l4_sample", "color", "1")));
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
elif codec_string == "gzip":
codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.gzip_codec()]
elif codec_string == "zstd":
codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.zstd_codec()]
codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.zstd_codec(checksum=True)]
elif codec_string == "bytes":
codec = [zarrita.codecs.bytes_codec()]
elif codec_string == "transpose":
Expand Down
Loading

0 comments on commit 9733e07

Please sign in to comment.