Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions src/aws_durable_functions_sdk_python/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Generic, TypeVar
Expand All @@ -18,6 +17,8 @@
from concurrent.futures import Future

from aws_durable_functions_sdk_python.lambda_service import OperationSubType
from aws_durable_functions_sdk_python.serdes import SerDes


Numeric = int | float # deliberately leaving off complex

Expand Down Expand Up @@ -82,16 +83,6 @@ class ParallelConfig:
serdes: SerDes | None = None


class SerDes(ABC, Generic[T]):
@abstractmethod
def serialize(self, value: T) -> str:
pass

@abstractmethod
def deserialize(self, data: str) -> T:
pass


class StepSemantics(Enum):
AT_MOST_ONCE_PER_RETRY = "AT_MOST_ONCE_PER_RETRY"
AT_LEAST_ONCE_PER_RETRY = "AT_LEAST_ONCE_PER_RETRY"
Expand Down
22 changes: 15 additions & 7 deletions src/aws_durable_functions_sdk_python/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import json
import logging
from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeVar

Expand All @@ -10,7 +9,6 @@
ChildConfig,
MapConfig,
ParallelConfig,
SerDes,
StepConfig,
WaitForCallbackConfig,
WaitForConditionConfig,
Expand Down Expand Up @@ -39,6 +37,7 @@
from aws_durable_functions_sdk_python.operation.wait_for_condition import (
wait_for_condition_handler,
)
from aws_durable_functions_sdk_python.serdes import SerDes, deserialize
from aws_durable_functions_sdk_python.state import ExecutionState # noqa: TCH001
from aws_durable_functions_sdk_python.threading import OrderedCounter
from aws_durable_functions_sdk_python.types import (
Expand Down Expand Up @@ -103,12 +102,12 @@ def __init__(
callback_id: str,
operation_id: str,
state: ExecutionState,
serdes: SerDes | None = None,
serdes: SerDes[T] | None = None,
):
self.callback_id: str = callback_id
self.operation_id: str = operation_id
self.state: ExecutionState = state
self.serdes: SerDes | None = serdes
self.serdes: SerDes[T] | None = serdes

def result(self) -> T | None:
"""Return the result of the future. Will block until result is available.
Expand All @@ -132,11 +131,15 @@ def result(self) -> T | None:
checkpointed_result.raise_callable_error()

if checkpointed_result.is_succeeded():
# TODO: serdes
if checkpointed_result.result is None:
return None # type: ignore

return json.loads(checkpointed_result.result)
return deserialize(
serdes=self.serdes,
data=checkpointed_result.result,
operation_id=self.operation_id,
durable_execution_arn=self.state.durable_execution_arn,
)

msg = "Callback must be started before you can await the result."
raise FatalError(msg)
Expand Down Expand Up @@ -270,6 +273,8 @@ def create_callback(
Return:
Callback future. Use result() on this future to wait for the callback resuilt.
"""
if not config:
config = CallbackConfig()
operation_id: str = self._create_step_id()
callback_id: str = create_callback_handler(
state=self.state,
Expand All @@ -280,7 +285,10 @@ def create_callback(
)

return Callback(
callback_id=callback_id, operation_id=operation_id, state=self.state
callback_id=callback_id,
operation_id=operation_id,
state=self.state,
serdes=config.serdes,
)

def map(
Expand Down
17 changes: 13 additions & 4 deletions src/aws_durable_functions_sdk_python/operation/child.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import json
import logging
from typing import TYPE_CHECKING, TypeVar

Expand All @@ -13,6 +12,7 @@
OperationSubType,
OperationUpdate,
)
from aws_durable_functions_sdk_python.serdes import deserialize, serialize

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -50,8 +50,12 @@ def child_handler(
)
if checkpointed_result.result is None:
return None # type: ignore
return json.loads(checkpointed_result.result)

return deserialize(
serdes=config.serdes,
data=checkpointed_result.result,
operation_id=operation_identifier.operation_id,
durable_execution_arn=state.durable_execution_arn,
)
if checkpointed_result.is_failed():
checkpointed_result.raise_callable_error()
sub_type = (
Expand All @@ -67,7 +71,12 @@ def child_handler(

try:
raw_result: T = func()
serialized_result: str = json.dumps(raw_result)
serialized_result: str = serialize(
serdes=config.serdes,
value=raw_result,
operation_id=operation_identifier.operation_id,
durable_execution_arn=state.durable_execution_arn,
)

success_operation = OperationUpdate.create_context_succeed(
identifier=operation_identifier,
Expand Down
17 changes: 13 additions & 4 deletions src/aws_durable_functions_sdk_python/operation/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import json
import logging
import time
from typing import TYPE_CHECKING, TypeVar
Expand All @@ -20,6 +19,7 @@
from aws_durable_functions_sdk_python.lambda_service import ErrorObject, OperationUpdate
from aws_durable_functions_sdk_python.logger import Logger, LogInfo
from aws_durable_functions_sdk_python.retries import RetryPresets
from aws_durable_functions_sdk_python.serdes import deserialize, serialize
from aws_durable_functions_sdk_python.types import StepContext

if TYPE_CHECKING:
Expand Down Expand Up @@ -59,11 +59,15 @@ def step_handler(
operation_identifier.operation_id,
operation_identifier.name,
)
# TODO: serdes
if checkpointed_result.result is None:
return None # type: ignore

return json.loads(checkpointed_result.result)
return deserialize(
serdes=config.serdes,
data=checkpointed_result.result,
operation_id=operation_identifier.operation_id,
durable_execution_arn=state.durable_execution_arn,
)

if checkpointed_result.is_failed():
# have to throw the exact same error on replay as the checkpointed failure
Expand Down Expand Up @@ -107,7 +111,12 @@ def step_handler(
try:
# this is the actual code provided by the caller to execute durably inside the step
raw_result: T = func(step_context)
serialized_result: str = json.dumps(raw_result)
serialized_result: str = serialize(
serdes=config.serdes,
value=raw_result,
operation_id=operation_identifier.operation_id,
durable_execution_arn=state.durable_execution_arn,
)

success_operation: OperationUpdate = OperationUpdate.create_step_succeed(
identifier=operation_identifier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import json
import logging
import time
from typing import TYPE_CHECKING, TypeVar
Expand All @@ -13,6 +12,7 @@
)
from aws_durable_functions_sdk_python.lambda_service import ErrorObject, OperationUpdate
from aws_durable_functions_sdk_python.logger import LogInfo
from aws_durable_functions_sdk_python.serdes import deserialize, serialize
from aws_durable_functions_sdk_python.types import WaitForConditionCheckContext

if TYPE_CHECKING:
Expand All @@ -26,6 +26,7 @@
from aws_durable_functions_sdk_python.logger import Logger
from aws_durable_functions_sdk_python.state import ExecutionState


T = TypeVar("T")

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -57,10 +58,14 @@ def wait_for_condition_handler(
operation_identifier.operation_id,
operation_identifier.name,
)
# TODO: use serdes from config
if checkpointed_result.result is None:
return None # type: ignore
return json.loads(checkpointed_result.result)
return deserialize(
serdes=config.serdes,
data=checkpointed_result.result,
operation_id=operation_identifier.operation_id,
durable_execution_arn=state.durable_execution_arn,
)

if checkpointed_result.is_failed():
checkpointed_result.raise_callable_error()
Expand All @@ -69,9 +74,13 @@ def wait_for_condition_handler(
if checkpointed_result.is_started_or_ready():
# This is a retry - get state from previous checkpoint
if checkpointed_result.result:
# TODO: serdes here
try:
current_state = json.loads(checkpointed_result.result)
current_state = deserialize(
serdes=config.serdes,
data=checkpointed_result.result,
operation_id=operation_identifier.operation_id,
durable_execution_arn=state.durable_execution_arn,
)
except Exception:
# default to initial state if there's an error getting checkpointed state
logger.exception(
Expand Down Expand Up @@ -117,8 +126,12 @@ def wait_for_condition_handler(
# Check if condition is met with the wait strategy
decision: WaitForConditionDecision = config.wait_strategy(new_state, attempt)

# TODO: SerDes here
serialized_state = json.dumps(new_state)
serialized_state = serialize(
serdes=config.serdes,
value=new_state,
operation_id=operation_identifier.operation_id,
durable_execution_arn=state.durable_execution_arn,
)

logger.debug(
"wait_for_condition check completed: %s, name: %s, attempt: %s",
Expand Down
74 changes: 74 additions & 0 deletions src/aws_durable_functions_sdk_python/serdes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Serialization and deserialization"""

import json
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, TypeVar

from aws_durable_functions_sdk_python.exceptions import FatalError

logger = logging.getLogger(__name__)

T = TypeVar("T")


@dataclass(frozen=True)
class SerDesContext:
operation_id: str
durable_execution_arn: str


class SerDes(ABC, Generic[T]):
@abstractmethod
def serialize(self, value: T, serdes_context: SerDesContext) -> str:
pass

@abstractmethod
def deserialize(self, data: str, serdes_context: SerDesContext) -> T:
pass


class JsonSerDes(SerDes[T]):
def serialize(self, value: T, _: SerDesContext) -> str:
return json.dumps(value)

def deserialize(self, data: str, _: SerDesContext) -> T:
return json.loads(data)


_DEFAULT_JSON_SERDES: SerDes = JsonSerDes()


def serialize(
serdes: SerDes[T] | None, value: T, operation_id: str, durable_execution_arn: str
) -> str:
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
if serdes is None:
serdes = _DEFAULT_JSON_SERDES
try:
return serdes.serialize(value, serdes_context)
except Exception as e:
logger.exception(
"⚠️ Serialization failed for id: %s",
operation_id,
)
msg = f"Serialization failed for id: {operation_id}, error: {e}."
raise FatalError(msg) from e


def deserialize(
serdes: SerDes[T] | None, data: str, operation_id: str, durable_execution_arn: str
) -> T:
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
if serdes is None:
serdes = _DEFAULT_JSON_SERDES
try:
return serdes.deserialize(data, serdes_context)
except Exception as e:
logger.exception(
"⚠️ Deserialization failed for id: %s",
operation_id,
)
msg = f"Deserialization failed for id: {operation_id}"
raise FatalError(msg) from e
Loading