Skip to content

Commit

Permalink
feat: leverage ten's log
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyoucao577 authored Nov 16, 2024
1 parent a9655f2 commit 8dc90f1
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 64 deletions.
15 changes: 4 additions & 11 deletions agents/ten_packages/extension/interrupt_detector/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ package extension

import (
"fmt"
"log/slog"

"ten_framework/ten"
)
Expand All @@ -24,10 +23,6 @@ const (
cmdNameFlush = "flush"
)

var (
logTag = slog.String("extension", "INTERRUPT_DETECTOR_EXTENSION")
)

type interruptDetectorExtension struct {
ten.DefaultExtension
}
Expand All @@ -47,29 +42,27 @@ func (p *interruptDetectorExtension) OnData(
) {
text, err := data.GetPropertyString(textDataTextField)
if err != nil {
slog.Warn(fmt.Sprintf("OnData GetProperty %s error: %v", textDataTextField, err), logTag)
tenEnv.LogWarn(fmt.Sprintf("OnData GetProperty %s error: %v", textDataTextField, err))
return
}

final, err := data.GetPropertyBool(textDataFinalField)
if err != nil {
slog.Warn(fmt.Sprintf("OnData GetProperty %s error: %v", textDataFinalField, err), logTag)
tenEnv.LogWarn(fmt.Sprintf("OnData GetProperty %s error: %v", textDataFinalField, err))
return
}

slog.Debug(fmt.Sprintf("OnData %s: %s %s: %t", textDataTextField, text, textDataFinalField, final), logTag)
tenEnv.LogDebug(fmt.Sprintf("OnData %s: %s %s: %t", textDataTextField, text, textDataFinalField, final))

if final || len(text) >= 2 {
flushCmd, _ := ten.NewCmd(cmdNameFlush)
tenEnv.SendCmd(flushCmd, nil)

slog.Info(fmt.Sprintf("sent cmd: %s", cmdNameFlush), logTag)
tenEnv.LogInfo(fmt.Sprintf("sent cmd: %s", cmdNameFlush))
}
}

func init() {
slog.Info("interrupt_detector extension init", logTag)

// Register addon
ten.RegisterAddonAsExtension(
"interrupt_detector",
Expand Down
3 changes: 0 additions & 3 deletions agents/ten_packages/extension/message_collector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,3 @@
#
#
from .src import addon
from .src.log import logger

logger.info("message_collector extension loaded")
6 changes: 3 additions & 3 deletions agents/ten_packages/extension/message_collector/src/addon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ class MessageCollectorExtensionAddon(Addon):

def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None:
from .extension import MessageCollectorExtension
from .log import logger
logger.info("MessageCollectorExtensionAddon on_create_instance")
ten_env.on_create_instance_done(MessageCollectorExtension(name), context)
ten_env.log_info("on_create_instance")
ten_env.on_create_instance_done(
MessageCollectorExtension(name), context)
56 changes: 31 additions & 25 deletions agents/ten_packages/extension/message_collector/src/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Data,
)
import asyncio
from .log import logger

MAX_SIZE = 800 # 1 KB limit
OVERHEAD_ESTIMATE = 200 # Estimate for the overhead of metadata in the JSON
Expand All @@ -37,38 +36,40 @@
cached_text_map = {}
MAX_CHUNK_SIZE_BYTES = 1024


def _text_to_base64_chunks(text: str, msg_id: str) -> list:
# Ensure msg_id does not exceed 50 characters
if len(msg_id) > 36:
raise ValueError("msg_id cannot exceed 36 characters.")

# Convert text to bytearray
byte_array = bytearray(text, 'utf-8')

# Encode the bytearray into base64
base64_encoded = base64.b64encode(byte_array).decode('utf-8')

# Initialize list to hold the final chunks
chunks = []

# We'll split the base64 string dynamically based on the final byte size
part_index = 0
total_parts = None # We'll calculate total parts once we know how many chunks we create

# Process the base64-encoded content in chunks
current_position = 0
total_length = len(base64_encoded)

while current_position < total_length:
part_index += 1

# Start guessing the chunk size by limiting the base64 content part
estimated_chunk_size = MAX_CHUNK_SIZE_BYTES # We'll reduce this dynamically
content_chunk = ""
count = 0
while True:
# Create the content part of the chunk
content_chunk = base64_encoded[current_position:current_position + estimated_chunk_size]
content_chunk = base64_encoded[current_position:
current_position + estimated_chunk_size]

# Format the chunk
formatted_chunk = f"{msg_id}|{part_index}|{total_parts if total_parts else '???'}|{content_chunk}"
Expand All @@ -81,11 +82,12 @@ def _text_to_base64_chunks(text: str, msg_id: str) -> list:
estimated_chunk_size -= 100 # Reduce content size gradually
count += 1

logger.debug(f"chunk estimate guess: {count}")
# logger.debug(f"chunk estimate guess: {count}")

# Add the current chunk to the list
chunks.append(formatted_chunk)
current_position += estimated_chunk_size # Move to the next part of the content
# Move to the next part of the content
current_position += estimated_chunk_size

# Now that we know the total number of parts, update the chunks with correct total_parts
total_parts = len(chunks)
Expand All @@ -95,19 +97,21 @@ def _text_to_base64_chunks(text: str, msg_id: str) -> list:

return updated_chunks


class MessageCollectorExtension(Extension):
# Create the queue for message processing
queue = asyncio.Queue()

def on_init(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_init")
ten_env.log_info("on_init")
ten_env.on_init_done()

def on_start(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_start")
ten_env.log_info("on_start")

# TODO: read properties, initialize resources
self.loop = asyncio.new_event_loop()

def start_loop():
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
Expand All @@ -118,19 +122,19 @@ def start_loop():
ten_env.on_start_done()

def on_stop(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_stop")
ten_env.log_info("on_stop")

# TODO: clean up resources

ten_env.on_stop_done()

def on_deinit(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_deinit")
ten_env.log_info("on_deinit")
ten_env.on_deinit_done()

def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()
logger.info("on_cmd name {}".format(cmd_name))
ten_env.log_info("on_cmd name {}".format(cmd_name))

# TODO: process cmd

Expand All @@ -145,7 +149,7 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None:
example:
{"name": "text_data", "properties": {"text": "hello", "is_final": true, "stream_id": 123, "end_of_segment": true}}
"""
logger.debug(f"on_data")
# ten_env.log_debug(f"on_data")

text = ""
final = True
Expand All @@ -155,7 +159,7 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None:
try:
text = data.get_property_string(TEXT_DATA_TEXT_FIELD)
except Exception as e:
logger.exception(
ten_env.log_error(
f"on_data get_property_string {TEXT_DATA_TEXT_FIELD} error: {e}"
)

Expand All @@ -170,13 +174,14 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None:
pass

try:
end_of_segment = data.get_property_bool(TEXT_DATA_END_OF_SEGMENT_FIELD)
end_of_segment = data.get_property_bool(
TEXT_DATA_END_OF_SEGMENT_FIELD)
except Exception as e:
logger.warning(
ten_env.log_warn(
f"on_data get_property_bool {TEXT_DATA_END_OF_SEGMENT_FIELD} error: {e}"
)

logger.debug(
ten_env.log_info(
f"on_data {TEXT_DATA_TEXT_FIELD}: {text} {TEXT_DATA_FINAL_FIELD}: {final} {TEXT_DATA_STREAM_ID_FIELD}: {stream_id} {TEXT_DATA_END_OF_SEGMENT_FIELD}: {end_of_segment}"
)

Expand Down Expand Up @@ -207,12 +212,14 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None:
}

try:
chunks = _text_to_base64_chunks(json.dumps(base_msg_data), message_id)
chunks = _text_to_base64_chunks(
json.dumps(base_msg_data), message_id)
for chunk in chunks:
asyncio.run_coroutine_threadsafe(self._queue_message(chunk), self.loop)
asyncio.run_coroutine_threadsafe(
self._queue_message(chunk), self.loop)

except Exception as e:
logger.warning(f"on_data new_data error: {e}")
ten_env.log_warn(f"on_data new_data error: {e}")
return

def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
Expand All @@ -223,7 +230,6 @@ def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None:
# TODO: process image frame
pass


async def _queue_message(self, data: str):
await self.queue.put(data)

Expand All @@ -237,4 +243,4 @@ async def _process_queue(self, ten_env: TenEnv):
ten_data.set_property_buf("data", data.encode())
ten_env.send_data(ten_data)
self.queue.task_done()
await asyncio.sleep(0.04)
await asyncio.sleep(0.04)
22 changes: 0 additions & 22 deletions agents/ten_packages/extension/message_collector/src/log.py

This file was deleted.

0 comments on commit 8dc90f1

Please sign in to comment.