From 54b1e6a06be06ee4b3e62d4445257e8d19bb8239 Mon Sep 17 00:00:00 2001 From: Scenic Authors Date: Mon, 6 May 2024 23:55:53 -0700 Subject: [PATCH] No change. PiperOrigin-RevId: 631311997 --- scenic/projects/vivit/model.py | 17 ++++++++++------- scenic/projects/vivit/model_utils.py | 1 - scenic/projects/vivit/trainer.py | 9 +++++---- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/scenic/projects/vivit/model.py b/scenic/projects/vivit/model.py index b1a8ef7f1..fffb9e554 100644 --- a/scenic/projects/vivit/model.py +++ b/scenic/projects/vivit/model.py @@ -947,18 +947,21 @@ def classification_metrics_function(logits, batch, metrics, class_splits, (model_utils.joint_accuracy(logits, one_hot_targets, class_splits, weights), base_model_utils.num_examples(logits, one_hot_targets, weights))) - pairwise_top_five = base_model_utils.psum_metric_normalizer( - (model_utils.joint_top_k( - logits, one_hot_targets, class_splits, k=5, weights=weights), - base_model_utils.num_examples(logits, one_hot_targets, weights))) eval_name = f'{split_names[0]}-{split_names[1]}' evaluated_metrics[f'{eval_name}_accuracy'] = pairwise_acc - evaluated_metrics[f'{eval_name}_accuracy_top_5'] = pairwise_top_five + if self.dataset_meta_data.get('num_classes', -1) > 5: + pairwise_top_five = base_model_utils.psum_metric_normalizer( + (model_utils.joint_top_k( + logits, one_hot_targets, class_splits, k=5, weights=weights), + base_model_utils.num_examples(logits, one_hot_targets, weights))) + evaluated_metrics[f'{eval_name}_accuracy_top_5'] = pairwise_top_five return evaluated_metrics - + metrics = ViViT_CLASSIFICATION_METRICS + if self.dataset_meta_data.get('num_classes', -1) <= 5: + metrics = ViViT_CLASSIFICATION_METRICS_BASIC return functools.partial( classification_metrics_function, - metrics=ViViT_CLASSIFICATION_METRICS, + metrics=metrics, class_splits=self.class_splits, split_names=self.split_names) diff --git a/scenic/projects/vivit/model_utils.py b/scenic/projects/vivit/model_utils.py index 5b4f1174e..f29a9afbf 100644 --- a/scenic/projects/vivit/model_utils.py +++ b/scenic/projects/vivit/model_utils.py @@ -27,7 +27,6 @@ from scenic.common_lib import debug_utils from scenic.model_lib.base_models import model_utils as base_model_utils import scipy -flax.config.update('flax_return_frozendict', True) def reshape_to_1d_factorized(x: jnp.ndarray, axis: int): diff --git a/scenic/projects/vivit/trainer.py b/scenic/projects/vivit/trainer.py index a1e31d760..f02d95644 100644 --- a/scenic/projects/vivit/trainer.py +++ b/scenic/projects/vivit/trainer.py @@ -21,6 +21,7 @@ from absl import logging from clu import metric_writers from clu import periodic_actions +import flax from flax import jax_utils import jax import jax.numpy as jnp @@ -66,6 +67,7 @@ def train( and eval_summary which are dict of metrics. These outputs are used for regression testing. """ + flax.config.update('flax_return_frozendict', True) lead_host = jax.process_index() == 0 # Build the loss_fn, metrics, and flax_model. model = model_cls(config, dataset.meta_data) @@ -110,8 +112,9 @@ def train( restored_train_state = pretrain_utils.restore_pretrained_checkpoint( init_checkpoint_path, train_state, assert_exist=True) elif checkpoint_format == 'big_vision': - restored_train_state = pretrain_utils.convert_big_vision_to_scenic_checkpoint( - init_checkpoint_path, train_state) + restored_train_state = ( + pretrain_utils.convert_big_vision_to_scenic_checkpoint( + init_checkpoint_path, train_state)) # Config dict in big_vision is not the same format as scenic. # Therefore, make sure config match the config of the loaded model! restored_model_cfg = copy.deepcopy(config) @@ -132,7 +135,6 @@ def train( # Replicate the optimzier, state, and rng. train_state = jax_utils.replicate(train_state) del params # Do not keep a copy of the initial params. - # Calculate the total number of training steps. total_steps, steps_per_epoch = train_utils.get_num_training_steps( config, dataset.meta_data) @@ -241,7 +243,6 @@ def train( do_memory_defrag = True except RuntimeError: logging.warn('Memory defragmentation not possible, use the tfrt runtime') - for step in range(start_step + 1, total_steps + 1): with jax.profiler.StepTraceAnnotation('train', step_num=step): train_batch = next(dataset.train_iter)