Skip to content

Commit 81b763f

Browse files
authored
Adding back a file that copybara deleted unintentionally. (#103)
1 parent 77ba076 commit 81b763f

File tree

1 file changed

+249
-0
lines changed

1 file changed

+249
-0
lines changed

pathwaysutils/sidecar/README.md

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# Colocated Python
2+
3+
## Purpose
4+
5+
This package provides the Colocated Python Sidecar implementation. It describes how to build the sidecar container image with custom Python dependencies. This sidecar container runs on the TPU workers and facilitates remote Python code execution, enabling seamless integration between the user code and Python-based tasks on the TPU worker.
6+
7+
**Why use Colocated Python?**
8+
9+
Colocated Python enables users to execute code that runs explicitly on a specified set of TPU VMs using simple annotations and sharding information. This increases throughput on data or I/O intensive tasks like data loading (as implemented in [MaxText's RemoteIterator class](https://github.com/AI-Hypercomputer/maxtext/blob/391a5a788d85cae8942334b042fdabdbd549af51/MaxText/multihost_dataloading.py#L175)).
10+
11+
## Examples
12+
13+
### Simple (No User Dependencies)
14+
15+
The following small example is modified from [JAX](https://github.com/jax-ml/jax/blob/f4c727abb3989048f49e3d9a4bf2e4052969974b/tests/colocated_python_test.py#L78-L89) with no additional user dependencies installed. It shows how you can use the JAX Colocated Python API to create a file on the specified TPU worker.
16+
17+
```python
18+
import jax
19+
from jax.experimental import colocated_python
20+
from jax.experimental.colocated_python import serialization
21+
22+
@colocated_python.colocated_python
23+
def create_a_file(dummy):
24+
"""
25+
Creates a simple file on the TPU worker.
26+
"""
27+
filename = "my_new_file.txt"
28+
content_to_write = f"This is written on TPU worker {jax.process_id}"
29+
30+
try:
31+
with open(filename, 'w', encoding='utf-8') as file:
32+
file.write(content_to_write)
33+
print(f"Content written to '{filename}'.")
34+
35+
print(f"File '{filename}' created and closed.")
36+
except IOError as e:
37+
print(f"An error occurred: {e}")
38+
39+
return dummy
40+
41+
devices = jax.devices()
42+
dummy_array = np.array(1)
43+
dummy_array = jax.device_put(dummy_array, devices[0])
44+
45+
out = create_a_file(dummy_array)
46+
```
47+
48+
### Medium (With User Dependencies)
49+
50+
What if you want to add your own dependencies to do more advanced logic?
51+
52+
The following is a simple line chart of the first 5 primes in matplotlib that is saved locally to the TPU worker.
53+
54+
```python
55+
import jax
56+
import numpy as np
57+
from jax.experimental import colocated_python
58+
from jax.experimental.colocated_python import serialization
59+
60+
# User added dependency
61+
import matplotlib.pyplot as plt
62+
63+
@colocated_python.colocated_python
64+
def create_and_save_primes_plot(dummy):
65+
"""
66+
Creates a simple matplotlib line plot and saves it as a PNG image
67+
on the TPU worker.
68+
"""
69+
worker_id = jax.process_id()
70+
plot_filename = f"simple_line_plot_worker_{worker_id}.png"
71+
72+
# Sample data for the plot
73+
x_data = np.array([1, 2, 3, 4, 5])
74+
y_data = np.array([2, 3, 5, 7, 11])
75+
76+
try:
77+
# Create the line plot
78+
plt.figure(figsize=(6, 4))
79+
plt.plot(x_data, y_data, marker='o', linestyle='-')
80+
81+
# Add labels and title
82+
plt.xlabel("Nth Prime")
83+
plt.ylabel("Primes")
84+
plt.title(f"Simple Plot from TPU Worker {worker_id}")
85+
plt.grid(True)
86+
87+
# Save the plot to the specified file
88+
plt.savefig(plot_filename)
89+
print(f"Plot successfully saved to '{plot_filename}' on worker {worker_id}.")
90+
91+
plt.close()
92+
except Exception as e:
93+
print(f"An error occurred on worker {worker_id} while creating/saving the plot: {e}")
94+
95+
return dummy
96+
97+
devices = jax.devices()
98+
dummy_np_array = np.array(1, dtype=np.float32)
99+
dummy_device_array = jax.device_put(dummy_np_array, devices[0])
100+
out = create_and_save_plot(dummy_device_array)
101+
```
102+
103+
### Advanced Usage (With User Dependencies and Control Flow Logic)
104+
105+
For more advanced usage (such as data loading), check out [MaxText's RemoteIterator class](https://github.com/AI-Hypercomputer/maxtext/blob/391a5a788d85cae8942334b042fdabdbd549af51/MaxText/multihost_dataloading.py#L175).
106+
107+
See Installation and Usage for instructions on how to use MaxText out of the box with this feature.
108+
109+
### Verification
110+
111+
To verify files were created, SSH into one of the TPU workers using the following command and check that the file was created.
112+
113+
`kubectl exec -it <pod_name> -- /bin/sh -c "cat my_new_file.txt"`
114+
115+
Logs can also be verified by tailing the pod.
116+
117+
`kubectl logs -f <pod_name>`
118+
119+
## Installation and Usage
120+
121+
Follow these steps to set up, build, and deploy your application with the Colocated Python sidecar.
122+
123+
**Prerequisites**
124+
125+
Ensure [Docker](https://docs.docker.com/engine/install/) is installed on your system along with [gcloud](https://cloud.google.com/sdk/docs/install). Ensure you are authenticated into gcloud and Docker is configured for your region. For Google Artifact Registry, you typically run a command like this (replace `REGION` with the region of your repository, e.g., `us-east5`):
126+
127+
```bash
128+
gcloud auth login
129+
gcloud auth configure-docker REGION-docker.pkg.dev
130+
```
131+
132+
**1. Clone the Repository**
133+
134+
Get the necessary code and scripts.
135+
136+
```bash
137+
git clone https://github.com/AI-Hypercomputer/pathways-utils.git
138+
cd pathways-utils/sidecar/python
139+
```
140+
141+
**2. Prepare Sidecar Dependencies**
142+
143+
Update the file named `requirements.txt`. List all the additional Python packages you need specifically for the sidecar environment, one package per line.
144+
145+
These dependencies may be the same as your main workload's dependencies.
146+
147+
```
148+
# Example requirements.txt
149+
jax>=0.5.1
150+
tensorflow-datasets
151+
tiktoken
152+
grain-nightly>=0.0.10
153+
```
154+
155+
**3. Build the Colocated Python Sidecar Image and upload it to Artifact Registry**
156+
157+
Use the provided Dockerfile to create the sidecar image. This image will contain the required dependencies in your `requirements.txt`. Also specify the image location to upload to in Artifact Registry
158+
159+
```bash
160+
export PROJECT_ID=<your_project_id>
161+
export LOCAL_IMAGE_NAME=my-colocated-python-server
162+
export JAX_VERSION=0.5.3
163+
164+
docker build --build-arg JAX_VERSION=${JAX_VERSION} -t ${LOCAL_IMAGE_NAME} .
165+
```
166+
167+
Now you can upload the image to Google Artifact Registry. If you do not have an Artifact Registry repository, please follow the instructions [here](https://cloud.google.com/artifact-registry/docs/repositories/create-repos) to create one.
168+
169+
```bash
170+
export REGION=us # Your Region
171+
export ARTIFACT_REGISTRY_REPO=YOUR_ARTIFACT_REGISTRY_REPO
172+
export EXPORTED_IMAGE_LOCATION=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY_REPO}/my-colocated-python:latest
173+
174+
docker tag ${LOCAL_IMAGE_NAME} ${EXPORTED_IMAGE_LOCATION}
175+
docker push ${EXPORTED_IMAGE_LOCATION}
176+
177+
# Delete the local image as it's no longer needed.
178+
docker image rm ${LOCAL_IMAGE_NAME}
179+
```
180+
181+
**4. Update Deployment Configuration**
182+
183+
***Simple Example***
184+
185+
Modify your Kubernetes deployment YAML file to use your colocated python sidecar image. This assumes you are using the [pathways-job](https://github.com/google/pathways-job) api.
186+
187+
For example, if using 2 v4-16 TPUs, use the following yaml. This example is modified from [pathways-job](https://github.com/google/pathways-job/blob/main/config/samples/colocated_python_example_pathwaysjob.yaml).
188+
189+
If you do not have an existing GCS Bucket, instructions to create one are [here](https://cloud.google.com/storage/docs/creating-buckets).
190+
191+
```yaml
192+
apiVersion: pathways-job.pathways.domain/v1
193+
kind: PathwaysJob
194+
metadata:
195+
name: pathways-colocated
196+
spec:
197+
maxRestarts: 0
198+
customComponents:
199+
- componentType: colocated_python_sidecar
200+
image: <location of your colocated python sidecar server image>
201+
workers:
202+
- type: ct4p-hightpu-4t
203+
topology: 2x2x2
204+
numSlices: 2
205+
pathwaysDir: "gs://<test-bucket>/tmp" # This bucket needs to be created in advance.
206+
controller:
207+
# Pod template for training, default mode.
208+
deploymentMode: default
209+
mainContainerName: main
210+
template: # UserPodTemplate
211+
spec:
212+
containers:
213+
- name: main
214+
image: python:3.12
215+
imagePullPolicy: Always
216+
command:
217+
- /bin/sh
218+
- -c
219+
- |
220+
pip install --upgrade pip
221+
pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
222+
pip install pathwaysutils
223+
python -c "import jax; import pathwaysutils; print(\"Number of JAX devices is\", len(jax.devices()))"
224+
```
225+
226+
***MaxText Reference Example***
227+
228+
If using MaxText, to turn on the data loading optimization that uses Colocated Python feature.
229+
230+
```python
231+
colocated_python_data_input=True
232+
```
233+
234+
**5. Deploy the Application**
235+
236+
Apply the updated deployment configuration to your Kubernetes cluster:
237+
238+
```bash
239+
kubectl apply -f path/to/your/deployment.yaml
240+
```
241+
242+
This will create the necessary pods with your application, pathways head, and the Colocated Python sidecar containers.
243+
244+
## The Sharp Bits 🔪
245+
246+
**User Dependency Conflicts**
247+
248+
Colocated Python relies on specific internal dependencies, including JAX. Refer to the provided `server_requirements.txt` for the complete list of required dependencies. Using a different dependency version than the one provided in `server_requirements.txt` will cause the Colocated Python image build to fail.
249+

0 commit comments

Comments
 (0)