diff --git a/.github/workflows/check-style.yml b/.github/workflows/check-style.yml
index 29a0f82c5..0c8d03ea0 100644
--- a/.github/workflows/check-style.yml
+++ b/.github/workflows/check-style.yml
@@ -9,7 +9,7 @@ jobs:
black:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- uses: psf/black@stable
with:
options: "--check --diff"
@@ -17,10 +17,19 @@ jobs:
isort:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v2
- - uses: actions/setup-python@v2
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v3
with:
python-version: 3.8
- uses: isort/isort-action@master
with:
isortVersion: "5.10.1"
+
+ codespell:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: codespell-project/actions-codespell@v1
+ with:
+ only_warn: 1
+ ignore_words_list: ibrary,nd
diff --git a/.github/workflows/push-docker-image.yml b/.github/workflows/push-docker-image.yml
index 5b0d02a8b..cf65d3b5b 100644
--- a/.github/workflows/push-docker-image.yml
+++ b/.github/workflows/push-docker-image.yml
@@ -14,7 +14,7 @@ jobs:
steps:
- name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
- name: Docker meta
id: meta
diff --git a/.github/workflows/run-benchmarks.yml b/.github/workflows/run-benchmarks.yml
index 3498a241c..3073f6c50 100644
--- a/.github/workflows/run-benchmarks.yml
+++ b/.github/workflows/run-benchmarks.yml
@@ -11,13 +11,13 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- name: Set up Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v3
with:
python-version: 3.9
- name: Cache dependencies
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.cache/pip
key: Key-v1-3.9-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
@@ -26,6 +26,9 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
+ - name: Build bitsandbytes
+ run: |
+ pip install bitsandbytes==0.41.1
- name: Build hivemind
run: |
pip install .
diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml
index 7cbd02799..feaeb0142 100644
--- a/.github/workflows/run-tests.yml
+++ b/.github/workflows/run-tests.yml
@@ -11,16 +11,16 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: [ 3.7, 3.8, 3.9 ]
+ python-version: [ '3.7', '3.8', '3.9', '3.10', '3.11' ]
timeout-minutes: 15
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- name: Set up Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Cache dependencies
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.cache/pip
key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
@@ -29,6 +29,9 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
+ - name: Build bitsandbytes
+ run: |
+ pip install bitsandbytes==0.41.1
- name: Build hivemind
run: |
pip install .
@@ -41,23 +44,23 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- uses: actions/setup-go@v3
with:
- go-version: '1.16'
+ go-version: '1.20.11'
check-latest: true
- name: Set up Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v3
with:
python-version: '3.8'
- name: Cache dependencies
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.cache/pip
key: Key-v1-3.8-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
+ python -m pip install --upgrade pip setuptools wheel
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Build hivemind
@@ -73,27 +76,30 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- name: Set up Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v3
with:
python-version: '3.8'
- name: Cache dependencies
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.cache/pip
key: Key-v1-3.8-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
+ python -m pip install --upgrade pip setuptools wheel
pip install -r requirements.txt
pip install -r requirements-dev.txt
+ - name: Build bitsandbytes
+ run: |
+ pip install bitsandbytes==0.41.1
- name: Build hivemind
run: |
pip install -e . --no-use-pep517
- name: Test
run: |
export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
- pytest --cov hivemind -v tests
+ pytest --cov hivemind --cov-config=pyproject.toml -v tests
- name: Upload coverage to Codecov
- uses: codecov/codecov-action@v1
+ uses: codecov/codecov-action@v3
diff --git a/.readthedocs.yml b/.readthedocs.yml
index da3f55db0..a65b37e6f 100644
--- a/.readthedocs.yml
+++ b/.readthedocs.yml
@@ -4,9 +4,13 @@ sphinx:
fail_on_warning: true
python:
- version: 3.7
install:
- requirements: requirements.txt
- requirements: requirements-docs.txt
- method: pip
path: .
+
+build:
+ os: ubuntu-22.04
+ tools:
+ python: "3.7"
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index c3fc3f343..5dc3a5223 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -38,7 +38,8 @@ with the following rules:
cannot be longer than 119 characters.
* We use [black](https://github.com/psf/black) for code formatting and [isort](https://github.com/PyCQA/isort) for
import sorting. Before submitting a PR, make sure to install and run `black .` and `isort .` in the root of the
- repository.
+ repository. Also, you may want to check your code for typos by running `codespell --skip=".git"`, though there
+ might be false positives.
* We highly encourage the use of [typing](https://docs.python.org/3/library/typing.html) where applicable.
* Use `get_logger` from `hivemind.utils.logging` to log any information instead of `print`ing directly to standard
output/error streams.
diff --git a/Dockerfile b/Dockerfile
index 850165c17..c556c1a95 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -9,11 +9,14 @@ RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment
# Install packages
RUN apt-get update && apt-get install -y --no-install-recommends --force-yes \
build-essential \
+ curl \
wget \
git \
vim \
&& apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/*
+RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
+ENV PATH="/root/.cargo/bin:${PATH}"
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O install_miniconda.sh && \
bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh
ENV PATH="/opt/conda/bin:${PATH}"
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 000000000..e39af2dab
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+include requirements*
\ No newline at end of file
diff --git a/README.md b/README.md
index 54c00a6cd..476c2eeeb 100644
--- a/README.md
+++ b/README.md
@@ -12,10 +12,6 @@ large model on hundreds of computers from different universities, companies, and

-## Live Demo
-
-Check out our NeurIPS 2021 demonstration ["Training Transformers Together"](https://training-transformers-together.github.io/) to see hivemind in action, join an ongoing collaborative experiment, and learn more about the technologies behind it!
-
## Key Features
* Distributed training without a master node: Distributed Hash Table allows connecting computers in a decentralized
@@ -28,12 +24,24 @@ Check out our NeurIPS 2021 demonstration ["Training Transformers Together"](http
Decentralized Mixture-of-Experts ([paper](https://arxiv.org/abs/2002.04013)).
To learn more about the ideas behind this library,
-see the [full list](https://github.com/learning-at-home/hivemind/tree/refer-to-discord-in-docs#citation) of our papers below.
+see the [full list](#citation) of our papers below.
+
+## Example Use Cases
+
+This section lists projects that leverage hivemind for decentralized training.
+If you have successfully trained a model or created a downstream repository with the help of our library,
+feel free to submit a pull request that adds your project to this list.
+
+* **Petals** ([webpage](https://petals.dev), [code](https://github.com/bigscience-workshop/petals)) — a decentralized platform for inference and fine-tuning of 100B+ language models.
+* **Training Transformers Together** ([webpage](https://training-transformers-together.github.io/), [code](https://github.com/learning-at-home/dalle-hivemind)) — a NeurIPS 2021 demonstration that trained a collaborative text-to-image Transformer model.
+* **CALM** ([webpage](https://huggingface.co/CALM), [code](https://github.com/NCAI-Research/CALM)) — a masked language model trained on a combination of Arabic datasets.
+* **sahajBERT** ([blog post](https://huggingface.co/blog/collaborative-training), [code](https://github.com/tanmoyio/sahajbert)) — a collaboratively pretrained ALBERT-xlarge for the Bengali language.
+* **HivemindStrategy** ([docs](https://lightning.ai/docs/pytorch/stable/advanced/third_party/hivemind.html?highlight=hivemindstrategy)) for PyTorch Lightning allows adapting your existing pipelines to training over slow network with unreliable peers.
## Installation
Before installing, make sure that your environment has Python 3.7+
-and [PyTorch](https://pytorch.org/get-started/locally/#start-locally) 1.6.0 or newer. They can be installed either
+and [PyTorch](https://pytorch.org/get-started/locally/#start-locally) 1.9.0 or newer. They can be installed either
natively or with [Anaconda](https://www.anaconda.com/products/individual).
You can get [the latest release](https://pypi.org/project/hivemind) with pip or build hivemind from source.
@@ -46,6 +54,10 @@ If your versions of Python and PyTorch match the requirements, you can install h
pip install hivemind
```
+Also, if you want to use blockwise 8-bit compression from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
+during data transfer, you can install it with `pip install hivemind[bitsandbytes]`.
+After that, you can use the `BlockwiseQuantization` class in [hivemind.compression](./hivemind/compression)
+
### From source
To install hivemind from source, simply run the following:
@@ -69,7 +81,8 @@ of [Go toolchain](https://golang.org/doc/install) (1.15 or 1.16 are supported).
- __Linux__ is the default OS for which hivemind is developed and tested. We recommend Ubuntu 18.04+ (64-bit), but
other 64-bit distros should work as well. Legacy 32-bit is not recommended.
-- __macOS 10.x__ can run hivemind using [Docker](https://docs.docker.com/desktop/mac/install/).
+- __macOS__ is partially supported.
+ If you have issues, you can run hivemind using [Docker](https://docs.docker.com/desktop/mac/install/) instead.
We recommend using [our Docker image](https://hub.docker.com/r/learningathome/hivemind).
- __Windows 10+ (experimental)__ can run hivemind
using [WSL](https://docs.microsoft.com/ru-ru/windows/wsl/install-win10). You can configure WSL to use GPU by
@@ -111,10 +124,10 @@ If you found hivemind or its underlying algorithms useful for your research, ple
```bibtex
@misc{hivemind,
- author = {Learning{@}home team},
title = {{H}ivemind: a {L}ibrary for {D}ecentralized {D}eep {L}earning},
+ author = {Learning{@}home team},
year = 2020,
- howpublished = {\url{https://github.com/learning-at-home/hivemind}},
+ howpublished = {\url{https://github.com/learning-at-home/hivemind}}
}
```
@@ -124,15 +137,12 @@ at [mryab/learning-at-home](https://github.com/mryab/learning-at-home)):
```bibtex
@inproceedings{ryabinin2020crowdsourced,
+ title = {Towards Crowdsourced Training of Large Neural Networks using Decentralized Mixture-of-Experts},
author = {Ryabinin, Max and Gusev, Anton},
+ year = 2020,
booktitle = {Advances in Neural Information Processing Systems},
- editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
- pages = {3659--3672},
- publisher = {Curran Associates, Inc.},
- title = {Towards Crowdsourced Training of Large Neural Networks using Decentralized Mixture-of-Experts},
- url = {https://proceedings.neurips.cc/paper/2020/file/25ddc0f8c9d3e22e03d3076f98d83cb2-Paper.pdf},
- volume = {33},
- year = {2020}
+ volume = 33,
+ url = {https://proceedings.neurips.cc/paper/2020/file/25ddc0f8c9d3e22e03d3076f98d83cb2-Paper.pdf}
}
```
@@ -142,39 +152,53 @@ at [mryab/learning-at-home](https://github.com/mryab/learning-at-home)):
["Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices"](https://arxiv.org/abs/2103.03239)
```bibtex
-@misc{ryabinin2021moshpit,
+@inproceedings{ryabinin2021moshpit,
title = {Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices},
- author = {Max Ryabinin and Eduard Gorbunov and Vsevolod Plokhotnyuk and Gennady Pekhimenko},
- year = {2021},
- eprint = {2103.03239},
- archivePrefix = {arXiv},
- primaryClass = {cs.LG}
+ author = {Ryabinin, Max and Gorbunov, Eduard and Plokhotnyuk, Vsevolod and Pekhimenko, Gennady},
+ year = 2021,
+ booktitle = {Advances in Neural Information Processing Systems},
+ volume = 34,
+ url = {https://proceedings.neurips.cc/paper/2021/file/97275a23ca44226c9964043c8462be96-Paper.pdf}
}
```
["Distributed Deep Learning in Open Collaborations"](https://arxiv.org/abs/2106.10207)
```bibtex
-@misc{diskin2021distributed,
- title = {Distributed Deep Learning in Open Collaborations},
- author = {Michael Diskin and Alexey Bukhtiyarov and Max Ryabinin and Lucile Saulnier and Quentin Lhoest and Anton Sinitsin and Dmitry Popov and Dmitry Pyrkin and Maxim Kashirin and Alexander Borzunov and Albert Villanova del Moral and Denis Mazur and Ilia Kobelev and Yacine Jernite and Thomas Wolf and Gennady Pekhimenko},
- year = {2021},
- eprint = {2106.10207},
- archivePrefix = {arXiv},
- primaryClass = {cs.LG}
+@inproceedings{diskin2021distributed,
+ title = {Distributed Deep Learning In Open Collaborations},
+ author = {Michael Diskin and Alexey Bukhtiyarov and Max Ryabinin and Lucile Saulnier and Quentin Lhoest and Anton Sinitsin and Dmitry Popov and Dmitriy Pyrkin and Maxim Kashirin and Alexander Borzunov and Albert Villanova del Moral and Denis Mazur and Ilia Kobelev and Yacine Jernite and Thomas Wolf and Gennady Pekhimenko},
+ year = 2021,
+ booktitle = {Advances in Neural Information Processing Systems},
+ url = {https://openreview.net/forum?id=FYHktcK-7v}
}
```
["Secure Distributed Training at Scale"](https://arxiv.org/abs/2106.11257)
```bibtex
-@misc{gorbunov2021secure,
+@inproceedings{gorbunov2022secure,
title = {Secure Distributed Training at Scale},
- author = {Eduard Gorbunov and Alexander Borzunov and Michael Diskin and Max Ryabinin},
- year = {2021},
- eprint = {2106.11257},
- archivePrefix = {arXiv},
- primaryClass = {cs.LG}
+ author = {Gorbunov, Eduard and Borzunov, Alexander and Diskin, Michael and Ryabinin, Max},
+ year = 2022,
+ month = {17--23 Jul},
+ booktitle = {Proceedings of the 39th International Conference on Machine Learning},
+ series = {Proceedings of Machine Learning Research},
+ volume = 162,
+ url = {https://proceedings.mlr.press/v162/gorbunov22a.html}
+}
+```
+
+["Training Transformers Together"](https://arxiv.org/abs/2207.03481)
+
+```bibtex
+@misc{borzunov2022training,
+ title = {Training Transformers Together},
+ author = {Alexander Borzunov and Max Ryabinin and Tim Dettmers and Quentin Lhoest and Lucile Saulnier and Michael Diskin and Yacine Jernite and Thomas Wolf},
+ year = 2022,
+ eprint = {2207.03481},
+ archiveprefix = {arXiv},
+ primaryclass = {cs.LG}
}
```
diff --git a/benchmarks/benchmark_dht.py b/benchmarks/benchmark_dht.py
index 5544b762d..db446259f 100644
--- a/benchmarks/benchmark_dht.py
+++ b/benchmarks/benchmark_dht.py
@@ -20,7 +20,7 @@ class NodeKiller:
"""Auxiliary class that kills dht nodes over a pre-defined schedule"""
def __init__(self, shutdown_peers: list, shutdown_timestamps: list):
- self.shutdown_peers = set(shutdown_peers)
+ self.shutdown_peers = shutdown_peers
self.shutdown_timestamps = shutdown_timestamps
self.current_iter = 0
self.timestamp_iter = 0
@@ -51,7 +51,7 @@ async def store_and_get_task(
latest: bool,
node_killer: NodeKiller,
) -> Tuple[list, list, list, list, int, int]:
- """Iteratively choose random peers to store data onto the dht, then retreive with another random subset of peers"""
+ """Iteratively choose random peers to store data onto the dht, then retrieve with another random subset of peers"""
total_stores = total_gets = 0
successful_stores = []
diff --git a/benchmarks/benchmark_tensor_compression.py b/benchmarks/benchmark_tensor_compression.py
index 1e8d7f053..3316f460c 100644
--- a/benchmarks/benchmark_tensor_compression.py
+++ b/benchmarks/benchmark_tensor_compression.py
@@ -11,26 +11,37 @@
logger = get_logger(__name__)
-def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
+def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> [float, float, int]:
t = time.time()
- deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type))
- return time.time() - t
+ serialized = serialize_torch_tensor(tensor, compression_type)
+ result = deserialize_torch_tensor(serialized)
+ return time.time() - t, (tensor - result).square().mean(), serialized.ByteSize()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument("--size", type=int, default=10000000, required=False)
+ parser.add_argument("--size", type=int, default=10_000_000, required=False)
parser.add_argument("--seed", type=int, default=7348, required=False)
parser.add_argument("--num_iters", type=int, default=30, required=False)
args = parser.parse_args()
torch.manual_seed(args.seed)
- X = torch.randn(args.size)
+ X = torch.randn(args.size, dtype=torch.float32)
for name, compression_type in CompressionType.items():
- tm = 0
+ total_time = 0
+ compression_error = 0
+ total_size = 0
for i in range(args.num_iters):
- tm += benchmark_compression(X, compression_type)
- tm /= args.num_iters
- logger.info(f"Compression type: {name}, time: {tm}")
+ iter_time, iter_distortion, size = benchmark_compression(X, compression_type)
+ total_time += iter_time
+ compression_error += iter_distortion
+ total_size += size
+ total_time /= args.num_iters
+ compression_error /= args.num_iters
+ total_size /= args.num_iters
+ logger.info(
+ f"Compression type: {name}, time: {total_time:.5f}, compression error: {compression_error:.5f}, "
+ f"size: {int(total_size):d}"
+ )
diff --git a/docs/modules/optim.rst b/docs/modules/optim.rst
index 641cd2ddd..49a02d84f 100644
--- a/docs/modules/optim.rst
+++ b/docs/modules/optim.rst
@@ -5,7 +5,7 @@
This module contains decentralized optimizers that wrap your regular PyTorch Optimizer to train with peers.
Depending on the exact configuration, Optimizer may perform large synchronous updates equivalent,
- or perform asynchrnous local updates and average model parameters.
+ or perform asynchronous local updates and average model parameters.
diff --git a/docs/user/contributing.md b/docs/user/contributing.md
index 094416ae7..04ad1ad01 100644
--- a/docs/user/contributing.md
+++ b/docs/user/contributing.md
@@ -2,7 +2,7 @@
This section describes the ways to contribute to the hivemind library. For technical details of developing this library
and getting towards merging your code in the master branch, read
-the [guidelines](https://github.com/learning-at-home/hivemind/blob/master/CONTRIBUTING.md) in our GitHub repository. In
+the [guidelines](https://github.com/learning-at-home/hivemind/blob/master/CONTRIBUTING.md#) in our GitHub repository. In
any case, please follow the [Contributor Covenant](https://www.contributor-covenant.org/version/2/0/code_of_conduct/)
code of conduct when discussing the library and the changes with other community members.
diff --git a/docs/user/dht.md b/docs/user/dht.md
index bc38806b1..4d48039ce 100644
--- a/docs/user/dht.md
+++ b/docs/user/dht.md
@@ -119,7 +119,7 @@ dht = hivemind.DHT(
], start=True)
```
-Thats it, now the two DHT nodes are connected. If you connect additional peers to the network, you only need to specify
+That's it, now the two DHT nodes are connected. If you connect additional peers to the network, you only need to specify
one (or a subset) of peers as `initial_peers`.
In case your peer operates behind a restrictive firewall, you may find it beneficial to set `client_mode=True`. In this
case, the DHT instance will access others, but it will not announce that other peers can connect to it.
diff --git a/docs/user/moe.md b/docs/user/moe.md
index 176405054..f6bcc1fd8 100644
--- a/docs/user/moe.md
+++ b/docs/user/moe.md
@@ -45,7 +45,7 @@ hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 5 --expert_patte
This server serves 5 feedforward experts with ReLU and LayerNorm
(see
-architecture [here](https://github.com/learning-at-home/hivemind/blob/master/hivemind/server/layers/__init__.py#L7-L21))
+architecture [here](https://github.com/learning-at-home/hivemind/blob/master/hivemind/moe/server/layers/common.py#L19))
. In order to connect to this server, you should copy its address from console outputs:
```shell
[...][INFO][moe.server.create:156] Running DHT node on ['ADDRESS_WILL_BE_PRINTED_HERE']
diff --git a/docs/user/quickstart.md b/docs/user/quickstart.md
index eddf2747b..3142b841a 100644
--- a/docs/user/quickstart.md
+++ b/docs/user/quickstart.md
@@ -22,8 +22,10 @@ We assume that you are already familiar with the official [CIFAR-10 example](htt
from the PyTorch website.
We build on top of the official example to spin up distributed training of a two-layer neural network by averaging weights.
-For simplicity, this tutorial will use two non-GPU peers running on the same machine. If you get to the end of this
-tutorial, we'll give you an example of actual distributed training of Transformers ;)
+For simplicity, this tutorial will use two non-GPU peers running on the same machine. If you try to run this example on two
+separate machines with different IPs, this example will not work. To read more about how to perform training on more
+than one machine check out [DHT - Running Across the Internet](https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet).
+If you get to the end of this tutorial, we'll give you an example of actual distributed training of Transformers ;)
For now, let's run our first training peer:
```python
@@ -187,4 +189,3 @@ If you want to learn more about each individual component,
- Learn the underlying math behind hivemind.Optimizer in [Diskin et al., (2021)](https://arxiv.org/abs/2106.10207),
[Li et al. (2020)](https://arxiv.org/abs/2005.00124) and [Ryabinin et al. (2021)](https://arxiv.org/abs/2103.03239).
- Read about setting up Mixture-of-Experts training in [this guide](https://learning-at-home.readthedocs.io/en/latest/user/moe.html),
-
diff --git a/examples/albert/arguments.py b/examples/albert/arguments.py
index df8c52422..f99cd4a19 100644
--- a/examples/albert/arguments.py
+++ b/examples/albert/arguments.py
@@ -113,7 +113,7 @@ class DatasetArguments:
)
tokenizer_path: Optional[str] = field(default="data/tokenizer", metadata={"help": "Path to the tokenizer"})
config_path: Optional[str] = field(
- default="https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
+ default="albert-large-v2",
metadata={"help": "Path to the model config"},
)
cache_dir: Optional[str] = field(default="data", metadata={"help": "Path to the cache"})
diff --git a/examples/albert/requirements.txt b/examples/albert/requirements.txt
index 49796e82b..2c618f62b 100644
--- a/examples/albert/requirements.txt
+++ b/examples/albert/requirements.txt
@@ -1,5 +1,5 @@
-transformers==4.6.0
-datasets==1.5.0
+transformers~=4.6
+datasets~=1.5
torch_optimizer==0.1.0
wandb==0.10.26
sentencepiece
diff --git a/examples/albert/run_trainer.py b/examples/albert/run_trainer.py
index 7fa550a92..9e9445cf8 100755
--- a/examples/albert/run_trainer.py
+++ b/examples/albert/run_trainer.py
@@ -18,6 +18,7 @@
from transformers.trainer_utils import is_main_process
from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
+from hivemind.optim.state_averager import LRSchedulerBase
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.networking import log_visible_maddrs
@@ -33,8 +34,6 @@
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
-
def setup_transformers_logging(process_rank: int):
if is_main_process(process_rank):
diff --git a/examples/albert/run_training_monitor.py b/examples/albert/run_training_monitor.py
index d0b8654dd..921a849e1 100755
--- a/examples/albert/run_training_monitor.py
+++ b/examples/albert/run_training_monitor.py
@@ -46,7 +46,7 @@ class TrainingMonitorArguments(BaseTrainingArguments):
default=5, metadata={"help": "Frequency (in steps) of fetching and saving state from peers"}
)
model_config_path: str = field(
- default="https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
+ default="albert-large-v2",
metadata={"help": "Path to the model config"},
)
repo_path: Optional[str] = field(
diff --git a/hivemind/__init__.py b/hivemind/__init__.py
index f74a640a7..f373643f0 100644
--- a/hivemind/__init__.py
+++ b/hivemind/__init__.py
@@ -13,4 +13,4 @@
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
from hivemind.utils import *
-__version__ = "1.1.0dev0"
+__version__ = "1.1.10.post2"
diff --git a/hivemind/averaging/averager.py b/hivemind/averaging/averager.py
index f83108730..cd9415667 100644
--- a/hivemind/averaging/averager.py
+++ b/hivemind/averaging/averager.py
@@ -8,6 +8,7 @@
import multiprocessing as mp
import os
import random
+import signal
import threading
import weakref
from dataclasses import asdict
@@ -24,8 +25,7 @@
from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
from hivemind.compression import CompressionBase, CompressionInfo, NoCompression, deserialize_torch_tensor
from hivemind.dht import DHT, DHTID
-from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
-from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
+from hivemind.p2p import P2P, P2PContext, P2PDaemonError, P2PHandlerError, PeerID, ServicerBase
from hivemind.proto import averaging_pb2
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
from hivemind.utils.asyncio import (
@@ -62,7 +62,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
:param min_matchmaking_time: when looking for group, wait for requests for at least this many seconds
:param compression: optionally compress tensors with this compression algorithm before running all-reduce
:param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
- :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
+ :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be compressed
:param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
:param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
@@ -330,6 +330,7 @@ def run_in_background(self, await_ready: bool = True, timeout: Optional[float] =
Starts averager in a background process. if await_ready, this method will wait until background dht
is ready to process incoming requests or for :timeout: seconds max.
"""
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
self.start()
if await_ready:
self.wait_until_ready(timeout)
@@ -350,6 +351,9 @@ def shutdown(self) -> None:
logger.exception("Averager shutdown has no effect: the process is already not alive")
async def _shutdown(self, timeout: Optional[float]) -> None:
+ if not self.client_mode:
+ await self.remove_p2p_handlers(self._p2p, namespace=self.prefix)
+
remaining_tasks = set()
for group in self._running_groups.values():
remaining_tasks.update(group.finalize(cancel=True))
@@ -372,7 +376,7 @@ def step(
"""
Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
- :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
+ :param gather: optionally send this information to all peers in the next group and gather it from every groupmate
(this operation is known as all-gather). The gathered data will be available as the output of this function.
:param scheduled_time: when matchmaking, assume that all-reduce will begin at this moment.
By default, schedule all-reduce current time plus min_matchmaking_time seconds
@@ -469,8 +473,7 @@ async def find_peers_or_notify_cancel():
asyncio.CancelledError,
asyncio.InvalidStateError,
P2PHandlerError,
- DispatchFailure,
- ControlFailure,
+ P2PDaemonError,
) as e:
if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
if not step.cancelled():
@@ -648,7 +651,7 @@ async def rpc_download_state(
def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor], Sequence[CompressionInfo]]:
"""
- Get current state and send it to a peer. executed in the host process. Meant to be overriden.
+ Get current state and send it to a peer. executed in the host process. Meant to be overridden.
:returns: a tuple of (small metadata, sequence of torch tensors)
:note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
"""
diff --git a/hivemind/averaging/matchmaking.py b/hivemind/averaging/matchmaking.py
index 6e5a690e5..1bf4e7c47 100644
--- a/hivemind/averaging/matchmaking.py
+++ b/hivemind/averaging/matchmaking.py
@@ -13,8 +13,7 @@
from hivemind.averaging.group_info import GroupInfo
from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
from hivemind.dht import DHT, DHTID
-from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
-from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
+from hivemind.p2p import P2P, P2PContext, P2PDaemonError, P2PHandlerError, PeerID, ServicerBase
from hivemind.proto import averaging_pb2
from hivemind.utils import DHTExpiration, TimedStorage, get_dht_time, get_logger, timed_storage
from hivemind.utils.asyncio import anext, cancel_and_wait
@@ -239,7 +238,7 @@ async def _request_join_group(self, leader: PeerID) -> Optional[GroupInfo]:
except asyncio.TimeoutError:
logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
return None
- except (P2PHandlerError, ControlFailure, DispatchFailure, StopAsyncIteration) as e:
+ except (P2PDaemonError, P2PHandlerError, StopAsyncIteration) as e:
logger.debug(f"{self} - failed to request potential leader {leader}:", exc_info=True)
return None
@@ -285,7 +284,7 @@ async def rpc_join_group(
# wait for the group to be assembled or disbanded
timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time())
await asyncio.wait(
- {self.assembled_group, self.was_accepted_to_group.wait()},
+ {self.assembled_group, asyncio.create_task(self.was_accepted_to_group.wait())},
return_when=asyncio.FIRST_COMPLETED,
timeout=timeout,
)
@@ -481,7 +480,11 @@ async def pop_next_leader(self) -> PeerID:
self.peer_id.to_bytes(),
):
await asyncio.wait(
- {self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED
+ {
+ asyncio.create_task(self.update_finished.wait()),
+ asyncio.create_task(self.declared_expiration.wait()),
+ },
+ return_when=asyncio.FIRST_COMPLETED,
)
self.declared_expiration.clear()
if self.update_finished.is_set():
@@ -512,7 +515,7 @@ async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None
self.update_finished.set()
await asyncio.wait(
- {self.running.wait(), self.update_triggered.wait()},
+ {asyncio.create_task(self.running.wait()), asyncio.create_task(self.update_triggered.wait())},
return_when=asyncio.ALL_COMPLETED,
timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
)
diff --git a/hivemind/averaging/partition.py b/hivemind/averaging/partition.py
index b15bbe2e3..39a354e8a 100644
--- a/hivemind/averaging/partition.py
+++ b/hivemind/averaging/partition.py
@@ -26,7 +26,7 @@ class TensorPartContainer:
:param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
:param compression: optionally compress tensors with this compression algorithm before sending them to peers
:param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
- :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
+ :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be compressed
:param return_deltas: if True, output tensors are differences (aggregated tensor - local tensor)
:param prefetch: when compressing, pre-compute this many compressed tensors in background
"""
@@ -224,7 +224,10 @@ async def accumulate_part(
while part_index > self.current_part_index:
# wait for previous parts to finish processing ...
- await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
+ await asyncio.wait(
+ {self.current_part_future, asyncio.create_task(self.finished.wait())},
+ return_when=asyncio.FIRST_COMPLETED,
+ )
if self.finished.is_set():
raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")
diff --git a/hivemind/compression/__init__.py b/hivemind/compression/__init__.py
index 77ccf8f42..77168d62a 100644
--- a/hivemind/compression/__init__.py
+++ b/hivemind/compression/__init__.py
@@ -5,7 +5,7 @@
from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
-from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.compression.quantization import BlockwiseQuantization, Quantile8BitQuantization, Uniform8BitQuantization
from hivemind.compression.serialization import (
deserialize_tensor_stream,
deserialize_torch_tensor,
diff --git a/hivemind/compression/base.py b/hivemind/compression/base.py
index 727ece5f6..956616bd3 100644
--- a/hivemind/compression/base.py
+++ b/hivemind/compression/base.py
@@ -1,4 +1,5 @@
import dataclasses
+import os
import warnings
from abc import ABC, abstractmethod
from enum import Enum, auto
@@ -11,8 +12,9 @@
from hivemind.utils.tensor_descr import TensorDescriptor
# While converting read-only NumPy arrays into PyTorch tensors, we don't make extra copies for efficiency
-warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
+warnings.filterwarnings("ignore", message="The given NumPy array is not writable", category=UserWarning)
+USE_LEGACY_BFLOAT16 = bool(int(os.environ.get("USE_LEGACY_BFLOAT16", 1)))
Key = Any
@@ -53,7 +55,7 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b
"""
Applies compression algorithm to a tensor based on their meta-parameters
- :param tensor: a pytorch tensor to compress; depending on the applicaiton, it is a full tensor or a part
+ :param tensor: a pytorch tensor to compress; depending on the application, it is a full tensor or a part
:param info: meta-information about the tensor; if partitioning is used, this still describes the full tensor
:param allow_inplace: if True, compression can (but doesn't have to) to modify tensor in-place for efficiency
:returns: a protobuf message that encodes the tensor
@@ -80,18 +82,41 @@ class NoCompression(CompressionBase):
compression_type = runtime_pb2.CompressionType.NONE
def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
- array = tensor.detach().numpy()
+ requires_grad = tensor.requires_grad
+ tensor = tensor.detach()
+ shape = tensor.shape
+ dtype_name = str(tensor.dtype).replace("torch.", "")
+ raw_data = tensor
+ if tensor.dtype == torch.bfloat16:
+ if USE_LEGACY_BFLOAT16: # legacy mode: convert to fp32
+ raw_data = tensor.to(torch.float32)
+ else: # efficient mode: send bfloat16 data directly
+ # reinterpret_cast to an arbitrary 2-byte type supported by numpy
+ raw_data = tensor.view(torch.int16)
+
return runtime_pb2.Tensor(
compression=self.compression_type,
- buffer=array.tobytes(),
- size=array.shape,
- dtype=array.dtype.name,
- requires_grad=tensor.requires_grad,
+ buffer=raw_data.numpy().tobytes(),
+ size=shape,
+ dtype=dtype_name,
+ requires_grad=requires_grad,
)
def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
- array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
- return torch.as_tensor(array).reshape(tuple(serialized_tensor.size))
+ shape = torch.Size(serialized_tensor.size)
+ if serialized_tensor.dtype == "bfloat16":
+ numel = shape.numel()
+ if numel > 0 and len(serialized_tensor.buffer) // numel == 4:
+ array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
+ tensor = torch.as_tensor(array, dtype=torch.bfloat16)
+ else:
+ array = np.frombuffer(serialized_tensor.buffer, dtype=np.int16)
+ # reinterpret_cast from an arbitrary 2-byte type supported by numpy
+ tensor = torch.as_tensor(array).view(torch.bfloat16)
+ else:
+ array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
+ tensor = torch.as_tensor(array)
+ return tensor.reshape(shape)
def estimate_compression_ratio(self, info: CompressionInfo) -> float:
return 1.0
diff --git a/hivemind/compression/floating.py b/hivemind/compression/floating.py
index 2bda9c399..73c37522a 100644
--- a/hivemind/compression/floating.py
+++ b/hivemind/compression/floating.py
@@ -12,22 +12,29 @@ class Float16Compression(CompressionBase):
FP16_MIN, FP16_MAX = torch.finfo(torch.float16).min, torch.finfo(torch.float16).max
def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+ if not torch.is_floating_point(tensor) or tensor.dtype == torch.bfloat16:
+ raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors")
+ requires_grad = tensor.requires_grad
+ tensor = tensor.detach().cpu()
dtype_name = tensor.numpy().dtype.name
- tensor = tensor.detach().cpu().float()
- tensor = tensor if allow_inplace else tensor.clone()
+ tensor = tensor.to(torch.float32, copy=not allow_inplace)
tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16)
return runtime_pb2.Tensor(
compression=self.compression_type,
buffer=tensor.numpy().tobytes(),
size=tensor.shape,
dtype=dtype_name,
- requires_grad=tensor.requires_grad,
+ requires_grad=requires_grad,
)
def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
original_dtype = np.dtype(serialized_tensor.dtype)
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
- return torch.as_tensor(np.asarray(array, dtype=original_dtype)).reshape(tuple(serialized_tensor.size))
+ return (
+ torch.as_tensor(np.asarray(array, dtype=original_dtype))
+ .reshape(tuple(serialized_tensor.size))
+ .requires_grad_(serialized_tensor.requires_grad)
+ )
def estimate_compression_ratio(self, info: CompressionInfo) -> float:
return 16.0 / get_num_bits(info.descriptor.dtype)
@@ -41,9 +48,12 @@ class ScaledFloat16Compression(Float16Compression):
FP32_EPS = torch.finfo(torch.float32).eps
def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+ if not torch.is_floating_point(tensor) or tensor.dtype == torch.bfloat16:
+ raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors")
+ requires_grad = tensor.requires_grad
+ tensor = tensor.detach().cpu()
dtype_name = tensor.numpy().dtype.name
- tensor = tensor.detach().cpu().float()
- tensor = tensor if allow_inplace else tensor.clone()
+ tensor = tensor.to(dtype=torch.float32, copy=not allow_inplace)
means = torch.mean(tensor, dim=-1, keepdim=True)
tensor.sub_(means)
stds = tensor.norm(dim=-1, keepdim=True) / math.sqrt(tensor.shape[-1])
@@ -58,7 +68,7 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b
buffer=data,
size=tensor.shape,
dtype=dtype_name,
- requires_grad=tensor.requires_grad,
+ requires_grad=requires_grad,
)
def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
@@ -77,7 +87,8 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
tensor = torch.as_tensor(np.asarray(array, dtype=serialized_tensor.dtype)).reshape(
list(serialized_tensor.size)
)
- return tensor.mul_(stds).add_(means)
+ dtype = getattr(torch, serialized_tensor.dtype)
+ return tensor.mul_(stds).add_(means).to(dtype).requires_grad_(serialized_tensor.requires_grad)
def get_num_bits(dtype: torch.dtype) -> int:
diff --git a/hivemind/compression/quantization.py b/hivemind/compression/quantization.py
index f7584bf8b..257d09bca 100644
--- a/hivemind/compression/quantization.py
+++ b/hivemind/compression/quantization.py
@@ -1,5 +1,6 @@
import math
import os
+import warnings
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple
@@ -10,6 +11,8 @@
from hivemind.compression.base import CompressionBase, CompressionInfo
from hivemind.proto import runtime_pb2
+warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
+
EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))
@@ -22,12 +25,14 @@ def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[n
...
def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+ if not torch.is_floating_point(tensor):
+ raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors")
quantized, codebook = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
return runtime_pb2.Tensor(
compression=self.compression_type,
buffer=b"".join((np.int64(len(codebook)).tobytes(), codebook.tobytes(), quantized.tobytes())),
size=tensor.shape,
- dtype=tensor.numpy().dtype.name,
+ dtype=tensor.data.numpy().dtype.name if tensor.dtype != torch.bfloat16 else "bfloat16",
requires_grad=tensor.requires_grad,
)
@@ -36,8 +41,8 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
codebook = np.frombuffer(serialized_tensor.buffer, offset=8, count=codebook_size, dtype=self.codebook_dtype)
quantized = np.frombuffer(serialized_tensor.buffer, offset=8 + codebook.nbytes, dtype=self.indices_dtype)
quantized = torch.as_tensor(quantized, dtype=torch.int64).reshape(tuple(serialized_tensor.size))
- codebook = torch.as_tensor(np.asarray(codebook, dtype=serialized_tensor.dtype))
- return codebook[quantized]
+ codebook = torch.as_tensor(codebook).to(dtype=getattr(torch, serialized_tensor.dtype))
+ return codebook[quantized].requires_grad_(serialized_tensor.requires_grad)
def estimate_compression_ratio(self, info: CompressionInfo) -> float:
return self.n_bits / torch.finfo(info.descriptor.dtype).bits
@@ -56,8 +61,10 @@ class Uniform8BitQuantization(Quantization):
compression_type = runtime_pb2.UNIFORM_8BIT
def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+ assert torch.is_floating_point(tensor)
offset = self.n_bins // 2
shift = tensor.mean()
+ tensor = tensor.to(dtype=torch.float32, copy=not allow_inplace)
centered_tensor = tensor.sub_(shift) if allow_inplace else tensor - shift
std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1)
scale = self.RANGE_IN_SIGMAS * std_unbiased / self.n_bins
@@ -112,3 +119,73 @@ def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_siz
for job in jobs:
job.result()
return np.quantile(partition_quantiles, quantiles)
+
+
+BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly.
+Please install it with `pip install bitsandbytes`
+or using the instruction from https://github.com/TimDettmers/bitsandbytes."""
+
+
+class BlockwiseQuantization(Quantization):
+ compression_type = runtime_pb2.BLOCKWISE_8BIT
+ codebook_dtype, indices_dtype = np.float32, np.uint8
+ EXTRA_PARAMS = (4096, False, torch.float32, None, None)
+
+ def quantize(
+ self, tensor: torch.Tensor, allow_inplace: bool = False
+ ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
+ try:
+ # This runs actual import only on the 1st call, copies references after that
+ from bitsandbytes.functional import quantize_blockwise
+ except ImportError:
+ raise ImportError(BNB_MISSING_MESSAGE)
+
+ quantized, (absmax, codebook, *extra_params) = quantize_blockwise(tensor, blocksize=4096, nested=False)
+ assert tuple(extra_params) == self.EXTRA_PARAMS # blocksize, nested, dtype, offset, state2
+ return quantized.numpy(), (absmax.numpy(), codebook.numpy())
+
+ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+ requires_grad = tensor.requires_grad
+ tensor = tensor.detach()
+ dtype_name = str(tensor.dtype).replace("torch.", "")
+ tensor = tensor.to(torch.float32)
+
+ quantized, (absmax, codebook) = self.quantize(tensor, allow_inplace=allow_inplace)
+
+ serialized_data = (
+ np.int64(len(absmax)).tobytes(),
+ np.int64(len(codebook)).tobytes(),
+ absmax.tobytes(),
+ codebook.tobytes(),
+ quantized.tobytes(),
+ )
+
+ return runtime_pb2.Tensor(
+ buffer=b"".join(serialized_data),
+ size=tensor.shape,
+ requires_grad=requires_grad,
+ dtype=dtype_name,
+ compression=self.compression_type,
+ )
+
+ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+ try:
+ from bitsandbytes.functional import dequantize_blockwise
+ except ImportError:
+ raise ImportError(BNB_MISSING_MESSAGE)
+
+ absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
+ codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
+ absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
+ codebook = np.frombuffer(
+ serialized_tensor.buffer, offset=16 + absmax.nbytes, count=codebook_size, dtype=self.codebook_dtype
+ )
+ quantized = np.frombuffer(
+ serialized_tensor.buffer, offset=16 + absmax.nbytes + codebook.nbytes, dtype=self.indices_dtype
+ )
+
+ absmax = torch.as_tensor(absmax)
+ codebook = torch.as_tensor(codebook)
+ quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
+ result = dequantize_blockwise(quantized, (absmax, codebook, *self.EXTRA_PARAMS))
+ return result.to(getattr(torch, serialized_tensor.dtype)).requires_grad_(serialized_tensor.requires_grad)
diff --git a/hivemind/compression/serialization.py b/hivemind/compression/serialization.py
index 849e54fb2..07b1e9378 100644
--- a/hivemind/compression/serialization.py
+++ b/hivemind/compression/serialization.py
@@ -6,21 +6,22 @@
from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression
from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
-from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.compression.quantization import BlockwiseQuantization, Quantile8BitQuantization, Uniform8BitQuantization
from hivemind.proto import runtime_pb2
from hivemind.utils.streaming import combine_from_streaming
-BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
+_BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
NONE=NoCompression(),
FLOAT16=Float16Compression(),
MEANSTD_16BIT=ScaledFloat16Compression(),
QUANTILE_8BIT=Quantile8BitQuantization(),
UNIFORM_8BIT=Uniform8BitQuantization(),
+ BLOCKWISE_8BIT=BlockwiseQuantization(),
)
for key in runtime_pb2.CompressionType.keys():
- assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer"
- actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
+ assert key in _BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer"
+ actual_compression_type = _BASE_COMPRESSION_TYPES[key].compression_type
assert (
runtime_pb2.CompressionType.Name(actual_compression_type) == key
), f"Compression strategy for {key} has inconsistent type"
@@ -35,14 +36,14 @@ def serialize_torch_tensor(
) -> runtime_pb2.Tensor:
"""Serialize a given tensor into a protobuf message using the specified compression strategy"""
assert tensor.device == torch.device("cpu")
- compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
+ compression = _BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
info = info or CompressionInfo.from_tensor(tensor, **kwargs)
return compression.compress(tensor, info, allow_inplace)
def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
"""Restore a pytorch tensor from a protobuf message"""
- compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
+ compression = _BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)
diff --git a/hivemind/dht/dht.py b/hivemind/dht/dht.py
index d6dea05cd..85b371d1c 100644
--- a/hivemind/dht/dht.py
+++ b/hivemind/dht/dht.py
@@ -3,6 +3,7 @@
import asyncio
import multiprocessing as mp
import os
+import signal
from functools import partial
from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
@@ -19,7 +20,7 @@
ReturnType = TypeVar("ReturnType")
-class DHT(mp.Process):
+class DHT(mp.context.ForkProcess):
"""
A high-level interface to a hivemind DHT that runs a single DHT node in a background process.
* hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
@@ -94,6 +95,9 @@ def run(self) -> None:
loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
async def _run():
+ # Set SIG_IGN handler to SIGINT
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
+
try:
if self._daemon_listen_maddr is not None:
replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
diff --git a/hivemind/dht/node.py b/hivemind/dht/node.py
index d046702c1..ee56da11a 100644
--- a/hivemind/dht/node.py
+++ b/hivemind/dht/node.py
@@ -254,7 +254,7 @@ async def create(
await asyncio.wait(
[
asyncio.create_task(self.find_nearest_nodes([self.node_id])),
- asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time),
+ asyncio.create_task(asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)),
],
return_when=asyncio.FIRST_COMPLETED,
)
@@ -271,6 +271,7 @@ def __init__(self, *, _initialized_with_create=False):
async def shutdown(self):
"""Process existing requests, close all connections and stop the server"""
self.is_alive = False
+ await self.protocol.shutdown()
if self._should_shutdown_p2p:
await self.p2p.shutdown()
@@ -585,7 +586,7 @@ async def get_many_by_id(
If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
:param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
:param num_workers: override for default num_workers, see traverse_dht num_workers param
- :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework.
+ :param return_futures: if True, immediately return asyncio.Future for every before interacting with the network.
The algorithm will populate these futures with (value, expiration) when it finds the corresponding key
Note: canceling a future will stop search for the corresponding key
:param _is_refresh: internal flag, set to True by an internal cache refresher (if enabled)
diff --git a/hivemind/dht/protocol.py b/hivemind/dht/protocol.py
index e383fc521..708b8c690 100644
--- a/hivemind/dht/protocol.py
+++ b/hivemind/dht/protocol.py
@@ -70,7 +70,7 @@ async def create(
self.record_validator = record_validator
self.authorizer = authorizer
- if not client_mode:
+ if not self.client_mode:
await self.add_p2p_handlers(self.p2p, AuthRPCWrapper(self, AuthRole.SERVICER, self.authorizer))
self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes())
@@ -79,6 +79,10 @@ async def create(
self.node_info = dht_pb2.NodeInfo()
return self
+ async def shutdown(self) -> None:
+ if not self.client_mode:
+ await self.remove_p2p_handlers(self.p2p)
+
def __init__(self, *, _initialized_with_create=False):
"""Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
assert _initialized_with_create, "Please use DHTProtocol.create coroutine to spawn new protocol instances"
@@ -117,7 +121,7 @@ async def call_ping(self, peer: PeerID, validate: bool = False, strict: bool = T
f"Peer {peer} can't access this node. " f"Probably, libp2p has failed to bypass the firewall"
)
- if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value:
+ if response.dht_time != dht_pb2.PingResponse.DESCRIPTOR.fields_by_name["dht_time"].default_value:
if (
response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS
or response.dht_time > time_responded + MAX_DHT_TIME_DISCREPANCY_SECONDS
diff --git a/hivemind/dht/routing.py b/hivemind/dht/routing.py
index 95778b1ad..1e2e0184a 100644
--- a/hivemind/dht/routing.py
+++ b/hivemind/dht/routing.py
@@ -1,4 +1,4 @@
-""" Utlity data structures to represent DHT nodes (peers), data keys, and routing tables. """
+""" Utility data structures to represent DHT nodes (peers), data keys, and routing tables. """
from __future__ import annotations
import hashlib
diff --git a/hivemind/dht/traverse.py b/hivemind/dht/traverse.py
index 502caa656..db8c01241 100644
--- a/hivemind/dht/traverse.py
+++ b/hivemind/dht/traverse.py
@@ -209,7 +209,9 @@ async def worker():
# get nearest neighbors (over network) and update search heaps. Abort if search finishes early
get_neighbors_task = asyncio.create_task(get_neighbors(chosen_peer, queries_to_call))
pending_tasks.add(get_neighbors_task)
- await asyncio.wait([get_neighbors_task, search_finished_event.wait()], return_when=asyncio.FIRST_COMPLETED)
+ await_finished_task = asyncio.create_task(search_finished_event.wait())
+ await asyncio.wait([get_neighbors_task, await_finished_task], return_when=asyncio.FIRST_COMPLETED)
+ del await_finished_task
if search_finished_event.is_set():
break # other worker triggered finish_search, we exit immediately
pending_tasks.remove(get_neighbors_task)
diff --git a/hivemind/hivemind_cli/run_dht.py b/hivemind/hivemind_cli/run_dht.py
index d8a5bab4f..d72dbd22b 100644
--- a/hivemind/hivemind_cli/run_dht.py
+++ b/hivemind/hivemind_cli/run_dht.py
@@ -1,5 +1,6 @@
import time
from argparse import ArgumentParser
+from secrets import token_hex
from hivemind.dht import DHT, DHTNode
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -18,6 +19,9 @@ async def report_status(dht: DHT, node: DHTNode):
logger.info(f"Local storage contains {len(node.protocol.storage)} keys")
logger.debug(f"Local storage contents: {node.protocol.storage}")
+ # Contact peers and keep the routing table healthy (remove stale PeerIDs)
+ await node.get(f"heartbeat_{token_hex(16)}", latest=True)
+
def main():
parser = ArgumentParser()
@@ -51,6 +55,17 @@ def main():
help="Path to a private key file. If defined, makes the peer ID deterministic. "
"If the file does not exist, writes a new private key to this file.",
)
+ parser.add_argument(
+ "--no_relay",
+ action="store_false",
+ dest="use_relay",
+ help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
+ )
+ parser.add_argument(
+ "--use_auto_relay",
+ action="store_true",
+ help="Look for libp2p relays to become reachable if we are behind NAT/firewall",
+ )
parser.add_argument(
"--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"
)
@@ -64,12 +79,19 @@ def main():
announce_maddrs=args.announce_maddrs,
use_ipfs=args.use_ipfs,
identity_path=args.identity_path,
+ use_relay=args.use_relay,
+ use_auto_relay=args.use_auto_relay,
)
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)
- while True:
- dht.run_coroutine(report_status, return_future=False)
- time.sleep(args.refresh_period)
+ try:
+ while True:
+ dht.run_coroutine(report_status, return_future=False)
+ time.sleep(args.refresh_period)
+ except KeyboardInterrupt:
+ logger.info("Caught KeyboardInterrupt, shutting down")
+ finally:
+ dht.shutdown()
if __name__ == "__main__":
diff --git a/hivemind/hivemind_cli/run_server.py b/hivemind/hivemind_cli/run_server.py
index 078702c8e..1c6bc9a09 100644
--- a/hivemind/hivemind_cli/run_server.py
+++ b/hivemind/hivemind_cli/run_server.py
@@ -35,6 +35,17 @@ def main():
help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
parser.add_argument('--announce_maddrs', type=list, nargs='+', default=None, required=False,
help='Visible multiaddrs the host announces for external connections from other p2p instances')
+ parser.add_argument(
+ "--no_relay",
+ action="store_false",
+ dest="use_relay",
+ help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
+ )
+ parser.add_argument(
+ "--use_auto_relay",
+ action="store_true",
+ help="Look for libp2p relays to become reachable if we are behind NAT/firewall",
+ )
parser.add_argument('--num_handlers', type=int, default=None, required=False,
help='server will use this many processes to handle incoming requests')
@@ -97,8 +108,6 @@ def main():
server.join()
except KeyboardInterrupt:
logger.info("Caught KeyboardInterrupt, shutting down")
- finally:
- server.shutdown()
if __name__ == "__main__":
diff --git a/hivemind/moe/client/beam_search.py b/hivemind/moe/client/beam_search.py
index 4baef06a7..f8b357c7b 100644
--- a/hivemind/moe/client/beam_search.py
+++ b/hivemind/moe/client/beam_search.py
@@ -171,7 +171,7 @@ def get_active_successors(
) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
"""
:param prefixes: a list of prefix for which to find active successor uids
- :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
+ :param grid_size: if specified, only return successors if they are in range [0, grid_size)
:param return_future: if False (default), find and return successors. Otherwise return MPFuture and fill later.
:returns: for every expert, return a dict{active_next_coordinate: (matching_expert_uid, matching_endpoint)}
:note: if a prefix is not found, get_active_successors will return an empty dictionary for that prefix
diff --git a/hivemind/moe/client/expert.py b/hivemind/moe/client/expert.py
index 477f977ec..c7ad84a91 100644
--- a/hivemind/moe/client/expert.py
+++ b/hivemind/moe/client/expert.py
@@ -13,7 +13,7 @@
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.p2p import P2P, PeerID, StubBase
-from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
+from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
from hivemind.utils.mpfuture import MPFuture
@@ -152,7 +152,7 @@ async def expert_backward(
size = 0
for t in inputs_and_grads:
size += t.element_size() * t.nelement()
- if size > DEFAULT_MAX_MSG_SIZE:
+ if size > MAX_UNARY_PAYLOAD_SIZE:
return await _backward_stream(uid, serialized_tensors, stub)
else:
return await _backward_unary(uid, serialized_tensors, stub)
@@ -185,7 +185,7 @@ async def expert_forward(
size = 0
for t in inputs:
size += t.element_size() * t.nelement()
- if size > DEFAULT_MAX_MSG_SIZE:
+ if size > MAX_UNARY_PAYLOAD_SIZE:
return await _forward_stream(uid, serialized_tensors, stub)
else:
return await _forward_unary(uid, serialized_tensors, stub)
diff --git a/hivemind/moe/client/remote_expert_worker.py b/hivemind/moe/client/remote_expert_worker.py
index 07012a91a..53b3c41b6 100644
--- a/hivemind/moe/client/remote_expert_worker.py
+++ b/hivemind/moe/client/remote_expert_worker.py
@@ -1,6 +1,6 @@
+import asyncio
import os
from concurrent.futures import Future
-from queue import Queue
from threading import Thread
from typing import Awaitable, Optional
@@ -10,39 +10,27 @@
class RemoteExpertWorker:
"""Local thread for managing async tasks related to RemoteExpert"""
- _task_queue: Queue = Queue()
- _event_thread: Optional[Thread] = None
- _pid: int = -1
+ _event_thread = None
+ _event_loop_fut = None
+ _pid = None
@classmethod
- def _run(cls):
- loop = switch_to_uvloop()
-
- async def receive_tasks():
- while True:
- cor, future = cls._task_queue.get()
- try:
- result = await cor
- except Exception as e:
- future.set_exception(e)
- continue
- if not future.cancelled():
- future.set_result(result)
-
- loop.run_until_complete(receive_tasks())
+ def _run_event_loop(cls):
+ try:
+ loop = switch_to_uvloop()
+ cls._event_loop_fut.set_result(loop)
+ except Exception as e:
+ cls._event_loop_fut.set_exception(e)
+ loop.run_forever()
@classmethod
def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
if cls._event_thread is None or cls._pid != os.getpid():
cls._pid = os.getpid()
- cls._event_thread = Thread(target=cls._run, daemon=True)
+ cls._event_loop_fut = Future()
+ cls._event_thread = Thread(target=cls._run_event_loop, daemon=True)
cls._event_thread.start()
- future = Future()
- cls._task_queue.put((coro, future))
-
- if return_future:
- return future
-
- result = future.result()
- return result
+ loop = cls._event_loop_fut.result()
+ future = asyncio.run_coroutine_threadsafe(coro, loop)
+ return future if return_future else future.result()
diff --git a/hivemind/moe/server/connection_handler.py b/hivemind/moe/server/connection_handler.py
index d00827689..f6f0bcc85 100644
--- a/hivemind/moe/server/connection_handler.py
+++ b/hivemind/moe/server/connection_handler.py
@@ -28,36 +28,75 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
:param module_backends: a dict [UID -> ModuleBackend] with all active experts
"""
- def __init__(self, dht: DHT, module_backends: Dict[str, ModuleBackend]):
+ def __init__(
+ self,
+ dht: DHT,
+ module_backends: Dict[str, ModuleBackend],
+ *,
+ balanced: bool = True,
+ shutdown_timeout: float = 3,
+ start: bool = False,
+ ):
super().__init__()
self.dht, self.module_backends = dht, module_backends
+ self.balanced, self.shutdown_timeout = balanced, shutdown_timeout
self._p2p: Optional[P2P] = None
+ self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=False)
self.ready = MPFuture()
+ if start:
+ self.run_in_background(await_ready=True)
+
def run(self):
torch.set_num_threads(1)
loop = switch_to_uvloop()
+ stop = asyncio.Event()
+ loop.add_reader(self._inner_pipe.fileno(), stop.set)
async def _run():
try:
self._p2p = await self.dht.replicate_p2p()
- await self.add_p2p_handlers(self._p2p, balanced=True)
-
- # wait forever
- await asyncio.Future()
-
+ await self.add_p2p_handlers(self._p2p, balanced=self.balanced)
+ self.ready.set_result(None)
except Exception as e:
+ logger.error("ConnectionHandler failed to start:", exc_info=True)
self.ready.set_exception(e)
- return
- self.ready.set_result(None)
+ try:
+ await stop.wait()
+ finally:
+ await self.remove_p2p_handlers(self._p2p)
try:
loop.run_until_complete(_run())
except KeyboardInterrupt:
logger.debug("Caught KeyboardInterrupt, shutting down")
+ def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
+ """
+ Starts ConnectionHandler in a background process. If :await_ready:, this method will wait until
+ it is ready to process incoming requests or for :timeout: seconds max.
+ """
+ self.start()
+ if await_ready:
+ self.wait_until_ready(timeout)
+
+ def wait_until_ready(self, timeout: Optional[float] = None) -> None:
+ self.ready.result(timeout=timeout)
+
+ def shutdown(self):
+ if self.is_alive():
+ self._outer_pipe.send("_shutdown")
+ self.join(self.shutdown_timeout)
+ if self.is_alive():
+ logger.warning(
+ "ConnectionHandler did not shut down within the grace period; terminating it the hard way"
+ )
+ self.terminate()
+ else:
+ logger.warning("ConnectionHandler shutdown had no effect, the process is already dead")
+
async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
module_info = self.module_backends[request.uid].get_info()
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(module_info))
diff --git a/hivemind/moe/server/layers/optim.py b/hivemind/moe/server/layers/optim.py
index f280ba427..00eb85e75 100644
--- a/hivemind/moe/server/layers/optim.py
+++ b/hivemind/moe/server/layers/optim.py
@@ -1,11 +1,10 @@
import torch
-class OptimizerWrapper(torch.optim.Optimizer):
+class OptimizerWrapper:
"""A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer"""
def __init__(self, optim: torch.optim.Optimizer):
- super().__init__(optim.param_groups, optim.defaults)
self.optim = optim
@property
diff --git a/hivemind/moe/server/module_backend.py b/hivemind/moe/server/module_backend.py
index f6260371a..199cc9f28 100644
--- a/hivemind/moe/server/module_backend.py
+++ b/hivemind/moe/server/module_backend.py
@@ -8,9 +8,13 @@
from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
logger = get_logger(__name__)
+try:
+ LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler
+except AttributeError: # torch < 2.0.0
+ LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler
+
class ModuleBackend:
"""
@@ -118,9 +122,7 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
with torch.enable_grad():
args = [
- tensor.detach().requires_grad_(True)
- if tensor.dtype in (torch.half, torch.float, torch.double)
- else tensor.detach()
+ tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach()
for tensor in args
]
kwargs = {
diff --git a/hivemind/moe/server/runtime.py b/hivemind/moe/server/runtime.py
index 1e750812f..14ad97045 100644
--- a/hivemind/moe/server/runtime.py
+++ b/hivemind/moe/server/runtime.py
@@ -6,13 +6,14 @@
from queue import SimpleQueue
from selectors import EVENT_READ, DefaultSelector
from statistics import mean
-from time import time
-from typing import Dict, NamedTuple, Optional
+from time import perf_counter
+from typing import Any, Dict, NamedTuple, Optional, Tuple
import torch
from prefetch_generator import BackgroundGenerator
from hivemind.moe.server.module_backend import ModuleBackend
+from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.utils import get_logger
logger = get_logger(__name__)
@@ -79,26 +80,36 @@ def run(self):
self.stats_reporter.start()
logger.info("Started")
- for pool, batch_index, batch in BackgroundGenerator(
- self.iterate_minibatches_from_pools(), self.prefetch_batches
- ):
- logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
-
- start = time()
- outputs = pool.process_func(*batch)
- batch_processing_time = time() - start
-
- batch_size = outputs[0].size(0)
- logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
+ batch_iterator = self.iterate_minibatches_from_pools()
+ if self.prefetch_batches > 0:
+ batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
- if self.stats_report_interval is not None:
- self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
+ for pool, batch_index, batch in batch_iterator:
+ logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
+ start = perf_counter()
+ try:
+ outputs, batch_size = self.process_batch(pool, batch_index, *batch)
+ batch_processing_time = perf_counter() - start
+ logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
+ if self.stats_report_interval is not None:
+ self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
+
+ output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
+ except KeyboardInterrupt:
+ raise
+ except BaseException as exception:
+ logger.exception(f"Caught {exception}, attempting to recover")
+ output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
- output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
finally:
if not self.shutdown_trigger.is_set():
self.shutdown()
+ def process_batch(self, pool: TaskPoolBase, batch_index: int, *batch: torch.Tensor) -> Tuple[Any, int]:
+ """process one batch of tasks from a given pool, return a batch of results and total batch size"""
+ outputs = pool.process_func(*batch)
+ return outputs, outputs[0].size(0)
+
def shutdown(self):
"""Gracefully terminate a running runtime."""
logger.info("Shutting down")
@@ -120,9 +131,7 @@ def shutdown(self):
self.shutdown_trigger.set()
def iterate_minibatches_from_pools(self, timeout=None):
- """
- Chooses pool according to priority, then copies exposed batch and frees the buffer
- """
+ """Iteratively select non-empty pool with highest priority and loads a batch from that pool"""
with DefaultSelector() as selector:
for pool in self.pools:
selector.register(pool.batch_receiver, EVENT_READ, pool)
@@ -136,8 +145,8 @@ def iterate_minibatches_from_pools(self, timeout=None):
if self.SHUTDOWN_TRIGGER in ready_objects:
break # someone asked us to shutdown, break from the loop
- logger.debug("Choosing the pool with highest priority")
- pool = max(ready_objects, key=lambda pool: pool.priority)
+ logger.debug("Choosing the pool with first priority")
+ pool = min(ready_objects, key=lambda pool: pool.priority)
logger.debug(f"Loading batch from {pool.name}")
batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
diff --git a/hivemind/moe/server/server.py b/hivemind/moe/server/server.py
index f4d7d7a77..87488f334 100644
--- a/hivemind/moe/server/server.py
+++ b/hivemind/moe/server/server.py
@@ -247,10 +247,8 @@ def run(self):
if self.checkpoint_saver is not None:
self.checkpoint_saver.start()
- for process in self.conn_handlers:
- if not process.is_alive():
- process.start()
- process.ready.result()
+ for handler in self.conn_handlers:
+ handler.run_in_background()
try:
self.runtime.run()
@@ -287,9 +285,8 @@ def shutdown(self):
"""
self.ready.clear()
- for process in self.conn_handlers:
- process.terminate()
- process.join()
+ for handler in self.conn_handlers:
+ handler.shutdown()
logger.debug("Connection handlers terminated")
if self.module_backends:
@@ -301,12 +298,11 @@ def shutdown(self):
self.checkpoint_saver.join()
self.dht.shutdown()
- self.dht.join()
logger.debug(f"Shutting down runtime")
-
self.runtime.shutdown()
- logger.info("Server shutdown succesfully")
+
+ logger.info("Server shutdown successfully")
@contextmanager
diff --git a/hivemind/moe/server/task_pool.py b/hivemind/moe/server/task_pool.py
index 8f9342bad..c763dc776 100644
--- a/hivemind/moe/server/task_pool.py
+++ b/hivemind/moe/server/task_pool.py
@@ -27,7 +27,7 @@ class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
def __init__(self, process_func: callable, daemon=True, **kwargs):
super().__init__(daemon=daemon, **kwargs)
self.process_func = process_func
- self._priority = mp.Value(ctypes.c_double, 1.0) # higher priority = the more urgent to process this pool
+ self._priority = mp.Value(ctypes.c_double, 1.0) # lower priority = the more urgent to process this pool
@abstractmethod
def run(self):
@@ -38,7 +38,7 @@ def submit_task(self, *args: torch.Tensor) -> Future:
pass
@abstractmethod
- def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
+ def load_batch_to_runtime(self) -> Tuple[Any, List[torch.Tensor]]:
pass
@property
@@ -170,7 +170,7 @@ def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwar
for skip_i in range(prev_num_tasks):
finished_task_timestamp = (
self.undispatched_task_timestamps.get()
- ) # earlier timestamp = higher priority
+ ) # earlier timestamp = smaller (better) priority, earlier processing
if skip_i == prev_num_tasks - 1:
self.priority = finished_task_timestamp
@@ -195,28 +195,42 @@ def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
while True:
logger.debug(f"{self.name} waiting for results from runtime")
- batch_index, batch_outputs = self.outputs_receiver.recv()
- logger.debug(f"{self.name}, batch {batch_index}: got results")
-
- # split batch into partitions for individual tasks
+ batch_index, batch_outputs_or_exception = self.outputs_receiver.recv()
batch_tasks = pending_batches.pop(batch_index)
- task_sizes = [self.get_task_size(task) for task in batch_tasks]
- outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
- logger.debug(f"{self.name}, batch {batch_index}: sending outputs to handlers")
- # dispatch results to futures
- for task, task_outputs in zip(batch_tasks, outputs_per_task):
- try:
- task.future.set_result(tuple(task_outputs))
- except InvalidStateError as e:
- logger.debug(f"Failed to send task result due to an exception: {e}")
+ if isinstance(batch_outputs_or_exception, BaseException):
+ logger.debug(f"{self.name}, batch {batch_index}: got exception, propagating to handlers")
+ exception = batch_outputs_or_exception
+ for task in batch_tasks:
+ try:
+ task.future.set_exception(exception)
+ except InvalidStateError as e:
+ logger.debug(f"Failed to send runtime error to a task: {e}")
+
+ else:
+ logger.debug(f"{self.name}, batch {batch_index}: got results")
+ batch_outputs = batch_outputs_or_exception
+
+ # split batch into partitions for individual tasks
+ task_sizes = [self.get_task_size(task) for task in batch_tasks]
+ outputs_per_task = zip(
+ *(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs)
+ )
+ logger.debug(f"{self.name}, batch {batch_index}: sending outputs to handlers")
+
+ # dispatch results to futures
+ for task, task_outputs in zip(batch_tasks, outputs_per_task):
+ try:
+ task.future.set_result(tuple(task_outputs))
+ except InvalidStateError as e:
+ logger.debug(f"Failed to send task result due to an exception: {e}")
@property
def empty(self):
return not self.batch_receiver.poll()
def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[torch.Tensor]]:
- """receive next batch of numpy arrays"""
+ """receive next batch of tensors"""
if not self.batch_receiver.poll(timeout):
raise TimeoutError()
@@ -227,11 +241,15 @@ def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[to
def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
"""send results for a processed batch, previously loaded through load_batch_to_runtime"""
batch_outputs = [
- tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
+ tensor.to(device="cpu", non_blocking=False).share_memory_().detach().requires_grad_(tensor.requires_grad)
+ # note: tensor.to deliberately does NOT use non_blocking; non_blocking + share_memory = undefined behavior
for tensor in batch_outputs
]
self.outputs_sender.send((batch_index, batch_outputs))
+ def send_exception_from_runtime(self, batch_index: int, exception: BaseException):
+ self.outputs_sender.send((batch_index, exception))
+
def get_task_size(self, task: Task) -> int:
"""compute task processing complexity (used for batching); defaults to batch size"""
return len(task.args[0]) if task.args else 1
diff --git a/hivemind/optim/grad_averager.py b/hivemind/optim/grad_averager.py
index e168e5eae..487a411f7 100644
--- a/hivemind/optim/grad_averager.py
+++ b/hivemind/optim/grad_averager.py
@@ -29,7 +29,7 @@ class GradientAverager(DecentralizedAverager):
(3) averaged gradients - gradient buffers that are aggregated in-place with peers, always in host memory
:param parameters: pytorch parameters for which to aggregate gradients
- :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
+ :param dht: a DHT instance connected to the rest of the swarm. See hivemind.DHT docs
:param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
:param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
diff --git a/hivemind/optim/optimizer.py b/hivemind/optim/optimizer.py
index 2e93aa185..ef0a05cc9 100644
--- a/hivemind/optim/optimizer.py
+++ b/hivemind/optim/optimizer.py
@@ -15,6 +15,7 @@
from hivemind.optim.grad_scaler import GradScaler
from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
from hivemind.optim.state_averager import (
+ ZERO_GRAD_SET_TO_NONE_DEFAULT,
LRSchedulerBase,
OptimizerFactory,
Parameters,
@@ -56,11 +57,11 @@ class Optimizer(torch.optim.Optimizer):
Unlike regular training, your device may join midway through training, when other peers already made some progress.
For this reason, any learning rate schedulers, curriculum and other **time-dependent features should be based on**
- ``optimizer.local_epoch`` (and not the number ot calls to opt.step). Otherwise, peers that joined training late
+ ``optimizer.local_epoch`` (and not the number of calls to opt.step). Otherwise, peers that joined training late
may end up having different learning rates. To do so automatically, specify ``scheduler=...`` parameter below.
:What is an epoch?: Optimizer uses the term ``epoch`` to describe intervals between synchronizations. One epoch
- coresponds to processing certain number of training samples (``target_batch_size``) in total across all peers.
+ corresponds to processing certain number of training samples (``target_batch_size``) in total across all peers.
Like in PyTorch LR Scheduler, **epoch does not necessarily correspond to a full pass over the training data.**
At the end of epoch, peers perform synchronous actions such as averaging gradients for a global optimizer update,
updating the learning rate scheduler or simply averaging parameters (if using local updates).
@@ -621,7 +622,10 @@ def _load_averaged_gradients_into_optimizer_(self):
with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
assert len(averaged_gradients) == len(optimized_parameters)
for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
- opt_param.grad.copy_(averaged_grad, non_blocking=True)
+ if opt_param.grad is None:
+ opt_param.grad = averaged_grad.clone()
+ else:
+ opt_param.grad.copy_(averaged_grad, non_blocking=True)
self.grad_averager.notify_used_averaged_gradients()
@@ -634,7 +638,7 @@ def _load_local_gradients_into_optimizer(self):
# - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps)
self._load_averaged_gradients_into_optimizer_()
- def zero_grad(self, set_to_none: bool = False):
+ def zero_grad(self, set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT):
"""Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
raise ValueError(
@@ -643,11 +647,9 @@ def zero_grad(self, set_to_none: bool = False):
)
for param_group in self.param_groups:
for param in param_group["params"]:
- if param.grad is None:
- pass
- elif set_to_none:
+ if set_to_none:
param.grad = None
- else:
+ elif param.grad is not None:
param.grad.zero_()
def _should_load_state_from_peers(self) -> bool:
diff --git a/hivemind/optim/power_sgd_averager.py b/hivemind/optim/power_sgd_averager.py
index ce8603103..aab5b48ea 100644
--- a/hivemind/optim/power_sgd_averager.py
+++ b/hivemind/optim/power_sgd_averager.py
@@ -51,7 +51,7 @@ class PowerSGDGradientAverager(GradientAverager):
:param parameters: pytorch parameters for which to aggregate gradients
:param averager_rank: rank of compressed gradients
- :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
+ :param dht: a DHT instance connected to the rest of the swarm. See hivemind.DHT docs
:param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
:param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
diff --git a/hivemind/optim/state_averager.py b/hivemind/optim/state_averager.py
index 794260fc4..f7a94f7b3 100644
--- a/hivemind/optim/state_averager.py
+++ b/hivemind/optim/state_averager.py
@@ -8,6 +8,7 @@
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
import torch
+from packaging.version import Version
import hivemind
from hivemind.averaging import DecentralizedAverager
@@ -22,7 +23,12 @@
Parameters = Iterable[torch.Tensor]
ParamGroups = Iterable[Dict[str, Any]]
TorchOptimizer = torch.optim.Optimizer
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
+if Version(torch.__version__).major >= 2:
+ ZERO_GRAD_SET_TO_NONE_DEFAULT = True
+ LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler
+else:
+ ZERO_GRAD_SET_TO_NONE_DEFAULT = False
+ LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler
OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]
@@ -332,6 +338,7 @@ def step(
averaging_control: Optional[StepControl] = None,
wait_for_trigger: Optional[Callable[[], Any]] = None,
grad_scaler: Optional[GradScaler] = None,
+ set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT,
averaging_opts: Optional[Dict[str, Any]] = None,
):
"""
@@ -353,6 +360,8 @@ def step(
:param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
:note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
:param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
+ :param set_to_none: if True, zero_grad sets local gradients to None instead of zero tensors
+ (default in PyTorch 2.0+)
:param averaging_opts: a dict of keyword arguments forwarded into averaging round
"""
if delay_averaging is None:
@@ -430,6 +439,7 @@ def step(
averaging_round,
averaging_control,
grad_scaler,
+ set_to_none,
**averaging_opts or {},
)
self.pending_updates.add(pending_update)
@@ -472,6 +482,7 @@ def _do(
averaging_round: bool,
averaging_control: Optional[StepControl],
grad_scaler: Optional[GradScaler],
+ set_to_none: bool,
timeout: Optional[float] = None,
**kwargs,
):
@@ -515,7 +526,9 @@ def _do(
self.optimizer.zero_grad()
if self.offload_optimizer:
for parameter in self.main_parameters:
- if parameter.grad is not None:
+ if set_to_none:
+ parameter.grad = None
+ elif parameter.grad is not None:
parameter.grad.zero_()
self._update_scheduler()
@@ -566,7 +579,10 @@ def _load_local_grads_into_optimizer_(self):
opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
for main_param, opt_param in zip(self.main_parameters, opt_parameters):
if main_param.grad is not None:
- opt_param.grad.copy_(main_param.grad, non_blocking=True)
+ if opt_param.grad is None:
+ opt_param.grad = main_param.grad.clone()
+ else:
+ opt_param.grad.copy_(main_param.grad, non_blocking=True)
@torch.no_grad()
def _apply_optimizer_parameters_(self):
diff --git a/hivemind/p2p/__init__.py b/hivemind/p2p/__init__.py
index 383121a77..425746176 100644
--- a/hivemind/p2p/__init__.py
+++ b/hivemind/p2p/__init__.py
@@ -1,3 +1,3 @@
-from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PDaemonError, P2PHandlerError
-from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
+from hivemind.p2p.p2p_daemon import P2P, P2PContext
+from hivemind.p2p.p2p_daemon_bindings import P2PDaemonError, P2PHandlerError, PeerID, PeerInfo
from hivemind.p2p.servicer import ServicerBase, StubBase
diff --git a/hivemind/p2p/p2p_daemon.py b/hivemind/p2p/p2p_daemon.py
index 806d3cd6d..4acea4d72 100644
--- a/hivemind/p2p/p2p_daemon.py
+++ b/hivemind/p2p/p2p_daemon.py
@@ -18,6 +18,7 @@
import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PDaemonError, P2PHandlerError
from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure
from hivemind.proto import crypto_pb2
from hivemind.proto.p2pd_pb2 import RPCError
from hivemind.utils.asyncio import as_aiter, asingle
@@ -102,6 +103,9 @@ async def create(
quic: Optional[bool] = None,
use_relay_hop: Optional[bool] = None,
use_relay_discovery: Optional[bool] = None,
+ check_if_identity_free: bool = True,
+ no_listen: bool = False,
+ trusted_relays: Optional[Sequence[Union[Multiaddr, str]]] = None,
) -> "P2P":
"""
Start a new p2pd process and connect to it.
@@ -123,12 +127,21 @@ async def create(
:param relay_hop_limit: sets the hop limit for hop relays
:param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
:param tls: Enables TLS1.3 channel security protocol
- :param use_auto_relay: enables autorelay
:param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
- :param use_relay: enables circuit relay
+ :param use_relay: Enable circuit relay functionality in libp2p
+ (see https://docs.libp2p.io/concepts/nat/circuit-relay/).
+ If enabled (default), you can reach peers behind NATs/firewalls through libp2p relays.
+ If you are behind NAT/firewall yourself,
+ please pass `use_auto_relay=True` to become reachable.
+ :param use_auto_relay: Look for libp2p relays to become reachable if we are behind NAT/firewall
:param quic: Deprecated, has no effect since libp2p 0.17.0
:param use_relay_hop: Deprecated, has no effect since libp2p 0.17.0
:param use_relay_discovery: Deprecated, has no effect since libp2p 0.17.0
+ :param check_if_identity_free: If enabled (default), ``identity_path`` is provided,
+ and we are connecting to an existing swarm,
+ ensure that this identity is not used by other peers already.
+ This slows down ``P2P.create()`` but protects from unintuitive libp2p errors
+ appearing in case of the identity collision.
:return: a wrapper for the p2p daemon
"""
@@ -164,14 +177,29 @@ async def create(
("bootstrapPeers", initial_peers),
("hostAddrs", host_maddrs),
("announceAddrs", announce_maddrs),
+ ("trustedRelays", trusted_relays),
]:
if value:
process_kwargs[param] = self._maddrs_to_str(value)
-
+ if no_listen:
+ process_kwargs["noListenAddrs"] = 1
if identity_path is not None:
- if not os.path.isfile(identity_path):
- logger.info(f"Generating new identity (libp2p private key) in `{identity_path}`")
+ if os.path.isfile(identity_path):
+ if check_if_identity_free and need_bootstrap:
+ logger.info(f"Checking that identity from `{identity_path}` is not used by other peers")
+ if await cls.is_identity_taken(
+ identity_path,
+ initial_peers=initial_peers,
+ tls=tls,
+ use_auto_relay=use_auto_relay,
+ use_ipfs=use_ipfs,
+ use_relay=use_relay,
+ ):
+ raise P2PDaemonError(f"Identity from `{identity_path}` is already taken by another peer")
+ else:
+ logger.info(f"Generating new identity to be saved in `{identity_path}`")
self.generate_identity(identity_path)
+ # A newly generated identity is not taken with ~100% probability
process_kwargs["id"] = identity_path
proc_args = self._make_process_args(
@@ -217,6 +245,36 @@ async def create(
await self._ping_daemon()
return self
+ @classmethod
+ async def is_identity_taken(
+ cls,
+ identity_path: str,
+ *,
+ initial_peers: Optional[Sequence[Union[Multiaddr, str]]],
+ tls: bool,
+ use_auto_relay: bool,
+ use_ipfs: bool,
+ use_relay: bool,
+ ) -> bool:
+ with open(identity_path, "rb") as f:
+ peer_id = PeerID.from_identity(f.read())
+
+ anonymous_p2p = await cls.create(
+ initial_peers=initial_peers,
+ dht_mode="client",
+ tls=tls,
+ use_auto_relay=use_auto_relay,
+ use_ipfs=use_ipfs,
+ use_relay=use_relay,
+ )
+ try:
+ await anonymous_p2p._client.connect(peer_id, [])
+ return True
+ except ControlFailure:
+ return False
+ finally:
+ await anonymous_p2p.shutdown()
+
@staticmethod
def generate_identity(identity_path: str) -> None:
private_key = RSAPrivateKey()
@@ -475,6 +533,19 @@ async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2
await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type, balanced=balanced)
+ async def remove_protobuf_handler(
+ self,
+ name: str,
+ *,
+ stream_input: bool = False,
+ stream_output: bool = False,
+ ) -> None:
+ if not stream_input and not stream_output:
+ await self._client.remove_unary_handler(name)
+ return
+
+ await self.remove_binary_stream_handler(name)
+
async def _add_protobuf_unary_handler(
self,
handle_name: str,
@@ -553,6 +624,9 @@ async def add_binary_stream_handler(
self._start_listening()
await self._client.stream_handler(name, handler, balanced)
+ async def remove_binary_stream_handler(self, name: str) -> None:
+ await self._client.remove_stream_handler(name)
+
async def call_binary_stream_handler(
self, peer_id: PeerID, handler_name: str
) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
@@ -580,12 +654,13 @@ def _terminate(self) -> None:
self._alive = False
if self._child is not None and self._child.returncode is None:
- self._child.terminate()
- logger.debug(f"Terminated p2pd with id = {self.peer_id}")
+ with suppress(ProcessLookupError):
+ self._child.terminate()
+ logger.debug(f"Terminated p2pd with id = {self.peer_id}")
- with suppress(FileNotFoundError):
+ with suppress(FileNotFoundError, TypeError):
os.remove(self._daemon_listen_maddr["unix"])
- with suppress(FileNotFoundError):
+ with suppress(FileNotFoundError, TypeError):
os.remove(self._client_listen_maddr["unix"])
@staticmethod
diff --git a/hivemind/p2p/p2p_daemon_bindings/__init__.py b/hivemind/p2p/p2p_daemon_bindings/__init__.py
index e69de29bb..7c50a2e08 100644
--- a/hivemind/p2p/p2p_daemon_bindings/__init__.py
+++ b/hivemind/p2p/p2p_daemon_bindings/__init__.py
@@ -0,0 +1,2 @@
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
+from hivemind.p2p.p2p_daemon_bindings.utils import P2PDaemonError, P2PHandlerError
diff --git a/hivemind/p2p/p2p_daemon_bindings/control.py b/hivemind/p2p/p2p_daemon_bindings/control.py
index 9b4376b79..4f229bbdb 100644
--- a/hivemind/p2p/p2p_daemon_bindings/control.py
+++ b/hivemind/p2p/p2p_daemon_bindings/control.py
@@ -12,7 +12,14 @@
from multiaddr import Multiaddr, protocols
from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
-from hivemind.p2p.p2p_daemon_bindings.utils import DispatchFailure, raise_if_failed, read_pbmsg_safe, write_pbmsg
+from hivemind.p2p.p2p_daemon_bindings.utils import (
+ DispatchFailure,
+ P2PDaemonError,
+ P2PHandlerError,
+ raise_if_failed,
+ read_pbmsg_safe,
+ write_pbmsg,
+)
from hivemind.proto import p2pd_pb2 as p2pd_pb
from hivemind.utils.logging import get_logger
@@ -27,6 +34,9 @@
logger = get_logger(__name__)
DEFAULT_MAX_MSG_SIZE = 4 * 1024**2
+MAX_UNARY_PAYLOAD_SIZE = DEFAULT_MAX_MSG_SIZE // 2
+# note: we check vs. 2x max message size to account for serialization overhead. The actual overhead is
+# typically smaller. We err on the side of streaming, because even 2MB messages can be streamed efficiently.
def parse_conn_protocol(maddr: Multiaddr) -> int:
@@ -246,20 +256,37 @@ async def _ensure_persistent_conn(self):
self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
- async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
+ async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False) -> None:
+ if proto in self.unary_handlers:
+ raise P2PDaemonError(f"Handler for protocol {proto} already registered")
+ self.unary_handlers[proto] = handler
+
call_id = uuid4()
+ req = p2pd_pb.PersistentConnectionRequest(
+ callId=call_id.bytes,
+ addUnaryHandler=p2pd_pb.AddUnaryHandlerRequest(proto=proto, balanced=balanced),
+ )
- add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto, balanced=balanced)
- req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
+ self._pending_calls[call_id] = asyncio.Future()
+ await self._pending_messages.put(req)
+ await self._pending_calls[call_id]
- if self.unary_handlers.get(proto):
- raise P2PDaemonError(f"Handler for protocol {proto} already registered")
- self.unary_handlers[proto] = handler
+ async def remove_unary_handler(self, proto: str) -> None:
+ if proto not in self.unary_handlers:
+ raise P2PDaemonError(f"Handler for protocol {proto} is not registered")
+
+ call_id = uuid4()
+ req = p2pd_pb.PersistentConnectionRequest(
+ callId=call_id.bytes,
+ removeUnaryHandler=p2pd_pb.RemoveUnaryHandlerRequest(proto=proto),
+ )
self._pending_calls[call_id] = asyncio.Future()
await self._pending_messages.put(req)
await self._pending_calls[call_id]
+ del self.unary_handlers[proto]
+
async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
call_id = uuid4()
call_unary_req = p2pd_pb.CallUnaryRequest(
@@ -359,13 +386,18 @@ async def stream_open(
return stream_info, reader, writer
async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: bool = False) -> None:
+ self.handlers[proto] = handler_cb
+
reader, writer = await self.daemon_connector.open_connection()
- listen_path_maddr_bytes = self.listen_maddr.to_bytes()
- stream_handler_req = p2pd_pb.StreamHandlerRequest(
- addr=listen_path_maddr_bytes, proto=[proto], balanced=balanced
+ req = p2pd_pb.Request(
+ type=p2pd_pb.Request.STREAM_HANDLER,
+ streamHandler=p2pd_pb.StreamHandlerRequest(
+ addr=self.listen_maddr.to_bytes(),
+ proto=[proto],
+ balanced=balanced,
+ ),
)
- req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req)
await write_pbmsg(writer, req)
resp = p2pd_pb.Response() # type: ignore
@@ -373,17 +405,21 @@ async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced:
writer.close()
raise_if_failed(resp)
- # if success, add the handler to the dict
- self.handlers[proto] = handler_cb
-
+ async def remove_stream_handler(self, proto: str) -> None:
+ reader, writer = await self.daemon_connector.open_connection()
-class P2PHandlerError(Exception):
- """
- Raised if remote handled a request with an exception
- """
+ req = p2pd_pb.Request(
+ type=p2pd_pb.Request.REMOVE_STREAM_HANDLER,
+ removeStreamHandler=p2pd_pb.RemoveStreamHandlerRequest(
+ addr=self.listen_maddr.to_bytes(),
+ proto=[proto],
+ ),
+ )
+ await write_pbmsg(writer, req)
+ resp = p2pd_pb.Response() # type: ignore
+ await read_pbmsg_safe(reader, resp)
+ writer.close()
+ raise_if_failed(resp)
-class P2PDaemonError(Exception):
- """
- Raised if daemon failed to handle request
- """
+ del self.handlers[proto]
diff --git a/hivemind/p2p/p2p_daemon_bindings/datastructures.py b/hivemind/p2p/p2p_daemon_bindings/datastructures.py
index 063f0ba46..920aa920b 100644
--- a/hivemind/p2p/p2p_daemon_bindings/datastructures.py
+++ b/hivemind/p2p/p2p_daemon_bindings/datastructures.py
@@ -9,31 +9,10 @@
import base58
import multihash
+from cryptography.hazmat.primitives import serialization
from multiaddr import Multiaddr, protocols
-from hivemind.proto import p2pd_pb2
-
-# NOTE: On inlining...
-# See: https://github.com/libp2p/specs/issues/138
-# NOTE: enabling to be interoperable w/ the Go implementation
-ENABLE_INLINING = True
-MAX_INLINE_KEY_LENGTH = 42
-
-IDENTITY_MULTIHASH_CODE = 0x00
-
-if ENABLE_INLINING:
-
- class IdentityHash:
- def __init__(self) -> None:
- self._digest = bytearray()
-
- def update(self, input: bytes) -> None:
- self._digest += input
-
- def digest(self) -> bytes:
- return self._digest
-
- multihash.FuncReg.register(IDENTITY_MULTIHASH_CODE, "identity", hash_new=IdentityHash)
+from hivemind.proto import crypto_pb2, p2pd_pb2
class PeerID:
@@ -88,6 +67,31 @@ def from_base58(cls, base58_id: str) -> "PeerID":
peer_id_bytes = base58.b58decode(base58_id)
return cls(peer_id_bytes)
+ @classmethod
+ def from_identity(cls, data: bytes) -> "PeerID":
+ """
+ See [1] for the specification of how this conversion should happen.
+
+ [1] https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md#peer-ids
+ """
+ key_data = crypto_pb2.PrivateKey.FromString(data).data
+ private_key = serialization.load_der_private_key(key_data, password=None)
+
+ encoded_public_key = private_key.public_key().public_bytes(
+ encoding=serialization.Encoding.DER,
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
+ )
+ encoded_public_key = crypto_pb2.PublicKey(
+ key_type=crypto_pb2.RSA,
+ data=encoded_public_key,
+ ).SerializeToString()
+
+ encoded_digest = multihash.encode(
+ hashlib.sha256(encoded_public_key).digest(),
+ multihash.coerce_code("sha2-256"),
+ )
+ return cls(encoded_digest)
+
def sha256_digest(data: Union[str, bytes]) -> bytes:
if isinstance(data, str):
diff --git a/hivemind/p2p/p2p_daemon_bindings/p2pclient.py b/hivemind/p2p/p2p_daemon_bindings/p2pclient.py
index e002e8aa5..7fee00477 100644
--- a/hivemind/p2p/p2p_daemon_bindings/p2pclient.py
+++ b/hivemind/p2p/p2p_daemon_bindings/p2pclient.py
@@ -47,7 +47,8 @@ async def create(
return client
def close(self) -> None:
- self.control.close()
+ if self.control is not None:
+ self.control.close()
def __del__(self):
self.close()
@@ -61,9 +62,12 @@ async def listen(self) -> AsyncIterator["Client"]:
async with self.control.listen():
yield self
- async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
+ async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False) -> None:
await self.control.add_unary_handler(proto, handler, balanced=balanced)
+ async def remove_unary_handler(self, proto: str) -> None:
+ await self.control.remove_unary_handler(proto)
+
async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
return await self.control.call_unary_handler(peer_id, proto, data)
@@ -114,3 +118,6 @@ async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced:
:return:
"""
await self.control.stream_handler(proto=proto, handler_cb=handler_cb, balanced=balanced)
+
+ async def remove_stream_handler(self, proto: str) -> None:
+ await self.control.remove_stream_handler(proto=proto)
diff --git a/hivemind/p2p/p2p_daemon_bindings/utils.py b/hivemind/p2p/p2p_daemon_bindings/utils.py
index 2a0d5b97c..c8ca87901 100644
--- a/hivemind/p2p/p2p_daemon_bindings/utils.py
+++ b/hivemind/p2p/p2p_daemon_bindings/utils.py
@@ -13,11 +13,23 @@
DEFAULT_MAX_BITS: int = 64
-class ControlFailure(Exception):
+class P2PHandlerError(Exception):
+ """
+ Raised if remote handled a request with an exception
+ """
+
+
+class P2PDaemonError(Exception):
+ """
+ Raised if daemon failed to handle request
+ """
+
+
+class ControlFailure(P2PDaemonError):
pass
-class DispatchFailure(Exception):
+class DispatchFailure(P2PDaemonError):
pass
diff --git a/hivemind/p2p/servicer.py b/hivemind/p2p/servicer.py
index 4ceb7bc9b..69dd4ce22 100644
--- a/hivemind/p2p/servicer.py
+++ b/hivemind/p2p/servicer.py
@@ -18,7 +18,7 @@ class RPCHandler:
class StubBase:
"""
- Base class for P2P RPC stubs. The interface mimicks gRPC stubs.
+ Base class for P2P RPC stubs. The interface mimics gRPC stubs.
Servicer derives stub classes for particular services (e.g. DHT, averager, etc.) from StubBase,
adding the necessary rpc_* methods. Calls to these methods are translated to calls to the remote peer.
@@ -32,7 +32,7 @@ def __init__(self, p2p: P2P, peer: PeerID, namespace: Optional[str]):
class ServicerBase:
"""
- Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimicks gRPC servicers.
+ Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimics gRPC servicers.
- ``add_p2p_handlers(self, p2p)`` registers all rpc_* methods of the derived class as P2P handlers, allowing
other peers to call them. It uses type annotations for the ``request`` parameter and the return value
@@ -124,6 +124,20 @@ async def add_p2p_handlers(
]
)
+ async def remove_p2p_handlers(self, p2p: P2P, *, namespace: Optional[str] = None) -> None:
+ self._collect_rpc_handlers()
+
+ await asyncio.gather(
+ *[
+ p2p.remove_protobuf_handler(
+ self._get_handle_name(namespace, handler.method_name),
+ stream_input=handler.stream_input,
+ stream_output=handler.stream_output,
+ )
+ for handler in self._rpc_handlers
+ ]
+ )
+
@classmethod
def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:
cls._collect_rpc_handlers()
diff --git a/hivemind/proto/auth_pb2.py b/hivemind/proto/auth_pb2.py
new file mode 100644
index 000000000..39d6f65ff
--- /dev/null
+++ b/hivemind/proto/auth_pb2.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: auth.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\nauth.proto\"_\n\x0b\x41\x63\x63\x65ssToken\x12\x10\n\x08username\x18\x01 \x01(\t\x12\x12\n\npublic_key\x18\x02 \x01(\x0c\x12\x17\n\x0f\x65xpiration_time\x18\x03 \x01(\t\x12\x11\n\tsignature\x18\x04 \x01(\x0c\"\x88\x01\n\x0fRequestAuthInfo\x12)\n\x13\x63lient_access_token\x18\x01 \x01(\x0b\x32\x0c.AccessToken\x12\x1a\n\x12service_public_key\x18\x02 \x01(\x0c\x12\x0c\n\x04time\x18\x03 \x01(\x01\x12\r\n\x05nonce\x18\x04 \x01(\x0c\x12\x11\n\tsignature\x18\x05 \x01(\x0c\"`\n\x10ResponseAuthInfo\x12*\n\x14service_access_token\x18\x01 \x01(\x0b\x32\x0c.AccessToken\x12\r\n\x05nonce\x18\x02 \x01(\x0c\x12\x11\n\tsignature\x18\x03 \x01(\x0c\x62\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'auth_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals['_ACCESSTOKEN']._serialized_start=14
+ _globals['_ACCESSTOKEN']._serialized_end=109
+ _globals['_REQUESTAUTHINFO']._serialized_start=112
+ _globals['_REQUESTAUTHINFO']._serialized_end=248
+ _globals['_RESPONSEAUTHINFO']._serialized_start=250
+ _globals['_RESPONSEAUTHINFO']._serialized_end=346
+# @@protoc_insertion_point(module_scope)
diff --git a/hivemind/proto/averaging_pb2.py b/hivemind/proto/averaging_pb2.py
new file mode 100644
index 000000000..bc85e88ec
--- /dev/null
+++ b/hivemind/proto/averaging_pb2.py
@@ -0,0 +1,37 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: averaging.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+from . import runtime_pb2 as runtime__pb2
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x61veraging.proto\x1a\rruntime.proto\"n\n\x0bJoinRequest\x12\x13\n\x0bschema_hash\x18\x02 \x01(\x0c\x12\x12\n\nexpiration\x18\x03 \x01(\x01\x12\x0e\n\x06gather\x18\x04 \x01(\x0c\x12\x13\n\x0b\x63lient_mode\x18\x05 \x01(\x08\x12\x11\n\tgroup_key\x18\x06 \x01(\t\"\x87\x01\n\x11MessageFromLeader\x12\x1a\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x0c.MessageCode\x12\x10\n\x08group_id\x18\x02 \x01(\x0c\x12\x18\n\x10suggested_leader\x18\x03 \x01(\x0c\x12\x18\n\x10ordered_peer_ids\x18\x04 \x03(\x0c\x12\x10\n\x08gathered\x18\x05 \x03(\x0c\"|\n\rAveragingData\x12\x1a\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x0c.MessageCode\x12\x10\n\x08group_id\x18\x02 \x01(\x0c\x12\x0f\n\x07peer_id\x18\x03 \x01(\x0c\x12\x1c\n\x0btensor_part\x18\x04 \x01(\x0b\x32\x07.Tensor\x12\x0e\n\x06weight\x18\x05 \x01(\x01\"\x11\n\x0f\x44ownloadRequest\">\n\x0c\x44ownloadData\x12\x10\n\x08metadata\x18\x01 \x01(\x0c\x12\x1c\n\x0btensor_part\x18\x02 \x01(\x0b\x32\x07.Tensor*\x86\x03\n\x0bMessageCode\x12\x0b\n\x07NO_CODE\x10\x00\x12\x10\n\x0cREQUEST_JOIN\x10\x01\x12\x0c\n\x08\x41\x43\x43\x45PTED\x10\x02\x12\x13\n\x0f\x42\x45GIN_ALLREDUCE\x10\x03\x12\x16\n\x12PART_FOR_AVERAGING\x10\x04\x12\x11\n\rAVERAGED_PART\x10\x05\x12\x10\n\x0cNOT_DECLARED\x10\x06\x12\x10\n\x0cNOT_A_LEADER\x10\x07\x12\x17\n\x13\x42\x41\x44_EXPIRATION_TIME\x10\x08\x12\x13\n\x0f\x42\x41\x44_SCHEMA_HASH\x10\t\x12\x10\n\x0c\x42\x41\x44_GROUP_ID\x10\n\x12\x15\n\x11\x44UPLICATE_PEER_ID\x10\x0b\x12\x11\n\rGROUP_IS_FULL\x10\x0c\x12\x19\n\x15NOT_LOOKING_FOR_GROUP\x10\r\x12\x16\n\x12PROTOCOL_VIOLATION\x10\x0e\x12\x12\n\x0eINTERNAL_ERROR\x10\x0f\x12\r\n\tCANCELLED\x10\x10\x12\x13\n\x0fGROUP_DISBANDED\x10\x11\x12\x11\n\rBAD_GROUP_KEY\x10\x12\x62\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'averaging_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals['_MESSAGECODE']._serialized_start=494
+ _globals['_MESSAGECODE']._serialized_end=884
+ _globals['_JOINREQUEST']._serialized_start=34
+ _globals['_JOINREQUEST']._serialized_end=144
+ _globals['_MESSAGEFROMLEADER']._serialized_start=147
+ _globals['_MESSAGEFROMLEADER']._serialized_end=282
+ _globals['_AVERAGINGDATA']._serialized_start=284
+ _globals['_AVERAGINGDATA']._serialized_end=408
+ _globals['_DOWNLOADREQUEST']._serialized_start=410
+ _globals['_DOWNLOADREQUEST']._serialized_end=427
+ _globals['_DOWNLOADDATA']._serialized_start=429
+ _globals['_DOWNLOADDATA']._serialized_end=491
+# @@protoc_insertion_point(module_scope)
diff --git a/hivemind/proto/crypto_pb2.py b/hivemind/proto/crypto_pb2.py
new file mode 100644
index 000000000..09a408247
--- /dev/null
+++ b/hivemind/proto/crypto_pb2.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: crypto.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x63rypto.proto\x12\tcrypto.pb\"?\n\tPublicKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c\"@\n\nPrivateKey\x12$\n\x08key_type\x18\x01 \x02(\x0e\x32\x12.crypto.pb.KeyType\x12\x0c\n\x04\x64\x61ta\x18\x02 \x02(\x0c*9\n\x07KeyType\x12\x07\n\x03RSA\x10\x00\x12\x0b\n\x07\x45\x64\x32\x35\x35\x31\x39\x10\x01\x12\r\n\tSecp256k1\x10\x02\x12\t\n\x05\x45\x43\x44SA\x10\x03')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'crypto_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals['_KEYTYPE']._serialized_start=158
+ _globals['_KEYTYPE']._serialized_end=215
+ _globals['_PUBLICKEY']._serialized_start=27
+ _globals['_PUBLICKEY']._serialized_end=90
+ _globals['_PRIVATEKEY']._serialized_start=92
+ _globals['_PRIVATEKEY']._serialized_end=156
+# @@protoc_insertion_point(module_scope)
diff --git a/hivemind/proto/dht_pb2.py b/hivemind/proto/dht_pb2.py
new file mode 100644
index 000000000..864bb3c9f
--- /dev/null
+++ b/hivemind/proto/dht_pb2.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: dht.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+from . import auth_pb2 as auth__pb2
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tdht.proto\x1a\nauth.proto\"\x1b\n\x08NodeInfo\x12\x0f\n\x07node_id\x18\x01 \x01(\x0c\"X\n\x0bPingRequest\x12\x1e\n\x04\x61uth\x18\x01 \x01(\x0b\x32\x10.RequestAuthInfo\x12\x17\n\x04peer\x18\x02 \x01(\x0b\x32\t.NodeInfo\x12\x10\n\x08validate\x18\x03 \x01(\x08\"m\n\x0cPingResponse\x12\x1f\n\x04\x61uth\x18\x01 \x01(\x0b\x32\x11.ResponseAuthInfo\x12\x17\n\x04peer\x18\x02 \x01(\x0b\x32\t.NodeInfo\x12\x10\n\x08\x64ht_time\x18\x04 \x01(\x01\x12\x11\n\tavailable\x18\x05 \x01(\x08\"\xa1\x01\n\x0cStoreRequest\x12\x1e\n\x04\x61uth\x18\x01 \x01(\x0b\x32\x10.RequestAuthInfo\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\x12\x0f\n\x07subkeys\x18\x03 \x03(\x0c\x12\x0e\n\x06values\x18\x04 \x03(\x0c\x12\x17\n\x0f\x65xpiration_time\x18\x05 \x03(\x01\x12\x10\n\x08in_cache\x18\x06 \x03(\x08\x12\x17\n\x04peer\x18\x07 \x01(\x0b\x32\t.NodeInfo\"[\n\rStoreResponse\x12\x1f\n\x04\x61uth\x18\x01 \x01(\x0b\x32\x11.ResponseAuthInfo\x12\x10\n\x08store_ok\x18\x02 \x03(\x08\x12\x17\n\x04peer\x18\x03 \x01(\x0b\x32\t.NodeInfo\"T\n\x0b\x46indRequest\x12\x1e\n\x04\x61uth\x18\x01 \x01(\x0b\x32\x10.RequestAuthInfo\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\x12\x17\n\x04peer\x18\x03 \x01(\x0b\x32\t.NodeInfo\"\x83\x01\n\nFindResult\x12\x19\n\x04type\x18\x01 \x01(\x0e\x32\x0b.ResultType\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x17\n\x0f\x65xpiration_time\x18\x03 \x01(\x01\x12\x18\n\x10nearest_node_ids\x18\x04 \x03(\x0c\x12\x18\n\x10nearest_peer_ids\x18\x05 \x03(\x0c\"f\n\x0c\x46indResponse\x12\x1f\n\x04\x61uth\x18\x01 \x01(\x0b\x32\x11.ResponseAuthInfo\x12\x1c\n\x07results\x18\x02 \x03(\x0b\x32\x0b.FindResult\x12\x17\n\x04peer\x18\x03 \x01(\x0b\x32\t.NodeInfo*D\n\nResultType\x12\r\n\tNOT_FOUND\x10\x00\x12\x11\n\rFOUND_REGULAR\x10\x01\x12\x14\n\x10\x46OUND_DICTIONARY\x10\x02\x62\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'dht_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals['_RESULTTYPE']._serialized_start=836
+ _globals['_RESULTTYPE']._serialized_end=904
+ _globals['_NODEINFO']._serialized_start=25
+ _globals['_NODEINFO']._serialized_end=52
+ _globals['_PINGREQUEST']._serialized_start=54
+ _globals['_PINGREQUEST']._serialized_end=142
+ _globals['_PINGRESPONSE']._serialized_start=144
+ _globals['_PINGRESPONSE']._serialized_end=253
+ _globals['_STOREREQUEST']._serialized_start=256
+ _globals['_STOREREQUEST']._serialized_end=417
+ _globals['_STORERESPONSE']._serialized_start=419
+ _globals['_STORERESPONSE']._serialized_end=510
+ _globals['_FINDREQUEST']._serialized_start=512
+ _globals['_FINDREQUEST']._serialized_end=596
+ _globals['_FINDRESULT']._serialized_start=599
+ _globals['_FINDRESULT']._serialized_end=730
+ _globals['_FINDRESPONSE']._serialized_start=732
+ _globals['_FINDRESPONSE']._serialized_end=834
+# @@protoc_insertion_point(module_scope)
diff --git a/hivemind/proto/generate.sh b/hivemind/proto/generate.sh
new file mode 100644
index 000000000..e932c088f
--- /dev/null
+++ b/hivemind/proto/generate.sh
@@ -0,0 +1 @@
+python -m grpc_tools.protoc --proto_path=. --python_out=. auth.proto averaging.proto crypto.proto dht.proto p2pd.proto runtime.proto test.proto
diff --git a/hivemind/proto/p2pd.proto b/hivemind/proto/p2pd.proto
index b0962d62a..b2ae15107 100644
--- a/hivemind/proto/p2pd.proto
+++ b/hivemind/proto/p2pd.proto
@@ -12,12 +12,12 @@ message Request {
CONNECT = 1;
STREAM_OPEN = 2;
STREAM_HANDLER = 3;
+ REMOVE_STREAM_HANDLER = 10;
DHT = 4;
LIST_PEERS = 5;
CONNMANAGER = 6;
DISCONNECT = 7;
PUBSUB = 8;
-
PERSISTENT_CONN_UPGRADE = 9;
}
@@ -26,6 +26,7 @@ message Request {
optional ConnectRequest connect = 2;
optional StreamOpenRequest streamOpen = 3;
optional StreamHandlerRequest streamHandler = 4;
+ optional RemoveStreamHandlerRequest removeStreamHandler = 9;
optional DHTRequest dht = 5;
optional ConnManagerRequest connManager = 6;
optional DisconnectRequest disconnect = 7;
@@ -52,6 +53,7 @@ message PersistentConnectionRequest {
oneof message {
AddUnaryHandlerRequest addUnaryHandler = 2;
+ RemoveUnaryHandlerRequest removeUnaryHandler = 6;
CallUnaryRequest callUnary = 3;
CallUnaryResponse unaryResponse = 4;
Cancel cancel = 5;
@@ -93,6 +95,11 @@ message StreamHandlerRequest {
required bool balanced = 3;
}
+message RemoveStreamHandlerRequest {
+ required bytes addr = 1;
+ repeated string proto = 2;
+}
+
message ErrorResponse {
required string msg = 1;
}
@@ -205,6 +212,10 @@ message AddUnaryHandlerRequest {
required bool balanced = 2;
}
+message RemoveUnaryHandlerRequest {
+ required string proto = 1;
+}
+
message DaemonError {
optional string message = 1;
}
diff --git a/hivemind/proto/p2pd_pb2.py b/hivemind/proto/p2pd_pb2.py
new file mode 100644
index 000000000..0f42df008
--- /dev/null
+++ b/hivemind/proto/p2pd_pb2.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: p2pd.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\np2pd.proto\x12\x11p2pclient.p2pd.pb\"\xc9\x05\n\x07Request\x12-\n\x04type\x18\x01 \x02(\x0e\x32\x1f.p2pclient.p2pd.pb.Request.Type\x12\x32\n\x07\x63onnect\x18\x02 \x01(\x0b\x32!.p2pclient.p2pd.pb.ConnectRequest\x12\x38\n\nstreamOpen\x18\x03 \x01(\x0b\x32$.p2pclient.p2pd.pb.StreamOpenRequest\x12>\n\rstreamHandler\x18\x04 \x01(\x0b\x32\'.p2pclient.p2pd.pb.StreamHandlerRequest\x12J\n\x13removeStreamHandler\x18\t \x01(\x0b\x32-.p2pclient.p2pd.pb.RemoveStreamHandlerRequest\x12*\n\x03\x64ht\x18\x05 \x01(\x0b\x32\x1d.p2pclient.p2pd.pb.DHTRequest\x12:\n\x0b\x63onnManager\x18\x06 \x01(\x0b\x32%.p2pclient.p2pd.pb.ConnManagerRequest\x12\x38\n\ndisconnect\x18\x07 \x01(\x0b\x32$.p2pclient.p2pd.pb.DisconnectRequest\x12,\n\x06pubsub\x18\x08 \x01(\x0b\x32\x1c.p2pclient.p2pd.pb.PSRequest\"\xc4\x01\n\x04Type\x12\x0c\n\x08IDENTIFY\x10\x00\x12\x0b\n\x07\x43ONNECT\x10\x01\x12\x0f\n\x0bSTREAM_OPEN\x10\x02\x12\x12\n\x0eSTREAM_HANDLER\x10\x03\x12\x19\n\x15REMOVE_STREAM_HANDLER\x10\n\x12\x07\n\x03\x44HT\x10\x04\x12\x0e\n\nLIST_PEERS\x10\x05\x12\x0f\n\x0b\x43ONNMANAGER\x10\x06\x12\x0e\n\nDISCONNECT\x10\x07\x12\n\n\x06PUBSUB\x10\x08\x12\x1b\n\x17PERSISTENT_CONN_UPGRADE\x10\t\"\xf8\x02\n\x08Response\x12.\n\x04type\x18\x01 \x02(\x0e\x32 .p2pclient.p2pd.pb.Response.Type\x12/\n\x05\x65rror\x18\x02 \x01(\x0b\x32 .p2pclient.p2pd.pb.ErrorResponse\x12\x31\n\nstreamInfo\x18\x03 \x01(\x0b\x32\x1d.p2pclient.p2pd.pb.StreamInfo\x12\x35\n\x08identify\x18\x04 \x01(\x0b\x32#.p2pclient.p2pd.pb.IdentifyResponse\x12+\n\x03\x64ht\x18\x05 \x01(\x0b\x32\x1e.p2pclient.p2pd.pb.DHTResponse\x12*\n\x05peers\x18\x06 \x03(\x0b\x32\x1b.p2pclient.p2pd.pb.PeerInfo\x12-\n\x06pubsub\x18\x07 \x01(\x0b\x32\x1d.p2pclient.p2pd.pb.PSResponse\"\x19\n\x04Type\x12\x06\n\x02OK\x10\x00\x12\t\n\x05\x45RROR\x10\x01\"\xf0\x02\n\x1bPersistentConnectionRequest\x12\x0e\n\x06\x63\x61llId\x18\x01 \x02(\x0c\x12\x44\n\x0f\x61\x64\x64UnaryHandler\x18\x02 \x01(\x0b\x32).p2pclient.p2pd.pb.AddUnaryHandlerRequestH\x00\x12J\n\x12removeUnaryHandler\x18\x06 \x01(\x0b\x32,.p2pclient.p2pd.pb.RemoveUnaryHandlerRequestH\x00\x12\x38\n\tcallUnary\x18\x03 \x01(\x0b\x32#.p2pclient.p2pd.pb.CallUnaryRequestH\x00\x12=\n\runaryResponse\x18\x04 \x01(\x0b\x32$.p2pclient.p2pd.pb.CallUnaryResponseH\x00\x12+\n\x06\x63\x61ncel\x18\x05 \x01(\x0b\x32\x19.p2pclient.p2pd.pb.CancelH\x00\x42\t\n\x07message\"\xa0\x02\n\x1cPersistentConnectionResponse\x12\x0e\n\x06\x63\x61llId\x18\x01 \x02(\x0c\x12\x41\n\x11\x63\x61llUnaryResponse\x18\x02 \x01(\x0b\x32$.p2pclient.p2pd.pb.CallUnaryResponseH\x00\x12>\n\x0frequestHandling\x18\x03 \x01(\x0b\x32#.p2pclient.p2pd.pb.CallUnaryRequestH\x00\x12\x35\n\x0b\x64\x61\x65monError\x18\x04 \x01(\x0b\x32\x1e.p2pclient.p2pd.pb.DaemonErrorH\x00\x12+\n\x06\x63\x61ncel\x18\x05 \x01(\x0b\x32\x19.p2pclient.p2pd.pb.CancelH\x00\x42\t\n\x07message\"-\n\x10IdentifyResponse\x12\n\n\x02id\x18\x01 \x02(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\">\n\x0e\x43onnectRequest\x12\x0c\n\x04peer\x18\x01 \x02(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12\x0f\n\x07timeout\x18\x03 \x01(\x03\"A\n\x11StreamOpenRequest\x12\x0c\n\x04peer\x18\x01 \x02(\x0c\x12\r\n\x05proto\x18\x02 \x03(\t\x12\x0f\n\x07timeout\x18\x03 \x01(\x03\"E\n\x14StreamHandlerRequest\x12\x0c\n\x04\x61\x64\x64r\x18\x01 \x02(\x0c\x12\r\n\x05proto\x18\x02 \x03(\t\x12\x10\n\x08\x62\x61lanced\x18\x03 \x02(\x08\"9\n\x1aRemoveStreamHandlerRequest\x12\x0c\n\x04\x61\x64\x64r\x18\x01 \x02(\x0c\x12\r\n\x05proto\x18\x02 \x03(\t\"\x1c\n\rErrorResponse\x12\x0b\n\x03msg\x18\x01 \x02(\t\"7\n\nStreamInfo\x12\x0c\n\x04peer\x18\x01 \x02(\x0c\x12\x0c\n\x04\x61\x64\x64r\x18\x02 \x02(\x0c\x12\r\n\x05proto\x18\x03 \x02(\t\"\xcb\x02\n\nDHTRequest\x12\x30\n\x04type\x18\x01 \x02(\x0e\x32\".p2pclient.p2pd.pb.DHTRequest.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12\x0b\n\x03\x63id\x18\x03 \x01(\x0c\x12\x0b\n\x03key\x18\x04 \x01(\x0c\x12\r\n\x05value\x18\x05 \x01(\x0c\x12\r\n\x05\x63ount\x18\x06 \x01(\x05\x12\x0f\n\x07timeout\x18\x07 \x01(\x03\"\xb3\x01\n\x04Type\x12\r\n\tFIND_PEER\x10\x00\x12 \n\x1c\x46IND_PEERS_CONNECTED_TO_PEER\x10\x01\x12\x12\n\x0e\x46IND_PROVIDERS\x10\x02\x12\x15\n\x11GET_CLOSEST_PEERS\x10\x03\x12\x12\n\x0eGET_PUBLIC_KEY\x10\x04\x12\r\n\tGET_VALUE\x10\x05\x12\x10\n\x0cSEARCH_VALUE\x10\x06\x12\r\n\tPUT_VALUE\x10\x07\x12\x0b\n\x07PROVIDE\x10\x08\"\xa1\x01\n\x0b\x44HTResponse\x12\x31\n\x04type\x18\x01 \x02(\x0e\x32#.p2pclient.p2pd.pb.DHTResponse.Type\x12)\n\x04peer\x18\x02 \x01(\x0b\x32\x1b.p2pclient.p2pd.pb.PeerInfo\x12\r\n\x05value\x18\x03 \x01(\x0c\"%\n\x04Type\x12\t\n\x05\x42\x45GIN\x10\x00\x12\t\n\x05VALUE\x10\x01\x12\x07\n\x03\x45ND\x10\x02\"%\n\x08PeerInfo\x12\n\n\x02id\x18\x01 \x02(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\"\xa9\x01\n\x12\x43onnManagerRequest\x12\x38\n\x04type\x18\x01 \x02(\x0e\x32*.p2pclient.p2pd.pb.ConnManagerRequest.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12\x0b\n\x03tag\x18\x03 \x01(\t\x12\x0e\n\x06weight\x18\x04 \x01(\x03\".\n\x04Type\x12\x0c\n\x08TAG_PEER\x10\x00\x12\x0e\n\nUNTAG_PEER\x10\x01\x12\x08\n\x04TRIM\x10\x02\"!\n\x11\x44isconnectRequest\x12\x0c\n\x04peer\x18\x01 \x02(\x0c\"\x9d\x01\n\tPSRequest\x12/\n\x04type\x18\x01 \x02(\x0e\x32!.p2pclient.p2pd.pb.PSRequest.Type\x12\r\n\x05topic\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"B\n\x04Type\x12\x0e\n\nGET_TOPICS\x10\x00\x12\x0e\n\nLIST_PEERS\x10\x01\x12\x0b\n\x07PUBLISH\x10\x02\x12\r\n\tSUBSCRIBE\x10\x03\"h\n\tPSMessage\x12\x0c\n\x04\x66rom\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"-\n\nPSResponse\x12\x0e\n\x06topics\x18\x01 \x03(\t\x12\x0f\n\x07peerIDs\x18\x02 \x03(\x0c\"=\n\x10\x43\x61llUnaryRequest\x12\x0c\n\x04peer\x18\x01 \x02(\x0c\x12\r\n\x05proto\x18\x02 \x02(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x02(\x0c\"B\n\x11\x43\x61llUnaryResponse\x12\x12\n\x08response\x18\x01 \x01(\x0cH\x00\x12\x0f\n\x05\x65rror\x18\x02 \x01(\x0cH\x00\x42\x08\n\x06result\"9\n\x16\x41\x64\x64UnaryHandlerRequest\x12\r\n\x05proto\x18\x01 \x02(\t\x12\x10\n\x08\x62\x61lanced\x18\x02 \x02(\x08\"*\n\x19RemoveUnaryHandlerRequest\x12\r\n\x05proto\x18\x01 \x02(\t\"\x1e\n\x0b\x44\x61\x65monError\x12\x0f\n\x07message\x18\x01 \x01(\t\"\x08\n\x06\x43\x61ncel\"\x1b\n\x08RPCError\x12\x0f\n\x07message\x18\x01 \x01(\t')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'p2pd_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals['_REQUEST']._serialized_start=34
+ _globals['_REQUEST']._serialized_end=747
+ _globals['_REQUEST_TYPE']._serialized_start=551
+ _globals['_REQUEST_TYPE']._serialized_end=747
+ _globals['_RESPONSE']._serialized_start=750
+ _globals['_RESPONSE']._serialized_end=1126
+ _globals['_RESPONSE_TYPE']._serialized_start=1101
+ _globals['_RESPONSE_TYPE']._serialized_end=1126
+ _globals['_PERSISTENTCONNECTIONREQUEST']._serialized_start=1129
+ _globals['_PERSISTENTCONNECTIONREQUEST']._serialized_end=1497
+ _globals['_PERSISTENTCONNECTIONRESPONSE']._serialized_start=1500
+ _globals['_PERSISTENTCONNECTIONRESPONSE']._serialized_end=1788
+ _globals['_IDENTIFYRESPONSE']._serialized_start=1790
+ _globals['_IDENTIFYRESPONSE']._serialized_end=1835
+ _globals['_CONNECTREQUEST']._serialized_start=1837
+ _globals['_CONNECTREQUEST']._serialized_end=1899
+ _globals['_STREAMOPENREQUEST']._serialized_start=1901
+ _globals['_STREAMOPENREQUEST']._serialized_end=1966
+ _globals['_STREAMHANDLERREQUEST']._serialized_start=1968
+ _globals['_STREAMHANDLERREQUEST']._serialized_end=2037
+ _globals['_REMOVESTREAMHANDLERREQUEST']._serialized_start=2039
+ _globals['_REMOVESTREAMHANDLERREQUEST']._serialized_end=2096
+ _globals['_ERRORRESPONSE']._serialized_start=2098
+ _globals['_ERRORRESPONSE']._serialized_end=2126
+ _globals['_STREAMINFO']._serialized_start=2128
+ _globals['_STREAMINFO']._serialized_end=2183
+ _globals['_DHTREQUEST']._serialized_start=2186
+ _globals['_DHTREQUEST']._serialized_end=2517
+ _globals['_DHTREQUEST_TYPE']._serialized_start=2338
+ _globals['_DHTREQUEST_TYPE']._serialized_end=2517
+ _globals['_DHTRESPONSE']._serialized_start=2520
+ _globals['_DHTRESPONSE']._serialized_end=2681
+ _globals['_DHTRESPONSE_TYPE']._serialized_start=2644
+ _globals['_DHTRESPONSE_TYPE']._serialized_end=2681
+ _globals['_PEERINFO']._serialized_start=2683
+ _globals['_PEERINFO']._serialized_end=2720
+ _globals['_CONNMANAGERREQUEST']._serialized_start=2723
+ _globals['_CONNMANAGERREQUEST']._serialized_end=2892
+ _globals['_CONNMANAGERREQUEST_TYPE']._serialized_start=2846
+ _globals['_CONNMANAGERREQUEST_TYPE']._serialized_end=2892
+ _globals['_DISCONNECTREQUEST']._serialized_start=2894
+ _globals['_DISCONNECTREQUEST']._serialized_end=2927
+ _globals['_PSREQUEST']._serialized_start=2930
+ _globals['_PSREQUEST']._serialized_end=3087
+ _globals['_PSREQUEST_TYPE']._serialized_start=3021
+ _globals['_PSREQUEST_TYPE']._serialized_end=3087
+ _globals['_PSMESSAGE']._serialized_start=3089
+ _globals['_PSMESSAGE']._serialized_end=3193
+ _globals['_PSRESPONSE']._serialized_start=3195
+ _globals['_PSRESPONSE']._serialized_end=3240
+ _globals['_CALLUNARYREQUEST']._serialized_start=3242
+ _globals['_CALLUNARYREQUEST']._serialized_end=3303
+ _globals['_CALLUNARYRESPONSE']._serialized_start=3305
+ _globals['_CALLUNARYRESPONSE']._serialized_end=3371
+ _globals['_ADDUNARYHANDLERREQUEST']._serialized_start=3373
+ _globals['_ADDUNARYHANDLERREQUEST']._serialized_end=3430
+ _globals['_REMOVEUNARYHANDLERREQUEST']._serialized_start=3432
+ _globals['_REMOVEUNARYHANDLERREQUEST']._serialized_end=3474
+ _globals['_DAEMONERROR']._serialized_start=3476
+ _globals['_DAEMONERROR']._serialized_end=3506
+ _globals['_CANCEL']._serialized_start=3508
+ _globals['_CANCEL']._serialized_end=3516
+ _globals['_RPCERROR']._serialized_start=3518
+ _globals['_RPCERROR']._serialized_end=3545
+# @@protoc_insertion_point(module_scope)
diff --git a/hivemind/proto/runtime.proto b/hivemind/proto/runtime.proto
index 14704113f..bba7a268c 100644
--- a/hivemind/proto/runtime.proto
+++ b/hivemind/proto/runtime.proto
@@ -12,10 +12,12 @@ message ExpertInfo {
message ExpertRequest {
string uid = 1;
repeated Tensor tensors = 2;
+ bytes metadata = 3;
}
message ExpertResponse {
repeated Tensor tensors = 2;
+ bytes metadata = 3;
}
enum CompressionType{
@@ -24,6 +26,7 @@ enum CompressionType{
FLOAT16 = 2;
QUANTILE_8BIT = 3;
UNIFORM_8BIT = 4;
+ BLOCKWISE_8BIT = 5;
}
message Tensor {
diff --git a/hivemind/proto/runtime_pb2.py b/hivemind/proto/runtime_pb2.py
new file mode 100644
index 000000000..3ed17b3d8
--- /dev/null
+++ b/hivemind/proto/runtime_pb2.py
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: runtime.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rruntime.proto\"\x18\n\tExpertUID\x12\x0b\n\x03uid\x18\x01 \x01(\t\"%\n\nExpertInfo\x12\x17\n\x0fserialized_info\x18\x01 \x01(\x0c\"H\n\rExpertRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x18\n\x07tensors\x18\x02 \x03(\x0b\x32\x07.Tensor\x12\x10\n\x08metadata\x18\x03 \x01(\x0c\"<\n\x0e\x45xpertResponse\x12\x18\n\x07tensors\x18\x02 \x03(\x0b\x32\x07.Tensor\x12\x10\n\x08metadata\x18\x03 \x01(\x0c\"\x83\x01\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\x0c\n\x04size\x18\x02 \x03(\r\x12\x15\n\rrequires_grad\x18\x03 \x01(\x08\x12\r\n\x05\x64type\x18\x04 \x01(\t\x12%\n\x0b\x63ompression\x18\x05 \x01(\x0e\x32\x10.CompressionType\x12\x0e\n\x06\x63hunks\x18\x06 \x01(\x05*t\n\x0f\x43ompressionType\x12\x08\n\x04NONE\x10\x00\x12\x11\n\rMEANSTD_16BIT\x10\x01\x12\x0b\n\x07\x46LOAT16\x10\x02\x12\x11\n\rQUANTILE_8BIT\x10\x03\x12\x10\n\x0cUNIFORM_8BIT\x10\x04\x12\x12\n\x0e\x42LOCKWISE_8BIT\x10\x05\x62\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'runtime_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals['_COMPRESSIONTYPE']._serialized_start=352
+ _globals['_COMPRESSIONTYPE']._serialized_end=468
+ _globals['_EXPERTUID']._serialized_start=17
+ _globals['_EXPERTUID']._serialized_end=41
+ _globals['_EXPERTINFO']._serialized_start=43
+ _globals['_EXPERTINFO']._serialized_end=80
+ _globals['_EXPERTREQUEST']._serialized_start=82
+ _globals['_EXPERTREQUEST']._serialized_end=154
+ _globals['_EXPERTRESPONSE']._serialized_start=156
+ _globals['_EXPERTRESPONSE']._serialized_end=216
+ _globals['_TENSOR']._serialized_start=219
+ _globals['_TENSOR']._serialized_end=350
+# @@protoc_insertion_point(module_scope)
diff --git a/hivemind/proto/test_pb2.py b/hivemind/proto/test_pb2.py
new file mode 100644
index 000000000..da89acd62
--- /dev/null
+++ b/hivemind/proto/test_pb2.py
@@ -0,0 +1,28 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: test.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ntest.proto\"\x1d\n\x0bTestRequest\x12\x0e\n\x06number\x18\x01 \x01(\x05\"\x1e\n\x0cTestResponse\x12\x0e\n\x06number\x18\x01 \x01(\x05\x62\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'test_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals['_TESTREQUEST']._serialized_start=14
+ _globals['_TESTREQUEST']._serialized_end=43
+ _globals['_TESTRESPONSE']._serialized_start=45
+ _globals['_TESTRESPONSE']._serialized_end=75
+# @@protoc_insertion_point(module_scope)
diff --git a/hivemind/utils/asyncio.py b/hivemind/utils/asyncio.py
index e86f8a15c..af2af1aca 100644
--- a/hivemind/utils/asyncio.py
+++ b/hivemind/utils/asyncio.py
@@ -1,5 +1,7 @@
import asyncio
import concurrent.futures
+import multiprocessing as mp
+import os
from concurrent.futures import ThreadPoolExecutor
from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar, Union
@@ -167,12 +169,25 @@ async def attach_event_on_finished(iterable: AsyncIterable[T], event: asyncio.Ev
class _AsyncContextWrapper(AbstractAsyncContextManager):
"""Wrapper for a non-async context manager that allows entering and exiting it in EventLoop-friendly manner"""
+ EXECUTOR_PID = None
+ CONTEXT_EXECUTOR = None
+ EXECUTOR_LOCK = mp.Lock()
+
def __init__(self, context: AbstractContextManager):
self._context = context
+ @classmethod
+ def get_process_wide_executor(cls):
+ if os.getpid() != cls.EXECUTOR_PID:
+ with cls.EXECUTOR_LOCK:
+ if os.getpid() != cls.EXECUTOR_PID:
+ cls.CONTEXT_EXECUTOR = ThreadPoolExecutor(max_workers=float("inf"))
+ cls.EXECUTOR_PID = os.getpid()
+ return cls.CONTEXT_EXECUTOR
+
async def __aenter__(self):
loop = asyncio.get_event_loop()
- return await loop.run_in_executor(None, self._context.__enter__)
+ return await loop.run_in_executor(self.get_process_wide_executor(), self._context.__enter__)
async def __aexit__(self, exc_type, exc_value, traceback):
return self._context.__exit__(exc_type, exc_value, traceback)
diff --git a/hivemind/utils/logging.py b/hivemind/utils/logging.py
index b42052e48..79c347122 100644
--- a/hivemind/utils/logging.py
+++ b/hivemind/utils/logging.py
@@ -3,20 +3,32 @@
import sys
import threading
from enum import Enum
-from typing import Optional, Union
+from typing import Any, Optional, Union
-logging.addLevelName(logging.WARNING, "WARN")
+def in_ipython() -> bool:
+ """Check if the code is run in IPython, Jupyter, or Colab"""
+
+ try:
+ __IPYTHON__
+ return True
+ except NameError:
+ return False
+
+
+logging.addLevelName(logging.WARNING, "WARN")
loglevel = os.getenv("HIVEMIND_LOGLEVEL", "INFO")
+TRUE_CONSTANTS = ["TRUE", "1"]
+
_env_colors = os.getenv("HIVEMIND_COLORS")
if _env_colors is not None:
- use_colors = _env_colors.lower() == "true"
+ use_colors = _env_colors.upper() in TRUE_CONSTANTS
else:
- use_colors = sys.stderr.isatty()
+ use_colors = sys.stderr.isatty() or in_ipython()
-_env_log_caller = os.getenv("HIVEMIND_ALWAYS_LOG_CALLER")
-always_log_caller = _env_log_caller is not None and _env_log_caller.lower() == "true"
+_env_log_caller = os.getenv("HIVEMIND_ALWAYS_LOG_CALLER", "0")
+always_log_caller = _env_log_caller.upper() in TRUE_CONSTANTS
class HandlerMode(Enum):
@@ -30,7 +42,14 @@ class HandlerMode(Enum):
_default_handler = None
-class TextStyle:
+class _DisableIfNoColors(type):
+ def __getattribute__(self, name: str) -> Any:
+ if name.isupper() and not use_colors:
+ return ""
+ return super().__getattribute__(name)
+
+
+class TextStyle(metaclass=_DisableIfNoColors):
"""
ANSI escape codes. Details: https://en.wikipedia.org/wiki/ANSI_escape_code#Colors
"""
@@ -42,11 +61,6 @@ class TextStyle:
PURPLE = "\033[35m"
ORANGE = "\033[38;5;208m" # From 8-bit palette
- if not use_colors:
- # Set the constants above to empty strings
- _codes = locals()
- _codes.update({_name: "" for _name in list(_codes) if _name.isupper()})
-
class CustomFormatter(logging.Formatter):
"""
@@ -115,14 +129,21 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
return logging.getLogger(name)
-def _enable_default_handler(name: str) -> None:
+def _enable_default_handler(name: Optional[str]) -> None:
logger = get_logger(name)
+
+ # Remove the extra default handler in the Colab's default logger before adding a new one
+ if isinstance(logger, logging.RootLogger):
+ for handler in list(logger.handlers):
+ if isinstance(handler, logging.StreamHandler) and handler.stream is sys.stderr:
+ logger.removeHandler(handler)
+
logger.addHandler(_default_handler)
logger.propagate = False
logger.setLevel(loglevel)
-def _disable_default_handler(name: str) -> None:
+def _disable_default_handler(name: Optional[str]) -> None:
logger = get_logger(name)
logger.removeHandler(_default_handler)
logger.propagate = True
diff --git a/hivemind/utils/math.py b/hivemind/utils/math.py
index f6d1098b5..901b971a3 100644
--- a/hivemind/utils/math.py
+++ b/hivemind/utils/math.py
@@ -15,7 +15,7 @@ def orthogonalize_(matrix, eps: float = 1e-8):
def get_flatten_greedy_dims(tensor: torch.Tensor, max_ndim: int = 2):
- """get dims to flatten tensor upto max_ndim dimensions by merging small axes together"""
+ """get dims to flatten tensor up to max_ndim dimensions by merging small axes together"""
dims = list(tensor.shape)
while len(dims) > max_ndim:
squeeze_ix = min(range(len(dims) - 1), key=lambda i: dims[i] * dims[i + 1])
diff --git a/hivemind/utils/mpfuture.py b/hivemind/utils/mpfuture.py
index 49f219067..11952811a 100644
--- a/hivemind/utils/mpfuture.py
+++ b/hivemind/utils/mpfuture.py
@@ -3,12 +3,12 @@
import asyncio
import concurrent.futures._base as base
import multiprocessing as mp
-import multiprocessing.connection
import os
import threading
import uuid
from contextlib import nullcontext
from enum import Enum, auto
+from multiprocessing.reduction import ForkingPickler
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
from weakref import ref
@@ -95,6 +95,8 @@ class MPFuture(base.Future, Generic[ResultType]):
_active_pid: Optional[PID] = None # pid of currently active process; used to handle forks natively
def __init__(self, *, use_lock: bool = True):
+ self._maybe_initialize_mpfuture_backend()
+
self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
self._shared_state_code = SharedBytes.next()
self._state_cache: Dict[State, State] = {}
@@ -105,11 +107,6 @@ def __init__(self, *, use_lock: bool = True):
self._state, self._result, self._exception = base.PENDING, None, None
self._use_lock = use_lock
- if self._origin_pid != MPFuture._active_pid:
- with MPFuture._initialization_lock:
- if self._origin_pid != MPFuture._active_pid:
- # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
- self._initialize_mpfuture_backend()
assert self._uid not in MPFuture._active_futures
MPFuture._active_futures[self._uid] = ref(self)
self._sender_pipe = MPFuture._global_sender_pipe
@@ -127,7 +124,8 @@ def _state(self) -> State:
@_state.setter
def _state(self, new_state: State):
- self._shared_state_code[...] = ALL_STATES.index(new_state)
+ with torch.inference_mode():
+ self._shared_state_code[...] = ALL_STATES.index(new_state)
if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
self._set_event_threadsafe()
@@ -150,16 +148,23 @@ async def _event_setter():
self._loop.run_until_complete(_event_setter())
@classmethod
- def _initialize_mpfuture_backend(cls):
+ def _maybe_initialize_mpfuture_backend(cls):
pid = os.getpid()
- logger.debug(f"Initializing MPFuture backend for pid {pid}")
-
- receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
- cls._active_pid, cls._active_futures = pid, {}
- cls._pipe_waiter_thread = threading.Thread(
- target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
- )
- cls._pipe_waiter_thread.start()
+ if pid != MPFuture._active_pid:
+ with MPFuture._initialization_lock:
+ if pid != MPFuture._active_pid:
+ # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
+ logger.debug(f"Initializing MPFuture backend for pid {pid}")
+
+ receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
+ cls._active_pid, cls._active_futures = pid, {}
+ cls._pipe_waiter_thread = threading.Thread(
+ target=cls._process_updates_in_background,
+ args=[receiver_pipe],
+ name=f"{__name__}.BACKEND",
+ daemon=True,
+ )
+ cls._pipe_waiter_thread.start()
@staticmethod
def reset_backend():
@@ -295,7 +300,7 @@ def __await__(self):
raise asyncio.CancelledError()
def __del__(self):
- if getattr(self, "_origin_pid", None) == os.getpid():
+ if getattr(self, "_origin_pid", None) == os.getpid() and MPFuture._active_futures is not None:
MPFuture._active_futures.pop(self._uid, None)
if getattr(self, "_aio_event", None):
self._aio_event.set()
@@ -303,7 +308,7 @@ def __del__(self):
def __getstate__(self):
return dict(
_sender_pipe=self._sender_pipe,
- _shared_state_code=self._shared_state_code,
+ _shared_state_code=ForkingPickler.dumps(self._shared_state_code).tobytes(),
_origin_pid=self._origin_pid,
_uid=self._uid,
_use_lock=self._use_lock,
@@ -313,7 +318,14 @@ def __getstate__(self):
def __setstate__(self, state):
self._sender_pipe = state["_sender_pipe"]
- self._shared_state_code = state["_shared_state_code"]
+ try:
+ self._shared_state_code = ForkingPickler.loads(state["_shared_state_code"])
+ except RuntimeError:
+ # If the origin process garbage-collects all instances of MPFuture using the same shmem buffer,
+ # the underlying buffer is freed, and we will get RuntimeError ("unable to open shared memory object")
+ # here since it is not possible to connect to this buffer anymore. To address this, we just replace
+ # the buffer with a non-shared tensor since the origin process doesn't care about our state anymore.
+ self._shared_state_code = torch.tensor([ALL_STATES.index(base.PENDING)], dtype=torch.uint8)
self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
self._result, self._exception = state["_result"], state["_exception"]
self._use_lock = state["_use_lock"]
diff --git a/pyproject.toml b/pyproject.toml
index 12b2b642e..dd72f66e4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,3 +8,8 @@ line_length = 119
combine_as_imports = true
combine_star = true
known_local_folder = ["arguments", "test_utils", "tests", "utils"]
+
+[tool.coverage.run]
+concurrency = ["thread", "multiprocessing"]
+omit = ["hivemind/proto/*"]
+source = ["hivemind"]
diff --git a/requirements-dev.txt b/requirements-dev.txt
index ba6da8000..8398751aa 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -5,7 +5,7 @@ pytest-cov
coverage==6.0.2 # see https://github.com/pytest-dev/pytest-cov/issues/520
tqdm
scikit-learn
-torchvision
black==22.3.0
isort==5.10.1
+codespell==2.2.2
psutil
diff --git a/requirements.txt b/requirements.txt
index a81dc9a8f..df60317d8 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
PyYAML
-torch>=1.6.0
+torch>=1.9.0
numpy>=1.17
scipy>=1.2.1
prefetch_generator>=1.0.1
@@ -9,7 +9,8 @@ uvloop>=0.14.0
grpcio-tools>=1.33.2
protobuf>=3.12.2
configargparse>=1.2.3
-multiaddr>=0.0.9
-pymultihash>=0.8.2
+py-multihash>=0.2.3
+multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@e01dbd38f2c0464c0f78b556691d655265018cce
cryptography>=3.4.6
-pydantic>=1.8.1
+pydantic>=1.8.1,<2.0
+packaging>=20.9
diff --git a/setup.py b/setup.py
index 7a739c7c1..8f0dcedbe 100644
--- a/setup.py
+++ b/setup.py
@@ -2,6 +2,7 @@
import glob
import hashlib
import os
+import platform
import re
import subprocess
import tarfile
@@ -13,17 +14,19 @@
from setuptools.command.build_py import build_py
from setuptools.command.develop import develop
-P2PD_VERSION = "v0.3.9"
+P2PD_VERSION = "v0.5.0.hivemind1"
P2PD_SOURCE_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/"
# The value is sha256 of the binary from the release page
-EXECUTABLES = {
- "p2pd": "8f9434f4717f6e851430f75f07e283d5ddeb2c7cde1b3648e677d813703f4e40",
+P2P_BINARY_HASH = {
+ "p2pd-darwin-amd64": "fe00f9d79e8e4e4c007144d19da10b706c84187b3fb84de170f4664c91ecda80",
+ "p2pd-darwin-arm64": "0404981a9c2b7cab5425ead2633d006c61c2c7ec85ac564ef69413ed470e65bd",
+ "p2pd-linux-amd64": "42f8f48e62583b97cdba3c31439c08029fb2b9fc506b5bdd82c46b7cc1d279d8",
+ "p2pd-linux-arm64": "046f18480c785a84bdf139d7486086d379397ca106cb2f0191598da32f81447a",
}
-
here = os.path.abspath(os.path.dirname(__file__))
@@ -72,31 +75,44 @@ def build_p2p_daemon():
with tarfile.open(dest, "r:gz") as tar:
tar.extractall(tempdir)
- for executable in EXECUTABLES:
- result = subprocess.run(
- ["go", "build", "-o", os.path.join(here, "hivemind", "hivemind_cli", executable)],
- cwd=os.path.join(tempdir, f"go-libp2p-daemon-{P2PD_VERSION.lstrip('v')}", executable),
- )
- if result.returncode != 0:
- raise RuntimeError(f"Failed to build {executable}: exited with status code: {result.returncode}")
+ result = subprocess.run(
+ ["go", "build", "-o", os.path.join(here, "hivemind", "hivemind_cli", "p2pd")],
+ cwd=os.path.join(tempdir, f"go-libp2p-daemon-{P2PD_VERSION.lstrip('v')}", "p2pd"),
+ )
+ if result.returncode != 0:
+ raise RuntimeError(f"Failed to build p2pd: exited with status code: {result.returncode}")
def download_p2p_daemon():
- for executable, expected_hash in EXECUTABLES.items():
- binary_path = os.path.join(here, "hivemind", "hivemind_cli", executable)
-
- if sha256(binary_path) != expected_hash:
- binary_url = os.path.join(P2PD_BINARY_URL, executable)
- print(f"Downloading {binary_url}")
-
- urllib.request.urlretrieve(binary_url, binary_path)
- os.chmod(binary_path, 0o777)
-
- actual_hash = sha256(binary_path)
- if actual_hash != expected_hash:
- raise RuntimeError(
- f"The sha256 checksum for {executable} does not match (expected: {expected_hash}, actual: {actual_hash})"
- )
+ binary_path = os.path.join(here, "hivemind", "hivemind_cli", "p2pd")
+ arch = platform.machine()
+ # An architecture name may vary depending on the OS (e.g., the same CPU is arm64 on macOS and aarch64 on Linux).
+ # We consider multiple aliases here, see https://stackoverflow.com/questions/45125516/possible-values-for-uname-m
+ if arch in ("x86_64", "x64"):
+ arch = "amd64"
+ if arch in ("aarch64", "aarch64_be", "armv8b", "armv8l"):
+ arch = "arm64"
+ binary_name = f"p2pd-{platform.system().lower()}-{arch}"
+
+ if binary_name not in P2P_BINARY_HASH:
+ raise RuntimeError(
+ f"hivemind does not provide a precompiled p2pd binary for {platform.system()} ({arch}). "
+ f"Please install Go and build it from source: https://github.com/learning-at-home/hivemind#from-source"
+ )
+ expected_hash = P2P_BINARY_HASH[binary_name]
+
+ if sha256(binary_path) != expected_hash:
+ binary_url = os.path.join(P2PD_BINARY_URL, binary_name)
+ print(f"Downloading {binary_url}")
+
+ urllib.request.urlretrieve(binary_url, binary_path)
+ os.chmod(binary_path, 0o777)
+
+ actual_hash = sha256(binary_path)
+ if actual_hash != expected_hash:
+ raise RuntimeError(
+ f"The sha256 checksum for p2pd does not match (expected: {expected_hash}, actual: {actual_hash})"
+ )
class BuildPy(build_py):
@@ -140,7 +156,9 @@ def run(self):
with open("requirements-docs.txt") as docs_requirements_file:
extras["docs"] = list(map(str, parse_requirements(docs_requirements_file)))
-extras["all"] = extras["dev"] + extras["docs"]
+extras["bitsandbytes"] = ["bitsandbytes~=0.41.1"]
+
+extras["all"] = extras["dev"] + extras["docs"] + extras["bitsandbytes"]
setup(
name="hivemind",
diff --git a/tests/test_averaging.py b/tests/test_averaging.py
index 79a8511da..1059e321b 100644
--- a/tests/test_averaging.py
+++ b/tests/test_averaging.py
@@ -356,7 +356,7 @@ def test_load_state_from_peers():
class TestAverager(DecentralizedAverager):
def get_current_state(self):
"""
- Get current state and send it to a peer. executed in the host process. Meant to be overriden.
+ Get current state and send it to a peer. executed in the host process. Meant to be overridden.
:returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
"""
nonlocal num_calls, super_metadata, super_tensors
@@ -528,7 +528,7 @@ def test_averaging_cancel():
step_controls = [averager.step(wait=False, scheduled_time=hivemind.get_dht_time() + 1) for averager in averagers]
- time.sleep(0.1)
+ time.sleep(0.05)
step_controls[0].cancel()
step_controls[1].cancel()
diff --git a/tests/test_cli_scripts.py b/tests/test_cli_scripts.py
index d69ef5ffb..3047d4c04 100644
--- a/tests/test_cli_scripts.py
+++ b/tests/test_cli_scripts.py
@@ -35,7 +35,7 @@ def test_dht_connection_successful():
dht_client_proc.stderr.readline()
first_report_msg = dht_client_proc.stderr.readline()
- assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg
+ assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg, first_report_msg
# ensure we get the output of dht_proc after the start of dht_client_proc
sleep(dht_refresh_period)
diff --git a/tests/test_compression.py b/tests/test_compression.py
index 172bf47e1..a75ea76a0 100644
--- a/tests/test_compression.py
+++ b/tests/test_compression.py
@@ -38,31 +38,81 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
assert error.square().mean() < beta
error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
assert error.square().mean() < beta
+ error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.BLOCKWISE_8BIT)) - X
+ assert error.square().mean() < beta
zeros = torch.zeros(5, 5)
for compression_type in CompressionType.values():
assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
+def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
+ serialized_tensor = serialize_torch_tensor(tensor, compression)
+ chunks = list(split_for_streaming(serialized_tensor, chunk_size))
+ assert len(chunks) == max((len(serialized_tensor.buffer) - 1) // chunk_size + 1, 1)
+ restored = combine_from_streaming(chunks)
+ result = deserialize_torch_tensor(restored)
+ assert result.dtype == tensor.dtype, compression
+ assert result.requires_grad == tensor.requires_grad
+ assert torch.allclose(result, tensor, rtol=rtol, atol=atol)
+
+
@pytest.mark.forked
def test_serialize_tensor():
- def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
- serialized_tensor = serialize_torch_tensor(tensor, compression)
- chunks = list(split_for_streaming(serialized_tensor, chunk_size))
- assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
- restored = combine_from_streaming(chunks)
- assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
-
- tensor = torch.randn(512, 12288)
+ tensor = torch.randn(512, 12288, requires_grad=True)
for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
_check(tensor, CompressionType.NONE, chunk_size=chunk_size)
_check(tensor, CompressionType.FLOAT16, rtol=0.0, atol=1e-2)
_check(torch.randint(0, 100, (512, 1, 1)), CompressionType.NONE)
+ _check(torch.randn(10, 20), CompressionType.MEANSTD_16BIT, atol=0.1)
_check(torch.tensor(1.0), CompressionType.NONE)
_check(torch.tensor(1.0), CompressionType.FLOAT16)
+@pytest.mark.parametrize(
+ "dtype",
+ [
+ torch.float32,
+ torch.float16,
+ torch.bfloat16,
+ torch.float64,
+ torch.complex64,
+ torch.int64,
+ torch.int32,
+ torch.uint8,
+ torch.bool,
+ ],
+)
+@pytest.mark.parametrize("requires_grad", [False, True])
+@pytest.mark.forked
+def test_serialize_tensor_properties(dtype: torch.dtype, requires_grad: bool):
+ tensor = torch.randn(123, 45, requires_grad=requires_grad).to(dtype)
+ if dtype == torch.bfloat16:
+ compression_types = [
+ type
+ for type in CompressionType.values()
+ if type not in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT)
+ ]
+ elif torch.is_floating_point(tensor): # nb: complex and qint data types are not is_floating_point
+ compression_types = CompressionType.values()
+ else:
+ compression_types = [CompressionType.NONE]
+
+ for compression_type in compression_types:
+ _check(tensor, compression_type, atol=float("inf"))
+
+
+@pytest.mark.parametrize("use_legacy_bfloat16", [True, False])
+@pytest.mark.parametrize("tensor_size", [(4096, 16), (0, 0)])
+@pytest.mark.forked
+def test_serialize_bfloat16(use_legacy_bfloat16: bool, tensor_size: tuple):
+ hivemind.compression.base.USE_LEGACY_BFLOAT16 = use_legacy_bfloat16
+ tensor = torch.randn(tensor_size, dtype=torch.bfloat16)
+ _check(tensor, CompressionType.NONE)
+ _check(tensor, CompressionType.BLOCKWISE_8BIT, rtol=0.1, atol=0.01, chunk_size=1024)
+
+
@pytest.mark.forked
def test_allreduce_compression():
"""this test ensures that compression works correctly when multiple tensors have different compression types"""
@@ -210,5 +260,7 @@ def test_adaptive_compression():
assert FLOAT32.mp_part_size.value == 1250 # four-byte tensors
averager1.load_state_from_peers()
- assert STATE_FP16.mp_counter.value == STATE_FP32.mp_counter.value == 9
+ state_metadata, state_tensors, infos = averager1.get_current_state()
+ assert STATE_FP16.mp_counter.value == len([tensor for tensor in state_tensors if tensor.numel() >= 500])
+ assert STATE_FP32.mp_counter.value == len([tensor for tensor in state_tensors if tensor.numel() < 500])
assert STATE_FP16.mp_part_size.value == STATE_FP32.mp_part_size.value == 0 # not partitioned
diff --git a/tests/test_connection_handler.py b/tests/test_connection_handler.py
index afc6179f0..0f3220574 100644
--- a/tests/test_connection_handler.py
+++ b/tests/test_connection_handler.py
@@ -20,19 +20,25 @@
from hivemind.utils.tensor_descr import BatchTensorDescriptor
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_connection_handler_info():
- handler = ConnectionHandler(
- DHT(start=True),
- dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
- )
- handler.start()
+@pytest.fixture
+async def client_stub():
+ handler_dht = DHT(start=True)
+ module_backends = {"expert1": DummyModuleBackend("expert1", k=1), "expert2": DummyModuleBackend("expert2", k=2)}
+ handler = ConnectionHandler(handler_dht, module_backends, start=True)
client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
- # info
+ yield client_stub
+
+ client_dht.shutdown()
+ handler.shutdown()
+ handler_dht.shutdown()
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler_info(client_stub):
response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
@@ -45,16 +51,7 @@ async def test_connection_handler_info():
@pytest.mark.forked
@pytest.mark.asyncio
-async def test_connection_handler_forward():
- handler = ConnectionHandler(
- DHT(start=True),
- dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
- )
- handler.start()
-
- client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
- client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
-
+async def test_connection_handler_forward(client_stub):
inputs = torch.randn(1, 2)
inputs_long = torch.randn(2**21, 2)
@@ -106,16 +103,7 @@ async def test_connection_handler_forward():
@pytest.mark.forked
@pytest.mark.asyncio
-async def test_connection_handler_backward():
- handler = ConnectionHandler(
- DHT(start=True),
- dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
- )
- handler.start()
-
- client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
- client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
-
+async def test_connection_handler_backward(client_stub):
inputs = torch.randn(1, 2)
inputs_long = torch.randn(2**21, 2)
@@ -165,8 +153,20 @@ async def test_connection_handler_backward():
# check that handler did not crash after failed request
await client_stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)]))
- handler.terminate()
- handler.join()
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler_shutdown():
+ # Here, all handlers will have the common hivemind.DHT and hivemind.P2P instances
+ handler_dht = DHT(start=True)
+ module_backends = {"expert1": DummyModuleBackend("expert1", k=1), "expert2": DummyModuleBackend("expert2", k=2)}
+
+ for _ in range(3):
+ handler = ConnectionHandler(handler_dht, module_backends, balanced=False, start=True)
+ # The line above would raise an exception if the previous handlers were not removed from hivemind.P2P
+ handler.shutdown()
+
+ handler_dht.shutdown()
class DummyPool(TaskPool):
diff --git a/tests/test_dht_node.py b/tests/test_dht_node.py
index 20d798c09..3dd6314b5 100644
--- a/tests/test_dht_node.py
+++ b/tests/test_dht_node.py
@@ -301,4 +301,4 @@ async def test_dhtnode_edge_cases():
assert subkey in stored.value
assert stored.value[subkey].value == value
- await asyncio.wait([node.shutdown() for node in peers])
+ await asyncio.wait([asyncio.create_task(node.shutdown()) for node in peers])
diff --git a/tests/test_moe.py b/tests/test_moe.py
index 46e9279a8..f62c2159d 100644
--- a/tests/test_moe.py
+++ b/tests/test_moe.py
@@ -1,3 +1,9 @@
+import asyncio
+import ctypes
+import multiprocessing as mp
+import threading
+import time
+
import numpy as np
import pytest
import torch
@@ -5,12 +11,13 @@
from hivemind.dht import DHT
from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts
from hivemind.moe.server.layers import name_to_block
-from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
-from hivemind.utils import BatchTensorDescriptor, get_dht_time
+from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
+from hivemind.utils import BatchTensorDescriptor, MPFuture, get_dht_time
@pytest.mark.forked
@@ -153,11 +160,17 @@ def test_remote_module_call(hidden_dim=16):
out3_again.norm().backward()
assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
- with pytest.raises(P2PDaemonError):
+ try:
real_expert(torch.randn(3, 11))
- with pytest.raises(P2PDaemonError):
+ except P2PHandlerError as e:
+ assert str(11) in repr(e), "Exception must relay the remote server error (i.e. incorrect dimensions)"
+ with pytest.raises(P2PHandlerError):
fake_expert(dummy_x)
+ # check that the server is still alive after processing a malformed request
+ out3_yet_again = real_expert(dummy_x[1:])
+ assert torch.allclose(out3_yet_again, out3[1:], atol=1e-5, rtol=0)
+
@pytest.mark.forked
def test_beam_search_correctness():
@@ -300,3 +313,43 @@ def test_client_anomaly_detection():
finally:
server.shutdown()
+
+
+def _measure_coro_running_time(n_coros, elapsed_fut, counter):
+ async def coro():
+ await asyncio.sleep(0.1)
+ counter.value += 1
+
+ try:
+ start_time = time.perf_counter()
+
+ futures = [
+ RemoteExpertWorker.run_coroutine(coro(), return_future=True) for _ in range(n_coros - 1)
+ ] # Non-blocking calls
+ RemoteExpertWorker.run_coroutine(coro(), return_future=False) # A blocking call
+ for fut in futures:
+ fut.result()
+
+ elapsed_fut.set_result(time.perf_counter() - start_time)
+ except Exception as e:
+ elapsed_fut.set_exception(e)
+
+
+@pytest.mark.forked
+def test_remote_expert_worker_runs_coros_concurrently(n_processes=4, n_coros=10):
+ processes = []
+ counter = mp.Value(ctypes.c_int64)
+ for i in range(n_processes):
+ elapsed_fut = MPFuture()
+ factory = threading.Thread if i % 2 == 0 else mp.Process # Test both threads and processes
+
+ proc = factory(target=_measure_coro_running_time, args=(n_coros, elapsed_fut, counter))
+ proc.start()
+ processes.append((proc, elapsed_fut))
+
+ for proc, elapsed_fut in processes:
+ # Ensure that the coroutines were run concurrently, not sequentially
+ assert elapsed_fut.result() < 0.2
+ proc.join()
+
+ assert counter.value == n_processes * n_coros # Ensure all couroutines have finished
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index e2361ae03..c859e3879 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -15,7 +15,7 @@
from hivemind.optim.optimizer import Optimizer
from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
from hivemind.optim.progress_tracker import ProgressTracker
-from hivemind.optim.state_averager import TrainingStateAverager
+from hivemind.optim.state_averager import ZERO_GRAD_SET_TO_NONE_DEFAULT, TrainingStateAverager
from hivemind.utils.crypto import RSAPrivateKey
@@ -79,8 +79,11 @@ def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
assert torch.allclose(model2.w.grad, ref_average)
# after no longer use_averaged_gradients
- assert not torch.allclose(model1.w.grad, ref_average)
- assert not torch.allclose(model2.w.grad, ref_average)
+ if ZERO_GRAD_SET_TO_NONE_DEFAULT: # averager1 has reuse_grad_buffers=False
+ assert model1.w.grad is None
+ else:
+ assert not torch.allclose(model1.w.grad, ref_average)
+ assert not torch.allclose(model2.w.grad, ref_average) # averager2 has reuse_grad_buffers=True
@pytest.mark.forked
@@ -151,7 +154,10 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
F.mse_loss(model2(x), -torch.ones(3)).backward()
avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
- assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
+ if ZERO_GRAD_SET_TO_NONE_DEFAULT:
+ assert model1.weight.grad is None and model2.weight.grad is None, ".zero_grad() wasn't called"
+ else:
+ assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), ".zero_grad() wasn't called"
assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
diff --git a/tests/test_p2p_daemon.py b/tests/test_p2p_daemon.py
index 70a0de084..55ff36af5 100644
--- a/tests/test_p2p_daemon.py
+++ b/tests/test_p2p_daemon.py
@@ -73,11 +73,37 @@ async def test_identity():
P2P.generate_identity(id1_path)
+@pytest.mark.asyncio
+async def test_check_if_identity_free():
+ with tempfile.TemporaryDirectory() as tempdir:
+ id1_path = os.path.join(tempdir, "id1")
+ id2_path = os.path.join(tempdir, "id2")
+
+ p2ps = [await P2P.create(identity_path=id1_path)]
+ initial_peers = await p2ps[0].get_visible_maddrs()
+
+ p2ps.append(await P2P.create(initial_peers=initial_peers))
+ p2ps.append(await P2P.create(initial_peers=initial_peers, identity_path=id2_path))
+
+ with pytest.raises(P2PDaemonError, match=r"Identity.+is already taken by another peer"):
+ await P2P.create(initial_peers=initial_peers, identity_path=id1_path)
+ with pytest.raises(P2PDaemonError, match=r"Identity.+is already taken by another peer"):
+ await P2P.create(initial_peers=initial_peers, identity_path=id2_path)
+
+ # Must work if a P2P with a certain identity is restarted
+ await p2ps[-1].shutdown()
+ p2ps.pop()
+ p2ps.append(await P2P.create(initial_peers=initial_peers, identity_path=id2_path))
+
+ for instance in p2ps:
+ await instance.shutdown()
+
+
@pytest.mark.parametrize(
"host_maddrs",
[
[Multiaddr("/ip4/127.0.0.1/tcp/0")],
- [Multiaddr("/ip4/127.0.0.1/udp/0/quic")],
+ [Multiaddr("/ip4/127.0.0.1/udp/0/quic-v1")],
[Multiaddr("/ip4/127.0.0.1/tcp/0"), Multiaddr("/ip4/127.0.0.1/udp/0/quic")],
],
)
diff --git a/tests/test_p2p_servicer.py b/tests/test_p2p_servicer.py
index ee8f187b1..1950260a2 100644
--- a/tests/test_p2p_servicer.py
+++ b/tests/test_p2p_servicer.py
@@ -3,7 +3,7 @@
import pytest
-from hivemind.p2p import P2P, P2PContext, ServicerBase
+from hivemind.p2p import P2P, P2PContext, P2PDaemonError, ServicerBase
from hivemind.proto import test_pb2
from hivemind.utils.asyncio import anext
@@ -17,35 +17,37 @@ async def server_client():
await asyncio.gather(server.shutdown(), client.shutdown())
+class UnaryUnaryServicer(ServicerBase):
+ async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
+ return test_pb2.TestResponse(number=request.number**2)
+
+
@pytest.mark.asyncio
async def test_unary_unary(server_client):
- class ExampleServicer(ServicerBase):
- async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
- return test_pb2.TestResponse(number=request.number**2)
-
server, client = server_client
- servicer = ExampleServicer()
+ servicer = UnaryUnaryServicer()
await servicer.add_p2p_handlers(server)
- stub = ExampleServicer.get_stub(client, server.peer_id)
+ stub = UnaryUnaryServicer.get_stub(client, server.peer_id)
assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
+class StreamUnaryServicer(ServicerBase):
+ async def rpc_sum(
+ self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
+ ) -> test_pb2.TestResponse:
+ result = 0
+ async for item in stream:
+ result += item.number
+ return test_pb2.TestResponse(number=result)
+
+
@pytest.mark.asyncio
async def test_stream_unary(server_client):
- class ExampleServicer(ServicerBase):
- async def rpc_sum(
- self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
- ) -> test_pb2.TestResponse:
- result = 0
- async for item in stream:
- result += item.number
- return test_pb2.TestResponse(number=result)
-
server, client = server_client
- servicer = ExampleServicer()
+ servicer = StreamUnaryServicer()
await servicer.add_p2p_handlers(server)
- stub = ExampleServicer.get_stub(client, server.peer_id)
+ stub = StreamUnaryServicer.get_stub(client, server.peer_id)
async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
for i in range(10):
@@ -54,42 +56,40 @@ async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
assert await stub.rpc_sum(generate_requests()) == test_pb2.TestResponse(number=45)
+class UnaryStreamServicer(ServicerBase):
+ async def rpc_count(
+ self, request: test_pb2.TestRequest, _context: P2PContext
+ ) -> AsyncIterator[test_pb2.TestResponse]:
+ for i in range(request.number):
+ yield test_pb2.TestResponse(number=i)
+
+
@pytest.mark.asyncio
async def test_unary_stream(server_client):
- class ExampleServicer(ServicerBase):
- async def rpc_count(
- self, request: test_pb2.TestRequest, _context: P2PContext
- ) -> AsyncIterator[test_pb2.TestResponse]:
- for i in range(request.number):
- yield test_pb2.TestResponse(number=i)
-
server, client = server_client
- servicer = ExampleServicer()
+ servicer = UnaryStreamServicer()
await servicer.add_p2p_handlers(server)
- stub = ExampleServicer.get_stub(client, server.peer_id)
+ stub = UnaryStreamServicer.get_stub(client, server.peer_id)
stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
- i = 0
- async for item in stream:
- assert item == test_pb2.TestResponse(number=i)
- i += 1
- assert i == 10
+ assert [item.number async for item in stream] == list(range(10))
+
+
+class StreamStreamServicer(ServicerBase):
+ async def rpc_powers(
+ self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
+ ) -> AsyncIterator[test_pb2.TestResponse]:
+ async for item in stream:
+ yield test_pb2.TestResponse(number=item.number**2)
+ yield test_pb2.TestResponse(number=item.number**3)
@pytest.mark.asyncio
async def test_stream_stream(server_client):
- class ExampleServicer(ServicerBase):
- async def rpc_powers(
- self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
- ) -> AsyncIterator[test_pb2.TestResponse]:
- async for item in stream:
- yield test_pb2.TestResponse(number=item.number**2)
- yield test_pb2.TestResponse(number=item.number**3)
-
server, client = server_client
- servicer = ExampleServicer()
+ servicer = StreamStreamServicer()
await servicer.add_p2p_handlers(server)
- stub = ExampleServicer.get_stub(client, server.peer_id)
+ stub = StreamStreamServicer.get_stub(client, server.peer_id)
async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
for i in range(10):
@@ -153,3 +153,43 @@ async def rpc_wait(
await asyncio.sleep(0.25)
assert handler_cancelled
+
+
+@pytest.mark.asyncio
+async def test_removing_unary_handlers(server_client):
+ server1, client = server_client
+ server2 = await P2P.replicate(server1.daemon_listen_maddr)
+ servicer = UnaryUnaryServicer()
+ stub = UnaryUnaryServicer.get_stub(client, server1.peer_id)
+
+ for server in [server1, server2, server1]:
+ await servicer.add_p2p_handlers(server)
+ assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
+
+ await servicer.remove_p2p_handlers(server)
+ with pytest.raises((P2PDaemonError, ConnectionError)):
+ await stub.rpc_square(test_pb2.TestRequest(number=10))
+
+ await asyncio.gather(server2.shutdown())
+
+
+@pytest.mark.asyncio
+async def test_removing_stream_handlers(server_client):
+ server1, client = server_client
+ server2 = await P2P.replicate(server1.daemon_listen_maddr)
+ servicer = UnaryStreamServicer()
+ stub = UnaryStreamServicer.get_stub(client, server1.peer_id)
+
+ for server in [server1, server2, server1]:
+ await servicer.add_p2p_handlers(server)
+ stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
+ assert [item.number async for item in stream] == list(range(10))
+
+ await servicer.remove_p2p_handlers(server)
+ with pytest.raises((P2PDaemonError, ConnectionError)):
+ stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
+ outputs = [item.number async for item in stream]
+ if not outputs:
+ raise P2PDaemonError("Daemon has reset the connection")
+
+ await asyncio.gather(server2.shutdown())
diff --git a/tests/test_relays.py b/tests/test_relays.py
new file mode 100644
index 000000000..d55b4ff8e
--- /dev/null
+++ b/tests/test_relays.py
@@ -0,0 +1,64 @@
+import time
+from functools import partial
+
+import pytest
+
+import hivemind
+
+
+async def ping_to_client(dht, node, peer_id: str):
+ return await node.protocol.call_ping(hivemind.PeerID.from_base58(str(peer_id)))
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize(
+ "use_auto_relay,use_relay",
+ [
+ (True, True),
+ (False, False),
+ ],
+)
+def test_autorelay(use_auto_relay: bool, use_relay: bool):
+ dht_first_peer = hivemind.DHT(
+ start=True,
+ use_auto_relay=use_auto_relay,
+ use_relay=use_relay,
+ force_reachability="public",
+ )
+ dht_first_peer_id = dht_first_peer.peer_id
+ initial_peers = dht_first_peer.get_visible_maddrs()
+ assert dht_first_peer_id is not None
+
+ dht_third_peer = hivemind.DHT(
+ initial_peers=initial_peers,
+ host_maddrs=[],
+ start=True,
+ no_listen=True,
+ use_relay=use_relay,
+ client_mode=False,
+ use_auto_relay=use_auto_relay,
+ )
+ time.sleep(5)
+ dht_second_peer = hivemind.DHT(
+ initial_peers=initial_peers,
+ start=True,
+ client_mode=False,
+ no_listen=False,
+ use_relay=use_relay,
+ use_auto_relay=use_auto_relay,
+ )
+
+ assert dht_first_peer.is_alive() and dht_second_peer.is_alive() and dht_third_peer.is_alive()
+
+ time_start = time.perf_counter()
+ while time.perf_counter() - time_start < 30:
+ reached_ip = dht_second_peer.run_coroutine(partial(ping_to_client, peer_id=dht_third_peer.peer_id))
+ if reached_ip:
+ assert use_relay
+ break
+ time.sleep(2)
+ else:
+ assert not use_relay
+
+ for peer in dht_first_peer, dht_second_peer, dht_third_peer:
+ peer.shutdown()
diff --git a/tests/test_start_server.py b/tests/test_start_server.py
index 512332070..b84dd5407 100644
--- a/tests/test_start_server.py
+++ b/tests/test_start_server.py
@@ -1,5 +1,6 @@
import os
import re
+from functools import partial
from subprocess import PIPE, Popen
from tempfile import TemporaryDirectory
@@ -10,10 +11,11 @@ def test_background_server_identity_path():
with TemporaryDirectory() as tempdir:
id_path = os.path.join(tempdir, "id")
- with background_server(num_experts=1, identity_path=id_path) as server_info_1, background_server(
- num_experts=1, identity_path=id_path
- ) as server_info_2, background_server(num_experts=1, identity_path=None) as server_info_3:
+ server_runner = partial(background_server, num_experts=1, device="cpu", hidden_dim=1)
+ with server_runner(identity_path=id_path) as server_info_1, server_runner(
+ identity_path=id_path
+ ) as server_info_2, server_runner(identity_path=None) as server_info_3:
assert server_info_1.peer_id == server_info_2.peer_id
assert server_info_1.peer_id != server_info_3.peer_id
assert server_info_3.peer_id == server_info_3.peer_id
@@ -32,10 +34,13 @@ def test_cli_run_server_identity_path():
encoding="utf-8",
)
- # Skip line "Generating new identity (libp2p private key) in {path to file}"
line = server_1_proc.stderr.readline()
+ assert "Generating new identity" in line
+
line = server_1_proc.stderr.readline()
- addrs_1 = set(re.search(pattern, line).group(1).split(", "))
+ addrs_pattern_result = re.search(pattern, line)
+ assert addrs_pattern_result is not None, line
+ addrs_1 = set(addrs_pattern_result.group(1).split(", "))
ids_1 = set(a.split("/")[-1] for a in addrs_1)
assert len(ids_1) == 1
@@ -48,7 +53,9 @@ def test_cli_run_server_identity_path():
)
line = server_2_proc.stderr.readline()
- addrs_2 = set(re.search(pattern, line).group(1).split(", "))
+ addrs_pattern_result = re.search(pattern, line)
+ assert addrs_pattern_result is not None, line
+ addrs_2 = set(addrs_pattern_result.group(1).split(", "))
ids_2 = set(a.split("/")[-1] for a in addrs_2)
assert len(ids_2) == 1
@@ -61,7 +68,9 @@ def test_cli_run_server_identity_path():
)
line = server_3_proc.stderr.readline()
- addrs_3 = set(re.search(pattern, line).group(1).split(", "))
+ addrs_pattern_result = re.search(pattern, line)
+ assert addrs_pattern_result is not None, line
+ addrs_3 = set(addrs_pattern_result.group(1).split(", "))
ids_3 = set(a.split("/")[-1] for a in addrs_3)
assert len(ids_3) == 1
diff --git a/tests/test_training.py b/tests/test_training.py
index c63b5116d..94c7ea993 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -20,7 +20,12 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
SGD = partial(torch.optim.SGD, lr=0.05)
with background_server(
- num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
+ num_experts=2,
+ device="cpu",
+ optim_cls=SGD,
+ hidden_dim=64,
+ num_handlers=1,
+ clip_grad_norm=1.0,
) as server_peer_info:
dht = DHT(initial_peers=server_peer_info.addrs, start=True)
expert1, expert2 = create_remote_experts(
diff --git a/tests/test_util_modules.py b/tests/test_util_modules.py
index 0cbb9e831..0243b91d7 100644
--- a/tests/test_util_modules.py
+++ b/tests/test_util_modules.py
@@ -507,6 +507,32 @@ async def coro2():
# running this without enter_asynchronously would deadlock the event loop
+@pytest.mark.asyncio
+async def test_async_context_flooding():
+ """
+ test for a possible deadlock when many coroutines await the lock and overwhelm the underlying ThreadPoolExecutor
+
+ Here's how the test below works: suppose that the thread pool has at most N workers;
+ If at least N + 1 coroutines await lock1 concurrently, N of them occupy workers and the rest are awaiting workers;
+ When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep(1e-2);
+ During that sleep, one of the worker-less coroutines will take up the worker freed by coroutine A.
+ Finally, coroutine A finishes sleeping and immediately gets stuck at lock2, because there are no free workers.
+ Thus, every single coroutine is either awaiting an already acquired lock, or awaiting for free workers in executor.
+
+ """
+ lock1, lock2 = mp.Lock(), mp.Lock()
+
+ async def coro():
+ async with enter_asynchronously(lock1):
+ await asyncio.sleep(1e-2)
+ async with enter_asynchronously(lock2):
+ await asyncio.sleep(1e-2)
+
+ num_coros = max(100, mp.cpu_count() * 5 + 1)
+ # note: if we deprecate py3.7, this can be reduced to max(33, cpu + 5); see https://bugs.python.org/issue35279
+ await asyncio.wait({asyncio.create_task(coro()) for _ in range(num_coros)})
+
+
def test_batch_tensor_descriptor_msgpack():
tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7))
tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr))