Skip to content

Commit

Permalink
Merge pull request #1 from scalableminds/fix-transpose-codec
Browse files Browse the repository at this point in the history
Fix transpose codec
  • Loading branch information
brokkoli71 authored May 29, 2024
2 parents 0f50859 + 30bb41d commit d2d3e04
Show file tree
Hide file tree
Showing 21 changed files with 255 additions and 160 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ jobs:
strategy:
matrix:
os: [ ubuntu, windows, macos ]
fail-fast: false
runs-on: ${{ matrix.os }}-latest
defaults:
run:
Expand All @@ -22,7 +23,7 @@ jobs:
- name: Set up JDK
uses: actions/setup-java@v3
with:
java-version: '8'
java-version: '22'
distribution: 'temurin'
cache: maven

Expand Down
9 changes: 3 additions & 6 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@
<version>4.13.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.zarr.zarrjava</groupId>
<artifactId>zarr-java</artifactId>
<version>0.0.1-SNAPSHOT</version>
<scope>test</scope>
</dependency>
</dependencies>

<repositories>
Expand All @@ -84,4 +78,7 @@
</repository>
</repositories>

<build>
<testSourceDirectory>src/test/java/dev/zarr/zarrjava</testSourceDirectory>
</build>
</project>
20 changes: 20 additions & 0 deletions src/main/java/dev/zarr/zarrjava/utils/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,24 @@ public static <T> T[] concatArrays(T[] array1, T[]... arrays) {
}
return result;
}

public static boolean isPermutation(int[] array) {
if (array.length==0){
return false;
}
int[] arange = new int[array.length];
Arrays.setAll(arange, i -> i);
int[] orderSorted = array.clone();
Arrays.sort(orderSorted);
return Arrays.equals(orderSorted, arange);
}

public static int[] inversePermutation(int[] origin){
assert isPermutation(origin);
int[] inverse = new int[origin.length];
for (int i = 0; i < origin.length; i++) {
inverse[origin[i]] = i;
}
return inverse;
}
}
9 changes: 4 additions & 5 deletions src/main/java/dev/zarr/zarrjava/v3/Array.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ protected Array(StoreHandle storeHandle, ArrayMetadata arrayMetadata)
throws IOException, ZarrException {
super(storeHandle);
this.metadata = arrayMetadata;
this.codecPipeline = new CodecPipeline(arrayMetadata.codecs);
this.codecPipeline = new CodecPipeline(arrayMetadata.codecs, arrayMetadata.coreArrayMetadata);
}

/**
Expand Down Expand Up @@ -171,8 +171,7 @@ public ucar.ma2.Array read(final long[] offset, final int[] shape) throws ZarrEx

if (codecPipeline.supportsPartialDecode()) {
final ucar.ma2.Array chunkArray = codecPipeline.decodePartial(chunkHandle,
Utils.toLongArray(chunkProjection.chunkOffset), chunkProjection.shape,
metadata.coreArrayMetadata);
Utils.toLongArray(chunkProjection.chunkOffset), chunkProjection.shape);
MultiArrayUtils.copyRegion(chunkArray, new int[metadata.ndim()], outputArray,
chunkProjection.outOffset, chunkProjection.shape
);
Expand Down Expand Up @@ -223,7 +222,7 @@ public ucar.ma2.Array readChunk(long[] chunkCoords)
return metadata.allocateFillValueChunk();
}

return codecPipeline.decode(chunkBytes, metadata.coreArrayMetadata);
return codecPipeline.decode(chunkBytes);
}

/**
Expand Down Expand Up @@ -299,7 +298,7 @@ public void writeChunk(long[] chunkCoords, ucar.ma2.Array chunkArray) throws Zar
if (MultiArrayUtils.allValuesEqual(chunkArray, metadata.parsedFillValue)) {
chunkHandle.delete();
} else {
ByteBuffer chunkBytes = codecPipeline.encode(chunkArray, metadata.coreArrayMetadata);
ByteBuffer chunkBytes = codecPipeline.encode(chunkArray);
chunkHandle.set(chunkBytes);
}
}
Expand Down
7 changes: 3 additions & 4 deletions src/main/java/dev/zarr/zarrjava/v3/codec/ArrayArrayCodec.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package dev.zarr.zarrjava.v3.codec;

import dev.zarr.zarrjava.ZarrException;
import dev.zarr.zarrjava.v3.ArrayMetadata.CoreArrayMetadata;
import ucar.ma2.Array;

public interface ArrayArrayCodec extends Codec {
public abstract class ArrayArrayCodec extends Codec {

Array encode(Array chunkArray, CoreArrayMetadata arrayMetadata)
protected abstract Array encode(Array chunkArray)
throws ZarrException;

Array decode(Array chunkArray, CoreArrayMetadata arrayMetadata)
protected abstract Array decode(Array chunkArray)
throws ZarrException;

}
17 changes: 9 additions & 8 deletions src/main/java/dev/zarr/zarrjava/v3/codec/ArrayBytesCodec.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@

import dev.zarr.zarrjava.ZarrException;
import dev.zarr.zarrjava.store.StoreHandle;
import dev.zarr.zarrjava.v3.ArrayMetadata.CoreArrayMetadata;
import java.nio.ByteBuffer;
import ucar.ma2.Array;

public interface ArrayBytesCodec extends Codec {
public abstract class ArrayBytesCodec extends Codec {

ByteBuffer encode(Array chunkArray, CoreArrayMetadata arrayMetadata)
protected abstract ByteBuffer encode(Array chunkArray)
throws ZarrException;

Array decode(ByteBuffer chunkBytes, CoreArrayMetadata arrayMetadata)
protected abstract Array decode(ByteBuffer chunkBytes)
throws ZarrException;

interface WithPartialDecode extends ArrayBytesCodec {
public abstract static class WithPartialDecode extends ArrayBytesCodec {

Array decodePartial(
StoreHandle handle, long[] offset, int[] shape,
CoreArrayMetadata arrayMetadata
public abstract Array decode(ByteBuffer shardBytes) throws ZarrException;
public abstract ByteBuffer encode(Array shardArray) throws ZarrException;

protected abstract Array decodePartial(
StoreHandle handle, long[] offset, int[] shape
) throws ZarrException;
}
}
Expand Down
10 changes: 4 additions & 6 deletions src/main/java/dev/zarr/zarrjava/v3/codec/BytesBytesCodec.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package dev.zarr.zarrjava.v3.codec;

import dev.zarr.zarrjava.ZarrException;
import dev.zarr.zarrjava.v3.ArrayMetadata.CoreArrayMetadata;

import java.nio.ByteBuffer;

public interface BytesBytesCodec extends Codec {
public abstract class BytesBytesCodec extends Codec {

ByteBuffer encode(ByteBuffer chunkBytes, CoreArrayMetadata arrayMetadata)
throws ZarrException;
protected abstract ByteBuffer encode(ByteBuffer chunkBytes) throws ZarrException;

ByteBuffer decode(ByteBuffer chunkBytes, CoreArrayMetadata arrayMetadata)
throws ZarrException;
public abstract ByteBuffer decode(ByteBuffer chunkBytes) throws ZarrException;

}
19 changes: 16 additions & 3 deletions src/main/java/dev/zarr/zarrjava/v3/codec/Codec.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,22 @@
import dev.zarr.zarrjava.v3.ArrayMetadata;

@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "name")
public interface Codec {
public abstract class Codec {

long computeEncodedSize(long inputByteLength, ArrayMetadata.CoreArrayMetadata arrayMetadata)
throws ZarrException;
protected ArrayMetadata.CoreArrayMetadata arrayMetadata;

protected ArrayMetadata.CoreArrayMetadata resolveArrayMetadata() throws ZarrException {
if (arrayMetadata == null) {
throw new ZarrException("arrayMetadata needs to get set in for every codec");
}
return this.arrayMetadata;
}

protected abstract long computeEncodedSize(long inputByteLength, ArrayMetadata.CoreArrayMetadata arrayMetadata)
throws ZarrException;

public void setCoreArrayMetadata(ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException{
this.arrayMetadata = arrayMetadata;
}
}

8 changes: 2 additions & 6 deletions src/main/java/dev/zarr/zarrjava/v3/codec/CodecBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,9 @@ public CodecBuilder withBlosc() {
return withBlosc("zstd");
}

public CodecBuilder withTranspose(String order) {
try {
public CodecBuilder withTranspose(int[] order) {
codecs.add(new TransposeCodec(new TransposeCodec.Configuration(order)));
} catch (ZarrException e) {
throw new RuntimeException(e);
}
return this;
return this;
}

public CodecBuilder withBytes(Endian endian) {
Expand Down
29 changes: 16 additions & 13 deletions src/main/java/dev/zarr/zarrjava/v3/codec/CodecPipeline.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@ public class CodecPipeline {

@Nonnull
final Codec[] codecs;
public final CoreArrayMetadata arrayMetadata;

public CodecPipeline(@Nonnull Codec[] codecs) throws ZarrException {
public CodecPipeline(@Nonnull Codec[] codecs, CoreArrayMetadata arrayMetadata) throws ZarrException {
this.arrayMetadata = arrayMetadata;
long arrayBytesCodecCount = Arrays.stream(codecs).filter(c -> c instanceof ArrayBytesCodec)
.count();
if (arrayBytesCodecCount != 1) {
throw new ZarrException(
"Exactly 1 ArrayBytesCodec is required. Found " + arrayBytesCodecCount + ".");
}
Codec prevCodec = null;
CoreArrayMetadata codecArrayMetadata = arrayMetadata;
for (Codec codec : codecs) {
if (prevCodec != null) {
if (codec instanceof ArrayBytesCodec && prevCodec instanceof ArrayBytesCodec) {
Expand All @@ -44,6 +47,8 @@ public CodecPipeline(@Nonnull Codec[] codecs) throws ZarrException {
prevCodec.getClass() + "'.");
}
}
codec.setCoreArrayMetadata(codecArrayMetadata);
codecArrayMetadata = codec.resolveArrayMetadata();
prevCodec = codec;
}

Expand Down Expand Up @@ -79,15 +84,14 @@ public boolean supportsPartialDecode() {
@Nonnull
public Array decodePartial(
@Nonnull StoreHandle storeHandle,
long[] offset, int[] shape,
@Nonnull CoreArrayMetadata arrayMetadata
long[] offset, int[] shape
) throws ZarrException {
if (!supportsPartialDecode()) {
throw new ZarrException(
"Partial decode is not supported for these codecs. " + Arrays.toString(codecs));
}
Array chunkArray = ((ArrayBytesCodec.WithPartialDecode) getArrayBytesCodec()).decodePartial(
storeHandle, offset, shape, arrayMetadata);
storeHandle, offset, shape);
if (chunkArray == null) {
throw new ZarrException("chunkArray is null. This is likely a bug in one of the codecs.");
}
Expand All @@ -96,8 +100,7 @@ public Array decodePartial(

@Nonnull
public Array decode(
@Nonnull ByteBuffer chunkBytes,
@Nonnull CoreArrayMetadata arrayMetadata
@Nonnull ByteBuffer chunkBytes
) throws ZarrException {
if (chunkBytes == null) {
throw new ZarrException("chunkBytes is null. Ohh nooo.");
Expand All @@ -106,23 +109,23 @@ public Array decode(
BytesBytesCodec[] bytesBytesCodecs = getBytesBytesCodecs();
for (int i = bytesBytesCodecs.length - 1; i >= 0; --i) {
BytesBytesCodec codec = bytesBytesCodecs[i];
chunkBytes = codec.decode(chunkBytes, arrayMetadata);
chunkBytes = codec.decode(chunkBytes);
}

if (chunkBytes == null) {
throw new ZarrException(
"chunkBytes is null. This is likely a bug in one of the codecs. " + Arrays.toString(
getBytesBytesCodecs()));
}
Array chunkArray = getArrayBytesCodec().decode(chunkBytes, arrayMetadata);
Array chunkArray = getArrayBytesCodec().decode(chunkBytes);
if (chunkArray == null) {
throw new ZarrException("chunkArray is null. This is likely a bug in one of the codecs.");
}

ArrayArrayCodec[] arrayArrayCodecs = getArrayArrayCodecs();
for (int i = arrayArrayCodecs.length - 1; i >= 0; --i) {
ArrayArrayCodec codec = arrayArrayCodecs[i];
chunkArray = codec.decode(chunkArray, arrayMetadata);
chunkArray = codec.decode(chunkArray);
}

if (chunkArray == null) {
Expand All @@ -133,16 +136,16 @@ public Array decode(

@Nonnull
public ByteBuffer encode(
@Nonnull Array chunkArray, @Nonnull CoreArrayMetadata arrayMetadata
@Nonnull Array chunkArray
) throws ZarrException {
for (ArrayArrayCodec codec : getArrayArrayCodecs()) {
chunkArray = codec.encode(chunkArray, arrayMetadata);
chunkArray = codec.encode(chunkArray);
}

ByteBuffer chunkBytes = getArrayBytesCodec().encode(chunkArray, arrayMetadata);
ByteBuffer chunkBytes = getArrayBytesCodec().encode(chunkArray);

for (BytesBytesCodec codec : getBytesBytesCodecs()) {
chunkBytes = codec.encode(chunkBytes, arrayMetadata);
chunkBytes = codec.encode(chunkBytes);
}
return chunkBytes;
}
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/dev/zarr/zarrjava/v3/codec/core/BloscCodec.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.nio.ByteBuffer;
import javax.annotation.Nonnull;

public class BloscCodec implements BytesBytesCodec {
public class BloscCodec extends BytesBytesCodec {

public final String name = "blosc";
@Nonnull
Expand All @@ -33,7 +33,7 @@ public BloscCodec(
}

@Override
public ByteBuffer decode(ByteBuffer chunkBytes, ArrayMetadata.CoreArrayMetadata arrayMetadata)
public ByteBuffer decode(ByteBuffer chunkBytes)
throws ZarrException {
try {
return ByteBuffer.wrap(Blosc.decompress(Utils.toArray(chunkBytes)));
Expand All @@ -43,7 +43,7 @@ public ByteBuffer decode(ByteBuffer chunkBytes, ArrayMetadata.CoreArrayMetadata
}

@Override
public ByteBuffer encode(ByteBuffer chunkBytes, ArrayMetadata.CoreArrayMetadata arrayMetadata)
public ByteBuffer encode(ByteBuffer chunkBytes)
throws ZarrException {
try {
return ByteBuffer.wrap(
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/dev/zarr/zarrjava/v3/codec/core/BytesCodec.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import javax.annotation.Nonnull;
import ucar.ma2.Array;

public class BytesCodec implements ArrayBytesCodec {
public class BytesCodec extends ArrayBytesCodec {

public final String name = "bytes";
@Nonnull
Expand All @@ -29,14 +29,14 @@ public BytesCodec(Endian endian) {
}

@Override
public Array decode(ByteBuffer chunkBytes, ArrayMetadata.CoreArrayMetadata arrayMetadata) {
public Array decode(ByteBuffer chunkBytes) {
chunkBytes.order(configuration.endian.getByteOrder());
return Array.factory(arrayMetadata.dataType.getMA2DataType(), arrayMetadata.chunkShape,
chunkBytes);
}

@Override
public ByteBuffer encode(Array chunkArray, ArrayMetadata.CoreArrayMetadata arrayMetadata) {
public ByteBuffer encode(Array chunkArray) {
return chunkArray.getDataAsByteBuffer(configuration.endian.getByteOrder());
}

Expand Down Expand Up @@ -72,7 +72,7 @@ public ByteOrder getByteOrder() {
}
}

public static final class Configuration {
public static final class Configuration{

@Nonnull
public final BytesCodec.Endian endian;
Expand Down
Loading

0 comments on commit d2d3e04

Please sign in to comment.