diff --git a/vmoe/projects/contrastive/evaluators.py b/vmoe/projects/contrastive/evaluators.py index 11ee1d9..7e0623a 100644 --- a/vmoe/projects/contrastive/evaluators.py +++ b/vmoe/projects/contrastive/evaluators.py @@ -57,6 +57,12 @@ def __init__( if bv_discriminative is None: raise NotImplementedError( 'Big Vision must be installed to run the discriminative evaluation.') + if ( + not hasattr(data_sharding.mesh, 'devices') + or data_sharding.mesh.devices is None + ): + raise ValueError( + 'data_sharding.mesh must be a Mesh, not an AbstractMesh.') bv_evaluator = bv_discriminative.Evaluator( predict_fn=apply_fn, devices=list(data_sharding.mesh.devices.flatten()), @@ -120,6 +126,12 @@ def __init__( if bv_retrieval is None: raise NotImplementedError( 'Big Vision must be installed to run the retrieval evaluation.') + if ( + not hasattr(data_sharding.mesh, 'devices') + or data_sharding.mesh.devices is None + ): + raise ValueError( + 'data_sharding.mesh must be a Mesh, not an AbstractMesh.') bv_evaluator = bv_retrieval.Evaluator( predict_fn=apply_fn, devices=list(data_sharding.mesh.devices.flatten()),