Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix transpose codec #1

Merged
merged 8 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ jobs:
strategy:
matrix:
os: [ ubuntu, windows, macos ]
fail-fast: false
runs-on: ${{ matrix.os }}-latest
defaults:
run:
Expand All @@ -22,7 +23,7 @@ jobs:
- name: Set up JDK
uses: actions/setup-java@v3
with:
java-version: '8'
java-version: '22'
distribution: 'temurin'
cache: maven

Expand Down
9 changes: 3 additions & 6 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@
<version>4.13.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.zarr.zarrjava</groupId>
<artifactId>zarr-java</artifactId>
<version>0.0.1-SNAPSHOT</version>
<scope>test</scope>
</dependency>
</dependencies>

<repositories>
Expand All @@ -84,4 +78,7 @@
</repository>
</repositories>

<build>
<testSourceDirectory>src/test/java/dev/zarr/zarrjava</testSourceDirectory>
</build>
</project>
20 changes: 20 additions & 0 deletions src/main/java/dev/zarr/zarrjava/utils/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,24 @@ public static <T> T[] concatArrays(T[] array1, T[]... arrays) {
}
return result;
}

public static boolean isPermutation(int[] array) {
if (array.length==0){
return false;
}
int[] arange = new int[array.length];
Arrays.setAll(arange, i -> i);
int[] orderSorted = array.clone();
Arrays.sort(orderSorted);
return Arrays.equals(orderSorted, arange);
}

public static int[] inversePermutation(int[] origin){
assert isPermutation(origin);
int[] inverse = new int[origin.length];
for (int i = 0; i < origin.length; i++) {
inverse[origin[i]] = i;
}
return inverse;
}
}
9 changes: 4 additions & 5 deletions src/main/java/dev/zarr/zarrjava/v3/Array.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ protected Array(StoreHandle storeHandle, ArrayMetadata arrayMetadata)
throws IOException, ZarrException {
super(storeHandle);
this.metadata = arrayMetadata;
this.codecPipeline = new CodecPipeline(arrayMetadata.codecs);
this.codecPipeline = new CodecPipeline(arrayMetadata.codecs, arrayMetadata.coreArrayMetadata);
}

/**
Expand Down Expand Up @@ -171,8 +171,7 @@ public ucar.ma2.Array read(final long[] offset, final int[] shape) throws ZarrEx

if (codecPipeline.supportsPartialDecode()) {
final ucar.ma2.Array chunkArray = codecPipeline.decodePartial(chunkHandle,
Utils.toLongArray(chunkProjection.chunkOffset), chunkProjection.shape,
metadata.coreArrayMetadata);
Utils.toLongArray(chunkProjection.chunkOffset), chunkProjection.shape);
MultiArrayUtils.copyRegion(chunkArray, new int[metadata.ndim()], outputArray,
chunkProjection.outOffset, chunkProjection.shape
);
Expand Down Expand Up @@ -223,7 +222,7 @@ public ucar.ma2.Array readChunk(long[] chunkCoords)
return metadata.allocateFillValueChunk();
}

return codecPipeline.decode(chunkBytes, metadata.coreArrayMetadata);
return codecPipeline.decode(chunkBytes);
}

/**
Expand Down Expand Up @@ -299,7 +298,7 @@ public void writeChunk(long[] chunkCoords, ucar.ma2.Array chunkArray) throws Zar
if (MultiArrayUtils.allValuesEqual(chunkArray, metadata.parsedFillValue)) {
chunkHandle.delete();
} else {
ByteBuffer chunkBytes = codecPipeline.encode(chunkArray, metadata.coreArrayMetadata);
ByteBuffer chunkBytes = codecPipeline.encode(chunkArray);
chunkHandle.set(chunkBytes);
}
}
Expand Down
7 changes: 3 additions & 4 deletions src/main/java/dev/zarr/zarrjava/v3/codec/ArrayArrayCodec.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package dev.zarr.zarrjava.v3.codec;

import dev.zarr.zarrjava.ZarrException;
import dev.zarr.zarrjava.v3.ArrayMetadata.CoreArrayMetadata;
import ucar.ma2.Array;

public interface ArrayArrayCodec extends Codec {
public abstract class ArrayArrayCodec extends Codec {

Array encode(Array chunkArray, CoreArrayMetadata arrayMetadata)
protected abstract Array encode(Array chunkArray)
throws ZarrException;

Array decode(Array chunkArray, CoreArrayMetadata arrayMetadata)
protected abstract Array decode(Array chunkArray)
throws ZarrException;

}
17 changes: 9 additions & 8 deletions src/main/java/dev/zarr/zarrjava/v3/codec/ArrayBytesCodec.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@

import dev.zarr.zarrjava.ZarrException;
import dev.zarr.zarrjava.store.StoreHandle;
import dev.zarr.zarrjava.v3.ArrayMetadata.CoreArrayMetadata;
import java.nio.ByteBuffer;
import ucar.ma2.Array;

public interface ArrayBytesCodec extends Codec {
public abstract class ArrayBytesCodec extends Codec {

ByteBuffer encode(Array chunkArray, CoreArrayMetadata arrayMetadata)
protected abstract ByteBuffer encode(Array chunkArray)
throws ZarrException;

Array decode(ByteBuffer chunkBytes, CoreArrayMetadata arrayMetadata)
protected abstract Array decode(ByteBuffer chunkBytes)
throws ZarrException;

interface WithPartialDecode extends ArrayBytesCodec {
public abstract static class WithPartialDecode extends ArrayBytesCodec {

Array decodePartial(
StoreHandle handle, long[] offset, int[] shape,
CoreArrayMetadata arrayMetadata
public abstract Array decode(ByteBuffer shardBytes) throws ZarrException;
public abstract ByteBuffer encode(Array shardArray) throws ZarrException;

protected abstract Array decodePartial(
StoreHandle handle, long[] offset, int[] shape
) throws ZarrException;
}
}
Expand Down
10 changes: 4 additions & 6 deletions src/main/java/dev/zarr/zarrjava/v3/codec/BytesBytesCodec.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package dev.zarr.zarrjava.v3.codec;

import dev.zarr.zarrjava.ZarrException;
import dev.zarr.zarrjava.v3.ArrayMetadata.CoreArrayMetadata;

import java.nio.ByteBuffer;

public interface BytesBytesCodec extends Codec {
public abstract class BytesBytesCodec extends Codec {

ByteBuffer encode(ByteBuffer chunkBytes, CoreArrayMetadata arrayMetadata)
throws ZarrException;
protected abstract ByteBuffer encode(ByteBuffer chunkBytes) throws ZarrException;

ByteBuffer decode(ByteBuffer chunkBytes, CoreArrayMetadata arrayMetadata)
throws ZarrException;
public abstract ByteBuffer decode(ByteBuffer chunkBytes) throws ZarrException;

}
19 changes: 16 additions & 3 deletions src/main/java/dev/zarr/zarrjava/v3/codec/Codec.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,22 @@
import dev.zarr.zarrjava.v3.ArrayMetadata;

@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "name")
public interface Codec {
public abstract class Codec {

long computeEncodedSize(long inputByteLength, ArrayMetadata.CoreArrayMetadata arrayMetadata)
throws ZarrException;
protected ArrayMetadata.CoreArrayMetadata arrayMetadata;

protected ArrayMetadata.CoreArrayMetadata resolveArrayMetadata() throws ZarrException {
if (arrayMetadata == null) {
throw new ZarrException("arrayMetadata needs to get set in for every codec");
}
return this.arrayMetadata;
}

protected abstract long computeEncodedSize(long inputByteLength, ArrayMetadata.CoreArrayMetadata arrayMetadata)
throws ZarrException;

public void setCoreArrayMetadata(ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException{
this.arrayMetadata = arrayMetadata;
}
}

8 changes: 2 additions & 6 deletions src/main/java/dev/zarr/zarrjava/v3/codec/CodecBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,9 @@ public CodecBuilder withBlosc() {
return withBlosc("zstd");
}

public CodecBuilder withTranspose(String order) {
try {
public CodecBuilder withTranspose(int[] order) {
codecs.add(new TransposeCodec(new TransposeCodec.Configuration(order)));
} catch (ZarrException e) {
throw new RuntimeException(e);
}
return this;
return this;
}

public CodecBuilder withBytes(Endian endian) {
Expand Down
29 changes: 16 additions & 13 deletions src/main/java/dev/zarr/zarrjava/v3/codec/CodecPipeline.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@ public class CodecPipeline {

@Nonnull
final Codec[] codecs;
public final CoreArrayMetadata arrayMetadata;

public CodecPipeline(@Nonnull Codec[] codecs) throws ZarrException {
public CodecPipeline(@Nonnull Codec[] codecs, CoreArrayMetadata arrayMetadata) throws ZarrException {
this.arrayMetadata = arrayMetadata;
long arrayBytesCodecCount = Arrays.stream(codecs).filter(c -> c instanceof ArrayBytesCodec)
.count();
if (arrayBytesCodecCount != 1) {
throw new ZarrException(
"Exactly 1 ArrayBytesCodec is required. Found " + arrayBytesCodecCount + ".");
}
Codec prevCodec = null;
CoreArrayMetadata codecArrayMetadata = arrayMetadata;
for (Codec codec : codecs) {
if (prevCodec != null) {
if (codec instanceof ArrayBytesCodec && prevCodec instanceof ArrayBytesCodec) {
Expand All @@ -44,6 +47,8 @@ public CodecPipeline(@Nonnull Codec[] codecs) throws ZarrException {
prevCodec.getClass() + "'.");
}
}
codec.setCoreArrayMetadata(codecArrayMetadata);
codecArrayMetadata = codec.resolveArrayMetadata();
prevCodec = codec;
}

Expand Down Expand Up @@ -79,15 +84,14 @@ public boolean supportsPartialDecode() {
@Nonnull
public Array decodePartial(
@Nonnull StoreHandle storeHandle,
long[] offset, int[] shape,
@Nonnull CoreArrayMetadata arrayMetadata
long[] offset, int[] shape
) throws ZarrException {
if (!supportsPartialDecode()) {
throw new ZarrException(
"Partial decode is not supported for these codecs. " + Arrays.toString(codecs));
}
Array chunkArray = ((ArrayBytesCodec.WithPartialDecode) getArrayBytesCodec()).decodePartial(
storeHandle, offset, shape, arrayMetadata);
storeHandle, offset, shape);
if (chunkArray == null) {
throw new ZarrException("chunkArray is null. This is likely a bug in one of the codecs.");
}
Expand All @@ -96,8 +100,7 @@ public Array decodePartial(

@Nonnull
public Array decode(
@Nonnull ByteBuffer chunkBytes,
@Nonnull CoreArrayMetadata arrayMetadata
@Nonnull ByteBuffer chunkBytes
) throws ZarrException {
if (chunkBytes == null) {
throw new ZarrException("chunkBytes is null. Ohh nooo.");
Expand All @@ -106,23 +109,23 @@ public Array decode(
BytesBytesCodec[] bytesBytesCodecs = getBytesBytesCodecs();
for (int i = bytesBytesCodecs.length - 1; i >= 0; --i) {
BytesBytesCodec codec = bytesBytesCodecs[i];
chunkBytes = codec.decode(chunkBytes, arrayMetadata);
chunkBytes = codec.decode(chunkBytes);
}

if (chunkBytes == null) {
throw new ZarrException(
"chunkBytes is null. This is likely a bug in one of the codecs. " + Arrays.toString(
getBytesBytesCodecs()));
}
Array chunkArray = getArrayBytesCodec().decode(chunkBytes, arrayMetadata);
Array chunkArray = getArrayBytesCodec().decode(chunkBytes);
if (chunkArray == null) {
throw new ZarrException("chunkArray is null. This is likely a bug in one of the codecs.");
}

ArrayArrayCodec[] arrayArrayCodecs = getArrayArrayCodecs();
for (int i = arrayArrayCodecs.length - 1; i >= 0; --i) {
ArrayArrayCodec codec = arrayArrayCodecs[i];
chunkArray = codec.decode(chunkArray, arrayMetadata);
chunkArray = codec.decode(chunkArray);
}

if (chunkArray == null) {
Expand All @@ -133,16 +136,16 @@ public Array decode(

@Nonnull
public ByteBuffer encode(
@Nonnull Array chunkArray, @Nonnull CoreArrayMetadata arrayMetadata
@Nonnull Array chunkArray
) throws ZarrException {
for (ArrayArrayCodec codec : getArrayArrayCodecs()) {
chunkArray = codec.encode(chunkArray, arrayMetadata);
chunkArray = codec.encode(chunkArray);
}

ByteBuffer chunkBytes = getArrayBytesCodec().encode(chunkArray, arrayMetadata);
ByteBuffer chunkBytes = getArrayBytesCodec().encode(chunkArray);

for (BytesBytesCodec codec : getBytesBytesCodecs()) {
chunkBytes = codec.encode(chunkBytes, arrayMetadata);
chunkBytes = codec.encode(chunkBytes);
}
return chunkBytes;
}
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/dev/zarr/zarrjava/v3/codec/core/BloscCodec.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.nio.ByteBuffer;
import javax.annotation.Nonnull;

public class BloscCodec implements BytesBytesCodec {
public class BloscCodec extends BytesBytesCodec {

public final String name = "blosc";
@Nonnull
Expand All @@ -33,7 +33,7 @@ public BloscCodec(
}

@Override
public ByteBuffer decode(ByteBuffer chunkBytes, ArrayMetadata.CoreArrayMetadata arrayMetadata)
public ByteBuffer decode(ByteBuffer chunkBytes)
throws ZarrException {
try {
return ByteBuffer.wrap(Blosc.decompress(Utils.toArray(chunkBytes)));
Expand All @@ -43,7 +43,7 @@ public ByteBuffer decode(ByteBuffer chunkBytes, ArrayMetadata.CoreArrayMetadata
}

@Override
public ByteBuffer encode(ByteBuffer chunkBytes, ArrayMetadata.CoreArrayMetadata arrayMetadata)
public ByteBuffer encode(ByteBuffer chunkBytes)
throws ZarrException {
try {
return ByteBuffer.wrap(
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/dev/zarr/zarrjava/v3/codec/core/BytesCodec.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import javax.annotation.Nonnull;
import ucar.ma2.Array;

public class BytesCodec implements ArrayBytesCodec {
public class BytesCodec extends ArrayBytesCodec {

public final String name = "bytes";
@Nonnull
Expand All @@ -29,14 +29,14 @@ public BytesCodec(Endian endian) {
}

@Override
public Array decode(ByteBuffer chunkBytes, ArrayMetadata.CoreArrayMetadata arrayMetadata) {
public Array decode(ByteBuffer chunkBytes) {
chunkBytes.order(configuration.endian.getByteOrder());
return Array.factory(arrayMetadata.dataType.getMA2DataType(), arrayMetadata.chunkShape,
chunkBytes);
}

@Override
public ByteBuffer encode(Array chunkArray, ArrayMetadata.CoreArrayMetadata arrayMetadata) {
public ByteBuffer encode(Array chunkArray) {
return chunkArray.getDataAsByteBuffer(configuration.endian.getByteOrder());
}

Expand Down Expand Up @@ -72,7 +72,7 @@ public ByteOrder getByteOrder() {
}
}

public static final class Configuration {
public static final class Configuration{

@Nonnull
public final BytesCodec.Endian endian;
Expand Down
Loading
Loading