Skip to content

Commit aa71ee8

Browse files
committed
Add derivation of Schema for union types (closes ghostdogpr#1926)
1 parent f5b5e29 commit aa71ee8

File tree

4 files changed

+119
-1
lines changed

4 files changed

+119
-1
lines changed

build.sbt

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Global / onChangedBuildSource := ReloadOnSourceChanges
3939

4040
inThisBuild(
4141
List(
42-
scalaVersion := scala213,
42+
scalaVersion := scala3,
4343
crossScalaVersions := allScala,
4444
organization := "com.github.ghostdogpr",
4545
homepage := Some(url("https://github.com/ghostdogpr/caliban")),

core/src/main/scala-3/caliban/schema/SchemaDerivation.scala

+2
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ trait SchemaDerivation[R] extends CommonSchemaDerivation {
123123

124124
inline def genDebug[R, A]: Schema[R, A] = PrintDerived(derived[R, A])
125125

126+
inline def unionType[T]: Schema[R, T] = ${ TypeUnionDerivation.typeUnionSchema[R, T] }
127+
126128
final lazy val auto = new AutoSchemaDerivation[Any] {}
127129

128130
final class SemiAuto[A](impl: Schema[R, A]) extends Schema[R, A] {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package caliban.schema
2+
3+
import caliban.introspection.adt.__Type
4+
5+
import scala.quoted.*
6+
7+
object TypeUnionDerivation {
8+
inline def derived[R, T]: Schema[R, T] = ${ typeUnionSchema[R, T] }
9+
10+
def typeUnionSchema[R: Type, T: Type](using quotes: Quotes): Expr[Schema[R, T]] = {
11+
import quotes.reflect.*
12+
13+
class TypeAndSchema[A](val typeRef: String, val schema: Expr[Schema[R, A]], val tpe: Type[A])
14+
15+
def rec[A](using tpe: Type[A]): List[TypeAndSchema[?]] =
16+
TypeRepr.of(using tpe).dealias match {
17+
case OrType(l, r) =>
18+
rec(using l.asType.asInstanceOf[Type[Any]]) ++ rec(using r.asType.asInstanceOf[Type[Any]])
19+
case otherRepr =>
20+
val otherString: String = otherRepr.show
21+
val expr: TypeAndSchema[A] =
22+
Expr.summon[Schema[R, A]] match {
23+
case Some(foundSchema) =>
24+
TypeAndSchema[A](otherString, foundSchema, otherRepr.asType.asInstanceOf[Type[A]])
25+
case None =>
26+
quotes.reflect.report.errorAndAbort(s"Couldn't resolve Schema[Any, $otherString]")
27+
}
28+
29+
List(expr)
30+
}
31+
32+
val typeAndSchemas: List[TypeAndSchema[?]] = rec[T]
33+
34+
val schemaByTypeNameList: Expr[List[(String, Schema[R, Any])]] = Expr.ofList(
35+
typeAndSchemas.map { case (tas: TypeAndSchema[a]) =>
36+
given Type[a] = tas.tpe
37+
'{ (${ Expr(tas.typeRef) }, ${ tas.schema }.asInstanceOf[Schema[R, Any]]) }
38+
}
39+
)
40+
val name = TypeRepr.of[T].show
41+
42+
if (name.contains("|")) {
43+
report.error(s"You must explicitly add type parameter to derive Schema for a union type in order to capture the name of the type alias")
44+
}
45+
46+
val ret = '{
47+
val schemaByName: Map[String, Schema[R, Any]] = ${ schemaByTypeNameList }.toMap
48+
new Schema[R, T] {
49+
50+
def resolve(value: T): Step[R] = {
51+
var ret: Step[R] = null
52+
${
53+
Expr.block(
54+
typeAndSchemas.map { case (tas: TypeAndSchema[a]) =>
55+
given Type[a] = tas.tpe
56+
'{ if value.isInstanceOf[a] then ret = schemaByName(${ Expr(tas.typeRef) }).resolve(value) }
57+
},
58+
'{ require(ret != null, s"no schema for ${value}") }
59+
)
60+
}
61+
ret
62+
}
63+
64+
def toType(isInput: Boolean, isSubscription: Boolean): __Type =
65+
Types.makeUnion(Some(${ Expr(name) }), None, schemaByName.values.map(_.toType_(isInput, isSubscription)).toList)
66+
}
67+
}
68+
// quotes.reflect.report.warning(ret.show)
69+
ret
70+
}
71+
}

core/src/test/scala-3/caliban/schema/Scala3DerivesSpec.scala

+45
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,51 @@ object Scala3DerivesSpec extends ZIOSpecDefault {
273273
data1 == """{"enum2String":"ENUM1"}""",
274274
data2 == """{"enum2String":"ENUM2"}"""
275275
)
276+
},
277+
test("union type") {
278+
final case class Foo(value: String) derives Schema.SemiAuto
279+
final case class Bar(foo: Int) derives Schema.SemiAuto
280+
type Payload = Foo | Bar
281+
282+
given Schema[Any, Payload] = Schema.unionType[Payload]
283+
284+
final case class QueryInput(isFoo: Boolean) derives ArgBuilder, Schema.SemiAuto
285+
final case class Query(testQuery: QueryInput => zio.UIO[Payload]) derives Schema.SemiAuto
286+
287+
val gql = graphQL(RootResolver(Query(i => ZIO.succeed(if (i.isFoo) Foo("foo") else Bar(1)))))
288+
289+
assertTrue(
290+
gql.render ==
291+
"""schema {
292+
| query: Query
293+
|}
294+
|
295+
|union Payload = Foo | Bar
296+
|
297+
|type Bar {
298+
| foo: Int!
299+
|}
300+
|
301+
|type Foo {
302+
| value: String!
303+
|}
304+
|
305+
|type Query {
306+
| testQuery(isFoo: Boolean!): Payload!
307+
|}""".stripMargin
308+
)
309+
310+
val interpreter = gql.interpreterUnsafe
311+
312+
for {
313+
res1 <- interpreter.execute("{ testQuery(isFoo: true){ ... on Foo { value } } }")
314+
res2 <- interpreter.execute("{ testQuery(isFoo: false){ ... on Bar { foo } } }")
315+
data1 = res1.data.toString
316+
data2 = res2.data.toString
317+
} yield assertTrue(
318+
data1 == """{"testQuery":{"value":"foo"}}""",
319+
data2 == """{"testQuery":{"foo":1}}"""
320+
)
276321
}
277322
)
278323
}

0 commit comments

Comments
 (0)