Skip to content

Commit 24c410c

Browse files
authored
Merge pull request exo-explore#653 from exo-explore/tinyfixes
Tiny fixes
2 parents f6ed830 + e6b4f29 commit 24c410c

File tree

5 files changed

+63
-58
lines changed

5 files changed

+63
-58
lines changed

exo/download/new_shard_download.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
105105
elapsed_time = time.time() - all_start_time
106106
all_speed = all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
107107
all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
108-
status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress"
108+
status = "complete" if all(p.status == "complete" for p in file_progress.values()) else "in_progress" if any(p.status == "in_progress" for p in file_progress.values()) else "not_started"
109109
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)
110110

111111
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
@@ -147,12 +147,12 @@ def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
147147
downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
148148
speed = downloaded_this_session / (time.time() - start_time)
149149
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
150-
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "in_progress", start_time)
150+
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time)
151151
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
152152
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
153153
for file in filtered_file_list:
154154
downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
155-
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time())
155+
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
156156

157157
semaphore = asyncio.Semaphore(max_parallel_downloads)
158158
async def download_with_semaphore(file):

exo/inference/tinygrad/inference.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
6161

6262
return model
6363

64+
_executor = ThreadPoolExecutor(max_workers=1) # singleton so tinygrad always runs on the same thread
6465
class TinygradDynamicShardInferenceEngine(InferenceEngine):
6566
def __init__(self, shard_downloader: ShardDownloader):
6667
self.shard = None
6768
self.shard_downloader = shard_downloader
68-
self.executor = ThreadPoolExecutor(max_workers=1)
6969
self.states = OrderedDict()
70+
self.executor = _executor
7071

7172
def poll_state(self, x, request_id: str, max_states=2):
7273
if request_id not in self.states:
@@ -79,8 +80,8 @@ def poll_state(self, x, request_id: str, max_states=2):
7980
return {"start_pos": state.start, "cache": state.cache}
8081

8182
async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
82-
logits = x[:, -1, :]
8383
def sample_wrapper():
84+
logits = x[:, -1, :]
8485
return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
8586
return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
8687

@@ -112,9 +113,9 @@ def wrap_infer():
112113
state = self.poll_state(h, request_id)
113114
out = self.model.forward(h, **state)
114115
self.states[request_id].start += x.shape[1]
115-
return out.realize()
116+
return out.numpy()
116117
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
117-
return output_data.numpy(), inference_state
118+
return output_data, inference_state
118119

119120
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
120121
def step(x, y, l):

exo/main.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,16 @@ def preemptively_load_shard(request_id: str, opaque_status: str):
206206
traceback.print_exc()
207207
node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard)
208208

209-
last_broadcast_time = 0
209+
last_events: dict[str, tuple[float, RepoProgressEvent]] = {}
210210
def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
211-
global last_broadcast_time
211+
global last_events
212212
current_time = time.time()
213213
if event.status == "not_started": return
214-
if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
215-
last_broadcast_time = current_time
216-
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
214+
last_event = last_events.get(shard.model_id)
215+
if last_event and last_event[1].status == "complete" and event.status == "complete": return
216+
if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: return
217+
last_events[shard.model_id] = (current_time, event)
218+
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
217219
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
218220

219221
async def run_model_cli(node: Node, model_name: str, prompt: str):

exo/viz/topology_viz.py

+47-45
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,16 @@ def _generate_prompt_output_layout(self) -> Panel:
8989
# Calculate available height for content
9090
panel_height = 15 # Fixed panel height
9191
available_lines = panel_height - 2 # Subtract 2 for panel borders
92-
lines_per_entry = available_lines // len(requests) if requests else 0
92+
lines_per_request = available_lines // len(requests) if requests else 0
9393

9494
for (prompt, output) in reversed(requests):
9595
prompt_icon, output_icon = "💬️", "🤖"
9696

97-
# Calculate max lines for prompt and output
98-
max_prompt_lines = max(3, lines_per_entry // 2) # Ensure at least 3 lines for prompt
99-
max_output_lines = lines_per_entry - max_prompt_lines - 1 # Remaining space minus spacing
97+
# Equal space allocation for prompt and output
98+
max_prompt_lines = lines_per_request // 2
99+
max_output_lines = lines_per_request - max_prompt_lines - 1 # -1 for spacing
100100

101-
# Process prompt with more generous line allocation
101+
# Process prompt
102102
prompt_lines = []
103103
for line in prompt.split('\n'):
104104
words = line.split()
@@ -118,53 +118,55 @@ def _generate_prompt_output_layout(self) -> Panel:
118118
if current_line:
119119
prompt_lines.append(' '.join(current_line))
120120

121-
# Show more prompt content and append ellipses to last line if needed
121+
# Truncate prompt if needed
122122
if len(prompt_lines) > max_prompt_lines:
123123
prompt_lines = prompt_lines[:max_prompt_lines]
124-
# Append ellipses to last line if there's room, otherwise truncate last line
125-
last_line = prompt_lines[-1]
126-
if len(last_line) + 4 <= max_width: # +4 for " ..."
127-
prompt_lines[-1] = last_line + " ..."
128-
else:
129-
prompt_lines[-1] = last_line[:max_width-4] + " ..."
124+
if prompt_lines:
125+
last_line = prompt_lines[-1]
126+
if len(last_line) + 4 <= max_width:
127+
prompt_lines[-1] = last_line + " ..."
128+
else:
129+
prompt_lines[-1] = last_line[:max_width-4] + " ..."
130130

131131
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
132132
prompt_text.append('\n'.join(prompt_lines), style="white")
133+
content.append(prompt_text)
133134

134-
# Process output - same word-aware wrapping
135-
output_lines = []
136-
for line in output.split('\n'):
137-
words = line.split()
138-
current_line = []
139-
current_length = 0
140-
141-
for word in words:
142-
if current_length + len(word) + 1 <= max_width:
143-
current_line.append(word)
144-
current_length += len(word) + 1
145-
else:
146-
if current_line:
147-
output_lines.append(' '.join(current_line))
148-
current_line = [word]
149-
current_length = len(word)
150-
151-
if current_line:
152-
output_lines.append(' '.join(current_line))
153-
154-
if len(output_lines) > max_output_lines:
155-
output_lines = output_lines[:max_output_lines]
156-
last_line = output_lines[-1] if output_lines else None
157-
if last_line:
158-
if len(last_line) + 4 <= max_width:
159-
output_lines[-1] = last_line + " ..."
160-
else:
161-
output_lines[-1] = last_line[:max_width-4] + " ..."
162-
163-
output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
164-
output_text.append('\n'.join(output_lines), style="white")
135+
# Process output with similar word wrapping
136+
if output: # Only process output if it exists
137+
output_lines = []
138+
for line in output.split('\n'):
139+
words = line.split()
140+
current_line = []
141+
current_length = 0
142+
143+
for word in words:
144+
if current_length + len(word) + 1 <= max_width:
145+
current_line.append(word)
146+
current_length += len(word) + 1
147+
else:
148+
if current_line:
149+
output_lines.append(' '.join(current_line))
150+
current_line = [word]
151+
current_length = len(word)
152+
153+
if current_line:
154+
output_lines.append(' '.join(current_line))
155+
156+
# Truncate output if needed
157+
if len(output_lines) > max_output_lines:
158+
output_lines = output_lines[:max_output_lines]
159+
if output_lines:
160+
last_line = output_lines[-1]
161+
if len(last_line) + 4 <= max_width:
162+
output_lines[-1] = last_line + " ..."
163+
else:
164+
output_lines[-1] = last_line[:max_width-4] + " ..."
165+
166+
output_text = Text(f"{output_icon} ", style="bold bright_magenta")
167+
output_text.append('\n'.join(output_lines), style="white")
168+
content.append(output_text)
165169

166-
content.append(prompt_text)
167-
content.append(output_text)
168170
content.append(Text()) # Empty line between entries
169171

170172
return Panel(

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"transformers==4.46.3",
3030
"uuid==1.30",
3131
"uvloop==0.21.0",
32-
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
32+
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@ec120ce6b9ce8e4ff4b5692566a683ef240e8bc8",
3333
]
3434

3535
extras_require = {

0 commit comments

Comments
 (0)