55import heapq
66import logging
77import threading
8- import time
98from abc import ABC , abstractmethod
109from collections import Counter
1110from concurrent .futures import Future , ThreadPoolExecutor
1211from dataclasses import dataclass
12+ from datetime import UTC , datetime
1313from enum import Enum
1414from typing import TYPE_CHECKING , Generic , Self , TypeVar
1515
@@ -258,7 +258,7 @@ def __init__(self, executable: Executable[CallableType]):
258258 self .executable = executable
259259 self ._status = BranchStatus .PENDING
260260 self ._future : Future | None = None
261- self ._suspend_until : float | None = None
261+ self ._suspend_until : datetime | None = None
262262 self ._result : ResultType = None # type: ignore[assignment]
263263 self ._is_result_set : bool = False
264264 self ._error : Exception | None = None
@@ -293,7 +293,7 @@ def error(self) -> Exception:
293293 return self ._error
294294
295295 @property
296- def suspend_until (self ) -> float | None :
296+ def suspend_until (self ) -> datetime | None :
297297 """Get suspend timestamp."""
298298 return self ._suspend_until
299299
@@ -308,7 +308,7 @@ def can_resume(self) -> bool:
308308 return self ._status is BranchStatus .SUSPENDED or (
309309 self ._status is BranchStatus .SUSPENDED_WITH_TIMEOUT
310310 and self ._suspend_until is not None
311- and time . time ( ) >= self ._suspend_until
311+ and datetime . now ( UTC ) >= self ._suspend_until
312312 )
313313
314314 @property
@@ -333,7 +333,7 @@ def suspend(self) -> None:
333333 self ._status = BranchStatus .SUSPENDED
334334 self ._suspend_until = None
335335
336- def suspend_with_timeout (self , timestamp : float ) -> None :
336+ def suspend_with_timeout (self , timestamp : datetime ) -> None :
337337 """Transition to SUSPENDED_WITH_TIMEOUT state."""
338338 self ._status = BranchStatus .SUSPENDED_WITH_TIMEOUT
339339 self ._suspend_until = timestamp
@@ -507,11 +507,11 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
507507 self .shutdown ()
508508
509509 def schedule_resume (
510- self , exe_state : ExecutableWithState , resume_time : float
510+ self , exe_state : ExecutableWithState , resume_time : datetime
511511 ) -> None :
512512 """Schedule a task to resume at the specified time."""
513513 with self ._lock :
514- heapq .heappush (self ._pending_resumes , (resume_time , exe_state ))
514+ heapq .heappush (self ._pending_resumes , (resume_time . timestamp () , exe_state ))
515515
516516 def shutdown (self ) -> None :
517517 """Shutdown the timer thread and cancel all pending resumes."""
@@ -534,7 +534,7 @@ def _timer_loop(self) -> None:
534534 self ._shutdown .wait (timeout = 0.1 )
535535 continue
536536
537- current_time = time . time ()
537+ current_time = datetime . now ( UTC ). timestamp ()
538538 if current_time >= next_resume_time :
539539 # Time to resume
540540 with self ._lock :
@@ -566,6 +566,7 @@ def __init__(
566566 sub_type_iteration : OperationSubType ,
567567 name_prefix : str ,
568568 serdes : SerDes | None ,
569+ item_serdes : SerDes | None = None ,
569570 summary_generator : SummaryGenerator | None = None ,
570571 ):
571572 """Initialize ConcurrentExecutor.
@@ -604,6 +605,7 @@ def __init__(
604605 )
605606 self .executables_with_state : list [ExecutableWithState ] = []
606607 self .serdes = serdes
608+ self .item_serdes = item_serdes
607609
608610 @abstractmethod
609611 def execute_item (
@@ -673,7 +675,7 @@ def on_done(future: Future) -> None:
673675
674676 def should_execution_suspend (self ) -> SuspendResult :
675677 """Check if execution should suspend."""
676- earliest_timestamp : float = float ( "inf" )
678+ earliest_timestamp : datetime | None = None
677679 indefinite_suspend_task : (
678680 ExecutableWithState [CallableType , ResultType ] | None
679681 ) = None
@@ -683,16 +685,16 @@ def should_execution_suspend(self) -> SuspendResult:
683685 # Exit here! Still have tasks that can make progress, don't suspend.
684686 return SuspendResult .do_not_suspend ()
685687 if exe_state .status is BranchStatus .SUSPENDED_WITH_TIMEOUT :
686- if (
687- exe_state . suspend_until
688- and exe_state .suspend_until < earliest_timestamp
688+ if exe_state . suspend_until and (
689+ earliest_timestamp is None
690+ or exe_state .suspend_until < earliest_timestamp
689691 ):
690692 earliest_timestamp = exe_state .suspend_until
691693 elif exe_state .status is BranchStatus .SUSPENDED :
692694 indefinite_suspend_task = exe_state
693695
694696 # All tasks are in final states and at least one of them is a suspend.
695- if earliest_timestamp != float ( "inf" ) :
697+ if earliest_timestamp is not None :
696698 return SuspendResult .suspend (
697699 TimedSuspendExecution (
698700 "All concurrent work complete or suspended pending retry." ,
@@ -797,7 +799,7 @@ def execute_in_child_context(child_context: DurableContext) -> ResultType:
797799 execute_in_child_context ,
798800 f"{ self .name_prefix } { executable .index } " ,
799801 ChildConfig (
800- serdes = self .serdes ,
802+ serdes = self .item_serdes or self . serdes ,
801803 sub_type = self .sub_type_iteration ,
802804 summary_generator = self .summary_generator ,
803805 ),
0 commit comments