diff --git a/inngest-spring-boot-adapter/src/main/java/com/inngest/springboot/InngestController.java b/inngest-spring-boot-adapter/src/main/java/com/inngest/springboot/InngestController.java index d43fd355..b9b5fe83 100644 --- a/inngest-spring-boot-adapter/src/main/java/com/inngest/springboot/InngestController.java +++ b/inngest-spring-boot-adapter/src/main/java/com/inngest/springboot/InngestController.java @@ -3,6 +3,7 @@ import com.inngest.CommHandler; import com.inngest.CommResponse; import com.inngest.InngestEnv; +import com.inngest.InngestQueryParamKey; import com.inngest.signingkey.SignatureVerificationKt; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -48,7 +49,7 @@ public ResponseEntity put( if (this.serveOrigin != null && !this.serveOrigin.isEmpty()) { origin = this.serveOrigin; } - String response = commHandler.register(origin); + String response = commHandler.register(origin, request.getParameter(InngestQueryParamKey.SyncId.getValue())); return ResponseEntity.ok().headers(getHeaders()).body(response); } diff --git a/inngest-spring-boot-demo/build.gradle.kts b/inngest-spring-boot-demo/build.gradle.kts index dcd6a283..b0e584d2 100644 --- a/inngest-spring-boot-demo/build.gradle.kts +++ b/inngest-spring-boot-demo/build.gradle.kts @@ -26,6 +26,7 @@ dependencies { implementation("com.squareup.okhttp3:okhttp:4.12.0") testImplementation("org.springframework.boot:spring-boot-starter-test") + testImplementation("com.squareup.okhttp3:mockwebserver:4.12.0") if (JavaVersion.current().isJava11Compatible) { testImplementation("uk.org.webcompere:system-stubs-jupiter:2.1.6") diff --git a/inngest-spring-boot-demo/src/main/java/com/inngest/springbootdemo/DevServerComponent.java b/inngest-spring-boot-demo/src/main/java/com/inngest/springbootdemo/DevServerComponent.java index 272c0abf..b7083410 100644 --- a/inngest-spring-boot-demo/src/main/java/com/inngest/springbootdemo/DevServerComponent.java +++ b/inngest-spring-boot-demo/src/main/java/com/inngest/springbootdemo/DevServerComponent.java @@ -33,7 +33,7 @@ private void waitForStartup(CommHandler commHandler) throws Exception { try (Response response = httpClient.newCall(request).execute()) { if (response.code() == 200) { Thread.sleep(3000); - commHandler.register("http://localhost:8080"); + commHandler.register("http://localhost:8080", null); return; } } diff --git a/inngest-spring-boot-demo/src/test/java/com/inngest/springbootdemo/SyncRequestTest.java b/inngest-spring-boot-demo/src/test/java/com/inngest/springbootdemo/SyncRequestTest.java new file mode 100644 index 00000000..5cb49269 --- /dev/null +++ b/inngest-spring-boot-demo/src/test/java/com/inngest/springbootdemo/SyncRequestTest.java @@ -0,0 +1,136 @@ +package com.inngest.springbootdemo; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.inngest.*; +import com.inngest.springboot.InngestConfiguration; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.*; +import org.junit.jupiter.api.condition.EnabledIfSystemProperty; +import org.junit.jupiter.api.extension.ExtendWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.web.servlet.WebMvcTest; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Import; +import org.springframework.test.web.servlet.MockMvc; +import uk.org.webcompere.systemstubs.environment.EnvironmentVariables; +import uk.org.webcompere.systemstubs.jupiter.SystemStub; +import uk.org.webcompere.systemstubs.jupiter.SystemStubsExtension; + +import java.util.HashMap; + +import static org.junit.jupiter.api.Assertions.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; + + +@ExtendWith(SystemStubsExtension.class) +public class SyncRequestTest { + static class SyncInngestConfiguration extends InngestConfiguration { + protected HashMap functions() { + return new HashMap<>(); + } + + @Override + protected Inngest inngestClient() { + return new Inngest("spring_test_registration"); + } + + @Override + protected ServeConfig serve(Inngest client) { + return new ServeConfig(client); + } + + @Bean + protected CommHandler commHandler(@Autowired Inngest inngestClient) { + ServeConfig serveConfig = new ServeConfig(inngestClient); + return new CommHandler(functions(), inngestClient, serveConfig, SupportedFrameworkName.SpringBoot); + } + } + + @SystemStub + private static EnvironmentVariables environmentVariables; + + public static MockWebServer mockWebServer; + + @Import(SyncInngestConfiguration.class) + @WebMvcTest(DemoController.class) + @Nested + @EnabledIfSystemProperty(named = "test-group", matches = "unit-test") + @TestMethodOrder(MethodOrderer.OrderAnnotation.class) + class InnerSpringTest { + @Autowired + private MockMvc mockMvc; + + @BeforeEach + void BeforeEach() throws Exception { + mockWebServer = new MockWebServer(); + mockWebServer.start(); + + String serverUrl = mockWebServer.url("").toString(); + + // Remove the trailing slash from the serverUrl + serverUrl = serverUrl.substring(0, serverUrl.length() - 1); + + environmentVariables.set("INNGEST_API_BASE_URL", serverUrl); + } + + @AfterEach + void afterEach() throws Exception { + mockWebServer.shutdown(); + } + + private void assertThatPayloadDoesNotContainDeployId(RecordedRequest recordedRequest) throws Exception { + // The url in the sync payload should not contain the deployId. + // https://github.com/inngest/inngest/blob/main/docs/SDK_SPEC.md#432-syncing + String requestBody = recordedRequest.getBody().readUtf8(); + + ObjectMapper objectMapper = new ObjectMapper(); + JsonNode jsonNode = objectMapper.readTree(requestBody); + + String url = jsonNode.get("url").asText(); + assertFalse(url.contains("deployId")); + } + + @Test + public void shouldIncludeDeployIdInSyncRequestIfPresent() throws Exception { + mockWebServer.enqueue(new MockResponse().setBody("Success")); + mockWebServer.enqueue(new MockResponse().setBody("Success")); + mockWebServer.enqueue(new MockResponse().setBody("Success")); + + mockMvc.perform(put("/api/inngest") + .header("Host", "localhost:8080") + .param("deployId", "1")) + .andExpect(status().isOk()); + + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + + assertEquals("/fn/register", recordedRequest.getRequestUrl().encodedPath()); + assertEquals("1", recordedRequest.getRequestUrl().queryParameter("deployId")); + assertThatPayloadDoesNotContainDeployId(recordedRequest); + + mockMvc.perform(put("/api/inngest") + .header("Host", "localhost:8080")) + .andExpect(status().isOk()); + + recordedRequest = mockWebServer.takeRequest(); + + assertEquals("/fn/register", recordedRequest.getRequestUrl().encodedPath()); + assertNull(recordedRequest.getRequestUrl().queryParameter("deployId")); + assertThatPayloadDoesNotContainDeployId(recordedRequest); + + mockMvc.perform(put("/api/inngest") + .header("Host", "localhost:8080") + .param("deployId", "3")) + .andExpect(status().isOk()); + + recordedRequest = mockWebServer.takeRequest(); + + assertEquals("/fn/register", recordedRequest.getRequestUrl().encodedPath()); + assertEquals("3", recordedRequest.getRequestUrl().queryParameter("deployId")); + assertThatPayloadDoesNotContainDeployId(recordedRequest); + } + } +} diff --git a/inngest/src/main/kotlin/com/inngest/Comm.kt b/inngest/src/main/kotlin/com/inngest/Comm.kt index 0324302a..1dda3367 100644 --- a/inngest/src/main/kotlin/com/inngest/Comm.kt +++ b/inngest/src/main/kotlin/com/inngest/Comm.kt @@ -152,7 +152,10 @@ class CommHandler( return configs } - fun register(origin: String): String { + fun register( + origin: String, + syncId: String?, + ): String { val registrationUrl = "${config.baseUrl()}/fn/register" val requestPayload = getRegistrationRequestPayload(origin) @@ -166,7 +169,9 @@ class CommHandler( null } - val request = httpClient.build(registrationUrl, requestPayload, authorizationHeaderRequestConfig) + val queryParams = syncId?.let { mapOf(InngestQueryParamKey.SyncId.value to it) } ?: emptyMap() + + val request = httpClient.build(registrationUrl, requestPayload, queryParams, authorizationHeaderRequestConfig) httpClient.send(request) { response -> if (!response.isSuccessful) throw IOException("Unexpected code $response") diff --git a/inngest/src/main/kotlin/com/inngest/HttpClient.kt b/inngest/src/main/kotlin/com/inngest/HttpClient.kt index 37f93e77..ac953ca6 100644 --- a/inngest/src/main/kotlin/com/inngest/HttpClient.kt +++ b/inngest/src/main/kotlin/com/inngest/HttpClient.kt @@ -2,6 +2,7 @@ package com.inngest import com.beust.klaxon.Klaxon import okhttp3.Headers +import okhttp3.HttpUrl.Companion.toHttpUrl import okhttp3.MediaType.Companion.toMediaType import okhttp3.OkHttpClient import okhttp3.RequestBody.Companion.toRequestBody @@ -31,13 +32,18 @@ internal class HttpClient( fun build( url: String, payload: Any, + queryParams: Map? = null, config: RequestConfig? = null, ): okhttp3.Request { + val httpUrlBuilder = url.toHttpUrl().newBuilder() + queryParams?.forEach { (k, v) -> httpUrlBuilder.addQueryParameter(k, v) } + val jsonRequestBody = Klaxon() .fieldConverter(KlaxonDuration::class, durationConverter) .fieldConverter(KlaxonConcurrencyScope::class, concurrencyScopeConverter) .toJsonString(payload) + val body = jsonRequestBody.toRequestBody(jsonMediaType) val clientHeaders = clientConfig.headers ?: emptyMap() @@ -45,7 +51,7 @@ internal class HttpClient( return okhttp3.Request .Builder() - .url(url) + .url(httpUrlBuilder.build()) .post(body) .headers(toOkHttpHeaders(clientHeaders + requestHeaders)) .build() diff --git a/inngest/src/main/kotlin/com/inngest/InngestQueryParamKey.kt b/inngest/src/main/kotlin/com/inngest/InngestQueryParamKey.kt new file mode 100644 index 00000000..3c52705a --- /dev/null +++ b/inngest/src/main/kotlin/com/inngest/InngestQueryParamKey.kt @@ -0,0 +1,7 @@ +package com.inngest + +enum class InngestQueryParamKey( + val value: String, +) { + SyncId("deployId"), +} diff --git a/inngest/src/main/kotlin/com/inngest/ktor/Route.kt b/inngest/src/main/kotlin/com/inngest/ktor/Route.kt index 297d04a7..6d5fea3a 100644 --- a/inngest/src/main/kotlin/com/inngest/ktor/Route.kt +++ b/inngest/src/main/kotlin/com/inngest/ktor/Route.kt @@ -72,16 +72,20 @@ fun Route.serve( } put("") { + val syncId = call.request.queryParameters[InngestQueryParamKey.SyncId.value] + val origin = getOrigin(call) - val resp = comm.register(origin) + val resp = comm.register(origin, syncId) call.respond(HttpStatusCode.OK, resp) } } } +val HTTP_PORTS = listOf(80, 443) + fun getOrigin(call: ApplicationCall): String { var origin = String.format("%s://%s", call.request.origin.scheme, call.request.origin.serverHost) - if (call.request.origin.serverPort != 80 || call.request.origin.serverPort != 443) { + if (call.request.origin.serverPort !in HTTP_PORTS) { origin = String.format("%s:%s", origin, call.request.origin.serverPort) } return origin