Skip to content

Commit

Permalink
Merge pull request #878 from LaurentMazare/2.4
Browse files Browse the repository at this point in the history
Support for PyTorch 2.4
  • Loading branch information
LaurentMazare authored Jul 24, 2024
2 parents a90854b + 0d6c6b2 commit a4e9362
Show file tree
Hide file tree
Showing 20 changed files with 210,487 additions and 75 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Changed

## v0.16.0
### Changed
- PyTorch v2.4 support

## v0.16.0
### Changed
- PyTorch v2.3 support
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tch"
version = "0.16.0"
version = "0.17.0"
authors = ["Laurent Mazare <[email protected]>"]
edition = "2021"
build = "build.rs"
Expand All @@ -22,7 +22,7 @@ libc = "0.2.0"
ndarray = "0.15"
rand = "0.8"
thiserror = "1"
torch-sys = { version = "0.16.0", path = "torch-sys" }
torch-sys = { version = "0.17.0", path = "torch-sys" }
zip = "0.6"
half = "2"
safetensors = "0.3.0"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The code generation part for the C api on top of libtorch comes from

## Getting Started

This crate requires the C++ PyTorch library (libtorch) in version *v2.3.0* to be available on
This crate requires the C++ PyTorch library (libtorch) in version *v2.4.0* to be available on
your system. You can either:

- Use the system-wide libtorch installation (default).
Expand Down Expand Up @@ -85,7 +85,7 @@ seem to include `libtorch.a` by default so this would have to be compiled
manually, e.g. via the following:

```bash
git clone -b v2.3.0 --recurse-submodule https://github.com/pytorch/pytorch.git pytorch-static --depth 1
git clone -b v2.4.0 --recurse-submodule https://github.com/pytorch/pytorch.git pytorch-static --depth 1
cd pytorch-static
USE_CUDA=OFF BUILD_SHARED_LIBS=OFF python setup.py build
# export LIBTORCH to point at the build directory in pytorch-static.
Expand Down
2 changes: 1 addition & 1 deletion examples/min-gpt/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ fn causal_self_attention(p: &nn::Path, cfg: Config) -> impl ModuleT {
let q = xs.apply(&query).view(sizes).transpose(1, 2);
let v = xs.apply(&value).view(sizes).transpose(1, 2);
let att = q.matmul(&k.transpose(-2, -1)) * (1.0 / f64::sqrt(sizes[3] as f64));
let att = att.masked_fill(&mask.i((.., .., ..sz_t, ..sz_t)).eq(0.), std::f64::NEG_INFINITY);
let att = att.masked_fill(&mask.i((.., .., ..sz_t, ..sz_t)).eq(0.), f64::NEG_INFINITY);
let att = att.softmax(-1, Kind::Float).dropout(cfg.attn_pdrop, train);
let ys = att.matmul(&v).transpose(1, 2).contiguous().view([sz_b, sz_t, sz_c]);
ys.apply(&proj).dropout(cfg.resid_pdrop, train)
Expand Down
6 changes: 3 additions & 3 deletions examples/python-extension/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.21", features = ["extension-module"] }
pyo3-tch = { path = "../../pyo3-tch", version = "0.16.0" }
tch = { path = "../..", features = ["python-extension"], version = "0.16.0" }
torch-sys = { path = "../../torch-sys", features = ["python-extension"], version = "0.16.0" }
pyo3-tch = { path = "../../pyo3-tch", version = "0.17.0" }
tch = { path = "../..", features = ["python-extension"], version = "0.17.0" }
torch-sys = { path = "../../torch-sys", features = ["python-extension"], version = "0.17.0" }
7 changes: 5 additions & 2 deletions gen/gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ let excluded_functions =
; "_cummax_helper"
; "retain_grad"
; "_validate_sparse_coo_tensor_args"
; "_sparse_semi_structured_addmm"
; "_backward"
; "size"
; "stride"
Expand Down Expand Up @@ -92,6 +93,7 @@ let excluded_prefixes =
; "_amp_foreach"
; "_nested_tensor"
; "_fused_adam"
; "_fused_adagrad"
; "sym_"
; "_fused_sgd"
]
Expand Down Expand Up @@ -168,6 +170,7 @@ module Func = struct
| "at::tensoroptions" -> Some TensorOptions
| "at::intarrayref" -> Some (if is_nullable then IntListOption else IntList)
| "at::arrayref<double>" -> Some DoubleList
| "const c10::list<::std::optional<at::tensor>> &"
| "const c10::list<c10::optional<at::tensor>> &" -> Some TensorOptList
| "const at::itensorlistref &" | "at::tensorlist" -> Some TensorList
| "at::device" -> Some Device
Expand Down Expand Up @@ -590,7 +593,7 @@ let write_cpp funcs filename =
let pc s = p out_cpp s in
let ph s = p out_h s in
pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
pc "#include \"%s.h\"" (Caml.Filename.basename filename);
pc "#include \"%s.h\"" (Stdlib.Filename.basename filename);
pc "";
ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
ph "#include \"torch_api.h\"";
Expand Down Expand Up @@ -879,7 +882,7 @@ let run

let () =
run
~yaml_filename:"third_party/pytorch/Declarations-v2.3.0.yaml"
~yaml_filename:"third_party/pytorch/Declarations-v2.4.0.yaml"
~cpp_filename:"torch-sys/libtch/torch_api_generated"
~ffi_filename:"torch-sys/src/c_generated.rs"
~wrapper_filename:"src/wrappers/tensor_generated.rs"
Expand Down
8 changes: 4 additions & 4 deletions pyo3-tch/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pyo3-tch"
version = "0.16.0"
version = "0.17.0"
authors = ["Laurent Mazare <[email protected]>"]
edition = "2021"
build = "build.rs"
Expand All @@ -12,6 +12,6 @@ categories = ["science"]
license = "MIT/Apache-2.0"

[dependencies]
tch = { path = "..", features = ["python-extension"], version = "0.16.0" }
torch-sys = { path = "../torch-sys", features = ["python-extension"], version = "0.16.0" }
pyo3 = { version = "0.21", features = ["extension-module"] }
tch = { path = "..", features = ["python-extension"], version = "0.17.0" }
torch-sys = { path = "../torch-sys", features = ["python-extension"], version = "0.17.0" }
pyo3 = { version = "0.21", features = ["extension-module"] }
4 changes: 2 additions & 2 deletions src/nn/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError>
// Optimize the case for which a single C++ code can be done.
if cst == 0. {
Tensor::f_zeros(dims, (Kind::Float, device))
} else if (cst - 1.).abs() <= std::f64::EPSILON {
} else if (cst - 1.).abs() <= f64::EPSILON {
Tensor::f_ones(dims, (Kind::Float, device))
} else {
Tensor::f_ones(dims, (Kind::Float, device)).map(|t| t * cst)
Expand All @@ -117,7 +117,7 @@ pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError>
Tensor::f_zeros(dims, (Kind::Float, device))?.f_uniform_(lo, up)
}
Init::Randn { mean, stdev } => {
if mean == 0. && (stdev - 1.).abs() <= std::f64::EPSILON {
if mean == 0. && (stdev - 1.).abs() <= f64::EPSILON {
Tensor::f_randn(dims, (Kind::Float, device))
} else {
Tensor::f_randn(dims, (Kind::Float, device)).map(|t| t * stdev + mean)
Expand Down
Loading

0 comments on commit a4e9362

Please sign in to comment.