11
11
from torchmetrics .classification import MulticlassAccuracy
12
12
from transformers import BertConfig , BertForMaskedLM , BertForSequenceClassification , BertTokenizerFast
13
13
14
- from composer .algorithms import GradientClipping
14
+ from composer .algorithms import GatedLinearUnits
15
15
from composer .loggers import RemoteUploaderDownloader
16
16
from composer .metrics .nlp import LanguageCrossEntropy , MaskedAccuracy
17
17
from composer .models import HuggingFaceModel
@@ -233,7 +233,7 @@ def inference_test_helper(
233
233
@pytest .mark .parametrize (
234
234
'model_type,algorithms,save_format' ,
235
235
[
236
- ('tinybert_hf' , [GradientClipping ( clipping_type = 'norm' , clipping_threshold = 1.0 ) ], 'onnx' ),
236
+ ('tinybert_hf' , [GatedLinearUnits ], 'onnx' ),
237
237
('simpletransformer' , [], 'torchscript' ),
238
238
],
239
239
)
@@ -257,6 +257,7 @@ def test_full_nlp_pipeline(
257
257
if onnx_opset_version == None and version .parse (torch .__version__ ) < version .parse ('1.13' ):
258
258
pytest .skip ("Don't test prior PyTorch version's default Opset version." )
259
259
260
+ algorithms = [algorithm () for algorithm in algorithms ]
260
261
device = get_device (device )
261
262
config = None
262
263
tokenizer = BertTokenizerFast .from_pretrained ('bert-base-uncased' , model_max_length = 128 )
0 commit comments