diff --git a/tests/test_full_nlp.py b/tests/test_full_nlp.py index 8a4145a99f..5785ebb096 100644 --- a/tests/test_full_nlp.py +++ b/tests/test_full_nlp.py @@ -11,7 +11,7 @@ from torchmetrics.classification import MulticlassAccuracy from transformers import BertConfig, BertForMaskedLM, BertForSequenceClassification, BertTokenizerFast -from composer.algorithms import LayerFreezing +from composer.algorithms import GradientClipping from composer.loggers import RemoteUploaderDownloader from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy from composer.models import HuggingFaceModel @@ -233,7 +233,7 @@ def inference_test_helper( @pytest.mark.parametrize( 'model_type,algorithms,save_format', [ - ('tinybert_hf', [LayerFreezing], 'onnx'), + ('tinybert_hf', [GradientClipping], 'onnx'), ('simpletransformer', [], 'torchscript'), ], )