Skip to content

Commit

Permalink
Support try-catch for both invoke and run (#65)
Browse files Browse the repository at this point in the history
* Support try-catch for both `invoke` and `run`

If an exception is thrown in the `invoke` or `run` function, it will now
be caught and returned as a `StepError` exception. This will allow the
error to be propagated to the caller and handled appropriately.

Transforming the errors into a `StepError` exception is following the
Inngest SDK spec:
https://github.com/inngest/inngest/blob/main/docs/SDK_SPEC.md#522-memoizing-a-step
Inngest JS SDK behavior:
https://github.com/inngest/inngest-js/blob/4f91d9c302592ecc2228914469dd057ae148005b/packages/inngest/src/components/execution/v1.ts#L437-L443
Inngest documentation:
https://www.inngest.com/docs/features/inngest-functions/error-retries/inngest-errors#step-errors

* Implement two example functions demonstrating try-catching

* Add tests for try-catching in `run` and `invoke` functions

* Remove the @BeforeAll method from `StepErrorsIntegrationTest`

* Extract Step status codes set into a const
  • Loading branch information
KiKoS0 authored Sep 1, 2024
1 parent a9274a6 commit 54b58d7
Show file tree
Hide file tree
Showing 14 changed files with 272 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.inngest.springbootdemo.testfunctions;

import com.inngest.*;
import org.jetbrains.annotations.NotNull;

import java.util.LinkedHashMap;

public class InvokeFailureFunction extends InngestFunction {

@NotNull
@Override
public InngestFunctionConfigBuilder config(InngestFunctionConfigBuilder builder) {
return builder
.id("invoke-failure-fn")
.name("Invoke Function")
.triggerEvent("test/invoke.failure");
}

@Override
public String execute(FunctionContext ctx, Step step) {
try {
step.invoke(
"failing-function",
"spring_test_demo",
"non-retriable-fn",
new LinkedHashMap<String,
String>(),
null,
Object.class);
} catch (StepError e) {
return e.getMessage();
}

return "An error should have been thrown and this message should not be returned";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public InngestFunctionConfigBuilder config(InngestFunctionConfigBuilder builder)
@Override
public String execute(FunctionContext ctx, Step step) {
step.run("fail-step", () -> {
throw new NonRetriableError("something fatally went wrong");
throw new NonRetriableError("Something fatally went wrong");
}, String.class);

return "Success";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.inngest.springbootdemo.testfunctions;

import com.inngest.*;
import org.jetbrains.annotations.NotNull;

class CustomException extends RuntimeException {
public CustomException(String message) {
super(message);
}
}

public class TryCatchRunFunction extends InngestFunction {

@NotNull
@Override
public InngestFunctionConfigBuilder config(InngestFunctionConfigBuilder builder) {
return builder
.id("try-catch-run-fn")
.name("Try catch run")
.triggerEvent("test/try.catch.run")
.retries(0);
}

@Override
public String execute(FunctionContext ctx, Step step) {
try {
step.run("fail-step", () -> {
throw new CustomException("Something fatally went wrong");
}, String.class);
} catch (StepError e) {
return e.getMessage();
}

return "An error should have been thrown and this message should not be returned";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ protected HashMap<String, InngestFunction> functions() {
addInngestFunction(functions, new NonRetriableErrorFunction());
addInngestFunction(functions, new RetriableErrorFunction());
addInngestFunction(functions, new ZeroRetriesFunction());
addInngestFunction(functions, new InvokeFailureFunction());
addInngestFunction(functions, new TryCatchRunFunction());

return functions;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void testNonRetriableShouldFail() throws Exception {
assertNotNull(run.getEnded_at());
assert output.get("name").contains("NonRetriableError");
assert output.get("stack").contains("NonRetriableErrorFunction.lambda$execute");
assertEquals(output.get("message"), "something fatally went wrong");
assertEquals(output.get("message"), "Something fatally went wrong");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.inngest.springbootdemo;

import com.inngest.Inngest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;
import org.springframework.beans.factory.annotation.Autowired;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

@IntegrationTest
@Execution(ExecutionMode.CONCURRENT)
class StepErrorsIntegrationTest {
@Autowired
private DevServerComponent devServer;

static int sleepTime = 5000;

@Autowired
private Inngest client;

@Test
void testShouldCatchStepErrorWhenInvokeThrows() throws Exception {
String eventId = InngestFunctionTestHelpers.sendEvent(client, "test/invoke.failure").first();

Thread.sleep(sleepTime);

RunEntry<Object> run = devServer.runsByEvent(eventId).first();
String output = (String) run.getOutput();

assertEquals("Completed", run.getStatus() );
assertNotNull(run.getEnded_at());

assertEquals("Something fatally went wrong", output);
}

@Test
void testShouldCatchStepErrorWhenRunThrows() throws Exception {
String eventId = InngestFunctionTestHelpers.sendEvent(client, "test/try.catch.run").first();

Thread.sleep(sleepTime);

RunEntry<Object> run = devServer.runsByEvent(eventId).first();
String output = (String) run.getOutput();

assertEquals("Completed", run.getStatus());
assertNotNull(run.getEnded_at());

assertEquals("Something fatally went wrong", output);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ fun Application.module() {
RestoreFromGlacier(),
ProcessUserSignup(),
TranscodeVideo(),
ImageFromPrompt(),
PushToSlackChannel(),
),
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.inngest.testserver

import com.inngest.*

class ImageFromPrompt : InngestFunction() {
override fun config(builder: InngestFunctionConfigBuilder): InngestFunctionConfigBuilder =
builder
.id("ImageFromPrompt")
.name("Image from Prompt")
.triggerEvent("media/prompt.created")

override fun execute(
ctx: FunctionContext,
step: Step,
): String {
val imageURL =
try {
step.run("generate-image-dall-e") {
// Call the DALL-E model to generate an image
throw Exception("Failed to generate image")

"example.com/image-dall-e.jpg"
}
} catch (e: StepError) {
// Fall back to a different image generation model
step.run("generate-image-midjourney") {
// Call the MidJourney model to generate an image
"example.com/image-midjourney.jpg"
}
}

try {
step.invoke<Map<String, Any>>(
"push-to-slack-channel",
"ktor-dev",
"PushToSlackChannel",
mapOf("image" to imageURL),
null,
)
} catch (e: StepError) {
// Pushing to Slack is not critical, so we can ignore the error, log it
// or handle it in some other way.
}

return imageURL
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.inngest.testserver

import com.inngest.*

class PushToSlackChannel : InngestFunction() {
override fun config(builder: InngestFunctionConfigBuilder): InngestFunctionConfigBuilder =
builder
.id("PushToSlackChannel")
.name("Push to Slack Channel")
.triggerEvent("media/image.generated")

override fun execute(
ctx: FunctionContext,
step: Step,
): String =
step.run("push-to-slack-channel") {
// Call Slack API to push the image to a channel
throw NonRetriableError("Failed to push image to Slack channel ${ctx.event.data["image"]}")

"Image pushed to Slack channel"
}
}
7 changes: 5 additions & 2 deletions inngest/src/main/kotlin/com/inngest/Comm.kt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ data class CommError(
val __serialized: Boolean = true,
)

private val stepTerminalStatusCodes = setOf(ResultStatusCode.StepComplete, ResultStatusCode.StepError)

class CommHandler(
functions: Map<String, InngestFunction>,
val client: Inngest,
Expand Down Expand Up @@ -81,7 +83,7 @@ class CommHandler(

val result = function.call(ctx = ctx, client = client, requestBody)
var body: Any? = null
if (result.statusCode == ResultStatusCode.StepComplete || result is StepOptions) {
if (result.statusCode in stepTerminalStatusCodes || result is StepOptions) {
body = listOf(result)
}
if (result is StepResult && result.statusCode == ResultStatusCode.FunctionComplete) {
Expand All @@ -94,7 +96,8 @@ class CommHandler(
)
} catch (e: Exception) {
val retryDecision = RetryDecision.fromException(e)
val statusCode = if (retryDecision.shouldRetry) ResultStatusCode.RetriableError else ResultStatusCode.NonRetriableError
val statusCode =
if (retryDecision.shouldRetry) ResultStatusCode.RetriableError else ResultStatusCode.NonRetriableError

val err =
CommError(
Expand Down
11 changes: 11 additions & 0 deletions inngest/src/main/kotlin/com/inngest/Function.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import java.util.function.BiFunction
// TODO - Add an abstraction layer between the Function call response and the comm handler response
enum class OpCode {
StepRun,
StepError,
Sleep,
StepStateFailed, // TODO
Step,
Expand All @@ -21,6 +22,7 @@ enum class ResultStatusCode(
val message: String,
) {
StepComplete(206, "Step Complete"),
StepError(206, "Step Error"),
FunctionComplete(200, "Function Complete"),
NonRetriableError(400, "Bad Request"),
RetriableError(500, "Function Error"),
Expand All @@ -40,6 +42,7 @@ data class StepResult(
override val op: OpCode,
override val statusCode: ResultStatusCode,
val data: Any? = null,
val error: Exception? = null,
) : StepOp(id, name, op, statusCode)

data class StepOptions(
Expand Down Expand Up @@ -202,6 +205,14 @@ internal open class InternalInngestFunction(
}
},
)
} catch (e: StepInterruptErrorException) {
return StepResult(
id = e.hashedId,
name = e.id,
op = OpCode.StepError,
statusCode = ResultStatusCode.StepError,
error = e.error,
)
} catch (e: StepInterruptException) {
// NOTE - Currently this error could be caught in the user's own function
// that wraps a
Expand Down
4 changes: 2 additions & 2 deletions inngest/src/main/kotlin/com/inngest/State.kt
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class State(
val dataNode = stepResult.get(fieldName)
return mapper.treeToValue(dataNode, type)
} else if (stepResult.has("error")) {
// TODO - Parse the error and throw it
return null
val error = mapper.treeToValue(stepResult.get("error"), StepError::class.java)
throw error
}
// NOTE - Sleep steps will be stored as null
// TODO - Investigate if sendEvents stores null as well.
Expand Down
38 changes: 35 additions & 3 deletions inngest/src/main/kotlin/com/inngest/Step.kt
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ class StepInterruptWaitForEventException(
val ifExpression: String?,
) : StepInterruptException(id, hashedId, null)

class StepInterruptErrorException(
id: String,
hashedId: String,
val error: Exception,
) : StepInterruptException(id, hashedId, null)

class Step(
private val state: State,
val client: Inngest,
Expand All @@ -80,6 +86,8 @@ class Step(
*
* @param id unique step id for memoization
* @param fn the function to run
*
* @exception StepError if the function throws an [Exception].
*/
inline fun <reified T> run(
id: String,
Expand All @@ -100,15 +108,34 @@ class Step(
}
} catch (e: StateNotFound) {
// If there is no existing result, run the lambda
val data = fn()
throw StepInterruptException(id, hashedId, data)
executeStep(id, hashedId, fn)
} catch (e: StepError) {
throw e
}
// TODO - Catch Step Error here and throw it when error parsing is added to getState

// TODO - handle invalidly stored step types properly
throw Exception("step state incorrect type")
}

private fun <T> executeStep(
id: String,
hashedId: String,
fn: () -> T,
) {
try {
val data = fn()
throw StepInterruptException(id, hashedId, data)
} catch (exception: Exception) {
when (exception) {
is RetryAfterError,
is NonRetriableError,
-> throw exception

else -> throw StepInterruptErrorException(id, hashedId, exception)
}
}
}

/**
* Invoke another Inngest function as a step
*
Expand All @@ -118,6 +145,8 @@ class Step(
* @param data the data to pass within `event.data` to the function
* @param timeout an optional timeout for the invoked function. If the invoked function does
* not finish within this time, the invoked function will be marked as failed.
*
* @exception StepError if the invoked function fails.
*/
inline fun <reified T> invoke(
id: String,
Expand All @@ -143,7 +172,10 @@ class Step(
}
} catch (e: StateNotFound) {
throw StepInterruptInvokeException(id, hashedId, appId, fnId, data, timeout)
} catch (e: StepError) {
throw e
}

// TODO - handle invalidly stored step types properly
throw Exception("step state incorrect type")
}
Expand Down
Loading

0 comments on commit 54b58d7

Please sign in to comment.