Skip to content

Commit

Permalink
Classifier (#1353)
Browse files Browse the repository at this point in the history
* the best version so far

* matche python mostly

* before major change, kinda works

* long one, all numpy stuff

* classfier v1

* Seems to be working

* stlye check

* removed some not used stuff

* add additional convert methods

* some test fixes

* rehaul

* changes from code review comments
  • Loading branch information
matakleo committed Jun 6, 2024
1 parent b5ef7fa commit 5cc4025
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 3 deletions.
1 change: 1 addition & 0 deletions cdm/core/src/main/java/ucar/nc2/constants/CDM.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class CDM {
public static final String RUNTIME_COORDINATE = "runtimeCoordinate";
public static final String STANDARDIZE = "standardize";
public static final String NORMALIZE = "normalize";
public static final String CLASSIFY = "classify";

// Special attributes

Expand Down
11 changes: 8 additions & 3 deletions cdm/core/src/main/java/ucar/nc2/dataset/NetcdfDataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,16 @@ public enum Enhance {
* If the enhanced data type is not {@code FLOAT} or {@code DOUBLE}, this has no effect.
*/
ApplyNormalizer,
/**
* Classify doubles or floats based on positive/negative into 1 or 0 {@code}
* x<0 --> 0 x>0 --> 1
*/
ApplyClassifier,
}

private static Set<Enhance> EnhanceAll =
Collections.unmodifiableSet(EnumSet.of(Enhance.ConvertEnums, Enhance.ConvertUnsigned, Enhance.ApplyScaleOffset,
Enhance.ConvertMissing, Enhance.CoordSystems, Enhance.ApplyStandardizer, Enhance.ApplyNormalizer));
private static Set<Enhance> EnhanceAll = Collections.unmodifiableSet(
EnumSet.of(Enhance.ConvertEnums, Enhance.ConvertUnsigned, Enhance.ApplyScaleOffset, Enhance.ConvertMissing,
Enhance.CoordSystems, Enhance.ApplyStandardizer, Enhance.ApplyNormalizer, Enhance.ApplyClassifier));
private static Set<Enhance> EnhanceNone = Collections.unmodifiableSet(EnumSet.noneOf(Enhance.class));
private static Set<Enhance> defaultEnhanceMode = EnhanceAll;

Expand Down
8 changes: 8 additions & 0 deletions cdm/core/src/main/java/ucar/nc2/dataset/VariableDS.java
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ Array convert(Array data, Set<NetcdfDataset.Enhance> enhancements) {
if (enhancements.contains(Enhance.ApplyNormalizer) && normalizer != null) {
toApply.add(normalizer);
}
if (enhancements.contains(Enhance.ApplyClassifier) && classifier != null) {
toApply.add(classifier);
}

double[] dataArray = (double[]) data.get1DJavaArray(DataType.DOUBLE);

Expand Down Expand Up @@ -865,6 +868,7 @@ public Array convert(Array in, boolean convertUnsigned, boolean applyScaleOffset
private ScaleOffset scaleOffset;
private Standardizer standardizer;
private Normalizer normalizer;
private Classifier classifier;
private ConvertMissing convertMissing;
private Set<Enhance> enhanceMode = EnumSet.noneOf(Enhance.class); // The set of enhancements that were made.

Expand Down Expand Up @@ -939,6 +943,10 @@ private void createEnhancements() {
if (normalizerAtt != null && this.enhanceMode.contains(Enhance.ApplyNormalizer) && dataType.isFloatingPoint()) {
this.normalizer = Normalizer.createFromVariable(this);
}
Attribute classifierAtt = findAttribute(CDM.CLASSIFY);
if (classifierAtt != null && this.enhanceMode.contains(Enhance.ApplyClassifier) && dataType.isNumeric()) {
this.classifier = Classifier.createFromVariable(this);
}
}

public Builder<?> toBuilder() {
Expand Down
70 changes: 70 additions & 0 deletions cdm/core/src/main/java/ucar/nc2/filter/Classifier.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package ucar.nc2.filter;

import java.io.IOException;
import ucar.ma2.Array;
import ucar.ma2.DataType;
import ucar.ma2.IndexIterator;
import ucar.nc2.dataset.VariableDS;

public class Classifier implements Enhancement {
private Classifier classifier = null;
private static Classifier emptyClassifier;
private int classifiedVal;
private int[] classifiedArray;

public static Classifier createFromVariable(VariableDS var) {
try {
Array arr = var.read();
// DataType type = var.getDataType();
return emptyClassifier();
} catch (IOException e) {
return emptyClassifier();
}
}

public static Classifier emptyClassifier() {
emptyClassifier = new Classifier();
return emptyClassifier;
}

/** Enough of a constructor */
public Classifier() {}

/** Classify double array */
public int[] classifyDoubleArray(Array arr) {
int[] classifiedArray = new int[(int) arr.getSize()];
int i = 0;
IndexIterator iterArr = arr.getIndexIterator();
while (iterArr.hasNext()) {
Number value = (Number) iterArr.getObjectNext();
if (!Double.isNaN(value.doubleValue())) {

classifiedArray[i] = classifyArray(value.doubleValue());
}
i++;
}
return classifiedArray;
}



/** for a single double */
public int classifyArray(double val) {
if (val >= 0) {
classifiedVal = 1;
} else {
classifiedVal = 0;
}

return classifiedVal;
}

@Override
public double convert(double val) {
return emptyClassifier.classifyArray(val);
}


}


54 changes: 54 additions & 0 deletions cdm/core/src/test/data/ncml/enhance/testClassifier.ncml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright (c) 1998-2023 University Corporation for Atmospheric Research/Unidata
~ See LICENSE for license information.
-->

<netcdf xmlns="http://www.unidata.ucar.edu/namespaces/netcdf/ncml-2.2" enhance="all">

<variable name="doublePositives" shape="5" type="double">
<attribute name="classify"/>
<values>1.0 2.0 3.0 4.0 5.0</values>
</variable>

<variable name="doubleNegatives" shape="5" type="double">
<attribute name="classify"/>
<values>-1.0 -2.0 -3.0 -4.0 -5.0</values>
</variable>

<variable name="doubleMix" shape="5" type="double">
<attribute name="classify"/>
<values>1.0 -2.0 0.0 4.0 -5.0</values>
</variable>

<variable name="floatPositives" shape="5" type="float">
<attribute name="classify"/>
<values>1.0 2.0 3.0 4.0 5.0</values>
</variable>

<variable name="floatNegatives" shape="5" type="float">
<attribute name="classify"/>
<values>-1.0 -2.0 -3.0 -4.0 -5.0</values>
</variable>
<variable name="floatMix" shape="5" type="float">
<attribute name="classify"/>
<values>1.0 -2.0 0.0 4.0 -5.0</values>
</variable>

<variable name="intPositives" shape="5" type="int">
<attribute name="classify"/>
<values>1 2 3 4 5</values>
</variable>

<variable name="intNegatives" shape="5" type="int">
<attribute name="classify"/>
<values>-1.0 -2.0 -3.0 -4.0 -5.0</values>
</variable>
<variable name="intMix" shape="5" type="int">
<attribute name="classify"/>
<values>1.0 -2.0 0.0 4.0 -5.0</values>
</variable>



</netcdf>
51 changes: 51 additions & 0 deletions cdm/core/src/test/java/ucar/nc2/filter/TestClassifier.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package ucar.nc2.filter;

import static org.junit.Assert.*;
import org.junit.Test;
import ucar.ma2.Array;


public class TestClassifier {



/** test doubles */
@Test
public void testClassifyDoubleArray_AllPositive() {
Classifier classifier = new Classifier();
double[] input = {1.1, 2.2, 3.3};
int[] expected = {1, 1, 1};
Array DATA = Array.makeFromJavaArray(input);
assertArrayEquals(expected, classifier.classifyDoubleArray(DATA));
}

@Test
public void testClassifyDoubleArray_AllNegative() {
Classifier classifier = new Classifier();
double[] input = {-1.1, -2.2, -3.3};
int[] expected = {0, 0, 0};
Array DATA = Array.makeFromJavaArray(input);
assertArrayEquals(expected, classifier.classifyDoubleArray(DATA));
}

@Test
public void testClassifyDoubleArray_Mixed() {
Classifier classifier = new Classifier();
double[] input = {-1.1, 2.2, -3.3, 4.4};
int[] expected = {0, 1, 0, 1};
Array DATA = Array.makeFromJavaArray(input);
assertArrayEquals(expected, classifier.classifyDoubleArray(DATA));
}

@Test
public void testClassifyDoubleArray_WithZero() {
Classifier classifier = new Classifier();
double[] input = {0.0, -1.1, 1.1, 0.0, 0.0, 0.0};
int[] expected = {1, 0, 1, 1, 1, 1};
Array DATA = Array.makeFromJavaArray(input);
assertArrayEquals(expected, classifier.classifyDoubleArray(DATA));
}



}
115 changes: 115 additions & 0 deletions cdm/core/src/test/java/ucar/nc2/ncml/TestEnhanceClassifier.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package ucar.nc2.ncml;

import static com.google.common.truth.Truth.assertThat;
import static ucar.ma2.MAMath.nearlyEquals;

import java.io.IOException;
import org.junit.Test;
import ucar.ma2.Array;
import ucar.ma2.DataType;
import ucar.nc2.NetcdfFile;
import ucar.nc2.Variable;
import ucar.nc2.dataset.NetcdfDatasets;
import ucar.unidata.util.test.TestDir;

public class TestEnhanceClassifier {

private static String dataDir = TestDir.cdmLocalTestDataDir + "ncml/enhance/";

public static final int[] all_ones = {1, 1, 1, 1, 1};
public static final Array DATA_all_ones = Array.makeFromJavaArray(all_ones);
public static final int[] all_zeroes = {0, 0, 0, 0, 0};
public static final Array DATA_all_zeroes = Array.makeFromJavaArray(all_zeroes);
public static final int[] mixNumbers = {1, 0, 1, 1, 0};
public static final Array DATA_mixNumbers = Array.makeFromJavaArray(mixNumbers);


/** test on doubles, all positives, all negatives and a mixed array */
@Test
public void testEnhanceClassifier_doubles() throws IOException {
try (NetcdfFile ncfile = NetcdfDatasets.openDataset(dataDir + "testClassifier.ncml", true, null)) {
Variable doublePositives = ncfile.findVariable("doublePositives");
assertThat((Object) doublePositives).isNotNull();
assertThat(doublePositives.getDataType()).isEqualTo(DataType.DOUBLE);
assertThat(doublePositives.attributes().hasAttribute("classify")).isTrue();
Array dataDoubles = doublePositives.read();
assertThat(nearlyEquals(dataDoubles, DATA_all_ones)).isTrue();

Variable doubleNegatives = ncfile.findVariable("doubleNegatives");
assertThat((Object) doubleNegatives).isNotNull();
assertThat(doubleNegatives.getDataType()).isEqualTo(DataType.DOUBLE);
assertThat(doubleNegatives.attributes().hasAttribute("classify")).isTrue();
Array datadoubleNegatives = doubleNegatives.read();
assertThat(nearlyEquals(datadoubleNegatives, DATA_all_zeroes)).isTrue();

Variable doubleMix = ncfile.findVariable("doubleMix");
assertThat((Object) doubleMix).isNotNull();
assertThat(doubleMix.getDataType()).isEqualTo(DataType.DOUBLE);
assertThat(doubleMix.attributes().hasAttribute("classify")).isTrue();
Array datadoubleMix = doubleMix.read();
assertThat(nearlyEquals(datadoubleMix, DATA_mixNumbers)).isTrue();

}


}

/** test on floats, all positives, all negatives and a mixed array */
@Test
public void testEnhanceClassifier_floats() throws IOException {
try (NetcdfFile ncfile = NetcdfDatasets.openDataset(dataDir + "testClassifier.ncml", true, null)) {

Variable floatPositives = ncfile.findVariable("floatPositives");
assertThat((Object) floatPositives).isNotNull();
assertThat(floatPositives.getDataType()).isEqualTo(DataType.FLOAT);
assertThat(floatPositives.attributes().hasAttribute("classify")).isTrue();
Array datafloats = floatPositives.read();
assertThat(nearlyEquals(datafloats, DATA_all_ones)).isTrue();

Variable floatNegatives = ncfile.findVariable("floatNegatives");
assertThat((Object) floatNegatives).isNotNull();
assertThat(floatNegatives.getDataType()).isEqualTo(DataType.FLOAT);
assertThat(floatNegatives.attributes().hasAttribute("classify")).isTrue();
Array datafloatNegatives = floatNegatives.read();
assertThat(nearlyEquals(datafloatNegatives, DATA_all_zeroes)).isTrue();

Variable floatMix = ncfile.findVariable("floatMix");
assertThat((Object) floatMix).isNotNull();
assertThat(floatMix.getDataType()).isEqualTo(DataType.FLOAT);
assertThat(floatMix.attributes().hasAttribute("classify")).isTrue();
Array datafloatsMix = floatMix.read();
assertThat(nearlyEquals(datafloatsMix, DATA_mixNumbers)).isTrue();

}

}

/** enhance is not applied to Integers, so we expect the same values after application */
@Test
public void testEnhanceClassifier_integers() throws IOException {

try (NetcdfFile ncfile = NetcdfDatasets.openDataset(dataDir + "testClassifier.ncml", true, null)) {
Variable IntegerPositives = ncfile.findVariable("intPositives");
assertThat((Object) IntegerPositives).isNotNull();
assertThat(IntegerPositives.getDataType()).isEqualTo(DataType.INT);
assertThat(IntegerPositives.attributes().hasAttribute("classify")).isTrue();
Array dataIntegers = IntegerPositives.read();
assertThat(nearlyEquals(dataIntegers, DATA_all_ones)).isTrue();

Variable intNegatives = ncfile.findVariable("intNegatives");
assertThat((Object) intNegatives).isNotNull();
assertThat(intNegatives.getDataType()).isEqualTo(DataType.INT);
assertThat(intNegatives.attributes().hasAttribute("classify")).isTrue();
Array dataintNegatives = intNegatives.read();
assertThat(nearlyEquals(dataintNegatives, DATA_all_zeroes)).isTrue();

Variable intMix = ncfile.findVariable("intMix");
assertThat((Object) intMix).isNotNull();
assertThat(intMix.getDataType()).isEqualTo(DataType.INT);
assertThat(intMix.attributes().hasAttribute("classify")).isTrue();
Array dataintMix = intMix.read();
assertThat(nearlyEquals(dataintMix, DATA_mixNumbers)).isTrue();
}

}
}

0 comments on commit 5cc4025

Please sign in to comment.