From b56c1b0184cc833624b048b1ae1d033edbf399d6 Mon Sep 17 00:00:00 2001 From: Daniel Keysers Date: Wed, 21 Aug 2024 02:21:09 -0700 Subject: [PATCH] Fix pytype problem. PiperOrigin-RevId: 665776374 --- vmoe/projects/contrastive/evaluators.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vmoe/projects/contrastive/evaluators.py b/vmoe/projects/contrastive/evaluators.py index 11ee1d9..7e0623a 100644 --- a/vmoe/projects/contrastive/evaluators.py +++ b/vmoe/projects/contrastive/evaluators.py @@ -57,6 +57,12 @@ def __init__( if bv_discriminative is None: raise NotImplementedError( 'Big Vision must be installed to run the discriminative evaluation.') + if ( + not hasattr(data_sharding.mesh, 'devices') + or data_sharding.mesh.devices is None + ): + raise ValueError( + 'data_sharding.mesh must be a Mesh, not an AbstractMesh.') bv_evaluator = bv_discriminative.Evaluator( predict_fn=apply_fn, devices=list(data_sharding.mesh.devices.flatten()), @@ -120,6 +126,12 @@ def __init__( if bv_retrieval is None: raise NotImplementedError( 'Big Vision must be installed to run the retrieval evaluation.') + if ( + not hasattr(data_sharding.mesh, 'devices') + or data_sharding.mesh.devices is None + ): + raise ValueError( + 'data_sharding.mesh must be a Mesh, not an AbstractMesh.') bv_evaluator = bv_retrieval.Evaluator( predict_fn=apply_fn, devices=list(data_sharding.mesh.devices.flatten()),