Skip to content

Commit

Permalink
NIFI-11177: Add defensive code for null values for Iceberg
Browse files Browse the repository at this point in the history
  • Loading branch information
mattyb149 committed Sep 22, 2023
1 parent 9b591a2 commit 43a4834
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ public static Object convertType(final Object value, final DataType dataType, fi

public static UUID toUUID(Object value) {
if (value == null) {
throw new IllegalTypeConversionException("Null values cannot be converted to a UUID");
return null;
}

if (value instanceof UUID) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.sql.Time;
import java.sql.Timestamp;
import java.time.LocalDateTime;
import java.time.LocalTime;
Expand Down Expand Up @@ -91,7 +92,8 @@ public TimeConverter(final String format) {

@Override
public LocalTime convert(Object data) {
return DataTypeUtils.toTime(data, () -> DataTypeUtils.getDateFormat(timeFormat), null).toLocalTime();
Time time = DataTypeUtils.toTime(data, () -> DataTypeUtils.getDateFormat(timeFormat), null);
return time == null ? null : time.toLocalTime();
}
}

Expand All @@ -106,7 +108,7 @@ public TimestampConverter(final DataType dataType) {
@Override
public LocalDateTime convert(Object data) {
final Timestamp convertedTimestamp = DataTypeUtils.toTimestamp(data, () -> DataTypeUtils.getDateFormat(dataType.getFormat()), null);
return convertedTimestamp.toLocalDateTime();
return convertedTimestamp == null ? null : convertedTimestamp.toLocalDateTime();
}
}

Expand All @@ -121,14 +123,17 @@ public TimestampWithTimezoneConverter(final DataType dataType) {
@Override
public OffsetDateTime convert(Object data) {
final Timestamp convertedTimestamp = DataTypeUtils.toTimestamp(data, () -> DataTypeUtils.getDateFormat(dataType.getFormat()), null);
return OffsetDateTime.ofInstant(convertedTimestamp.toInstant(), ZoneId.of("UTC"));
return convertedTimestamp == null ? null : OffsetDateTime.ofInstant(convertedTimestamp.toInstant(), ZoneId.of("UTC"));
}
}

static class UUIDtoByteArrayConverter extends DataConverter<Object, byte[]> {

@Override
public byte[] convert(Object data) {
if (data == null) {
return null;
}
final UUID uuid = DataTypeUtils.toUUID(data);
ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]);
byteBuffer.putLong(uuid.getMostSignificantBits());
Expand All @@ -147,6 +152,9 @@ static class FixedConverter extends DataConverter<Byte[], byte[]> {

@Override
public byte[] convert(Byte[] data) {
if (data == null) {
return null;
}
Validate.isTrue(data.length == length, String.format("Cannot write byte array of length %s as fixed[%s]", data.length, length));
return ArrayUtils.toPrimitive(data);
}
Expand All @@ -156,6 +164,9 @@ static class BinaryConverter extends DataConverter<Byte[], ByteBuffer> {

@Override
public ByteBuffer convert(Byte[] data) {
if (data == null) {
return null;
}
return ByteBuffer.wrap(ArrayUtils.toPrimitive(data));
}
}
Expand All @@ -171,6 +182,9 @@ static class BigDecimalConverter extends DataConverter<Object, BigDecimal> {

@Override
public BigDecimal convert(Object data) {
if (data == null) {
return null;
}
if (data instanceof BigDecimal) {
BigDecimal bigDecimal = (BigDecimal) data;
Validate.isTrue(bigDecimal.scale() == scale, "Cannot write value as decimal(%s,%s), wrong scale %s for value: %s", precision, scale, bigDecimal.scale(), data);
Expand All @@ -194,6 +208,9 @@ static class ArrayConverter<S, T> extends DataConverter<S[], List<T>> {
@Override
@SuppressWarnings("unchecked")
public List<T> convert(S[] data) {
if (data == null) {
return null;
}
final int numElements = data.length;
final List<T> result = new ArrayList<>(numElements);
for (int i = 0; i < numElements; i += 1) {
Expand All @@ -219,6 +236,9 @@ static class MapConverter<SK, SV, TK, TV> extends DataConverter<Map<SK, SV>, Map
@Override
@SuppressWarnings("unchecked")
public Map<TK, TV> convert(Map<SK, SV> data) {
if (data == null) {
return null;
}
final int mapSize = data.size();
final Object[] keyArray = data.keySet().toArray();
final Object[] valueArray = data.values().toArray();
Expand Down Expand Up @@ -253,6 +273,9 @@ static class RecordConverter extends DataConverter<Record, GenericRecord> {

@Override
public GenericRecord convert(Record data) {
if (data == null) {
return null;
}
final GenericRecord record = GenericRecord.create(schema);

for (DataConverter<?, ?> converter : converters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.condition.OS.WINDOWS;
Expand Down Expand Up @@ -471,6 +472,33 @@ public void testPrimitives(FileFormat format) throws IOException {
} else {
assertEquals(UUID.fromString("0000-00-00-00-000000"), resultRecord.get(13, UUID.class));
}

// Test null values
for (String fieldName : record.getRawFieldNames()) {
record.setValue(fieldName, null);
}

genericRecord = recordConverter.convert(record);

writeTo(format, PRIMITIVES_SCHEMA, genericRecord, tempFile);

results = readFrom(format, PRIMITIVES_SCHEMA, tempFile.toInputFile());

assertEquals(results.size(), 1);
resultRecord = results.get(0);
assertNull(resultRecord.get(0, String.class));
assertNull(resultRecord.get(1, Integer.class));
assertNull(resultRecord.get(2, Float.class));
assertNull(resultRecord.get(3, Long.class));
assertNull(resultRecord.get(4, Double.class));
assertNull(resultRecord.get(5, BigDecimal.class));
assertNull(resultRecord.get(6, Boolean.class));
assertNull(resultRecord.get(7));
assertNull(resultRecord.get(8));
assertNull(resultRecord.get(9, LocalDate.class));
assertNull(resultRecord.get(10, LocalTime.class));
assertNull(resultRecord.get(11, OffsetDateTime.class));
assertNull(resultRecord.get(14, Integer.class));
}

@DisabledOnOs(WINDOWS)
Expand Down

0 comments on commit 43a4834

Please sign in to comment.