Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Add freezing suport for make_prediction and make_batch
Browse files Browse the repository at this point in the history
Summary: Add freezing suport for make_prediction and make_batch

Differential Revision: D23486865

fbshipit-source-id: f840f7cad0465baa3a10430068e98858f9b90fbb
  • Loading branch information
Michael Gschwind authored and facebook-github-bot committed Sep 3, 2020
1 parent 21374b3 commit 23c5ee5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion pytext/task/new_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ def torchscript_export(
model.half()
trace = model.trace(inputs)
if "nnpi" in accelerate:
trace._c = torch._C._freeze_module(trace._c)
trace._c = torch._C._freeze_module(
trace._c, preservedAttrs=["make_prediction", "make_batch"]
)
if hasattr(model, "torchscriptify"):
trace = model.torchscriptify(self.data.tensorizers, trace)
if padding_control is not None:
Expand Down
4 changes: 3 additions & 1 deletion pytext/task/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,9 @@ def torchscript_export(self, model, export_path=None, **options):
else:
trace = jit.trace(model.encoder1, (inputs[0],))
if "nnpi" in accelerate:
trace._c = torch._C._freeze_module(trace._c)
trace._c = torch._C._freeze_module(
trace._c, preservedAttrs=["make_prediction", "make_batch"]
)
if hasattr(model, "torchscriptify"):
trace = model.torchscriptify(
self.data.tensorizers, trace, self.trace_both_encoders
Expand Down

0 comments on commit 23c5ee5

Please sign in to comment.