|
| 1 | +# Colocated Python |
| 2 | + |
| 3 | +## Purpose |
| 4 | + |
| 5 | +This package provides the Colocated Python Sidecar implementation. It describes how to build the sidecar container image with custom Python dependencies. This sidecar container runs on the TPU workers and facilitates remote Python code execution, enabling seamless integration between the user code and Python-based tasks on the TPU worker. |
| 6 | + |
| 7 | +**Why use Colocated Python?** |
| 8 | + |
| 9 | +Colocated Python enables users to execute code that runs explicitly on a specified set of TPU VMs using simple annotations and sharding information. This increases throughput on data or I/O intensive tasks like data loading (as implemented in [MaxText's RemoteIterator class](https://github.com/AI-Hypercomputer/maxtext/blob/391a5a788d85cae8942334b042fdabdbd549af51/MaxText/multihost_dataloading.py#L175)). |
| 10 | + |
| 11 | +## Examples |
| 12 | + |
| 13 | +### Simple (No User Dependencies) |
| 14 | + |
| 15 | +The following small example is modified from [JAX](https://github.com/jax-ml/jax/blob/f4c727abb3989048f49e3d9a4bf2e4052969974b/tests/colocated_python_test.py#L78-L89) with no additional user dependencies installed. It shows how you can use the JAX Colocated Python API to create a file on the specified TPU worker. |
| 16 | + |
| 17 | +```python |
| 18 | +import jax |
| 19 | +from jax.experimental import colocated_python |
| 20 | +from jax.experimental.colocated_python import serialization |
| 21 | + |
| 22 | +@colocated_python.colocated_python |
| 23 | +def create_a_file(dummy): |
| 24 | + """ |
| 25 | + Creates a simple file on the TPU worker. |
| 26 | + """ |
| 27 | + filename = "my_new_file.txt" |
| 28 | + content_to_write = f"This is written on TPU worker {jax.process_id}" |
| 29 | + |
| 30 | + try: |
| 31 | + with open(filename, 'w', encoding='utf-8') as file: |
| 32 | + file.write(content_to_write) |
| 33 | + print(f"Content written to '{filename}'.") |
| 34 | + |
| 35 | + print(f"File '{filename}' created and closed.") |
| 36 | + except IOError as e: |
| 37 | + print(f"An error occurred: {e}") |
| 38 | + |
| 39 | + return dummy |
| 40 | + |
| 41 | +devices = jax.devices() |
| 42 | +dummy_array = np.array(1) |
| 43 | +dummy_array = jax.device_put(dummy_array, devices[0]) |
| 44 | + |
| 45 | +out = create_a_file(dummy_array) |
| 46 | +``` |
| 47 | + |
| 48 | +### Medium (With User Dependencies) |
| 49 | + |
| 50 | +What if you want to add your own dependencies to do more advanced logic? |
| 51 | + |
| 52 | +The following is a simple line chart of the first 5 primes in matplotlib that is saved locally to the TPU worker. |
| 53 | + |
| 54 | +```python |
| 55 | +import jax |
| 56 | +import numpy as np |
| 57 | +from jax.experimental import colocated_python |
| 58 | +from jax.experimental.colocated_python import serialization |
| 59 | + |
| 60 | +# User added dependency |
| 61 | +import matplotlib.pyplot as plt |
| 62 | + |
| 63 | +@colocated_python.colocated_python |
| 64 | +def create_and_save_primes_plot(dummy): |
| 65 | + """ |
| 66 | + Creates a simple matplotlib line plot and saves it as a PNG image |
| 67 | + on the TPU worker. |
| 68 | + """ |
| 69 | + worker_id = jax.process_id() |
| 70 | + plot_filename = f"simple_line_plot_worker_{worker_id}.png" |
| 71 | + |
| 72 | + # Sample data for the plot |
| 73 | + x_data = np.array([1, 2, 3, 4, 5]) |
| 74 | + y_data = np.array([2, 3, 5, 7, 11]) |
| 75 | + |
| 76 | + try: |
| 77 | + # Create the line plot |
| 78 | + plt.figure(figsize=(6, 4)) |
| 79 | + plt.plot(x_data, y_data, marker='o', linestyle='-') |
| 80 | + |
| 81 | + # Add labels and title |
| 82 | + plt.xlabel("Nth Prime") |
| 83 | + plt.ylabel("Primes") |
| 84 | + plt.title(f"Simple Plot from TPU Worker {worker_id}") |
| 85 | + plt.grid(True) |
| 86 | + |
| 87 | + # Save the plot to the specified file |
| 88 | + plt.savefig(plot_filename) |
| 89 | + print(f"Plot successfully saved to '{plot_filename}' on worker {worker_id}.") |
| 90 | + |
| 91 | + plt.close() |
| 92 | + except Exception as e: |
| 93 | + print(f"An error occurred on worker {worker_id} while creating/saving the plot: {e}") |
| 94 | + |
| 95 | + return dummy |
| 96 | + |
| 97 | +devices = jax.devices() |
| 98 | +dummy_np_array = np.array(1, dtype=np.float32) |
| 99 | +dummy_device_array = jax.device_put(dummy_np_array, devices[0]) |
| 100 | +out = create_and_save_plot(dummy_device_array) |
| 101 | +``` |
| 102 | + |
| 103 | +### Advanced Usage (With User Dependencies and Control Flow Logic) |
| 104 | + |
| 105 | +For more advanced usage (such as data loading), check out [MaxText's RemoteIterator class](https://github.com/AI-Hypercomputer/maxtext/blob/391a5a788d85cae8942334b042fdabdbd549af51/MaxText/multihost_dataloading.py#L175). |
| 106 | + |
| 107 | +See Installation and Usage for instructions on how to use MaxText out of the box with this feature. |
| 108 | + |
| 109 | +### Verification |
| 110 | + |
| 111 | +To verify files were created, SSH into one of the TPU workers using the following command and check that the file was created. |
| 112 | + |
| 113 | +`kubectl exec -it <pod_name> -- /bin/sh -c "cat my_new_file.txt"` |
| 114 | + |
| 115 | +Logs can also be verified by tailing the pod. |
| 116 | + |
| 117 | +`kubectl logs -f <pod_name>` |
| 118 | + |
| 119 | +## Installation and Usage |
| 120 | + |
| 121 | +Follow these steps to set up, build, and deploy your application with the Colocated Python sidecar. |
| 122 | + |
| 123 | +**Prerequisites** |
| 124 | + |
| 125 | +Ensure [Docker](https://docs.docker.com/engine/install/) is installed on your system along with [gcloud](https://cloud.google.com/sdk/docs/install). Ensure you are authenticated into gcloud and Docker is configured for your region. For Google Artifact Registry, you typically run a command like this (replace `REGION` with the region of your repository, e.g., `us-east5`): |
| 126 | + |
| 127 | +```bash |
| 128 | +gcloud auth login |
| 129 | +gcloud auth configure-docker REGION-docker.pkg.dev |
| 130 | +``` |
| 131 | + |
| 132 | +**1. Clone the Repository** |
| 133 | + |
| 134 | +Get the necessary code and scripts. |
| 135 | + |
| 136 | +```bash |
| 137 | +git clone https://github.com/AI-Hypercomputer/pathways-utils.git |
| 138 | +cd pathways-utils/sidecar/python |
| 139 | +``` |
| 140 | + |
| 141 | +**2. Prepare Sidecar Dependencies** |
| 142 | + |
| 143 | +Update the file named `requirements.txt`. List all the additional Python packages you need specifically for the sidecar environment, one package per line. |
| 144 | + |
| 145 | +These dependencies may be the same as your main workload's dependencies. |
| 146 | + |
| 147 | +``` |
| 148 | +# Example requirements.txt |
| 149 | +jax>=0.5.1 |
| 150 | +tensorflow-datasets |
| 151 | +tiktoken |
| 152 | +grain-nightly>=0.0.10 |
| 153 | +``` |
| 154 | + |
| 155 | +**3. Build the Colocated Python Sidecar Image and upload it to Artifact Registry** |
| 156 | + |
| 157 | +Use the provided Dockerfile to create the sidecar image. This image will contain the required dependencies in your `requirements.txt`. Also specify the image location to upload to in Artifact Registry |
| 158 | + |
| 159 | +```bash |
| 160 | +export PROJECT_ID=<your_project_id> |
| 161 | +export LOCAL_IMAGE_NAME=my-colocated-python-server |
| 162 | +export JAX_VERSION=0.5.3 |
| 163 | + |
| 164 | +docker build --build-arg JAX_VERSION=${JAX_VERSION} -t ${LOCAL_IMAGE_NAME} . |
| 165 | +``` |
| 166 | + |
| 167 | +Now you can upload the image to Google Artifact Registry. If you do not have an Artifact Registry repository, please follow the instructions [here](https://cloud.google.com/artifact-registry/docs/repositories/create-repos) to create one. |
| 168 | + |
| 169 | +```bash |
| 170 | +export REGION=us # Your Region |
| 171 | +export ARTIFACT_REGISTRY_REPO=YOUR_ARTIFACT_REGISTRY_REPO |
| 172 | +export EXPORTED_IMAGE_LOCATION=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY_REPO}/my-colocated-python:latest |
| 173 | + |
| 174 | +docker tag ${LOCAL_IMAGE_NAME} ${EXPORTED_IMAGE_LOCATION} |
| 175 | +docker push ${EXPORTED_IMAGE_LOCATION} |
| 176 | + |
| 177 | +# Delete the local image as it's no longer needed. |
| 178 | +docker image rm ${LOCAL_IMAGE_NAME} |
| 179 | +``` |
| 180 | + |
| 181 | +**4. Update Deployment Configuration** |
| 182 | + |
| 183 | +***Simple Example*** |
| 184 | + |
| 185 | +Modify your Kubernetes deployment YAML file to use your colocated python sidecar image. This assumes you are using the [pathways-job](https://github.com/google/pathways-job) api. |
| 186 | + |
| 187 | +For example, if using 2 v4-16 TPUs, use the following yaml. This example is modified from [pathways-job](https://github.com/google/pathways-job/blob/main/config/samples/colocated_python_example_pathwaysjob.yaml). |
| 188 | + |
| 189 | +If you do not have an existing GCS Bucket, instructions to create one are [here](https://cloud.google.com/storage/docs/creating-buckets). |
| 190 | + |
| 191 | +```yaml |
| 192 | +apiVersion: pathways-job.pathways.domain/v1 |
| 193 | +kind: PathwaysJob |
| 194 | +metadata: |
| 195 | + name: pathways-colocated |
| 196 | +spec: |
| 197 | + maxRestarts: 0 |
| 198 | + customComponents: |
| 199 | + - componentType: colocated_python_sidecar |
| 200 | + image: <location of your colocated python sidecar server image> |
| 201 | + workers: |
| 202 | + - type: ct4p-hightpu-4t |
| 203 | + topology: 2x2x2 |
| 204 | + numSlices: 2 |
| 205 | + pathwaysDir: "gs://<test-bucket>/tmp" # This bucket needs to be created in advance. |
| 206 | + controller: |
| 207 | + # Pod template for training, default mode. |
| 208 | + deploymentMode: default |
| 209 | + mainContainerName: main |
| 210 | + template: # UserPodTemplate |
| 211 | + spec: |
| 212 | + containers: |
| 213 | + - name: main |
| 214 | + image: python:3.12 |
| 215 | + imagePullPolicy: Always |
| 216 | + command: |
| 217 | + - /bin/sh |
| 218 | + - -c |
| 219 | + - | |
| 220 | + pip install --upgrade pip |
| 221 | + pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ |
| 222 | + pip install pathwaysutils |
| 223 | + python -c "import jax; import pathwaysutils; print(\"Number of JAX devices is\", len(jax.devices()))" |
| 224 | +``` |
| 225 | +
|
| 226 | +***MaxText Reference Example*** |
| 227 | +
|
| 228 | +If using MaxText, to turn on the data loading optimization that uses Colocated Python feature. |
| 229 | +
|
| 230 | +```python |
| 231 | +colocated_python_data_input=True |
| 232 | +``` |
| 233 | + |
| 234 | +**5. Deploy the Application** |
| 235 | + |
| 236 | +Apply the updated deployment configuration to your Kubernetes cluster: |
| 237 | + |
| 238 | +```bash |
| 239 | +kubectl apply -f path/to/your/deployment.yaml |
| 240 | +``` |
| 241 | + |
| 242 | +This will create the necessary pods with your application, pathways head, and the Colocated Python sidecar containers. |
| 243 | + |
| 244 | +## The Sharp Bits 🔪 |
| 245 | + |
| 246 | +**User Dependency Conflicts** |
| 247 | + |
| 248 | +Colocated Python relies on specific internal dependencies, including JAX. Refer to the provided `server_requirements.txt` for the complete list of required dependencies. Using a different dependency version than the one provided in `server_requirements.txt` will cause the Colocated Python image build to fail. |
| 249 | + |
0 commit comments