From 2741e898547001a3979e4a2cb4c1834a6ec33c26 Mon Sep 17 00:00:00 2001 From: jahatef Date: Sun, 31 Mar 2024 19:45:56 +0000 Subject: [PATCH 01/20] add rwkv support --- megatron/model/gpt2_model.py | 13 ++++++++++++- megatron/model/init_functions.py | 1 + megatron/neox_arguments/deepspeed_args.py | 4 +++- megatron/neox_arguments/neox_args.py | 13 ++++++++++++- 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index e083351cc..29aecb093 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -37,6 +37,7 @@ ParallelLinear, ) from megatron.model.gmlp import GMLPBlock +from megatron.model.rwkv import RWKVResidualLayerPipe from megatron.model.word_embeddings import EmbeddingPipe, SoftEmbedding # Pipeline parallelism @@ -166,7 +167,9 @@ def insert_layers( topology=self.__topology__, activation_checkpoint_interval=self.activation_checkpoint_interval, partition_method=self.neox_args.pipe_partition_method, - checkpointable_layers=["GMLPBlock", "ParallelTransformerLayerPipe"], + checkpointable_layers=["GMLPBlock", + "ParallelTransformerLayerPipe", + "RWKVResidualLayerPipe"], ) def init_specs(self): @@ -242,6 +245,14 @@ def init_specs(self): mask_fn=gpt2_attention_mask_func, ) ) + elif layer_type == "rwkv": + self.specs.append( + LayerSpec( + RWKVResidualLayerPipe, + self.neox_args, + i, + ) + ) else: self.specs.append( LayerSpec( diff --git a/megatron/model/init_functions.py b/megatron/model/init_functions.py index ad8ebc05a..e903a4eb1 100644 --- a/megatron/model/init_functions.py +++ b/megatron/model/init_functions.py @@ -15,6 +15,7 @@ import math import torch +from rwkv import TimeMixing, ChannelMixing try: import mup diff --git a/megatron/neox_arguments/deepspeed_args.py b/megatron/neox_arguments/deepspeed_args.py index 708a5f5b1..a4d5c6a10 100644 --- a/megatron/neox_arguments/deepspeed_args.py +++ b/megatron/neox_arguments/deepspeed_args.py @@ -100,7 +100,9 @@ class NeoXArgsDeepspeedConfig(NeoXArgsTemplate): bf16: dict = None """ - Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options + Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). + + Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options """ # ---Automatic Mixed Precision (AMP) Training Options--- diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 76da42a43..f0b43afec 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -36,6 +36,7 @@ "gmlp", "amlp", "flash", + "rwkv", ] @@ -210,7 +211,7 @@ class NeoXArgsModel(NeoXArgsTemplate): The first item in the list specifies the attention type(s), and should be a list of strings. The second item specifies the number of times to repeat those attention types in the full list. - attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash"] + attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "rwkv"] So a 12 layer network with only global attention could be specified like: [[[`global`], 12]] @@ -1141,6 +1142,16 @@ class NeoXArgsTraining(NeoXArgsTemplate): What to scale width by when creating the delta model for mup """ + rwkv_pre_ffn: bool = False + """ + Use channel mix block as first time mix block + """ + + rwkv_mishglu: bool = False + """ + mishglu ffn + """ + @dataclass class NeoXArgsTextgen(NeoXArgsTemplate): From 282f800c1002c97608b9325cc3a8afd7b078be7c Mon Sep 17 00:00:00 2001 From: Jacob Hatef <74274091+jahatef@users.noreply.github.com> Date: Sun, 31 Mar 2024 15:48:46 -0400 Subject: [PATCH 02/20] Update init_functions.py --- megatron/model/init_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/model/init_functions.py b/megatron/model/init_functions.py index e903a4eb1..ad8ebc05a 100644 --- a/megatron/model/init_functions.py +++ b/megatron/model/init_functions.py @@ -15,7 +15,6 @@ import math import torch -from rwkv import TimeMixing, ChannelMixing try: import mup From 09ba65ef0c23ef85519280e4cb2bff6f67a0e7d8 Mon Sep 17 00:00:00 2001 From: jahatef Date: Sun, 31 Mar 2024 19:49:38 +0000 Subject: [PATCH 03/20] rwkv model files --- megatron/model/rwkv/__init__.py | 1 + megatron/model/rwkv/cuda/__init__.py | 43 ++++ megatron/model/rwkv/cuda/setup.py | 88 +++++++ megatron/model/rwkv/cuda/wkv6_cuda.cu | 242 ++++++++++++++++++++ megatron/model/rwkv/cuda/wkv6_op.cpp | 22 ++ megatron/model/rwkv/rwkv.py | 318 ++++++++++++++++++++++++++ 6 files changed, 714 insertions(+) create mode 100644 megatron/model/rwkv/__init__.py create mode 100644 megatron/model/rwkv/cuda/__init__.py create mode 100644 megatron/model/rwkv/cuda/setup.py create mode 100644 megatron/model/rwkv/cuda/wkv6_cuda.cu create mode 100644 megatron/model/rwkv/cuda/wkv6_op.cpp create mode 100644 megatron/model/rwkv/rwkv.py diff --git a/megatron/model/rwkv/__init__.py b/megatron/model/rwkv/__init__.py new file mode 100644 index 000000000..5c75f6622 --- /dev/null +++ b/megatron/model/rwkv/__init__.py @@ -0,0 +1 @@ +from .rwkv import RWKVResidualLayerPipe, RWKVResidualLayer \ No newline at end of file diff --git a/megatron/model/rwkv/cuda/__init__.py b/megatron/model/rwkv/cuda/__init__.py new file mode 100644 index 000000000..5eea67020 --- /dev/null +++ b/megatron/model/rwkv/cuda/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib +import subprocess + +from pathlib import Path + +srcpath = Path(__file__).parent.absolute() + +# Setting this param to a list has a problem of generating different +# compilation commands (with different order of architectures) and +# leading to recompilation of fused kernels. Set it to empty string +# to avoid recompilation and assign arch flags explicitly in +# extra_cuda_cflags below +os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + +def load_fused_kernels(): + try: + import rwkv_cuda + except (ImportError, ModuleNotFoundError) as e: + print("\n") + print(e) + print("=" * 100) + print( + f"ERROR: RWKV kernels configured but not properly installed. Please run `pip install {str(srcpath)}` to install them" + ) + print("=" * 100) + exit() + return diff --git a/megatron/model/rwkv/cuda/setup.py b/megatron/model/rwkv/cuda/setup.py new file mode 100644 index 000000000..f9e560f16 --- /dev/null +++ b/megatron/model/rwkv/cuda/setup.py @@ -0,0 +1,88 @@ +# Copyright (c) 2024, EleutherAI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from setuptools import setup, find_packages +from torch.utils import cpp_extension +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +from pathlib import Path +import subprocess + + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + +class CommandMixin(object): + user_options = [ + ('head_size=', 64, 'head size for the kernel'), + ('max_seq_length=', 512, 'maximum sequence length for the kernel') + ] + + def initialize_options(self): + super().initialize_options() + # Initialize options + self.head_size = 64 + self.max_seq_length = 512 + + def finalize_options(self): + # Validate options + if self.head_size <= 0: + raise ValueError("head_size must be positive") + if self.max_seq_length <= 0: + raise ValueError("max_seq_length must be positive") + super().finalize_options() + + def run(self): + # Use options + global head_size, max_seq_length + head_size = self.head_size + max_seq_length = self.max_seq_length + global cuda_ext_args + cuda_ext_args = ["-res-usage", + "--use_fast_math", + "-O3", "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={head_size}", + f"-D_T_={max_seq_length}"] + print("here") + super().run() + +class ExtensionCommand(CommandMixin, BuildExtension): + user_options = getattr(BuildExtension, 'user_options', []) + CommandMixin.user_options + +srcpath = Path(__file__).parent.absolute() + +setup( + name="rwkv_cuda", + include_package_data=False, + ext_modules=[ + CUDAExtension( + name="wkv6_cuda", + sources=[ + str(srcpath / "wkv6_op.cpp"), + str(srcpath / "wkv6_cuda.cu"), + ], + extra_compile_args=cuda_ext_args, + ) + ], + cmdclass={"build_ext": ExtensionCommand}, +) diff --git a/megatron/model/rwkv/cuda/wkv6_cuda.cu b/megatron/model/rwkv/cuda/wkv6_cuda.cu new file mode 100644 index 000000000..7b7c8366c --- /dev/null +++ b/megatron/model/rwkv/cuda/wkv6_cuda.cu @@ -0,0 +1,242 @@ +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + float state[_N_] = {0}; + + __syncthreads(); + u[i] = float(_u[i]); + __syncthreads(); + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + w[i] = exp(_w[t]); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } +} + +template +__global__ void kernel_backward_111(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, + F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + + __shared__ float u_[_N_]; + __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; + __syncthreads(); + u_[i] = float(_u[i]); + __syncthreads(); + + const float u = u_[i]; + + float state[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; + + const int t_0 = b*T*C + h*_N_ + i; + const int t_T_1 = t_0 + (T-1)*C; + const int t_T = t_0 + T*C; + + float gu = 0; + for (int t = t_0; t < t_T; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float k = float(_k[t]); + const float w = exp(_w[t]); + float gr = 0, gu_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = state[j]; + float x = k * v[j]; + + gr += (u * x + s) * gy[j]; + gu_ += x * gy[j]; + s = s * w + x; + } + _gr[t] = F(gr); + gu += float(_r[t]) * gu_; + } + _gu[b*C + h*_N_ + i] = F(gu); + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float rr = float(_r[t]); + const float w = exp(_w[t]); + float gk = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + float x = rr * gy[j]; + + gk += (u * x + s) * v[j]; + s = x + s * w; + } + _gk[t] = F(gk); + } + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + w_[i] = exp(_w[t]); + __syncthreads(); + + const float gyy = float(_gy[t]); + float gv = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = sdddd[j]; + float x = gyy * r[j]; + + gv += (u_[j] * x + s) * k[j]; + s = x + s * w_[j]; + } + _gv[t] = F(gv); + } +} + +template +__global__ void kernel_backward_222(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, + F *__restrict__ const _gw) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + + __shared__ float v[_N_], gy[_N_]; + float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0}; + + const int t_0 = b*T*C + h*_N_ + i; + const int t_1 = t_0 + C; + const int t_2 = t_0 + 2*C; + const int t_T_1 = t_0 + (T-1)*C; + + for (int t = t_T_1; t > t_1; t -= C) + { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t-2*C]); + __syncthreads(); + + const float r = float(_r[t]); + const float w = exp(_w[t-C]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + float x = r * gy[j]; + s = (s + x) * w; + sum += s * v[j]; + } + sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]); + } + + float sss = sbbbb[0]; + _gw[t_0] = 0; + _gw[t_1] = F(sss * _w[t_1]); + + for (int t = t_2; t < t_T_1; t += C) + { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t-2*C]); + __syncthreads(); + + const float w = exp(_w[t-C]); + const float k = float(_k[t-2*C]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + float x = k * v[j]; + s = (s + x) * w; + sum += s * gy[j]; + } + sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t])); + _gw[t] = F(sss * _w[t]); + } + _gw[t_T_1] = 0; +} + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); +} + +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_backward_111<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gu); + kernel_backward_222<<>>(B, T, C, H, r, k, v, w, u, gy, gw); +} diff --git a/megatron/model/rwkv/cuda/wkv6_op.cpp b/megatron/model/rwkv/cuda/wkv6_op.cpp new file mode 100644 index 000000000..56d9fbc62 --- /dev/null +++ b/megatron/model/rwkv/cuda/wkv6_op.cpp @@ -0,0 +1,22 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); + +void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { + cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv6 forward"); + m.def("backward", &backward, "wkv6 backward"); +} + +TORCH_LIBRARY(wkv6, m) { + m.def("forward", forward); + m.def("backward", backward); +} \ No newline at end of file diff --git a/megatron/model/rwkv/rwkv.py b/megatron/model/rwkv/rwkv.py new file mode 100644 index 000000000..5b50c9b1a --- /dev/null +++ b/megatron/model/rwkv/rwkv.py @@ -0,0 +1,318 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import os, math, gc, importlib +import torch +# torch._C._jit_set_profiling_executor(True) +# torch._C._jit_set_profiling_mode(True) +import torch.nn as nn +from torch.nn import functional as F +from rwkv.cuda import rwkv_cuda + +######################################################################################################## +# CUDA Kernel +######################################################################################################## + +from torch.utils.cpp_extension import load + +HEAD_SIZE = 64 + +#wkv_cuda = load(name="wkv6", sources=["/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_op.cpp", f"/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_cuda.cu"], +# verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={512}"]) + +class WKV(torch.autograd.Function): + @staticmethod + def forward(WKV, ctx, B, T, C, H, r, k, v, w, u): + with torch.no_grad(): + assert r.dtype == torch.bfloat16 + assert k.dtype == torch.bfloat16 + assert v.dtype == torch.bfloat16 + assert w.dtype == torch.bfloat16 + assert u.dtype == torch.bfloat16 + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + ew = (-torch.exp(w.float())).contiguous() + ctx.save_for_backward(r, k, v, ew, u) + y = torch.empty((B, T, C), device=r.device, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100) + wkv_cuda.forward(B, T, C, H, r, k, v, ew, u, y) + return y + + @staticmethod + def backward(ctx, gy): + with torch.no_grad(): + assert gy.dtype == torch.bfloat16 + B = ctx.B + T = ctx.T + C = ctx.C + H = ctx.H + assert gy.is_contiguous() + r, k, v, ew, u = ctx.saved_tensors + gr = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100) + gk = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100) + gv = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100) + gw = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100) + gu = torch.empty((B, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100) + wkv_cuda.backward(B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu) + gu = torch.sum(gu, 0).view(H, C//H) + return (None, None, None, None, gr, gk, gv, gw, gu) + +def RUN_CUDA_RWKV(WKV, B, T, C, H, r, k, v, w, u): + return WKV.apply(B, T, C, H, r, k, v, w, u) + +# RWKV6 time mix +class RWKV_TimeMix(nn.Module): + def __init__(self, neox_args, layer_number): + super().__init__() + self.neox_args = neox_args + self.layer_number = layer_number + + self.head_size = 64 + self.num_attention_heads = neox_args.dim_att // self.head_size + assert neox_args.dim_att % self.num_attention_heads == 0 + self.wkv_cuda = load(name="wkv6", sources=["/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_op.cpp", + f"/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_cuda.cu"], + verbose=True, + extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={self.neox_args.head_size}", f"-D_T_={self.neox_args.seq_length}"]) + + + with torch.no_grad(): + ratio_0_to_1 = layer_number / (neox_args.num_layers - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_number / neox_args.num_layers) # 1 to ~0 + ddd = torch.ones(1, 1, neox_args.hidden_size) + for i in range(neox_args.hidden_size): + ddd[0, 0, i] = i / neox_args.hidden_size + + # fancy time_mix + self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_v = nn.Parameter(1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)) + self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)) + self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)) + + TIME_MIX_EXTRA_DIM = 32 # generate TIME_MIX for w,k,v,r,g + self.time_maa_w1 = nn.Parameter(torch.zeros(neox_args.hidden_size, TIME_MIX_EXTRA_DIM*5).uniform_(-1e-4, 1e-4)) + self.time_maa_w2 = nn.Parameter(torch.zeros(5, TIME_MIX_EXTRA_DIM, neox_args.hidden_size).uniform_(-1e-4, 1e-4)) + + # fancy time_decay + decay_speed = torch.ones(neox_args.dim_att) + for n in range(neox_args.dim_att): + decay_speed[n] = -6 + 5 * (n / (neox_args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) + self.time_decay = nn.Parameter(decay_speed.reshape(1,1,neox_args.dim_att)) + + TIME_DECAY_EXTRA_DIM = 64 + self.time_decay_w1 = nn.Parameter(torch.zeros(neox_args.hidden_size, TIME_DECAY_EXTRA_DIM).uniform_(-1e-4, 1e-4)) + self.time_decay_w2 = nn.Parameter(torch.zeros(TIME_DECAY_EXTRA_DIM, neox_args.dim_att).uniform_(-1e-4, 1e-4)) + + tmp = torch.zeros(neox_args.dim_att) + for n in range(neox_args.dim_att): + zigzag = ((n + 1) % 3 - 1) * 0.1 + tmp[n] = ratio_0_to_1 * (1 - (n / (neox_args.dim_att - 1))) + zigzag + + self.time_faaaa = nn.Parameter(tmp.reshape(self.num_attention_heads, self.head_size)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.receptance = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) + self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) + + self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) + self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False) + self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) + self.ln_x = nn.GroupNorm(self.num_attention_heads, neox_args.dim_att, eps=(1e-5)*(8**2)) + + #@torch.jit.script + def jit_func(self, x): + B, T, C = x.size() + + xx = self.time_shift(x) - x + + xxx = x + xx * self.time_maa_x + xxx = torch.tanh(xxx @ self.time_maa_w1).view(B*T, 5, -1).transpose(0, 1) + xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + xw = x + xx * (self.time_maa_w + mw) + xk = x + xx * (self.time_maa_k + mk) + xv = x + xx * (self.time_maa_v + mv) + xr = x + xx * (self.time_maa_r + mr) + xg = x + xx * (self.time_maa_g + mg) + + r = self.receptance(xr) + k = self.key(xk) + v = self.value(xv) + g = F.silu(self.gate(xg)) + + ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 + w = self.time_decay + ww + + return r, k, v, g, w + + #@torch.jit.script + def jit_func_2(self, x, g): + B, T, C = x.size() + x = x.view(B * T, C) + + x = self.ln_x(x).view(B, T, C) + x = self.output(x * g) + return x + + def forward(self, x): + B, T, C = x.size() + H = self.num_attention_heads + + r, k, v, g, w = self.jit_func(x) + x = RUN_CUDA_RWKV(self.WKV, B, T, C, H, r, k, v, w, u=self.time_faaaa) + + return self.jit_func_2(x, g) + +######################################################################################################## + +class RWKV_ChannelMix(nn.Module): + def __init__(self, neox_args, layer_number): + super().__init__() + self.neox_args = neox_args + self.layer_number = layer_number + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = 1.0 - (layer_number / neox_args.num_layers) # 1 to ~0 + ddd = torch.ones(1, 1, neox_args.hidden_size) + for i in range(neox_args.hidden_size): + ddd[0, 0, i] = i / neox_args.hidden_size + self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + + self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) + self.receptance = nn.Linear(neox_args.hidden_size, neox_args.hidden_size, bias=False) + self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) + + #@torch.jit.script + def forward(self, x): + xx = self.time_shift(x) - x + xk = x + xx * self.time_maa_k + xr = x + xx * self.time_maa_r + + k = self.key(xk) + k = torch.relu(k) ** 2 + kv = self.value(k) + return torch.sigmoid(self.receptance(xr)) * kv + +######################################################################################################## + +class MishGLU(nn.Module): + def __init__(self, neox_args, layer_number): + super().__init__() + self.neox_args = neox_args + self.layer_number = layer_number + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): + ratio_1_to_almost0 = 1.0 - (layer_number / neox_args.num_layers) + + x = torch.ones(1, 1, neox_args.hidden_size) + for i in range(neox_args.hidden_size): + x[0, 0, i] = i / neox_args.hidden_size + + self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.aa = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) + self.bb = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) + self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) + + #@torch.jit.script + def forward(self, x): + xx = self.time_shift(x) + xa = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xb = x * self.time_mix_r + xx * (1 - self.time_mix_r) + a = self.aa(xa) + b = self.bb(xb) + return self.value(a * F.mish(b)) + +######################################################################################################## +# The RWKV Model with our blocks +######################################################################################################## + +class RWKVResidualLayer(nn.Module): + def __init__(self, neox_args, layer_number): + super().__init__() + self.neox_args = neox_args + self.layer_number = layer_number + self.fp16 = neox_args.precision == "fp16" + self.bf16 = neox_args.precision == "bfloat16" + if not hasattr(neox_args, 'dim_att'): + neox_args.dim_att = neox_args.hidden_size + if not hasattr(neox_args, 'dim_ffn'): + neox_args.dim_ffn = neox_args.hidden_size * 4 + assert neox_args.hidden_size % 32 == 0 + assert neox_args.dim_att % 32 == 0 + assert neox_args.dim_ffn % 32 == 0 + + if neox_args.attention_dropout > 0: + self.drop0 = nn.Dropout(p = neox_args.attention_dropout) + + self.ln1 = nn.LayerNorm(neox_args.hidden_size) + self.ln2 = nn.LayerNorm(neox_args.hidden_size) + + if self.layer_number == 0 and self.neox_args.rwkv_pre_ffn > 0: + self.ffnPre = RWKV_ChannelMix(neox_args, 0) + else: + self.att = RWKV_TimeMix(neox_args, layer_number) + + if neox_args.rwkv_mishglu: + self.ffn = MishGLU(neox_args, layer_number) + else: + self.ffn = RWKV_ChannelMix(neox_args, layer_number) + + if neox_args.attention_dropout > 0: + self.drop0 = nn.Dropout(p = neox_args.attention_dropout) + if neox_args.hidden_dropout > 0: + self.drop1 = nn.Dropout(p = neox_args.hidden_dropout) + + global wkv_cuda + wkv_cuda = load(name="wkv6", sources=["/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_op.cpp", f"/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_cuda.cu"], + verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={512}"]) + + def forward(self, x): + neox_args = self.neox_args + B, T, C = x.size() + if self.layer_number == 0: + x = self.ln1(x) + #if neox_args.pos_emb > 0: + # pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:] + # x = x + pos_emb + + if self.neox_args.attention_dropout == 0 and self.neox_args.hidden_dropout == 0: + if self.layer_number == 0 and neox_args.rwkv_pre_ffn > 0: + x = x + self.ffnPre(self.ln1(x)) + else: + x = x + self.att(self.ln1(x)) + x = x + self.ffn(self.ln2(x)) + else: + if self.layer_number == 0 and neox_args.rwkv_pre_ffn > 0: + x = self.drop0(x + self.ffnPre(self.ln1(x))) + else: + if self.neox_args.attention_dropout > 0: + x = self.drop0(x + self.att(self.ln1(x))) + x = self.drop0(x + self.att(self.ln1(x))) + if self.neox_args.hidden_dropout > 0: + x = self.drop1(x + self.ffn(self.ln2(x))) + + return x + +class RWKVResidualLayerPipe(RWKVResidualLayer): + def forward(self, args): + assert len(args) == 2 + hidden_states, mask = args + neox_args = self.neox_args + return super().forward(hidden_states), mask From 04b8fdb2629ec1b012d0e1a070dec870247b60ce Mon Sep 17 00:00:00 2001 From: jahatef Date: Sun, 31 Mar 2024 20:06:38 +0000 Subject: [PATCH 04/20] configs --- configs/rwkv/170M.yml | 105 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 configs/rwkv/170M.yml diff --git a/configs/rwkv/170M.yml b/configs/rwkv/170M.yml new file mode 100644 index 000000000..f0d8246e3 --- /dev/null +++ b/configs/rwkv/170M.yml @@ -0,0 +1,105 @@ +{ + # Parallelism is not yet supported for rwkv + "pipe_parallel_size": 0, + "model_parallel_size": 1, + + "num_layers": 12, + "hidden_size": 768, + "num_attention_heads": 12, # ignored when using rwkv + "seq_length": 512, + "max_position_embeddings": 2048, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + "train_micro_batch_size_per_gpu": 1, + + "attention_config": [[["rwkv"], 12]], + + "activation": "silu", + + "output_layer_init_method": "single_residual_scaled_normal", + + # model settings + + #"pos_emb": "rotary", + "rotary_pct": 0.25, + "no_weight_tying": true, + "gpt_j_residual": true, + "output_layer_parallelism": "column", + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + "rope_fusion": false, + "layernorm_fusion": false, + + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0008, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00008, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 32, + "data_impl": "mmap", + "num_workers": 1, + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "bf16": { + "bf16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 12, + "hysteresis": 2, + "min_loss_scale": 1, + }, + + # misc. training settings + "train_iters": 500, + "lr_decay_iters": 500, + "distributed_backend": "nccl", + "lr_decay_style": "constant", + "warmup": 0.01, + "checkpoint_factor": 100, + "eval_interval": 100000, + "eval_iters": 10, + + # logging + "log_interval": 10, + "steps_per_print": 10, + "wall_clock_breakdown": true, +} From 6e79cc28050f5adf4ea92f853e89e6b678a3d270 Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Wed, 10 Apr 2024 12:49:43 -0400 Subject: [PATCH 05/20] kernels --- configs/rwkv/170M.yml | 3 +- megatron/model/rwkv/cuda/__init__.py | 43 -------------- megatron/model/rwkv/cuda/setup.py | 88 ---------------------------- megatron/model/rwkv/rwkv.py | 47 +++++++-------- 4 files changed, 23 insertions(+), 158 deletions(-) delete mode 100644 megatron/model/rwkv/cuda/__init__.py delete mode 100644 megatron/model/rwkv/cuda/setup.py diff --git a/configs/rwkv/170M.yml b/configs/rwkv/170M.yml index f0d8246e3..4cdf9a376 100644 --- a/configs/rwkv/170M.yml +++ b/configs/rwkv/170M.yml @@ -5,7 +5,8 @@ "num_layers": 12, "hidden_size": 768, - "num_attention_heads": 12, # ignored when using rwkv + "num_attention_heads": 12, # head_size = dim_att / num_attention_heads. + # head_size is 64 for all rwkv models "seq_length": 512, "max_position_embeddings": 2048, "output_layer_parallelism": "column", diff --git a/megatron/model/rwkv/cuda/__init__.py b/megatron/model/rwkv/cuda/__init__.py deleted file mode 100644 index 5eea67020..000000000 --- a/megatron/model/rwkv/cuda/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import pathlib -import subprocess - -from pathlib import Path - -srcpath = Path(__file__).parent.absolute() - -# Setting this param to a list has a problem of generating different -# compilation commands (with different order of architectures) and -# leading to recompilation of fused kernels. Set it to empty string -# to avoid recompilation and assign arch flags explicitly in -# extra_cuda_cflags below -os.environ["TORCH_CUDA_ARCH_LIST"] = "" - - -def load_fused_kernels(): - try: - import rwkv_cuda - except (ImportError, ModuleNotFoundError) as e: - print("\n") - print(e) - print("=" * 100) - print( - f"ERROR: RWKV kernels configured but not properly installed. Please run `pip install {str(srcpath)}` to install them" - ) - print("=" * 100) - exit() - return diff --git a/megatron/model/rwkv/cuda/setup.py b/megatron/model/rwkv/cuda/setup.py deleted file mode 100644 index f9e560f16..000000000 --- a/megatron/model/rwkv/cuda/setup.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) 2024, EleutherAI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from setuptools import setup, find_packages -from torch.utils import cpp_extension -from torch.utils.cpp_extension import BuildExtension, CUDAExtension -from pathlib import Path -import subprocess - - -def _get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output( - [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True - ) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - -class CommandMixin(object): - user_options = [ - ('head_size=', 64, 'head size for the kernel'), - ('max_seq_length=', 512, 'maximum sequence length for the kernel') - ] - - def initialize_options(self): - super().initialize_options() - # Initialize options - self.head_size = 64 - self.max_seq_length = 512 - - def finalize_options(self): - # Validate options - if self.head_size <= 0: - raise ValueError("head_size must be positive") - if self.max_seq_length <= 0: - raise ValueError("max_seq_length must be positive") - super().finalize_options() - - def run(self): - # Use options - global head_size, max_seq_length - head_size = self.head_size - max_seq_length = self.max_seq_length - global cuda_ext_args - cuda_ext_args = ["-res-usage", - "--use_fast_math", - "-O3", "-Xptxas -O3", - "--extra-device-vectorization", - f"-D_N_={head_size}", - f"-D_T_={max_seq_length}"] - print("here") - super().run() - -class ExtensionCommand(CommandMixin, BuildExtension): - user_options = getattr(BuildExtension, 'user_options', []) + CommandMixin.user_options - -srcpath = Path(__file__).parent.absolute() - -setup( - name="rwkv_cuda", - include_package_data=False, - ext_modules=[ - CUDAExtension( - name="wkv6_cuda", - sources=[ - str(srcpath / "wkv6_op.cpp"), - str(srcpath / "wkv6_cuda.cu"), - ], - extra_compile_args=cuda_ext_args, - ) - ], - cmdclass={"build_ext": ExtensionCommand}, -) diff --git a/megatron/model/rwkv/rwkv.py b/megatron/model/rwkv/rwkv.py index 5b50c9b1a..9f1841bcf 100644 --- a/megatron/model/rwkv/rwkv.py +++ b/megatron/model/rwkv/rwkv.py @@ -8,7 +8,7 @@ # torch._C._jit_set_profiling_mode(True) import torch.nn as nn from torch.nn import functional as F -from rwkv.cuda import rwkv_cuda + ######################################################################################################## # CUDA Kernel @@ -16,21 +16,21 @@ from torch.utils.cpp_extension import load -HEAD_SIZE = 64 +#HEAD_SIZE = 64 #wkv_cuda = load(name="wkv6", sources=["/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_op.cpp", f"/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_cuda.cu"], # verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={512}"]) class WKV(torch.autograd.Function): @staticmethod - def forward(WKV, ctx, B, T, C, H, r, k, v, w, u): + def forward(ctx, B, T, C, H, r, k, v, w, u): with torch.no_grad(): assert r.dtype == torch.bfloat16 assert k.dtype == torch.bfloat16 assert v.dtype == torch.bfloat16 assert w.dtype == torch.bfloat16 assert u.dtype == torch.bfloat16 - assert HEAD_SIZE == C // H + #assert HEAD_SIZE == C // H ctx.B = B ctx.T = T ctx.C = C @@ -65,7 +65,7 @@ def backward(ctx, gy): gu = torch.sum(gu, 0).view(H, C//H) return (None, None, None, None, gr, gk, gv, gw, gu) -def RUN_CUDA_RWKV(WKV, B, T, C, H, r, k, v, w, u): +def RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u): return WKV.apply(B, T, C, H, r, k, v, w, u) # RWKV6 time mix @@ -75,17 +75,6 @@ def __init__(self, neox_args, layer_number): self.neox_args = neox_args self.layer_number = layer_number - self.head_size = 64 - self.num_attention_heads = neox_args.dim_att // self.head_size - assert neox_args.dim_att % self.num_attention_heads == 0 - self.wkv_cuda = load(name="wkv6", sources=["/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_op.cpp", - f"/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_cuda.cu"], - verbose=True, - extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", - "--extra-device-vectorization", - f"-D_N_={self.neox_args.head_size}", f"-D_T_={self.neox_args.seq_length}"]) - - with torch.no_grad(): ratio_0_to_1 = layer_number / (neox_args.num_layers - 1) # 0 to 1 ratio_1_to_almost0 = 1.0 - (layer_number / neox_args.num_layers) # 1 to ~0 @@ -120,7 +109,7 @@ def __init__(self, neox_args, layer_number): zigzag = ((n + 1) % 3 - 1) * 0.1 tmp[n] = ratio_0_to_1 * (1 - (n / (neox_args.dim_att - 1))) + zigzag - self.time_faaaa = nn.Parameter(tmp.reshape(self.num_attention_heads, self.head_size)) + self.time_faaaa = nn.Parameter(tmp.reshape(neox_args.num_attention_heads, neox_args.head_size)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.receptance = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) @@ -129,7 +118,7 @@ def __init__(self, neox_args, layer_number): self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False) self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) - self.ln_x = nn.GroupNorm(self.num_attention_heads, neox_args.dim_att, eps=(1e-5)*(8**2)) + self.ln_x = nn.GroupNorm(neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5)*(8**2)) #@torch.jit.script def jit_func(self, x): @@ -169,10 +158,10 @@ def jit_func_2(self, x, g): def forward(self, x): B, T, C = x.size() - H = self.num_attention_heads + H = self.neox_args.num_attention_heads r, k, v, g, w = self.jit_func(x) - x = RUN_CUDA_RWKV(self.WKV, B, T, C, H, r, k, v, w, u=self.time_faaaa) + x = RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u=self.time_faaaa) return self.jit_func_2(x, g) @@ -257,6 +246,10 @@ def __init__(self, neox_args, layer_number): assert neox_args.hidden_size % 32 == 0 assert neox_args.dim_att % 32 == 0 assert neox_args.dim_ffn % 32 == 0 + self.neox_args.head_size = neox_args.dim_att // neox_args.num_attention_heads + self.head_size = self.neox_args.head_size + self.num_attention_heads = neox_args.num_attention_heads + assert neox_args.dim_att % self.num_attention_heads == 0 if neox_args.attention_dropout > 0: self.drop0 = nn.Dropout(p = neox_args.attention_dropout) @@ -279,18 +272,20 @@ def __init__(self, neox_args, layer_number): if neox_args.hidden_dropout > 0: self.drop1 = nn.Dropout(p = neox_args.hidden_dropout) - global wkv_cuda - wkv_cuda = load(name="wkv6", sources=["/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_op.cpp", f"/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_cuda.cu"], - verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={512}"]) + if layer_number == 0: + global wkv_cuda + wkv_cuda = load(name="wkv6", sources=["megatron/model/rwkv/cuda/wkv6_op.cpp", + f"megatron/model/rwkv/cuda/wkv6_cuda.cu"], + verbose=True, + extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={self.neox_args.head_size}", f"-D_T_={self.neox_args.seq_length}"]) def forward(self, x): neox_args = self.neox_args B, T, C = x.size() if self.layer_number == 0: x = self.ln1(x) - #if neox_args.pos_emb > 0: - # pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:] - # x = x + pos_emb if self.neox_args.attention_dropout == 0 and self.neox_args.hidden_dropout == 0: if self.layer_number == 0 and neox_args.rwkv_pre_ffn > 0: From cb49ff6ea1cb972917bfb4b081509fb3092f409d Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Tue, 16 Apr 2024 00:16:39 -0400 Subject: [PATCH 06/20] Cleanup --- configs/760M.yml | 2 +- megatron/model/rwkv/rwkv.py | 49 +++++++++++++--------------- megatron/neox_arguments/neox_args.py | 4 +-- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/configs/760M.yml b/configs/760M.yml index 6d62dc0f3..93b3d4400 100644 --- a/configs/760M.yml +++ b/configs/760M.yml @@ -76,7 +76,7 @@ }, # misc. training settings - "train_iters": 320000, + "train_iters": 32000, "lr_decay_iters": 320000, "distributed_backend": "nccl", "lr_decay_style": "cosine", diff --git a/megatron/model/rwkv/rwkv.py b/megatron/model/rwkv/rwkv.py index 9f1841bcf..5a6a28121 100644 --- a/megatron/model/rwkv/rwkv.py +++ b/megatron/model/rwkv/rwkv.py @@ -4,24 +4,14 @@ import os, math, gc, importlib import torch -# torch._C._jit_set_profiling_executor(True) -# torch._C._jit_set_profiling_mode(True) import torch.nn as nn from torch.nn import functional as F - - -######################################################################################################## -# CUDA Kernel -######################################################################################################## - from torch.utils.cpp_extension import load -#HEAD_SIZE = 64 - -#wkv_cuda = load(name="wkv6", sources=["/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_op.cpp", f"/weka/home-jacob/gpt-neox/megatron/model/rwkv/cuda/wkv6_cuda.cu"], -# verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={512}"]) - class WKV(torch.autograd.Function): + """ + WKV block, using cuda kernel. + """ @staticmethod def forward(ctx, B, T, C, H, r, k, v, w, u): with torch.no_grad(): @@ -30,7 +20,6 @@ def forward(ctx, B, T, C, H, r, k, v, w, u): assert v.dtype == torch.bfloat16 assert w.dtype == torch.bfloat16 assert u.dtype == torch.bfloat16 - #assert HEAD_SIZE == C // H ctx.B = B ctx.T = T ctx.C = C @@ -70,6 +59,11 @@ def RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u): # RWKV6 time mix class RWKV_TimeMix(nn.Module): + """ + Time Mixing Layer + The RWKV substitute for attention. + TODO: fix jit compiling. + """ def __init__(self, neox_args, layer_number): super().__init__() self.neox_args = neox_args @@ -120,7 +114,6 @@ def __init__(self, neox_args, layer_number): self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) self.ln_x = nn.GroupNorm(neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5)*(8**2)) - #@torch.jit.script def jit_func(self, x): B, T, C = x.size() @@ -147,7 +140,6 @@ def jit_func(self, x): return r, k, v, g, w - #@torch.jit.script def jit_func_2(self, x, g): B, T, C = x.size() x = x.view(B * T, C) @@ -165,9 +157,10 @@ def forward(self, x): return self.jit_func_2(x, g) -######################################################################################################## - class RWKV_ChannelMix(nn.Module): + """ + Channel Mix layer. The ffn in RWKV + """ def __init__(self, neox_args, layer_number): super().__init__() self.neox_args = neox_args @@ -186,7 +179,6 @@ def __init__(self, neox_args, layer_number): self.receptance = nn.Linear(neox_args.hidden_size, neox_args.hidden_size, bias=False) self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) - #@torch.jit.script def forward(self, x): xx = self.time_shift(x) - x xk = x + xx * self.time_maa_k @@ -197,9 +189,10 @@ def forward(self, x): kv = self.value(k) return torch.sigmoid(self.receptance(xr)) * kv -######################################################################################################## - class MishGLU(nn.Module): + """ + MishGLU ffn, used in place of channel mixing if neox_args.rwkv_mishglu + """ def __init__(self, neox_args, layer_number): super().__init__() self.neox_args = neox_args @@ -219,7 +212,6 @@ def __init__(self, neox_args, layer_number): self.bb = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) - #@torch.jit.script def forward(self, x): xx = self.time_shift(x) xa = x * self.time_mix_k + xx * (1 - self.time_mix_k) @@ -228,11 +220,10 @@ def forward(self, x): b = self.bb(xb) return self.value(a * F.mish(b)) -######################################################################################################## -# The RWKV Model with our blocks -######################################################################################################## - class RWKVResidualLayer(nn.Module): + """ + RWKV layer definition + """ def __init__(self, neox_args, layer_number): super().__init__() self.neox_args = neox_args @@ -274,6 +265,9 @@ def __init__(self, neox_args, layer_number): if layer_number == 0: global wkv_cuda + """ + Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not. + """ wkv_cuda = load(name="wkv6", sources=["megatron/model/rwkv/cuda/wkv6_op.cpp", f"megatron/model/rwkv/cuda/wkv6_cuda.cu"], verbose=True, @@ -306,6 +300,9 @@ def forward(self, x): return x class RWKVResidualLayerPipe(RWKVResidualLayer): + """ + RWKV Pipeline Layer + """ def forward(self, args): assert len(args) == 2 hidden_states, mask = args diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index f0b43afec..f730f2867 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1144,12 +1144,12 @@ class NeoXArgsTraining(NeoXArgsTemplate): rwkv_pre_ffn: bool = False """ - Use channel mix block as first time mix block + Use ffn block as first time mix block """ rwkv_mishglu: bool = False """ - mishglu ffn + mishglu ffn instead of channel mix """ From 96ea6f5b01e554c20da1f5097435d874036a6773 Mon Sep 17 00:00:00 2001 From: Jacob Hatef <74274091+jahatef@users.noreply.github.com> Date: Tue, 16 Apr 2024 00:17:55 -0400 Subject: [PATCH 07/20] Update 760M.yml --- configs/760M.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/760M.yml b/configs/760M.yml index 93b3d4400..6d62dc0f3 100644 --- a/configs/760M.yml +++ b/configs/760M.yml @@ -76,7 +76,7 @@ }, # misc. training settings - "train_iters": 32000, + "train_iters": 320000, "lr_decay_iters": 320000, "distributed_backend": "nccl", "lr_decay_style": "cosine", From 54f27755cba0c1a88d61db32199dd5f68e7412a7 Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Wed, 17 Apr 2024 18:35:19 -0400 Subject: [PATCH 08/20] remove preffn and mishglu --- megatron/model/rwkv/rwkv.py | 60 +++++----------------------- megatron/neox_arguments/neox_args.py | 10 ----- 2 files changed, 9 insertions(+), 61 deletions(-) diff --git a/megatron/model/rwkv/rwkv.py b/megatron/model/rwkv/rwkv.py index 5a6a28121..3ab148266 100644 --- a/megatron/model/rwkv/rwkv.py +++ b/megatron/model/rwkv/rwkv.py @@ -189,36 +189,6 @@ def forward(self, x): kv = self.value(k) return torch.sigmoid(self.receptance(xr)) * kv -class MishGLU(nn.Module): - """ - MishGLU ffn, used in place of channel mixing if neox_args.rwkv_mishglu - """ - def __init__(self, neox_args, layer_number): - super().__init__() - self.neox_args = neox_args - self.layer_number = layer_number - self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - - with torch.no_grad(): - ratio_1_to_almost0 = 1.0 - (layer_number / neox_args.num_layers) - - x = torch.ones(1, 1, neox_args.hidden_size) - for i in range(neox_args.hidden_size): - x[0, 0, i] = i / neox_args.hidden_size - - self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - self.aa = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) - self.bb = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) - self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) - - def forward(self, x): - xx = self.time_shift(x) - xa = x * self.time_mix_k + xx * (1 - self.time_mix_k) - xb = x * self.time_mix_r + xx * (1 - self.time_mix_r) - a = self.aa(xa) - b = self.bb(xb) - return self.value(a * F.mish(b)) class RWKVResidualLayer(nn.Module): """ @@ -248,15 +218,9 @@ def __init__(self, neox_args, layer_number): self.ln1 = nn.LayerNorm(neox_args.hidden_size) self.ln2 = nn.LayerNorm(neox_args.hidden_size) - if self.layer_number == 0 and self.neox_args.rwkv_pre_ffn > 0: - self.ffnPre = RWKV_ChannelMix(neox_args, 0) - else: - self.att = RWKV_TimeMix(neox_args, layer_number) + self.att = RWKV_TimeMix(neox_args, layer_number) - if neox_args.rwkv_mishglu: - self.ffn = MishGLU(neox_args, layer_number) - else: - self.ffn = RWKV_ChannelMix(neox_args, layer_number) + self.ffn = RWKV_ChannelMix(neox_args, layer_number) if neox_args.attention_dropout > 0: self.drop0 = nn.Dropout(p = neox_args.attention_dropout) @@ -281,21 +245,15 @@ def forward(self, x): if self.layer_number == 0: x = self.ln1(x) - if self.neox_args.attention_dropout == 0 and self.neox_args.hidden_dropout == 0: - if self.layer_number == 0 and neox_args.rwkv_pre_ffn > 0: - x = x + self.ffnPre(self.ln1(x)) - else: - x = x + self.att(self.ln1(x)) + if self.neox_args.attention_dropout == 0: + x = x + self.att(self.ln1(x)) + else: + x = self.drop0(x + self.att(self.ln1(x))) + + if self.neox_args.hidden_dropout == 0: x = x + self.ffn(self.ln2(x)) else: - if self.layer_number == 0 and neox_args.rwkv_pre_ffn > 0: - x = self.drop0(x + self.ffnPre(self.ln1(x))) - else: - if self.neox_args.attention_dropout > 0: - x = self.drop0(x + self.att(self.ln1(x))) - x = self.drop0(x + self.att(self.ln1(x))) - if self.neox_args.hidden_dropout > 0: - x = self.drop1(x + self.ffn(self.ln2(x))) + x = self.drop1(x + self.ffn(self.ln2(x))) return x diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index f730f2867..c55d2d877 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1142,16 +1142,6 @@ class NeoXArgsTraining(NeoXArgsTemplate): What to scale width by when creating the delta model for mup """ - rwkv_pre_ffn: bool = False - """ - Use ffn block as first time mix block - """ - - rwkv_mishglu: bool = False - """ - mishglu ffn instead of channel mix - """ - @dataclass class NeoXArgsTextgen(NeoXArgsTemplate): From 276ffa9aa14446ccb57857fe3588c6b8c1dc219b Mon Sep 17 00:00:00 2001 From: github-actions Date: Fri, 19 Apr 2024 01:26:22 +0000 Subject: [PATCH 09/20] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index f0ea55eeb..e66e4b195 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 11a5537 + Default = 8d60cef current git hash of repository @@ -432,7 +432,7 @@ Model Arguments The first item in the list specifies the attention type(s), and should be a list of strings. The second item specifies the number of times to repeat those attention types in the full list. - attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba"] + attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba", "rwkv"] So a 12 layer network with only global attention could be specified like: [[[`global`], 12]] @@ -1965,7 +1965,9 @@ Args for deepspeed config Default = None - Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options + Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). + + Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options From e20138cc1b6355c04d1b865a4a0b7a4f6133fa2b Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Thu, 18 Apr 2024 21:30:20 -0400 Subject: [PATCH 10/20] Add RWKV parallelism assertions --- megatron/neox_arguments/arguments.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index bf6e3f3e8..a4b1b8507 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1066,6 +1066,15 @@ def calculate_derived(self): assert ( self.hidden_dropout == 0.0, ), "Mamba does not yet have dropout implemented" + if "rwkv" in self.attention_config: + assert ( + not self.is_pipe_parallel and self.model_parallel_size == 1 + ), "RWKV not currently compatible with parallelism" + if isinstance(self.zero_stage, int): + assert self.zero_stage <= 2, "Zero stage 3 not compatible with RWKV" + assert ( + self.hidden_dropout == 0.0, + ), "RWKV does not yet have dropout implemented" # Sparsity config if self.sparsity_config is None: From 428aad520a4fca1a72383899a1f7a05e24c0dfa4 Mon Sep 17 00:00:00 2001 From: github-actions Date: Fri, 19 Apr 2024 01:30:45 +0000 Subject: [PATCH 11/20] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index e66e4b195..1eb3d943e 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 8d60cef + Default = e20138c current git hash of repository From 1b0bbab41706a855a9f211c5cf9406dbfc31158c Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Thu, 18 Apr 2024 21:33:45 -0400 Subject: [PATCH 12/20] pre-commit and config cleanup --- configs/rwkv/170M.yml | 6 +- megatron/model/gpt2_model.py | 2 +- megatron/model/rwkv/__init__.py | 2 +- megatron/model/rwkv/cuda/wkv6_cuda.cu | 158 ++++++++++++--------- megatron/model/rwkv/cuda/wkv6_op.cpp | 91 ++++++++++-- megatron/model/rwkv/rwkv.py | 164 +++++++++++++++++----- megatron/neox_arguments/deepspeed_args.py | 4 +- 7 files changed, 306 insertions(+), 121 deletions(-) diff --git a/configs/rwkv/170M.yml b/configs/rwkv/170M.yml index 4cdf9a376..f31f2613b 100644 --- a/configs/rwkv/170M.yml +++ b/configs/rwkv/170M.yml @@ -5,7 +5,7 @@ "num_layers": 12, "hidden_size": 768, - "num_attention_heads": 12, # head_size = dim_att / num_attention_heads. + "num_attention_heads": 12, # head_size = dim_att / num_attention_heads. # head_size is 64 for all rwkv models "seq_length": 512, "max_position_embeddings": 2048, @@ -18,15 +18,12 @@ "activation": "silu", - "output_layer_init_method": "single_residual_scaled_normal", - # model settings #"pos_emb": "rotary", "rotary_pct": 0.25, "no_weight_tying": true, "gpt_j_residual": true, - "output_layer_parallelism": "column", # these should provide some speedup but takes a while to build, set to true if desired "scaled_upper_triang_masked_softmax_fusion": false, @@ -62,7 +59,6 @@ }, # batch / data settings - "train_micro_batch_size_per_gpu": 32, "data_impl": "mmap", "num_workers": 1, diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index d85ad2393..89f43c352 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -37,7 +37,7 @@ ParallelLinear, ) from megatron.model.gmlp import GMLPBlock -from megatron.model.rwkv import RWKVResidualLayerPipe +from megatron.model.rwkv import RWKVResidualLayerPipe from megatron.model.mamba import ParallelMambaResidualLayerPipe from megatron.model.word_embeddings import EmbeddingPipe, SoftEmbedding diff --git a/megatron/model/rwkv/__init__.py b/megatron/model/rwkv/__init__.py index 5c75f6622..c0d8d4ba1 100644 --- a/megatron/model/rwkv/__init__.py +++ b/megatron/model/rwkv/__init__.py @@ -1 +1 @@ -from .rwkv import RWKVResidualLayerPipe, RWKVResidualLayer \ No newline at end of file +from .rwkv import RWKVResidualLayerPipe, RWKVResidualLayer diff --git a/megatron/model/rwkv/cuda/wkv6_cuda.cu b/megatron/model/rwkv/cuda/wkv6_cuda.cu index 7b7c8366c..2b228e90f 100644 --- a/megatron/model/rwkv/cuda/wkv6_cuda.cu +++ b/megatron/model/rwkv/cuda/wkv6_cuda.cu @@ -1,17 +1,24 @@ -#include #include +#include #include "ATen/ATen.h" typedef at::BFloat16 bf16; template -__global__ void kernel_forward(const int B, const int T, const int C, const int H, - const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, - F *__restrict__ const _y) +__global__ void kernel_forward(const int B, + const int T, + const int C, + const int H, + const F* __restrict__ const _r, + const F* __restrict__ const _k, + const F* __restrict__ const _v, + const float* __restrict__ _w, + const F* __restrict__ _u, + F* __restrict__ const _y) { const int b = blockIdx.x / H; const int h = blockIdx.x % H; const int i = threadIdx.x; - _u += h*_N_; + _u += h * _N_; __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; float state[_N_] = {0}; @@ -20,8 +27,7 @@ __global__ void kernel_forward(const int B, const int T, const int C, const int u[i] = float(_u[i]); __syncthreads(); - for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) - { + for (int t = b * T * C + h * _N_ + i; t < (b + 1) * T * C + h * _N_ + i; t += C) { __syncthreads(); w[i] = exp(_w[t]); r[i] = float(_r[t]); @@ -31,9 +37,8 @@ __global__ void kernel_forward(const int B, const int T, const int C, const int const float v = float(_v[t]); float y = 0; - #pragma unroll - for (int j = 0; j < _N_; j+=4) - { +#pragma unroll + for (int j = 0; j < _N_; j += 4) { const float4& r_ = (float4&)(r[j]); const float4& k_ = (float4&)(k[j]); const float4& w_ = (float4&)(w[j]); @@ -61,14 +66,25 @@ __global__ void kernel_forward(const int B, const int T, const int C, const int } template -__global__ void kernel_backward_111(const int B, const int T, const int C, const int H, - const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, - F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu) +__global__ void kernel_backward_111(const int B, + const int T, + const int C, + const int H, + const F* __restrict__ const _r, + const F* __restrict__ const _k, + const F* __restrict__ const _v, + const float* __restrict__ _w, + const F* __restrict__ _u, + const F* __restrict__ const _gy, + F* __restrict__ const _gr, + F* __restrict__ const _gk, + F* __restrict__ const _gv, + F* __restrict__ const _gu) { const int b = blockIdx.x / H; const int h = blockIdx.x % H; const int i = threadIdx.x; - _u += h*_N_; + _u += h * _N_; __shared__ float u_[_N_]; __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; @@ -80,13 +96,12 @@ __global__ void kernel_backward_111(const int B, const int T, const int C, const float state[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; - const int t_0 = b*T*C + h*_N_ + i; - const int t_T_1 = t_0 + (T-1)*C; - const int t_T = t_0 + T*C; + const int t_0 = b * T * C + h * _N_ + i; + const int t_T_1 = t_0 + (T - 1) * C; + const int t_T = t_0 + T * C; float gu = 0; - for (int t = t_0; t < t_T; t += C) - { + for (int t = t_0; t < t_T; t += C) { __syncthreads(); v[i] = float(_v[t]); gy[i] = float(_gy[t]); @@ -96,9 +111,8 @@ __global__ void kernel_backward_111(const int B, const int T, const int C, const const float w = exp(_w[t]); float gr = 0, gu_ = 0; - #pragma unroll - for (int j = 0; j < _N_; j++) - { +#pragma unroll + for (int j = 0; j < _N_; j++) { float& s = state[j]; float x = k * v[j]; @@ -109,10 +123,9 @@ __global__ void kernel_backward_111(const int B, const int T, const int C, const _gr[t] = F(gr); gu += float(_r[t]) * gu_; } - _gu[b*C + h*_N_ + i] = F(gu); + _gu[b * C + h * _N_ + i] = F(gu); - for (int t = t_T_1; t >= t_0; t -= C) - { + for (int t = t_T_1; t >= t_0; t -= C) { __syncthreads(); v[i] = float(_v[t]); gy[i] = float(_gy[t]); @@ -122,20 +135,18 @@ __global__ void kernel_backward_111(const int B, const int T, const int C, const const float w = exp(_w[t]); float gk = 0; - #pragma unroll - for (int j = 0; j < _N_; j++) - { +#pragma unroll + for (int j = 0; j < _N_; j++) { float& s = scccc[j]; float x = rr * gy[j]; - + gk += (u * x + s) * v[j]; s = x + s * w; } _gk[t] = F(gk); } - for (int t = t_T_1; t >= t_0; t -= C) - { + for (int t = t_T_1; t >= t_0; t -= C) { __syncthreads(); r[i] = float(_r[t]); k[i] = float(_k[t]); @@ -145,12 +156,11 @@ __global__ void kernel_backward_111(const int B, const int T, const int C, const const float gyy = float(_gy[t]); float gv = 0; - #pragma unroll - for (int j = 0; j < _N_; j++) - { +#pragma unroll + for (int j = 0; j < _N_; j++) { float& s = sdddd[j]; float x = gyy * r[j]; - + gv += (u_[j] * x + s) * k[j]; s = x + s * w_[j]; } @@ -159,84 +169,102 @@ __global__ void kernel_backward_111(const int B, const int T, const int C, const } template -__global__ void kernel_backward_222(const int B, const int T, const int C, const int H, - const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, - F *__restrict__ const _gw) +__global__ void kernel_backward_222(const int B, + const int T, + const int C, + const int H, + const F* __restrict__ const _r, + const F* __restrict__ const _k, + const F* __restrict__ const _v, + const float* __restrict__ _w, + const F* __restrict__ _u, + const F* __restrict__ const _gy, + F* __restrict__ const _gw) { const int b = blockIdx.x / H; const int h = blockIdx.x % H; const int i = threadIdx.x; __shared__ float v[_N_], gy[_N_]; - float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0}; + float saaaa[_N_] = {0}, sbbbb[_T_ - 2] = {0}, scccc[_N_] = {0}; - const int t_0 = b*T*C + h*_N_ + i; + const int t_0 = b * T * C + h * _N_ + i; const int t_1 = t_0 + C; - const int t_2 = t_0 + 2*C; - const int t_T_1 = t_0 + (T-1)*C; + const int t_2 = t_0 + 2 * C; + const int t_T_1 = t_0 + (T - 1) * C; - for (int t = t_T_1; t > t_1; t -= C) - { + for (int t = t_T_1; t > t_1; t -= C) { __syncthreads(); gy[i] = float(_gy[t]); - v[i] = float(_v[t-2*C]); + v[i] = float(_v[t - 2 * C]); __syncthreads(); const float r = float(_r[t]); - const float w = exp(_w[t-C]); + const float w = exp(_w[t - C]); float sum = 0.0f; - #pragma unroll - for (int j = 0; j < _N_; j++) - { +#pragma unroll + for (int j = 0; j < _N_; j++) { float& s = saaaa[j]; float x = r * gy[j]; s = (s + x) * w; sum += s * v[j]; } - sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]); + sbbbb[(t - t_2) / C] = sum * float(_k[t - 2 * C]); } float sss = sbbbb[0]; _gw[t_0] = 0; _gw[t_1] = F(sss * _w[t_1]); - for (int t = t_2; t < t_T_1; t += C) - { + for (int t = t_2; t < t_T_1; t += C) { __syncthreads(); gy[i] = float(_gy[t]); - v[i] = float(_v[t-2*C]); + v[i] = float(_v[t - 2 * C]); __syncthreads(); - const float w = exp(_w[t-C]); - const float k = float(_k[t-2*C]); + const float w = exp(_w[t - C]); + const float k = float(_k[t - 2 * C]); float sum = 0.0f; - #pragma unroll - for (int j = 0; j < _N_; j++) - { +#pragma unroll + for (int j = 0; j < _N_; j++) { float& s = scccc[j]; float x = k * v[j]; s = (s + x) * w; sum += s * gy[j]; } - sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t])); + sss += sbbbb[(t - t_1) / C] - (sum * float(_r[t])); _gw[t] = F(sss * _w[t]); } _gw[t_T_1] = 0; } -void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) +void cuda_forward(int B, int T, int C, int H, bf16* r, bf16* k, bf16* v, float* w, bf16* u, bf16* y) { - assert(H*_N_ == C); - assert(_N_%4 == 0); + assert(H * _N_ == C); + assert(_N_ % 4 == 0); kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); } -void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) +void cuda_backward(int B, + int T, + int C, + int H, + bf16* r, + bf16* k, + bf16* v, + float* w, + bf16* u, + bf16* gy, + bf16* gr, + bf16* gk, + bf16* gv, + bf16* gw, + bf16* gu) { - assert(H*_N_ == C); - assert(_N_%4 == 0); + assert(H * _N_ == C); + assert(_N_ % 4 == 0); kernel_backward_111<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gu); kernel_backward_222<<>>(B, T, C, H, r, k, v, w, u, gy, gw); } diff --git a/megatron/model/rwkv/cuda/wkv6_op.cpp b/megatron/model/rwkv/cuda/wkv6_op.cpp index 56d9fbc62..385b47487 100644 --- a/megatron/model/rwkv/cuda/wkv6_op.cpp +++ b/megatron/model/rwkv/cuda/wkv6_op.cpp @@ -2,21 +2,94 @@ #include "ATen/ATen.h" typedef at::BFloat16 bf16; -void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); -void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); +void cuda_forward(int B, + int T, + int C, + int H, + bf16* r, + bf16* k, + bf16* v, + float* w, + bf16* u, + bf16* y); +void cuda_backward(int B, + int T, + int C, + int H, + bf16* r, + bf16* k, + bf16* v, + float* w, + bf16* u, + bf16* gy, + bf16* gr, + bf16* gk, + bf16* gv, + bf16* gw, + bf16* gu); -void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { - cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +void forward(int64_t B, + int64_t T, + int64_t C, + int64_t H, + torch::Tensor& r, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& w, + torch::Tensor& u, + torch::Tensor& y) +{ + cuda_forward(B, + T, + C, + H, + r.data_ptr(), + k.data_ptr(), + v.data_ptr(), + w.data_ptr(), + u.data_ptr(), + y.data_ptr()); } -void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { - cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); +void backward(int64_t B, + int64_t T, + int64_t C, + int64_t H, + torch::Tensor& r, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& w, + torch::Tensor& u, + torch::Tensor& gy, + torch::Tensor& gr, + torch::Tensor& gk, + torch::Tensor& gv, + torch::Tensor& gw, + torch::Tensor& gu) +{ + cuda_backward(B, + T, + C, + H, + r.data_ptr(), + k.data_ptr(), + v.data_ptr(), + w.data_ptr(), + u.data_ptr(), + gy.data_ptr(), + gr.data_ptr(), + gk.data_ptr(), + gv.data_ptr(), + gw.data_ptr(), + gu.data_ptr()); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ m.def("forward", &forward, "wkv6 forward"); m.def("backward", &backward, "wkv6 backward"); } -TORCH_LIBRARY(wkv6, m) { +TORCH_LIBRARY(wkv6, m) +{ m.def("forward", forward); m.def("backward", backward); -} \ No newline at end of file +} diff --git a/megatron/model/rwkv/rwkv.py b/megatron/model/rwkv/rwkv.py index 3ab148266..39cb9d6fb 100644 --- a/megatron/model/rwkv/rwkv.py +++ b/megatron/model/rwkv/rwkv.py @@ -8,10 +8,12 @@ from torch.nn import functional as F from torch.utils.cpp_extension import load + class WKV(torch.autograd.Function): """ WKV block, using cuda kernel. """ + @staticmethod def forward(ctx, B, T, C, H, r, k, v, w, u): with torch.no_grad(): @@ -31,7 +33,12 @@ def forward(ctx, B, T, C, H, r, k, v, w, u): assert u.is_contiguous() ew = (-torch.exp(w.float())).contiguous() ctx.save_for_backward(r, k, v, ew, u) - y = torch.empty((B, T, C), device=r.device, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100) + y = torch.empty( + (B, T, C), + device=r.device, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) wkv_cuda.forward(B, T, C, H, r, k, v, ew, u, y) return y @@ -45,25 +52,58 @@ def backward(ctx, gy): H = ctx.H assert gy.is_contiguous() r, k, v, ew, u = ctx.saved_tensors - gr = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100) - gk = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100) - gv = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100) - gw = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100) - gu = torch.empty((B, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100) + gr = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gk = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gv = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gw = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gu = torch.empty( + (B, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) wkv_cuda.backward(B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu) - gu = torch.sum(gu, 0).view(H, C//H) + gu = torch.sum(gu, 0).view(H, C // H) return (None, None, None, None, gr, gk, gv, gw, gu) + def RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u): return WKV.apply(B, T, C, H, r, k, v, w, u) + # RWKV6 time mix class RWKV_TimeMix(nn.Module): """ Time Mixing Layer - The RWKV substitute for attention. + The RWKV substitute for attention. TODO: fix jit compiling. """ + def __init__(self, neox_args, layer_number): super().__init__() self.neox_args = neox_args @@ -80,39 +120,69 @@ def __init__(self, neox_args, layer_number): self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) - self.time_maa_v = nn.Parameter(1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)) - self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)) - self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)) - - TIME_MIX_EXTRA_DIM = 32 # generate TIME_MIX for w,k,v,r,g - self.time_maa_w1 = nn.Parameter(torch.zeros(neox_args.hidden_size, TIME_MIX_EXTRA_DIM*5).uniform_(-1e-4, 1e-4)) - self.time_maa_w2 = nn.Parameter(torch.zeros(5, TIME_MIX_EXTRA_DIM, neox_args.hidden_size).uniform_(-1e-4, 1e-4)) + self.time_maa_v = nn.Parameter( + 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + ) + self.time_maa_r = nn.Parameter( + 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0) + ) + self.time_maa_g = nn.Parameter( + 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0) + ) + + TIME_MIX_EXTRA_DIM = 32 # generate TIME_MIX for w,k,v,r,g + self.time_maa_w1 = nn.Parameter( + torch.zeros(neox_args.hidden_size, TIME_MIX_EXTRA_DIM * 5).uniform_( + -1e-4, 1e-4 + ) + ) + self.time_maa_w2 = nn.Parameter( + torch.zeros(5, TIME_MIX_EXTRA_DIM, neox_args.hidden_size).uniform_( + -1e-4, 1e-4 + ) + ) # fancy time_decay decay_speed = torch.ones(neox_args.dim_att) for n in range(neox_args.dim_att): - decay_speed[n] = -6 + 5 * (n / (neox_args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) - self.time_decay = nn.Parameter(decay_speed.reshape(1,1,neox_args.dim_att)) + decay_speed[n] = -6 + 5 * (n / (neox_args.dim_att - 1)) ** ( + 0.7 + 1.3 * ratio_0_to_1 + ) + self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, neox_args.dim_att)) TIME_DECAY_EXTRA_DIM = 64 - self.time_decay_w1 = nn.Parameter(torch.zeros(neox_args.hidden_size, TIME_DECAY_EXTRA_DIM).uniform_(-1e-4, 1e-4)) - self.time_decay_w2 = nn.Parameter(torch.zeros(TIME_DECAY_EXTRA_DIM, neox_args.dim_att).uniform_(-1e-4, 1e-4)) + self.time_decay_w1 = nn.Parameter( + torch.zeros(neox_args.hidden_size, TIME_DECAY_EXTRA_DIM).uniform_( + -1e-4, 1e-4 + ) + ) + self.time_decay_w2 = nn.Parameter( + torch.zeros(TIME_DECAY_EXTRA_DIM, neox_args.dim_att).uniform_( + -1e-4, 1e-4 + ) + ) tmp = torch.zeros(neox_args.dim_att) for n in range(neox_args.dim_att): zigzag = ((n + 1) % 3 - 1) * 0.1 tmp[n] = ratio_0_to_1 * (1 - (n / (neox_args.dim_att - 1))) + zigzag - self.time_faaaa = nn.Parameter(tmp.reshape(neox_args.num_attention_heads, neox_args.head_size)) + self.time_faaaa = nn.Parameter( + tmp.reshape(neox_args.num_attention_heads, neox_args.head_size) + ) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - self.receptance = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) + self.receptance = nn.Linear( + neox_args.hidden_size, neox_args.dim_att, bias=False + ) self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False) self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) - self.ln_x = nn.GroupNorm(neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5)*(8**2)) + self.ln_x = nn.GroupNorm( + neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2) + ) def jit_func(self, x): B, T, C = x.size() @@ -120,7 +190,7 @@ def jit_func(self, x): xx = self.time_shift(x) - x xxx = x + xx * self.time_maa_x - xxx = torch.tanh(xxx @ self.time_maa_w1).view(B*T, 5, -1).transpose(0, 1) + xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1) xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1) mw, mk, mv, mr, mg = xxx.unbind(dim=0) @@ -157,10 +227,12 @@ def forward(self, x): return self.jit_func_2(x, g) + class RWKV_ChannelMix(nn.Module): """ Channel Mix layer. The ffn in RWKV """ + def __init__(self, neox_args, layer_number): super().__init__() self.neox_args = neox_args @@ -176,7 +248,9 @@ def __init__(self, neox_args, layer_number): self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) - self.receptance = nn.Linear(neox_args.hidden_size, neox_args.hidden_size, bias=False) + self.receptance = nn.Linear( + neox_args.hidden_size, neox_args.hidden_size, bias=False + ) self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) def forward(self, x): @@ -194,26 +268,27 @@ class RWKVResidualLayer(nn.Module): """ RWKV layer definition """ + def __init__(self, neox_args, layer_number): super().__init__() self.neox_args = neox_args self.layer_number = layer_number self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" - if not hasattr(neox_args, 'dim_att'): + if not hasattr(neox_args, "dim_att"): neox_args.dim_att = neox_args.hidden_size - if not hasattr(neox_args, 'dim_ffn'): + if not hasattr(neox_args, "dim_ffn"): neox_args.dim_ffn = neox_args.hidden_size * 4 assert neox_args.hidden_size % 32 == 0 assert neox_args.dim_att % 32 == 0 assert neox_args.dim_ffn % 32 == 0 - self.neox_args.head_size = neox_args.dim_att // neox_args.num_attention_heads + self.neox_args.head_size = neox_args.dim_att // neox_args.num_attention_heads self.head_size = self.neox_args.head_size self.num_attention_heads = neox_args.num_attention_heads assert neox_args.dim_att % self.num_attention_heads == 0 if neox_args.attention_dropout > 0: - self.drop0 = nn.Dropout(p = neox_args.attention_dropout) + self.drop0 = nn.Dropout(p=neox_args.attention_dropout) self.ln1 = nn.LayerNorm(neox_args.hidden_size) self.ln2 = nn.LayerNorm(neox_args.hidden_size) @@ -223,21 +298,32 @@ def __init__(self, neox_args, layer_number): self.ffn = RWKV_ChannelMix(neox_args, layer_number) if neox_args.attention_dropout > 0: - self.drop0 = nn.Dropout(p = neox_args.attention_dropout) + self.drop0 = nn.Dropout(p=neox_args.attention_dropout) if neox_args.hidden_dropout > 0: - self.drop1 = nn.Dropout(p = neox_args.hidden_dropout) + self.drop1 = nn.Dropout(p=neox_args.hidden_dropout) if layer_number == 0: global wkv_cuda """ Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not. """ - wkv_cuda = load(name="wkv6", sources=["megatron/model/rwkv/cuda/wkv6_op.cpp", - f"megatron/model/rwkv/cuda/wkv6_cuda.cu"], - verbose=True, - extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", - "--extra-device-vectorization", - f"-D_N_={self.neox_args.head_size}", f"-D_T_={self.neox_args.seq_length}"]) + wkv_cuda = load( + name="wkv6", + sources=[ + "megatron/model/rwkv/cuda/wkv6_op.cpp", + f"megatron/model/rwkv/cuda/wkv6_cuda.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={self.neox_args.head_size}", + f"-D_T_={self.neox_args.seq_length}", + ], + ) def forward(self, x): neox_args = self.neox_args @@ -249,7 +335,7 @@ def forward(self, x): x = x + self.att(self.ln1(x)) else: x = self.drop0(x + self.att(self.ln1(x))) - + if self.neox_args.hidden_dropout == 0: x = x + self.ffn(self.ln2(x)) else: @@ -257,12 +343,14 @@ def forward(self, x): return x + class RWKVResidualLayerPipe(RWKVResidualLayer): """ RWKV Pipeline Layer """ + def forward(self, args): assert len(args) == 2 - hidden_states, mask = args + hidden_states, mask = args neox_args = self.neox_args return super().forward(hidden_states), mask diff --git a/megatron/neox_arguments/deepspeed_args.py b/megatron/neox_arguments/deepspeed_args.py index a4d5c6a10..270e67f8c 100644 --- a/megatron/neox_arguments/deepspeed_args.py +++ b/megatron/neox_arguments/deepspeed_args.py @@ -100,8 +100,8 @@ class NeoXArgsDeepspeedConfig(NeoXArgsTemplate): bf16: dict = None """ - Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). - + Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). + Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options """ From c0af56368a717268fe9a514108224ada7c743d58 Mon Sep 17 00:00:00 2001 From: github-actions Date: Fri, 19 Apr 2024 01:35:05 +0000 Subject: [PATCH 13/20] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 1eb3d943e..66f657992 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = e20138c + Default = 7550d64 current git hash of repository @@ -1965,8 +1965,8 @@ Args for deepspeed config Default = None - Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). - + Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). + Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options From 1eb5f51750a4b34eb72ff7dd562d906ae7763a52 Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Fri, 3 May 2024 10:39:44 -0400 Subject: [PATCH 14/20] rwkv logging --- megatron/logging.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/megatron/logging.py b/megatron/logging.py index 6c9b7915e..8a5346726 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -92,19 +92,34 @@ def get_flops(neox_args, iter_time_s) -> float: hidden_size = neox_args.hidden_size num_layers = neox_args.num_layers ckpt_activations_factor = 4 if neox_args.checkpoint_activations else 3 - flops_per_iteration = ( - 24 - * ckpt_activations_factor - * batch_size - * seq_len - * num_layers - * (hidden_size**2) - * ( - 1.0 - + (seq_len / (6.0 * hidden_size)) - + (vocab_size / (16.0 * num_layers * hidden_size)) + if "rwkv" in neox_args.attention_config: + num_heads = neox_args.num_attention_heads + + flops_per_iteration = ( + batch_size * + seq_len * + ( + 78 * hidden_size * hidden_size * num_layers + + 84 * hidden_size * num_layers + + 16 * hidden_size + + 12 * hidden_size * vocab_size + + 18 * hidden_size * hidden_size * num_layers / num_heads + ) + ) + else: + flops_per_iteration = ( + 24 + * ckpt_activations_factor + * batch_size + * seq_len + * num_layers + * (hidden_size**2) + * ( + 1.0 + + (seq_len / (6.0 * hidden_size)) + + (vocab_size / (16.0 * num_layers * hidden_size)) + ) ) - ) return flops_per_iteration / (iter_time_s * world_size) From a599ac7c48eb2a0f2bd7d73070e5d550d33569db Mon Sep 17 00:00:00 2001 From: github-actions Date: Sat, 4 May 2024 17:15:10 +0000 Subject: [PATCH 15/20] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 3d54edba8..72143f166 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 47c93fb + Default = 1103663 current git hash of repository From 682f7e54be011a3f9f9f173cfabdb28f7b1477b1 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Sat, 4 May 2024 13:43:52 -0400 Subject: [PATCH 16/20] Add rwkv version dirname, make hdim 3.5x --- megatron/model/rwkv/__init__.py | 1 - megatron/model/rwkv/v6/__init__.py | 1 + megatron/model/rwkv/{ => v6}/cuda/wkv6_cuda.cu | 0 megatron/model/rwkv/{ => v6}/cuda/wkv6_op.cpp | 0 megatron/model/rwkv/{ => v6}/rwkv.py | 7 ++++--- 5 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 megatron/model/rwkv/v6/__init__.py rename megatron/model/rwkv/{ => v6}/cuda/wkv6_cuda.cu (100%) rename megatron/model/rwkv/{ => v6}/cuda/wkv6_op.cpp (100%) rename megatron/model/rwkv/{ => v6}/rwkv.py (97%) diff --git a/megatron/model/rwkv/__init__.py b/megatron/model/rwkv/__init__.py index c0d8d4ba1..e69de29bb 100644 --- a/megatron/model/rwkv/__init__.py +++ b/megatron/model/rwkv/__init__.py @@ -1 +0,0 @@ -from .rwkv import RWKVResidualLayerPipe, RWKVResidualLayer diff --git a/megatron/model/rwkv/v6/__init__.py b/megatron/model/rwkv/v6/__init__.py new file mode 100644 index 000000000..c0d8d4ba1 --- /dev/null +++ b/megatron/model/rwkv/v6/__init__.py @@ -0,0 +1 @@ +from .rwkv import RWKVResidualLayerPipe, RWKVResidualLayer diff --git a/megatron/model/rwkv/cuda/wkv6_cuda.cu b/megatron/model/rwkv/v6/cuda/wkv6_cuda.cu similarity index 100% rename from megatron/model/rwkv/cuda/wkv6_cuda.cu rename to megatron/model/rwkv/v6/cuda/wkv6_cuda.cu diff --git a/megatron/model/rwkv/cuda/wkv6_op.cpp b/megatron/model/rwkv/v6/cuda/wkv6_op.cpp similarity index 100% rename from megatron/model/rwkv/cuda/wkv6_op.cpp rename to megatron/model/rwkv/v6/cuda/wkv6_op.cpp diff --git a/megatron/model/rwkv/rwkv.py b/megatron/model/rwkv/v6/rwkv.py similarity index 97% rename from megatron/model/rwkv/rwkv.py rename to megatron/model/rwkv/v6/rwkv.py index 39cb9d6fb..5d4e0d144 100644 --- a/megatron/model/rwkv/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -278,7 +278,8 @@ def __init__(self, neox_args, layer_number): if not hasattr(neox_args, "dim_att"): neox_args.dim_att = neox_args.hidden_size if not hasattr(neox_args, "dim_ffn"): - neox_args.dim_ffn = neox_args.hidden_size * 4 + # Make hidden size 3.5x. Round to nearest multiple of 32 until we add hdim rounding logic + neox_args.dim_ffn = int((neox_args.hidden_size * 3.5) // 32 * 32) assert neox_args.hidden_size % 32 == 0 assert neox_args.dim_att % 32 == 0 assert neox_args.dim_ffn % 32 == 0 @@ -310,8 +311,8 @@ def __init__(self, neox_args, layer_number): wkv_cuda = load( name="wkv6", sources=[ - "megatron/model/rwkv/cuda/wkv6_op.cpp", - f"megatron/model/rwkv/cuda/wkv6_cuda.cu", + "megatron/model/rwkv/v6/cuda/wkv6_op.cpp", + f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu", ], verbose=True, extra_cuda_cflags=[ From 921c41a5daeb5a74760c3dde270ecf491efaa189 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Sat, 4 May 2024 13:44:27 -0400 Subject: [PATCH 17/20] pre-commit --- megatron/logging.py | 16 ++++++++-------- tests/cpu_tests/action.yml | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/megatron/logging.py b/megatron/logging.py index 8a5346726..247aeb1b5 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -96,14 +96,14 @@ def get_flops(neox_args, iter_time_s) -> float: num_heads = neox_args.num_attention_heads flops_per_iteration = ( - batch_size * - seq_len * - ( - 78 * hidden_size * hidden_size * num_layers + - 84 * hidden_size * num_layers + - 16 * hidden_size + - 12 * hidden_size * vocab_size + - 18 * hidden_size * hidden_size * num_layers / num_heads + batch_size + * seq_len + * ( + 78 * hidden_size * hidden_size * num_layers + + 84 * hidden_size * num_layers + + 16 * hidden_size + + 12 * hidden_size * vocab_size + + 18 * hidden_size * hidden_size * num_layers / num_heads ) ) else: diff --git a/tests/cpu_tests/action.yml b/tests/cpu_tests/action.yml index a7847d1ec..f8180605f 100644 --- a/tests/cpu_tests/action.yml +++ b/tests/cpu_tests/action.yml @@ -5,7 +5,7 @@ inputs: required: true type: string runs: - using: composite + using: composite steps: - uses: actions/checkout@v4 with: From 8f60a43192b472eb0dd6e898e17e91a7c989f2ee Mon Sep 17 00:00:00 2001 From: github-actions Date: Sat, 4 May 2024 17:44:55 +0000 Subject: [PATCH 18/20] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 72143f166..7dedf8a10 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 1103663 + Default = 921c41a current git hash of repository From 6fb840e9ecdda4ba69034722d4a6cd4a040834cf Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Sun, 5 May 2024 16:00:05 -0400 Subject: [PATCH 19/20] fix bug and set batch size to 32 --- configs/rwkv/170M.yml | 4 ++-- megatron/model/gpt2_model.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/rwkv/170M.yml b/configs/rwkv/170M.yml index f31f2613b..11311f441 100644 --- a/configs/rwkv/170M.yml +++ b/configs/rwkv/170M.yml @@ -1,6 +1,6 @@ { # Parallelism is not yet supported for rwkv - "pipe_parallel_size": 0, + "pipe_parallel_size": 1, "model_parallel_size": 1, "num_layers": 12, @@ -12,7 +12,7 @@ "output_layer_parallelism": "column", "norm": "rmsnorm", "rms_norm_epsilon": 1.0e-5, - "train_micro_batch_size_per_gpu": 1, + "train_micro_batch_size_per_gpu": 32, "attention_config": [[["rwkv"], 12]], diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 89f43c352..9e643874a 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -37,7 +37,7 @@ ParallelLinear, ) from megatron.model.gmlp import GMLPBlock -from megatron.model.rwkv import RWKVResidualLayerPipe +from megatron.model.rwkv.v6 import RWKVResidualLayerPipe from megatron.model.mamba import ParallelMambaResidualLayerPipe from megatron.model.word_embeddings import EmbeddingPipe, SoftEmbedding From dd0138e8cf14024879beab599b8ca7efdc56445f Mon Sep 17 00:00:00 2001 From: github-actions Date: Sun, 5 May 2024 20:00:35 +0000 Subject: [PATCH 20/20] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 7dedf8a10..c8e1492ae 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 921c41a + Default = 6fb840e current git hash of repository