Skip to content

Conversation

@tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Oct 22, 2025

Motivation

Loading large model weights has been a bottleneck in development. Moreover, there are more and more models that do not have smaller model size counterpart for rapid prototyping and development.
The cold start time during deployment is also a huge downside.

Lucky to have found out this hidden GEM which also has great integration with vLLM.

Purpose

Add support on ROCm. Since at the moment there are no alternative to NVIDIA-GDS, this enable is without GDS support.

From the news on ROCm 7.9 Tech Preview , there is a new library called hipFile. Hope that we can enable GDS when it is released.

Performance

When loading DeepSeek-R1 weight with TP8, we saw a whooping 7.4x improvement compared to reading using safetensors from NVMe using vLLM.

image

NOTE: On vLLM it will by default try to load fastsafetensors with gds=True, and if gds=True is not supported it will fallback to gds=False. Hey, there is no need to update vLLM code and it works after installing this fastsafetensors ROCm support !

Tests

It passes all of the core unit test:

  1. tests/test_multi.py
  2. tests/test_fastsafetensors.py

I have also been heavily using it for my development on ROCm since the enablement of this feature.

Setup steps.

I have updated the README.md.
It is as easy as python3 setup.py develop

UPDATES on setup steps:

Install from Github Source

pip install git+https://github.com/foundation-model-stack/fastsafetensors.git

Install from source

pip install .

@takeshi-yoshimura are you interested in collaborating on a vLLM blog post at https://blog.vllm.ai/ showcasing how effective it is fastsafetensors in improving cold start efficiency.

@takeshi-yoshimura
Copy link
Collaborator

excellent work! I will review your changes later this week or early next week. Of course, I can help the blog post!

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Oct 23, 2025

@takeshi-yoshimura
Sounds great. Let's discuss further through Slack https://vllm-dev.slack.com/team/U07RGJZ17Q8 or email [email protected]

@takeshi-yoshimura takeshi-yoshimura self-requested a review October 23, 2025 06:30
Copy link
Collaborator

@takeshi-yoshimura takeshi-yoshimura left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A major comments are twofold.

  1. Fastsafetensors dynamically load shared libraries at run time, not compile time. I want you to try doing the same on ROCm library. If that is possible, general ROCm users can avoid building code by themselves, can just use downloaded pip packages.
  2. Please update and use get_cuda_ver() in FrameworkOpBase instead of torch.cuda.is_available().

@@ -0,0 +1,33 @@
/*
* Copyright 2024 IBM Inc. All rights reserved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided not to display Copyright at the first line.
So, please start with SPDX-License-Identifier: Apache-2.0 for a new file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, maybe my comment was confusing. please delete the line for IBM.

// Custom function pointer names that hipify-perl doesn't recognize
// These are our own naming in ext_funcs struct, not standard CUDA API
#ifdef USE_ROCM
#define cudaDeviceMalloc hipDeviceMalloc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know those kinds of macros enable minimum change for switching to hip APIs, but this will likely cause confusion. We could change callback registration at load_nvidia_functions. But, it now should be renamed to load_gpu_functions or another appropriate name and potentially cause even more renames... I will work on refactoring this later, and so, please just keep them.

"""Detect if we're running on ROCm or CUDA"""
try:
import torch
if torch.cuda.is_available():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to update get_cuda_ver in frameworks/_torch.py and use it instead of calling torch directly? We have a different framework than torch and this change may cause breakage on them. You can add suffix to determine hip or cuda.



def MyExtension(name, sources, mod_name, *args, **kwargs):
def detect_platform():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is okay to have this at this moment, but this is not useful for general users to rebuild code by hand.
I hope to see a simplified setup.py with unified binaries that dynamically switch runtimes.

#ifndef USE_ROCM
#define USE_ROCM
#endif
#include <hip/hip_runtime.h>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if you noticed or not, but fastsafetensors does not include CUDA headers to avoid CUDA dependencies (no -lcudart or other link options in setup.py). It dynamically searches shared library (i.e., libcudart.so) and load symbols on demand with dlopen. I guess this header include is not required if you follow this procedure, but have you tried yet? If my guess is correct, you can simplify setup.py and general users can even use ROCm via pip install fastasfetensors instead of building the code by themselves.

def is_rocm_platform():
"""Detect if running on ROCm/AMD platform."""
try:
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please do not use framework specific calls as I commented in another file.

@takeshi-yoshimura
Copy link
Collaborator

One more request is to make sure passing lint.

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Oct 27, 2025

@takeshi-yoshimura Can you take another look. I have addressed your comments.

Copy link
Collaborator

@takeshi-yoshimura takeshi-yoshimura left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I could not promptly respond. Regarding changes around setup.py, I need to learn more about hipify (e.g., I still don't understand how it converts dynamic function pointers), but it seems not to influence the original build process. So, let's go ahead with your proposal at this moment. maybe I will simplify it later if possible.

Before merging your changes, please make sure passing all the tests. You can find test logs in the last artifact of each run in the Action page.

________________________________ test_framework ________________________________

fstcpp_log = None
framework = <fastsafetensors.frameworks._paddle.PaddleOp object at 0x7f3e0fdd65b0>

    def test_framework(fstcpp_log, framework) -> None:
        t = framework.get_empty_tensor([1], DType.F16, Device.from_str("cpu"))
        with pytest.raises(Exception):
            framework.is_equal(t, [float(0.0)])
        with pytest.raises(Exception):
            framework.get_process_group(int(0))
        # Test that get_cuda_ver() returns a string with platform prefix
        cuda_ver = framework.get_cuda_ver()
        assert isinstance(cuda_ver, str)
        # Should be "hip-X.Y.Z", "cuda-X.Y", or "none"
>       assert (
            cuda_ver.startswith("hip-")
            or cuda_ver.startswith("cuda-")
            or cuda_ver == "none"
        )
E       AssertionError: assert (False or False or '0.0' == 'none'
E        +  where False = <built-in method startswith of str object at 0x7f3e0e39db70>('hip-')
E        +    where <built-in method startswith of str object at 0x7f3e0e39db70> = '0.0'.startswith
E        +  and   False = <built-in method startswith of str object at 0x7f3e0e39db70>('cuda-')
E        +    where <built-in method startswith of str object at 0x7f3e0e39db70> = '0.0'.startswith
E         
E         - none
E         + 0.0)

test_fastsafetensors.py:117: AssertionError

@@ -0,0 +1,33 @@
/*
* Copyright 2024 IBM Inc. All rights reserved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, maybe my comment was confusing. please delete the line for IBM.

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Nov 7, 2025

@takeshi-yoshimura It is ready for review.

Hi. I have fixed all the unit tests, both torch and paddle paddle unit tests are passing.

I have setup Two new Github Actions:

  1. (Github Action 1) Build the fastsafetensors rocm packages for python 3.9 to python 3.13
  2. We manually create two tags in github release for two architectures: v0.1.15-cuda and v0.1.15-rocm
  3. (Github Action 2) Then we manually trigger the github page CI to setup Github PyPI index. It will look for the wheels in all releases with the following format vx.x.x-cuda and vx.x.x-rocm and add to the Github pypi index page.
  4. Users can use the following way to choose whether to install for cuda or rocm (This instruction is already working)
pip install fastsafetensors --index-url https://embeddedllm.github.io/fastsafetensors-rocm/rocm/simple/
pip install fastsafetensors --index-url https://embeddedllm.github.io/fastsafetensors-rocm/cuda/simple/

This is the webpage for proof of concept https://embeddedllm.github.io/fastsafetensors-rocm/ .

We could work together to make sure users can

pip install fastsafetensors --index-url https://foundation-model-stack.github.io/fastsafetensors/rocm/simple/
pip install fastsafetensors --index-url https://foundation-model-stack.github.io/fastsafetensors/cuda/simple/

Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
  - Add script to auto-generate PyPI-compatible index from GitHub releases
  - Create separate indexes for CUDA and ROCm wheels based on release tags
  - Add GitHub Actions workflow to deploy indexes to GitHub Pages
  - Generate landing page with installation instructions
  - Ignore auto-generated pypi-index directory

  Indexes will be available at:
  - ROCm: https://embeddedllm.github.io/fastsafetensors-rocm/rocm/simple/
  - CUDA: https://embeddedllm.github.io/fastsafetensors-rocm/cuda/simple/

Signed-off-by: tjtanaa <[email protected]>
between different GPU platforms without using torch directly.
"""
if torch.cuda.is_available():
return str(torch.version.cuda)
Copy link
Contributor Author

@tjtanaa tjtanaa Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@takeshi-yoshimura This part where it is different from previous is here. And this is the part where it is causing the issue with unit tests. Let me know if it is an issue. Thank you.

Reverted the behavior, when on CPU, it returns version string of "0.0"

return (
str(paddle.version.cuda())
if paddle.device.is_compiled_with_cuda()
else "0.0"
Copy link
Contributor Author

@tjtanaa tjtanaa Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@takeshi-yoshimura Same here, this part where it is different from previous is here. And this is the part where it is causing the issue with unit tests. Let me know if it is an issue. Thank you.

Reverted the behavior, when on CPU, it returns version string of "0.0"

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Nov 10, 2025

@takeshi-yoshimura
I have a proposal as another pathway to release the wheels for both CUDA and ROCm, which is through single Github PyPI Index.

Creating Wheels for Releases

There are two Github Actions that have to be run to create the wheels.

  1. Run publish.yaml to build the CUDA compatible wheels. Find the built wheels in the artifacts.
  2. Run build-rocm-wheels.yaml to build the ROCm compatible wheels. Find the built wheels in the artifacts.

Create Releases

The convention of the releases on Github should be vX.X.X-cuda for CUDA wheels release, and vX.X.X-rocm for ROCm wheel release.

Remember to unzip the artifacts and upload the individual wheels to the Github release page.

(Optional) Create Github PyPI Index

If you have not setup the Github Pages to host the Github PyPI Index, follow the instruction in GITHUB_PAGES_SETUP_GUIDE.md.
GITHUB_PAGES_SETUP_GUIDE.md

setup-github-pages-pypi.sh

Update the Github PyPI Index to Include the New Version Wheels

Trigger the workflow publish-to-index.yaml. In the version, type in the version vX.X.X e.g. v0.1.15.
Platforms select both if you want to release for both CUDA and ROCm. (Recommended)

When the Github Action is completed, you should be able to install through pip install

pip install fastsafetensors --index-url https://foundation-model-stack.github.io/fastsafetensors/rocm/simple/
pip install fastsafetensors --index-url https://foundation-model-stack.github.io/fastsafetensors/cuda/simple/

@takeshi-yoshimura
Copy link
Collaborator

Thanks. I would like to add three requests:

  1. We've just had a new change in parallel. please resolve conflicts.
  2. Please do not add two new proposals in a PR. So, please revert changes under .github/.
  3. Looks like your changes under .github/ has too big influences to all the operations of this package. For example, changing version names and distribution method are far beyond an individual contribution. So, please raise a new issue and a new draft PR about the proposal. Let's have discussion about it on the issue page.

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Nov 11, 2025

@takeshi-yoshimura This PR is ready for review.
All your requests have been addressed.

  1. Sync with new changes
  2. Remove the github workflow

The test-torch and test-paddle passed on my Github Action.
The unit tests are passing on ROCm.

I have also updated the instruction as we can now just install from git+https://github.com/foundation-model-stack/fastsafetensors.git , like this

pip install git+https://github.com/foundation-model-stack/fastsafetensors.git

@takeshi-yoshimura takeshi-yoshimura merged commit 1de5f13 into foundation-model-stack:main Nov 11, 2025
13 checks passed
@takeshi-yoshimura
Copy link
Collaborator

Great work! thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants