Skip to content

Commit 7772f96

Browse files
author
v-chen_data
committed
algo
1 parent 568596b commit 7772f96

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/test_full_nlp.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchmetrics.classification import MulticlassAccuracy
1212
from transformers import BertConfig, BertForMaskedLM, BertForSequenceClassification, BertTokenizerFast
1313

14-
from composer.algorithms import GradientClipping
14+
from composer.algorithms import GatedLinearUnits
1515
from composer.loggers import RemoteUploaderDownloader
1616
from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy
1717
from composer.models import HuggingFaceModel
@@ -233,7 +233,7 @@ def inference_test_helper(
233233
@pytest.mark.parametrize(
234234
'model_type,algorithms,save_format',
235235
[
236-
('tinybert_hf', [GradientClipping(clipping_type='norm', clipping_threshold=1.0)], 'onnx'),
236+
('tinybert_hf', [GatedLinearUnits], 'onnx'),
237237
('simpletransformer', [], 'torchscript'),
238238
],
239239
)
@@ -257,6 +257,7 @@ def test_full_nlp_pipeline(
257257
if onnx_opset_version == None and version.parse(torch.__version__) < version.parse('1.13'):
258258
pytest.skip("Don't test prior PyTorch version's default Opset version.")
259259

260+
algorithms = [algorithm() for algorithm in algorithms]
260261
device = get_device(device)
261262
config = None
262263
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', model_max_length=128)

0 commit comments

Comments
 (0)