Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zstd codec #4

Merged
merged 5 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 50 additions & 65 deletions src/main/java/dev/zarr/zarrjava/v3/codec/core/ZstdCodec.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,89 +2,74 @@

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.github.luben.zstd.ZstdInputStream;
import com.github.luben.zstd.ZstdOutputStream;
import com.github.luben.zstd.Zstd;
import com.github.luben.zstd.ZstdCompressCtx;
import dev.zarr.zarrjava.ZarrException;
import dev.zarr.zarrjava.utils.Utils;
import dev.zarr.zarrjava.v3.ArrayMetadata;
import dev.zarr.zarrjava.v3.codec.BytesBytesCodec;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;

import javax.annotation.Nonnull;
import java.nio.ByteBuffer;

public class ZstdCodec extends BytesBytesCodec {

public final String name = "zstd";
@Nonnull
public final Configuration configuration;
public final String name = "zstd";
@Nonnull
public final Configuration configuration;

@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
public ZstdCodec(
@Nonnull @JsonProperty(value = "configuration", required = true) Configuration configuration) {
this.configuration = configuration;
}

private void copy(InputStream inputStream, OutputStream outputStream) throws IOException {
byte[] buffer = new byte[4096];
int len;
while ((len = inputStream.read(buffer)) > 0) {
outputStream.write(buffer, 0, len);
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
public ZstdCodec(
@Nonnull @JsonProperty(value = "configuration", required = true) Configuration configuration) {
this.configuration = configuration;
}
}

@Override
public ByteBuffer decode(ByteBuffer chunkBytes)
throws ZarrException {
try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ZstdInputStream inputStream = new ZstdInputStream(
new ByteArrayInputStream(Utils.toArray(chunkBytes)))) {
copy(inputStream, outputStream);
inputStream.close();
return ByteBuffer.wrap(outputStream.toByteArray());
} catch (IOException ex) {
throw new ZarrException("Error in decoding zstd.", ex);
@Override
public ByteBuffer decode(ByteBuffer compressedBytes) throws ZarrException {
byte[] compressedArray = compressedBytes.array();

long originalSize = Zstd.decompressedSize(compressedArray);
if (originalSize == 0) {
throw new ZarrException("Failed to get decompressed size");
}

byte[] decompressed = Zstd.decompress(compressedArray, (int) originalSize);
return ByteBuffer.wrap(decompressed);
}
}

@Override
public ByteBuffer encode(ByteBuffer chunkBytes)
throws ZarrException {
try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ZstdOutputStream zstdStream = new ZstdOutputStream(
outputStream, configuration.level).setChecksum(
configuration.checksum)) {
zstdStream.write(Utils.toArray(chunkBytes));
zstdStream.close();
return ByteBuffer.wrap(outputStream.toByteArray());
} catch (IOException ex) {
throw new ZarrException("Error in encoding zstd.", ex);
@Override
public ByteBuffer encode(ByteBuffer chunkBytes) throws ZarrException {
byte[] arr = chunkBytes.array();
byte[] compressed;
try (ZstdCompressCtx ctx = new ZstdCompressCtx()) {
ctx.setLevel(configuration.level);
ctx.setChecksum(configuration.checksum);
compressed = ctx.compress(arr);
}
return ByteBuffer.wrap(compressed);
}
}

@Override
public long computeEncodedSize(long inputByteLength,
ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException {
throw new ZarrException("Not implemented for Zstd codec.");
}
@Override
public long computeEncodedSize(long inputByteLength,
ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException {
throw new ZarrException("Not implemented for Zstd codec.");
}

public static final class Configuration {
public static final class Configuration {

public final int level;
public final boolean checksum;
public final int level;
public final boolean checksum;

@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
public Configuration(@JsonProperty(value = "level", defaultValue = "5") int level,
@JsonProperty(value = "checksum", defaultValue = "true") boolean checksum)
throws ZarrException {
if (level < -131072 || level > 22) {
throw new ZarrException("'level' needs to be between -131072 and 22.");
}
this.level = level;
this.checksum = checksum;
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
public Configuration(@JsonProperty(value = "level", defaultValue = "5") int level,
@JsonProperty(value = "checksum", defaultValue = "true") boolean checksum)
throws ZarrException {
if (level < -131072 || level > 22) {
throw new ZarrException("'level' needs to be between -131072 and 22.");
}
this.level = level;
this.checksum = checksum;
}
}
}
}


119 changes: 87 additions & 32 deletions src/test/java/dev/zarr/zarrjava/ZarrTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,26 @@
import com.amazonaws.auth.AnonymousAWSCredentials;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.luben.zstd.Zstd;
import com.github.luben.zstd.ZstdCompressCtx;
import dev.zarr.zarrjava.store.FilesystemStore;
import dev.zarr.zarrjava.store.HttpStore;
import dev.zarr.zarrjava.store.S3Store;
import dev.zarr.zarrjava.store.StoreHandle;
import dev.zarr.zarrjava.utils.MultiArrayUtils;
import dev.zarr.zarrjava.v3.*;
import dev.zarr.zarrjava.v3.codec.CodecBuilder;
import dev.zarr.zarrjava.v3.codec.core.TransposeCodec;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import ucar.ma2.MAMath;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
Expand All @@ -35,8 +38,7 @@ public class ZarrTest {

final static Path TESTDATA = Paths.get("testdata");
final static Path TESTOUTPUT = Paths.get("testoutput");
final static Path ZARRITA_WRITE_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/zarrita_write.py");
final static Path ZARRITA_READ_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/zarrita_read.py");
final static Path PYTHON_TEST_PATH = Paths.get("src/test/python-scripts/");

public static String pythonPath() {
if (System.getProperty("os.name").startsWith("Windows")) {
Expand All @@ -60,7 +62,7 @@ public static void clearTestoutputFolder() throws IOException {
public void testReadFromZarrita(String codec) throws IOException, ZarrException, InterruptedException {

String command = pythonPath();
ProcessBuilder pb = new ProcessBuilder(command, ZARRITA_WRITE_PATH.toString(), codec, TESTOUTPUT.toString());
ProcessBuilder pb = new ProcessBuilder(command, PYTHON_TEST_PATH.resolve("zarrita_write.py").toString(), codec, TESTOUTPUT.toString());
Process process = pb.start();

BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
Expand Down Expand Up @@ -91,10 +93,43 @@ public void testReadFromZarrita(String codec) throws IOException, ZarrException,
Assertions.assertArrayEquals(expectedData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT));
}

@CsvSource({"0,true", "0,false", "5, true", "10, false"})
@ParameterizedTest
public void testZstdLibrary(int clevel, boolean checksumFlag) throws IOException, InterruptedException {
//compress using ZstdCompressCtx
int number = 123456;
byte[] src = ByteBuffer.allocate(4).putInt(number).array();
byte[] compressed;
try (ZstdCompressCtx ctx = new ZstdCompressCtx()) {
ctx.setLevel(clevel);
ctx.setChecksum(checksumFlag);
compressed = ctx.compress(src);
}
//decompress with Zstd.decompress
long originalSize = Zstd.decompressedSize(compressed);
byte[] decompressed = Zstd.decompress(compressed, (int) originalSize);
Assertions.assertEquals(number, ByteBuffer.wrap(decompressed).getInt());

//write compressed to file
String compressedDataPath = TESTOUTPUT.resolve("compressed" + clevel + checksumFlag + ".bin").toString();
try (FileOutputStream fos = new FileOutputStream(compressedDataPath)) {
fos.write(compressed);
}

//decompress in python
Process process = new ProcessBuilder(
pythonPath(),
PYTHON_TEST_PATH.resolve("zstd_decompress.py").toString(),
compressedDataPath,
Integer.toString(number)
).start();
int exitCode = process.waitFor();
assert exitCode == 0;
}

//TODO: add crc32c
//Disabled "zstd": known issue
@ParameterizedTest
@ValueSource(strings = {"blosc", "gzip", "bytes", "transpose", "sharding_start", "sharding_end"})
@ValueSource(strings = {"blosc", "gzip", "zstd", "bytes", "transpose", "sharding_start", "sharding_end"})
public void testWriteToZarrita(String codec) throws IOException, ZarrException, InterruptedException {
StoreHandle storeHandle = new FilesystemStore(TESTOUTPUT).resolve("write_to_zarrita", codec);
ArrayMetadataBuilder builder = Array.metadataBuilder()
Expand All @@ -105,10 +140,10 @@ public void testWriteToZarrita(String codec) throws IOException, ZarrException,

switch (codec) {
case "blosc":
builder = builder.withCodecs(c -> c.withBlosc());
builder = builder.withCodecs(CodecBuilder::withBlosc);
break;
case "gzip":
builder = builder.withCodecs(c -> c.withGzip());
builder = builder.withCodecs(CodecBuilder::withGzip);
break;
case "zstd":
builder = builder.withCodecs(c -> c.withZstd(0));
Expand Down Expand Up @@ -140,7 +175,7 @@ public void testWriteToZarrita(String codec) throws IOException, ZarrException,

String command = pythonPath();

ProcessBuilder pb = new ProcessBuilder(command, ZARRITA_READ_PATH.toString(), codec, TESTOUTPUT.toString());
ProcessBuilder pb = new ProcessBuilder(command, PYTHON_TEST_PATH.resolve("zarrita_read.py").toString(), codec, TESTOUTPUT.toString());
Process process = pb.start();

BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
Expand All @@ -161,7 +196,7 @@ public void testWriteToZarrita(String codec) throws IOException, ZarrException,

@ParameterizedTest
@ValueSource(strings = {"blosc", "gzip", "zstd", "bytes", "transpose", "sharding_start", "sharding_end"})
public void testCodecsWriteRead(String codec) throws IOException, ZarrException, InterruptedException {
public void testCodecsWriteRead(String codec) throws IOException, ZarrException {
int[] testData = new int[16 * 16 * 16];
Arrays.setAll(testData, p -> p);

Expand All @@ -175,10 +210,10 @@ public void testCodecsWriteRead(String codec) throws IOException, ZarrException,

switch (codec) {
case "blosc":
builder = builder.withCodecs(c -> c.withBlosc());
builder = builder.withCodecs(CodecBuilder::withBlosc);
break;
case "gzip":
builder = builder.withCodecs(c -> c.withGzip());
builder = builder.withCodecs(CodecBuilder::withGzip);
break;
case "zstd":
builder = builder.withCodecs(c -> c.withZstd(0));
Expand Down Expand Up @@ -216,8 +251,31 @@ public void testCodecsWriteRead(String codec) throws IOException, ZarrException,
Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT));
}

@ParameterizedTest
@CsvSource({"0,true", "0,false", "5, true", "5, false"})
public void testZstdCodecReadWrite(int clevel, boolean checksum) throws ZarrException, IOException {
int[] testData = new int[16 * 16 * 16];
Arrays.setAll(testData, p -> p);

StoreHandle storeHandle = new FilesystemStore(TESTOUTPUT).resolve("testZstdCodecReadWrite", "checksum_" + checksum, "clevel_" + clevel);
ArrayMetadataBuilder builder = Array.metadataBuilder()
.withShape(16, 16, 16)
.withDataType(DataType.UINT32)
.withChunkShape(2, 4, 8)
.withFillValue(0)
.withCodecs(c -> c.withZstd(clevel, checksum));
Array writeArray = Array.create(storeHandle, builder.build());
writeArray.write(ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{16, 16, 16}, testData));

Array readArray = Array.open(storeHandle);
ucar.ma2.Array result = readArray.read();

Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT));

}

@Test
public void testCodecTranspose() throws IOException, ZarrException, InterruptedException {
public void testTransposeCodec() throws ZarrException {
ucar.ma2.Array testData = ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{2, 3, 3}, new int[]{
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
ucar.ma2.Array testDataTransposed120 = ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{3, 3, 2}, new int[]{
Expand All @@ -237,8 +295,8 @@ public void testCodecTranspose() throws IOException, ZarrException, InterruptedE
transposeCodecWrongOrder2.setCoreArrayMetadata(metadata);
transposeCodecWrongOrder3.setCoreArrayMetadata(metadata);

assert ucar.ma2.MAMath.equals(testDataTransposed120, transposeCodec.encode(testData));
assert ucar.ma2.MAMath.equals(testData, transposeCodec.decode(testDataTransposed120));
assert MAMath.equals(testDataTransposed120, transposeCodec.encode(testData));
assert MAMath.equals(testData, transposeCodec.decode(testDataTransposed120));
assertThrows(ZarrException.class, () -> transposeCodecWrongOrder1.encode(testData));
assertThrows(ZarrException.class, () -> transposeCodecWrongOrder2.encode(testData));
assertThrows(ZarrException.class, () -> transposeCodecWrongOrder3.encode(testData));
Expand Down Expand Up @@ -295,8 +353,8 @@ public void testV3ShardingReadCutout() throws IOException, ZarrException {
Array array = Array.open(new FilesystemStore(TESTDATA).resolve("l4_sample", "color", "1"));

ucar.ma2.Array outArray = array.read(new long[]{0, 3073, 3073, 513}, new int[]{1, 64, 64, 64});
assertEquals(outArray.getSize(), 64 * 64 * 64);
assertEquals(outArray.getByte(0), -98);
Assertions.assertEquals(outArray.getSize(), 64 * 64 * 64);
Assertions.assertEquals(outArray.getByte(0), -98);
}

@Test
Expand All @@ -306,8 +364,8 @@ public void testV3Access() throws IOException, ZarrException {
ucar.ma2.Array outArray = readArray.access().withOffset(0, 3073, 3073, 513)
.withShape(1, 64, 64, 64)
.read();
assertEquals(outArray.getSize(), 64 * 64 * 64);
assertEquals(outArray.getByte(0), -98);
Assertions.assertEquals(outArray.getSize(), 64 * 64 * 64);
Assertions.assertEquals(outArray.getByte(0), -98);

Array writeArray = Array.create(
new FilesystemStore(TESTOUTPUT).resolve("l4_sample_2", "color", "1"),
Expand Down Expand Up @@ -375,15 +433,15 @@ public void testV3ArrayMetadataBuilder() throws ZarrException {
.withChunkShape(1, 1024, 1024, 1024)
.withFillValue(0)
.withCodecs(
c -> c.withSharding(new int[]{1, 32, 32, 32}, c1 -> c1.withBlosc()))
c -> c.withSharding(new int[]{1, 32, 32, 32}, CodecBuilder::withBlosc))
.build();
}

@Test
public void testV3FillValue() throws ZarrException {
assertEquals((int) ArrayMetadata.parseFillValue(0, DataType.UINT32), 0);
assertEquals((int) ArrayMetadata.parseFillValue("0x00010203", DataType.UINT32), 50462976);
assertEquals((byte) ArrayMetadata.parseFillValue("0b00000010", DataType.UINT8), 2);
Assertions.assertEquals((int) ArrayMetadata.parseFillValue(0, DataType.UINT32), 0);
Assertions.assertEquals((int) ArrayMetadata.parseFillValue("0x00010203", DataType.UINT32), 50462976);
Assertions.assertEquals((byte) ArrayMetadata.parseFillValue("0b00000010", DataType.UINT8), 2);
assert Double.isNaN((double) ArrayMetadata.parseFillValue("NaN", DataType.FLOAT64));
assert Double.isInfinite((double) ArrayMetadata.parseFillValue("-Infinity", DataType.FLOAT64));
}
Expand All @@ -401,18 +459,15 @@ public void testV3Group() throws IOException, ZarrException {
);
array.write(new long[]{2, 2}, ucar.ma2.Array.factory(ucar.ma2.DataType.UBYTE, new int[]{8, 8}));

assertArrayEquals(
((Array) ((Group) group.listAsArray()[0]).listAsArray()[0]).metadata.chunkShape(),
new int[]{5, 5});
Assertions.assertArrayEquals(((Array) ((Group) group.listAsArray()[0]).listAsArray()[0]).metadata.chunkShape(), new int[]{5, 5});
}

@Test
public void testV2() throws IOException, ZarrException {
public void testV2() throws IOException{
FilesystemStore fsStore = new FilesystemStore("");
HttpStore httpStore = new HttpStore("https://static.webknossos.org/data");

System.out.println(
dev.zarr.zarrjava.v2.Array.open(httpStore.resolve("l4_sample", "color", "1")));
System.out.println(dev.zarr.zarrjava.v2.Array.open(httpStore.resolve("l4_sample", "color", "1")));
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
elif codec_string == "gzip":
codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.gzip_codec()]
elif codec_string == "zstd":
codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.zstd_codec()]
codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.zstd_codec(checksum=True)]
elif codec_string == "bytes":
codec = [zarrita.codecs.bytes_codec()]
elif codec_string == "transpose":
Expand Down
Loading
Loading