diff --git a/vmoe/data/pp_ops.py b/vmoe/data/pp_ops.py index c8bf68f..809607d 100644 --- a/vmoe/data/pp_ops.py +++ b/vmoe/data/pp_ops.py @@ -380,9 +380,9 @@ def tokenize( pad_value = tokenizer.string_to_id(pad_value) def _pp_tokenize(txt): - if sample_if_multi: - txt = bv_ops_text.ops_general.get_choice( - empty_fallback='', key='t')({'t': txt})['t'] + if sample_if_multi and tf.convert_to_tensor(txt).ndim: + txt = bv_ops_text.ops_general.get_choice(key='t')( + bv_ops_text.ops_general.get_setdefault('t', '')({'t': txt}))['t'] if lower: txt = tf.strings.lower(txt) if sample_if_multi else tf.map_fn(