Skip to content

Commit 26ebbfe

Browse files
authored
Support macOS (#477)
This PR makes both clients and servers work on macOS. Specifically, it: - Follows learning-at-home/hivemind#586 to run a macOS-compatible `p2pd` binary (both x86-64 and ARM64 are supported) - Fixes forking issues and tests on macOS, Python 3.10+ - Introduces basic support for serving model blocks on Apple M1/M2 GPUs (torch.mps) - Increases max number of open files by default (it's not enough on Linux and is really small on macOS)
1 parent 75e516a commit 26ebbfe

File tree

10 files changed

+118
-61
lines changed

10 files changed

+118
-61
lines changed

.github/workflows/run-tests.yaml

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,21 @@ on:
77

88
jobs:
99
run-tests:
10-
runs-on: ubuntu-latest
1110
strategy:
1211
matrix:
1312
include:
14-
- { model: 'bigscience/bloom-560m', python-version: '3.8' }
15-
- { model: 'bigscience/bloom-560m', python-version: '3.9' }
16-
- { model: 'bigscience/bloom-560m', python-version: '3.10' }
17-
- { model: 'bigscience/bloom-560m', python-version: '3.11' }
18-
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' }
19-
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' }
13+
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.8' }
14+
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
15+
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
16+
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
17+
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
18+
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
2019
fail-fast: false
20+
runs-on: ${{ matrix.os }}-latest
2121
timeout-minutes: 15
2222
steps:
2323
- name: Increase swap space
24+
if: ${{ matrix.os == 'ubuntu' }}
2425
uses: pierotofy/set-swap-space@master
2526
with:
2627
swap-size-gb: 10
@@ -47,12 +48,7 @@ jobs:
4748
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
4849
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
4950
50-
# [Step 1] Watch free RAM (lack of RAM is a common issue in CI)
51-
52-
bash -c 'while true; do free -h && sleep 30s; done' &
53-
RAM_WATCH_PID=$!
54-
55-
# [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
51+
# [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
5652
5753
python -m petals.cli.run_dht \
5854
--identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
@@ -61,7 +57,7 @@ jobs:
6157
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
6258
# ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
6359
64-
sleep 5 # wait for DHT init
60+
until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init
6561
6662
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
6763
--mean_balance_check_period 10 \
@@ -95,11 +91,15 @@ jobs:
9591
sleep 30 # wait for servers to eval throughput, download layers, and rebalance
9692
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init
9793
98-
# [Step 3] Run PyTest
94+
# [Step 2] Run PyTest
95+
96+
# Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
97+
export no_proxy=*
98+
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
9999
100100
pytest tests --durations=0 --durations-min=1.0 -v
101101
102-
# [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
102+
# [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
103103
104104
python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
105105
--seq_len 3
@@ -110,9 +110,7 @@ jobs:
110110
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
111111
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
112112
113-
# [Step 5] Clean up
114-
115-
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived tests
113+
# [Step 4] Clean up
116114
117-
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID
115+
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
118116
echo "Done!"

README.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,28 @@ python -m petals.cli.run_server petals-team/StableBeluga2
5151

5252
🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.
5353

54-
🐋 **Any OS + Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
54+
🐋 **Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
5555

5656
```bash
5757
sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \
5858
learningathome/petals:main \
5959
python -m petals.cli.run_server --port 31330 petals-team/StableBeluga2
6060
```
6161

62+
🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands:
63+
64+
```bash
65+
brew install python
66+
python3 -m pip install git+https://github.com/bigscience-workshop/petals
67+
python3 -m petals.cli.run_server petals-team/StableBeluga2
68+
```
69+
6270
<p align="center">
63-
📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (using multiple GPUs, starting on boot, etc.)
64-
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
65-
💬 &nbsp;<b><a href="https://discord.gg/X7DgtxgMhc">Ask for help in Discord</a></b>
71+
📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.)
6672
</p>
6773

74+
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
75+
6876
🦙 **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add `--token YOUR_TOKEN_HERE` to the `python -m petals.cli.run_server` command.
6977

7078
🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).

setup.cfg

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ classifiers =
1818
Programming Language :: Python :: 3.8
1919
Programming Language :: Python :: 3.9
2020
Programming Language :: Python :: 3.10
21+
Programming Language :: Python :: 3.11
2122
Topic :: Scientific/Engineering
2223
Topic :: Scientific/Engineering :: Mathematics
2324
Topic :: Scientific/Engineering :: Artificial Intelligence
@@ -39,7 +40,7 @@ install_requires =
3940
transformers>=4.32.0,<5.0.0 # if you change this, please also change version assert in petals/__init__.py
4041
speedtest-cli==2.1.3
4142
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
42-
hivemind==1.1.9
43+
hivemind @ git+https://github.com/learning-at-home/hivemind
4344
tensor_parallel==1.0.23
4445
humanfriendly
4546
async-timeout>=4.0.2

src/petals/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import os
2+
import platform
23

34
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
45

6+
if platform.system() == "Darwin":
7+
# Necessary for forks to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
8+
os.environ.setdefault("no_proxy", "*")
9+
os.environ.setdefault("OBJC_DISABLE_INITIALIZE_FORK_SAFETY", "YES")
10+
511
import hivemind
612
import transformers
713
from packaging import version

src/petals/cli/run_server.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import argparse
2+
import logging
23

34
import configargparse
5+
import torch
46
from hivemind.proto.runtime_pb2 import CompressionType
5-
from hivemind.utils.limits import increase_file_limit
7+
from hivemind.utils import limits
68
from hivemind.utils.logging import get_logger
79
from humanfriendly import parse_size
810

@@ -127,9 +129,9 @@ def main():
127129
group.add_argument('--new_swarm', action='store_true',
128130
help='Start a new private swarm (i.e., do not connect to any initial peers)')
129131

130-
parser.add_argument('--increase_file_limit', action='store_true',
131-
help='On *nix, this will increase the max number of processes '
132-
'a server can spawn before hitting "Too many open files"; Use at your own risk.')
132+
parser.add_argument('--increase_file_limit', type=int, default=4096,
133+
help='On *nix, increase the max number of files a server can open '
134+
'before hitting "Too many open files" (set to zero to keep the system limit)')
133135
parser.add_argument('--stats_report_interval', type=int, required=False,
134136
help='Interval between two reports of batch processing performance statistics')
135137

@@ -185,8 +187,10 @@ def main():
185187

186188
args["startup_timeout"] = args.pop("daemon_startup_timeout")
187189

188-
if args.pop("increase_file_limit"):
189-
increase_file_limit()
190+
file_limit = args.pop("increase_file_limit")
191+
if file_limit:
192+
limits.logger.setLevel(logging.WARNING)
193+
limits.increase_file_limit(file_limit, file_limit)
190194

191195
compression_type = args.pop("compression").upper()
192196
compression = getattr(CompressionType, compression_type)
@@ -207,6 +211,10 @@ def main():
207211

208212
validate_version()
209213

214+
if not torch.backends.openmp.is_available():
215+
# Necessary to prevent the server from freezing after forks
216+
torch.set_num_threads(1)
217+
210218
server = Server(
211219
**args,
212220
host_maddrs=host_maddrs,

src/petals/server/reachability.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ async def _serve_with_probe():
140140
protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
141141

142142
ready.set_result(True)
143-
logger.info("Reachability service started")
143+
logger.debug("Reachability service started")
144144

145145
async with protocol.serve(common_p2p):
146146
await protocol._stop.wait()

src/petals/server/server.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from typing import Dict, List, Optional, Sequence, Union
1010

1111
import hivemind
12+
import psutil
1213
import torch
14+
import torch.mps
1315
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
1416
from hivemind.moe.server.layers import add_custom_models_from_file
1517
from hivemind.moe.server.runtime import Runtime
@@ -154,13 +156,25 @@ def __init__(
154156
self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
155157

156158
if device is None:
157-
device = "cuda" if torch.cuda.is_available() else "cpu"
159+
if torch.cuda.is_available():
160+
device = "cuda"
161+
elif torch.backends.mps.is_available():
162+
device = "mps"
163+
else:
164+
device = "cpu"
158165
device = torch.device(device)
159166
if device.type == "cuda" and device.index is None:
160167
device = torch.device(device.type, index=0)
161168
self.device = device
162169

163170
torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
171+
if device.type == "cpu" and torch_dtype == torch.float16:
172+
raise ValueError(
173+
f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16"
174+
)
175+
if device.type == "mps" and torch_dtype == torch.bfloat16:
176+
logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead")
177+
torch_dtype = torch.float16
164178
self.torch_dtype = torch_dtype
165179

166180
if tensor_parallel_devices is None:
@@ -253,13 +267,14 @@ def __init__(
253267
self.stop = threading.Event()
254268

255269
def _choose_num_blocks(self) -> int:
256-
assert self.device.type == "cuda", (
270+
assert self.device.type in ("cuda", "mps"), (
257271
"GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
258272
"CPU-only servers in the public swarm are discouraged since they are much slower"
259273
)
260274
num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
261275

262276
if num_devices > 1:
277+
assert self.device.type == "cuda", f"Tensor parallelism is not supported on {self.device.type.upper()}"
263278
memory_per_device = tuple(
264279
torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
265280
)
@@ -270,8 +285,10 @@ def _choose_num_blocks(self) -> int:
270285
"Please launch individual servers on each GPU or set --num_blocks manually to "
271286
"override this exception."
272287
)
273-
else:
288+
elif self.device.type == "cuda":
274289
total_memory = torch.cuda.get_device_properties(self.device).total_memory
290+
else:
291+
total_memory = psutil.virtual_memory().total
275292

276293
gib = 1024**3
277294
# Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
@@ -373,6 +390,8 @@ def _clean_memory_and_fds(self):
373390
f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
374391
f"{reserved_vram / gib:.1f} GiB reserved memory"
375392
)
393+
elif self.device.type == "mps":
394+
torch.mps.empty_cache()
376395

377396
def _choose_blocks(self) -> List[int]:
378397
if self.strict_block_indices is not None:

src/petals/server/throughput.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Dict, Optional, Sequence, Union
1010

1111
import torch
12+
import torch.mps
1213
from hivemind.utils.logging import get_logger
1314
from transformers import PretrainedConfig
1415

@@ -207,14 +208,12 @@ def measure_compute_rps(
207208
elapsed = 0
208209
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
209210
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
210-
if device.type == "cuda":
211-
torch.cuda.synchronize(device)
211+
synchronize(device)
212212

213213
start_time = time.perf_counter()
214-
for step in range(n_steps):
214+
for _ in range(n_steps):
215215
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
216-
if device.type == "cuda":
217-
torch.cuda.synchronize(device)
216+
synchronize(device)
218217
elapsed = time.perf_counter() - start_time
219218
device_rps = n_steps * n_tokens / elapsed
220219

@@ -230,8 +229,15 @@ def measure_compute_rps(
230229
return device_rps
231230

232231

232+
def synchronize(device: torch.device):
233+
if device.type == "cuda":
234+
torch.cuda.synchronize(device)
235+
elif device.type == "mps":
236+
torch.mps.synchronize()
237+
238+
233239
def get_device_name(device: torch.device) -> str:
234-
return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU"
240+
return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else device.type.upper()
235241

236242

237243
def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:

tests/test_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async def _allocate_af():
118118
allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f)) # klogs the cache
119119
await allocate_f_task
120120

121-
alloc_process1 = mp.Process(target=lambda: asyncio.run(_allocate_af()), daemon=True)
121+
alloc_process1 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_af()), daemon=True)
122122
alloc_process1.start()
123123

124124
async def _allocate_bcde():
@@ -128,7 +128,7 @@ async def _allocate_bcde():
128128
allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit
129129
await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
130130

131-
alloc_process2 = mp.Process(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
131+
alloc_process2 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
132132
alloc_process2.start()
133133
assert cache.current_size_bytes == 0
134134
alloc_event.set()

0 commit comments

Comments
 (0)