Skip to content

Commit

Permalink
Add option to use tf_data_service_config for validation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 645350562
  • Loading branch information
Johannes Gasteiger authored and tensorflower-gardener committed Jun 21, 2024
1 parent 00f9191 commit d1167ca
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion tensorflow_gnn/runner/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def run(*,
train_padding: Optional[GraphTensorPadding] = None,
valid_padding: Optional[GraphTensorPadding] = None,
tf_data_service_config: Optional[TFDataServiceConfig] = None,
use_data_service_for_validation: bool = False,
steps_per_execution: Optional[int] = None,
run_eagerly: bool = False):
"""Runs training (and validation) of a model on task(s) with the given data.
Expand Down Expand Up @@ -474,6 +475,11 @@ def run(*,
runtime reducing input bottlenecks for model training. Particularly for
training on accelerators consider enabling it. For more info please see:
https://www.tensorflow.org/api_docs/python/tf/data/experimental/service.
use_data_service_for_validation: Whether to use tf.data service for
validation, in addition to training. Use with caution! Many ShardingPolicy
values do not visit every sample exactly once, which is critical for
validation. Increasing `validation_freq` of the trainer is another way to
reduce the fraction of time spent on validation.
steps_per_execution: The number of batches to run during each training
iteration. If not set, for TPU strategy default to 100 and to `None`
otherwise.
Expand Down Expand Up @@ -561,11 +567,19 @@ def apply_fn(
tf_data_service_config)

if validate:
# TFTrainer doesn't support using a different tf_data_service_config for
# validation than for training (see b/346691297#comment5). We therefore use
# the same config for both or None.
valid_tf_data_service_config = (
tf_data_service_config if use_data_service_for_validation else None
)
valid_ds_provider = _WrappedDatasetProvider(
valid_apply_fn,
valid_ds_provider,
drop_remainder,
global_batch_size)
global_batch_size,
valid_tf_data_service_config,
)

def adapted_model_fn():
xs, *_ = preprocess_model.output
Expand Down

0 comments on commit d1167ca

Please sign in to comment.