From bff1d69f241d0158d84d21fb55343ccf613a95c5 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang <100308595+nvidianz@users.noreply.github.com> Date: Mon, 9 Sep 2024 17:04:19 -0400 Subject: [PATCH] Fixing memoryview error (#2929) * Fixed dup seq 0 bug * Formatting errors --- nvflare/fuel/f3/streaming/blob_streamer.py | 14 +++--- nvflare/fuel/f3/streaming/byte_receiver.py | 55 ++++++++++++---------- 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/nvflare/fuel/f3/streaming/blob_streamer.py b/nvflare/fuel/f3/streaming/blob_streamer.py index 282d8e8050..506ecd1692 100644 --- a/nvflare/fuel/f3/streaming/blob_streamer.py +++ b/nvflare/fuel/f3/streaming/blob_streamer.py @@ -77,6 +77,9 @@ def __init__(self, future: StreamFuture, stream: Stream): else: self.buffer = FastBuffer() + def __str__(self): + return f"Blob[SID:{self.future.get_stream_id()} Sizeļ¼š{self.size}]" + class BlobHandler: def __init__(self, blob_cb: Callable): @@ -113,23 +116,22 @@ def _read_stream(blob_task: BlobTask): if blob_task.pre_allocated: remaining = len(blob_task.buffer) - buf_size if length > remaining: - log.error(f"Buffer overrun: {remaining=} {length=} {buf_size=}") + log.error(f"{blob_task} Buffer overrun: {remaining=} {length=} {buf_size=}") if remaining > 0: blob_task.buffer[buf_size : buf_size + remaining] = buf[0:remaining] + break else: blob_task.buffer[buf_size : buf_size + length] = buf else: blob_task.buffer.append(buf) except Exception as ex: - log.error(f"memory view error: {ex} Debug info: {length=} {buf_size=} {type(buf)=}") + log.error(f"{blob_task} memoryview error: {ex} Debug info: {length=} {buf_size=} {type(buf)=}") raise ex buf_size += length if blob_task.size and blob_task.size != buf_size: - log.warning( - f"Stream {blob_task.future.get_stream_id()} size doesn't match: " f"{blob_task.size} <> {buf_size}" - ) + log.warning(f"Stream {blob_task} Size doesn't match: " f"{blob_task.size} <> {buf_size}") if blob_task.pre_allocated: result = blob_task.buffer @@ -138,7 +140,7 @@ def _read_stream(blob_task: BlobTask): blob_task.future.set_result(result) except Exception as ex: - log.error(f"Stream {blob_task.future.get_stream_id()} read error: {ex}") + log.error(f"Stream {blob_task} Read error: {ex}") log.error(secure_format_traceback()) blob_task.future.set_exception(ex) diff --git a/nvflare/fuel/f3/streaming/byte_receiver.py b/nvflare/fuel/f3/streaming/byte_receiver.py index a9309bc030..f2aff8d0aa 100644 --- a/nvflare/fuel/f3/streaming/byte_receiver.py +++ b/nvflare/fuel/f3/streaming/byte_receiver.py @@ -71,7 +71,7 @@ def __init__(self, sid: int, origin: str): self.last_chunk_received = False def __str__(self): - return f"Rx[SID:{self.sid} from {self.origin} for {self.channel}/{self.topic}]" + return f"Rx[SID:{self.sid} from {self.origin} for {self.channel}/{self.topic} Size: {self.size}]" class RxStream(Stream): @@ -98,9 +98,7 @@ def read(self, chunk_size: int) -> bytes: # Block if buffers are empty if count > 0: - log.warning(f"Read block is unblocked multiple times: {count}") - - self.task.waiter.clear() + log.warning(f"{self.task} Read block is unblocked multiple times: {count}") if not self.task.waiter.wait(self.timeout): error = StreamError(f"{self.task} read timed out after {self.timeout} seconds") @@ -117,6 +115,7 @@ def _read_chunk(self, chunk_size: int) -> Tuple[int, Optional[BytesAlike]]: if self.task.eos: return RESULT_EOS, None else: + self.task.waiter.clear() return RESULT_WAIT, None last_chunk, buf = self.task.buffers.popleft() @@ -239,33 +238,39 @@ def _data_handler(self, message: Message): self.stop_task(task, StreamError(f"Received error from {origin}: {error}"), notify=False) return - if seq == 0: - # Handle new stream - task.channel = message.get_header(StreamHeaderKey.CHANNEL) - task.topic = message.get_header(StreamHeaderKey.TOPIC) - task.headers = message.headers + with task.task_lock: + if seq == 0: + # Handle new stream + task.channel = message.get_header(StreamHeaderKey.CHANNEL) + task.topic = message.get_header(StreamHeaderKey.TOPIC) + task.headers = message.headers + + # GRPC may re-send the same request, causing seq 0 delivered more than once + if task.stream_future: + log.warning(f"{task} Received duplicate chunk 0, ignored") + return - task.stream_future = StreamFuture(sid, message.headers) - task.size = message.get_header(StreamHeaderKey.SIZE, 0) - task.stream_future.set_size(task.size) + task.stream_future = StreamFuture(sid, message.headers) + task.size = message.get_header(StreamHeaderKey.SIZE, 0) + task.stream_future.set_size(task.size) - # Invoke callback - callback = self.registry.find(task.channel, task.topic) - if not callback: - self.stop_task(task, StreamError(f"No callback is registered for {task.channel}/{task.topic}")) - return + # Invoke callback + callback = self.registry.find(task.channel, task.topic) + if not callback: + self.stop_task(task, StreamError(f"No callback is registered for {task.channel}/{task.topic}")) + return - self.received_stream_counter_pool.increment( - category=stream_stats_category(task.channel, task.topic, "stream"), counter_name=COUNTER_NAME_RECEIVED - ) + self.received_stream_counter_pool.increment( + category=stream_stats_category(task.channel, task.topic, "stream"), + counter_name=COUNTER_NAME_RECEIVED, + ) - self.received_stream_size_pool.record_value( - category=stream_stats_category(task.channel, task.topic, "stream"), value=task.size / ONE_MB - ) + self.received_stream_size_pool.record_value( + category=stream_stats_category(task.channel, task.topic, "stream"), value=task.size / ONE_MB + ) - stream_thread_pool.submit(self._callback_wrapper, task, callback) + stream_thread_pool.submit(self._callback_wrapper, task, callback) - with task.task_lock: data_type = message.get_header(StreamHeaderKey.DATA_TYPE) last_chunk = data_type == StreamDataType.FINAL if last_chunk: