Skip to content
This repository has been archived by the owner. It is now read-only.

Commit 23c9332

Browse files
Chickerfomkin
authored andcommitted
fix processing fields of type like Option[Seq[A]]
1 parent 569a57d commit 23c9332

File tree

3 files changed

+145
-63
lines changed

3 files changed

+145
-63
lines changed

core/src/main/scala/zhukov/SizeMeter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,6 @@ object SizeMeter {
4141
CodedOutputStream.computeRawVarint32Size(len) + len
4242
})
4343

44-
implicit def iterable[A, Col[_] <: Iterable[A]](implicit sm: SizeMeter[A]): SizeMeter[Col[A]] =
44+
implicit def iterable[A, Col[A] <: Iterable[A]](implicit sm: SizeMeter[A]): SizeMeter[Col[A]] =
4545
SizeMeter(xs => sm.measureValues(xs.toIterable))
4646
}

derivation/src/main/scala/zhukov/derivation/ZhukovDerivationMacro.scala

Lines changed: 98 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class ZhukovDerivationMacro(val c: blackbox.Context) {
1616
private val marshallerCache = TrieMap.empty[Type, Tree]
1717
private val unmarshallerCache = TrieMap.empty[Type, Tree]
1818
private val sizeMeterCache = TrieMap.empty[Type, Tree]
19+
private val iterableType = typeOf[Iterable[_]]
1920

2021
def unmarshallerImpl[T: WeakTypeTag]: Tree = {
2122
val T = weakTypeTag[T].tpe
@@ -210,13 +211,12 @@ class ZhukovDerivationMacro(val c: blackbox.Context) {
210211
case _ => EmptyTree
211212
}
212213

213-
private def commonMarshaller(T: Type, x: Field): Tree = x match {
214-
case Field(i, nameOpt, _, _, _, Some(tpe), _, false) =>
215-
val name = nameOpt.fold[Tree](Ident(TermName("_v")))(s => q"_value.$s")
214+
private def commonMarshaller(T: Type, x: Field): Tree = {
215+
def writer0(tpe: Type, indexField: Int, name: Tree) = {
216216
inferMarshallerWireType(tpe) match {
217217
case VarInt | Fixed32 | Fixed64 => // Packed
218218
q"""
219-
_stream.writeTag($i, ${WireFormat.WIRETYPE_LENGTH_DELIMITED})
219+
_stream.writeTag($indexField, ${WireFormat.WIRETYPE_LENGTH_DELIMITED})
220220
_stream.writeRawVarint32(implicitly[zhukov.SizeMeter[$tpe]].measureValues($name))
221221
val _i = $name.iterator
222222
while (_i.hasNext)
@@ -226,7 +226,7 @@ class ZhukovDerivationMacro(val c: blackbox.Context) {
226226
q"""
227227
val _i = $name.iterator
228228
while (_i.hasNext) {
229-
_stream.writeTag($i, ${WireFormat.WIRETYPE_LENGTH_DELIMITED})
229+
_stream.writeTag($indexField, ${WireFormat.WIRETYPE_LENGTH_DELIMITED})
230230
implicitly[zhukov.Marshaller[$tpe]].write(_stream, _i.next())
231231
}
232232
"""
@@ -235,102 +235,134 @@ class ZhukovDerivationMacro(val c: blackbox.Context) {
235235
val _i = $name.iterator
236236
while (_i.hasNext) {
237237
val _v = _i.next()
238-
_stream.writeTag($i, ${WireFormat.WIRETYPE_LENGTH_DELIMITED})
238+
_stream.writeTag($indexField, ${WireFormat.WIRETYPE_LENGTH_DELIMITED})
239239
_stream.writeRawVarint32(implicitly[zhukov.SizeMeter[$tpe]].measure(_v))
240240
implicitly[zhukov.Marshaller[$tpe]].write(_stream, _v)
241241
}
242242
"""
243243
}
244-
case Field(i, nameOpt, _, _, tpe, maybeRepTpe, _, isOption) =>
245-
def writer(tpe: Type, name: Tree) = inferMarshallerWireType(tpe) match {
246-
case LengthDelimited =>
247-
q"""
248-
_stream.writeTag($i, ${WireFormat.WIRETYPE_LENGTH_DELIMITED})
244+
}
245+
def writer(tpe: Type, indexField: Int, name: Tree) = inferMarshallerWireType(tpe) match {
246+
case LengthDelimited =>
247+
q"""
248+
_stream.writeTag($indexField, ${WireFormat.WIRETYPE_LENGTH_DELIMITED})
249249
_stream.writeRawVarint32(implicitly[zhukov.SizeMeter[$tpe]].measure($name))
250250
implicitly[zhukov.Marshaller[$tpe]].write(_stream, $name)
251251
"""
252-
case wireType =>
253-
q"""
254-
_stream.writeTag($i, ${wireType.value})
252+
case wireType =>
253+
q"""
254+
_stream.writeTag($indexField, ${wireType.value})
255255
implicitly[zhukov.Marshaller[$tpe]].write(_stream, $name)
256256
"""
257-
}
258-
val name = nameOpt.fold[Tree](Ident(TermName("_v")))(s => q"_value.$s")
259-
if (isOption) {
257+
}
258+
259+
x match {
260+
case Field(i, nameOpt, _, _, tpe, Some(repTpe), _, true) if repTpe <:< iterableType =>
261+
val concreteType = repTpe.typeArgs.head
262+
val name = nameOpt.fold[Tree](Ident(TermName("_v")))(s => q"_value.$s")
263+
260264
q"""
261265
$name match {
262-
case Some(__extracted) => ${writer(maybeRepTpe.get, Ident(TermName("__extracted")))}
266+
case Some(__extracted) => ${writer0(concreteType, i, Ident(TermName("__extracted")))}
263267
case None => ()
264268
}
265269
"""
266-
} else writer(tpe, name)
267-
case _ => EmptyTree
270+
case Field(i, nameOpt, _, _, _, Some(repTpe), _, false) =>
271+
val name = nameOpt.fold[Tree](Ident(TermName("_v")))(s => q"_value.$s")
272+
writer0(repTpe, i, name)
273+
case Field(i, nameOpt, _, _, tpe, maybeRepTpe, _, isOption) =>
274+
val name = nameOpt.fold[Tree](Ident(TermName("_v")))(s => q"_value.$s")
275+
if (isOption && maybeRepTpe.isDefined) {
276+
q"""
277+
$name match {
278+
case Some(__extracted) => ${writer(maybeRepTpe.get, i, Ident(TermName("__extracted")))}
279+
case None => ()
280+
}
281+
"""
282+
} else writer(tpe, i, name)
283+
case _ => EmptyTree
284+
}
268285
}
269286

270287
private val mapSymbol = c.typecheck(tq"Map[_, _]")
271288
.tpe
272289
.typeSymbol
273290

274291
private def commonUnmarshaller(T: Type, fields: List[Field]): Tree = {
292+
def writer(tpe: Type, indexField: Int, name: TermName): List[Tree] = {
293+
val wireType = inferUnmarshallerWireType(tpe)
294+
val tag = WireFormat.makeTag(indexField, wireType.value)
295+
val singleRead = q"implicitly[zhukov.Unmarshaller[$tpe]].read(_stream)"
296+
wireType match {
297+
case VarInt | Fixed32 | Fixed64 => // Packed
298+
val repTag = WireFormat.makeTag(indexField, WireFormat.WIRETYPE_LENGTH_DELIMITED)
299+
val `case` =
300+
cq"""
301+
$repTag =>
302+
val _length = _stream.readRawVarint32()
303+
val _oldLimit = _stream.pushLimit(_length)
304+
while (_stream.getBytesUntilLimit > 0)
305+
$name += $singleRead
306+
_stream.popLimit(_oldLimit)
307+
"""
308+
List(`case`, cq"$tag => $name += $singleRead")
309+
case Coded =>
310+
List(cq"$tag => $name += $singleRead")
311+
case LengthDelimited =>
312+
List(
313+
cq"""$tag =>
314+
val _length = _stream.readRawVarint32()
315+
val _oldLimit = _stream.pushLimit(_length)
316+
$name += $singleRead
317+
_stream.checkLastTagWas(0)
318+
_stream.popLimit(_oldLimit)
319+
""")
320+
}
321+
}
275322
val vars = fields.groupBy(_.varName).mapValues(_.head).collect {
276323
case (name, Field(_, _, _, Some(default), repTpe, Some(tpe), None, false)) =>
277324
q"var $name = ${repTpe.typeSymbol.companion}.newBuilder[..${tpe.typeArgs}] ++= $default"
278-
case (name, Field(_, _, _, Some(default), repTpe, Some(_), None, true)) =>
279-
q"var $name:$repTpe = $default"
325+
case (name, Field(_, _, _, Some(_), _, Some(repTpe), None, true)) if (repTpe <:< iterableType) =>
326+
q"var $name = ${repTpe.typeSymbol.companion}.newBuilder[..${repTpe.typeArgs}]"
327+
case (name, Field(_, _, _, Some(default), tpe, Some(_), None, true)) =>
328+
q"var $name:$tpe = $default"
280329
case (name, Field(_, _, _, Some(default), _, None, None, _)) =>
281330
q"var $name = $default"
282331
case (name, Field(_, _, _, None, _, None, Some(parent), _)) =>
283332
q"var $name:$parent = null"
284333
}
285334
val cases = fields.flatMap { x =>
286335
val tpe = x.repTpe.getOrElse(x.tpe)
287-
val wireType = inferUnmarshallerWireType(tpe)
288-
val tag = WireFormat.makeTag(x.index, wireType.value)
289-
val singleRead = q"implicitly[zhukov.Unmarshaller[$tpe]].read(_stream)"
290-
if (x.repTpe.isEmpty || x.isOption) {
291-
val read =
292-
if (x.isOption) q"Some($singleRead)"
293-
else singleRead
294-
wireType match {
295-
case LengthDelimited =>
296-
List(
297-
cq"""
336+
337+
if (tpe <:< iterableType) {
338+
val concreteType = tpe.typeArgs.head
339+
writer(concreteType, x.index, x.varName)
340+
} else {
341+
342+
val wireType = inferUnmarshallerWireType(tpe)
343+
val tag = WireFormat.makeTag(x.index, wireType.value)
344+
val singleRead = q"implicitly[zhukov.Unmarshaller[$tpe]].read(_stream)"
345+
if (x.repTpe.isEmpty || x.isOption) {
346+
val read =
347+
if (x.isOption) q"Some($singleRead)"
348+
else singleRead
349+
wireType match {
350+
case LengthDelimited =>
351+
List(
352+
cq"""
298353
$tag =>
299354
val _length = _stream.readRawVarint32()
300355
val _oldLimit = _stream.pushLimit(_length)
301356
${x.varName} = $read
302357
_stream.checkLastTagWas(0)
303358
_stream.popLimit(_oldLimit)
304359
"""
305-
)
306-
case _ =>
307-
List(cq"$tag => ${x.varName} = $read")
308-
}
309-
} else {
310-
wireType match {
311-
case VarInt | Fixed32 | Fixed64 => // Packed
312-
val repTag = WireFormat.makeTag(x.index, WireFormat.WIRETYPE_LENGTH_DELIMITED)
313-
val `case` =
314-
cq"""
315-
$repTag =>
316-
val _length = _stream.readRawVarint32()
317-
val _oldLimit = _stream.pushLimit(_length)
318-
while (_stream.getBytesUntilLimit > 0)
319-
${x.varName} += $singleRead
320-
_stream.popLimit(_oldLimit)
321-
"""
322-
List(`case`, cq"$tag => ${x.varName} += $singleRead")
323-
case Coded =>
324-
List(cq"$tag => ${x.varName} += $singleRead")
325-
case LengthDelimited =>
326-
List(
327-
cq"""$tag =>
328-
val _length = _stream.readRawVarint32()
329-
val _oldLimit = _stream.pushLimit(_length)
330-
${x.varName} += $singleRead
331-
_stream.checkLastTagWas(0)
332-
_stream.popLimit(_oldLimit)
333-
""")
360+
)
361+
case _ =>
362+
List(cq"$tag => ${x.varName} = $read")
363+
}
364+
} else {
365+
writer(tpe, x.index, x.varName)
334366
}
335367
}
336368
}
@@ -400,6 +432,10 @@ class ZhukovDerivationMacro(val c: blackbox.Context) {
400432
q"$originalName = $varName.result()"
401433
case Field(_, Some(originalName), varName, _, _, None, _, false) =>
402434
q"$originalName = $varName"
435+
case Field(_, Some(originalName), varName, _, _, Some(repTpe), _, true) if (repTpe <:< iterableType) =>
436+
q"$originalName = if ($varName.result().isEmpty) None else Some($varName.result())"
437+
case Field(_, Some(originalName), varName, _, _, Some(_), _, true) =>
438+
q"$originalName = $varName"
403439
case Field(_, Some(originalName), varName, _, _, _, _, true) =>
404440
q"$originalName = $varName"
405441
}

derivation/src/test/scala/CompareWithScalapbTest.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,27 @@ object CompareWithScalapbTest extends SimpleTestSuite {
7272
implicit val u = unmarshaller[WrapperForMessageWithVarIntOption1]
7373
}
7474

75+
case class MessageWithLengthDelimOption(maybeSeq: Option[Seq[Int]] = None)
76+
object MessageWithLengthDelimOption {
77+
implicit val m = marshaller[MessageWithLengthDelimOption]
78+
implicit val u = unmarshaller[MessageWithLengthDelimOption]
79+
implicit val s = sizeMeter[MessageWithLengthDelimOption]
80+
implicit val d = Default[MessageWithLengthDelimOption](MessageWithLengthDelimOption(Some(Seq.empty)))
81+
}
82+
83+
case class WrapperForMessageWithLengthDelimOption(m: MessageWithLengthDelimOption)
84+
85+
object WrapperForMessageWithLengthDelimOption {
86+
implicit val m = marshaller[WrapperForMessageWithLengthDelimOption]
87+
implicit val u = unmarshaller[WrapperForMessageWithLengthDelimOption]
88+
}
89+
90+
case class CompositeMessage(msg: Option[Seq[MessageWithLengthDelimOption]])
91+
object CompositeMessage {
92+
implicit val m = marshaller[CompositeMessage]
93+
implicit val u = unmarshaller[CompositeMessage]
94+
}
95+
7596
sealed trait Expr2
7697

7798
object Expr2 {
@@ -313,4 +334,29 @@ object CompareWithScalapbTest extends SimpleTestSuite {
313334
val res = Unmarshaller[WrapperForMessageWithVarIntOption1].read(bytes)
314335
assert(message == res)
315336
}
337+
338+
test("Messages which contain some value as the option of sequence of elements") {
339+
val m1 = MessageWithLengthDelimOption(Some(Seq(1,2,3)))
340+
val message = WrapperForMessageWithLengthDelimOption(m1)
341+
val bytes = Marshaller[WrapperForMessageWithLengthDelimOption].write(message)
342+
val res = Unmarshaller[WrapperForMessageWithLengthDelimOption].read(bytes)
343+
assert(message == res)
344+
}
345+
346+
test("Messages which contain None value as the option of sequence of elements") {
347+
val m1 = MessageWithLengthDelimOption(None)
348+
val message = WrapperForMessageWithLengthDelimOption(m1)
349+
val bytes = Marshaller[WrapperForMessageWithLengthDelimOption].write(message)
350+
val res = Unmarshaller[WrapperForMessageWithLengthDelimOption].read(bytes)
351+
assert(message == res)
352+
}
353+
354+
test("Composite message") {
355+
val m1 = MessageWithLengthDelimOption(None)
356+
val m2 = MessageWithLengthDelimOption(Some(Seq(1,2,3)))
357+
val message = CompositeMessage(Some(Seq(m1, m2)))
358+
val bytes = Marshaller[CompositeMessage].write(message)
359+
val res = Unmarshaller[CompositeMessage].read(bytes)
360+
assert(message == res)
361+
}
316362
}

0 commit comments

Comments
 (0)