From 3032cf4fd302ae57b36f99080562e5fa7777d45f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 5 Jul 2022 08:11:28 -0700 Subject: [PATCH] default normalization to rmsnorm --- retro_pytorch/retro_pytorch.py | 2 ++ setup.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/retro_pytorch/retro_pytorch.py b/retro_pytorch/retro_pytorch.py index 344c2dc..3fd0eba 100644 --- a/retro_pytorch/retro_pytorch.py +++ b/retro_pytorch/retro_pytorch.py @@ -492,6 +492,8 @@ def __init__( # for deepnet, residual scales # follow equation in Figure 2. in https://arxiv.org/abs/2203.00555 + norm_klass = default(norm_klass, RMSNorm) + if use_deepnet: enc_scale_residual = default(enc_scale_residual, 0.81 * ((enc_depth ** 4) * dec_depth) ** .0625) dec_scale_residual = default(dec_scale_residual, (3 * dec_depth) ** 0.25) diff --git a/setup.py b/setup.py index 87eaecc..63edf1b 100644 --- a/setup.py +++ b/setup.py @@ -3,9 +3,10 @@ setup( name = 'retro-pytorch', packages = find_packages(exclude=[]), - version = '0.3.7', + version = '0.3.8', license='MIT', description = 'RETRO - Retrieval Enhanced Transformer - Pytorch', + long_description_content_type = 'text/markdown', author = 'Phil Wang', author_email = 'lucidrains@gmail.com', url = 'https://github.com/lucidrains/RETRO-pytorch',