From ad2038712eb8e75ff6719df28b2d936900941274 Mon Sep 17 00:00:00 2001 From: Joan Puigcerver Date: Mon, 20 Nov 2023 07:11:32 -0800 Subject: [PATCH] Update EEE paper fine-tuning config after internal updates. PiperOrigin-RevId: 584017588 --- .../eee_s32_last2_ilsvrc2012_ft_cifar100.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/vmoe/configs/eee_paper/eee_s32_last2_ilsvrc2012_ft_cifar100.py b/vmoe/configs/eee_paper/eee_s32_last2_ilsvrc2012_ft_cifar100.py index 9a884b1..6d7179b 100644 --- a/vmoe/configs/eee_paper/eee_s32_last2_ilsvrc2012_ft_cifar100.py +++ b/vmoe/configs/eee_paper/eee_s32_last2_ilsvrc2012_ft_cifar100.py @@ -86,10 +86,23 @@ def get_config(): config.description = 'EEE-S/32, K=1, M=2, Last 2' config.train_steps = 2_000 config.initialization = ml_collections.ConfigDict({ - 'name': 'initialize_from_vmoe_release', + 'name': 'initialize_from_vmoe', 'prefix': 'gs://vmoe_checkpoints/eee_s32_last2_ilsvrc2012', - 'keep': ['head'], - 'reshape': ['Moe/Router/dense'], + 'rules': [ + ('head', ''), # Do not restore the head params. + # We pre-trained on 224px and are finetuning on 128px. + # Resize positional embeddings. + ('^(.*/pos_embedding)$', r'params/\1', 'vit_zoom'), + # Reshape router params to the appropriate shape for EEE. + ('^(.*/Moe/Router/dense/.*)$', r'params/\1', 'reshape'), + # Restore the rest of parameters without any transformation. + ('^(.*)$', r'params/\1'), + ], + # We are not initializing several arrays from the new train state, do not + # raise an exception. + 'raise_if_target_unmatched': False, + # Partition MoE parameters when reading from the checkpoint. + 'axis_resources_regexes': [('Moe/Mlp/.*', ('expert',))], }) config.model = ml_collections.ConfigDict({ 'name': 'VisionTransformerMoeEnsemble', @@ -164,7 +177,6 @@ def get_config(): config.save_checkpoint = ml_collections.ConfigDict() config.save_checkpoint.every_steps = 1_000 config.save_checkpoint.keep_last = 1 - config.save_checkpoint.num_shards = 32 # Target number of checkpoint shards. config.save_checkpoint.wait_seconds = 300 # Report training progress every 100 steps. config.report_progress = ml_collections.ConfigDict()