Skip to content

Commit

Permalink
fix logger in http monitor
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Oct 1, 2024
1 parent c36c631 commit 2fcb1ba
Showing 1 changed file with 20 additions and 27 deletions.
47 changes: 20 additions & 27 deletions src/zeroband/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import importlib
from zeroband.utils.logging import get_logger

logger = get_logger()


class Monitor(Protocol):
def __init__(self, project, config): ...
Expand All @@ -23,11 +21,13 @@ class HttpMonitor:

def __init__(self, config, *args, **kwargs):
self.data = []
self.batch_size = config['monitor']['batch_size'] or 10
self.base_url = config['monitor']['base_url']
self.auth_token = config['monitor']['auth_token']
self.batch_size = config["monitor"]["batch_size"] or 10
self.base_url = config["monitor"]["base_url"]
self.auth_token = config["monitor"]["auth_token"]

self._logger = get_logger()

self.run_id = config.get('run_id', None)
self.run_id = config.get("run_id", None)
if self.run_id is None:
raise ValueError("run_id must be set for HttpMonitor")

Expand All @@ -43,13 +43,11 @@ def _remove_duplicates(self):

def set_stage(self, stage: str):
import time

# add a new log entry with the stage name
self.data.append({
"stage": stage,
"time": time.time()
})
self._handle_send_batch(flush=True) # it's useful to have the most up-to-date stage broadcasted

self.data.append({"stage": stage, "time": time.time()})
self._handle_send_batch(flush=True) # it's useful to have the most up-to-date stage broadcasted

def log(self, data: dict[str, Any]):
# Lowercase the keys in the data dictionary
lowercased_data = {k.lower(): v for k, v in data.items()}
Expand All @@ -60,21 +58,17 @@ def log(self, data: dict[str, Any]):
def _handle_send_batch(self, flush: bool = False):
if len(self.data) >= self.batch_size or flush:
import asyncio

# do this in a separate thread to not affect training loop
asyncio.create_task(self._send_batch())

async def _send_batch(self):
import aiohttp

self._remove_duplicates()
batch = self.data[:self.batch_size]
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.auth_token}"
}
payload = {
"logs": batch
}
batch = self.data[: self.batch_size]
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.auth_token}"}
payload = {"logs": batch}
api = f"{self.base_url}/training_runs/{self.run_id}/logs"

try:
Expand All @@ -83,28 +77,27 @@ async def _send_batch(self):
if response is not None:
await response.raise_for_status()
else:
logger.error("Received None response from server")
self._logger.error("Received None response from server")
pass

except Exception as e:
logger.error(f"Error sending batch to server: {str(e)}")
self._logger.error(f"Error sending batch to server: {str(e)}")
pass

self.data = self.data[self.batch_size:]
self.data = self.data[self.batch_size :]
return True

def _finish(self):
import requests
headers = {
"Content-Type": "application/json"
}

headers = {"Content-Type": "application/json"}
api = f"{self.base_url}/training_runs/{self.run_id}/finish"
try:
response = requests.post(api, headers=headers)
response.raise_for_status()
return True
except requests.RequestException as e:
logger.debug(f"Failed to send finish signal to http monitor: {e}")
self._logger.debug(f"Failed to send finish signal to http monitor: {e}")
return False

def finish(self):
Expand Down

0 comments on commit 2fcb1ba

Please sign in to comment.