Skip to content

Commit ea4856c

Browse files
committed
Rewrite tests for StreamObserverSessionClient to use CompletableFutures
* Request/response executors introduce more realistic async behaviour * Use of completable futures verifies the entire request lifecycle until onComplete/onError * Allows testing handshakes with more than one roundtrip (with attestation) even though we are not doing this yet. Change-Id: If14d5f7fd8397c30332f9068c124d656c38a93f8
1 parent c76bfce commit ea4856c

1 file changed

Lines changed: 98 additions & 60 deletions

File tree

java/src/test/java/com/google/oak/client/grpc/StreamObserverSessionClientTest.kt

Lines changed: 98 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,17 @@ import com.google.oak.session.v1.SessionResponse
2727
import com.google.protobuf.ByteString
2828
import io.grpc.stub.StreamObserver
2929
import java.nio.charset.StandardCharsets
30+
import java.util.concurrent.CompletableFuture
3031
import java.util.concurrent.ExecutorService
3132
import java.util.concurrent.Executors
33+
import java.util.concurrent.TimeUnit
3234
import javax.inject.Provider
3335
import kotlin.jvm.optionals.getOrNull
3436
import org.junit.Assert.assertEquals
3537
import org.junit.Test
3638
import org.junit.runner.RunWith
3739
import org.junit.runners.JUnit4
3840
import org.mockito.kotlin.any
39-
import org.mockito.kotlin.argumentCaptor
4041
import org.mockito.kotlin.mock
4142
import org.mockito.kotlin.never
4243
import org.mockito.kotlin.times
@@ -50,73 +51,103 @@ class StreamObserverSessionClientTest {
5051
OakClientSession.loadNativeLib()
5152
}
5253

53-
/** Executor to run service operations on a background thread. */
54-
private val executor: ExecutorService = Executors.newSingleThreadExecutor()
55-
5654
@Test
5755
fun client_startedSession_handshakesWithServer() {
5856
val client = StreamObserverSessionClient(unattestedConfigProvider())
5957
val serverConfig = unattestedConfig()
60-
val fakeService = FakeService(executor, serverConfig) { it }
58+
val fakeService = FakeService(serverConfig) { it }
6159

62-
val responseObserver = mock<OakSessionStreamObserver>()
60+
val done = CompletableFuture<Void>()
61+
val responseObserver =
62+
object : OakSessionStreamObserver {
63+
override fun onSessionOpen(clientRequests: StreamObserver<ByteString>) {
64+
done.complete(null)
65+
}
66+
67+
override fun onNext(response: ByteString) {}
68+
69+
override fun onError(t: Throwable) {}
70+
71+
override fun onCompleted() {}
72+
}
6373

6474
client.startSession(responseObserver) { fakeService.start(it) }
6575

66-
fakeService.await()
67-
verify(responseObserver).onSessionOpen(any())
76+
done.get(10, TimeUnit.SECONDS)
6877
}
6978

7079
@Test
7180
fun client_startedSession_getsServerAppResponse() {
7281
val client = StreamObserverSessionClient(unattestedConfigProvider())
7382
val serverConfig = unattestedConfig()
74-
val fakeService =
75-
FakeService(executor, serverConfig) { "PONG: ${it.toStringUtf8()}".toByteString() }
76-
77-
val responseObserver = mock<OakSessionStreamObserver>()
78-
val clientStreamCaptor = argumentCaptor<StreamObserver<ByteString>>()
79-
val responseCaptor = argumentCaptor<ByteString>()
83+
val fakeService = FakeService(serverConfig) { "PONG: ${it.toStringUtf8()}".toByteString() }
84+
85+
val responses = mutableListOf<ByteString>()
86+
val done = CompletableFuture<Void>()
87+
val responseObserver =
88+
object : OakSessionStreamObserver {
89+
lateinit var response: ByteString
90+
lateinit var clientRequests: StreamObserver<ByteString>
91+
92+
override fun onSessionOpen(clientRequests: StreamObserver<ByteString>) {
93+
this.clientRequests = clientRequests
94+
clientRequests.onNext("Hello World".toByteString())
95+
clientRequests.onCompleted()
96+
}
8097

81-
client.startSession(responseObserver) { fakeService.start(it) }
82-
fakeService.await()
98+
override fun onNext(response: ByteString) {
99+
responses.add(response)
100+
}
83101

84-
verify(responseObserver).onSessionOpen(clientStreamCaptor.capture())
85-
val clientStream = clientStreamCaptor.lastValue
102+
override fun onError(t: Throwable) {}
86103

87-
clientStream.onNext("Hello World".toByteString())
88-
fakeService.await()
104+
override fun onCompleted() {
105+
done.complete(null)
106+
}
107+
}
89108

90-
verify(responseObserver).onNext(responseCaptor.capture())
109+
client.startSession(responseObserver) { fakeService.start(it) }
110+
done.get(10, TimeUnit.SECONDS)
91111

92-
assertEquals(responseCaptor.lastValue.toStringUtf8(), "PONG: Hello World")
112+
assertEquals(responses.map { it.toStringUtf8() }, listOf("PONG: Hello World"))
93113
}
94114

95115
@Test
96116
fun client_startedSession_getsServerMultipleAppResponses() {
97117
val client = StreamObserverSessionClient(unattestedConfigProvider())
98118
val serverConfig = unattestedConfig()
99-
val fakeService =
100-
FakeService(executor, serverConfig) { "PONG: ${it.toStringUtf8()}".toByteString() }
119+
val fakeService = FakeService(serverConfig) { "PONG: ${it.toStringUtf8()}".toByteString() }
120+
121+
val responses = mutableListOf<ByteString>()
122+
val done = CompletableFuture<Void>()
123+
val responseObserver =
124+
object : OakSessionStreamObserver {
125+
lateinit var response: ByteString
126+
127+
override fun onSessionOpen(clientRequests: StreamObserver<ByteString>) {
128+
// Order will be preserved because the fake server executors are single-threaded.
129+
clientRequests.onNext("Hello World".toByteString())
130+
clientRequests.onNext("Hello World 2".toByteString())
131+
clientRequests.onNext("Hello World 3".toByteString())
132+
clientRequests.onCompleted()
133+
}
101134

102-
val responseObserver = mock<OakSessionStreamObserver>()
103-
val clientStreamCaptor = argumentCaptor<StreamObserver<ByteString>>()
104-
val responseCaptor = argumentCaptor<ByteString>()
135+
override fun onNext(response: ByteString) {
136+
responses.add(response)
137+
}
105138

106-
client.startSession(responseObserver) { fakeService.start(it) }
107-
fakeService.await()
139+
override fun onError(t: Throwable) {}
108140

109-
verify(responseObserver).onSessionOpen(clientStreamCaptor.capture())
110-
val clientStream = clientStreamCaptor.lastValue
141+
override fun onCompleted() {
142+
done.complete(null)
143+
}
144+
}
111145

112-
clientStream.onNext("Hello World".toByteString())
113-
clientStream.onNext("Hello World 2".toByteString())
114-
clientStream.onNext("Hello World 3".toByteString())
115-
fakeService.await()
146+
client.startSession(responseObserver) { fakeService.start(it) }
147+
done.get(10, TimeUnit.SECONDS)
116148

117-
verify(responseObserver, times(3)).onNext(responseCaptor.capture())
118149
assertEquals(
119-
responseCaptor.allValues.map { it.toStringUtf8() },
150+
responses.map { it.toStringUtf8() },
120151
listOf("PONG: Hello World", "PONG: Hello World 2", "PONG: Hello World 3"),
121152
)
122153
}
@@ -128,11 +159,11 @@ class StreamObserverSessionClientTest {
128159
val responseObserver = mock<OakSessionStreamObserver>()
129160

130161
val serverException = RuntimeException("Didn't connect")
131-
val mockToServer = mock<StreamObserver<SessionRequest>>()
132162

163+
val executor = Executors.newSingleThreadExecutor()
133164
client.startSession(responseObserver) {
134165
executor.execute { it.onError(serverException) }
135-
mockToServer
166+
mock<StreamObserver<SessionRequest>>()
136167
}
137168
executor.await()
138169

@@ -145,21 +176,28 @@ class StreamObserverSessionClientTest {
145176
val client = StreamObserverSessionClient(unattestedConfigProvider())
146177
val serverConfig = unattestedConfig()
147178
val fakeAppException = RuntimeException("oops")
148-
val fakeService = FakeService(executor, serverConfig) { throw fakeAppException }
179+
val fakeService = FakeService(serverConfig) { throw fakeAppException }
149180

150-
val responseObserver = mock<OakSessionStreamObserver>()
151-
val clientStreamCaptor = argumentCaptor<StreamObserver<ByteString>>()
181+
val responseFuture = CompletableFuture<Throwable>()
182+
val responseObserver =
183+
object : OakSessionStreamObserver {
152184

153-
client.startSession(responseObserver) { fakeService.start(it) }
154-
fakeService.await()
185+
override fun onSessionOpen(clientRequests: StreamObserver<ByteString>) {
186+
clientRequests.onNext("Big badaboom".toByteString())
187+
}
188+
189+
override fun onNext(response: ByteString) {}
155190

156-
verify(responseObserver).onSessionOpen(clientStreamCaptor.capture())
157-
val clientStream = clientStreamCaptor.lastValue
191+
override fun onError(t: Throwable) {
192+
responseFuture.complete(t)
193+
}
158194

159-
clientStream.onNext("Hello World".toByteString())
160-
fakeService.await()
195+
override fun onCompleted() {}
196+
}
161197

162-
verify(responseObserver).onError(fakeAppException)
198+
client.startSession(responseObserver) { fakeService.start(it) }
199+
val exception = responseFuture.get(10, TimeUnit.SECONDS)
200+
assertEquals(exception, fakeAppException)
163201
}
164202

165203
private fun unattestedConfig() =
@@ -174,15 +212,13 @@ class StreamObserverSessionClientTest {
174212
* provided application implementation function.
175213
*/
176214
class FakeService(
177-
val executor: ExecutorService,
178215
val sessionConfig: OakSessionConfigBuilder,
179216
val application: (ByteString) -> ByteString,
180217
) : StreamObserver<SessionRequest> {
181-
private lateinit var responses: StreamObserver<SessionResponse>
218+
private var requestExecutor = Executors.newSingleThreadExecutor()
219+
private var responseExecutor = Executors.newSingleThreadExecutor()
182220

183-
// Wait for the single-threaded service to finish by submitting a job and then
184-
// waiting on it.
185-
fun await() = executor.await()
221+
private lateinit var responses: StreamObserver<SessionResponse>
186222

187223
fun start(responses: StreamObserver<SessionResponse>): StreamObserver<SessionRequest> {
188224
this.responses = responses
@@ -192,34 +228,36 @@ class StreamObserverSessionClientTest {
192228
private val serverSession = OakServerSession(sessionConfig)
193229

194230
override fun onNext(request: SessionRequest) {
195-
executor.execute {
231+
requestExecutor.execute {
196232
if (serverSession.isOpen) {
197233
check(serverSession.putIncomingMessage(request))
198234
val decrypted = checkNotNull(serverSession.read().getOrNull()).plaintext
199235
val response =
200236
try {
201237
application(decrypted)
202238
} catch (e: Exception) {
203-
responses.onError(e)
239+
responseExecutor.execute { responses.onError(e) }
204240
return@execute
205241
}
206242
val pt = PlaintextMessage.newBuilder().setPlaintext(response).build()
207243
serverSession.write(pt)
208244
val encryptedResponse = checkNotNull(serverSession.outgoingMessage.getOrNull())
209-
responses.onNext(encryptedResponse)
245+
responseExecutor.execute { responses.onNext(encryptedResponse) }
210246
} else {
211247
check(serverSession.putIncomingMessage(request))
212-
serverSession.outgoingMessage.getOrNull()?.let { responses.onNext(it) }
248+
serverSession.outgoingMessage.getOrNull()?.let {
249+
responseExecutor.execute { responses.onNext(it) }
250+
}
213251
}
214252
}
215253
}
216254

217255
override fun onError(t: Throwable) {
218-
executor.execute { responses.onError(t) }
256+
requestExecutor.execute { responseExecutor.execute { responses.onError(t) } }
219257
}
220258

221259
override fun onCompleted() {
222-
executor.execute { responses.onCompleted() }
260+
requestExecutor.execute { responseExecutor.execute { responses.onCompleted() } }
223261
}
224262
}
225263

0 commit comments

Comments
 (0)