Skip to content

Commit e5758ad

Browse files
committed
refactor scala object deserialization (#657)
* refactor scala object deserialization * Update build.sbt * Update Classes.scala * refactor beanintrospector * Update build.sbt * Update CaseObjectDeserializerTest.scala Update ScalaObjectDeserializerModule.scala
1 parent 38562f5 commit e5758ad

File tree

4 files changed

+17
-23
lines changed

4 files changed

+17
-23
lines changed

src/main/scala/tools/jackson/module/scala/deser/ScalaObjectDeserializerModule.scala

+5-13
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,14 @@ import tools.jackson.module.scala.util.ClassW
1212
import scala.languageFeature.postfixOps
1313
import scala.util.control.NonFatal
1414

15-
private class ScalaObjectDeserializer(clazz: Class[_]) extends StdDeserializer[Any](classOf[Any]) {
16-
override def deserialize(p: JsonParser, ctxt: DeserializationContext): Any = {
17-
try {
18-
clazz.getField("MODULE$").get(null)
19-
} catch {
20-
case NonFatal(_) => null
21-
}
22-
}
15+
private class ScalaObjectDeserializer(value: Any) extends StdDeserializer[Any](classOf[Any]) {
16+
override def deserialize(p: JsonParser, ctxt: DeserializationContext): Any = value
2317
}
2418

2519
private class ScalaObjectDeserializerResolver(config: ScalaModule.Config) extends Deserializers.Base {
26-
override def findBeanDeserializer(javaType: JavaType, deserializationConfig: DeserializationConfig, beanDesc: BeanDescription): ValueDeserializer[_] = {
27-
val clazz = javaType.getRawClass
28-
if (hasDeserializerFor(deserializationConfig, clazz))
29-
new ScalaObjectDeserializer(clazz)
30-
else null
20+
override def findBeanDeserializer(javaType: JavaType, deserializationConfig: DeserializationConfig, beanDesc: BeanDescription): ValueDeserializer[_] = { ClassW(javaType.getRawClass).getModuleField.flatMap { field =>
21+
Option(field.get(null))
22+
}.map(new ScalaObjectDeserializer(_)).orNull
3123
}
3224

3325
override def hasDeserializerFor(deserializationConfig: DeserializationConfig, valueType: Class[_]): Boolean = {

src/main/scala/tools/jackson/module/scala/introspect/BeanIntrospector.scala

+2-3
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,11 @@ object BeanIntrospector {
184184
//create properties for all appropriate fields
185185
val fields = for {
186186
cls <- hierarchy
187-
scalaCaseObject = isScalaCaseObject(cls)
188-
isScalaObject = ClassW(cls).isScalaObject
189187
field <- cls.getDeclaredFields
188+
isScalaObject = ClassW(cls).isScalaObject || isScalaCaseObject(cls)
190189
name = maybePrivateName(field)
191190
if !name.contains('$')
192-
if (isScalaObject || scalaCaseObject || isAcceptableField(field))
191+
if isScalaObject || isAcceptableField(field)
193192
beanGetter = findBeanGetter(cls, name)
194193
beanSetter = findBeanSetter(cls, name)
195194
} yield PropertyDescriptor(nameOfField(field, name), findConstructorParam(hierarchy.head, name), Some(field), findGetter(cls, name), findSetter(cls, name), beanGetter, beanSetter)

src/main/scala/tools/jackson/module/scala/util/Classes.scala

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package tools.jackson.module.scala.util
22

3+
import java.lang.reflect.Field
34
import scala.annotation.tailrec
45
import scala.language.implicitConversions
56
import scala.reflect.{ScalaLongSignature, ScalaSignature}
@@ -25,9 +26,11 @@ trait ClassW extends PimpedType[Class[_]] {
2526
hasSigHelper(value)
2627
}
2728

28-
def isScalaObject: Boolean = {
29-
Try(value.getField("MODULE$")).isSuccess
30-
}
29+
def isScalaObject: Boolean = moduleField.isSuccess
30+
31+
def getModuleField: Option[Field] = moduleField.toOption
32+
33+
private lazy val moduleField: Try[Field] = Try(value.getField("MODULE$"))
3134
}
3235

3336
object ClassW {

src/test/scala/tools/jackson/module/scala/deser/CaseObjectDeserializerTest.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ class CaseObjectDeserializerTest extends DeserializerTest {
2525
val original = TestObject
2626
val json = mapper.writeValueAsString(original)
2727
val deserialized = mapper.readValue(json, TestObject.getClass)
28-
assert(deserialized == original)
28+
assert(deserialized === original)
2929
}
3030

3131
it should "deserialize Foo and not create a new instance" in {
3232
val mapper = newMapper
3333
val original = Foo
3434
val json = mapper.writeValueAsString(original)
3535
val deserialized = mapper.readValue(json, Foo.getClass)
36-
assert(deserialized == original)
36+
assert(deserialized === original)
3737
}
3838

3939
it should "deserialize Foo and not create a new instance (visibility settings)" in {
@@ -47,15 +47,15 @@ class CaseObjectDeserializerTest extends DeserializerTest {
4747
val original = Foo
4848
val json = mapper.writeValueAsString(original)
4949
val deserialized = mapper.readValue(json, Foo.getClass)
50-
assert(deserialized == original)
50+
assert(deserialized === original)
5151
}
5252

5353
"An ObjectMapper with ClassTagExtensions" should "deserialize a case object and not create a new instance" in {
5454
val mapper = newMapper :: ClassTagExtensions
5555
val original = TestObject
5656
val json = mapper.writeValueAsString(original)
5757
val deserialized = mapper.readValue[TestObject.type](json)
58-
assert(deserialized == original)
58+
assert(deserialized === original)
5959
}
6060

6161
"An ObjectMapper without ScalaObjectDeserializerModule" should "deserialize a case object but create a new instance" in {

0 commit comments

Comments
 (0)