Skip to content

Commit b5e079c

Browse files
authored
Fix docker URL in xpk (#54) (#55)
I forgot to replace the hardcoded pytorch-tpu/llama with the autogenerated docker URL. Unfortunately in order to return a URL from buildpush, we'll have to convert that from bash to Python.
1 parent 6113639 commit b5e079c

File tree

3 files changed

+85
-38
lines changed

3 files changed

+85
-38
lines changed

torchprime/launcher/buildpush.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#!/usr/bin/env python3
2+
3+
import datetime
4+
import getpass
5+
import grp
6+
import os
7+
import random
8+
import string
9+
import subprocess
10+
from pathlib import Path
11+
12+
import click
13+
14+
15+
def buildpush(
16+
torchprime_project_id,
17+
torchprime_docker_url=None,
18+
torchprime_docker_tag=None,
19+
) -> str:
20+
# Determine the path of this script and its directory
21+
script_path = os.path.realpath(__file__)
22+
script_dir = Path(os.path.dirname(script_path))
23+
context_dir = script_dir.parent.parent.relative_to(os.getcwd())
24+
docker_file = (script_dir / "Dockerfile").relative_to(os.getcwd())
25+
26+
# Check if the user is in the 'docker' group
27+
user = getpass.getuser()
28+
groups_for_user = [g.gr_name for g in grp.getgrall() if user in g.gr_mem]
29+
sudo_cmd = "" if "docker" in groups_for_user else "sudo"
30+
31+
# Generate date/time string and 4 random lowercase letters
32+
datetime_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
33+
random_chars = "".join(random.choices(string.ascii_lowercase, k=4))
34+
35+
# Determine Docker tag
36+
default_tag = f"{datetime_str}-{random_chars}"
37+
docker_tag = torchprime_docker_tag if torchprime_docker_tag else default_tag
38+
39+
# Determine Docker URL
40+
default_url = f"gcr.io/{torchprime_project_id}/torchprime-{user}:{docker_tag}"
41+
docker_url = torchprime_docker_url if torchprime_docker_url else default_url
42+
43+
print()
44+
print(f"Will build a docker image and upload to: {docker_url}")
45+
print()
46+
47+
# Build, tag, and push Docker image
48+
try:
49+
_run(
50+
f"{sudo_cmd} docker build --network=host --progress=auto -t {docker_tag} {context_dir} -f {docker_file}",
51+
)
52+
_run(
53+
f"{sudo_cmd} docker tag {docker_tag} {docker_url}",
54+
)
55+
_run(f"{sudo_cmd} docker push {docker_url}")
56+
except subprocess.CalledProcessError as e:
57+
print(f"Error running command: {e}")
58+
exit(e.returncode)
59+
60+
return docker_url
61+
62+
63+
def _run(command):
64+
click.echo(command)
65+
subprocess.run(
66+
command,
67+
shell=True,
68+
check=True,
69+
)
70+
71+
72+
if __name__ == "__main__":
73+
# Read environment variables or use defaults
74+
torchprime_project_id = os.getenv("TORCHPRIME_PROJECT_ID", "tpu-pytorch")
75+
torchprime_docker_url = os.getenv("TORCHPRIME_DOCKER_URL", None)
76+
torchprime_docker_tag = os.getenv("TORCHPRIME_DOCKER_TAG", None)
77+
buildpush(
78+
torchprime_project_id,
79+
torchprime_docker_url,
80+
torchprime_docker_tag,
81+
)

torchprime/launcher/buildpush.sh

Lines changed: 0 additions & 35 deletions
This file was deleted.

torchprime/launcher/cli.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from watchdog.events import FileSystemEventHandler
2121
from watchdog.observers import Observer
2222

23+
from torchprime.launcher.buildpush import buildpush
24+
2325

2426
@dataclass_json
2527
@dataclass
@@ -186,8 +188,7 @@ def run(args):
186188
docker_project = config.docker_project
187189
if docker_project is None:
188190
docker_project = config.project
189-
os.environ["TORCHPRIME_PROJECT_ID"] = docker_project
190-
assert os.system(Path(__file__).parent / "buildpush.sh") == 0
191+
docker_url = buildpush(docker_project)
191192

192193
# Submit xpk workload
193194
datetime_str = datetime.now().strftime("%Y%m%d-%H%M%S")
@@ -212,7 +213,7 @@ def run(args):
212213
"--cluster",
213214
config.cluster,
214215
"--docker-image",
215-
"gcr.io/tpu-pytorch/llama3:latest",
216+
docker_url,
216217
"--workload",
217218
f"{os.environ['USER']}-xpk-{config.tpu_type}-{config.num_slices}-{datetime_str}",
218219
"--tpu-type",

0 commit comments

Comments
 (0)