Skip to content
This repository was archived by the owner on Nov 20, 2019. It is now read-only.

Commit 802fe85

Browse files
authored
RowSerializer: Safe guard against double types with Decimal schema tags (#799)
1 parent 8fcd5d6 commit 802fe85

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

Diff for: common/src/main/scala/com/stratio/crossdata/common/serializers/RowSerializer.scala

+2
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ case class RowSerializer(providedSchema: StructType) extends Serializer[Row] {
9595
case (DoubleType, v: Double) => JDouble(v)
9696
case (LongType, v: Long) => JInt(v)
9797
case (_: DecimalType, v: Decimal) => JDecimal(v.toBigDecimal)
98+
case (_: DecimalType, v: Double) => JDecimal(BigDecimal(v))
99+
case (_: DecimalType, v: Float) => JDecimal(BigDecimal(v))
98100
case (ByteType, v: Byte) => JInt(v.toInt)
99101
case (BinaryType, v: Array[Byte]) => JString(new String(v))
100102
case (BooleanType, v: Boolean) => JBool(v)

Diff for: common/src/test/scala/com/stratio/crossdata/common/serializers/RowSerializerSpec.scala

+20-1
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@ import org.apache.spark.sql.Row
2020
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
2121
import org.apache.spark.sql.types._
2222
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
23+
import org.json4s.Extraction
24+
import org.json4s.jackson.JsonMethods.{compact, parse, render}
2325
import org.junit.runner.RunWith
26+
import org.scalatest.Inside
2427
import org.scalatest.junit.JUnitRunner
2528

2629
import scala.collection.mutable.WrappedArray
2730

2831
@RunWith(classOf[JUnitRunner])
29-
class RowSerializerSpec extends XDSerializationTest[Row] with CrossdataCommonSerializer {
32+
class RowSerializerSpec extends XDSerializationTest[Row] with CrossdataCommonSerializer with Inside {
3033

3134
lazy val schema = StructType(List(
3235
StructField("int",IntegerType,true),
@@ -112,4 +115,20 @@ class RowSerializerSpec extends XDSerializationTest[Row] with CrossdataCommonSer
112115
TestCase("marshall & unmarshall a row with schema", rowWithSchema)
113116
)
114117

118+
it should " be able to recover Double values when their schema type is misleading" in {
119+
120+
val schema = StructType(List(StructField("decimaldouble", DecimalType(10,1),true)))
121+
val row = Row.fromSeq(Array(32.1))
122+
123+
val formats = json4sJacksonFormats + new RowSerializer(schema)
124+
125+
val serialized = compact(render(Extraction.decompose(row)(formats)))
126+
val extracted = parse(serialized, false).extract[Row](formats, implicitly[Manifest[Row]])
127+
128+
inside(extracted) {
129+
case r: Row => r.get(0) shouldBe Decimal(32.1)
130+
}
131+
132+
}
133+
115134
}

0 commit comments

Comments
 (0)