@@ -27,16 +27,17 @@ import com.google.oak.session.v1.SessionResponse
2727import com.google.protobuf.ByteString
2828import io.grpc.stub.StreamObserver
2929import java.nio.charset.StandardCharsets
30+ import java.util.concurrent.CompletableFuture
3031import java.util.concurrent.ExecutorService
3132import java.util.concurrent.Executors
33+ import java.util.concurrent.TimeUnit
3234import javax.inject.Provider
3335import kotlin.jvm.optionals.getOrNull
3436import org.junit.Assert.assertEquals
3537import org.junit.Test
3638import org.junit.runner.RunWith
3739import org.junit.runners.JUnit4
3840import org.mockito.kotlin.any
39- import org.mockito.kotlin.argumentCaptor
4041import org.mockito.kotlin.mock
4142import org.mockito.kotlin.never
4243import org.mockito.kotlin.times
@@ -50,73 +51,103 @@ class StreamObserverSessionClientTest {
5051 OakClientSession .loadNativeLib()
5152 }
5253
53- /* * Executor to run service operations on a background thread. */
54- private val executor: ExecutorService = Executors .newSingleThreadExecutor()
55-
5654 @Test
5755 fun client_startedSession_handshakesWithServer () {
5856 val client = StreamObserverSessionClient (unattestedConfigProvider())
5957 val serverConfig = unattestedConfig()
60- val fakeService = FakeService (executor, serverConfig) { it }
58+ val fakeService = FakeService (serverConfig) { it }
6159
62- val responseObserver = mock<OakSessionStreamObserver >()
60+ val done = CompletableFuture <Void >()
61+ val responseObserver =
62+ object : OakSessionStreamObserver {
63+ override fun onSessionOpen (clientRequests : StreamObserver <ByteString >) {
64+ done.complete(null )
65+ }
66+
67+ override fun onNext (response : ByteString ) {}
68+
69+ override fun onError (t : Throwable ) {}
70+
71+ override fun onCompleted () {}
72+ }
6373
6474 client.startSession(responseObserver) { fakeService.start(it) }
6575
66- fakeService.await()
67- verify(responseObserver).onSessionOpen(any())
76+ done.get(10 , TimeUnit .SECONDS )
6877 }
6978
7079 @Test
7180 fun client_startedSession_getsServerAppResponse () {
7281 val client = StreamObserverSessionClient (unattestedConfigProvider())
7382 val serverConfig = unattestedConfig()
74- val fakeService =
75- FakeService (executor, serverConfig) { " PONG: ${it.toStringUtf8()} " .toByteString() }
76-
77- val responseObserver = mock<OakSessionStreamObserver >()
78- val clientStreamCaptor = argumentCaptor<StreamObserver <ByteString >>()
79- val responseCaptor = argumentCaptor<ByteString >()
83+ val fakeService = FakeService (serverConfig) { " PONG: ${it.toStringUtf8()} " .toByteString() }
84+
85+ val responses = mutableListOf<ByteString >()
86+ val done = CompletableFuture <Void >()
87+ val responseObserver =
88+ object : OakSessionStreamObserver {
89+ lateinit var response: ByteString
90+ lateinit var clientRequests: StreamObserver <ByteString >
91+
92+ override fun onSessionOpen (clientRequests : StreamObserver <ByteString >) {
93+ this .clientRequests = clientRequests
94+ clientRequests.onNext(" Hello World" .toByteString())
95+ clientRequests.onCompleted()
96+ }
8097
81- client.startSession(responseObserver) { fakeService.start(it) }
82- fakeService.await()
98+ override fun onNext (response : ByteString ) {
99+ responses.add(response)
100+ }
83101
84- verify(responseObserver).onSessionOpen(clientStreamCaptor.capture())
85- val clientStream = clientStreamCaptor.lastValue
102+ override fun onError (t : Throwable ) {}
86103
87- clientStream.onNext(" Hello World" .toByteString())
88- fakeService.await()
104+ override fun onCompleted () {
105+ done.complete(null )
106+ }
107+ }
89108
90- verify(responseObserver).onNext(responseCaptor.capture())
109+ client.startSession(responseObserver) { fakeService.start(it) }
110+ done.get(10 , TimeUnit .SECONDS )
91111
92- assertEquals(responseCaptor.lastValue .toStringUtf8(), " PONG: Hello World" )
112+ assertEquals(responses.map { it .toStringUtf8() }, listOf ( " PONG: Hello World" ) )
93113 }
94114
95115 @Test
96116 fun client_startedSession_getsServerMultipleAppResponses () {
97117 val client = StreamObserverSessionClient (unattestedConfigProvider())
98118 val serverConfig = unattestedConfig()
99- val fakeService =
100- FakeService (executor, serverConfig) { " PONG: ${it.toStringUtf8()} " .toByteString() }
119+ val fakeService = FakeService (serverConfig) { " PONG: ${it.toStringUtf8()} " .toByteString() }
120+
121+ val responses = mutableListOf<ByteString >()
122+ val done = CompletableFuture <Void >()
123+ val responseObserver =
124+ object : OakSessionStreamObserver {
125+ lateinit var response: ByteString
126+
127+ override fun onSessionOpen (clientRequests : StreamObserver <ByteString >) {
128+ // Order will be preserved because the fake server executors are single-threaded.
129+ clientRequests.onNext(" Hello World" .toByteString())
130+ clientRequests.onNext(" Hello World 2" .toByteString())
131+ clientRequests.onNext(" Hello World 3" .toByteString())
132+ clientRequests.onCompleted()
133+ }
101134
102- val responseObserver = mock< OakSessionStreamObserver >()
103- val clientStreamCaptor = argumentCaptor< StreamObserver < ByteString >>( )
104- val responseCaptor = argumentCaptor< ByteString >()
135+ override fun onNext ( response : ByteString ) {
136+ responses.add(response )
137+ }
105138
106- client.startSession(responseObserver) { fakeService.start(it) }
107- fakeService.await()
139+ override fun onError (t : Throwable ) {}
108140
109- verify(responseObserver).onSessionOpen(clientStreamCaptor.capture())
110- val clientStream = clientStreamCaptor.lastValue
141+ override fun onCompleted () {
142+ done.complete(null )
143+ }
144+ }
111145
112- clientStream.onNext(" Hello World" .toByteString())
113- clientStream.onNext(" Hello World 2" .toByteString())
114- clientStream.onNext(" Hello World 3" .toByteString())
115- fakeService.await()
146+ client.startSession(responseObserver) { fakeService.start(it) }
147+ done.get(10 , TimeUnit .SECONDS )
116148
117- verify(responseObserver, times(3 )).onNext(responseCaptor.capture())
118149 assertEquals(
119- responseCaptor.allValues .map { it.toStringUtf8() },
150+ responses .map { it.toStringUtf8() },
120151 listOf (" PONG: Hello World" , " PONG: Hello World 2" , " PONG: Hello World 3" ),
121152 )
122153 }
@@ -128,11 +159,11 @@ class StreamObserverSessionClientTest {
128159 val responseObserver = mock<OakSessionStreamObserver >()
129160
130161 val serverException = RuntimeException (" Didn't connect" )
131- val mockToServer = mock<StreamObserver <SessionRequest >>()
132162
163+ val executor = Executors .newSingleThreadExecutor()
133164 client.startSession(responseObserver) {
134165 executor.execute { it.onError(serverException) }
135- mockToServer
166+ mock< StreamObserver < SessionRequest >>()
136167 }
137168 executor.await()
138169
@@ -145,21 +176,28 @@ class StreamObserverSessionClientTest {
145176 val client = StreamObserverSessionClient (unattestedConfigProvider())
146177 val serverConfig = unattestedConfig()
147178 val fakeAppException = RuntimeException (" oops" )
148- val fakeService = FakeService (executor, serverConfig) { throw fakeAppException }
179+ val fakeService = FakeService (serverConfig) { throw fakeAppException }
149180
150- val responseObserver = mock<OakSessionStreamObserver >()
151- val clientStreamCaptor = argumentCaptor<StreamObserver <ByteString >>()
181+ val responseFuture = CompletableFuture <Throwable >()
182+ val responseObserver =
183+ object : OakSessionStreamObserver {
152184
153- client.startSession(responseObserver) { fakeService.start(it) }
154- fakeService.await()
185+ override fun onSessionOpen (clientRequests : StreamObserver <ByteString >) {
186+ clientRequests.onNext(" Big badaboom" .toByteString())
187+ }
188+
189+ override fun onNext (response : ByteString ) {}
155190
156- verify(responseObserver).onSessionOpen(clientStreamCaptor.capture())
157- val clientStream = clientStreamCaptor.lastValue
191+ override fun onError (t : Throwable ) {
192+ responseFuture.complete(t)
193+ }
158194
159- clientStream.onNext( " Hello World " .toByteString())
160- fakeService.await()
195+ override fun onCompleted () {}
196+ }
161197
162- verify(responseObserver).onError(fakeAppException)
198+ client.startSession(responseObserver) { fakeService.start(it) }
199+ val exception = responseFuture.get(10 , TimeUnit .SECONDS )
200+ assertEquals(exception, fakeAppException)
163201 }
164202
165203 private fun unattestedConfig () =
@@ -174,15 +212,13 @@ class StreamObserverSessionClientTest {
174212 * provided application implementation function.
175213 */
176214 class FakeService (
177- val executor : ExecutorService ,
178215 val sessionConfig : OakSessionConfigBuilder ,
179216 val application : (ByteString ) -> ByteString ,
180217 ) : StreamObserver<SessionRequest> {
181- private lateinit var responses: StreamObserver <SessionResponse >
218+ private var requestExecutor = Executors .newSingleThreadExecutor()
219+ private var responseExecutor = Executors .newSingleThreadExecutor()
182220
183- // Wait for the single-threaded service to finish by submitting a job and then
184- // waiting on it.
185- fun await () = executor.await()
221+ private lateinit var responses: StreamObserver <SessionResponse >
186222
187223 fun start (responses : StreamObserver <SessionResponse >): StreamObserver <SessionRequest > {
188224 this .responses = responses
@@ -192,34 +228,36 @@ class StreamObserverSessionClientTest {
192228 private val serverSession = OakServerSession (sessionConfig)
193229
194230 override fun onNext (request : SessionRequest ) {
195- executor .execute {
231+ requestExecutor .execute {
196232 if (serverSession.isOpen) {
197233 check(serverSession.putIncomingMessage(request))
198234 val decrypted = checkNotNull(serverSession.read().getOrNull()).plaintext
199235 val response =
200236 try {
201237 application(decrypted)
202238 } catch (e: Exception ) {
203- responses.onError(e)
239+ responseExecutor.execute { responses.onError(e) }
204240 return @execute
205241 }
206242 val pt = PlaintextMessage .newBuilder().setPlaintext(response).build()
207243 serverSession.write(pt)
208244 val encryptedResponse = checkNotNull(serverSession.outgoingMessage.getOrNull())
209- responses.onNext(encryptedResponse)
245+ responseExecutor.execute { responses.onNext(encryptedResponse) }
210246 } else {
211247 check(serverSession.putIncomingMessage(request))
212- serverSession.outgoingMessage.getOrNull()?.let { responses.onNext(it) }
248+ serverSession.outgoingMessage.getOrNull()?.let {
249+ responseExecutor.execute { responses.onNext(it) }
250+ }
213251 }
214252 }
215253 }
216254
217255 override fun onError (t : Throwable ) {
218- executor .execute { responses.onError(t) }
256+ requestExecutor .execute { responseExecutor.execute { responses.onError(t) } }
219257 }
220258
221259 override fun onCompleted () {
222- executor .execute { responses.onCompleted() }
260+ requestExecutor .execute { responseExecutor.execute { responses.onCompleted() } }
223261 }
224262 }
225263
0 commit comments