Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 452: Added the assertion for consistency check and evaluation frequency check #788

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,30 @@

# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an
# unsigned int), while RandomState.randint only accepts and returns signed ints.
MAX_INT32 = 2**31
MIN_INT32 = -MAX_INT32
MAX_UINT32 = 2**32-1
MIN_UINT32 = 0

SeedType = Union[int, list, np.ndarray]


def _signed_to_unsigned(seed: SeedType) -> SeedType:
if isinstance(seed, int):
return seed % 2**32
return seed % MAX_UINT32
if isinstance(seed, list):
return [s % 2**32 for s in seed]
return [s % MAX_UINT32 for s in seed]
if isinstance(seed, np.ndarray):
return np.array([s % 2**32 for s in seed.tolist()])
return np.array([s % MAX_UINT32 for s in seed.tolist()])


def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32)
return [new_seed, data]


def _split(seed: SeedType, num: int = 2) -> SeedType:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2])


def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
Expand Down Expand Up @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType:
def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
if FLAGS.framework == 'jax':
_check_jax_install()
return jax_rng.PRNGKey(seed)
return jax_rng.key(seed)
return _PRNGKey(seed)
25 changes: 20 additions & 5 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
--num_tuning_trials=3 \
--experiment_dir=/home/znado/experiment_dir \
--experiment_name=baseline
--skip_eval = True/False
"""

import datetime
Expand Down Expand Up @@ -88,6 +89,10 @@
flags.DEFINE_string('librispeech_tokenizer_vocab_path',
'',
'Location to librispeech tokenizer.')
flags.DEFINE_boolean(
'skip_eval',
True,
help='True to skip eval on the datasets and false otherwise')

flags.DEFINE_enum(
'framework',
Expand Down Expand Up @@ -327,7 +332,10 @@ def train_once(
train_state['last_step_end_time'] = global_start_time

logging.info('Starting training loop.')
goals_reached = (
if FLAGS.skip_eval == True:
goals_reached = (train_state['validation_goal_reached'])
else:
goals_reached = (
train_state['validation_goal_reached'] and
train_state['test_goal_reached'])
while train_state['is_time_remaining'] and \
Expand Down Expand Up @@ -402,9 +410,12 @@ def train_once(
train_state['test_goal_reached'] = (
workload.has_reached_test_target(latest_eval_result) or
train_state['test_goal_reached'])
goals_reached = (
train_state['validation_goal_reached'] and
train_state['test_goal_reached'])
if FLAGS.skip_eval == True:
goals_reached = (train_state['validation_goal_reached'])
else:
goals_reached = (
train_state['validation_goal_reached'] and
train_state['test_goal_reached'])
# Save last eval time.
eval_end_time = get_time()
train_state['last_eval_time'] = eval_end_time
Expand Down Expand Up @@ -487,7 +498,11 @@ def train_once(
preemption_count=preemption_count,
checkpoint_dir=log_dir,
save_intermediate_checkpoints=FLAGS.save_intermediate_checkpoints)

assert(abs(metrics['eval_results'][-1][1]['total_duration'] -
(train_state['accumulated_submission_time'] +
train_state['accumulated_logging_time'] +
train_state['accumulated_eval_time']) <= 10))
assert(int(train_state['accumulated_submission_time'] // workload.eval_period_time_sec) <= len(metrics['eval_results']) + 2)
return train_state['accumulated_submission_time'], metrics


Expand Down
126 changes: 126 additions & 0 deletions tests/test_evals_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
import sys
import copy
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
from absl import logging
from collections import namedtuple
import json
import jax
from algorithmic_efficiency import halton
from algorithmic_efficiency import random_utils as prng
from algorithmic_efficiency.profiler import PassThroughProfiler
from algorithmic_efficiency.workloads import workloads
import submission_runner
import reference_algorithms.development_algorithms.mnist.mnist_pytorch.submission as submission_pytorch
import reference_algorithms.development_algorithms.mnist.mnist_jax.submission as submission_jax

try:
import jax.random as jax_rng
except (ImportError, ModuleNotFoundError):
logging.warning(
'Could not import jax.random for the submission runner, falling back to '
'numpy random_utils.')
jax_rng = None

FLAGS = flags.FLAGS
FLAGS(sys.argv)

class Hyperparameters:
def __init__(self):
self.learning_rate = 0.0005
self.one_minus_beta_1 = 0.05
self.beta2 = 0.999
self.weight_decay = 0.01
self.epsilon = 1e-25
self.label_smoothing = 0.1
self.dropout_rate = 0.1

class CheckTime(parameterized.TestCase):
"""Tests to check if submission_time + eval_time + logging_time ~ total _wallclock_time """
rng_seed = 0

@parameterized.named_parameters(
*[ dict(
testcase_name = 'mnist_pytorch',
framework = 'pytorch',
init_optimizer_state=submission_pytorch.init_optimizer_state,
update_params=submission_pytorch.update_params,
data_selection=submission_pytorch.data_selection,
rng = prng.PRNGKey(rng_seed))],

*[
dict(
testcase_name = 'mnist_jax',
framework = 'jax',
init_optimizer_state=submission_jax.init_optimizer_state,
update_params=submission_jax.update_params,
data_selection=submission_jax.data_selection,
rng = jax_rng.PRNGKey(rng_seed) if jax_rng else None,
),
]
)
def test_train_once_time_consistency(self, framework, init_optimizer_state, update_params, data_selection, rng):
"""Test to check the consistency of timing metrics."""
rng_seed = 0
#rng = jax.PRNGKey(rng_seed)
#rng, _ = prng.split(rng, 2)
workload_metadata = copy.deepcopy(workloads.WORKLOADS["mnist"])
workload_metadata['workload_path'] = os.path.join(
workloads.BASE_WORKLOADS_DIR,
workload_metadata['workload_path'] + '_' + framework,
'workload.py')
workload = workloads.import_workload(
workload_path=workload_metadata['workload_path'],
workload_class_name=workload_metadata['workload_class_name'],
workload_init_kwargs={})

Hp = namedtuple("Hp",["dropout_rate", "learning_rate", "one_minus_beta_1", "weight_decay", "beta2", "warmup_factor", "epsilon" ])
hp1 = Hp(0.1,0.0017486387539278373,0.06733926164,0.9955159689799007,0.08121616522670176, 0.02, 1e-25)
HPARAMS = {
"dropout_rate": 0.1,
"learning_rate": 0.0017486387539278373,
"one_minus_beta_1": 0.06733926164,
"beta2": 0.9955159689799007,
"weight_decay": 0.08121616522670176,
"warmup_factor": 0.02,
"epsilon" : 1e-25
}


accumulated_submission_time, metrics = submission_runner.train_once(
workload = workload,
workload_name="mnist",
global_batch_size = 32,
global_eval_batch_size = 256,
data_dir = '~/tensorflow_datasets', # not sure
imagenet_v2_data_dir = None,
hyperparameters= hp1,
init_optimizer_state = init_optimizer_state,
update_params = update_params,
data_selection = data_selection,
rng = rng,
rng_seed = 0,
profiler= PassThroughProfiler(),
max_global_steps=500)


# Example: Check if total time roughly equals to submission_time + eval_time + logging_time
total_logged_time = (metrics['eval_results'][-1][1]['total_duration']
- (accumulated_submission_time +
metrics['eval_results'][-1][1]['accumulated_logging_time'] +
metrics['eval_results'][-1][1]['accumulated_eval_time']))

# Use a tolerance for floating-point arithmetic
tolerance = 10
self.assertAlmostEqual(total_logged_time, 0, delta=tolerance,
msg="Total wallclock time does not match the sum of submission, eval, and logging times.")

# Check if the expected number of evaluations occurred
expected_evals = int(accumulated_submission_time // workload.eval_period_time_sec)
self.assertTrue(expected_evals <= len(metrics['eval_results']) + 2,
f"Number of evaluations {len(metrics['eval_results'])} exceeded the expected number {expected_evals + 2}.")

if __name__ == '__main__':
absltest.main()