diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 0000000..f8c87de --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,102 @@ +name: Build wheel and sdist + +on: + workflow_dispatch: + release: + types: [published] + +jobs: + build-wheels: + name: Build wheel for ${{ matrix.python-version }} + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["cp39", "cp310", "cp311", "cp312", "cp313"] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install cibuildwheel + run: python -m pip install cibuildwheel setuptools pybind11 + - name: Build wheels with cibuildwheel + run: | + cibuildwheel --output-dir wheelhouse + env: + CIBW_BUILD: "${{ matrix.python-version }}-manylinux_x86_64" + CIBW_SKIP: "*-musllinux_* *-win32 *-manylinux_i686" + CIBW_TEST_SKIP: "*" + CIBW_ARCHS: "x86_64" + + - name: Upload wheel artifact + uses: actions/upload-artifact@v4 + with: + name: wheels-${{ matrix.python-version }} + path: wheelhouse/*.whl + + build-sdist: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install build + run: pip install build + + - name: Build sdist + run: python -m build --sdist --outdir dist/ + + - name: Upload sdist artifact + uses: actions/upload-artifact@v4 + with: + name: sdist + path: dist/*.tar.gz + + upload: + needs: [build-wheels, build-sdist] + runs-on: ubuntu-latest + + steps: + - name: Download all wheels + uses: actions/download-artifact@v4 + with: + path: dist + - name: Download sdist artifact + uses: actions/download-artifact@v4 + with: + name: sdist + path: dist + + - name: Flatten all artifacts + run: | + mkdir final_dist + find dist -name '*.whl' -exec cp {} final_dist/ \; + find dist -name '*.tar.gz' -exec cp {} final_dist/ \; + + - name: Upload all artifact + uses: actions/upload-artifact@v4 + with: + name: final_dist + path: final_dist + +# - name: Set up Python +# uses: actions/setup-python@v5 +# with: +# python-version: "3.11" +# - name: Publish to PyPI +# run: | +# python -m pip install twine +# twine upload --non-interactive final_dist/* +# env: +# TWINE_USERNAME: __token__ +# TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/test-paddle.yaml b/.github/workflows/test-paddle.yaml index 985394f..c34c00d 100644 --- a/.github/workflows/test-paddle.yaml +++ b/.github/workflows/test-paddle.yaml @@ -32,9 +32,23 @@ jobs: - name: Install Python dependencies run: | python -m pip install --upgrade pip - pip install torch==2.5.1 --index-url https://download.pytorch.org/whl/cpu + tf_ver=4.52 + npy_ver=2.2 + torch_ver=2.7 + if [ "${{ matrix.python-version }}" = "3.9" ]; then + npy_ver=1.26 + tf_ver=4.40 + torch_ver=2.1 + elif [ "${{ matrix.python-version }}" = "3.10" ]; then + torch_ver=2.3 + elif [ "${{ matrix.python-version }}" = "3.11" ]; then + torch_ver=2.5 + elif [ "${{ matrix.python-version }}" = "3.12" ]; then + torch_ver=2.6 + fi + pip install torch==${torch_ver} --index-url https://download.pytorch.org/whl/cpu # transformers requires torch pip install paddlepaddle==3.0.0 - pip install pytest pytest-cov setuptools_scm safetensors transformers==4.52 + pip install pytest pytest-cov setuptools_scm safetensors transformers==${tf_ver} numpy==${npy_ver} - name: Build Package run: | pip install . @@ -43,12 +57,11 @@ jobs: cd tests LIBDIR=`python3 -c "import os; os.chdir('/tmp'); import fastsafetensors; print(os.path.dirname(fastsafetensors.__file__))"` mkdir -p /tmp/pytest-log + export TEST_FASTSAFETENSORS_FRAMEWORK=paddle COVERAGE_FILE=.coverage_0 pytest -s --cov=${LIBDIR} test_fastsafetensors.py > /tmp/pytest-log/0.log 2>&1 - COVERAGE_FILE=.coverage_1 CUDA_VISIBLE_DEVICES="" pytest -s --cov=${LIBDIR} test_fastsafetensors.py > /tmp/pytest-log/1.log 2>&1 - COVERAGE_FILE=.coverage_2 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 --no-python pytest -s --cov=${LIBDIR} test_multi.py > /tmp/pytest-log/2.log 2>&1 & - COVERAGE_FILE=.coverage_3 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 --no-python pytest -s --cov=${LIBDIR} test_multi.py > /tmp/pytest-log/3.log 2>&1 - python -m paddle.distributed.launch --nproc_per_node 2 run_distributed_paddle_test.py -s --cov=${LIBDIR} test_multi_paddle.py - coverage combine .coverage_0 .coverage_1 .coverage_2 .coverage_3 .coverage_4 .coverage_5 + COVERAGE_FILE=.coverage_1 WORLD_SIZE=2 python3 -m paddle.distributed.launch --nnodes 2 --master 127.0.0.1:1234 --rank 0 test_multi.py --cov=${LIBDIR} -s test_multi.py > /tmp/pytest-log/1.log 2>&1 & \ + COVERAGE_FILE=.coverage_2 WORLD_SIZE=2 python3 -m paddle.distributed.launch --nnodes 2 --master 127.0.0.1:1234 --rank 1 test_multi.py --cov=${LIBDIR} -s test_multi.py > /tmp/pytest-log/2.log 2>&1 && \ + coverage combine .coverage_* coverage html mv htmlcov /tmp/pytest-log - name: upload pytest log diff --git a/.github/workflows/test-torch.yaml b/.github/workflows/test-torch.yaml index 8ea8fdb..834ea4f 100644 --- a/.github/workflows/test-torch.yaml +++ b/.github/workflows/test-torch.yaml @@ -56,11 +56,14 @@ jobs: cd tests LIBDIR=`python3 -c "import os; os.chdir('/tmp'); import fastsafetensors; print(os.path.dirname(fastsafetensors.__file__))"` mkdir -p /tmp/pytest-log + export TEST_FASTSAFETENSORS_FRAMEWORK=pytorch COVERAGE_FILE=.coverage_0 pytest -s --cov=${LIBDIR} test_fastsafetensors.py > /tmp/pytest-log/0.log 2>&1 - COVERAGE_FILE=.coverage_1 CUDA_VISIBLE_DEVICES="" pytest -s --cov=${LIBDIR} test_fastsafetensors.py > /tmp/pytest-log/1.log 2>&1 - COVERAGE_FILE=.coverage_2 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 --no-python pytest -s --cov=${LIBDIR} test_multi.py > /tmp/pytest-log/2.log 2>&1 & - COVERAGE_FILE=.coverage_3 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 --no-python pytest -s --cov=${LIBDIR} test_multi.py > /tmp/pytest-log/3.log 2>&1 - coverage combine .coverage_0 .coverage_1 .coverage_2 .coverage_3 + COVERAGE_FILE=.coverage_1 torchrun --nnodes=1 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 test_multi.py --cov=${LIBDIR} -s test_multi.py > /tmp/pytest-log/1.log 2>&1 + COVERAGE_FILE=.coverage_2 torchrun --nnodes=4 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 test_multi.py --cov=${LIBDIR} -s test_multi.py > /tmp/pytest-log/2.log 2>&1 & + COVERAGE_FILE=.coverage_3 torchrun --nnodes=4 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 test_multi.py --cov=${LIBDIR} -s test_multi.py > /tmp/pytest-log/3.log 2>&1 & + COVERAGE_FILE=.coverage_4 torchrun --nnodes=4 --master_addr=0.0.0.0 --master_port=1234 --node_rank=2 test_multi.py --cov=${LIBDIR} -s test_multi.py > /tmp/pytest-log/4.log 2>&1 & + COVERAGE_FILE=.coverage_5 torchrun --nnodes=4 --master_addr=0.0.0.0 --master_port=1234 --node_rank=3 test_multi.py --cov=${LIBDIR} -s test_multi.py > /tmp/pytest-log/5.log 2>&1 + coverage combine .coverage_* coverage html mv htmlcov /tmp/pytest-log - name: Upload Pytest log diff --git a/Makefile b/Makefile index d44b15e..d3cc75c 100644 --- a/Makefile +++ b/Makefile @@ -6,24 +6,44 @@ CONCMD := docker ifdef PODMAN CONCMD = podman endif -FST_DIR := $(shell python3 -c "import os; os.chdir('/tmp'); import fastsafetensors; print(os.path.dirname(fastsafetensors.__file__))") .PHONY: install install: pip install . --no-cache-dir --no-build-isolation -.PHONY: unittest +.PHONY: unittest unittest-parallel unittest-paddle unittest-paddle-gpu htmlcov + +FST_DIR := $(shell python3 -c "import os; os.chdir('/tmp'); import fastsafetensors; print(os.path.dirname(fastsafetensors.__file__))") + unittest: - COVERAGE_FILE=.coverage_0 pytest -s --cov=$(FST_DIR) tests/test_fastsafetensors.py - COVERAGE_FILE=.coverage_1 CUDA_VISIBLE_DEVICES="" pytest -s --cov=$(FST_DIR) tests/test_fastsafetensors.py - COVERAGE_FILE=.coverage_2 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 --no-python pytest -s --cov=${FST_DIR} tests/test_multi.py > /tmp/2.log 2>&1 & - COVERAGE_FILE=.coverage_3 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 --no-python pytest -s --cov=${FST_DIR} tests/test_multi.py > /tmp/3.log 2>&1 - coverage combine .coverage_0 .coverage_1 .coverage_2 .coverage_3 - coverage html + @FST_DIR=$(FST_DIR); \ + TEST_FASTSAFETENSORS_FRAMEWORK=torch COVERAGE_FILE=.coverage_0 pytest -s --cov=$(FST_DIR) tests/test_fastsafetensors.py && \ + TEST_FASTSAFETENSORS_FRAMEWORK=torch COVERAGE_FILE=.coverage_1 CUDA_VISIBLE_DEVICES="" pytest -s --cov=$(FST_DIR) tests/test_fastsafetensors.py && \ + TEST_FASTSAFETENSORS_FRAMEWORK=torch COVERAGE_FILE=.coverage_2 pytest -s --cov=$(FST_DIR) -s tests/test_vllm.py + +unittest-parallel: + TEST_FASTSAFETENSORS_FRAMEWORK=torch COVERAGE_FILE=.coverage_3 torchrun --nnodes=4 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 tests/test_multi.py --cov=$(FST_DIR) -s tests/test_multi.py > /tmp/3.log 2>&1 & \ + TEST_FASTSAFETENSORS_FRAMEWORK=torch COVERAGE_FILE=.coverage_4 torchrun --nnodes=4 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 tests/test_multi.py --cov=$(FST_DIR) -s tests/test_multi.py > /tmp/4.log 2>&1 & \ + TEST_FASTSAFETENSORS_FRAMEWORK=torch COVERAGE_FILE=.coverage_5 torchrun --nnodes=4 --master_addr=0.0.0.0 --master_port=1234 --node_rank=2 tests/test_multi.py --cov=$(FST_DIR) -s tests/test_multi.py > /tmp/5.log 2>&1 & \ + TEST_FASTSAFETENSORS_FRAMEWORK=torch COVERAGE_FILE=.coverage_6 torchrun --nnodes=4 --master_addr=0.0.0.0 --master_port=1234 --node_rank=3 tests/test_multi.py --cov=$(FST_DIR) -s tests/test_multi.py > /tmp/6.log 2>&1 && \ + wait && \ + TEST_FASTSAFETENSORS_FRAMEWORK=torch COVERAGE_FILE=.coverage_7 torchrun --nnodes=1 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 tests/test_multi.py --cov=$(FST_DIR) -s tests/test_multi.py > /tmp/7.log 2>&1 & \ + wait -.PHONY: integrationtest -integrationtest: - cd tests && COVERAGE_FILE=.coverage pytest -s test_vllm.py +unittest-paddle: + @FST_DIR=$(FST_DIR); \ + TEST_FASTSAFETENSORS_FRAMEWORK=paddle COVERAGE_FILE=.coverage_8 CUDA_VISIBLE_DEVICES="" pytest -s --cov=$(FST_DIR) tests/test_fastsafetensors.py && \ + TEST_FASTSAFETENSORS_FRAMEWORK=paddle COVERAGE_FILE=.coverage_9 CUDA_VISIBLE_DEVICES="" WORLD_SIZE=2 python3 -m paddle.distributed.launch --nnodes 2 --master 127.0.0.1:1234 --rank 0 tests/test_multi.py --cov=$(FST_DIR) -s tests/test_multi.py > /tmp/9.log 2>&1 & \ + TEST_FASTSAFETENSORS_FRAMEWORK=paddle COVERAGE_FILE=.coverage_10 CUDA_VISIBLE_DEVICES="" WORLD_SIZE=2 python3 -m paddle.distributed.launch --nnodes 2 --master 127.0.0.1:1234 --rank 1 tests/test_multi.py --cov=$(FST_DIR) -s tests/test_multi.py > /tmp/10.log 2>&1 && \ + wait + +unittest-paddle-gpu: + @FST_DIR=$(FST_DIR); \ + TEST_FASTSAFETENSORS_FRAMEWORK=paddle COVERAGE_FILE=.coverage_11 pytest -s --cov=$(FST_DIR) tests/test_fastsafetensors.py + +htmlcov: + coverage combine .coverage_* && \ + coverage html .PHONY: builder builder: Dockerfile.build @@ -45,7 +65,20 @@ upload: python3 -m twine upload -u __token__ dist/fastsafetensors-$(shell grep version pyproject.toml | sed -e 's/version = "\([0-9.]\+\)"/\1/g')* perf/dist: - cd perf && python3 -m build + cd perf && pip install . + +.PHONY: format +format: + black8 . + isort . + +.PHONY: lint +lint: + black . + isort . + flake8 . --select=E9,F63,F7,F82 + mypy . --ignore-missing-imports +.PHONY: clean clean: rm -rf dist build fastsafetensors.egg-info \ No newline at end of file diff --git a/examples/paddle_case/a_paddle.safetensors b/examples/a_paddle.safetensors similarity index 100% rename from examples/paddle_case/a_paddle.safetensors rename to examples/a_paddle.safetensors diff --git a/examples/paddle_case/b_paddle.safetensors b/examples/b_paddle.safetensors similarity index 100% rename from examples/paddle_case/b_paddle.safetensors rename to examples/b_paddle.safetensors diff --git a/examples/extract_keys.py b/examples/extract_keys.py index 9a3b4f9..142a09b 100644 --- a/examples/extract_keys.py +++ b/examples/extract_keys.py @@ -1,32 +1,38 @@ -import sys import os -import torch +import sys +from typing import Dict, List + +from safetensors.torch import load_file + from fastsafetensors import SafeTensorsFileLoader, SingleGroup -from safetensors import safe_open if __name__ == "__main__": if len(sys.argv) != 2: print("specify a directory containing safetensors files") sys.exit(1) - loader = SafeTensorsFileLoader(SingleGroup(), torch.device("cpu"), nogds=True) + loader = SafeTensorsFileLoader(SingleGroup(), device="cpu", nogds=True) input_file_or_dir = sys.argv[1] - src_files = {0: []} + src_files: Dict[int, List[str]] = {0: []} orig_keys = {} if os.path.isdir(input_file_or_dir): for dir, _, files in os.walk(input_file_or_dir): - for filename in files: - if filename.endswith(".safetensors"): - src_files[0].append(f"{dir}/{filename}") - elif os.path.exists(input_file_or_dir) and input_file_or_dir.endswith(".safetensors"): - src_files[0].append(input_file_or_dir) - with safe_open(input_file_or_dir, framework="pytorch") as f: - for key in f.keys(): - orig_keys[key] = f.get_tensor(key) + for filename in files: + if filename.endswith(".safetensors"): + src_files[0].append(f"{dir}/{filename}") + elif os.path.exists(input_file_or_dir) and input_file_or_dir.endswith( + ".safetensors" + ): + src_files[0].append(input_file_or_dir) + orig_keys = load_file(input_file_or_dir) loader.add_filenames(src_files) fb = loader.copy_files_to_device() if len(orig_keys) > 0: for key in loader.get_keys(): - print(f"\"{key}\",{loader.get_shape(key)},{loader.frames[key].data_offsets},{fb.get_tensor(key).dtype},{orig_keys[key].dtype}") + print( + f'"{key}",{loader.get_shape(key)},{loader.frames[key].data_offsets},{fb.get_tensor(key).dtype},{orig_keys[key].dtype}' + ) else: for key in loader.get_keys(): - print(f"\"{key}\",{loader.get_shape(key)},{loader.frames[key].data_offsets},{fb.get_tensor(key).dtype}") + print( + f'"{key}",{loader.get_shape(key)},{loader.frames[key].data_offsets},{fb.get_tensor(key).dtype}' + ) diff --git a/examples/fix_alignment.py b/examples/fix_alignment.py index a0b4b6f..fe7897f 100644 --- a/examples/fix_alignment.py +++ b/examples/fix_alignment.py @@ -1,57 +1,75 @@ #!/usr/bin/env python3 -import sys import os +import sys + os.environ["CUDA_VISIBLE_DEVICES"] = "" -from fastsafetensors import SafeTensorsMetadata -from fastsafetensors.common import CUDA_PTR_ALIGN +import json import shutil from copy import deepcopy -import json + +from fastsafetensors import SafeTensorsMetadata +from fastsafetensors.frameworks._torch import TorchOp + def fix_sten_file(src_file: str, dst_file: str): - pad_key='p' - pad_value='P' + pad_key = "p" + pad_value = "P" src_fd = os.open(src_file, os.O_RDONLY, 0o644) if src_fd < 0: raise Exception(f"FAIL: open, src_file={src_file}") - meta = SafeTensorsMetadata.from_fd(src_fd, src_file, keep_orig_dict=True) - print(f"src: filename={src_file}, header_len={meta.header_length}, size={meta.size_bytes}") + meta = SafeTensorsMetadata.from_fd( + src_fd, src_file, framework=TorchOp(), keep_orig_dict=True + ) + print( + f"src: filename={src_file}, header_len={meta.header_length}, size={meta.size_bytes}" + ) - min_head_pad_len = len(bytes(f',"{pad_key}":""', encoding='utf-8')) - dst_header = {'__metadata__': deepcopy(meta.metadata)} + min_head_pad_len = len(bytes(f',"{pad_key}":""', encoding="utf-8")) + dst_header = {"__metadata__": deepcopy(meta.metadata)} dst_header.update(meta.ser) - dst_header_str = str.encode(json.dumps(dst_header, separators=(',', ':')), 'utf-8') + dst_header_str = str.encode(json.dumps(dst_header, separators=(",", ":")), "utf-8") dst_header_len = len(dst_header_str) + 8 head_pad = 0 need_copy = True + CUDA_PTR_ALIGN = meta.framework.get_device_ptr_align() if dst_header_len % CUDA_PTR_ALIGN > 0: head_pad = CUDA_PTR_ALIGN - dst_header_len % CUDA_PTR_ALIGN if head_pad < min_head_pad_len: head_pad += CUDA_PTR_ALIGN - dst_header['__metadata__'][pad_key] = pad_value*(head_pad - min_head_pad_len) - dst_header_str = str.encode(json.dumps(dst_header, separators=(',', ':')), 'utf-8') + dst_header["__metadata__"][pad_key] = pad_value * (head_pad - min_head_pad_len) + dst_header_str = str.encode( + json.dumps(dst_header, separators=(",", ":")), "utf-8" + ) dst_header_len = len(dst_header_str) + 8 - print(f"dst: filename={dst_file}, header_len={dst_header_len} (pad={head_pad}), size={dst_header_len + meta.size_bytes - meta.header_length}") + print( + f"dst: filename={dst_file}, header_len={dst_header_len} (pad={head_pad}), size={dst_header_len + meta.size_bytes - meta.header_length}" + ) - dst_fd = os.open(dst_file, os.O_WRONLY|os.O_CREAT|os.O_EXCL, 0o644) + dst_fd = os.open(dst_file, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644) if dst_fd < 0: raise Exception(f"FAIL: open, dst_fd={dst_fd}") - os_write_full(dst_fd, (dst_header_len - 8).to_bytes(length=8, byteorder='little', signed=False)) + os_write_full( + dst_fd, + (dst_header_len - 8).to_bytes(length=8, byteorder="little", signed=False), + ) os_write_full(dst_fd, dst_header_str) os.lseek(src_fd, meta.header_length, os.SEEK_SET) os.lseek(dst_fd, dst_header_len, os.SEEK_SET) - os_sendfile_full(dst_fd, src_fd, dst_header_len, meta.size_bytes-meta.header_length) + os_sendfile_full( + dst_fd, src_fd, dst_header_len, meta.size_bytes - meta.header_length + ) os.close(dst_fd) need_copy = False - meta2 = SafeTensorsMetadata.from_file(dst_file) + meta2 = SafeTensorsMetadata.from_file(dst_file, TorchOp()) print(f"new metadata: {meta2.metadata}") else: print(f"no fixes are required. skip") os.close(src_fd) return need_copy + def os_write_full(fd: int, buf: bytes): count = 0 while count < len(buf): @@ -59,22 +77,30 @@ def os_write_full(fd: int, buf: bytes): if c == 0: break elif c < 0: - raise IOError(f"os_write_full: os.write returned error, fd={fd}, len(buf)={len(buf)}, count={count}") + raise IOError( + f"os_write_full: os.write returned error, fd={fd}, len(buf)={len(buf)}, count={count}" + ) count += c + def os_sendfile_full(src_fd: int, dst_fd: int, offset: int, length: int): count = 0 while count < length: - c = os.sendfile(src_fd, dst_fd, None, length-count) + c = os.sendfile(src_fd, dst_fd, 0, length - count) if c == 0: break elif c < 0: - raise IOError(f"os_sendfile_full: os.sendfile returned error, src_fd={src_fd}, dst_fd={dst_fd}, offset={offset}, length={length}, count={count}") + raise IOError( + f"os_sendfile_full: os.sendfile returned error, src_fd={src_fd}, dst_fd={dst_fd}, offset={offset}, length={length}, count={count}" + ) count += c + if __name__ == "__main__": if len(sys.argv) != 4: - print("specify a transformers_cache directory, src model name, and dst model name") + print( + "specify a transformers_cache directory, src model name, and dst model name" + ) sys.exit(1) cache_dir = sys.argv[1] src_dir = os.path.join(cache_dir, "models--" + sys.argv[2].replace("/", "--")) diff --git a/examples/gen.py b/examples/gen.py index 8cf9409..2eb6e14 100644 --- a/examples/gen.py +++ b/examples/gen.py @@ -1,6 +1,44 @@ -import os -import torch -t0 = torch.concat([torch.full((1,8), i, dtype=torch.float16) for i in range(0, 16)], dim=0) -from safetensors.torch import save_file -for file_prefix in ["a", "b"]: - save_file({f"{file_prefix}0": t0}, f"{file_prefix}.safetensors", metadata={"fst": "sample"}) +def gen_torch(): + import torch + from safetensors.torch import save_file + + t0 = torch.concat( + [torch.full((1, 8), i, dtype=torch.float16) for i in range(0, 16)], dim=0 + ) + + for file_prefix in ["a", "b"]: + save_file( + {f"{file_prefix}0": t0}, + f"{file_prefix}.safetensors", + metadata={"fst": "sample"}, + ) + + +def gen_paddle(): + import paddle + from safetensors.paddle import save_file + + t0 = paddle.concat( + [paddle.full((1, 8), i, dtype=paddle.float16) for i in range(0, 16)], axis=0 + ) + + for file_prefix in ["a", "b"]: + save_file( + {f"{file_prefix}0": t0}, + f"{file_prefix}_paddle.safetensors", + metadata={"fst": "sample"}, + ) + + +gens = { + "torch": gen_torch, + "paddle": gen_paddle, +} + +if __name__ == "__main__": + import sys + + framework = "torch" + if len(sys.argv) > 1: + framework = sys.argv[1] + gens[framework]() diff --git a/examples/paddle_case/gen.py b/examples/paddle_case/gen.py deleted file mode 100644 index 8ca69fd..0000000 --- a/examples/paddle_case/gen.py +++ /dev/null @@ -1,6 +0,0 @@ -import os -import paddle -t0 = paddle.concat([paddle.full((1,8), i, dtype=paddle.float16) for i in range(0, 16)], dim=0) -from safetensors.paddle import save_file -for file_prefix in ["a", "b"]: - save_file({f"{file_prefix}0": t0}, f"{file_prefix}_paddle.safetensors", metadata={"fst": "sample"}) diff --git a/examples/paddle_case/run_parallel.py b/examples/paddle_case/run_parallel.py deleted file mode 100644 index f132337..0000000 --- a/examples/paddle_case/run_parallel.py +++ /dev/null @@ -1,32 +0,0 @@ -# !/usr/bin/env python3 -# PIDS=() - -# runner="python -m paddle.distributed.launch" - -# cd paddle_case -# ${runner} --nnodes=2 --master=127.0.0.1:12345 --rank=0 run_parallel.py & -# PIDS+=($!) -# ${runner} --nnodes=2 --master=127.0.0.1:12345 --rank=1 run_parallel.py & -# PIDS+=($!) -# wait "${PIDS[@]}" - -import paddle -import paddle.distributed as dist -from fastsafetensors import SafeTensorsFileLoader -dist.init_parallel_env() -backend = "nccl" if paddle.device.cuda.device_count() else "gloo" -pg = dist.new_group(ranks=[0,1], backend=backend) -device = "gpu" if paddle.device.cuda.device_count() else "cpu" -loader = SafeTensorsFileLoader(pg, device, nogds=False, debug_log=True, framework="paddle") -loader.add_filenames({0: ["a_paddle.safetensors"], 1:["b_paddle.safetensors"]}) # {rank: files} - -# load a.safetensors to rank 0 GPU and b.safetensors to rank 1 GPU -fb = loader.copy_files_to_device() - -# every rank must call get_tensor and get_sharded in the same order since they internally call paddle.distributed collective ops -tensor_a0 = fb.get_tensor(tensor_name="a0") # broadcast -tensor_b0_sharded = fb.get_sharded(tensor_name="b0", dim=1) # partition and scatter -print(f"RANK {pg.process_group.rank()}: tensor_a0={tensor_a0}") -print(f"RANK {pg.process_group.rank()}: tensor_b0_sharded={tensor_b0_sharded}") -fb.close() -loader.close() diff --git a/examples/paddle_case/run_single.py b/examples/paddle_case/run_single.py deleted file mode 100644 index 6d5b042..0000000 --- a/examples/paddle_case/run_single.py +++ /dev/null @@ -1,13 +0,0 @@ -import paddle -from fastsafetensors import SafeTensorsFileLoader, SingleGroup -device = "gpu:0" if paddle.device.cuda.device_count() else "cpu" -loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=False, debug_log=True, framework="paddle") -loader.add_filenames({0: ["a_paddle.safetensors", "b_paddle.safetensors"]}) # {rank: files} -fb = loader.copy_files_to_device() -tensor_a0 = fb.get_tensor(tensor_name="a0") -tensor_b0 = fb.get_tensor(tensor_name="b0") -print(f"a0: {tensor_a0}\n b0 : {tensor_b0}") -mycat = paddle.concat([tensor_a0, tensor_b0]) -print(f"cat: {mycat}, size={mycat.size}") -fb.close() -loader.close() diff --git a/examples/run_paddle_parallel_cpu.sh b/examples/run_paddle_parallel_cpu.sh index 6a95391..ac883ce 100755 --- a/examples/run_paddle_parallel_cpu.sh +++ b/examples/run_paddle_parallel_cpu.sh @@ -3,7 +3,6 @@ PIDS=() runner="python -m paddle.distributed.launch" -cd paddle_case rm -rf log # It can only be used on the CPU version of paddlepaddle -${runner} --nproc_per_node 2 run_parallel.py \ No newline at end of file +${runner} --nproc_per_node 2 run_parallel.py paddle \ No newline at end of file diff --git a/examples/run_paddle_parallel_gpu.sh b/examples/run_paddle_parallel_gpu.sh index 8995129..4618f04 100755 --- a/examples/run_paddle_parallel_gpu.sh +++ b/examples/run_paddle_parallel_gpu.sh @@ -3,9 +3,8 @@ PIDS=() runner="python -m paddle.distributed.launch" -cd paddle_case rm -rf log # It can only be used on the GPU version of paddlepaddle-gpu # A machine multy gpu (case : 1 machine 2 gpus) # Different to torch script because the paddle distributed use nccl to communicate in gpus -CUDA_VISIBLE_DEVICES=0,1 ${runner} --gpus 0,1 run_parallel.py \ No newline at end of file +CUDA_VISIBLE_DEVICES=0,1 ${runner} --gpus 0,1 run_parallel.py paddle \ No newline at end of file diff --git a/examples/run_parallel.py b/examples/run_parallel.py index eb25aae..a0491cd 100644 --- a/examples/run_parallel.py +++ b/examples/run_parallel.py @@ -7,23 +7,54 @@ # PIDS+=$($!) # wait ${PIDS[@]} -import torch -import torch.distributed as dist -from fastsafetensors import SafeTensorsFileLoader -dist.init_process_group(backend="gloo") -dist.barrier() -pg = dist.group.WORLD -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -loader = SafeTensorsFileLoader(pg, device, nogds=False, debug_log=True) -loader.add_filenames({0: ["a.safetensors"], 1:["b.safetensors"]}) # {rank: files} - -# load a.safetensors to rank 0 GPU and b.safetensors to rank 1 GPU -fb = loader.copy_files_to_device() - -# every rank must call get_tensor and get_sharded in the same order since they internally call torch.distributed collective ops -tensor_a0 = fb.get_tensor(tensor_name="a0") # broadcast -tensor_b0_sharded = fb.get_sharded(tensor_name="b0", dim=1) # partition and scatter -print(f"RANK {pg.rank()}: tensor_a0={tensor_a0}") -print(f"RANK {pg.rank()}: tensor_b0_sharded={tensor_b0_sharded}") -fb.close() -loader.close() + +def run_torch(): + import torch + import torch.distributed as dist + + dist.init_process_group(backend="gloo") + dist.barrier() + pg = dist.group.WORLD + device = "cuda:0" if torch.cuda.is_available() else "cpu" + return pg, device + + +def run_paddle(): + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + backend = "nccl" if paddle.device.cuda.device_count() else "gloo" + pg = dist.new_group(ranks=[0, 1], backend=backend) + device = "gpu" if paddle.device.cuda.device_count() else "cpu" + return pg, device + + +runs = { + "torch": run_torch, + "paddle": run_paddle, +} + +if __name__ == "__main__": + import sys + + from fastsafetensors import SafeTensorsFileLoader + + framework = "torch" + if len(sys.argv) > 1: + framework = sys.argv[1] + + pg, device = runs[framework]() + loader = SafeTensorsFileLoader(pg, device, nogds=False, debug_log=True) + loader.add_filenames({0: ["a.safetensors"], 1: ["b.safetensors"]}) # {rank: files} + + # load a.safetensors to rank 0 GPU and b.safetensors to rank 1 GPU + fb = loader.copy_files_to_device() + + # every rank must call get_tensor and get_sharded in the same order since they internally call torch.distributed collective ops + tensor_a0 = fb.get_tensor(tensor_name="a0") # broadcast + tensor_b0_sharded = fb.get_sharded(tensor_name="b0", dim=1) # partition and scatter + print(f"RANK {pg.rank()}: tensor_a0={tensor_a0}") + print(f"RANK {pg.rank()}: tensor_b0_sharded={tensor_b0_sharded}") + fb.close() + loader.close() diff --git a/examples/run_reuse_loader.py b/examples/run_reuse_loader.py index de15f83..2049fba 100644 --- a/examples/run_reuse_loader.py +++ b/examples/run_reuse_loader.py @@ -1,26 +1,28 @@ +import sys + import torch + from fastsafetensors import SafeTensorsFileLoader, SingleGroup -import sys sys.path.insert(0, "/nvme/manish/repos/fastsafetensors/fastsafetensors") -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = "cuda:0" if torch.cuda.is_available() else "cpu" loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=True, debug_log=True) -loader.add_filenames({0: ["a.safetensors"]}) # {rank: files} +loader.add_filenames({0: ["a.safetensors"]}) # {rank: files} fb = loader.copy_files_to_device() keys = list(fb.key_to_rank_lidx.keys()) for k in keys: t = fb.get_tensor(k) - print(f' k, shape = {k, t.shape}\n') + print(f" k, shape = {k, t.shape}\n") fb.close() -loader.reset() # reset the loader for reusing with different set of files -loader.add_filenames({0: ["b.safetensors"]}) # {rank: files} +loader.reset() # reset the loader for reusing with different set of files +loader.add_filenames({0: ["b.safetensors"]}) # {rank: files} fb = loader.copy_files_to_device() keys = list(fb.key_to_rank_lidx.keys()) for k in keys: t = fb.get_tensor(k) - print(f' k, shape = {k, t.shape}\n') + print(f" k, shape = {k, t.shape}\n") fb.close() loader.close() diff --git a/examples/run_single.py b/examples/run_single.py index 28299cc..2bd610a 100644 --- a/examples/run_single.py +++ b/examples/run_single.py @@ -1,13 +1,32 @@ -import torch -from fastsafetensors import SafeTensorsFileLoader, SingleGroup -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=False, debug_log=True) -loader.add_filenames({0: ["a.safetensors", "b.safetensors"]}) # {rank: files} -fb = loader.copy_files_to_device() -tensor_a0 = fb.get_tensor(tensor_name="a0") -tensor_b0 = fb.get_tensor(tensor_name="b0") -print(f"a0: {tensor_a0}") -mycat = torch.concat([tensor_a0, tensor_b0], dim=1) -print(f"cat: {mycat}, size={mycat.size()}") -fb.close() -loader.close() +#!/usr/bin/env python3 + +if __name__ == "__main__": + import sys + + from fastsafetensors import cpp as fstcpp + from fastsafetensors import fastsafe_open + + framework = "torch" + filenames = ["a.safetensors", "b.safetensors"] + if len(sys.argv) > 1: + framework = sys.argv[1] + if framework == "torch": + import torch + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + elif framework == "paddle": + import paddle + + device = "gpu" if paddle.device.cuda.device_count() else "cpu" + filenames = ["a_paddle.safetensors", "b_paddle.safetensors"] + else: + raise Exception(f"unknown framework: {framework}") + + with fastsafe_open( + filenames, + device=device, + nogds=not fstcpp.is_cufile_found(), + framework=framework, + ) as f: + print(f"a0: {f.get_tensor(name='a0')}") + print(f"b0: {f.get_tensor(name='b0')}") diff --git a/examples/run_torch_parrallel.sh b/examples/run_torch_parrallel.sh index 9ca6d11..d707865 100755 --- a/examples/run_torch_parrallel.sh +++ b/examples/run_torch_parrallel.sh @@ -2,7 +2,7 @@ PIDS=() torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 run_parallel.py & -PIDS+=$($!) +PIDS+=($!) torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 run_parallel.py & -PIDS+=$($!) +PIDS+=($!) wait ${PIDS[@]} \ No newline at end of file diff --git a/examples/test_vllm.py b/examples/test_vllm.py deleted file mode 100644 index fd347ca..0000000 --- a/examples/test_vllm.py +++ /dev/null @@ -1,28 +0,0 @@ -import sys -import os -import vllm -from vllm.config import LoadFormat -from pathlib import Path - -def drop_cache(model_dir: str): - total = 0 - for f in Path(model_dir).rglob("*"): - if (f.suffix == ".safetensors"): - fd = os.open(f.resolve(), os.O_RDONLY) - s = os.fstat(fd) - os.posix_fadvise(fd, 0, s.st_size, os.POSIX_FADV_DONTNEED) - os.close(fd) - print(f"DROP_CACHE: {f}, {s.st_size/1024/1024/1024} GiB") - total += s.st_size - print(f"total={total/1024/1024/1024}GiB from {model_dir}") - -if __name__ == "__main__": - load_format = LoadFormat.AUTO - if len(sys.argv) > 1 and sys.argv[1] == "1": - load_format = LoadFormat.FASTSAFETENSORS - os.environ["FASTSAFETENSORS_ENABLE_INIT_LOG"] = "1" - print("export FASTSAFETENSORS_ENABLE_INIT_LOG=1") - if len(sys.argv) > 2 and sys.argv[2] == "1": - from transformers.utils import TRANSFORMERS_CACHE - drop_cache(os.path.join(TRANSFORMERS_CACHE, "models--ibm-granite--granite-3.0-8b-instruct/snapshots")) - _ = vllm.LLM(model="ibm-granite/granite-3.0-8b-instruct", load_format=load_format) diff --git a/examples/tgis_weight.py b/examples/tgis_weight.py index 19ea14f..27f256d 100644 --- a/examples/tgis_weight.py +++ b/examples/tgis_weight.py @@ -1,37 +1,50 @@ # Copyright 2024 IBM Inc. All rights reserved # SPDX-License-Identifier: Apache-2.0 -import os import glob +import json +import os +from typing import Any, Dict, List, Optional, Tuple + import torch import torch.distributed as dist -from typing import List, Optional, Dict, Tuple, Any from loguru import logger -import json + +from fastsafetensors.frameworks import TensorBase +from fastsafetensors.loader import SafeTensorsFileLoader +from fastsafetensors.st_types import DType QUANTIZE_CONFIG_FILENAME = "quantize_config.json" + def get_config(device_index: int) -> Tuple[bool, int, int]: auto_config = os.getenv("FST_CONFIG", "auto") - nogds = os.getenv("FST_NOGDS") # disable GDS if FST_NOGDS==1 - nogds = nogds is not None and nogds == "1" - max_copier_threads = int(os.getenv("FST_THREADS", "16")) # number of copy threads at host CPU - bbuf_size_kb_total = int(os.getenv("FST_BBUF_SIZE_KB", "163840")) # size of bounce buffer at host memory for FST_NOGDS==1 + nogds_str = os.getenv("FST_NOGDS") # disable GDS if FST_NOGDS==1 + nogds = nogds_str is not None and nogds_str == "1" + max_copier_threads = int( + os.getenv("FST_THREADS", "16") + ) # number of copy threads at host CPU + bbuf_size_kb_total = int( + os.getenv("FST_BBUF_SIZE_KB", "163840") + ) # size of bounce buffer at host memory for FST_NOGDS==1 if auto_config == "auto": - nogds = not os.path.exists("/run/udev") # udev directory is required for GDS + nogds = not os.path.exists("/run/udev") # udev directory is required for GDS from fastsafetensors.common import get_device_numa_node + node = get_device_numa_node(device_index) total_l2_size = 0 phys_cpus = {} failed = False for cpudir in glob.glob(f"/sys/devices/system/node/node{node}/cpu[0-9]*"): try: - with open(f"{cpudir}/cache/index2/size") as f: # L2 cache size for a cpu + with open( + f"{cpudir}/cache/index2/size" + ) as f: # L2 cache size for a cpu size_str = f.read().strip() if size_str[-1] != "K": raise Exception(f"cannot parse {cpudir}/cache/index2/size") total_l2_size += int(size_str[:-1]) - with open(f"{cpudir}/topology/core_id") as f: # physical core ID + with open(f"{cpudir}/topology/core_id") as f: # physical core ID phys_cpus[f.read().strip()] = True except Exception as e: failed = True @@ -43,22 +56,61 @@ def get_config(device_index: int) -> Tuple[bool, int, int]: max_copier_threads = len(phys_cpus) return (nogds, max_copier_threads, bbuf_size_kb_total) + +dtype_convert: Dict[torch.dtype, DType] = { + torch.bool: DType.BOOL, + torch.uint8: DType.U8, + torch.int8: DType.I8, + torch.int16: DType.I16, + torch.int32: DType.I32, + torch.int64: DType.I64, + torch.float16: DType.F16, + torch.bfloat16: DType.BF16, + torch.float32: DType.F32, + torch.float64: DType.F64, +} + +if hasattr(torch, "float8_e5m2"): + dtype_convert[torch.float8_e5m2] = DType.F8_E5M2 +if hasattr(torch, "float8_e4m3fn"): + dtype_convert[torch.float8_e4m3fn] = DType.F8_E4M3 +if hasattr(torch, "uint16"): + dtype_convert[torch.uint16] = DType.U16 +if hasattr(torch, "uint32"): + dtype_convert[torch.uint32] = DType.U32 +if hasattr(torch, "uint64"): + dtype_convert[torch.uint64] = DType.U64 + + class FastWeights: - def __init__(self, filenames:List[str], - device: torch.device, - dtype: torch.dtype, - pg: dist.ProcessGroup, - debug_log: bool=False, - aliases: Optional[Dict[str, List[str]]] = None, - prefix: Optional[str] = None, - ): - from fastsafetensors.loader import SafeTensorsFileLoader + def __init__( + self, + filenames: List[str], + device: torch.device, + dtype: torch.dtype, + pg: dist.ProcessGroup, + debug_log: bool = False, + aliases: Optional[Dict[str, List[str]]] = None, + prefix: Optional[str] = None, + ): (nogds, max_copier_threads, bbuf_size_kb_total) = get_config(device.index) - self._loader = SafeTensorsFileLoader(pg, device, bbuf_size_kb=bbuf_size_kb_total//pg.size(), max_threads=max_copier_threads, nogds=nogds, debug_log=debug_log) - rank_filenames: Dict[str, List[str]] = {rank: [] for rank in range(0, pg.size())} + st_dtype = dtype_convert[dtype] + self._loader = SafeTensorsFileLoader( + pg, + str(device), + bbuf_size_kb=bbuf_size_kb_total // pg.size(), + max_threads=max_copier_threads, + nogds=nogds, + debug_log=debug_log, + ) + rank_filenames: Dict[int, List[str]] = { + rank: [] for rank in range(0, pg.size()) + } max_copy_block_size = 1 total_size = 0 - for idx, filename in enumerate(sorted(filenames, key=lambda x: os.path.basename(x))): + for idx, filename in enumerate( + sorted(filenames, key=lambda x: os.path.basename(x)) + ): rank_filenames[idx % pg.size()].append(filename) s = os.stat(filename) total_size += s.st_size @@ -67,20 +119,29 @@ def __init__(self, filenames:List[str], self._loader.add_filenames(rank_filenames) if len(filenames) < max_copier_threads: max_copy_block_size = total_size // pg.size() // max_copier_threads - if max_copy_block_size % bbuf_size_kb_total*1024 > 0: - max_copy_block_size = max_copy_block_size - max_copy_block_size % (bbuf_size_kb_total*1024) + (bbuf_size_kb_total*1024) + if max_copy_block_size % bbuf_size_kb_total * 1024 > 0: + max_copy_block_size = ( + max_copy_block_size + - max_copy_block_size % (bbuf_size_kb_total * 1024) + + (bbuf_size_kb_total * 1024) + ) msg = f"Fastsafetensors configuration: GDS={not nogds}, maximum number of file copy threads={max_copier_threads}, copy block size={max_copy_block_size}B" if nogds: msg += f", total bounce buffer size={bbuf_size_kb_total * 1024}B" print(msg) - self._fb = self._loader.copy_files_to_device(dtype, max_copy_block_size=max_copy_block_size) + self._fb = self._loader.copy_files_to_device( + st_dtype, max_copy_block_size=max_copy_block_size + ) self.device = device + self.st_device = self._loader.device self.dtype = dtype + self.st_dtype = st_dtype if aliases is None: aliases = {} self.prefix = prefix self.aliases = aliases self.process_group = pg + self.st_pg = self._loader.pg self.routing = {} for key in self._loader.get_keys(): self.routing[key] = True @@ -90,7 +151,7 @@ def close(self): self._loader.close() torch.cuda.empty_cache() - def _get_alias(self, tensor_name: str)->str: + def _get_alias(self, tensor_name: str) -> str: if self._fb.get_filename(tensor_name) is None: if tensor_name in self.aliases: for alias in self.aliases[tensor_name]: @@ -99,30 +160,56 @@ def _get_alias(self, tensor_name: str)->str: raise RuntimeError(f"weight {tensor_name} does not exist") return tensor_name - def get_shape(self, tensor_name: str)->torch.Size: - return self._fb.get_shape(self._get_alias(tensor_name)) + def get_shape(self, tensor_name: str) -> torch.Size: + return torch.Size(self._fb.get_shape(self._get_alias(tensor_name))) - def get_tensor(self, tensor_name: str)->torch.Tensor: - return self._fb.get_tensor(self._get_alias(tensor_name), device=self.device, dtype=self.dtype) + def get_tensor(self, tensor_name: str) -> torch.Tensor: + return self._fb.get_tensor( + self._get_alias(tensor_name), device=self.st_device, dtype=self.st_dtype + ).get_raw() - def push_tensor(self, tensor_name: str, dst_rank: int)->torch.Tensor: - return self._fb.push_tensor(self._get_alias(tensor_name), dst_rank, device=self.device, dtype=self.dtype) + def push_tensor(self, tensor_name: str, dst_rank: int) -> Optional[torch.Tensor]: + t = self._fb.push_tensor( + self._get_alias(tensor_name), + dst_rank, + device=self.st_device, + dtype=self.st_dtype, + ) + return t.get_raw() if isinstance(t, TensorBase) else None - def get_partial_sharded(self, tensor_name: str, dim: int)->torch.Tensor: - return self._fb.get_sharded(self._get_alias(tensor_name), dim, device=self.device, dtype=self.dtype) + def get_partial_sharded(self, tensor_name: str, dim: int) -> torch.Tensor: + return self._fb.get_sharded( + self._get_alias(tensor_name), + dim, + device=self.st_device, + dtype=self.st_dtype, + ).get_raw() - def get_sharded(self, tensor_name: str, dim: int=1)->torch.Tensor: - return self._fb.get_sharded(self._get_alias(tensor_name), dim, device=self.device, dtype=self.dtype) + def get_sharded(self, tensor_name: str, dim: int = 1) -> torch.Tensor: + return self._fb.get_sharded( + self._get_alias(tensor_name), + dim, + device=self.st_device, + dtype=self.st_dtype, + ).get_raw() def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): if quantize == "gptq": try: - qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) + qweight = torch.cat( + [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) except RuntimeError: - raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) - qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1) - scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1) + qzeros = torch.cat( + [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + scales = torch.cat( + [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) @@ -140,9 +227,11 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda) else: tensor_names = [self._get_alias(f"{prefix}.weight") for prefix in prefixes] - weight = self._fb.get_multi_cols(tensor_names, dim, device=self.device, dtype=self.dtype) + weight = self._fb.get_multi_cols( + tensor_names, dim, device=self.st_device, dtype=self.st_dtype + ).get_raw() return weight - + def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": bits, groupsize = self._get_gptq_params() @@ -150,9 +239,18 @@ def get_multi_weights_row(self, prefix: str, quantize: str): use_gptq_cuda = bits == 4 if self.process_group.size() > 1: - g_idx = self.get_tensor(f"{prefix}.g_idx") + g_idx: Optional[torch.Tensor] = self.get_tensor(f"{prefix}.g_idx") if g_idx is not None: - if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all(): + if ( + not torch.equal( + g_idx.cpu(), + torch.tensor( + [i // groupsize for i in range(g_idx.shape[0])], + dtype=torch.int32, + ), + ) + and not (g_idx == 0).all() + ): # Exllama implementation does not support row tensor parallelism with act-order, as # it would require to reorder input activations that are split unto several GPUs use_gptq_cuda = False @@ -160,9 +258,12 @@ def get_multi_weights_row(self, prefix: str, quantize: str): try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: - raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) from text_generation_server.utils.layers import HAS_GPTQ_CUDA + if use_gptq_cuda: use_gptq_cuda = HAS_GPTQ_CUDA if self.process_group.rank == 0: @@ -199,22 +300,26 @@ def get_multi_weights_row(self, prefix: str, quantize: str): weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda) else: - weight = self._fb.get_sharded(self._get_alias(f"{prefix}.weight"), 1, device=self.device, dtype=self.dtype) + weight = self._fb.get_sharded( + self._get_alias(f"{prefix}.weight"), + 1, + device=self.st_device, + dtype=self.st_dtype, + ).get_raw() return weight - def _get_gptq_params(self) -> Tuple[int, int]: try: - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() - except (RuntimeError) as e: + bits = int(self.get_tensor("gptq_bits").item()) + groupsize = int(self.get_tensor("gptq_groupsize").item()) + except RuntimeError as e: try: bits = self.gptq_bits groupsize = self.gptq_groupsize except Exception: raise e return bits, groupsize - + def _set_gptq_params(self, model_config: Any, model_path: str): # Get quantization config from model's configuration # or else look for quantize_config.json in the model dir @@ -228,4 +333,4 @@ def _set_gptq_params(self, model_config: Any, model_path: str): quantize_config = json.load(f) self.gptq_bits = quantize_config["bits"] - self.gptq_groupsize = quantize_config["group_size"] \ No newline at end of file + self.gptq_groupsize = quantize_config["group_size"] diff --git a/fastsafetensors/__init__.py b/fastsafetensors/__init__.py index 2c867db..f005e5e 100644 --- a/fastsafetensors/__init__.py +++ b/fastsafetensors/__init__.py @@ -1,6 +1,6 @@ # Copyright 2024 IBM Inc. All rights reserved # SPDX-License-Identifier: Apache-2.0 +from .common import SafeTensorsMetadata, SingleGroup, TensorFrame, get_device_numa_node +from .file_buffer import FilesBufferOnDevice from .loader import SafeTensorsFileLoader, fastsafe_open -from .common import SafeTensorsMetadata, TensorFrame, SingleGroup, get_device_numa_node, str_to_dtype, alloc_tensor_memory, free_tensor_memory -from .file_buffer import FilesBufferOnDevice \ No newline at end of file diff --git a/fastsafetensors/common.py b/fastsafetensors/common.py index b8ac341..1855b1b 100644 --- a/fastsafetensors/common.py +++ b/fastsafetensors/common.py @@ -1,267 +1,263 @@ # Copyright 2024 IBM Inc. All rights reserved # SPDX-License-Identifier: Apache-2.0 -import torch -try: - import paddle - from paddle.framework import core as paddle_core - paddle_loaded = True -except: - paddle_loaded = False -import os import json +import os +import platform from collections import OrderedDict +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + from . import cpp as fstcpp -from typing import List, Dict, Tuple from .dlpack import from_cuda_buffer +from .frameworks import FrameworkOpBase, TensorBase +from .st_types import Device, DType -class SingleGroup: - def size(self): - return 1 - def rank(self): - return 0 - -ALIGN: int = fstcpp.get_alignment_size() -CUDA_PTR_ALIGN: int = 16 - -framework_index = { - "pytorch": 1, -} -dtype_convert = { - 'BOOL': (1, torch.bool), 'U8': (1, torch.uint8), 'I8': (1, torch.int8), 'F8_E5M2': (1, torch.float8_e5m2), 'F8_E4M3': (1, torch.float8_e4m3fn), - 'I16': (2, torch.int16), 'U16': (2, torch.int16), 'I32': (4, torch.int32), 'U32': (4, torch.int32), 'I64': (8, torch.int64), 'U64': (8, torch.int64), - 'F16': (2, torch.float16), 'BF16': (2, torch.bfloat16), 'F32': (4, torch.float32), 'F64': (8, torch.float64) -} - -need_workaround_dtypes = { - torch.float8_e5m2: torch.int8, - torch.float8_e4m3fn: torch.int8, -} - -if paddle_loaded: - framework_index = { - "pytorch": 1, - "paddle": 2, - } - dtype_convert = { - 'BOOL': (1, torch.bool, paddle.bool), 'U8': (1, torch.uint8, paddle.uint8), 'I8': (1, torch.int8, paddle.int8), 'F8_E5M2': (1, torch.float8_e5m2, paddle.float8_e5m2), 'F8_E4M3': (1, torch.float8_e4m3fn, paddle.float8_e4m3fn), - 'I16': (2, torch.int16, paddle.int16), 'U16': (2, torch.int16, paddle.bfloat16), 'I32': (4, torch.int32, paddle.int32), 'U32': (4, torch.int32, paddle.int32), 'I64': (8, torch.int64, paddle.int64), 'U64': (8, torch.int64, paddle.int64), - 'F16': (2, torch.float16, paddle.float16), 'BF16': (2, torch.bfloat16, paddle.bfloat16), 'F32': (4, torch.float32, paddle.float32), 'F64': (8, torch.float64, paddle.float64) - } - need_workaround_dtypes = { - torch.float8_e5m2: torch.int8, - torch.float8_e4m3fn: torch.int8, - paddle.float8_e5m2 : paddle.int8, - paddle.float8_e4m3fn : paddle.int8 - } - -def str_to_dtype(dtype_str: str, framework: str="pytorch")->torch.dtype: - if framework not in framework_index.keys(): - raise NotImplementedError(f"str_to_dtype: Not implemented for other frameworks than {framework_index.keys()}") - if dtype_str not in dtype_convert: - raise ValueError(f"str_to_dtype: Not supported dtype: {dtype_str}") - return dtype_convert[dtype_str][framework_index[framework]] - -def get_device_numa_node(device: int): - if device is None: - return +def get_device_numa_node(device: Optional[int]) -> Optional[int]: + if device is None or platform.system() != "Linux": + return None pci_addr = fstcpp.get_device_pci_bus(device) if pci_addr == "": - #raise Exception(f"get_device_numa_node, get_device_pci_bus failed, device={device}") - return - bus_addr = ':'.join(pci_addr.split(":")[:2]).lower() + return None + bus_addr = ":".join(pci_addr.split(":")[:2]).lower() syspath = f"/sys/class/pci_bus/{bus_addr}/device/numa_node" if not os.path.exists(syspath): - return 0 + return None with open(syspath) as f: return int(f.read().strip()) -def alloc_tensor_memory(length: int, dev: torch.device, framework: str="pytorch")->fstcpp.gds_device_buffer: - dev_is_gpu = True - if framework == "pytorch" and dev.type == 'cuda': - rbuf = torch.cuda.caching_allocator_alloc(length) - elif paddle_loaded and framework == "paddle" and "gpu" in dev: - rbuf = fstcpp.gpu_malloc(length) - else: - dev_is_gpu = False - rbuf = fstcpp.cpu_malloc(length) - return fstcpp.gds_device_buffer(rbuf, length, dev_is_gpu) -def free_tensor_memory(gbuf: fstcpp.gds_device_buffer, dev: torch.device, framework: str="pytorch"): - if framework =="pytorch" and dev.type == 'cuda': - rbuf = torch.cuda.caching_allocator_delete(gbuf.get_base_address()) - elif paddle_loaded and framework == "paddle" and "gpu" in dev: - rbuf = fstcpp.gpu_free(gbuf.get_base_address()) - else: - rbuf = fstcpp.cpu_free(gbuf.get_base_address()) - return rbuf +# keep this for compatibility +class SingleGroup: + def size(self): + return 1 + + def rank(self): + return 0 class SafeTensorsMetadata: - def __init__(self, string: str, header_length: int, size_bytes: int, src: str="", keep_orig_dict: bool=False, framework: str="pytorch"): + def __init__( + self, + string: str, + header_length: int, + size_bytes: int, + framework: FrameworkOpBase, + src: str = "", + keep_orig_dict: bool = False, + ): self.src = src self.framework = framework ser = json.loads(string, object_pairs_hook=OrderedDict) - self.metadata = ser.get('__metadata__', '') + self.metadata = ser.get("__metadata__", "") if self.metadata: - del(ser['__metadata__']) + del ser["__metadata__"] self.tensors: Dict[str, TensorFrame] = {} self.header_length = header_length - self.aligned = header_length % CUDA_PTR_ALIGN == 0 + self.aligned = header_length % framework.get_device_ptr_align() == 0 if keep_orig_dict: self.ser = ser start = 0 - for _, (k, buffer) in enumerate(sorted(ser.items(), key=lambda x: x[1]['data_offsets'][0])): - t = TensorFrame.from_buffer(buffer, self.framework) + for _, (k, buffer) in enumerate( + sorted(ser.items(), key=lambda x: x[1]["data_offsets"][0]) + ): + t: TensorFrame = TensorFrame.from_buffer(buffer) self.tensors[k] = t # validation s, e = t.data_offsets if s != start or e < s: - raise Exception(f"validate(tensor {k}): InvalidOffset s={s}, start={start}, e={e}, src={src}") - #if (header_length + s) % CUDA_PTR_ALIGN > 0: + raise Exception( + f"validate(tensor {k}): InvalidOffset s={s}, start={start}, e={e}, src={src}" + ) + # if (header_length + s) % CUDA_PTR_ALIGN > 0: # print(f"[WARNING] misaligned tensor is detected at {header_length + s}. this will cause cuda pointer alignment errors later.") start = e nelements = 1 for sh in t.shape: nelements *= sh - if self.framework == "pytorch": - t_dtype_size = t.dtype.itemsize - elif paddle_loaded and self.framework == "paddle": - t_dtype_size = paddle_core.size_of_dtype(t.dtype) - nbytes = nelements * t_dtype_size + nbytes = nelements * framework.get_dtype_size(t.dtype) if (e - s) != nbytes: - raise Exception(f"validate(tensor {k}): TensorInvalidInfo, e-s={e-s}, nbytes={nbytes}, src={src}") + raise Exception( + f"validate(tensor {k}): TensorInvalidInfo, e-s={e-s}, nbytes={nbytes}, src={src}" + ) self.size_bytes = size_bytes if start + header_length != size_bytes: - raise Exception(f"MetadataIncompleteBuffer, src={src}, start={start}, header_length={header_length}, size_bytes={size_bytes}") + raise Exception( + f"MetadataIncompleteBuffer, src={src}, start={start}, header_length={header_length}, size_bytes={size_bytes}" + ) @classmethod - def from_buffer(self, buf: int, buffer_len: int, filename: str): + def from_buffer( + self, buf: int, buffer_len: int, filename: str, framework: FrameworkOpBase + ): if buffer_len < 8: - raise Exception(f"from_buffer: HeaderTooSmall, filename={filename}, buffer_len={buffer_len}") + raise Exception( + f"from_buffer: HeaderTooSmall, filename={filename}, buffer_len={buffer_len}" + ) arr = fstcpp.read_buffer(buf, 8) - n = int.from_bytes(arr, byteorder='little', signed=False) + n = int.from_bytes(arr, byteorder="little", signed=False) if n > 100000000: - raise Exception(f"from_buffer: HeaderTooLarge, n={n}, filename={filename}, buffer_len={buffer_len}") + raise Exception( + f"from_buffer: HeaderTooLarge, n={n}, filename={filename}, buffer_len={buffer_len}" + ) if n > buffer_len - 8: - raise Exception(f"from_buffer: InvalidHeaderLength, n={n}, filename={filename}, buffer_len={buffer_len}") - string = fstcpp.read_buffer(buf+8, n).decode('utf-8') + raise Exception( + f"from_buffer: InvalidHeaderLength, n={n}, filename={filename}, buffer_len={buffer_len}" + ) + string = fstcpp.read_buffer(buf + 8, n).decode("utf-8") # Assert the string starts with { # NOTE: Add when we move to 0.4.0 - #if string.startswith('{'): + # if string.startswith('{'): # raise Exception(f"{filename}: InvalidHeaderStart") - return SafeTensorsMetadata(string, n + 8, buffer_len) + return SafeTensorsMetadata(string, n + 8, buffer_len, framework) @classmethod - def from_fd(self, fd: int, filename: str, keep_orig_dict: bool=False, framework: str="pytorch"): + def from_fd( + self, + fd: int, + filename: str, + framework: FrameworkOpBase, + keep_orig_dict: bool = False, + ): status = os.fstat(fd) buffer_len = status.st_size if buffer_len < 8: raise Exception(f"{filename}: HeaderTooSmall, buffer_len={buffer_len}") arr = os.read(fd, 8) - n = int.from_bytes(arr, byteorder='little', signed=False) + n = int.from_bytes(arr, byteorder="little", signed=False) if n > 100000000: - raise Exception(f"{filename}: HeaderTooLarge, n={n}, buffer_len={buffer_len}") + raise Exception( + f"{filename}: HeaderTooLarge, n={n}, buffer_len={buffer_len}" + ) if n > buffer_len - 8: - raise Exception(f"{filename}: InvalidHeaderLength, n={n}, buffer_len={buffer_len}") - string = os.read(fd, n).decode('utf-8') + raise Exception( + f"{filename}: InvalidHeaderLength, n={n}, buffer_len={buffer_len}" + ) + string = os.read(fd, n).decode("utf-8") # Assert the string starts with { # NOTE: Add when we move to 0.4.0 - #if string.startswith('{'): + # if string.startswith('{'): # raise Exception(f"{filename}: InvalidHeaderStart") - return SafeTensorsMetadata(string, n + 8, buffer_len, filename, keep_orig_dict=keep_orig_dict, framework=framework) + return SafeTensorsMetadata( + string, + n + 8, + buffer_len, + framework, + filename, + keep_orig_dict=keep_orig_dict, + ) @classmethod - def from_file(self, filename: str, framework: str="pytorch"): + def from_file(self, filename: str, framework: FrameworkOpBase): fd = os.open(filename, os.O_RDONLY, 0o644) - ret = self.from_fd(fd, filename, keep_orig_dict=False, framework=framework) + ret = self.from_fd(fd, filename, framework=framework, keep_orig_dict=False) os.close(fd) return ret - def get_tensors(self, gbuf: fstcpp.gds_device_buffer, device: torch.device, copy_start_offset: int, dtype: torch.dtype=None) -> Dict[str, torch.Tensor]: + def get_tensors( + self, + gbuf: fstcpp.gds_device_buffer, + device: Device, + copy_start_offset: int, + dtype: DType = DType.AUTO, + ) -> Dict[str, TensorBase]: ret = {} for tensor_name, t in self.tensors.items(): - dst_dev_ptr = gbuf.get_base_address() + self.header_length + t.data_offsets[0]-copy_start_offset - if self.framework == "pytorch": - if t.dtype in need_workaround_dtypes: - t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[t.dtype], device)).view(t.dtype) - else: - t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, t.dtype, device)) - if dtype is not None and dtype != t.dtype: - if dtype.itemsize > t.dtype.itemsize: - raise Exception(f"Online type conversion to larger sizes is not supported ({t.dtype} -> {dtype})") - t3 = t2.to(dtype=dtype) - if dtype in need_workaround_dtypes: - t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[dtype], device)).view(dtype) - else: - t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, dtype, device)) - t2.copy_(t3) - self.tensors[tensor_name].dtype = dtype - elif paddle_loaded and self.framework == "paddle": - if t.dtype in need_workaround_dtypes: - t2 = paddle.utils.dlpack.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[t.dtype], device)) - else: - t2 = paddle.utils.dlpack.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, t.dtype, device)) - if dtype is not None and dtype != t.dtype: - if paddle_core.size_of_dtype(dtype) > paddle_core.size_of_dtype(t.dtype): - raise Exception(f"Online type conversion to larger sizes is not supported ({t.dtype} -> {dtype})") - t3 = t2.to(dtype=dtype) - if t.dtype in need_workaround_dtypes: - t2 = paddle.utils.dlpack.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[dtype], device)) - else: - t2 = paddle.utils.dlpack.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, dtype, device)) - paddle.assign(t3, output=t2) - self.tensors[tensor_name].dtype = dtype + dst_dev_ptr = ( + gbuf.get_base_address() + + self.header_length + + t.data_offsets[0] + - copy_start_offset + ) + disk_dtype = self.framework.as_workaround_dtype(t.dtype) + dl_tensor = from_cuda_buffer( + dst_dev_ptr, + t.shape, + t.strides, + disk_dtype, + device, + ) + t2 = self.framework.from_dlpack(dl_tensor, device, disk_dtype) + if disk_dtype != t.dtype: + t2 = t2.view(t.dtype) + + if dtype != DType.AUTO and dtype != t.dtype: + if self.framework.get_dtype_size(dtype) > self.framework.get_dtype_size( + t.dtype + ): + raise Exception( + f"Online type conversion to larger sizes is not supported ({t.dtype} -> {dtype})" + ) + t3 = t2.to(dtype=dtype) + conv_dtype: DType = self.framework.as_workaround_dtype(dtype) + dl_tensor = from_cuda_buffer( + dst_dev_ptr, + t.shape, + t.strides, + conv_dtype, + device, + ) + t2 = self.framework.from_dlpack(dl_tensor, device, conv_dtype) + if dtype != conv_dtype: + t2 = t2.view(dtype) + self.framework.copy_tensor(t2, t3) + self.tensors[tensor_name].dtype = dtype ret[tensor_name] = t2 return ret - def __repr__(self)->str: + def __repr__(self) -> str: return str({"__metadata__": self.metadata, "tensors": self.tensors}) + +@dataclass class TensorFrame: - def __init__(self, dtype: torch.dtype, shape: torch.Size, data_offsets: List[int], strides: List[int], offsets: List[int], sliced: bool): - self.dtype = dtype - self.shape = shape - self.data_offsets = data_offsets - self.strides = strides - self.offsets = offsets - self.sliced = sliced + dtype: DType + shape: List[int] + data_offsets: List[int] + strides: List[int] + offsets: List[int] + sliced: bool @classmethod - def from_buffer(self, entry: OrderedDict[str, List[int]], framework:str="pytorch"): - dtype = str_to_dtype(entry['dtype'], framework=framework) - shape = entry['shape'] - if framework == "pytorch": - shape = torch.Size(shape) - data_offsets = list(entry['data_offsets']) + def from_buffer(self, entry: OrderedDict[str, List[int]]): + shape = entry["shape"] + data_offsets = list(entry["data_offsets"]) strides = [] offsets = [] for i in range(0, len(shape)): s = 1 - for j in range(i+1, len(shape)): + for j in range(i + 1, len(shape)): s *= shape[j] strides.append(s) offsets.append(0) - return TensorFrame(dtype, shape, data_offsets, strides, offsets, False) + return TensorFrame( + DType(entry["dtype"]), shape, data_offsets, strides, offsets, False + ) - def __repr__(self)->str: - return str({ - "dtype": self.dtype, "shape": self.shape, "data_offsets": self.data_offsets, - }) + def __repr__(self) -> str: + return str( + { + "dtype": self.dtype, + "shape": self.shape, + "data_offsets": self.data_offsets, + } + ) # TODO: reduce dim if isinstance(_val, int) == True - def __getitem__(self, _val): + def __getitem__(self, _val) -> "TensorFrame": val: Tuple = () if isinstance(_val, slice) or isinstance(_val, int): val = (_val,) - elif isinstance(_val, Tuple): + elif isinstance(_val, tuple): val = _val else: raise Exception(f"[BUG] Unsupported index type for DiskTensor: {_val}") if len(val) > len(self.shape): - raise Exception(f"[BUG] tried to get too large slice {_val} from {self.shape}") + raise Exception( + f"[BUG] tried to get too large slice {_val} from {self.shape}" + ) shape: List[int] = [] strides: List[int] = [] offsets: List[int] = [] @@ -269,7 +265,9 @@ def __getitem__(self, _val): if isinstance(val[dim], int): start = val[dim] if start >= self.shape[dim] or start < -self.shape[dim]: - raise IndexError(f"[BUG] tried to access index {start} at dim={dim} for shape={self.shape}") + raise IndexError( + f"[BUG] tried to access index {start} at dim={dim} for shape={self.shape}" + ) if start < 0: start = self.shape[dim] + start + 1 stop = start + 1 @@ -294,12 +292,18 @@ def __getitem__(self, _val): if step == 0: raise ValueError(f"[BUG] slice step cannot be zero") length = stop - start - if length == 0 or (length < 0 and step > 0) or (length > 0 and step < 0): - return TensorFrame(self.dtype, torch.Size([]), self.data_offsets, (), ()) + if ( + length == 0 + or (length < 0 and step > 0) + or (length > 0 and step < 0) + ): + return TensorFrame(self.dtype, [], self.data_offsets, [], [], False) if length < 0 and step < 0: length *= -1 else: - raise Exception(f"[BUG] Unsupported index type for DiskTensor: {_val} at dim={dim}") + raise Exception( + f"[BUG] Unsupported index type for DiskTensor: {_val} at dim={dim}" + ) offsets.append(self.offsets[dim] + start) strides.append(self.strides[dim] * step) shape.append(length // (step if step > 0 else -step)) @@ -307,4 +311,4 @@ def __getitem__(self, _val): offsets.append(self.offsets[rdim]) strides.append(self.strides[rdim]) shape.append(self.shape[rdim]) - return TensorFrame(self.dtype, torch.Size(shape), self.data_offsets, strides, offsets, True) + return TensorFrame(self.dtype, shape, self.data_offsets, strides, offsets, True) diff --git a/fastsafetensors/copier/gds.py b/fastsafetensors/copier/gds.py index 65a9b36..4dc1d7c 100644 --- a/fastsafetensors/copier/gds.py +++ b/fastsafetensors/copier/gds.py @@ -1,40 +1,49 @@ # Copyright 2024 IBM Inc. All rights reserved # SPDX-License-Identifier: Apache-2.0 -import torch +from typing import Dict, Optional + from .. import cpp as fstcpp -from typing import Dict -from ..common import alloc_tensor_memory, free_tensor_memory, SafeTensorsMetadata, ALIGN, CUDA_PTR_ALIGN, paddle_loaded -if paddle_loaded: - import paddle +from ..common import SafeTensorsMetadata +from ..frameworks import FrameworkOpBase, TensorBase +from ..st_types import Device, DeviceType, DType + class GdsFileCopier: - def __init__(self, metadata: SafeTensorsMetadata, device: torch.device, reader: fstcpp.gds_file_reader, debug_log: bool=False): + def __init__( + self, + metadata: SafeTensorsMetadata, + device: Device, + reader: fstcpp.gds_file_reader, + framework: FrameworkOpBase, + debug_log: bool = False, + ): + self.framework = framework self.metadata = metadata self.device = device self.reader = reader self.debug_log = debug_log self.gbuf = None - self.fh = 0 + self.fh: Optional[fstcpp.gds_file_handle] = None self.copy_reqs: Dict[int, int] = {} self.aligned_length = 0 - try: - if self.metadata.framework == "pytorch": - cuda_vers_list = torch.version.cuda.split('.') - elif paddle_loaded and self.metadata.framework == "paddle": - cuda_vers_list = paddle.version.cuda().split('.') - cudavers = list(map(int, cuda_vers_list)) - # CUDA 12.2 (GDS version 1.7) introduces support for non O_DIRECT file descriptors - # Compatible with CUDA 11.x - self.o_direct = not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2)) - except: - self.o_direct = True + cudavers = list(map(int, framework.get_cuda_ver().split("."))) + # CUDA 12.2 (GDS version 1.7) introduces support for non O_DIRECT file descriptors + # Compatible with CUDA 11.x + self.o_direct = not ( + cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2) + ) def set_o_direct(self, enable: bool): self.o_direct = enable - def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gds_device_buffer: - dev_is_cuda = (self.metadata.framework == "pytorch" and self.device.type == 'cuda') or (paddle_loaded and self.metadata.framework == "paddle" and "gpu" in self.device) + def submit_io( + self, use_buf_register: bool, max_copy_block_size: int + ) -> fstcpp.gds_device_buffer: + dev_is_cuda = ( + self.device.type == DeviceType.CUDA or self.device.type == DeviceType.GPU + ) + ALIGN: int = fstcpp.get_alignment_size() self.fh = fstcpp.gds_file_handle(self.metadata.src, self.o_direct, dev_is_cuda) offset = self.metadata.header_length length = self.metadata.size_bytes - self.metadata.header_length @@ -47,7 +56,7 @@ def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gd aligned_length = length + head_bytes aligned_offset = offset - head_bytes - gbuf = alloc_tensor_memory(aligned_length, self.device, self.metadata.framework) + gbuf = self.framework.alloc_tensor_memory(aligned_length, self.device) if use_buf_register: count = 0 while count < aligned_length: @@ -55,7 +64,11 @@ def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gd if req_len > max_copy_block_size: req_len = max_copy_block_size if gbuf.cufile_register(count, req_len) < 0: - raise Exception("submit_io: register_buffer failed, ptr=0x{:x}, count={}, len={}".format(gbuf.get_base_address(), count, req_len)) + raise Exception( + "submit_io: register_buffer failed, ptr=0x{:x}, count={}, len={}".format( + gbuf.get_base_address(), count, req_len + ) + ) count += req_len count = 0 @@ -64,40 +77,65 @@ def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gd if req_len > max_copy_block_size: req_len = max_copy_block_size # TODO: pass timeout so that wait_copy_tensors can recognize too slow pread() - req = self.reader.submit_read(self.fh, gbuf, aligned_offset + count, req_len, count, self.metadata.size_bytes) + req = self.reader.submit_read( + self.fh, + gbuf, + aligned_offset + count, + req_len, + count, + self.metadata.size_bytes, + ) self.copy_reqs[req] = -1 if not use_buf_register else count count += req_len self.aligned_offset = aligned_offset self.aligned_length = aligned_length return gbuf - def wait_io(self, gbuf: fstcpp.gds_device_buffer, dtype: torch.dtype=None, noalign: bool=False)->Dict[str, torch.Tensor]: + def wait_io( + self, + gbuf: fstcpp.gds_device_buffer, + dtype: DType = DType.AUTO, + noalign: bool = False, + ) -> Dict[str, TensorBase]: failed = [] - for req, c in sorted(self.copy_reqs.items(), key=lambda x:x[0]): + for req, c in sorted(self.copy_reqs.items(), key=lambda x: x[0]): count = self.reader.wait_read(req) if count < 0: failed.append(req) if c != -1: gbuf.cufile_deregister(c) - if self.fh != 0: + if self.fh is not None: del self.fh - self.fh = 0 + self.fh = None if len(failed) > 0: - raise Exception(f"wait_io: wait_gds_read failed, failed={failed}, reqs={self.copy_reqs}") + raise Exception( + f"wait_io: wait_gds_read failed, failed={failed}, reqs={self.copy_reqs}" + ) self.copy_reqs = {} if not noalign and not self.metadata.aligned and self.aligned_length > 0: - misaligned_bytes = self.metadata.header_length % CUDA_PTR_ALIGN - length = 1024*1024*1024 - tmp_gbuf = alloc_tensor_memory(length, self.device, self.metadata.framework) + misaligned_bytes = ( + self.metadata.header_length % self.framework.get_device_ptr_align() + ) + length = 1024 * 1024 * 1024 + tmp_gbuf = self.framework.alloc_tensor_memory(length, self.device) count = 0 while count + misaligned_bytes < self.aligned_length: l = self.aligned_length - misaligned_bytes - count if l > length: l = length if self.debug_log: - print("wait_io: fix misalignment, src=0x{:x}, misaligned_bytes={}, count={}, tmp=0x{:x}".format(gbuf.get_base_address(), misaligned_bytes, count, tmp_gbuf.get_base_address())) + print( + "wait_io: fix misalignment, src=0x{:x}, misaligned_bytes={}, count={}, tmp=0x{:x}".format( + gbuf.get_base_address(), + misaligned_bytes, + count, + tmp_gbuf.get_base_address(), + ) + ) gbuf.memmove(count, misaligned_bytes + count, tmp_gbuf, l) count += l - free_tensor_memory(tmp_gbuf, self.device, self.metadata.framework) + self.framework.free_tensor_memory(tmp_gbuf, self.device) self.aligned_offset += misaligned_bytes - return self.metadata.get_tensors(gbuf, self.device, self.aligned_offset, dtype=dtype) + return self.metadata.get_tensors( + gbuf, self.device, self.aligned_offset, dtype=dtype + ) diff --git a/fastsafetensors/copier/nogds.py b/fastsafetensors/copier/nogds.py index f8359b7..18183ac 100644 --- a/fastsafetensors/copier/nogds.py +++ b/fastsafetensors/copier/nogds.py @@ -1,39 +1,61 @@ # Copyright 2024 IBM Inc. All rights reserved # SPDX-License-Identifier: Apache-2.0 -import torch import os +from typing import Dict, List + from .. import cpp as fstcpp -from typing import Dict -from ..common import alloc_tensor_memory, SafeTensorsMetadata, ALIGN, CUDA_PTR_ALIGN +from ..common import SafeTensorsMetadata +from ..frameworks import FrameworkOpBase, TensorBase +from ..st_types import Device, DType + class NoGdsFileCopier: - def __init__(self, metadata: SafeTensorsMetadata, device: torch.device, reader: fstcpp.nogds_file_reader, debug_log: bool=False): + def __init__( + self, + metadata: SafeTensorsMetadata, + device: Device, + reader: fstcpp.nogds_file_reader, + framework: FrameworkOpBase, + debug_log: bool = False, + ): + self.framework = framework self.metadata = metadata self.reader = reader self.fd = os.open(metadata.src, os.O_RDONLY, 0o644) if self.fd < 0: - raise Exception(f"NoGdsFileCopier.__init__: failed to open, file={metadata.src}") + raise Exception( + f"NoGdsFileCopier.__init__: failed to open, file={metadata.src}" + ) self.device = device self.debug_log = debug_log - self.reqs = [] + self.reqs: List[int] = [] - def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gds_device_buffer: + def submit_io( + self, use_buf_register: bool, max_copy_block_size: int + ) -> fstcpp.gds_device_buffer: total_length = self.metadata.size_bytes - self.metadata.header_length - gbuf = alloc_tensor_memory(total_length, self.device, self.metadata.framework) + gbuf = self.framework.alloc_tensor_memory(total_length, self.device) count = 0 while count < total_length: l = total_length - count if max_copy_block_size < l: l = max_copy_block_size - req = self.reader.submit_read(self.fd, gbuf, self.metadata.header_length + count, l, count) + req = self.reader.submit_read( + self.fd, gbuf, self.metadata.header_length + count, l, count + ) if req < 0: raise Exception(f"submit_io: submit_nogds_read failed, err={req}") self.reqs.append(req) count += l return gbuf - def wait_io(self, gbuf: fstcpp.gds_device_buffer, dtype: torch.dtype=None, noalign: bool=False)->Dict[str, torch.Tensor]: + def wait_io( + self, + gbuf: fstcpp.gds_device_buffer, + dtype: DType = DType.AUTO, + noalign: bool = False, + ) -> Dict[str, TensorBase]: for req in self.reqs: count = self.reader.wait_read(req) if count < 0: @@ -41,4 +63,6 @@ def wait_io(self, gbuf: fstcpp.gds_device_buffer, dtype: torch.dtype=None, noali if self.fd > 0: os.close(self.fd) self.fd = 0 - return self.metadata.get_tensors(gbuf, self.device, self.metadata.header_length, dtype=dtype) + return self.metadata.get_tensors( + gbuf, self.device, self.metadata.header_length, dtype=dtype + ) diff --git a/fastsafetensors/cpp.pyi b/fastsafetensors/cpp.pyi new file mode 100644 index 0000000..6222b6c --- /dev/null +++ b/fastsafetensors/cpp.pyi @@ -0,0 +1,52 @@ +# Copyright 2025 IBM Inc. All rights reserved +# SPDX-License-Identifier: Apache-2.0 + +class gds_device_buffer: + def __init__(self, devPtr_base: int, length: int, use_cuda: bool) -> None: ... + def cufile_register(self, offset: int, length: int) -> int: ... + def cufile_deregister(self, offset: int) -> int: ... + def memmove( + self, dst_off: int, src_off: int, tmp: "gds_device_buffer", length: int + ) -> int: ... + def get_base_address(self) -> int: ... + +class nogds_file_reader: + def __init__( + self, use_mmap: bool, bbuf_size_kb: int, max_threads: int, use_cuda: bool + ) -> None: ... + def submit_read( + self, fd: int, dst: gds_device_buffer, offset: int, length: int, ptr_off: int + ) -> int: ... + def wait_read(self, thread_id: int) -> int: ... + +class gds_file_handle: + def __init__(self, filename: str, o_direct: bool, use_cuda: bool) -> None: ... + +class gds_file_reader: + def __init__(self, max_threads: int, use_cuda: bool) -> None: ... + def submit_read( + self, + fh: gds_file_handle, + dst: gds_device_buffer, + offset: int, + length: int, + ptr_off: int, + file_length: int, + ) -> int: ... + def wait_read(self, id: int) -> int: ... + +def is_cuda_found() -> bool: ... +def is_cufile_found() -> bool: ... +def cufile_version() -> int: ... +def get_alignment_size() -> int: ... +def set_debug_log(debug_log: bool) -> None: ... +def init_gds() -> int: ... +def close_gds() -> int: ... +def get_device_pci_bus(deviceId: int) -> str: ... +def set_numa_node(numa_node: int) -> int: ... +def read_buffer(dst: int, length: int) -> bytes: ... +def cpu_malloc(length: int) -> int: ... +def cpu_free(addr: int) -> None: ... +def gpu_malloc(length: int) -> int: ... +def gpu_free(addr: int) -> None: ... +def load_nvidia_functions() -> None: ... diff --git a/fastsafetensors/cpp/ext.cpp b/fastsafetensors/cpp/ext.cpp index c88e5f2..bf47410 100644 --- a/fastsafetensors/cpp/ext.cpp +++ b/fastsafetensors/cpp/ext.cpp @@ -18,18 +18,18 @@ static bool debug_log = false; /* cpu_mode functions: for tests and debugs */ -static CUfileError_t cpu_cuFileDriverOpen() { return CUfileError_t{err: CU_FILE_SUCCESS}; } -static CUfileError_t cpu_cuFileDriverClose() { return CUfileError_t{err: CU_FILE_SUCCESS}; } -static CUfileError_t cpu_cuFileDriverSetMaxDirectIOSize(size_t) { return CUfileError_t{err: CU_FILE_SUCCESS}; } -static CUfileError_t cpu_cuFileDriverSetMaxPinnedMemSize(size_t) { return CUfileError_t{err: CU_FILE_SUCCESS}; } -static CUfileError_t cpu_cuFileBufRegister(const void *, size_t, int) { return CUfileError_t{err: CU_FILE_SUCCESS}; } -static CUfileError_t cpu_cuFileBufDeregister(const void *) { return CUfileError_t{err: CU_FILE_SUCCESS}; } +static CUfileError_t cpu_cuFileDriverOpen() { return CUfileError_t{.err = CU_FILE_SUCCESS}; } +static CUfileError_t cpu_cuFileDriverClose() { return CUfileError_t{.err = CU_FILE_SUCCESS}; } +static CUfileError_t cpu_cuFileDriverSetMaxDirectIOSize(size_t) { return CUfileError_t{.err = CU_FILE_SUCCESS}; } +static CUfileError_t cpu_cuFileDriverSetMaxPinnedMemSize(size_t) { return CUfileError_t{.err = CU_FILE_SUCCESS}; } +static CUfileError_t cpu_cuFileBufRegister(const void *, size_t, int) { return CUfileError_t{.err = CU_FILE_SUCCESS}; } +static CUfileError_t cpu_cuFileBufDeregister(const void *) { return CUfileError_t{.err = CU_FILE_SUCCESS}; } static CUfileError_t cpu_cuFileHandleRegister(CUfileHandle_t * in, CUfileDescr_t *) { *in = reinterpret_cast(malloc(sizeof(CUfileHandle_t))); if (*in != nullptr) { - return CUfileError_t{err: CU_FILE_SUCCESS}; + return CUfileError_t{.err = CU_FILE_SUCCESS}; } - return CUfileError_t{err: CU_FILE_INTERNAL_ERROR}; + return CUfileError_t{.err = CU_FILE_INTERNAL_ERROR}; } static void cpu_cuFileHandleDeregister(CUfileHandle_t h) { free(reinterpret_cast(h)); @@ -57,21 +57,21 @@ static cudaError_t cpu_cudaDeviceGetPCIBusId(char * in, int s, int) { static int cpu_numa_run_on_node(int) {return 0; } ext_funcs_t cpu_fns = ext_funcs_t { - cuFileDriverOpen: cpu_cuFileDriverOpen, - cuFileDriverClose: cpu_cuFileDriverClose, - cuFileDriverSetMaxDirectIOSize: cpu_cuFileDriverSetMaxDirectIOSize, - cuFileDriverSetMaxPinnedMemSize: cpu_cuFileDriverSetMaxPinnedMemSize, - cuFileBufRegister: cpu_cuFileBufRegister, - cuFileBufDeregister: cpu_cuFileBufDeregister, - cuFileHandleRegister: cpu_cuFileHandleRegister, - cuFileHandleDeregister: cpu_cuFileHandleDeregister, - cuFileRead: nullptr, - cudaMemcpy: cpu_cudaMemcpy, - cudaDeviceSynchronize: cpu_cudaDeviceSynchronize, - cudaHostAlloc: cpu_cudaHostAlloc, - cudaFreeHost: cpu_cudaFreeHost, - cudaDeviceGetPCIBusId: cpu_cudaDeviceGetPCIBusId, - numa_run_on_node: cpu_numa_run_on_node, + .cuFileDriverOpen = cpu_cuFileDriverOpen, + .cuFileDriverClose = cpu_cuFileDriverClose, + .cuFileDriverSetMaxDirectIOSize = cpu_cuFileDriverSetMaxDirectIOSize, + .cuFileDriverSetMaxPinnedMemSize = cpu_cuFileDriverSetMaxPinnedMemSize, + .cuFileBufRegister = cpu_cuFileBufRegister, + .cuFileBufDeregister = cpu_cuFileBufDeregister, + .cuFileHandleRegister = cpu_cuFileHandleRegister, + .cuFileHandleDeregister = cpu_cuFileHandleDeregister, + .cuFileRead = nullptr, + .cudaMemcpy = cpu_cudaMemcpy, + .cudaDeviceSynchronize = cpu_cudaDeviceSynchronize, + .cudaHostAlloc = cpu_cudaHostAlloc, + .cudaFreeHost = cpu_cudaFreeHost, + .cudaDeviceGetPCIBusId = cpu_cudaDeviceGetPCIBusId, + .numa_run_on_node = cpu_numa_run_on_node, }; ext_funcs_t cuda_fns; @@ -241,7 +241,7 @@ void set_debug_log(bool _debug_log) debug_log = _debug_log; } -int init_gds(uint64_t _max_direct_io_size_in_kb, uint64_t max_pinned_memory_size_in_kb) +int init_gds() { CUfileError_t err; @@ -253,32 +253,9 @@ int init_gds(uint64_t _max_direct_io_size_in_kb, uint64_t max_pinned_memory_size return -1; } } - - std::chrono::steady_clock::time_point begin_set_dio = std::chrono::steady_clock::now(); - if (cuda_fns.cuFileDriverSetMaxDirectIOSize) { - err = cuda_fns.cuFileDriverSetMaxDirectIOSize(_max_direct_io_size_in_kb); - if (err.err != CU_FILE_SUCCESS) { - std::fprintf(stderr, "init_gds: cuFileDriverGetProperties(%ld) returned an error = %d\n", _max_direct_io_size_in_kb, err.err); - close_gds(); - return -1; - } - } - - std::chrono::steady_clock::time_point begin_set_pin = std::chrono::steady_clock::now(); - if (cuda_fns.cuFileDriverSetMaxPinnedMemSize) { - err = cuda_fns.cuFileDriverSetMaxPinnedMemSize(max_pinned_memory_size_in_kb); - if (err.err != CU_FILE_SUCCESS) { - std::fprintf(stderr, "init_gds: cuFileDriverSetMaxPinnedMemSize(%ld) returned an error = %d\n", max_pinned_memory_size_in_kb, err.err); - close_gds(); - return -1; - } - } if (debug_log) { std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); - std::printf("[DEBUG] init_gds: cuFileDriverOpen=%ld us, cuFileDriverSetMaxDirectIOSize=%ld us, cuFileDriverSetMaxPinnedMemSize=%ld us, elapsed=%ld us\n", - std::chrono::duration_cast(begin_set_dio - begin).count(), - std::chrono::duration_cast(begin_set_pin - begin_set_dio).count(), - std::chrono::duration_cast(end - begin_set_pin).count(), + std::printf("[DEBUG] init_gds: cuFileDriverOpen=%lld us\n", std::chrono::duration_cast(end - begin).count()); } return 0; @@ -298,7 +275,7 @@ int close_gds() } if (debug_log) { std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); - std::printf("[DEBUG] close_gds: cuFileDriverClose, elapsed=%ld us\n", + std::printf("[DEBUG] close_gds: cuFileDriverClose, elapsed=%lld us\n", std::chrono::duration_cast(end - begin).count()); } return 0; @@ -375,7 +352,7 @@ const int gds_device_buffer::cufile_register(uint64_t offset, uint64_t length) { } if (debug_log) { std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); - std::printf("[DEBUG] gds_device_buffer.cufile_register: addr=%p, offset=%lu, length=%lu, register=%ld us\n", dst, offset, length, + std::printf("[DEBUG] gds_device_buffer.cufile_register: addr=%p, offset=%" PRIu64 ", length=%" PRIu64 ", register=%lld us\n", dst, offset, length, std::chrono::duration_cast(end - begin_register).count()); } return 0; @@ -392,7 +369,7 @@ const int gds_device_buffer::cufile_deregister(uint64_t offset) { } if (debug_log) { std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); - std::printf("[DEBUG] gds_device_buffer.cufile_deregister: addr=%p, offset=%ld, elapsed=%ld us\n", dst, offset, + std::printf("[DEBUG] gds_device_buffer.cufile_deregister: addr=%p, offset=%" PRIu64 ", elapsed=%lld us\n", dst, offset, std::chrono::duration_cast(end - begin).count()); } return 0; @@ -405,15 +382,15 @@ const int gds_device_buffer::memmove(uint64_t _dst_off, uint64_t _src_off, const void *tmp = const_cast(_tmp._devPtr_base->get_raw()); if (this->_length < _dst_off) { - std::fprintf(stderr, "gds_device_buffer.memmove: length is smaller than request dst_off, tmp.length=%ld, _dst_off=%ld\n", _tmp._length, _dst_off); + std::fprintf(stderr, "gds_device_buffer.memmove: length is smaller than request dst_off, tmp.length=%" PRIu64 ", _dst_off=%" PRIu64 "\n", _tmp._length, _dst_off); return -1; } if (this->_length < _src_off) { - std::fprintf(stderr, "gds_device_buffer.memmove: length is smaller than request dst_off, tmp.length=%ld, _src_off=%ld\n", _tmp._length, _src_off); + std::fprintf(stderr, "gds_device_buffer.memmove: length is smaller than request dst_off, tmp.length=%" PRIu64 ", _src_off=%" PRIu64 "\n", _tmp._length, _src_off); return -1; } if (_tmp._length < length) { - std::fprintf(stderr, "gds_device_buffer.memmove: tmp is smaller than request length, tmp.length=%ld, length=%ld\n", _tmp._length, length); + std::fprintf(stderr, "gds_device_buffer.memmove: tmp is smaller than request length, tmp.length=%" PRIu64 ", length=%" PRIu64 "\n", _tmp._length, length); return -1; } if (length == 0) { @@ -423,17 +400,17 @@ const int gds_device_buffer::memmove(uint64_t _dst_off, uint64_t _src_off, const std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); err = _fns->cudaMemcpy(tmp, src, length, cudaMemcpyDefault); if (err != cudaSuccess) { - std::printf("gds_device_buffer.memmove: cudaMemcpy[0](tmp=%p, src=%p, length=%ld) failed, err=%d\n", tmp, src, length, err); + std::printf("gds_device_buffer.memmove: cudaMemcpy[0](tmp=%p, src=%p, length=%" PRIu64 ") failed, err=%d\n", tmp, src, length, err); return -1; } err = _fns->cudaMemcpy(dst, tmp, length, cudaMemcpyDefault); if (err != cudaSuccess) { - std::printf("gds_device_buffer.memmove: cudaMemcpy[1](dst=%p, tmp=%p, length=%ld) failed, err=%d\n", dst, tmp, length, err); + std::printf("gds_device_buffer.memmove: cudaMemcpy[1](dst=%p, tmp=%p, length=%" PRIu64 ") failed, err=%d\n", dst, tmp, length, err); return -1; } if (debug_log) { std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); - std::printf("[DEBUG] gds_device_buffer.memmove: dst=%p, src=%p, tmp=%p, length=%ld, elapsed=%ld us\n", dst, src, tmp, length, + std::printf("[DEBUG] gds_device_buffer.memmove: dst=%p, src=%p, tmp=%p, length=%" PRIu64 ", elapsed=%lld us\n", dst, src, tmp, length, std::chrono::duration_cast(end - begin).count()); } return 0; @@ -451,13 +428,13 @@ void nogds_file_reader::_thread(const int thread_id, ext_funcs_t *fns, const int std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); src = mmap(NULL, length, PROT_READ, MAP_PRIVATE, fd, offset); if (src == MAP_FAILED) { - std::printf("nogds_file_reader._thread: mmap(fd=%d, offset=%ld, length=%ld) failed\n", fd, offset, length); + std::printf("nogds_file_reader._thread: mmap(fd=%d, offset=%" PRIu64 ", length=%" PRIu64 ") failed\n", fd, offset, length); failed = true; goto out; } if (debug_log) { std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); - std::printf("[DEBUG] nogds_file_reader._thread: mmap, fd=%d, offset=%ld, length=%ld, elapsed=%ld us\n", + std::printf("[DEBUG] nogds_file_reader._thread: mmap, fd=%d, offset=%" PRIu64 ", length=%" PRIu64 ", elapsed=%lld us\n", fd, offset, length, std::chrono::duration_cast(end - begin).count()); } } @@ -475,7 +452,7 @@ void nogds_file_reader::_thread(const int thread_id, ext_funcs_t *fns, const int } else { c = pread(fd, buffer, l, offset + count); if (c != l) { - std::printf("nogds_file_reader._thread failed: pread(fd=%d, buffer=%p, offset=%ld, count=%ld, l=%ld), c=%ld\n", fd, buffer, offset, count, l, c); + std::printf("nogds_file_reader._thread failed: pread(fd=%d, buffer=%p, offset=%" PRIu64 ", count=%" PRIi64 ", l=%" PRIi64 "), c=%" PRIi64 "\n", fd, buffer, offset, count, l, c); failed = true; goto out; } @@ -483,7 +460,7 @@ void nogds_file_reader::_thread(const int thread_id, ext_funcs_t *fns, const int std::chrono::steady_clock::time_point memcpy_begin = std::chrono::steady_clock::now(); err = fns->cudaMemcpy(dst._get_raw_pointer(ptr_off + count, c), buffer, c, cudaMemcpyHostToDevice); if (err != cudaSuccess) { - std::printf("nogds_file_reader._thread: cudaMemcpy(%p, %p, %ld) failed, err=%d\n", dst._get_raw_pointer(ptr_off + count, c), buffer, count, err); + std::printf("nogds_file_reader._thread: cudaMemcpy(%p, %p, %" PRIi64 ") failed, err=%d\n", dst._get_raw_pointer(ptr_off + count, c), buffer, count, err); failed = true; goto out; } else if (c <= 64 * 1024) { @@ -492,7 +469,7 @@ void nogds_file_reader::_thread(const int thread_id, ext_funcs_t *fns, const int count += c; if (debug_log) { std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); - std::printf("[DEBUG] nogds_file_reader._thread: read (mmap=%d), fd=%d, offset=%ld, count=%ld, c=%ld, copy=%ld us, cuda_copy=%ld us\n", + std::printf("[DEBUG] nogds_file_reader._thread: read (mmap=%d), fd=%d, offset=%" PRIu64 ", count=%" PRIi64 ", c=%" PRIi64 ", copy=%lld us, cuda_copy=%lld us\n", s->_use_mmap, fd, offset, count, c, std::chrono::duration_cast(memcpy_begin - begin).count(), std::chrono::duration_cast(end - memcpy_begin).count()); } } @@ -525,12 +502,12 @@ const int nogds_file_reader::submit_read(const int fd, const gds_device_buffer& std::chrono::steady_clock::time_point alloc_begin = std::chrono::steady_clock::now(); err = _fns->cudaHostAlloc(&this->_s._read_buffer, this->_s._bbuf_size_kb * 1024 * this->_s._max_threads, 0); if (err != cudaSuccess) { - std::printf("nogds_file_reader.submit_read: cudaHostAlloc(%lu) failed\n", this->_s._bbuf_size_kb * 1024 * this->_s._max_threads); + std::printf("nogds_file_reader.submit_read: cudaHostAlloc(%" PRIi64 ") failed\n", this->_s._bbuf_size_kb * 1024 * this->_s._max_threads); return -1; } if (debug_log) { std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); - std::printf("[DEBUG] nogds_file_reader.submit_read: cudaHostAlloc, size=%ld, elapsed=%ld us\n", + std::printf("[DEBUG] nogds_file_reader.submit_read: cudaHostAlloc, size=%" PRIi64 ", elapsed=%lld us\n", this->_s._bbuf_size_kb * 1024, std::chrono::duration_cast(end - alloc_begin).count()); } } @@ -579,7 +556,7 @@ nogds_file_reader::~nogds_file_reader() { } if (debug_log) { std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); - std::printf("[DEBUG] ~nogds_file_reader: elapsed=%ld us\n", + std::printf("[DEBUG] ~nogds_file_reader: elapsed=%lld us\n", std::chrono::duration_cast(end - begin).count()); } } @@ -592,9 +569,11 @@ raw_gds_file_handle::raw_gds_file_handle(std::string filename, bool o_direct, bo int flags = O_RDONLY; std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); +#if defined(O_DIRECT) if (o_direct) { flags |= O_DIRECT; } +#endif fd = open(filename.c_str(), flags, 0644); if (fd < 0) { char msg[256]; @@ -616,7 +595,7 @@ raw_gds_file_handle::raw_gds_file_handle(std::string filename, bool o_direct, bo } if (debug_log) { std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); - std::printf("[DEBUG] raw_gds_file_handle: fd=%d, cf_handle=%p, elapsed=%ld us\n", fd, cf_handle, + std::printf("[DEBUG] raw_gds_file_handle: fd=%d, cf_handle=%p, elapsed=%lld us\n", fd, cf_handle, std::chrono::duration_cast(end - begin).count()); } this->_cf_handle = cf_handle; @@ -653,7 +632,7 @@ void gds_file_reader::_thread(const int thread_id, ext_funcs_t *fns, const gds_f c = fns->cuFileRead(fh._get_cf_handle(), devPtr_base, length - count, offset + count, count); } if (debug_log) { - std::printf("[DEBUG] gds_file_reader._thread: cuFileRead(fh, %p, length=%lu, off=%lu, ptr_off=%lu, count=%ld)=%ld\n", devPtr_base, length, offset, ptr_off, count, c); + std::printf("[DEBUG] gds_file_reader._thread: cuFileRead(fh, %p, length=%" PRIu64 ", off=%" PRIu64 ", ptr_off=%" PRIu64 ", count=%zd)=%zd\n", devPtr_base, length, offset, ptr_off, count, c); } if (c < 0) { std::fprintf(stderr, "gds_file_reader._thread: cuFileRead returned an error: errno=%d\n", errno); @@ -671,7 +650,7 @@ void gds_file_reader::_thread(const int thread_id, ext_funcs_t *fns, const gds_f } if (debug_log) { std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); - std::printf("[DEBUG] gds_file_reader._thread: fh=%p, offset=%lu, length=%lu, count=%ld, read=%ld us, notify=%ld us\n", + std::printf("[DEBUG] gds_file_reader._thread: fh=%p, offset=%" PRIu64 ", length=%" PRIu64 ", count=%zd, read=%lld us, notify=%lld us\n", fh._get_cf_handle(), offset, length, count, std::chrono::duration_cast(begin_notify - begin).count(), std::chrono::duration_cast(end - begin_notify).count()); diff --git a/fastsafetensors/cpp/ext.hpp b/fastsafetensors/cpp/ext.hpp index acd82d7..465d4f7 100644 --- a/fastsafetensors/cpp/ext.hpp +++ b/fastsafetensors/cpp/ext.hpp @@ -61,7 +61,7 @@ typedef struct CUfileDrvProps { int get_alignment_size(); void set_debug_log(bool _debug_log); -int init_gds(uint64_t _max_direct_io_size_in_kb, uint64_t max_pinned_memory_size_in_kb); +int init_gds(); int close_gds(); std::string get_device_pci_bus(int deviceId); int set_numa_node(int numa_node); @@ -96,7 +96,7 @@ class gds_device_buffer { void * _get_raw_pointer(uint64_t offset, uint64_t length) const { // not exposed to python if (this->_length < offset + length) { char msg[256]; - snprintf(msg, 256, "out of bound access: 0x%p, length=%ld, request offset=%ld, request length=%ld", this->_devPtr_base->get_raw(), this->_length, offset, length); + snprintf(msg, 256, "out of bound access: 0x%p, length=%" PRIu64 ", request offset=%" PRIu64 ", request length=%" PRIu64, this->_devPtr_base->get_raw(), this->_length, offset, length); throw std::out_of_range(msg); } return reinterpret_cast(this->_devPtr_base->get_uintptr() + offset); @@ -127,8 +127,8 @@ class nogds_file_reader { public: nogds_file_reader(const bool use_mmap, const uint64_t bbuf_size_kb, const uint64_t max_threads, bool use_cuda): _next_thread_id(1), _threads(nullptr), _fns(use_cuda?&cuda_fns:&cpu_fns), - _s(thread_states_t{_read_buffer: nullptr, _use_mmap: use_mmap, - _bbuf_size_kb: (bbuf_size_kb + max_threads - 1)/max_threads, _max_threads: max_threads}) + _s(thread_states_t{._read_buffer = nullptr, ._use_mmap = use_mmap, + ._bbuf_size_kb = (bbuf_size_kb + max_threads - 1)/max_threads, ._max_threads = max_threads}) {} static void _thread(const int thread_id, ext_funcs_t *fns, const int fd, const gds_device_buffer& dst, const int64_t offset, const int64_t length, const uint64_t ptr_off, thread_states_t *s); // not exposed to python @@ -172,7 +172,7 @@ class gds_file_reader { thread_states_t _s; ext_funcs_t *_fns; public: - gds_file_reader(const int max_threads, bool use_cuda): _next_id(1), _threads(nullptr), _s(thread_states_t{_max_threads: max_threads}), _fns(use_cuda?&cuda_fns:&cpu_fns) {} + gds_file_reader(const int max_threads, bool use_cuda): _next_id(1), _threads(nullptr), _s(thread_states_t{._max_threads = max_threads}), _fns(use_cuda?&cuda_fns:&cpu_fns) {} static void _thread(const int thread_id, ext_funcs_t *fns, const gds_file_handle &fh, const gds_device_buffer &dst, const uint64_t offset, const uint64_t length, const uint64_t ptr_off, const uint64_t file_length, thread_states_t *s); const int submit_read(const gds_file_handle &fh, const gds_device_buffer &dst, const uint64_t offset, const uint64_t length, const uint64_t ptr_off, const uint64_t file_length); const ssize_t wait_read(const int id); diff --git a/fastsafetensors/dlpack.py b/fastsafetensors/dlpack.py index f9dc820..007487a 100644 --- a/fastsafetensors/dlpack.py +++ b/fastsafetensors/dlpack.py @@ -5,30 +5,17 @@ # to add from_cuda_buffer() import ctypes -import torch -from .common import paddle_loaded -from typing import List -if paddle_loaded: - import paddle +from typing import Dict, List, Union + +from .st_types import Device, DeviceType, DType _c_str_dltensor = b"dltensor" + class DLDevice(ctypes.Structure): - def __init__(self, device: torch.device): - if isinstance(device, str): - self.device_id = 0 - if device == "cpu": - self.device_type = self.TYPE_MAP[device] - else: - device = device.split(":") - if len(device) == 2: - self.device_id = int(device[1]) - self.device_type = self.TYPE_MAP[device[0]] - else: - self.device_type = self.TYPE_MAP[device.type] - self.device_id = 0 - if device.index: - self.device_id = device.index + def __init__(self, dev: Device): + self.device_type = self.DeviceToDL[dev.type] + self.device_id = dev.index if dev.index is not None else 0 kDLCPU = 1 kDLCUDA = 2 @@ -36,83 +23,79 @@ def __init__(self, device: torch.device): ("device_type", ctypes.c_int), ("device_id", ctypes.c_int), ] - TYPE_MAP= { - "cpu": kDLCPU, - "cuda": kDLCUDA, - "gpu": kDLCUDA + + DeviceToDL = { + DeviceType.CPU: kDLCPU, + DeviceType.CUDA: kDLCUDA, + DeviceType.GPU: kDLCUDA, } -class DLDataTypeCode(ctypes.c_uint8): +class c_DLDataType(ctypes.Structure): + def __init__(self, dtype: DType): + (self.type_code, self.bits, self.lanes) = self.STDataToDL[dtype] + kDLInt = 0 kDLUInt = 1 kDLFloat = 2 kDLBfloat = 4 - - def __str__(self): - return { - self.kDLInt: "int", - self.kDLUInt: "uint", - self.kDLFloat: "float", - self.kDLBfloat: "bfloat", - }[self.value] - - -class DLDataType(ctypes.Structure): + kDLBool = 6 _fields_ = [ - ("type_code", DLDataTypeCode), + ("type_code", ctypes.c_uint8), ("bits", ctypes.c_uint8), ("lanes", ctypes.c_uint16), ] - TYPE_MAP = { - torch.bool: (6, 8, 1), - torch.int8: (0, 8, 1), - torch.int16: (0, 16, 1), - torch.int32: (0, 32, 1), - torch.int: (0, 32, 1), - torch.int64: (0, 64, 1), - torch.uint8: (1, 8, 1), - torch.float16: (2, 16, 1), - torch.float32: (2, 32, 1), - torch.float64: (2, 64, 1), - torch.bfloat16: (4, 16, 1), + + STDataToDL: Dict[DType, tuple[int, int, int]] = { + DType.BOOL: (kDLBool, 8, 1), + DType.I8: (kDLInt, 8, 1), + DType.I16: (kDLInt, 16, 1), + DType.I32: (kDLInt, 32, 1), + DType.I64: (kDLInt, 64, 1), + DType.U8: (kDLUInt, 8, 1), + DType.U16: (kDLUInt, 16, 1), + DType.U32: (kDLUInt, 32, 1), + DType.U64: (kDLUInt, 64, 1), + DType.F16: (kDLFloat, 16, 1), + DType.F32: (kDLFloat, 32, 1), + DType.F64: (kDLFloat, 64, 1), + DType.BF16: (kDLBfloat, 16, 1), } - if paddle_loaded: - TYPE_MAP = { - torch.bool: (6, 8, 1), - torch.int8: (0, 8, 1), - torch.int16: (0, 16, 1), - torch.int32: (0, 32, 1), - torch.int: (0, 32, 1), - torch.int64: (0, 64, 1), - torch.uint8: (1, 8, 1), - torch.float16: (2, 16, 1), - torch.float32: (2, 32, 1), - torch.float64: (2, 64, 1), - torch.bfloat16: (4, 16, 1), - paddle.bool: (6, 8, 1), - paddle.int8: (0, 8, 1), - paddle.int16: (0, 16, 1), - paddle.int32: (0, 32, 1), - paddle.int64: (0, 64, 1), - paddle.uint8: (1, 8, 1), - paddle.float16: (2, 16, 1), - paddle.float32: (2, 32, 1), - paddle.float64: (2, 64, 1), - paddle.bfloat16: (4, 16, 1), - } + + +class _Holder: + def __init__(self, shape: List[int], strides: List[int]): + self.shape = (ctypes.c_int64 * len(shape))(*shape) + self.strides = (ctypes.c_int64 * len(strides))(*strides) + + def _as_manager_ctx(self) -> ctypes.c_void_p: + py_obj = ctypes.py_object(self) + py_obj_ptr = ctypes.pointer(py_obj) + ctypes.pythonapi.Py_IncRef(py_obj) + ctypes.pythonapi.Py_IncRef(ctypes.py_object(py_obj_ptr)) + return ctypes.cast(py_obj_ptr, ctypes.c_void_p) + class DLTensor(ctypes.Structure): _fields_ = [ ("data", ctypes.c_void_p), ("device", DLDevice), ("ndim", ctypes.c_int), - ("dtype", DLDataType), + ("dtype", c_DLDataType), ("shape", ctypes.POINTER(ctypes.c_int64)), ("strides", ctypes.POINTER(ctypes.c_int64)), ("byte_offset", ctypes.c_uint64), ] + def __init__(self, dev_ptr: int, dev: Device, dtype: DType, holder: _Holder): + self.data = dev_ptr + self.device = DLDevice(dev) + self.ndim = len(holder.shape) + self.dtype = c_DLDataType(dtype) + self.shape = holder.shape + self.strides = holder.strides + self.byte_offset = 0 + @property def itemsize(self): return self.dtype.lanes * self.dtype.bits // 8 @@ -151,6 +134,24 @@ class DLManagedTensor(ctypes.Structure): ("deleter", ctypes.CFUNCTYPE(None, ctypes.c_void_p)), ] + def as_py( + self, + dev_ptr: int, + shape: List[int], + strides: List[int], + dtype: DType, + dev: Device, + ): + holder = _Holder(shape, strides) + self.dl_tensor = DLTensor(dev_ptr, dev, dtype, holder) + self.manager_ctx = holder._as_manager_ctx() + self.deleter = _numpy_buffer_deleter + return ctypes.pythonapi.PyCapsule_New( + ctypes.byref(self), + _c_str_dltensor, + _numpy_pycapsule_deleter, + ) + @property def __array_interface__(self): return self.dl_tensor.__array_interface__ @@ -159,26 +160,25 @@ def __array_interface__(self): ctypes.pythonapi.PyMem_RawMalloc.restype = ctypes.c_void_p ctypes.pythonapi.PyMem_RawFree.argtypes = [ctypes.c_void_p] -ctypes.pythonapi.PyCapsule_New.restype=ctypes.py_object -ctypes.pythonapi.PyCapsule_New.argtypes=[ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] +ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object +ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.c_void_p, +] -class _Holder: - def __init__(self, shape: List[int], strides: List[int]): - self.shape = (ctypes.c_int64*len(shape))(*shape) - self.strides = (ctypes.c_int64*len(strides))(*strides) - - def _as_manager_ctx(self) -> ctypes.c_void_p: - py_obj = ctypes.py_object(self) - py_obj_ptr = ctypes.pointer(py_obj) - ctypes.pythonapi.Py_IncRef(py_obj) - ctypes.pythonapi.Py_IncRef(ctypes.py_object(py_obj_ptr)) - return ctypes.cast(py_obj_ptr, ctypes.c_void_p) - @ctypes.CFUNCTYPE(None, ctypes.c_void_p) -def _numpy_cuda_buffer_deleter(handle: ctypes.c_void_p) -> None: +def _numpy_buffer_deleter(handle: Union[int, ctypes.c_void_p]) -> None: """A function to deallocate the memory of a cuda buffer.""" - dl_managed_tensor = DLManagedTensor.from_address(handle) + if isinstance(handle, int): + dl_managed_tensor = DLManagedTensor.from_address(handle) + elif isinstance(handle, ctypes.c_void_p): + dl_managed_tensor = DLManagedTensor.from_address( + handle.value if handle.value else 0 + ) + else: + raise Exception("invalid type of handle!") py_obj_ptr = ctypes.cast( dl_managed_tensor.manager_ctx, ctypes.POINTER(ctypes.py_object) ) @@ -187,6 +187,7 @@ def _numpy_cuda_buffer_deleter(handle: ctypes.c_void_p) -> None: ctypes.pythonapi.Py_DecRef(ctypes.py_object(py_obj_ptr)) ctypes.pythonapi.PyMem_RawFree(handle) + @ctypes.CFUNCTYPE(None, ctypes.c_void_p) def _numpy_pycapsule_deleter(handle: ctypes.c_void_p) -> None: """A function to deallocate a pycapsule that wraps a cuda buffer.""" @@ -195,27 +196,15 @@ def _numpy_pycapsule_deleter(handle: ctypes.c_void_p) -> None: dl_managed_tensor = ctypes.pythonapi.PyCapsule_GetPointer( pycapsule, _c_str_dltensor ) - _numpy_cuda_buffer_deleter(dl_managed_tensor) + _numpy_buffer_deleter(dl_managed_tensor) ctypes.pythonapi.PyCapsule_SetDestructor(pycapsule, None) -def from_cuda_buffer(dev_ptr: int, shape: List[int], strides: List[int], dtype: torch.dtype, device: torch.device): - holder = _Holder(shape, strides) + +def from_cuda_buffer( + dev_ptr: int, shape: List[int], strides: List[int], dtype: DType, dev: Device +): size = ctypes.c_size_t(ctypes.sizeof(DLManagedTensor)) dl_managed_tensor = DLManagedTensor.from_address( ctypes.pythonapi.PyMem_RawMalloc(size) ) - dl_managed_tensor.dl_tensor.data = dev_ptr - dl_managed_tensor.dl_tensor.device = DLDevice(device) - dl_managed_tensor.dl_tensor.ndim = len(holder.shape) - dl_managed_tensor.dl_tensor.dtype = DLDataType.TYPE_MAP[dtype] - dl_managed_tensor.dl_tensor.shape = holder.shape - dl_managed_tensor.dl_tensor.strides = holder.strides - dl_managed_tensor.dl_tensor.byte_offset = 0 - dl_managed_tensor.manager_ctx = holder._as_manager_ctx() - dl_managed_tensor.deleter = _numpy_cuda_buffer_deleter - pycapsule = ctypes.pythonapi.PyCapsule_New( - ctypes.byref(dl_managed_tensor), - _c_str_dltensor, - _numpy_pycapsule_deleter, - ) - return pycapsule + return dl_managed_tensor.as_py(dev_ptr, shape, strides, dtype, dev) diff --git a/fastsafetensors/file_buffer.py b/fastsafetensors/file_buffer.py index 9244585..2c49b75 100644 --- a/fastsafetensors/file_buffer.py +++ b/fastsafetensors/file_buffer.py @@ -1,19 +1,16 @@ # Copyright 2024 IBM Inc. All rights reserved # SPDX-License-Identifier: Apache-2.0 -import os -import torch -import torch.distributed as dist -from typing import Dict, List, Tuple, Generator from collections import OrderedDict +from typing import Any, Dict, List, Optional, Tuple +from .frameworks import FrameworkOpBase, ProcessGroupBase, TensorBase +from .st_types import Device, DType from .tensor_factory import LazyTensorFactory -from .common import SingleGroup, paddle_loaded -if paddle_loaded: - import paddle + class FilesBufferOnDevice: - r""" Device buffer for .safetensors files. + r"""Device buffer for .safetensors files. Users can call get_tensor(), get_sharded(), etc. to instantiate (sharded) tensors from the device buffer. Note that for multi-process loading, users must follow the single-program multiple-data (SPMD) paradigm, which is common for torch.distributed programs. In other words, users must ensure that every worker process calls the methods here in the same order. @@ -24,31 +21,35 @@ class FilesBufferOnDevice: Args: rank_loaders (Dict): Tensor factories per rank, which hold device pointers for buffers. - pg (dist.ProcessGroup): process group for pytorch distributed. SingleGroup is available for single GPU use-cases. + pg (ProcessGroupBase): process group for calling distributed ops. auto_mem_delete (bool): automatically release device buffers when all the tensors are shuffled. Examples: See examples/run_single.py and examples/run_parallel.py. """ - def __init__(self, rank_loaders: Dict[int, List[LazyTensorFactory]], pg: dist.ProcessGroup, auto_mem_delete=True, framework="pytorch"): + + def __init__( + self, + rank_loaders: Dict[int, List[LazyTensorFactory]], + pg: ProcessGroupBase, + framework: FrameworkOpBase, + auto_mem_delete: bool = True, + ): + self.framework = framework self.rank_loaders: Dict[int, List[LazyTensorFactory]] = rank_loaders self.key_to_rank_lidx: Dict[str, Tuple[int, int]] = {} - self.instantiated: Dict[int, Dict[int, Dict[str, bool]]] = {} # rank, key name + self.instantiated: Dict[int, Dict[int, Dict[str, bool]]] = {} # rank, key name for rank, loaders in rank_loaders.items(): self.instantiated[rank] = {} for lidx, loader in enumerate(loaders): for key in loader.metadata.tensors.keys(): if key in self.key_to_rank_lidx: - raise Exception(f"FilesBufferOnDevice: key {key} must be unique among files") + raise Exception( + f"FilesBufferOnDevice: key {key} must be unique among files" + ) self.key_to_rank_lidx[key] = (rank, lidx) self.instantiated[rank][lidx] = {} - self.framework = framework - if self.framework == "pytorch" or isinstance(pg, SingleGroup): - self.pg = pg - self.group = None - elif paddle_loaded and self.framework == "paddle": - self.pg = pg.process_group - self.group = pg + self.pg = pg self.auto_mem_delete = auto_mem_delete and self.pg.size() > 1 def close(self): @@ -57,55 +58,95 @@ def close(self): loader.free_dev_ptrs() self.rank_loaders = {} - def get_filename(self, tensor_name: str)->str: + def get_filename(self, tensor_name: str) -> str: if tensor_name not in self.key_to_rank_lidx: - return None + return "" (rank, lidx) = self.key_to_rank_lidx[tensor_name] return self.rank_loaders[rank][lidx].metadata.src - def get_shape(self, tensor_name: str)->torch.Size: + def get_shape(self, tensor_name: str) -> List[int]: (rank, lidx) = self._get_rank_lidx(tensor_name) return self.rank_loaders[rank][lidx].metadata.tensors[tensor_name].shape - def _get_rank_lidx(self, tensor_name: str)->Tuple[int, int]: + def _get_rank_lidx(self, tensor_name: str) -> Tuple[int, int]: if tensor_name not in self.key_to_rank_lidx: raise ValueError(f"_get_rank: key {tensor_name} was not found in files") return self.key_to_rank_lidx[tensor_name] - def _get_tensor(self, rank: int, lidx: int, tensor_name: str, ret: torch.Tensor, device: torch.device, dtype: torch.dtype)->torch.Tensor: + def _get_tensor( + self, + rank: int, + lidx: int, + tensor_name: str, + ret: TensorBase, + device: Optional[Device], + dtype: DType, + ) -> TensorBase: loader = self.rank_loaders[rank][lidx] if self.auto_mem_delete: self.instantiated[rank][lidx][tensor_name] = True if len(self.instantiated[rank][lidx]) == len(loader.metadata.tensors): if loader.debug_log and self.pg.rank() == rank: - print(f"_get_tensor: free_dev_ptrs, lidx={lidx}, src={loader.metadata.src}") + print( + f"_get_tensor: free_dev_ptrs, lidx={lidx}, src={loader.metadata.src}" + ) loader.free_dev_ptrs() - if ret is None: - return ret - if (device is not None and (device != ret.device)) or (dtype is not None and (dtype != ret.dtype)): - ret = ret.to(device=device, dtype=dtype) - return ret + return ret.to(device=device, dtype=dtype) + + def get_sharded_wrapped( + self, + tensor_name: str, + dim: int, + device: Optional[Device] = None, + dtype: DType = DType.AUTO, + ) -> TensorBase: + (rank, lidix) = self._get_rank_lidx(tensor_name) + t = self.rank_loaders[rank][lidix].shuffle(self.pg, tensor_name, dim) + return self._get_tensor(rank, lidix, tensor_name, t, device, dtype) - def get_sharded(self, tensor_name: str, dim: int, device: torch.device=None, dtype: torch.dtype=None)->torch.Tensor: + def get_sharded( + self, + tensor_name: str, + dim: int, + device: Optional[Device] = None, + dtype: DType = DType.AUTO, + ) -> Any: """ partition a tensor instance with the key tensor_name at the dimension dim and return it. In multi-process loading, this eventually calls torch.distributed.scatter. A special dim is -1, which broadcast a tensor to all the ranks (== get_tensor()). """ - (rank, lidix) = self._get_rank_lidx(tensor_name) - t = self.rank_loaders[rank][lidix].shuffle(self.pg, tensor_name, dim, group=self.group) - return self._get_tensor(rank, lidix, tensor_name, t, device, dtype) - - def get_tensor(self, tensor_name: str, device: torch.device=None, dtype: torch.dtype=None)->torch.Tensor: + return self.get_sharded_wrapped(tensor_name, dim, device, dtype).get_raw() + + def get_tensor_wrapped( + self, + tensor_name: str, + device: Optional[Device] = None, + dtype: DType = DType.AUTO, + ) -> TensorBase: + return self.get_sharded_wrapped(tensor_name, -1, device, dtype) + + def get_tensor( + self, + tensor_name: str, + device: Optional[Device] = None, + dtype: DType = DType.AUTO, + ) -> Any: """ get a tensor instance with the key tensor_name from a local or remote rank. In multi-process loading, this eventually calls torch.distributed.broadcast. So, every rank will allocate the same tensor at each device memroy. In single-process loading, this directly instantiates a tensor from the device buffer with zero copy. """ - return self.get_sharded(tensor_name, -1, device, dtype) - - def push_tensor(self, tensor_name: str, dst_rank: int, device: torch.device=None, dtype: torch.dtype=None) -> torch.Tensor: + return self.get_tensor_wrapped(tensor_name, device, dtype).get_raw() + + def push_tensor( + self, + tensor_name: str, + dst_rank: int, + device: Optional[Device] = None, + dtype: DType = DType.AUTO, + ) -> Optional[Any]: """ push a tensor instance with the key tensor_name from a rank to a destination rank dst_rank. In multi-process loading, this eventually calls torch.distributed.send if the rank has the tensor instance. @@ -113,15 +154,20 @@ def push_tensor(self, tensor_name: str, dst_rank: int, device: torch.device=Non Other ranks do nothing. """ (rank, lidix) = self._get_rank_lidx(tensor_name) - t = self.rank_loaders[rank][lidix].push(self.pg, tensor_name, dst_rank, rank, group=self.group) - return self._get_tensor(rank, lidix, tensor_name, t, device, dtype) - - def get_sharded_packed_qkv(self, tensor_name: str, device: torch.device=None, dtype: torch.dtype=None)->torch.Tensor: - (rank, lidix) = self._get_rank_lidx(tensor_name) - t = self.rank_loaders[rank][lidix].shuffle_packed_qkv(self.pg, tensor_name, group=self.group) - return self._get_tensor(rank, lidix, tensor_name, t, device, dtype) - - def get_multi_cols(self, tensor_names: List[str], dim: int, device: torch.device=None, dtype: torch.dtype=None)->torch.Tensor: + t = self.rank_loaders[rank][lidix].push(self.pg, tensor_name, dst_rank, rank) + if t: + return self._get_tensor( + rank, lidix, tensor_name, t, device, dtype + ).get_raw() + return None + + def get_multi_cols( + self, + tensor_names: List[str], + dim: int, + device: Optional[Device] = None, + dtype: DType = DType.AUTO, + ) -> TensorBase: rank_lidixs: Dict[Tuple[int, int], List[str]] = {} for tensor_name in tensor_names: ranklidx = self._get_rank_lidx(tensor_name) @@ -129,13 +175,17 @@ def get_multi_cols(self, tensor_names: List[str], dim: int, device: torch.device rank_lidixs[ranklidx].append(tensor_name) else: rank_lidixs[ranklidx] = [tensor_name] - ts: List[torch.Tensor] = [] - for (rank, lidix), tns in sorted(rank_lidixs.items(), key=lambda x:x[0]): - ts.append(self.rank_loaders[rank][lidix].shuffle_multi_cols(self.pg, tns, dim,group=self.group)) + ts: List[TensorBase] = [] + for (rank, lidix), tns in sorted(rank_lidixs.items(), key=lambda x: x[0]): + ts.append( + self.rank_loaders[rank][lidix].shuffle_multi_cols(self.pg, tns, dim) + ) if len(ts) == 1: # fastpath: tensors at the same layer are often in the same file - return self._get_tensor(rank, lidix, rank_lidixs[(rank, lidix)][0], ts[0], device, dtype) - ret = torch.cat(ts, dim=dim) + return self._get_tensor( + rank, lidix, rank_lidixs[(rank, lidix)][0], ts[0], device, dtype + ) + ret = self.framework.concat_tensors(ts, dim=dim) if self.auto_mem_delete: for tensor_name in tensor_names: (rank, lidx) = self._get_rank_lidx(tensor_name) @@ -143,24 +193,26 @@ def get_multi_cols(self, tensor_names: List[str], dim: int, device: torch.device self.instantiated[rank][lidx][tensor_name] = True if len(self.instantiated[rank][lidx]) == len(loader.metadata.tensors): if loader.debug_log and self.pg.rank() == rank: - print(f"get_multi_cols: free_dev_ptrs, rank={rank}, lidx={lidx}, src={loader.metadata.src}") + print( + f"get_multi_cols: free_dev_ptrs, rank={rank}, lidx={lidx}, src={loader.metadata.src}" + ) loader.free_dev_ptrs() - if (device is not None and (device != ret.device)) or (dtype is not None and (dtype != ret.dtype)): - ret = ret.to(device=device, dtype=dtype) - return ret + return ret.to(device=device, dtype=dtype) - def as_dict(self, tensor_shard_dim: OrderedDict[str, int])->Dict[str, torch.Tensor]: - tensors: Dict[str, torch.Tensor] = {} + def as_dict(self, tensor_shard_dim: OrderedDict[str, int]) -> Dict[str, TensorBase]: + tensors: Dict[str, TensorBase] = {} for tensor_name, dim in tensor_shard_dim.items(): (rank, lidx) = self._get_rank_lidx(tensor_name) loader = self.rank_loaders[rank][lidx] - tensors[tensor_name] = loader.shuffle(self.pg, tensor_name, dim, group=self.group) + tensors[tensor_name] = loader.shuffle(self.pg, tensor_name, dim) if self.auto_mem_delete: self.instantiated[rank][lidx][tensor_name] = True if len(self.instantiated[rank][lidx]) == len(loader.metadata.tensors): if loader.debug_log and self.pg.rank() == rank: - print(f"as_dict: free_dev_ptrs, rank={rank}, src={loader.metadata.src}") + print( + f"as_dict: free_dev_ptrs, rank={rank}, src={loader.metadata.src}" + ) loader.free_dev_ptrs() if self.auto_mem_delete: - self.loaders = {} + self.rank_loaders = {} return tensors diff --git a/fastsafetensors/frameworks/__init__.py b/fastsafetensors/frameworks/__init__.py new file mode 100644 index 0000000..de13790 --- /dev/null +++ b/fastsafetensors/frameworks/__init__.py @@ -0,0 +1,177 @@ +# Copyright 2025 IBM Inc. All rights reserved +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Generic, List, Optional, TypeVar + +from ..cpp import gds_device_buffer +from ..st_types import Device, DType + + +@dataclass +class TensorBase: + device: Device + dtype: DType + + @abstractmethod + def get_raw(self) -> Any: + pass + + @abstractmethod + def contiguous(self) -> "TensorBase": + pass + + @abstractmethod + def to( + self, + device: Optional[Device] = None, + dtype: DType = DType.AUTO, + ) -> "TensorBase": + pass + + @abstractmethod + def clone(self) -> "TensorBase": + pass + + @abstractmethod + def detach(self) -> "TensorBase": + pass + + @abstractmethod + def view(self, dtype: DType) -> "TensorBase": + pass + + @abstractmethod + def __getitem__(self, _val) -> "TensorBase": + pass + + +T = TypeVar("T", bound=TensorBase) + + +class ProcessGroupBase(ABC, Generic[T]): + @abstractmethod + def size(self) -> int: + pass + + @abstractmethod + def rank(self) -> int: + pass + + @abstractmethod + def broadcast(self, dst: T, rank: int) -> None: + pass + + @abstractmethod + def scatter( + self, + dst: T, + scatter_list: List[T], + src: int, + ) -> None: + pass + + @abstractmethod + def send( + self, + t: T, + dst_rank: int, + tag: int, + ) -> None: + pass + + @abstractmethod + def recv( + self, + t: T, + src_rank: int, + tag: int, + ) -> None: + pass + + +K = TypeVar("K", bound=ProcessGroupBase) + + +class FrameworkOpBase(ABC, Generic[T, K]): + @abstractmethod + def get_name(self) -> str: + pass + + @abstractmethod + def get_device(self, device: str, pg: K) -> Device: + pass + + @abstractmethod + def set_device(self, device: Device) -> None: + pass + + @abstractmethod + def alloc_tensor_memory(self, length: int, dev: Device) -> gds_device_buffer: + pass + + @abstractmethod + def free_tensor_memory(self, gbuf: gds_device_buffer, dev: Device) -> None: + pass + + @abstractmethod + def get_empty_tensor(self, shape: List[int], dtype: DType, device: Device) -> T: + pass + + @abstractmethod + def concat_tensors(self, tensors: List[T], dim: int) -> T: + pass + + @abstractmethod + def copy_tensor(self, dst: T, src: T) -> None: + pass + + @abstractmethod + def get_dtype_size(self, dtype: DType) -> int: + pass + + @abstractmethod + def from_dlpack(self, dl_tensor: Any, device: Device, dtype: DType) -> T: + pass + + @abstractmethod + def get_cuda_ver(self) -> str: + pass + + @abstractmethod + def get_device_ptr_align(self) -> int: + pass + + @abstractmethod + def as_workaround_dtype(self, dtype: DType) -> DType: + pass + + @abstractmethod + def get_process_group(self, pg: Optional[Any]) -> ProcessGroupBase: + pass + + @abstractmethod + def is_equal(self, wrapped: T, real: Any) -> bool: + pass + + @abstractmethod + def randn(self, s: tuple, device: Device, dtype: DType) -> T: + pass + + @abstractmethod + def support_fp8(self) -> bool: + pass + + +def get_framework_op(name: str) -> FrameworkOpBase: + if name == "pt" or name == "pytorch" or name == "torch": + from ._torch import TorchOp + + return TorchOp() + elif name == "paddle": + from ._paddle import PaddleOp + + return PaddleOp() + else: + raise Exception(f"Unknown framework name: {name}") diff --git a/fastsafetensors/frameworks/_paddle.py b/fastsafetensors/frameworks/_paddle.py new file mode 100644 index 0000000..b089d00 --- /dev/null +++ b/fastsafetensors/frameworks/_paddle.py @@ -0,0 +1,249 @@ +# Copyright 2025 IBM Inc. All rights reserved +# SPDX-License-Identifier: Apache-2.0 + +try: + import paddle + import paddle.distributed as pdist + from paddle.distributed.communication.group import Group + from paddle.framework import core as paddle_core +except ImportError as e: + raise ImportError( + "could not import paddle, paddle_core, or numpy. Please install them." + ) from e + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from ..common import SingleGroup +from ..cpp import cpu_free, cpu_malloc, gds_device_buffer, gpu_free, gpu_malloc +from ..st_types import Device, DeviceType, DType +from . import FrameworkOpBase, ProcessGroupBase, TensorBase + +dtype_convert: Dict[DType, Any] = { + DType.BOOL: paddle.bool, + DType.I8: paddle.uint8, + DType.I8: paddle.int8, + DType.I16: paddle.int16, + DType.U16: paddle.bfloat16, + DType.I32: paddle.int32, + DType.U32: paddle.int32, + DType.I64: paddle.int64, + DType.U64: paddle.int64, + DType.F16: paddle.float16, + DType.BF16: paddle.bfloat16, + DType.F32: paddle.float32, + DType.F64: paddle.float64, +} +need_workaround_dtypes: Dict[DType, DType] = { + DType.F8_E5M2: DType.I8, + DType.F8_E4M3: DType.I8, +} + +if hasattr(paddle, "float8_e5m2"): + dtype_convert[DType.F8_E5M2] = paddle.float8_e5m2 +if hasattr(paddle, "float8_e4m3fn"): + dtype_convert[DType.F8_E4M3] = paddle.float8_e4m3fn + + +@dataclass +class PaddleTensor(TensorBase): + real_tensor: paddle.Tensor + + def get_raw(self) -> paddle.Tensor: + return self.real_tensor + + def contiguous(self) -> "PaddleTensor": + return PaddleTensor(self.device, self.dtype, self.real_tensor.contiguous()) + + def to( + self, + device: Optional[Device] = None, + dtype: DType = DType.AUTO, + ) -> "PaddleTensor": + to_dev: Optional[str] = None + if device is not None and self.device != device: + to_dev = device.as_str() + else: + device = self.device + to_dtype: Optional[paddle.dtype] = None + if dtype != DType.AUTO and (dtype != self.dtype): + to_dtype = dtype_convert[dtype] + else: + dtype = self.dtype + if to_dev is not None or to_dtype is not None: + return PaddleTensor( + device, dtype, self.real_tensor.to(device=to_dev, dtype=to_dtype) + ) + return self + + def clone(self) -> "PaddleTensor": + return PaddleTensor(self.device, self.dtype, self.real_tensor.clone()) + + def detach(self) -> "PaddleTensor": + return PaddleTensor(self.device, self.dtype, self.real_tensor.detach()) + + def view(self, dtype: DType) -> "PaddleTensor": + t2 = self.real_tensor.view(dtype_convert[dtype]) + return PaddleTensor(self.device, dtype, t2) + + def __getitem__(self, _val) -> "PaddleTensor": + return PaddleTensor(self.device, self.dtype, self.real_tensor[_val]) + + +@dataclass +class PaddleProcessGroup(ProcessGroupBase[PaddleTensor]): + real_pg: Optional[Group] + + def size(self) -> int: + return self.real_pg.process_group.size() if self.real_pg else 1 + + def rank(self) -> int: + return self.real_pg.process_group.rank() if self.real_pg else 0 + + def broadcast(self, dst: PaddleTensor, rank: int) -> None: + if self.real_pg: + pdist.broadcast(dst.real_tensor, rank, group=self.real_pg) + + def scatter( + self, + dst: PaddleTensor, + scatter_list: List[PaddleTensor], + src: int, + ) -> None: + if self.real_pg: + sl = [t.real_tensor for t in scatter_list] + pdist.scatter( + dst.real_tensor, + tensor_list=sl, + src=src, + group=self.real_pg, + ) + + def send( + self, + t: PaddleTensor, + dst_rank: int, + tag: int, + ) -> None: + if self.real_pg: + pdist.send(t.real_tensor, dst_rank, group=self.real_pg) + + def recv( + self, + t: PaddleTensor, + src_rank: int, + tag: int, + ) -> None: + if self.real_pg: + pdist.recv(t.real_tensor, src_rank, group=self.real_pg) + + +class PaddleOp(FrameworkOpBase[PaddleTensor, PaddleProcessGroup]): + def get_name(self) -> str: + return "paddle" + + def get_device(self, device: str, pg: PaddleProcessGroup) -> Device: + dev_index: Optional[int] = None + try: + dev_split = device.split(":", 1) + dev_type = DeviceType(dev_split[0].lower()) + if dev_type != DeviceType.CPU: + dev_index = 0 + if len(dev_split) > 1: + dev_index = int(dev_split[1]) + except ValueError: + raise ValueError(f"Unknown device: {device}") + + if paddle.device.cuda.device_count() > 0 and pg.real_pg is not None: + # For single (gpu:x, gpu) + # gpu:x, like gpu:0, gpu:1, ... + # For distributed + # The gpu determines the current rank + # rank0 use gpu:0, rank1 use gpu:1 + dev_index = pg.rank() % paddle.device.cuda.device_count() + return Device(dev_type, dev_index) + + def set_device(self, device: Device) -> None: + if device.type != DeviceType.CPU: + paddle.set_device(device.as_str()) + + def alloc_tensor_memory(self, length: int, dev: Device) -> gds_device_buffer: + if dev.type == DeviceType.GPU: + rbuf = gpu_malloc(length) + else: + rbuf = cpu_malloc(length) + return gds_device_buffer(rbuf, length, dev.type == DeviceType.GPU) + + def free_tensor_memory(self, gbuf: gds_device_buffer, dev: Device) -> None: + if dev.type == DeviceType.GPU: + gpu_free(gbuf.get_base_address()) + else: + cpu_free(gbuf.get_base_address()) + + def get_empty_tensor( + self, shape: List[int], dtype: DType, device: Device + ) -> PaddleTensor: + dst = paddle.to_tensor( + paddle.empty(shape=shape, dtype=dtype_convert[dtype]), + place=device.as_str(), + ) + return PaddleTensor(device, dtype, dst) + + def concat_tensors(self, tensors: List[PaddleTensor], dim) -> PaddleTensor: + ts = [tensor.real_tensor for tensor in tensors] + return PaddleTensor( + tensors[0].device, tensors[0].dtype, paddle.concat(ts, axis=dim) + ) + + def get_dtype_size(self, dtype: DType) -> int: + return paddle_core.size_of_dtype(dtype_convert[dtype]) + + def from_dlpack(self, dl_tensor: Any, device: Device, dtype: DType) -> PaddleTensor: + return PaddleTensor(device, dtype, paddle.utils.dlpack.from_dlpack(dl_tensor)) + + def copy_tensor(self, dst: PaddleTensor, src: PaddleTensor) -> None: + paddle.assign(src.real_tensor, output=dst.real_tensor) + dst.dtype = src.dtype + dst.device = src.device + + def get_cuda_ver(self) -> str: + return ( + str(paddle.version.cuda()) + if paddle.device.is_compiled_with_cuda() + else "0.0" + ) + + def get_device_ptr_align(self) -> int: + CUDA_PTR_ALIGN: int = 16 + return CUDA_PTR_ALIGN + + def as_workaround_dtype(self, dtype: DType) -> DType: + if dtype in need_workaround_dtypes: + return need_workaround_dtypes[dtype] + return dtype + + def get_process_group(self, pg: Optional[Any]) -> PaddleProcessGroup: + if pg is not None: + if isinstance(pg, SingleGroup): + pg = None + elif not isinstance(pg, Group): + raise Exception( + "pg must be an instance of paddle.distributed.communication.group.Group" + ) + return PaddleProcessGroup(pg) + + # for testing + def is_equal(self, wrapped: PaddleTensor, real: Any) -> bool: + if isinstance(real, paddle.Tensor): + return paddle.all(wrapped.real_tensor.equal(real)) + raise Exception("real is not paddle.Tensor") + + def randn(self, s: tuple, device: Device, dtype: DType) -> PaddleTensor: + return PaddleTensor( + device, + dtype, + paddle.randn(s, dtype=dtype_convert[dtype]).to(device=device.as_str()), + ) + + def support_fp8(self) -> bool: + return DType.F8_E5M2 in dtype_convert diff --git a/fastsafetensors/frameworks/_torch.py b/fastsafetensors/frameworks/_torch.py new file mode 100644 index 0000000..1b77987 --- /dev/null +++ b/fastsafetensors/frameworks/_torch.py @@ -0,0 +1,220 @@ +# Copyright 2025 IBM Inc. All rights reserved +# SPDX-License-Identifier: Apache-2.0 + +try: + import torch + import torch.distributed as dist +except ImportError as e: + raise ImportError("could not import torch. Please install it.") from e + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from ..common import SingleGroup +from ..cpp import cpu_free, cpu_malloc, gds_device_buffer +from ..st_types import Device, DeviceType, DType +from . import FrameworkOpBase, ProcessGroupBase, TensorBase + +dtype_convert: Dict[DType, Any] = { + DType.BOOL: torch.bool, + DType.U8: torch.uint8, + DType.I8: torch.int8, + DType.I16: torch.int16, + DType.I32: torch.int32, + DType.I64: torch.int64, + DType.F16: torch.float16, + DType.BF16: torch.bfloat16, + DType.F32: torch.float32, + DType.F64: torch.float64, +} +need_workaround_dtypes: Dict[DType, DType] = { + DType.F8_E5M2: DType.I8, + DType.F8_E4M3: DType.I8, +} + +if hasattr(torch, "float8_e5m2"): + dtype_convert[DType.F8_E5M2] = torch.float8_e5m2 +if hasattr(torch, "float8_e4m3fn"): + dtype_convert[DType.F8_E4M3] = torch.float8_e4m3fn +if hasattr(torch, "uint16"): + dtype_convert[DType.U16] = torch.uint16 +if hasattr(torch, "uint32"): + dtype_convert[DType.U32] = torch.uint32 +if hasattr(torch, "uint64"): + dtype_convert[DType.U64] = torch.uint64 + + +@dataclass +class TorchTensor(TensorBase): + real_tensor: torch.Tensor + + def get_raw(self) -> torch.Tensor: + return self.real_tensor + + def contiguous(self) -> "TorchTensor": + return TorchTensor(self.device, self.dtype, self.real_tensor.contiguous()) + + def to( + self, + device: Optional[Device] = None, + dtype: DType = DType.AUTO, + ) -> "TorchTensor": + to_dev: Optional[str] = None + if device is not None and self.device != device: + to_dev = device.as_str() + else: + device = self.device + to_dtype: Optional[torch.dtype] = None + if dtype != DType.AUTO and (dtype != self.dtype): + to_dtype = dtype_convert[dtype] + else: + dtype = self.dtype + if to_dev is not None or to_dtype is not None: + return TorchTensor( + device, dtype, self.real_tensor.to(device=to_dev, dtype=to_dtype) + ) + return self + + def clone(self) -> "TorchTensor": + return TorchTensor(self.device, self.dtype, self.real_tensor.clone()) + + def detach(self) -> "TorchTensor": + return TorchTensor(self.device, self.dtype, self.real_tensor.detach()) + + def view(self, dtype: DType) -> "TorchTensor": + t2 = self.real_tensor.view(dtype_convert[dtype]) + return TorchTensor(self.device, dtype, t2) + + def __getitem__(self, _val) -> "TorchTensor": + return TorchTensor(self.device, self.dtype, self.real_tensor[_val]) + + +@dataclass +class TorchProcessGroup(ProcessGroupBase[TorchTensor]): + real_pg: Optional[dist.ProcessGroup] + + def size(self) -> int: + return self.real_pg.size() if self.real_pg else 1 + + def rank(self) -> int: + return self.real_pg.rank() if self.real_pg else 0 + + def broadcast(self, dst: TorchTensor, rank: int) -> None: + if self.real_pg: + dist.broadcast(dst.real_tensor, rank, group=self.real_pg) + + def scatter( + self, + dst: TorchTensor, + scatter_list: List[TorchTensor], + src: int, + ) -> None: + if self.real_pg: + sl = [t.real_tensor for t in scatter_list] + dist.scatter(dst.real_tensor, scatter_list=sl, src=src, group=self.real_pg) + + def send( + self, + t: TorchTensor, + dst_rank: int, + tag: int, + ): + if self.real_pg: + dist.send(t.real_tensor, dst_rank, group=self.real_pg, tag=tag) + + def recv( + self, + t: TorchTensor, + src_rank: int, + tag: int, + ): + if self.real_pg: + dist.recv(t.real_tensor, src_rank, group=self.real_pg, tag=tag) + + +class TorchOp(FrameworkOpBase[TorchTensor, TorchProcessGroup]): + def get_name(self) -> str: + return "pytorch" + + def get_device(self, device: str, pg: TorchProcessGroup) -> Device: + dev = torch.device(device) + return Device(DeviceType(dev.type), dev.index) + + def set_device(self, device: Device) -> None: + if device.type != DeviceType.CPU: + torch.cuda.set_device(device.as_str()) + + def alloc_tensor_memory(self, length: int, dev: Device) -> gds_device_buffer: + if dev.type == DeviceType.CUDA: + rbuf = torch.cuda.caching_allocator_alloc(length) + else: + rbuf = cpu_malloc(length) + return gds_device_buffer(rbuf, length, dev.type == DeviceType.CUDA) + + def free_tensor_memory(self, gbuf: gds_device_buffer, dev: Device): + if dev.type == DeviceType.CUDA: + torch.cuda.caching_allocator_delete(gbuf.get_base_address()) + else: + cpu_free(gbuf.get_base_address()) + + def get_empty_tensor( + self, shape: List[int], dtype: DType, device: Device + ) -> TorchTensor: + dst = torch.empty( + size=shape, dtype=dtype_convert[dtype], device=device.as_str() + ) + return TorchTensor(device, dtype, dst) + + def concat_tensors(self, tensors: List[TorchTensor], dim: int) -> TorchTensor: + ts = [tensor.real_tensor for tensor in tensors] + return TorchTensor(tensors[0].device, tensors[0].dtype, torch.cat(ts, dim=dim)) + + def get_dtype_size(self, dtype: DType) -> int: + return dtype_convert[dtype].itemsize + + def from_dlpack(self, dl_tensor: Any, device: Device, dtype: DType) -> TorchTensor: + t = torch.from_dlpack(dl_tensor) + return TorchTensor(device, dtype, t) + + def copy_tensor(self, dst: TorchTensor, src: TorchTensor): + dst.real_tensor.copy_(src.real_tensor) + + def get_cuda_ver(self) -> str: + if torch.cuda.is_available(): + return str(torch.version.cuda) + return "0.0" + + def get_device_ptr_align(self) -> int: + CUDA_PTR_ALIGN: int = 16 + return CUDA_PTR_ALIGN + + def as_workaround_dtype(self, dtype: DType) -> DType: + if dtype in need_workaround_dtypes: + return need_workaround_dtypes[dtype] + return dtype + + def get_process_group(self, pg: Optional[Any]) -> TorchProcessGroup: + if pg is not None: + if isinstance(pg, SingleGroup): + pg = None + elif not isinstance(pg, dist.ProcessGroup): + raise Exception( + "pg must be an instance of torch.disributed.ProcessGroup" + ) + return TorchProcessGroup(pg) + + # for testing + def is_equal(self, wrapped: TorchTensor, real: Any) -> bool: + if isinstance(real, torch.Tensor): + return bool(torch.all(wrapped.real_tensor.eq(real))) + raise Exception("real is not torch.Tensor") + + def randn(self, s: tuple, device: Device, dtype: DType) -> TorchTensor: + return TorchTensor( + device, + dtype, + torch.randn(s, device=device.as_str(), dtype=dtype_convert[dtype]), + ) + + def support_fp8(self) -> bool: + return DType.F8_E5M2 in dtype_convert diff --git a/fastsafetensors/loader.py b/fastsafetensors/loader.py index a26cbb4..8e4b9b7 100644 --- a/fastsafetensors/loader.py +++ b/fastsafetensors/loader.py @@ -1,120 +1,101 @@ # Copyright 2024 IBM Inc. All rights reserved # SPDX-License-Identifier: Apache-2.0 -import torch -import torch.distributed as dist -import os import math -from . import cpp as fstcpp -from typing import List, Dict, Tuple, OrderedDict, Union import warnings +from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union -from .common import paddle_loaded, SafeTensorsMetadata, ALIGN, CUDA_PTR_ALIGN, TensorFrame, get_device_numa_node, SingleGroup -from .tensor_factory import LazyTensorFactory +from . import cpp as fstcpp +from .common import SafeTensorsMetadata, TensorFrame, get_device_numa_node from .file_buffer import FilesBufferOnDevice -if paddle_loaded: - import paddle +from .frameworks import TensorBase, get_framework_op +from .st_types import DeviceType, DType +from .tensor_factory import LazyTensorFactory + +gl_set_numa = False -initialized: bool = False -loaded_nvidia: bool = False -if not loaded_nvidia: - fstcpp.load_nvidia_functions() - loaded_nvidia = True +loaded_nvidia = False -support_framework = ["pytorch", "pt"] -if paddle_loaded: - support_framework.append("paddle") class SafeTensorsFileLoader: - r""" Load .safetensors files lazily. + r"""Load .safetensors files lazily. Args: - pg (dist.ProcessGroup): process group for pytorch distributed. SingleGroup is available for single GPU use-cases. - devcie (torch.device): target device. + devcie (str): target device. + pg (Optional[Any]): process group-like objects for distributed. None for single GPU use-cases. bbuf_size_kb (int): bounce buffer size for file copies. - max_pinned_memory_in_kb (int): maximum KiB of pinned memory for GDS configuration. max_threads (int): maximum number of threads for memory copies. nogds (bool): if True, trun off GDS and fallback to pread with bounce buffer. debug_log (bool): enable debug logs. Examples: - >> from fastsafetensors import SafeTensorsFileLoader, SingleGroup + >> from fastsafetensors import SafeTensorsFileLoader >> src_files = download(target_dir, "gpt2") - >> loader = SafeTensorsFileLoader(SingleGroup, torch.device("cpu"), nogds=True, debug_log=True) + >> loader = SafeTensorsFileLoader(Device("cpu"), nogds=True, debug_log=True) >> loader.add_filenames({0: src_files}) >> bufs = loader.copy_files_to_device() >> print(bufs.get_tensor(loader.get_keys()[0])) >> loader.close() """ - def __init__(self, pg: dist.ProcessGroup, device: torch.device, bbuf_size_kb: int = 16 * 1024, max_pinned_memory_in_kb: int = 64 * 1024 * 1024, max_threads: int=16, nogds: bool=False, debug_log: bool=False, framework="pytorch"): - if framework not in support_framework: - raise NotImplementedError(f"fastsafetensors only supports {support_framework} framework") - self.device = device + + def __init__( + self, + pg: Optional[Any], + device: str = "cpu", + bbuf_size_kb: int = 16 * 1024, + max_threads: int = 16, + nogds: bool = False, + set_numa: bool = True, + debug_log: bool = False, + framework="pytorch", + ): + self.framework = get_framework_op(framework) + self.pg = self.framework.get_process_group(pg) + self.device = self.framework.get_device(device, self.pg) self.debug_log = debug_log self.meta: Dict[str, Tuple[SafeTensorsMetadata, int]] = {} - self.need_gds_close = False - self.frames: OrderedDict[str, TensorFrame] = {} - self.framework = framework - if self.framework == "pytorch" or isinstance(pg, SingleGroup): - self.pg = pg - self.group = pg - elif paddle_loaded and self.framework == "paddle": - self.pg = pg.process_group - self.group = pg - self.nogds = nogds - global initialized - if not initialized: - fstcpp.set_debug_log(debug_log) - if self.framework == "pytorch": - d_id = device.index - elif paddle_loaded and self.framework == "paddle": - if device == "cpu": - d_id = None - else: - if isinstance(self.pg, SingleGroup): - # For single (gpu:x, gpu) - # gpu:x, like gpu:0, gpu:1, ... - d_id = device.split(":") - d_id = int(d_id[1]) if len(d_id) == 2 else 0 - else: - # For distributed - # The gpu determines the current rank - # rank0 use gpu:0, rank1 use gpu:1 - d_id = self.pg.rank() % paddle.device.cuda.device_count() - self.device = f"gpu:{d_id}" - node = get_device_numa_node(d_id) + self.frames = OrderedDict[str, TensorFrame]() + global loaded_nvidia + if not loaded_nvidia: + fstcpp.load_nvidia_functions() + if fstcpp.init_gds() != 0: + raise Exception(f"[FAIL] init_gds()") + loaded_nvidia = True + global gl_set_numa + if not gl_set_numa and set_numa: + node = get_device_numa_node(self.device.index) if node is not None: fstcpp.set_numa_node(node) - if False and fstcpp.is_cufile_found() and not nogds: # TODO: init_gds should be called but too slow for parallel initialization - if fstcpp.init_gds(bbuf_size_kb, max_pinned_memory_in_kb) != 0: - raise Exception(f"[FAIL] init_gds max_io_block_in_kb={max_io_block_in_kb}, max_pinned_memory_in_kb={max_pinned_memory_in_kb}") - self.need_gds_close = True - initialized = True - device_is_not_cpu = not (paddle_loaded and self.framework == "paddle" and device == "cpu") and not (self.framework == "pytorch" and device.type == "cpu") + gl_set_numa = True + fstcpp.set_debug_log(debug_log) + device_is_not_cpu = self.device.type != DeviceType.CPU if device_is_not_cpu and not fstcpp.is_cuda_found(): raise Exception("[FAIL] libcudart.so does not exist") if not fstcpp.is_cufile_found() and not nogds: - warnings.warn("libcufile.so does not exist but nogds is False. use nogds=True", UserWarning) + warnings.warn( + "libcufile.so does not exist but nogds is False. use nogds=True", + UserWarning, + ) nogds = True + self.reader: Union[fstcpp.nogds_file_reader, fstcpp.gds_file_reader] if nogds: - self.reader = fstcpp.nogds_file_reader(False, bbuf_size_kb, max_threads, device_is_not_cpu) + self.reader = fstcpp.nogds_file_reader( + False, bbuf_size_kb, max_threads, device_is_not_cpu + ) else: self.reader = fstcpp.gds_file_reader(max_threads, device_is_not_cpu) - self.nogds = nogds def reset(self): self.frames = {} self.meta = {} def close(self): - if self.need_gds_close: - fstcpp.close_gds() - self.need_gds_close = False + self.reset() def get_keys(self) -> List[str]: - return self.frames.keys() + return list(self.frames.keys()) - def get_shape(self, tensor_name: str) -> torch.Size: + def get_shape(self, tensor_name: str) -> List[int]: return self.frames[tensor_name].shape def add_filenames(self, filenames: Dict[int, List[str]]): @@ -129,7 +110,7 @@ def add_filenames(self, filenames: Dict[int, List[str]]): for rank in filenames.keys(): next_idx = rank_next_idx[rank] if next_idx < len(filenames[rank]): - realpath = filenames[rank][next_idx] #os.path.realpath(filename) + realpath = filenames[rank][next_idx] # os.path.realpath(filename) metadata = SafeTensorsMetadata.from_file(realpath, self.framework) self.meta[realpath] = (metadata, rank) self.frames.update(metadata.tensors) @@ -139,18 +120,18 @@ def add_filenames(self, filenames: Dict[int, List[str]]): else: completed += 1 - def copy_files_to_device(self, dtype: torch.dtype=None, use_buf_register: bool=True, max_copy_block_size: int=16*1024*1024*1024)->FilesBufferOnDevice: + def copy_files_to_device( + self, + dtype: DType = DType.AUTO, + use_buf_register: bool = True, + max_copy_block_size: int = 16 * 1024 * 1024 * 1024, + ) -> FilesBufferOnDevice: """ trigger copying all the files to device buffers. At this moment, we do not instantiate tensors but just creating copies at device buffers with or without GDS. Users can instantiate and/or partition tensors with FilesBufferOnDevice returned by this function. """ - if self.framework == "pytorch": - if self.device.type != "cpu": - torch.cuda.set_device(self.device) - elif paddle_loaded and self.framework == "paddle": - if "gpu" in self.device: - paddle.set_device(self.device) + self.framework.set_device(self.device) need_wait: List[LazyTensorFactory] = [] factories: Dict[int, List[LazyTensorFactory]] = {} @@ -159,22 +140,31 @@ def copy_files_to_device(self, dtype: torch.dtype=None, use_buf_register: bool=T factory_idx_bits = math.ceil(math.log2(len(self.meta) + 1)) lidx = 1 - + for _, (meta, rank) in sorted(self.meta.items(), key=lambda x: x[0]): self_rank = self.pg.rank() == rank - factory = LazyTensorFactory(meta, self.device, rank, self_rank, factory_idx_bits, lidx, self.nogds, self.reader, self.debug_log) + factory = LazyTensorFactory( + meta, + self.device, + rank, + self_rank, + factory_idx_bits, + lidx, + self.reader, + self.framework, + self.debug_log, + ) factory.submit_io(use_buf_register, max_copy_block_size) factories[rank].append(factory) if self_rank: need_wait.append(factory) lidx += 1 for factory in need_wait: - factory.wait_io(dtype=dtype, noalign=self.nogds) - if self.framework == "pytorch": - return FilesBufferOnDevice(factories, pg=self.pg) - elif paddle_loaded and self.framework == "paddle": - return FilesBufferOnDevice(factories, pg=self.group, framework=self.framework) - return None + factory.wait_io( + dtype=dtype, noalign=isinstance(self.reader, fstcpp.nogds_file_reader) + ) + return FilesBufferOnDevice(factories, pg=self.pg, framework=self.framework) + class fastsafe_open: """ @@ -183,40 +173,49 @@ class fastsafe_open: Args: filenames (:obj:`str`|`list[str]`|`dict[int, str]`): The filename(s) or rank-file map to open - framework (:obj:`str`): `pt` and `paddle` are only supported currently + framework (:obj:`str`): `pt`, `pytorch`, and `paddle` are only supported currently device (:obj:`str`, defaults to :obj:`"cpu"`): The device on which you want the tensors. """ - def __init__(self, filenames: Union[str, List[str], Dict[int, str]], - framework: str="pt", pg: dist.ProcessGroup=SingleGroup(), - device: Union[str, torch.device]="cpu", - nogds: bool=False, - debug_log: bool=False, - max_copy_block_size: int=16*1024*1024*1024): - if framework not in support_framework: - raise NotImplementedError("pytorch is only a framework that current fastsafetensors supports") - if isinstance(device, str) and framework == "pt": - device = torch.device(device) - self.loader = SafeTensorsFileLoader(pg, device, nogds=nogds, debug_log=debug_log, framework= framework if framework != "pt" else "pytorch") + def __init__( + self, + filenames: Union[str, List[str], Dict[int, List[str]]], + framework: str = "pt", + pg: Optional[Any] = None, + device: str = "cpu", + nogds: bool = False, + debug_log: bool = False, + max_copy_block_size: int = 16 * 1024 * 1024 * 1024, + ): + self.loader = SafeTensorsFileLoader( + pg, device, nogds=nogds, debug_log=debug_log, framework=framework + ) + file_dict: Dict[int, List[str]] = {} if isinstance(filenames, str): - filenames = [filenames] + file_dict = {0: [filenames]} if isinstance(filenames, list): - self.loader.add_filenames({0: filenames}) + file_dict = {0: filenames} elif isinstance(filenames, dict): - self.loader.add_filenames(filenames) - self.fb = self.loader.copy_files_to_device(max_copy_block_size=max_copy_block_size) + file_dict = filenames + self.loader.add_filenames(file_dict) + self.fb = self.loader.copy_files_to_device( + max_copy_block_size=max_copy_block_size + ) - def metadata(self)->Dict[str, Dict[str, str]]: + def metadata(self) -> Dict[str, Dict[str, str]]: ret = {} for filename, (metadata, _) in self.loader.meta.items(): ret[filename] = metadata.metadata return ret - def get_keys(self)->List[str]: - return self.fb.key_to_rank_lidx.keys() + def keys(self) -> List[str]: + return list(self.fb.key_to_rank_lidx.keys()) + + def get_tensor_wrapped(self, name: str) -> TensorBase: + return self.fb.get_tensor_wrapped(name) - def get_tensor(self, name: str)->torch.Tensor: - return self.fb.get_tensor(name) + def get_tensor(self, name: str) -> Any: + return self.get_tensor_wrapped(name).get_raw() def __enter__(self): return self diff --git a/fastsafetensors/st_types.py b/fastsafetensors/st_types.py new file mode 100644 index 0000000..dfead2e --- /dev/null +++ b/fastsafetensors/st_types.py @@ -0,0 +1,62 @@ +# Copyright 2025 IBM Inc. All rights reserved +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from enum import Enum +from typing import Optional + + +class DeviceType(Enum): + CPU = "cpu" + CUDA = "cuda" + GPU = "gpu" + + +@dataclass(frozen=True) +class Device: + type: DeviceType = DeviceType.CPU + index: Optional[int] = None + + def as_str(self) -> str: + if self.index is None: + return self.type.value + return f"{self.type.value}:{self.index}" + + @classmethod + def from_str(cls, s: str) -> "Device": + if ":" in s: + type_str, index_str = s.split(":", 1) + try: + dev_type = DeviceType(type_str.lower()) + except ValueError: + raise ValueError(f"Unknown device type: {type_str}") + try: + index = int(index_str) + except ValueError: + raise ValueError(f"Invalid index: {index_str}") + return cls(type=dev_type, index=index) + else: + try: + dev_type = DeviceType(s.lower()) + except ValueError: + raise ValueError(f"Unknown device type: {s}") + return cls(type=dev_type, index=None) + + +class DType(Enum): + BOOL = "BOOL" + I8 = "I8" + I16 = "I16" + I32 = "I32" + I64 = "I64" + U8 = "U8" + U16 = "U16" + U32 = "U32" + U64 = "U64" + F16 = "F16" + F32 = "F32" + F64 = "F64" + BF16 = "BF16" + F8_E5M2 = "F8_E5M2" + F8_E4M3 = "F8_E4M3" + AUTO = "AUTO" diff --git a/fastsafetensors/tensor_factory.py b/fastsafetensors/tensor_factory.py index 46b0174..43b7403 100644 --- a/fastsafetensors/tensor_factory.py +++ b/fastsafetensors/tensor_factory.py @@ -1,33 +1,46 @@ # Copyright 2024 IBM Inc. All rights reserved # SPDX-License-Identifier: Apache-2.0 -import torch -import torch.distributed as dist -from typing import Dict, List, Tuple from collections import OrderedDict +from typing import Dict, List, Optional, Tuple, Union from . import cpp as fstcpp -from .common import SafeTensorsMetadata, free_tensor_memory, paddle_loaded +from .common import SafeTensorsMetadata from .copier.gds import GdsFileCopier from .copier.nogds import NoGdsFileCopier +from .frameworks import FrameworkOpBase, ProcessGroupBase, TensorBase +from .st_types import Device, DType -if paddle_loaded: - import paddle - import paddle.distributed as pdist class LazyTensorFactory: - def __init__(self, metadata: SafeTensorsMetadata, device: torch.device, rank: int, local_rank: bool, factory_idx_bits: int, lidx: int, nogds: bool, reader, debug_log: bool=False): + def __init__( + self, + metadata: SafeTensorsMetadata, + device: Device, + rank: int, + local_rank: bool, + factory_idx_bits: int, + lidx: int, + reader: Union[fstcpp.gds_file_reader, fstcpp.nogds_file_reader], + framework: FrameworkOpBase, + debug_log: bool = False, + ): + self.framework = framework self.metadata = metadata self.device = device - if not local_rank: - self.copier = None - elif nogds: - self.copier = NoGdsFileCopier(metadata, device, reader, debug_log) - else: - self.copier = GdsFileCopier(metadata, device, reader, debug_log) - self.tensors: Dict[str, torch.Tensor] = {} - self.shuffled: Dict[str, torch.Tensor] = {} - self.gbuf: fstcpp.gds_device_buffer = None + self.copier: Optional[Union[NoGdsFileCopier, GdsFileCopier]] = None + if local_rank: + if isinstance(reader, fstcpp.nogds_file_reader): + self.copier = NoGdsFileCopier( + metadata, device, reader, framework, debug_log + ) + else: + self.copier = GdsFileCopier( + metadata, device, reader, framework, debug_log + ) + self.tensors: Dict[str, TensorBase] = {} + self.shuffled: Dict[str, TensorBase] = {} + self.gbuf: Optional[fstcpp.gds_device_buffer] = None self.debug_log = debug_log self.rank = rank self.factory_idx_bits = factory_idx_bits @@ -38,52 +51,61 @@ def submit_io(self, use_buf_register: bool, max_copy_block_size: int): if self.copier is not None: self.gbuf = self.copier.submit_io(use_buf_register, max_copy_block_size) - def wait_io(self, dtype: torch.dtype=None, noalign: bool=False): - if self.copier is not None: + def wait_io(self, dtype: DType = DType.AUTO, noalign: bool = False): + if self.copier is not None and self.gbuf is not None: self.tensors = self.copier.wait_io(self.gbuf, dtype=dtype, noalign=noalign) if self.debug_log: for name in self.tensors.keys(): print(f"wait_io: tensor={name}") self.copier = None - def push(self, pg: dist.ProcessGroup, tensor_name: str, dst_rank: int, src_rank: int, group = None)->torch.Tensor: + def push( + self, + pg: ProcessGroupBase, + tensor_name: str, + dst_rank: int, + src_rank: int, + ) -> Optional[TensorBase]: if pg.size() == 1: return self.tensors[tensor_name] tag = (self.next_tag << self.factory_idx_bits) + self.lidx self.next_tag += 1 if pg.rank() != dst_rank and pg.rank() != src_rank: if self.debug_log: - print(f"push: skip, tensor_name={tensor_name}, dst_rank={dst_rank}, pg.rank()={pg.rank()}, tag={tag}") + print( + f"push: skip, tensor_name={tensor_name}, dst_rank={dst_rank}, pg.rank()={pg.rank()}, tag={tag}" + ) return None elif pg.rank() == dst_rank and src_rank == dst_rank: if self.debug_log: - print(f"push: nocopy, tensor_name={tensor_name}, dst_rank={dst_rank}, pg.rank()={pg.rank()}, tag={tag}") + print( + f"push: nocopy, tensor_name={tensor_name}, dst_rank={dst_rank}, pg.rank()={pg.rank()}, tag={tag}" + ) return self.tensors[tensor_name].clone().detach() frame = self.metadata.tensors[tensor_name] if pg.rank() == src_rank: if tensor_name not in self.tensors: - raise Exception(f"push: tensor {tensor_name} was not found. released? lidx={self.lidx}") + raise Exception( + f"push: tensor {tensor_name} was not found. released? lidx={self.lidx}" + ) t = self.tensors[tensor_name].clone().detach() if self.debug_log: - print(f"push: send, tensor_name={tensor_name}, shape={frame.shape}, dst_rank={dst_rank}, pg.rank()={pg.rank()}, tag={tag}") - if self.metadata.framework == "pytorch": - dist.send(t, dst_rank, group=pg, tag=tag) - elif paddle_loaded and self.metadata.framework == "paddle": - pdist.send(t, dst_rank, group=group) + print( + f"push: send, tensor_name={tensor_name}, shape={frame.shape}, dst_rank={dst_rank}, pg.rank()={pg.rank()}, tag={tag}" + ) + pg.send(t, dst_rank, tag=tag) return None - + if self.debug_log: - print(f"push: recv, tensor_name={tensor_name}, shape={frame.shape}, src_rank={src_rank}, pg.rank()={pg.rank()}, tag={tag}") - - if self.metadata.framework == "pytorch": - t = torch.empty(size=frame.shape, dtype=frame.dtype, device=self.device) - dist.recv(t, src_rank, group=pg, tag=tag) - elif paddle_loaded and self.metadata.framework == "paddle": - t = paddle.to_tensor(paddle.empty(size=frame.shape, dtype=frame.dtype), place=self.device) - pdist.recv(t,src_rank, group=group) + print( + f"push: recv, tensor_name={tensor_name}, shape={frame.shape}, src_rank={src_rank}, pg.rank()={pg.rank()}, tag={tag}" + ) + + t = self.framework.get_empty_tensor(frame.shape, frame.dtype, self.device) + pg.recv(t, src_rank, tag=tag) return t - def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int, group = None)->torch.Tensor: + def shuffle(self, pg: ProcessGroupBase, tensor_name: str, dim: int) -> TensorBase: if pg.size() == 1: return self.tensors[tensor_name] if tensor_name in self.shuffled: @@ -96,17 +118,14 @@ def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int, group = Non if tensor_name in self.tensors: dst = self.tensors[tensor_name].clone().detach() else: - if self.metadata.framework == "pytorch": - dst = torch.empty(size=frame.shape, dtype=frame.dtype, device=self.device) - elif paddle_loaded and self.metadata.framework == "paddle": - dst = paddle.to_tensor(paddle.empty(shape=frame.shape, dtype=frame.dtype), place=self.device) - + dst = self.framework.get_empty_tensor( + frame.shape, frame.dtype, self.device + ) if self.debug_log: - print(f"shuffle: broadcast, tensor_name={tensor_name}, shape={frame.shape}, self.rank={self.rank}, pg.rank()={pg.rank()}, has_tensor={tensor_name in self.tensors}") - if self.metadata.framework == "pytorch": - dist.broadcast(dst, self.rank, group=pg) - elif paddle_loaded and self.metadata.framework == "paddle": - pdist.broadcast(dst, self.rank, group=group) + print( + f"shuffle: broadcast, tensor_name={tensor_name}, shape={frame.shape}, self.rank={self.rank}, pg.rank()={pg.rank()}, has_tensor={tensor_name in self.tensors}" + ) + pg.broadcast(dst, self.rank) else: rank_slices: List[Tuple] = [() for i in range(0, pg.size())] size = frame.shape[dim] @@ -114,118 +133,89 @@ def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int, group = Non for rank in range(0, pg.size()): for i in range(0, len(frame.shape)): if i < dim: - rank_slices[rank] += (slice(None,None,None),) + rank_slices[rank] += (slice(None, None, None),) elif i == dim: - rank_slices[rank] += (slice(rank * block_size, (rank + 1) * block_size, 1),) + rank_slices[rank] += ( + slice(rank * block_size, (rank + 1) * block_size, 1), + ) break - scatter_list: List[torch.Tensor] = [] + scatter_list: List[TensorBase] = [] new_frame = frame[rank_slices[pg.rank()]] - - if self.metadata.framework == "pytorch": - dst = torch.empty(size=new_frame.shape, dtype=new_frame.dtype, device=self.device) - elif paddle_loaded and self.metadata.framework == "paddle": - dst = paddle.to_tensor(paddle.empty(shape=new_frame.shape, dtype=frame.dtype), place=self.device) + dst = self.framework.get_empty_tensor( + new_frame.shape, new_frame.dtype, self.device + ) if self.rank == pg.rank(): if tensor_name not in self.tensors: - raise Exception(f"shuffle: tensor {tensor_name} was not found, released? lidx={self.lidx}") + raise Exception( + f"shuffle: tensor {tensor_name} was not found, released? lidx={self.lidx}" + ) t = self.tensors[tensor_name] - scatter_list = [t[rank_slices[rank]].contiguous() for rank in range(0, pg.size())] # scatter requires contiguous tensor + scatter_list = [ + t[rank_slices[rank]].contiguous() for rank in range(0, pg.size()) + ] # scatter requires contiguous tensor if self.debug_log: - print(f"shuffle: scatter, tensor_name={tensor_name}, shape={frame.shape}->{new_frame.shape}, self.rank={self.rank}, pg.rank()={pg.rank()}, rank_slices={rank_slices}, len(scatter_list)={len(scatter_list)}") - if self.metadata.framework == "pytorch": - dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg) - elif paddle_loaded and self.metadata.framework == "paddle": - pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group) + print( + f"shuffle: scatter, tensor_name={tensor_name}, shape={frame.shape}->{new_frame.shape}, self.rank={self.rank}, pg.rank()={pg.rank()}, rank_slices={rank_slices}, len(scatter_list)={len(scatter_list)}" + ) + pg.scatter(dst, scatter_list=scatter_list, src=self.rank) self.shuffled[tensor_name] = dst return dst - def shuffle_packed_qkv(self, pg: dist.ProcessGroup, tensor_name: str, group = None)->torch.Tensor: - if tensor_name in self.shuffled: - if self.debug_log: - print(f"shuffle: use cache, tensor_name={tensor_name}") - t = self.shuffled[tensor_name].clone().detach() - return t - frame = self.metadata.tensors[tensor_name] - total_size = frame.shape[0] - single_size = total_size // 3 - block_size = (single_size + pg.size() - 1) // pg.size() - scatter_list: List[torch.Tensor] = [] - if tensor_name in self.tensors: - t = self.tensors[tensor_name] - for rank in range(0, pg.size()): - q = t[(slice(rank * block_size, (rank + 1) * block_size, 1))] - k = t[(slice(single_size + rank * block_size, single_size + (rank + 1) * block_size, 1))] - v = t[(slice(single_size * 2 + rank * block_size, single_size * 2 + (rank + 1) * block_size, 1))] - if self.metadata.framework == "pytorch": - cat_res = torch.cat([q, k, v], dim=0) - elif paddle_loaded and self.metadata.framework == "paddle": - cat_res = paddle.concat([q, k, v], axis=0) - scatter_list.append(cat_res) - if pg.size() == 1: - self.shuffled[tensor_name] = scatter_list[0] - return scatter_list[0] - new_shape = (block_size * 3,) + frame.shape[1:] - - if self.debug_log: - print(f"shuffle_packed_qkv: scatter, tensor_name={tensor_name}, shape={frame.shape}->{new_shape}, self.rank={self.rank}, pg.rank()={pg.rank()}, len(scatter_list)={len(scatter_list)}") - if self.metadata.framework == "pytorch": - dst = torch.empty(size=new_shape, dtype=frame.dtype, device=self.device) - dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg) - elif paddle_loaded and self.metadata.framework == "paddle": - dst = paddle.to_tensor(paddle.empty(shape=new_shape, dtype=frame.dtype),place=self.device) - pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group) - self.shuffled[tensor_name] = dst - return dst - - def shuffle_multi_cols(self, pg: dist.ProcessGroup, tensor_names: List[str], dim: int, group = None)->torch.Tensor: - rank_tensors: List[List[torch.Tensor]] = [[] for i in range(0, pg.size())] - new_shape: List = [] + def shuffle_multi_cols( + self, pg: ProcessGroupBase, tensor_names: List[str], dim: int + ) -> TensorBase: + rank_tensors: List[List[TensorBase]] = [[] for i in range(0, pg.size())] + new_shape: List[int] = [] for tensor_name in tensor_names: frame = self.metadata.tensors[tensor_name] - total_size = frame.shape[0] + total_size = frame.shape[dim] block_size = (total_size + pg.size() - 1) // pg.size() if len(new_shape) == 0: - new_shape = [block_size] + list(frame.shape[1:]) - elif dim == 0: - new_shape[0] += block_size + new_shape = frame.shape + new_shape[dim] = 0 else: - new_shape[dim] += frame.shape[dim] + for dim2 in range(0, len(frame.shape)): + if dim2 != dim and frame.shape[dim2] != new_shape[dim2]: + raise Exception( + f"dim {dim2} mismatch: tensor {tensor_name} has {frame.shape} vs. {new_shape} (dim={dim})" + ) + new_shape[dim] += block_size if self.rank == pg.rank(): if tensor_name not in self.tensors: - raise Exception(f"shuffle_multi_cols: tensor {tensor_name} was not found, released? lidx={self.lidx}") + raise Exception( + f"shuffle_multi_cols: tensor {tensor_name} was not found, released? lidx={self.lidx}" + ) t = self.tensors[tensor_name] for rank in range(0, pg.size()): - rank_tensors[rank].append(t[(slice(rank * block_size, (rank + 1) * block_size, 1))]) + rank_slices: Tuple[slice, ...] = () + for i in range(0, len(frame.shape)): + if i < dim: + rank_slices += (slice(None, None, None),) + elif i == dim: + rank_slices += ( + slice(rank * block_size, (rank + 1) * block_size, 1), + ) + break + rank_tensors[rank].append(t[rank_slices]) if pg.size() == 1: - if self.metadata.framework == "pytorch": - return torch.cat(rank_tensors[self.rank], dim=dim) - elif paddle_loaded and self.metadata.framework == "paddle": - return paddle.concat(rank_tensors[self.rank], axis=dim) - return None - scatter_list: List[torch.Tensor] = [] - - if len(rank_tensors[0]) > 0: + return self.framework.concat_tensors(rank_tensors[self.rank], dim=dim) + scatter_list: List[TensorBase] = [] + + if self.rank == pg.rank(): for rank in range(0, pg.size()): - scatter_list.append(torch.cat(rank_tensors[rank], dim=dim)) + scatter_list.append( + self.framework.concat_tensors(rank_tensors[rank], dim=dim) + ) if self.debug_log: - print(f"shuffle_multi_cols: scatter, tensor_name={tensor_name}, shape={frame.shape}->{new_shape}, self.rank={self.rank}, pg.rank()={pg.rank()}, len(scatter_list)={len(scatter_list)}") - if self.metadata.framework == "pytorch": - dst = torch.empty(size=new_shape, dtype=frame.dtype, device=self.device) # dst should be eariler than scatter_list for less fragmentation - dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg) - elif paddle_loaded and self.metadata.framework == "paddle": - dst = paddle.to_tensor(paddle.empty(shape=new_shape, dtype=frame.dtype), place=self.device )# dst should be eariler than scatter_list for less fragmentation - pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group) + print( + f"shuffle_multi_cols: scatter, tensor_name={tensor_name}, shape={frame.shape}->{new_shape}, self.rank={self.rank}, pg.rank()={pg.rank()}, len(scatter_list)={len(scatter_list)}" + ) + dst = self.framework.get_empty_tensor(new_shape, frame.dtype, self.device) + pg.scatter(dst, scatter_list=scatter_list, src=self.rank) return dst def free_dev_ptrs(self): self.tensors = {} if self.gbuf is not None: - free_tensor_memory(self.gbuf, self.device, self.metadata.framework) + self.framework.free_tensor_memory(self.gbuf, self.device) self.gbuf = None - - def shuffle_all(self, pg: dist.ProcessGroup, tensor_shard_dim: OrderedDict[str, int])->Tuple[int, Dict[str, torch.Tensor]]: - ret: Dict[str, torch.Tensor] = {} - for tensor_name, dim in tensor_shard_dim.items(): - if tensor_name in self.metadata.tensors: - ret[tensor_name] = self.shuffle(pg, tensor_name, dim) - return (0, ret) diff --git a/perf/fastsafetensors_perf/perf.py b/perf/fastsafetensors_perf/perf.py index 59cdca6..dd74fda 100644 --- a/perf/fastsafetensors_perf/perf.py +++ b/perf/fastsafetensors_perf/perf.py @@ -1,27 +1,37 @@ # Copyright 2024 IBM Inc. All rights reserved # SPDX-License-Identifier: Apache-2.0 +import json import os -import sys -import typer import re -from typing import Dict, List, Tuple, Any -import json import subprocess -import torch -import torch.distributed as dist -from collections import OrderedDict +import sys import time +from collections import OrderedDict from copy import deepcopy +from typing import Any, Dict, List, Tuple, Union + +import torch +import torch.distributed as dist +import typer from safetensors import safe_open -from fastsafetensors import str_to_dtype, SafeTensorsFileLoader, SingleGroup + +from fastsafetensors import SafeTensorsFileLoader, SingleGroup +from fastsafetensors.st_types import Device, DType app = typer.Typer() script_path = __file__ + class FilesBufferOnMmap: - def __init__(self, device: torch.device|None, dtype: torch.dtype|None=None, opt: bool=False, debug_log: bool=False): + def __init__( + self, + device: Union[torch.device, None], + dtype: Union[torch.dtype, None] = None, + opt: bool = False, + debug_log: bool = False, + ): self.device = device self.dtype = dtype self.debug_log = debug_log @@ -32,116 +42,112 @@ def __init__(self, device: torch.device|None, dtype: torch.dtype|None=None, opt: def add_filenames(self, filenames: List[str]): for filename in filenames: - f = safe_open(os.path.realpath(filename), framework="pytorch") - for k in f.keys(): - self.handles[k] = f + with safe_open( + os.path.realpath(filename), framework="pytorch" + ) as f: # type: ignore[attr-defined] + for k in f.keys(): + self.handles[k] = filename + self.filenames = filenames - def get_keys(self)->List[str]: - return self.handles.keys() + def get_keys(self) -> List[str]: + return list(self.handles.keys()) def enable_opt(self): self.opt = True - def get_tensor(self, tensor_name: str)->torch.Tensor: + def get_tensor(self, tensor_name: str) -> torch.Tensor: if tensor_name not in self.key_to_handle: raise ValueError(f"get_tensor: key {tensor_name} was not found in files") f = self.key_to_handle[tensor_name] - t = f.get_tensor(tensor_name) # tensor at pageable area (mmap) + t = f.get_tensor(tensor_name) # tensor at pageable area (mmap) t = t.clone().detach() if self.opt else t return t.to(device=self.device, dtype=self.dtype) - def get_sharded(self, pg: dist.ProcessGroup, tensor_name: str, dim: int)->torch.Tensor: + def get_sharded( + self, pg: dist.ProcessGroup, tensor_name: str, dim: int + ) -> torch.Tensor: if tensor_name not in self.key_to_handle: raise ValueError(f"get_sharded: key {tensor_name} was not found in files") f = self.key_to_handle[tensor_name] - t = f.get_slice(tensor_name) # tensor at pageable area (mmap) - rank_slices = () + t = f.get_slice(tensor_name) # tensor at pageable area (mmap) + rank_slices: tuple[slice, ...] = () shape = t.get_shape() size = shape[dim] block_size = (size + pg.size() - 1) // pg.size() for i in range(0, len(shape)): if i < dim: - rank_slices += (slice(None,None,None),) + rank_slices += (slice(None, None, None),) elif i == dim: - rank_slices += (slice(pg.rank() * block_size, (pg.rank() + 1) * block_size, 1),) + rank_slices += ( + slice(pg.rank() * block_size, (pg.rank() + 1) * block_size, 1), + ) break t = t[rank_slices] t = t.clone().detach() if self.opt else t return t.to(device=self.device, dtype=self.dtype) - def as_dict(self)->Dict[str, torch.Tensor]: + def as_dict(self) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} - for key, f in self.handles.items(): - t = f.get_tensor(key) # tensor at pageable area (mmap) - t = t.clone().detach() if self.opt else t # opt==True: copy to pinned area? - t = t.to(device=self.device, dtype=self.dtype) - tensors[key] = t + for filename in self.filenames: + with safe_open(filename, framework="pytorch") as f: # type: ignore[attr-defined] + for key in f.keys(): + t = f.get_tensor(key) # tensor at pageable area (mmap) + t = ( + t.clone().detach() if self.opt else t + ) # opt==True: copy to pinned area? + t = t.to(device=self.device, dtype=self.dtype) + tensors[key] = t return tensors - def as_dict_sharded(self, pg: dist.ProcessGroup, tensor_shard_dim: OrderedDict[str, int])->Dict[str, torch.Tensor]: + def as_dict_sharded( + self, pg: dist.ProcessGroup, tensor_shard_dim: OrderedDict[str, int] + ) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} - for key, dim in sorted(tensor_shard_dim.items(), key=lambda x:x[0]): - f = self.handles[key] - if dim == -1: - t = f.get_tensor(key) # tensor at pageable area (mmap) - else: - t = f.get_slice(key) - rank_slices = () - shape = t.get_shape() - size = shape[dim] - block_size = (size + pg.size() - 1) // pg.size() - for i in range(0, len(shape)): - if i < dim: - rank_slices += (slice(None,None,None),) - elif i == dim: - rank_slices += (slice(pg.rank() * block_size, (pg.rank() + 1) * block_size, 1),) - break - t = t[rank_slices] - t = t.clone().detach() if self.opt else t # opt==True: copy to pinned area? - t = t.to(device=self.device, dtype=self.dtype) - tensors[key] = t + for filename in self.filenames: + with safe_open(filename, framework="pytorch") as f: # type: ignore[attr-defined] + for key in f.keys(): + dim = tensor_shard_dim[key] + if dim == -1: + t = f.get_tensor(key) # tensor at pageable area (mmap) + else: + t = f.get_slice(key) + rank_slices: tuple = () + shape = t.get_shape() + size = shape[dim] + block_size = (size + pg.size() - 1) // pg.size() + for i in range(0, len(shape)): + if i < dim: + rank_slices += (slice(None, None, None),) + elif i == dim: + rank_slices += ( + slice( + pg.rank() * block_size, + (pg.rank() + 1) * block_size, + 1, + ), + ) + break + t = t[rank_slices] + t = ( + t.clone().detach() if self.opt else t + ) # opt==True: copy to pinned area? + t = t.to(device=self.device, dtype=self.dtype) + tensors[key] = t return tensors -"""" -sten_collection.json format: -{ - "model_name_0": { - "world_size_0": { - "rank_0": ["/path_0/file_0.safetensors", "/path_0/file_1.safetensors", "/path_0/file_2.safetensors"] - }, - "world_size_1": { - "rank_0": ["/path_0/file_0.safetensors", "/path_0/file_1.safetensors", "/path_0/file_2.safetensors"], - "rank_1": ["/path_1/file_0.safetensors", "/path_1/file_1.safetensors", "/path_1/file_2.safetensors"] - }, - ... - "world_size_N": { - "rank_0": ["/path_0/file_0.safetensors", "/path_0/file_1.safetensors", "/path_0/file_2.safetensors"], - "rank_1": ["/path_1/file_0.safetensors", "/path_1/file_1.safetensors", "/path_1/file_2.safetensors"], - ... - "rank_N": ["/path_N/file_0.safetensors", "/path_N/file_1.safetensors", "/path_N/file_2.safetensors"] - }, - "keys": { - "all": ["key.head.weight"], - "dim_0": ["key\.[0-9]+\.weight"], - "dim_1": ["key\.x\.[0-9]+\.weight"] - } - }, - "model_name_1": { - ... - }, - ... -} -""" -def get_sten_files(sten_collection_filepath: str, model_name: str, world_size: int) -> Dict[int, List[str]]: + +def get_sten_files( + sten_collection_filepath: str, model_name: str, world_size: int +) -> Dict[int, List[str]]: world_size_str = f"world_size_{world_size}" with open(sten_collection_filepath, "r") as f: m = json.load(f) - if not model_name in m: + if model_name not in m: return {} m = m[model_name] - if not world_size_str in m: + if world_size_str not in m: world_size_str = "world_size_0" - if not world_size_str in m: + if world_size_str not in m: return {} m = m[world_size_str] ret = {} @@ -149,13 +155,16 @@ def get_sten_files(sten_collection_filepath: str, model_name: str, world_size: i ret[rank] = m[f"rank_{rank}"] return ret -def get_key_pats(sten_collection_filepath: str, model_name: str) -> Tuple[Dict[re.Pattern, int], str]: + +def get_key_pats( + sten_collection_filepath: str, model_name: str +) -> Tuple[Dict[re.Pattern, int], str]: with open(sten_collection_filepath, "r") as f: m = json.load(f) - if not model_name in m: + if model_name not in m: return {}, "" m = m[model_name] - if not "keys" in m: + if "keys" not in m: return {}, "" m = m["keys"] ret = {} @@ -167,31 +176,60 @@ def get_key_pats(sten_collection_filepath: str, model_name: str) -> Tuple[Dict[r ret[re.compile(key)] = 1 return ret, m["layer_prefix"] -def get_key_dim(keys: List[str], pats: Dict[re.Pattern, int], layer_prefix: str) -> Dict[str, int]: + +def get_key_dim( + keys: List[str], pats: Dict[re.Pattern, int], layer_prefix: str +) -> OrderedDict[str, int]: from collections import OrderedDict - ret: OrderedDict[str, int] = {} - layer_tmp = {} - pat2 = re.compile(f"{layer_prefix}([0-9]+)\..*") + + ret: OrderedDict[str, int] = OrderedDict() + # layer_tmp = {} + # pat2 = re.compile(f"{layer_prefix}([0-9]+)\..*") for key in keys: found = False for pat, dim in pats.items(): m = pat.match(key) if m is not None: - #if m[0].startswith(layer_prefix): + # if m[0].startswith(layer_prefix): # layer_tmp[m[0]] = (int(pat2.match(m[0])[1]), dim) - #else: + # else: # ret[m[0]] = dim ret[m[0]] = dim found = True break if not found: ret[key] = -1 - #for key, (_, dim) in sorted(layer_tmp.items(), key=lambda x:x[1][0]): + # for key, (_, dim) in sorted(layer_tmp.items(), key=lambda x:x[1][0]): # ret[key] = dim return ret -mon_procs = {} -def start_sysstat(model_name: str, run_id: str|None=None, gpu_trace: bool=True, memtrace_enabled: bool=False) -> int: + +class MyProc: + def __init__(self, popen: Union[subprocess.Popen, None] = None): + self.popen: Union[subprocess.Popen, None] = popen + + def terminate(self): + if self.popen: + self.popen.terminate() + + def kill(self): + if self.popen: + self.popen.kill() + + def wait(self, timeout=Union[int, None]): + if self.popen: + return self.popen.wait(timeout=timeout) + + +mon_procs: Dict[int, Tuple[MyProc, MyProc, Any, str]] = {} + + +def start_sysstat( + model_name: str, + run_id: Union[str, None] = None, + gpu_trace: bool = True, + memtrace_enabled: bool = False, +) -> int: memtrace_file = "" if gpu_trace and memtrace_enabled: torch.cuda.memory._record_memory_history(max_entries=100000) @@ -207,17 +245,33 @@ def start_sysstat(model_name: str, run_id: str|None=None, gpu_trace: bool=True, dool_file = f"dool-{model_name.replace('/', '--')}.csv" iostat_f = open(f"iostat-{model_name.replace('/', '--')}.log", "w") dool_cmd = [ - "dool", "-cmdnpyg", + "dool", + "-cmdnpyg", ] if gpu_trace: dool_cmd.append("--nvidia-gpu") dool_cmd += ["--output", dool_file] - dool = subprocess.Popen(dool_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - iostat = subprocess.Popen(["iostat", "1"], stdout=iostat_f, stderr=subprocess.DEVNULL) + try: + dool = MyProc( + subprocess.Popen( + dool_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + ) + except: + dool = MyProc() + try: + iostat = MyProc( + subprocess.Popen( + ["iostat", "1"], stdout=iostat_f, stderr=subprocess.DEVNULL + ) + ) + except: + iostat = MyProc() id = len(mon_procs) mon_procs[id] = (dool, iostat, iostat_f, memtrace_file) return id + def stop_sysstat(id: int): (dool, iostat, iostat_f, memtrace_file) = mon_procs[id] dool.terminate() @@ -235,51 +289,71 @@ def stop_sysstat(id: int): torch.cuda.memory._dump_snapshot(memtrace_file) torch.cuda.memory._record_memory_history(enabled=None) -def as_torch_dtype(dtype_str: str)->torch.dtype|None: - if dtype_str == "auto": + +def as_safetensors_dtype(dtype_str: str) -> Union[torch.dtype, None]: + if dtype_str == "AUTO": return None - return str_to_dtype(dtype_str) + from fastsafetensors.frameworks._torch import dtype_convert + + dtype = DType(dtype_str) + + if dtype not in dtype_convert: + raise Exception( + f"unsupported type: {dtype}. supported types: {dtype_convert.keys()}" + ) + return dtype_convert[dtype] -def get_size(tensor: torch.Tensor)->int: + +def get_size(tensor: torch.Tensor) -> int: c = 1 for s in list(tensor.shape): c *= s c *= tensor.dtype.itemsize return c -def as_torch_device(device: str, rank: int)->torch.device: + +def as_torch_device(device: str, rank: int) -> torch.device: if device.startswith("cuda"): return torch.device(f"cuda:{rank}") elif device == "cpu": return torch.device("cpu") return torch.device(device) + @app.command() -def get_config(device_index: int, - model_name: str, - sten_collection_json: str, - rank: int, - world_size: int) -> Tuple[int, int]: +def get_config( + device_index: int, + model_name: str, + sten_collection_json: str, + rank: int, + world_size: int, +) -> Tuple[int, int]: rank_filenames = get_sten_files(sten_collection_json, model_name, world_size) filenames = rank_filenames[rank] - max_copier_threads = int(os.getenv("FST_THREADS", "16")) # number of copy threads at host CPU - bbuf_size_kb_total = int(os.getenv("FST_BBUF_SIZE_KB", "163840")) # size of bounce buffer at host memory for FST_NOGDS==1 + max_copier_threads = int( + os.getenv("FST_THREADS", "16") + ) # number of copy threads at host CPU + bbuf_size_kb_total = int( + os.getenv("FST_BBUF_SIZE_KB", "163840") + ) # size of bounce buffer at host memory for FST_NOGDS==1 from fastsafetensors.common import get_device_numa_node + node = get_device_numa_node(device_index) total_l2_size = 0 phys_cpus = {} failed = False import glob + for cpudir in glob.glob(f"/sys/devices/system/node/node{node}/cpu[0-9]*"): try: - with open(f"{cpudir}/cache/index2/size") as f: # L2 cache size for a cpu + with open(f"{cpudir}/cache/index2/size") as f: # L2 cache size for a cpu size_str = f.read().strip() if size_str[-1] != "K": raise Exception(f"cannot parse {cpudir}/cache/index2/size") total_l2_size += int(size_str[:-1]) - with open(f"{cpudir}/topology/core_id") as f: # physical core ID + with open(f"{cpudir}/topology/core_id") as f: # physical core ID phys_cpus[f.read().strip()] = True except Exception as e: failed = True @@ -299,16 +373,23 @@ def get_config(device_index: int, max_copy_block_size = s.st_size if len(filenames) < max_copier_threads: max_copy_block_size = total_size // world_size // max_copier_threads - if max_copy_block_size % bbuf_size_kb_total*1024 > 0: - max_copy_block_size = max_copy_block_size - max_copy_block_size % (bbuf_size_kb_total*1024) + (bbuf_size_kb_total*1024) - print(f"--max-threads={max_copier_threads} --max-direct-io-kb={int(bbuf_size_kb_total)} --max-block-size-mb={int(max_copy_block_size/1024/1024)}") + if max_copy_block_size % bbuf_size_kb_total * 1024 > 0: + max_copy_block_size = ( + max_copy_block_size + - max_copy_block_size % (bbuf_size_kb_total * 1024) + + (bbuf_size_kb_total * 1024) + ) + print( + f"--max-threads={max_copier_threads} --max-direct-io-kb={int(bbuf_size_kb_total)} --max-block-size-mb={int(max_copy_block_size/1024/1024)}" + ) return (max_copier_threads, bbuf_size_kb_total) + @app.command() def drop_cache( model_name: str, sten_collection_json: str, - world_size: int=1, + world_size: int = 1, ): total = 0 with open(sten_collection_json, "r") as f: @@ -321,45 +402,54 @@ def drop_cache( for filename in targets.keys(): fd = os.open(filename, os.O_RDONLY) s = os.fstat(fd) - os.posix_fadvise(fd, 0, s.st_size, os.POSIX_FADV_DONTNEED) + if hasattr(os, "posix_fadvise") and hasattr(os, "POSIX_FADV_DONTNEED"): + os.posix_fadvise(fd, 0, s.st_size, os.POSIX_FADV_DONTNEED) # type: ignore[attr-defined] os.close(fd) print(f"DROP_CACHE: {filename}, {s.st_size/1024/1024/1024} GiB") total += s.st_size fd = os.open(sten_collection_json, os.O_RDONLY) s = os.fstat(fd) - os.posix_fadvise(fd, 0, s.st_size, os.POSIX_FADV_DONTNEED) + if hasattr(os, "posix_fadvise") and hasattr(os, "POSIX_FADV_DONTNEED"): + os.posix_fadvise(fd, 0, s.st_size, os.POSIX_FADV_DONTNEED) # type: ignore[attr-defined] os.close(fd) print(f"DROP_CACHE: {sten_collection_json}, {s.st_size/1024/1024/1024} GiB") total += s.st_size print(f"total={total/1024/1024/1024}GiB from {sten_collection_json}") + @app.command() def run_mmap_sharded_internal( model_name: str, sten_collection_json: str, - device: str="cuda", - dtype: str="auto", - opt: bool=False, - debug_log: bool=False, + device: str = "cuda", + dtype: str = "AUTO", + opt: bool = False, + debug_log: bool = False, ): import torch.distributed as dist + backend = "nccl" if device == "cpu": backend = "gloo" dist.init_process_group(backend=backend) - dist.barrier() # ensure nccl is initialized + dist.barrier() # ensure nccl is initialized pg = dist.group.WORLD + if pg is None: + return rank_filenames = get_sten_files(sten_collection_json, model_name, pg.size()) filenames = [] - for _, files in sorted(rank_filenames.items(), key=lambda x:x[0]): + for _, files in sorted(rank_filenames.items(), key=lambda x: x[0]): for f in files: filenames.append(f) (key_pats, layer_prefix) = get_key_pats(sten_collection_json, model_name) - torch_dtype = as_torch_dtype(dtype) - ts = [] t0 = time.time_ns() - fb = FilesBufferOnMmap(device=as_torch_device(device, pg.rank()), dtype=torch_dtype, opt=opt, debug_log=debug_log) + fb = FilesBufferOnMmap( + device=torch.device(f"{device}:{pg.rank()}"), + dtype=as_safetensors_dtype(dtype), + opt=opt, + debug_log=debug_log, + ) fb.add_filenames(filenames) key_dim = get_key_dim(fb.get_keys(), key_pats, layer_prefix) t1 = time.time_ns() @@ -370,171 +460,227 @@ def run_mmap_sharded_internal( t2 = time.time_ns() print(f"{t0},{t1},{t2},{count}") + @app.command() def run_mmap( model_name: str, sten_collection_json: str, - device: str="cuda", - dtype: str="auto", - run_id: str=None, - rank: int=0, - world_size: int=1, - debug_log: bool=False, - sysstat_enabled: bool=True, - memtrace_enabled: bool=False, - opt: bool=False, - cache_drop: bool=False, + device: str = "cuda", + dtype: str = "AUTO", + run_id: Union[str, None] = None, + rank: int = 0, + world_size: int = 1, + debug_log: bool = False, + sysstat_enabled: bool = True, + memtrace_enabled: bool = False, + opt: bool = False, + cache_drop: bool = False, ): if cache_drop: drop_cache(model_name, sten_collection_json, world_size) - torch_dtype = as_torch_dtype(dtype) if sysstat_enabled: - stat_id = start_sysstat(model_name, run_id, device.startswith("cuda"), memtrace_enabled and world_size == 1) + stat_id = start_sysstat( + model_name, + run_id, + device.startswith("cuda"), + memtrace_enabled and world_size == 1, + ) t0 = time.time_ns() if world_size == 1: - device = as_torch_device(device, 0) filenames = get_sten_files(sten_collection_json, model_name, world_size)[rank] - ts = [] - fb = FilesBufferOnMmap(device=device, dtype=torch_dtype, opt=opt, debug_log=debug_log) + fb = FilesBufferOnMmap( + device=torch.device(f"{device}:0"), + dtype=as_safetensors_dtype(dtype), + opt=opt, + debug_log=debug_log, + ) fb.add_filenames(filenames) t1 = time.time_ns() ts = fb.as_dict() - count = 0 + count = 0.0 for _, t in ts.items(): count += get_size(t) t2 = time.time_ns() init_sec = (t1 - t0) / 1000 / 1000 / 1000 get_sec = (t2 - t1) / 1000 / 1000 / 1000 elapsed_sec = (t2 - t0) / 1000 / 1000 / 1000 - count = count/1024/1024/1024 + count = count / 1024 / 1024 / 1024 else: rank_procs = [] for rank in range(0, world_size): - rank_cmd = ["torchrun", "--nproc-per-node=1", f"--nnodes={world_size}", "--max-restarts=0", - "--master_addr=0.0.0.0", "--master_port=1234", f"--node_rank={rank}", - script_path, "run-mmap-sharded-internal", f"--dtype={dtype}", f"--device={device}", model_name, sten_collection_json, - ] + rank_cmd = [ + "torchrun", + "--nproc-per-node=1", + f"--nnodes={world_size}", + "--max-restarts=0", + "--master_addr=0.0.0.0", + "--master_port=1234", + f"--node_rank={rank}", + script_path, + "run-mmap-sharded-internal", + f"--dtype={dtype}", + f"--device={device}", + model_name, + sten_collection_json, + ] if opt: rank_cmd += ["--opt"] envs = deepcopy(os.environ) if world_size == 2: envs["CUDA_VISIBLE_DEVICES"] = "4,6" else: - envs["CUDA_VISIBLE_DEVICES"] = ",".join([str((i + 4) % 8) for i in range(0, world_size)]) # for vela cluster - rank_procs.append(subprocess.Popen(rank_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=envs)) + envs["CUDA_VISIBLE_DEVICES"] = ",".join( + [str((i + 4) % 8) for i in range(0, world_size)] + ) # for vela cluster + rank_procs.append( + subprocess.Popen( + rank_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=envs + ) + ) outs = [] for proc in rank_procs: stdout, stderr = proc.communicate() outs.append(stdout.decode("utf-8")) if len(stderr) > 0: print(stderr) - - count = 0 - t = [[-1,-1],[-1,-1],[-1,-1]] # (min, max) for t_n + + count = 0.0 + time_matrix = [[-1, -1], [-1, -1], [-1, -1]] # (min, max) for t_n for rank, out in enumerate(outs): out_split = out.split(",") t_ps = [int(s) for s in out_split] count += t_ps[-1] - p_init_sec = (t_ps[1]-t_ps[0]) / 1000 / 1000 / 1000 - p_get_sec = (t_ps[2]-t_ps[1]) / 1000 / 1000 / 1000 - p_elapsed_sec = (t_ps[2]-t_ps[0]) / 1000 / 1000 / 1000 - print(f"rank{rank}: elapsed={p_elapsed_sec}, init={p_init_sec} sec, get={p_get_sec} sec, bytes={t_ps[-1]/1024/1024/1024} GiB, bw={t_ps[-1]/1024/1024/1024/p_elapsed_sec} GiB/s") + p_init_sec = (t_ps[1] - t_ps[0]) / 1000 / 1000 / 1000 + p_get_sec = (t_ps[2] - t_ps[1]) / 1000 / 1000 / 1000 + p_elapsed_sec = (t_ps[2] - t_ps[0]) / 1000 / 1000 / 1000 + print( + f"rank{rank}: elapsed={p_elapsed_sec}, init={p_init_sec} sec, get={p_get_sec} sec, bytes={t_ps[-1]/1024/1024/1024} GiB, bw={t_ps[-1]/1024/1024/1024/p_elapsed_sec} GiB/s" + ) for i, t_p in enumerate(t_ps[:-1]): - if t[i][0] == -1 or t[i][0] > t_p: - t[i][0] = t_p - if t[i][1] == -1 or t[i][1] < t_p: - t[i][1] = t_p - init_sec = (t[1][1] - t[0][0]) / 1000 / 1000 / 1000 - get_sec = (t[2][1] - t[1][0]) / 1000 / 1000 /1000 - elapsed_sec = (t[2][1] - t[0][0]) / 1000 / 1000 /1000 - count = count/1024/1024/1024 + if time_matrix[i][0] == -1 or time_matrix[i][0] > t_p: + time_matrix[i][0] = t_p + if time_matrix[i][1] == -1 or time_matrix[i][1] < t_p: + time_matrix[i][1] = t_p + init_sec = (time_matrix[1][1] - time_matrix[0][0]) / 1000 / 1000 / 1000 + get_sec = (time_matrix[2][1] - time_matrix[1][0]) / 1000 / 1000 / 1000 + elapsed_sec = (time_matrix[2][1] - time_matrix[0][0]) / 1000 / 1000 / 1000 + count = count / 1024 / 1024 / 1024 if sysstat_enabled: stop_sysstat(stat_id) - print(f"elapsed: {elapsed_sec} sec, init: {init_sec} sec, get: {get_sec} sec, bytes={count}GiB, bw={count/elapsed_sec}GiB/s") + print( + f"elapsed: {elapsed_sec} sec, init: {init_sec} sec, get: {get_sec} sec, bytes={count}GiB, bw={count/elapsed_sec}GiB/s" + ) + @app.command() def run_gds_sharded_internal( model_name: str, sten_collection_json: str, device: str = "cuda", - dtype: str="auto", - max_block_size_mb: float=10*1024, - max_direct_io_kb: int=16*1024, - max_pinned_memory_in_kb: int=64*1024*1024, - max_threads: int=16, - use_buf_register: bool=False, - debug_log: bool=False, + dtype: str = "AUTO", + max_block_size_mb: float = 10 * 1024, + bbuf_size_kb: int = 16 * 1024, + max_threads: int = 16, + use_buf_register: bool = True, + debug_log: bool = False, nogds: bool = False, - exclude_gds_init: bool=True, + exclude_gds_init: bool = True, ): - torch_dtype: torch.dtype = as_torch_dtype(dtype) import torch.distributed as dist + backend = "nccl" if device == "cpu": backend = "gloo" dist.init_process_group(backend=backend) - dist.barrier() # ensure nccl is initialized + dist.barrier() # ensure nccl is initialized pg = dist.group.WORLD + if pg is None: + return filenames = get_sten_files(sten_collection_json, model_name, pg.size()) (key_pats, layer_prefix) = get_key_pats(sten_collection_json, model_name) - device = as_torch_device(device, pg.rank()) t0 = time.time_ns() if not nogds and exclude_gds_init: import fastsafetensors.cpp as fstcpp - fstcpp.init_gds(16 * 1024, 80 * 1024 * 1024) - loader = SafeTensorsFileLoader(pg, device, max_direct_io_kb, max_pinned_memory_in_kb, max_threads, nogds=nogds, debug_log=debug_log) + + fstcpp.init_gds() + loader = SafeTensorsFileLoader( + pg, + f"{device}:{pg.rank()}", + bbuf_size_kb, + max_threads, + nogds=nogds, + debug_log=debug_log, + ) loader.add_filenames(filenames) key_dim = get_key_dim(loader.get_keys(), key_pats, layer_prefix) t1 = time.time_ns() - fb = loader.copy_files_to_device(dtype=torch_dtype, use_buf_register=use_buf_register, max_copy_block_size=int(max_block_size_mb*1024*1024)) + fb = loader.copy_files_to_device( + dtype=DType(dtype), + use_buf_register=use_buf_register, + max_copy_block_size=int(max_block_size_mb * 1024 * 1024), + ) t2 = time.time_ns() ts = fb.as_dict(tensor_shard_dim=key_dim) t3 = time.time_ns() count = 0 for _, t in ts.items(): - count += get_size(t) + count += get_size(t.get_raw()) t4 = time.time_ns() print(f"{t0},{t1},{t2},{t3},{t4},{count}") loader.close() if not nogds and exclude_gds_init: fstcpp.close_gds() + @app.command() def run_gds( model_name: str, sten_collection_json: str, - dtype: str="auto", - run_id: str=None, - device: str="cuda", - max_block_size_mb: float=10*1024, - debug_log: bool=False, - max_direct_io_kb: int=16*1024, - max_pinned_memory_in_kb: int=64*1024*1024, - max_threads: int=16, - world_size: int=1, - use_buf_register: bool=False, - sysstat_enabled: bool=True, - memtrace_enabled: bool=False, + dtype: str = "AUTO", + run_id: Union[str, None] = None, + device: str = "cuda", + max_block_size_mb: float = 10 * 1024, + debug_log: bool = False, + bbuf_size_kb: int = 16 * 1024, + max_threads: int = 16, + world_size: int = 1, + use_buf_register: bool = True, + sysstat_enabled: bool = True, + memtrace_enabled: bool = False, nogds: bool = False, cache_drop: bool = False, - exclude_gds_init: bool=True, + exclude_gds_init: bool = True, ): if cache_drop: drop_cache(model_name, sten_collection_json, world_size) - torch_dtype = as_torch_dtype(dtype) + torch_dtype = as_safetensors_dtype(dtype) if sysstat_enabled: - stat_id = start_sysstat(model_name, run_id, device.startswith("cuda"), memtrace_enabled and world_size == 1) + stat_id = start_sysstat( + model_name, + run_id, + device.startswith("cuda"), + memtrace_enabled and world_size == 1, + ) if world_size > 1: rank_procs = {} for rank in range(0, world_size): - rank_cmd = ["torchrun", "--nproc-per-node=1", f"--nnodes={world_size}", "--max-restarts=0", - "--master_addr=0.0.0.0", "--master_port=1234", f"--node_rank={rank}", - script_path, "run-gds-sharded-internal", - f"--max-threads={max_threads}", f"--max-direct-io-kb={max_direct_io_kb}", f"--device={device}", - f"--max-block-size-mb={max_block_size_mb}", - ] + rank_cmd = [ + "torchrun", + "--nproc-per-node=1", + f"--nnodes={world_size}", + "--max-restarts=0", + "--master_addr=0.0.0.0", + "--master_port=1234", + f"--node_rank={rank}", + script_path, + "run-gds-sharded-internal", + f"--max-threads={max_threads}", + f"--bbuf-size-kb={bbuf_size_kb}", + f"--device={device}", + f"--max-block-size-mb={max_block_size_mb}", + ] if use_buf_register: rank_cmd += ["--use-buf-register"] if nogds: @@ -549,26 +695,36 @@ def run_gds( if world_size == 2: envs["CUDA_VISIBLE_DEVICES"] = "4,6" else: - envs["CUDA_VISIBLE_DEVICES"] = ",".join([str((i + 4) % 8) for i in range(0, world_size)]) # for vela cluster + envs["CUDA_VISIBLE_DEVICES"] = ",".join( + [str((i + 4) % 8) for i in range(0, world_size)] + ) # for vela cluster envs["NCCL_CUMEM_ENABLE"] = "0" if debug_log: - print(' '.join(rank_cmd)) - rank_procs[rank] = subprocess.Popen(rank_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=envs) + print(" ".join(rank_cmd)) + rank_procs[rank] = subprocess.Popen( + rank_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=envs + ) outs = [] - for rank, proc in sorted(rank_procs.items(), key=lambda x:x[0]): + for rank, proc in sorted(rank_procs.items(), key=lambda x: x[0]): stdout, stderr = proc.communicate() outs.append(stdout.decode("utf-8")) if len(stdout) > 0: - sys.stdout.buffer.write(bytes(f"{rank}: ", 'utf-8')) + sys.stdout.buffer.write(bytes(f"{rank}: ", "utf-8")) sys.stdout.buffer.write(stdout) if len(stderr) > 0: - sys.stderr.buffer.write(bytes(f"{rank}: ", 'utf-8')) + sys.stderr.buffer.write(bytes(f"{rank}: ", "utf-8")) sys.stderr.buffer.write(stderr) - pat = re.compile('^([0-9]+),([0-9]+),([0-9]+),([0-9]+),([0-9]+),([0-9]+)$') - count = 0 - t = [[-1,-1],[-1,-1],[-1,-1],[-1,-1],[-1,-1]] # (min, max) for t_n + pat = re.compile("^([0-9]+),([0-9]+),([0-9]+),([0-9]+),([0-9]+),([0-9]+)$") + count = 0.0 + time_matrix = [ + [-1, -1], + [-1, -1], + [-1, -1], + [-1, -1], + [-1, -1], + ] # (min, max) for t_n for rank, out in enumerate(outs): - for line in out.split('\n'): + for line in out.split("\n"): out_split = pat.match(line) if out_split is None: continue @@ -576,40 +732,52 @@ def run_gds( for i in range(0, 6): t_ps.append(int(out_split[i + 1])) count += t_ps[-1] - p_init_sec = (t_ps[1]-t_ps[0]) / 1000 / 1000 / 1000 - p_io_sec = (t_ps[2]-t_ps[1]) / 1000 / 1000 / 1000 - p_shuffle_sec = (t_ps[3]-t_ps[2]) / 1000 / 1000 / 1000 - p_get_sec = (t_ps[4]-t_ps[3]) / 1000 / 1000 / 1000 - p_elapsed_sec = (t_ps[4]-t_ps[0]) / 1000 / 1000 / 1000 - print(f"rank{rank}: elapsed={p_elapsed_sec}, init={p_init_sec} sec, io={p_io_sec} sec, shuffle={p_shuffle_sec} sec, get={p_get_sec} sec, bytes={t_ps[-1]/1024/1024/1024} GiB, bw={t_ps[-1]/1024/1024/1024/p_elapsed_sec} GiB/s") + p_init_sec = (t_ps[1] - t_ps[0]) / 1000 / 1000 / 1000 + p_io_sec = (t_ps[2] - t_ps[1]) / 1000 / 1000 / 1000 + p_shuffle_sec = (t_ps[3] - t_ps[2]) / 1000 / 1000 / 1000 + p_get_sec = (t_ps[4] - t_ps[3]) / 1000 / 1000 / 1000 + p_elapsed_sec = (t_ps[4] - t_ps[0]) / 1000 / 1000 / 1000 + print( + f"rank{rank}: elapsed={p_elapsed_sec}, init={p_init_sec} sec, io={p_io_sec} sec, shuffle={p_shuffle_sec} sec, get={p_get_sec} sec, bytes={t_ps[-1]/1024/1024/1024} GiB, bw={t_ps[-1]/1024/1024/1024/p_elapsed_sec} GiB/s" + ) for i, t_p in enumerate(t_ps[:-1]): - if t[i][0] == -1 or t[i][0] > t_p: - t[i][0] = t_p - if t[i][1] == -1 or t[i][1] < t_p: - t[i][1] = t_p - init_sec = (t[1][1] - t[0][0]) / 1000 / 1000 / 1000 - io_sec = (t[3][1] - t[1][0]) / 1000 / 1000 /1000 - get_sec = (t[4][1] - t[3][0]) / 1000 / 1000 /1000 - elapsed_sec = (t[4][1] - t[0][0]) / 1000 / 1000 /1000 - count = count/1024/1024/1024 + if time_matrix[i][0] == -1 or time_matrix[i][0] > t_p: + time_matrix[i][0] = t_p + if time_matrix[i][1] == -1 or time_matrix[i][1] < t_p: + time_matrix[i][1] = t_p + init_sec = (time_matrix[1][1] - time_matrix[0][0]) / 1000 / 1000 / 1000 + io_sec = (time_matrix[3][1] - time_matrix[1][0]) / 1000 / 1000 / 1000 + get_sec = (time_matrix[4][1] - time_matrix[3][0]) / 1000 / 1000 / 1000 + elapsed_sec = (time_matrix[4][1] - time_matrix[0][0]) / 1000 / 1000 / 1000 + count = count / 1024 / 1024 / 1024 else: - device = as_torch_device(device, 0) - filenames = get_sten_files(sten_collection_json, model_name, world_size) t0 = time.time_ns() if not nogds and exclude_gds_init: import fastsafetensors.cpp as fstcpp - fstcpp.init_gds(16 * 1024, 80 * 1024 * 1024) - loader = SafeTensorsFileLoader(SingleGroup(), device, max_direct_io_kb, max_pinned_memory_in_kb, max_threads, nogds=nogds, debug_log=debug_log) + + fstcpp.init_gds() + loader = SafeTensorsFileLoader( + SingleGroup(), + f"{device}:0", + bbuf_size_kb, + max_threads, + nogds=nogds, + debug_log=debug_log, + ) loader.add_filenames(filenames) t1 = time.time_ns() - tensors = loader.copy_files_to_device(dtype=torch_dtype, use_buf_register=use_buf_register, max_copy_block_size=int(max_block_size_mb*1024*1024)) + tensors = loader.copy_files_to_device( + dtype=DType(dtype), + use_buf_register=use_buf_register, + max_copy_block_size=int(max_block_size_mb * 1024 * 1024), + ) t2 = time.time_ns() - count = 0 - ts = tensors.as_dict({key: -1 for key in loader.get_keys()}) + count = 0.0 + ts = tensors.as_dict(OrderedDict({key: -1 for key in loader.get_keys()})) for key, t in ts.items(): - c = get_size(t) + c = get_size(t.get_raw()) count += c t3 = time.time_ns() init_sec = (t1 - t0) / 1000 / 1000 / 1000 @@ -621,10 +789,13 @@ def run_gds( if sysstat_enabled: stop_sysstat(stat_id) - print(f"elapsed: {elapsed_sec} sec, init: {init_sec} sec, io: {io_sec} sec, get: {get_sec} sec, bytes={count}GiB, bw={count/elapsed_sec}GiB/s") + print( + f"elapsed: {elapsed_sec} sec, init: {init_sec} sec, io: {io_sec} sec, get: {get_sec} sec, bytes={count}GiB, bw={count/elapsed_sec}GiB/s" + ) if world_size == 1: loader.close() + if __name__ == "__main__": app() diff --git a/perf/pyproject.toml b/perf/pyproject.toml index 2a7a7b4..0c4d2a2 100644 --- a/perf/pyproject.toml +++ b/perf/pyproject.toml @@ -7,7 +7,7 @@ authors = ["Takeshi Yoshimura "] [tool.poetry.dependencies] python = ">=3.10.0,<3.13" fastsafetensors = ">=0.1.0" -typer = "^0.9.0" +typer = ">0.9.0" torch = ">2.1.0" safetensors = "^0.4.3" diff --git a/pyproject.toml b/pyproject.toml index 593d429..dd3be8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "fastsafetensors" -version = "0.1.13" +version = "0.1.14" description = "High-performance safetensors model loader" authors = [{name = "Takeshi Yoshimura", email = "tyos@jp.ibm.com"}] maintainers = [{name = "Takeshi Yoshimura", email = "tyos@jp.ibm.com"}] @@ -10,11 +10,11 @@ keywords = ["fastsafetensors", "safetensors", "GDS"] requires-python = ">= 3.9" dependencies = [ "typer>=0.9.0", - "torch>=2.1", ] [project.optional-dependencies] test = [ + "torch>=2.5.1", "pytest>=8.1.1", "pytest-cov>=5.0.0", "transformers>=4.40.2", @@ -31,3 +31,6 @@ filterwarnings = ["ignore:Can't initialize NVML"] [build-system] requires = ["setuptools==69.5.1", "pybind11"] build-backend = "setuptools.build_meta" + +[tool.isort] +profile = "black" diff --git a/setup.py b/setup.py index 309807f..b052909 100644 --- a/setup.py +++ b/setup.py @@ -2,26 +2,37 @@ # SPDX-License-Identifier: Apache-2.0 import os -from setuptools import setup, Extension + +from setuptools import Extension, setup + def MyExtension(name, sources, mod_name, *args, **kwargs): import pybind11 + pybind11_path = os.path.dirname(pybind11.__file__) - kwargs['define_macros'] = [("__MOD_NAME__", mod_name)] - kwargs['libraries'] = ['stdc++'] - kwargs['include_dirs'] = kwargs.get('include_dirs', []) + [f"{pybind11_path}/include"] # for pybind11/pybind11.h - kwargs['language'] = 'c++' + kwargs["define_macros"] = [("__MOD_NAME__", mod_name)] + kwargs["libraries"] = ["stdc++"] + kwargs["include_dirs"] = kwargs.get("include_dirs", []) + [ + f"{pybind11_path}/include" + ] # for pybind11/pybind11.h + kwargs["language"] = "c++" # https://pybind11.readthedocs.io/en/stable/faq.html#someclass-declared-with-greater-visibility-than-the-type-of-its-field-someclass-member-wattributes - kwargs['extra_compile_args'] = ['-fvisibility=hidden', '-std=c++17'] + kwargs["extra_compile_args"] = ["-fvisibility=hidden", "-std=c++17"] return Extension(name, sources, *args, **kwargs) + setup( - packages=["fastsafetensors", "fastsafetensors.copier", "fastsafetensors.cpp"], + packages=[ + "fastsafetensors", + "fastsafetensors.copier", + "fastsafetensors.cpp", + "fastsafetensors.frameworks", + ], include_package_data=True, - package_data={"fastsafetensors.cpp": ["*.hpp"]}, + package_data={"fastsafetensors.cpp": ["*.hpp", "cpp.pyi"]}, ext_modules=[ MyExtension( name=f"fastsafetensors.cpp", diff --git a/tests/conftest.py b/tests/conftest.py index 7b38f55..5018399 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,12 @@ import os +from typing import List + import pytest -import torch -import torch.distributed as dist -from fastsafetensors import cpp as fstcpp + from fastsafetensors import SingleGroup -from fastsafetensors.common import paddle_loaded -from typing import List +from fastsafetensors import cpp as fstcpp +from fastsafetensors.frameworks import FrameworkOpBase, get_framework_op +from fastsafetensors.st_types import Device TESTS_DIR = os.path.dirname(__file__) REPO_ROOT = os.path.dirname(os.path.dirname(TESTS_DIR)) @@ -15,58 +16,68 @@ os.makedirs(TF_DIR, 0o777, True) os.makedirs(TMP_DIR, 0o777, True) -@pytest.fixture(scope='session', autouse=True) +FRAMEWORK = get_framework_op(os.getenv("TEST_FASTSAFETENSORS_FRAMEWORK", "please set")) + + +@pytest.fixture(scope="session", autouse=True) +def framework() -> FrameworkOpBase: + return FRAMEWORK + + +@pytest.fixture(scope="session", autouse=True) def input_files() -> List[str]: - os.environ["HF_HOME"] = TF_DIR - os.environ["HUGGINGFACE_HUB_CACHE"] = TF_DIR - from transformers import AutoModelForCausalLM, AutoTokenizer - AutoModelForCausalLM.from_pretrained("gpt2") - AutoTokenizer.from_pretrained("gpt2") + gpt_dir = os.path.join(TF_DIR, "models--gpt2") + if not os.path.exists(gpt_dir): + from transformers import AutoModelForCausalLM, AutoTokenizer + + AutoModelForCausalLM.from_pretrained( + "gpt2", trust_remote_code=True, use_safetensors=True, cache_dir=TF_DIR + ) + AutoTokenizer.from_pretrained("gpt2", cache_dir=TF_DIR) src_files = [] - for dir, _, files in os.walk(os.path.join(TF_DIR, "models--gpt2")): + for dir, _, files in os.walk(gpt_dir): for filename in files: if filename.endswith(".safetensors"): src_files.append(f"{dir}/{filename}") + print(src_files[-1]) return src_files -@pytest.fixture(scope='session', autouse=True) + +@pytest.fixture(scope="session", autouse=True) def pg(): - world_size = os.getenv("WORLD_SIZE") - if world_size is not None and int(world_size) > 1: - dist.init_process_group(backend="gloo") - dist.barrier() - PG = dist.group.WORLD - else: - PG = SingleGroup() - return PG - -@pytest.fixture(scope='session', autouse=True) -def pg_paddle(): - PG = SingleGroup() - - if paddle_loaded: - # The following code can only be successfully - # executed by running the code using - # `python -m paddle.distributed.launch` - try: - import paddle + world_size = int(os.getenv("WORLD_SIZE", "1")) + if world_size > 1: + if FRAMEWORK.get_name() == "pytorch": + import torch.distributed as dist + + dist.init_process_group(backend="gloo") + dist.barrier() + return dist.group.WORLD + elif FRAMEWORK.get_name() == "paddle": + # The following code can only be successfully + # executed by running the code using + # `python -m paddle.distributed.launch` import paddle.distributed as dist + dist.init_parallel_env() - backend = "nccl" if paddle.device.cuda.device_count() else "gloo" - PG = dist.new_group(ranks=[0,1], backend=backend) - except: - pass - return PG + return dist.new_group(ranks=list(range(world_size)), backend="gloo") + return SingleGroup() -@pytest.fixture(scope='session', autouse=True) + +@pytest.fixture(scope="session", autouse=True) def dev_init() -> None: - if torch.cuda.is_available(): - torch.cuda.set_device(0) + if fstcpp.is_cuda_found(): + dev_str = "cuda:0" if FRAMEWORK.get_name() == "pytorch" else "gpu:0" + else: + dev_str = "cpu" + FRAMEWORK.set_device(Device.from_str(dev_str)) -@pytest.fixture(scope='function') + +@pytest.fixture(scope="function") def fstcpp_log() -> None: fstcpp.set_debug_log(True) -@pytest.fixture(scope='function') + +@pytest.fixture(scope="function") def tmp_dir() -> str: - return TMP_DIR \ No newline at end of file + return TMP_DIR diff --git a/tests/run_distributed_paddle_test.py b/tests/run_distributed_paddle_test.py deleted file mode 100644 index ceea4e0..0000000 --- a/tests/run_distributed_paddle_test.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest -import sys -import os - -if __name__ == "__main__": - # There are 4 commands before this test - # GPU ditributed need at least 2 GPU - rank = int(os.getenv("PADDLE_TRAINER_ID")) + 4 - os.environ["COVERAGE_FILE"] = f".coverage_{rank}" - pytest_args = sys.argv[1:] - sys.exit(pytest.main(pytest_args)) \ No newline at end of file diff --git a/tests/sten-collection.json b/tests/sten-collection.json index 5859d6b..892457e 100644 --- a/tests/sten-collection.json +++ b/tests/sten-collection.json @@ -1,4 +1,18 @@ { + "gpt2": { + "world_size_1": { + "rank_0": [ + "/data/transformers_cache/models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/model.safetensors" + ] + } + }, + "openai-community/gpt2-medium": { + "world_size_1": { + "rank_0": [ + "/data/transformers_cache/models--openai-community--gpt2-medium/snapshots/6dcaa7a952f72f9298047fd5137cd6e4f05f41da/model.safetensors" + ] + } + }, "bigscience/bloom": { "world_size_8": { "rank_0": [ diff --git a/tests/test_fastsafetensors.py b/tests/test_fastsafetensors.py index c32e159..8ae32a1 100644 --- a/tests/test_fastsafetensors.py +++ b/tests/test_fastsafetensors.py @@ -2,37 +2,81 @@ # SPDX-License-Identifier: Apache-2.0 import os +from collections import OrderedDict +from typing import Any, Dict, List, Tuple + import pytest -import torch -from safetensors import safe_open -from safetensors.torch import save_file -from typing import Dict, Tuple -from fastsafetensors.dlpack import from_cuda_buffer -from fastsafetensors import SafeTensorsFileLoader, SingleGroup, SafeTensorsMetadata, fastsafe_open + +from fastsafetensors import SafeTensorsFileLoader, SafeTensorsMetadata, SingleGroup +from fastsafetensors import cpp as fstcpp +from fastsafetensors import fastsafe_open +from fastsafetensors.common import get_device_numa_node from fastsafetensors.copier.gds import GdsFileCopier from fastsafetensors.copier.nogds import NoGdsFileCopier -from fastsafetensors.common import alloc_tensor_memory, free_tensor_memory, need_workaround_dtypes, paddle_loaded -from fastsafetensors import cpp as fstcpp -if paddle_loaded: - import paddle - from safetensors.paddle import save_file as paddle_save_file - -def get_and_check_device(framework="pytorch"): - dev_is_gpu = fstcpp.is_cuda_found() - if framework == "pytorch" or framework == "pt": - device = torch.device("cuda:0" if dev_is_gpu else "cpu") - elif paddle_loaded and framework == "paddle": - device = "gpu:0" if dev_is_gpu else "cpu" +from fastsafetensors.dlpack import from_cuda_buffer +from fastsafetensors.frameworks import FrameworkOpBase +from fastsafetensors.st_types import Device, DeviceType, DType + + +def load_safetensors_file( + filename: str, + device: Device, + framework: FrameworkOpBase, + to_dtype: DType = DType.AUTO, +) -> Dict[str, Any]: + if framework.get_name() == "pytorch": + from safetensors.torch import load_file + elif framework.get_name() == "paddle": + from safetensors.paddle import load_file else: - raise NotImplementedError(f"Do not support framework: {framework}") - return device, dev_is_gpu + raise Exception(f"unkown framework: {framework.get_name()}") + d = load_file(filename, device.as_str()) + if to_dtype != DType.AUTO: + if framework.get_name() == "pytorch": + from fastsafetensors.frameworks._torch import dtype_convert + elif framework.get_name() == "paddle": + from fastsafetensors.frameworks._paddle import dtype_convert + + for k, t in d.items(): + d[k] = t.to(dtype=dtype_convert[to_dtype]) + return d + + +def save_safetensors_file( + tensors: Dict[str, Any], + filename: str, + metadata: Dict[str, str], + framework: FrameworkOpBase, +) -> None: + if framework.get_name() == "pytorch": + from safetensors.torch import save_file + elif framework.get_name() == "paddle": + from safetensors.paddle import save_file + else: + raise Exception(f"unkown framework: {framework.get_name()}") + save_file(tensors, filename, metadata) + -def run_nogds_file_read(input_file: str, framework="pytorch")->Tuple[SafeTensorsMetadata, fstcpp.gds_device_buffer]: +def get_and_check_device(framework: FrameworkOpBase): + dev_is_gpu = fstcpp.is_cuda_found() + device = "cpu" + if dev_is_gpu: + if framework.get_name() == "pytorch": + device = "cuda:0" + elif framework.get_name() == "paddle": + device = "gpu:0" + return Device.from_str(device), dev_is_gpu + + +def run_nogds_file_read( + input_file: str, + framework: FrameworkOpBase, +) -> Tuple[SafeTensorsMetadata, fstcpp.gds_device_buffer]: fd = os.open(input_file, os.O_RDONLY, 0o644) - meta = SafeTensorsMetadata.from_file(input_file, framework=framework) + meta = SafeTensorsMetadata.from_file(input_file, framework) size = meta.size_bytes - meta.header_length device, dev_is_gpu = get_and_check_device(framework) - gbuf = alloc_tensor_memory(size, device, framework=framework) + gbuf = framework.alloc_tensor_memory(size, device) reader = fstcpp.nogds_file_reader(False, 20 * 1024, 1, dev_is_gpu) req = reader.submit_read(fd, gbuf, meta.header_length, size, 0) assert req > 0 @@ -40,124 +84,192 @@ def run_nogds_file_read(input_file: str, framework="pytorch")->Tuple[SafeTensors os.close(fd) return (meta, gbuf) -def test_load_metadata_and_dlpack(fstcpp_log, input_files, framework="pytorch"): + +def test_device(fstcpp_log) -> None: + print("test_device") + with pytest.raises(ValueError, match="Unknown device type: aaaa"): + Device.from_str("aaaa:0") + with pytest.raises(ValueError, match="Invalid index: -xxx"): + Device.from_str("cpu:-xxx") + with pytest.raises(ValueError, match="Unknown device type: aaa"): + Device.from_str("aaa") + cuda = Device.from_str("cuda:4") + assert cuda.type == DeviceType.CUDA + assert cuda.index == 4 + cpu = Device(DeviceType.CPU, None) + assert cpu.type == DeviceType.CPU and cpu.index == None + + +def test_framework(fstcpp_log, framework) -> None: + print("test_framework") + t = framework.get_empty_tensor([1], DType.F16, Device.from_str("cpu")) + with pytest.raises(Exception): + framework.is_equal(t, [float(0.0)]) + with pytest.raises(Exception): + framework.get_process_group(int(0)) + if framework.get_name() == "pytorch": + import torch + + cuda_ver = str(torch.version.cuda) if torch.cuda.is_available() else "0.0" + elif framework.get_name() == "paddle": + import paddle + + if paddle.device.is_compiled_with_cuda(): + cuda_ver = str(paddle.version.cuda()) + else: + cuda_ver = "0.0" + assert framework.get_cuda_ver() == cuda_ver + + +def make_header_bytes(s: str): + header = s.encode("utf-8") + n = len(header) + return n.to_bytes(8, byteorder="little", signed=False) + header + + +def test_from_buffer_header_too_small(framework): + with pytest.raises(Exception, match="HeaderTooSmall"): + SafeTensorsMetadata.from_buffer( + buf=0, buffer_len=4, filename="testfile", framework=framework + ) + + +def test_from_buffer_header_too_large(monkeypatch, framework): + def fake_read_buffer(buf, size): + return (100_000_001).to_bytes(8, "little") + + monkeypatch.setattr(fstcpp, "read_buffer", fake_read_buffer) + + with pytest.raises(Exception, match="HeaderTooLarge"): + SafeTensorsMetadata.from_buffer( + buf=0, buffer_len=1024, filename="testfile", framework=framework + ) + + +def test_from_buffer_invalid_header_length(monkeypatch, framework): + def fake_read_buffer(buf, size): + return (100).to_bytes(8, "little") + + monkeypatch.setattr(fstcpp, "read_buffer", fake_read_buffer) + + with pytest.raises(Exception, match="InvalidHeaderLength"): + SafeTensorsMetadata.from_buffer( + buf=0, buffer_len=50, filename="testfile", framework=framework + ) + + +def test_from_buffer_success(monkeypatch, framework): + json_str = '{"__metadata__": {"data_offsets": [0, 123]}}' + header_bytes = make_header_bytes(json_str) + buf_data = header_bytes + + def fake_read_buffer(buf, size): + return buf_data[buf : buf + size] + + monkeypatch.setattr(fstcpp, "read_buffer", fake_read_buffer) + + meta = SafeTensorsMetadata.from_buffer( + buf=0, buffer_len=len(buf_data), filename="goodfile", framework=framework + ) + assert isinstance(meta, SafeTensorsMetadata) + + +def test_load_metadata_and_dlpack(fstcpp_log, input_files, framework) -> None: print("test_load_metadata_and_dlpack") assert len(input_files) > 0 device, _ = get_and_check_device(framework) for input_file in input_files: - expected_tensors: Dict[str, torch.Tensor] = {} - with safe_open(input_file, framework="pt") as f: - for k in f.keys(): - expected_tensors[k] = f.get_tensor(k) - if framework == "pytorch": - expected_tensors[k] = expected_tensors[k].to(device=device) - elif framework == "paddle": - expected_tensors[k] = paddle.to_tensor(expected_tensors[k].numpy(), place=device) - meta, gbuf = run_nogds_file_read(input_file, framework=framework) + expected_tensors = load_safetensors_file(input_files[0], device, framework) + meta, gbuf = run_nogds_file_read(input_file, framework) assert meta.header_length > 0 assert meta.size_bytes > 0 assert len(meta.tensors) > 0 printed = False - for name, actual_meta in sorted(meta.tensors.items(), key=lambda x:x[0]): + for name, actual_meta in sorted(meta.tensors.items(), key=lambda x: x[0]): dst_dev_ptr = gbuf.get_base_address() + actual_meta.data_offsets[0] - if actual_meta.dtype in need_workaround_dtypes: - wdtype = need_workaround_dtypes[actual_meta.dtype] - cu_buf = from_cuda_buffer(dst_dev_ptr, actual_meta.shape, actual_meta.strides, wdtype, device) - if framework == "pytorch": - actual = torch.from_dlpack(cu_buf).view(actual_meta.dtype) - elif framework == "paddle": - actual = paddle.utils.dlpack.from_dlpack(cu_buf).view(actual_meta.dtype) - else: - cu_buf = from_cuda_buffer(dst_dev_ptr, actual_meta.shape, actual_meta.strides, actual_meta.dtype, device) - if framework == "pytorch": - actual = torch.from_dlpack(cu_buf) - elif framework == "paddle": - actual = paddle.utils.dlpack.from_dlpack(cu_buf) + wdtype = framework.as_workaround_dtype(actual_meta.dtype) + cu_buf = from_cuda_buffer( + dst_dev_ptr, actual_meta.shape, actual_meta.strides, wdtype, device + ) + actual = framework.from_dlpack(cu_buf, device, wdtype) + if wdtype != actual_meta.dtype: + actual = actual.view(actual_meta.dtype) exp = expected_tensors[name] - if framework == "pytorch": - assert torch.all(exp.eq(actual)) - elif framework == "paddle": - assert paddle.all(exp.equal(actual)) + assert framework.is_equal(actual, exp) if not printed: print(actual_meta.__repr__()) printed = True -def test_load_metadata_and_dlpack_for_paddle(fstcpp_log, input_files): - if paddle_loaded: - test_load_metadata_and_dlpack(fstcpp_log, input_files, "paddle") -def test_set_debug_log(): +def test_set_debug_log() -> None: fstcpp.set_debug_log(False) assert True -def test_get_alignment_size(): + +def test_get_alignment_size() -> None: assert fstcpp.get_alignment_size() == 4096 -def test_init_gds(fstcpp_log): - assert fstcpp.init_gds(16 * 1024, 64 * 1024 * 1024) == 0 -def test_close_gds(fstcpp_log): +def test_init_gds(fstcpp_log) -> None: + assert fstcpp.init_gds() == 0 + + +def test_close_gds(fstcpp_log) -> None: assert fstcpp.close_gds() == 0 -def test_get_device_pci_bus(fstcpp_log): + +def test_get_device_pci_bus(fstcpp_log) -> None: bus = fstcpp.get_device_pci_bus(0) if not fstcpp.is_cuda_found(): - assert bus == "" + assert bus == "" else: print(f"bus for cuda:0: {bus}") assert len(bus) > 0 -def test_set_numa_node(fstcpp_log): + +def test_set_numa_node(fstcpp_log) -> None: assert fstcpp.set_numa_node(0) == 0 -def test_alloc_gds_buffer(fstcpp_log, framework="pytorch"): + +def test_alloc_gds_buffer(fstcpp_log, framework) -> None: print("test_alloc_gds_buffer") device, _ = get_and_check_device(framework) - gbuf = alloc_tensor_memory(1024, device, framework=framework) + gbuf = framework.alloc_tensor_memory(1024, device) addr = gbuf.get_base_address() assert addr != 0 -def test_alloc_gds_buffer_for_paddle(fstcpp_log): - if paddle_loaded: - test_alloc_gds_buffer(fstcpp_log, "paddle") -def test_cufile_register_deregister(fstcpp_log, framework="pytorch"): +def test_cufile_register_deregister(fstcpp_log, framework) -> None: print("test_cufile_register_deregister") device, _ = get_and_check_device(framework) - gbuf = alloc_tensor_memory(1024, device, framework=framework) + gbuf = framework.alloc_tensor_memory(1024, device) assert gbuf.cufile_register(0, 256) == 0 - assert gbuf.cufile_register(256, 1024-256) == 0 + assert gbuf.cufile_register(256, 1024 - 256) == 0 assert gbuf.cufile_deregister(0) == 0 assert gbuf.cufile_deregister(256) == 0 -def test_cufile_register_deregister_for_paddle(fstcpp_log): - if paddle_loaded: - test_alloc_gds_buffer(fstcpp_log, "paddle") -def test_memmove(fstcpp_log , framework="pytorch"): +def test_memmove(fstcpp_log, framework) -> None: print("test_memmove") device, _ = get_and_check_device(framework) - gbuf = alloc_tensor_memory(1024, device, framework=framework) - tmp = alloc_tensor_memory(1024, device, framework=framework) - assert gbuf.memmove(0, 12, tmp, 256*3) == 0 + gbuf = framework.alloc_tensor_memory(1024, device) + tmp = framework.alloc_tensor_memory(1024, device) + assert gbuf.memmove(0, 12, tmp, 256 * 3) == 0 # Confuse about this test : gbuf.memmove(0, 12, tmp, 1024) - # I think this test should start copying a section of 1024 memory - # from the position of gbuf+12 to the position of gbuf+0. - # However, this piece of memory itself is only 1024. - # After offsetting by 12, there is no 1024 left in the remaining memory. + # I think this test should start copying a section of 1024 memory + # from the position of gbuf+12 to the position of gbuf+0. + # However, this piece of memory itself is only 1024. + # After offsetting by 12, there is no 1024 left in the remaining memory. # This part really puzzles me. So I change the moving size to 256*3 (<1024) -def test_memmove_for_paddle(fstcpp_log): - if paddle_loaded: - test_memmove(fstcpp_log, "paddle") -def test_nogds_file_reader(fstcpp_log, input_files, framework="pytorch"): +def test_nogds_file_reader(fstcpp_log, input_files, framework) -> None: print("test_nogds_file_reader") fd = os.open(input_files[0], os.O_RDONLY, 0o644) s = os.fstat(fd) assert fd > 0 device, dev_is_gpu = get_and_check_device(framework) - gbuf = alloc_tensor_memory(s.st_size, device, framework=framework) + gbuf = framework.alloc_tensor_memory(s.st_size, device) reader = fstcpp.nogds_file_reader(False, 256 * 1024, 4, dev_is_gpu) step = s.st_size // 4 reqs = [] @@ -176,170 +288,199 @@ def test_nogds_file_reader(fstcpp_log, input_files, framework="pytorch"): assert reader.wait_read(req) > 0 os.close(fd) -def test_nogds_file_reader_for_paddle(fstcpp_log, input_files): - if paddle_loaded: - test_nogds_file_reader(fstcpp_log, input_files, "paddle") -def test_NoGdsFileCopier(fstcpp_log, input_files, framework="pytorch"): +def test_NoGdsFileCopier(fstcpp_log, input_files, framework) -> None: print("test_NoGdsFileCopier") meta = SafeTensorsMetadata.from_file(input_files[0], framework) device, dev_is_gpu = get_and_check_device(framework) reader = fstcpp.nogds_file_reader(False, 256 * 1024, 4, dev_is_gpu) - copier = NoGdsFileCopier(meta, device, reader, True) + copier = NoGdsFileCopier(meta, device, reader, framework, True) gbuf = copier.submit_io(False, 10 * 1024 * 1024 * 1024) - tensors = copier.wait_io(gbuf, None) - with safe_open(input_files[0], framework="pt") as f: - for key in tensors.keys(): - if framework == "pytorch": - assert torch.all(f.get_tensor(key).to(device=device).eq(tensors[key])) - elif framework == "paddle": - assert paddle.all(paddle.to_tensor(f.get_tensor(key).numpy(), place=device).equal(tensors[key])) - free_tensor_memory(gbuf, device, framework) - -def test_NoGdsFileCopier_for_paddle(fstcpp_log, input_files): - if paddle_loaded: - test_NoGdsFileCopier(fstcpp_log, input_files,"paddle") - -def test_GdsFileCopier(fstcpp_log, input_files, framework="pytorch"): + tensors = copier.wait_io(gbuf) + for key, exp in load_safetensors_file(input_files[0], device, framework).items(): + actual = tensors[key] + assert framework.is_equal(actual, exp) + framework.free_tensor_memory(gbuf, device) + + +def test_GdsFileCopier(fstcpp_log, input_files, framework) -> None: print("test_GdsFileCopier") - if not fstcpp.is_cufile_found(): - pytest.skip("cufile.so is not found") - return - meta = SafeTensorsMetadata.from_file(input_files[0], framework=framework) + meta = SafeTensorsMetadata.from_file(input_files[0], framework) device, dev_is_gpu = get_and_check_device(framework) reader = fstcpp.gds_file_reader(4, dev_is_gpu) - copier = GdsFileCopier(meta, device, reader, True) + copier = GdsFileCopier(meta, device, reader, framework, True) gbuf = copier.submit_io(False, 10 * 1024 * 1024 * 1024) - tensors = copier.wait_io(gbuf, None) - with safe_open(input_files[0], framework="pt") as f: - for key in tensors.keys(): - if framework == "torch": - assert torch.all(f.get_tensor(key).to(device=device).eq(tensors[key])) - elif framework == "paddle": - assert paddle.all(paddle.to_tensor(f.get_tensor(key).numpy(), place=device).equal(tensors[key])) - free_tensor_memory(gbuf, device, framework=framework) - -def test_GdsFileCopier_for_paddle(fstcpp_log, input_files): - if paddle_loaded: - test_GdsFileCopier(fstcpp_log, input_files, "paddle") - -def test_SafeTensorsFileLoader(fstcpp_log, input_files, framework="pytorch"): + tensors = copier.wait_io(gbuf) + for key, exp in load_safetensors_file(input_files[0], device, framework).items(): + actual = tensors[key] + assert framework.is_equal(actual, exp) + framework.free_tensor_memory(gbuf, device) + + +def test_SafeTensorsFileLoader(fstcpp_log, input_files, framework) -> None: device, _ = get_and_check_device(framework) - if framework == "pytorch": - data_type = torch.float16 - elif framework == "paddle": + if framework.get_name() == "pytorch": + import torch + + data_type = DType.F16 + data_type_real = torch.float16 + elif framework.get_name() == "paddle": # There are some lack of accuracy in paddle.float16 (about 1e-4) in cpu. - data_type = paddle.float32 + import paddle + + data_type = DType.F32 + data_type_real = paddle.float32 else: - raise NotImplementedError(f"Do not support the framework: {framework}") - loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=False, debug_log=True, framework=framework) + raise NotImplementedError( + f"Do not support the framework: {framework.get_name()}" + ) + loader = SafeTensorsFileLoader( + pg=SingleGroup(), + device=device.as_str(), + framework=framework.get_name(), + nogds=False, + debug_log=True, + ) loader.add_filenames({0: input_files}) - bufs = loader.copy_files_to_device(dtype=data_type, use_buf_register=True, max_copy_block_size=256*1024*1024) - key_dims = {key: -1 for key in loader.get_keys()} - tensors = bufs.as_dict(key_dims) + bufs = loader.copy_files_to_device( + dtype=data_type, use_buf_register=True, max_copy_block_size=256 * 1024 * 1024 + ) last_key = "" - last_shape: torch.Size = None - with safe_open(input_files[0], framework="pt") as f: - for key in tensors.keys(): - if framework == "pytorch": - exp = f.get_tensor(key).to(device=device, dtype=data_type) - assert torch.all(exp.eq(bufs.get_tensor(key))) - elif framework == "paddle": - exp = paddle.to_tensor(f.get_tensor(key).numpy(), place=device, dtype=data_type) - assert paddle.all(exp.equal(bufs.get_tensor(key))) - last_key = key - last_shape = exp.shape + last_shape: List[int] = [] + for key, exp in load_safetensors_file(input_files[0], device, framework).items(): + exp = exp.to(dtype=data_type_real) + actual = bufs.get_tensor_wrapped(key) + assert framework.is_equal(actual, exp) + last_key = key + last_shape = list(exp.shape) if last_key != "": assert bufs.get_filename(last_key) == input_files[0] assert bufs.get_shape(last_key) == last_shape assert loader.get_shape(last_key) == last_shape - assert bufs.get_filename("aaaaaaaaaaaaa") == None + assert bufs.get_filename("aaaaaaaaaaaaa") == "" bufs.close() loader.close() -def test_SafeTensorsFileLoader_for_paddle(fstcpp_log, input_files): - if paddle_loaded: - test_SafeTensorsFileLoader(fstcpp_log, input_files,"paddle") -def test_SafeTensorsFileLoaderNoGds(fstcpp_log, input_files, framework="pytorch"): +def test_SafeTensorsFileLoaderNoGds(fstcpp_log, input_files, framework) -> None: device, _ = get_and_check_device(framework) - loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=True, debug_log=True, framework=framework) + loader = SafeTensorsFileLoader( + pg=SingleGroup(), + device=device.as_str(), + framework=framework.get_name(), + nogds=True, + debug_log=True, + ) loader.add_filenames({0: input_files}) bufs = loader.copy_files_to_device() - key_dims = {key: -1 for key in loader.get_keys()} + key_dims = OrderedDict({key: -1 for key in loader.get_keys()}) tensors = bufs.as_dict(key_dims) - with safe_open(input_files[0], framework="pt") as f: - for key in tensors.keys(): - if framework == "pytorch": - assert torch.all(f.get_tensor(key).to(device=device).eq(tensors[key])) - elif framework == "paddle": - assert paddle.all(paddle.to_tensor(f.get_tensor(key).numpy(), place=device).equal(tensors[key])) + for key, exp in load_safetensors_file(input_files[0], device, framework).items(): + actual = tensors[key] + assert framework.is_equal(actual, exp) bufs.close() loader.close() -def test_SafeTensorsFileLoaderNoGds_for_paddle(fstcpp_log, input_files): - if paddle_loaded: - test_SafeTensorsFileLoaderNoGds(fstcpp_log, input_files, "paddle") -def test_fastsafe_open(fstcpp_log, input_files, framework="pt"): +def test_fastsafe_open(fstcpp_log, input_files, framework) -> None: device, _ = get_and_check_device(framework) + def weight_iterator(): - with fastsafe_open(input_files, pg=SingleGroup(), device=device, nogds=True, debug_log=True, framework=framework) as f: - for k in f.get_keys(): - t = f.get_tensor(k) + with fastsafe_open( + input_files, + device=device.as_str(), + nogds=True, + debug_log=True, + framework=framework.get_name(), + ) as f: + for k in f.keys(): + t = f.get_tensor_wrapped(k) yield k, t - tensors = {} - with safe_open(input_files[0], framework="pt") as f: - for key in f.keys(): - if framework == "pt": - tensors[key] = f.get_tensor(key).to(device=device) - elif framework == "paddle": - tensors[key] = paddle.to_tensor(f.get_tensor(key).numpy(), place=device) - for k, t in weight_iterator(): - if framework == "pt": - assert torch.all(tensors[k].eq(t)) - elif framework == "paddle": - assert paddle.all(tensors[k].equal(t)) - -def test_fastsafe_open_for_paddle(fstcpp_log, input_files): - if paddle_loaded: - test_fastsafe_open(fstcpp_log, input_files, "paddle") -def _test_type(tmp_dir, dtype, device): + tensors = load_safetensors_file(input_files[0], device, framework) + for k, t in weight_iterator(): + assert framework.is_equal(t, tensors[k]) + + with fastsafe_open( + input_files[0], + device=device.as_str(), + nogds=True, + framework=framework.get_name(), + ) as f: + for filename in f.metadata().keys(): + assert filename in input_files + + with fastsafe_open( + {0: input_files}, + device=device.as_str(), + nogds=True, + framework=framework.get_name(), + ) as f: + for k in f.keys(): + t = f.get_tensor(k) + if framework.get_name() == "pytorch": + import torch + + assert isinstance(t, torch.Tensor) + elif framework.get_name() == "paddle": + import paddle + + assert isinstance(t, paddle.Tensor) + break + + +def _test_type( + tmp_dir, + dtype: DType, + device: Device, + framework: FrameworkOpBase, + to_dtype: DType = DType.AUTO, +) -> None: filename = os.path.join(tmp_dir, f"a.safetensors") - t0 = torch.randn((8, 16), dtype=torch.float32).to(dtype=dtype) - save_file({f"a": t0}, filename, metadata={"fst": "sample"}) - with fastsafe_open(filenames=[filename], nogds=True, device=device, debug_log=True) as f: - for key in f.get_keys(): - t1 = f.get_tensor(key).clone().detach() - with safe_open(filename, framework='pt', device=device) as f: + t0 = framework.randn((8, 16), device=device, dtype=DType.F32).to(dtype=dtype) + if to_dtype is not DType.AUTO: + t0 = t0.to(dtype=to_dtype) + save_safetensors_file({f"a": t0.get_raw()}, filename, {"fst": "sample"}, framework) + t2 = load_safetensors_file(filename, device, framework, to_dtype=to_dtype) + with fastsafe_open( + filenames=[filename], + nogds=True, + device=device.as_str(), + framework=framework.get_name(), + debug_log=True, + ) as f: for key in f.keys(): - t2 = f.get_tensor(key) - assert torch.all(t2.eq(t1)) + t1 = f.get_tensor_wrapped(key).clone().detach() + assert framework.is_equal(t1, t2[key]) -def _test_type_for_paddle(tmp_dir, dtype, device): - filename = os.path.join(tmp_dir, f"a.safetensors") - t0 = paddle.randn((8, 16), dtype=paddle.float32).to(dtype=dtype) - paddle_save_file({f"a": t0}, filename, metadata={"fst": "sample"}) - with fastsafe_open(filenames=[filename], nogds=True, device=device, debug_log=True, framework="paddle") as f: - for key in f.get_keys(): - t1 = f.get_tensor(key).clone().detach() - with safe_open(filename, framework='pt') as f: - for key in f.keys(): - t2 = paddle.to_tensor(f.get_tensor(key).numpy(), place=device) - assert paddle.all(t2.equal(t1)) - -def test_int8(fstcpp_log, tmp_dir): - _test_type(tmp_dir, torch.int8, "cuda:0" if fstcpp.is_cuda_found() else "cpu") - if paddle_loaded: - _test_type_for_paddle(tmp_dir, paddle.int8, "gpu:0" if fstcpp.is_cuda_found() else "cpu") - -def test_float8_e5m2(fstcpp_log, tmp_dir): - _test_type(tmp_dir, torch.float8_e5m2, "cuda:0" if fstcpp.is_cuda_found() else "cpu") - if paddle_loaded: - _test_type_for_paddle(tmp_dir, paddle.float8_e5m2, "gpu:0" if fstcpp.is_cuda_found() else "cpu") - -def test_float8_e4m3fn(fstcpp_log, tmp_dir): - _test_type(tmp_dir, torch.float8_e4m3fn, "cuda:0" if fstcpp.is_cuda_found() else "cpu") - if paddle_loaded: - _test_type_for_paddle(tmp_dir, paddle.float8_e4m3fn, "gpu:0" if fstcpp.is_cuda_found() else "cpu") + +def test_int8(fstcpp_log, tmp_dir, framework) -> None: + if not framework.support_fp8(): + pytest.skip("FP8 is not supported") + return + device, _ = get_and_check_device(framework) + _test_type(tmp_dir, DType.I8, device, framework) + + +def test_float8_e5m2(fstcpp_log, tmp_dir, framework) -> None: + if not framework.support_fp8(): + pytest.skip("FP8 is not supported") + return + device, _ = get_and_check_device(framework) + _test_type(tmp_dir, DType.F8_E5M2, device, framework) + + +def test_float8_e4m3fn(fstcpp_log, tmp_dir, framework) -> None: + if not framework.support_fp8(): + pytest.skip("FP8 is not supported") + return + device, _ = get_and_check_device(framework) + _test_type(tmp_dir, DType.F8_E4M3, device, framework) + + +def test_float8_e4m3fn_to_int8(fstcpp_log, tmp_dir, framework) -> None: + if not framework.support_fp8(): + pytest.skip("FP8 is not supported") + return + device, _ = get_and_check_device(framework) + _test_type(tmp_dir, DType.F8_E4M3, device, framework, DType.I8) diff --git a/tests/test_multi.py b/tests/test_multi.py index 94c5f44..aaf304d 100644 --- a/tests/test_multi.py +++ b/tests/test_multi.py @@ -2,42 +2,121 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -import torch -from safetensors import safe_open +from fastsafetensors import SafeTensorsFileLoader from fastsafetensors import cpp as fstcpp -from fastsafetensors import SafeTensorsFileLoader, SingleGroup, SafeTensorsMetadata -def test_shuffle(fstcpp_log, input_files, pg): + +def test_shuffle(fstcpp_log, input_files, pg, framework): print("test_shuffle") - device = torch.device(f"cuda:0" if fstcpp.is_cuda_found() else "cpu") - loader = SafeTensorsFileLoader(pg, device, nogds=True, debug_log=True) + if framework.get_name() == "pytorch": + from safetensors.torch import load_file + + rank = pg.rank() + world_size = pg.size() + device = "cuda:0" if fstcpp.is_cuda_found() else "cpu" + elif framework.get_name() == "paddle": + from safetensors.paddle import load_file + + rank = pg.process_group.rank() + world_size = pg.process_group.size() + device = "gpu:0" if fstcpp.is_cuda_found() else "cpu" + else: + raise Exception(f"Unknown framework: {framework.get_name()}") + loader = SafeTensorsFileLoader( + pg=pg, device=device, nogds=True, framework=framework.get_name(), debug_log=True + ) loader.add_filenames({0: input_files}) bufs = loader.copy_files_to_device() key_dims = {key: -1 for key in loader.get_keys()} for i in range(0, 12): - key_dims[f"h.{i}.mlp.c_proj.weight"] = 0 - key_dims[f"h.{i}.mlp.c_fc.weight"] = 1 + key_dims[f"h.{i}.mlp.c_proj.weight"] = 1 + key_dims[f"h.{i}.mlp.c_fc.weight"] = 0 + key_dims[f"h.{i}.attn.c_proj.weight"] = 0 tensors = bufs.as_dict(key_dims) - with safe_open(input_files[0], framework="pt") as f: - for key in tensors.keys(): - dim = key_dims[key] - if dim == 0 or dim == 1: - t = f.get_slice(key) - rank_slices = () - shape = t.get_shape() - size = shape[dim] - block_size = (size + pg.size() - 1) // pg.size() - for i in range(0, len(shape)): - if i < dim: - rank_slices += (slice(None,None,None),) - elif i == dim: - rank_slices += (slice(pg.rank() * block_size, (pg.rank() + 1) * block_size, 1),) - break - t = t[rank_slices] - t = t.clone().detach() - else: - t = f.get_tensor(key) - assert torch.all(t.to(device=device).eq(tensors[key])) + f = load_file(input_files[0]) + origs = {} + for key in tensors.keys(): + dim = key_dims[key] + if dim == 0 or dim == 1: + t = f[key] + rank_slices = () + shape = t.shape + size = shape[dim] + block_size = (size + world_size - 1) // world_size + for i in range(0, len(shape)): + if i < dim: + rank_slices += (slice(None, None, None),) + elif i == dim: + rank_slices += ( + slice(rank * block_size, (rank + 1) * block_size, 1), + ) + break + t = t[rank_slices] + t = t.clone().detach() + else: + t = f[key] + t = t.to(device=device) + origs[key] = t + assert framework.is_equal(tensors[key], t) + + bufs.close() + loader.reset() + + loader.add_filenames({0: input_files}) + bufs = loader.copy_files_to_device() + if world_size > 1: + pushed = bufs.push_tensor("h.0.attn.c_proj.bias", 1) + assert pushed is not None if rank == 1 else pushed is None + pushed1 = bufs.push_tensor("h.1.attn.c_proj.bias", 0) + assert pushed1 is not None if rank == 0 else pushed1 is None + + tensors2 = bufs.get_tensor_wrapped("h.0.attn.c_proj.bias") # cached load + assert ( + framework.is_equal(tensors2, origs["h.0.attn.c_proj.bias"]) + if rank == 1 + else pushed is None + ) + + actual = bufs.get_multi_cols( + ["h.0.attn.c_proj.weight", "h.1.attn.c_proj.weight"], dim=0 + ) + if framework.get_name() == "pytorch": + import torch + + exp = torch.concat( + [origs["h.0.attn.c_proj.weight"], origs["h.1.attn.c_proj.weight"]], dim=0 + ) + assert framework.is_equal(actual, exp) + elif framework.get_name() == "paddle": + import paddle + + exp = paddle.concat( + [origs["h.0.attn.c_proj.weight"], origs["h.1.attn.c_proj.weight"]], axis=0 + ) + assert framework.is_equal(actual, exp) + + actual = bufs.get_multi_cols( + ["h.0.mlp.c_proj.weight", "h.1.mlp.c_proj.weight"], dim=1 + ) + if framework.get_name() == "pytorch": + exp = torch.concat( + [origs["h.0.mlp.c_proj.weight"], origs["h.1.mlp.c_proj.weight"]], dim=1 + ) + assert framework.is_equal(actual, exp) + elif framework.get_name() == "paddle": + exp = paddle.concat( + [origs["h.0.mlp.c_proj.weight"], origs["h.1.mlp.c_proj.weight"]], axis=1 + ) + assert framework.is_equal(actual, exp) + bufs.close() loader.close() + + +if __name__ == "__main__": + import os + import sys + + os.environ["PADDLE_DISTRI_BACKEND"] = "gloo" + sys.exit(pytest.main(sys.argv[1:])) diff --git a/tests/test_multi_paddle.py b/tests/test_multi_paddle.py deleted file mode 100644 index 6c8fda6..0000000 --- a/tests/test_multi_paddle.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024- IBM Inc. All rights reserved -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import torch -import paddle -from safetensors import safe_open - -from fastsafetensors import cpp as fstcpp -from fastsafetensors import SafeTensorsFileLoader, SingleGroup, SafeTensorsMetadata -from fastsafetensors.common import paddle_loaded - -def test_shuffle_paddle(fstcpp_log, input_files, pg_paddle): - if paddle_loaded: - device = "gpu" if paddle.device.cuda.device_count() else "cpu" - loader = SafeTensorsFileLoader(pg_paddle, device, nogds=True, debug_log=True, framework="paddle") - loader.add_filenames({0: input_files}) - bufs = loader.copy_files_to_device() - key_dims = {key: -1 for key in loader.get_keys()} - for i in range(0, 12): - key_dims[f"h.{i}.mlp.c_proj.weight"] = 0 - key_dims[f"h.{i}.mlp.c_fc.weight"] = 1 - tensors = bufs.as_dict(key_dims) - with safe_open(input_files[0], framework="pt") as f: - for key in tensors.keys(): - dim = key_dims[key] - if dim == 0 or dim == 1: - t = f.get_slice(key) - rank_slices = () - shape = t.get_shape() - size = shape[dim] - block_size = (size + pg_paddle.process_group.size() - 1) // pg_paddle.process_group.size() - for i in range(0, len(shape)): - if i < dim: - rank_slices += (slice(None,None,None),) - elif i == dim: - rank_slices += (slice(pg_paddle.process_group.rank() * block_size, (pg_paddle.process_group.rank() + 1) * block_size, 1),) - break - t = t[rank_slices] - t = t.clone().detach() - else: - t = f.get_tensor(key) - assert paddle.all(paddle.to_tensor(t.numpy(),place=loader.device).equal(tensors[key])) - bufs.close() - loader.close() diff --git a/tests/test_vllm.py b/tests/test_vllm.py index dd34a98..b586295 100644 --- a/tests/test_vllm.py +++ b/tests/test_vllm.py @@ -5,11 +5,22 @@ import vllm from vllm.config import LoadFormat + def test_vllm_no_fastsafetensors(fstcpp_log): - _ = vllm.LLM(model="ibm-granite/granite-3.0-8b-instruct") + _ = vllm.LLM(model="ibm-granite/granite-3.3-2b-instruct", max_model_len=16) + def test_vllm_fastsafetensors(fstcpp_log): - _ = vllm.LLM(model="ibm-granite/granite-3.0-8b-instruct", load_format=LoadFormat.FASTSAFETENSORS) + _ = vllm.LLM( + model="ibm-granite/granite-3.3-2b-instruct", + load_format=LoadFormat.FASTSAFETENSORS, + max_model_len=16, + ) + def test_deepseek_r1(fstcpp_log): - _ = vllm.LLM(model="silence09/DeepSeek-R1-Small-2layers", load_format=LoadFormat.FASTSAFETENSORS) \ No newline at end of file + _ = vllm.LLM( + model="silence09/DeepSeek-R1-Small-2layers", + load_format=LoadFormat.FASTSAFETENSORS, + max_model_len=16, + )