Skip to content

Commit b436d09

Browse files
authored
Fix the CI pipeline for the latest PyTorch release. (microsoft#759)
1 parent f1abea1 commit b436d09

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

test/test_processing.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -92,17 +92,17 @@ def test_gpt2_preprocessing(self):
9292
merges_file=util.get_test_data_file("data", "gpt2.merges.txt"),
9393
)
9494
inputs = tok.forward(test_sentence)
95-
pnp.export(tok, test_sentence, opset_version=12, output_path="temp_tok2.onnx")
95+
pnp.export(tok, test_sentence, opset_version=14, output_path="temp_tok2.onnx")
9696

9797
with open("temp_gpt2lmh.onnx", "wb") as f:
9898
torch.onnx.export(
99-
gpt2_m, inputs, f, opset_version=12, do_constant_folding=False
99+
gpt2_m, inputs, f, opset_version=14, do_constant_folding=False
100100
)
101-
pnp.export(gpt2_m, *inputs, opset_version=12, do_constant_folding=False)
101+
pnp.export(gpt2_m, *inputs, opset_version=14, do_constant_folding=False)
102102
full_model = pnp.SequentialProcessingModule(tok, gpt2_m)
103103
expected = full_model.forward(test_sentence)
104104
model = pnp.export(
105-
full_model, test_sentence, opset_version=12, do_constant_folding=False
105+
full_model, test_sentence, opset_version=14, do_constant_folding=False
106106
)
107107
mfunc = OrtPyFunction.from_model(model)
108108
actuals = mfunc(test_sentence)

0 commit comments

Comments
 (0)