Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'scaling.py' throws torch error: 'function definitions aren't supported' when executed by 'train.py' and other scripts #1834

Open
ChrystianKacki opened this issue Dec 11, 2024 · 7 comments

Comments

@ChrystianKacki
Copy link

I am using torch2.4.1-cuda12.4 icefall docker image to train zipformer model on datasets like LibriSpeech or Common Voice.
When I execute scripts like train.py or export.py, which use module scaling.py, I get following error:

...
torch.jit.frontend.UnsupportedNodeError: function definitions aren't supported:
File "/workspace/icefall/egs/commonvoice/ASR/zipformer/scaling.py", line 882
# these limitations, as limits on the absolute value and the proportion of positive
# values, to limits on the RMS value and the (mean / stddev).
def _abs_to_rms(x):
~~~ <--- HERE
# for normally distributed data, if the expected absolute value is x, the
# expected rms value will be sqrt(pi/2) * x.

It seems that Torch is not supporting function definitions inside instance methods (it says: function definitions aren't supported).
There is also another such fuction in method forward() of Balancer class: _proportion_positive_to_mean(), with functions inside: _atanh(), _approx_inverse_erf().
How could it be fixed?
Temporarily, I tried to extract those functions to the instance level of methods in Balancer class and it worked fine.

@csukuangfj
Copy link
Collaborator

Would you mind posting the complete command and complete logs?

Make sure you have use the latest master inside the docker container.

@ChrystianKacki
Copy link
Author

/workspace/icefall/egs/librispeech/ASR/zipformer/export.py/export.py
--epoch 150
--avg 1
--causal True
--chunk-size 16
--left-context-frames 128
--use-averaged-model False
--jit 1
--tokens $data_dir/lang_bpe_500/tokens.txt
--exp-dir $base_dir

2024-12-10 10:23:42,190 INFO [export.py:439] device: cpu
2024-12-10 10:23:42,201 INFO [export.py:446] {
"avg": 1,
"batch_idx_train": 0,
"best_train_epoch": -1,
"best_train_loss": Infinity,
"best_valid_epoch": -1,
"best_valid_loss": Infinity,
"blank_id": 0,
"causal": true,
"chunk_size": "16",
"cnn_module_kernel": "31,31,15,15,15,31",
"context_size": 2,
"decoder_dim": 512,
"downsampling_factor": "1,2,4,8,4,2",
"encoder_dim": "192,256,384,512,384,256",
"encoder_unmasked_dim": "192,192,256,256,256,192",
"env_info": {
"IP address": "172.17.0.2",
"hostname": "ef67e375e469",
"icefall-git-branch": null,
"icefall-git-date": null,
"icefall-git-sha1": null,
"icefall-path": "/workspace/icefall",
"k2-build-type": "Release",
"k2-git-date": "Thu Sep 5 19:25:17 2024",
"k2-git-sha1": "cf664841c6d93e21e59b40aade84869b76c919c1",
"k2-path": "/opt/conda/lib/python3.11/site-packages/k2/init.py",
"k2-version": "1.24.4",
"k2-with-cuda": true,
"lhotse-path": "/opt/conda/lib/python3.11/site-packages/lhotse/init.py",
"lhotse-version": "1.28.0.dev+git.9648516.clean",
"python-version": "3.11",
"torch-cuda-available": true,
"torch-cuda-version": "12.4",
"torch-version": "2.4.1+cu124"
},
"eos_id": 1,
"epoch": 150,
"exp_dir": "/workspace/model-pl",
"feature_dim": 80,
"feedforward_dim": "512,768,1024,1536,1024,768",
"iter": 0,
"jit": true,
"joiner_dim": 512,
"left_context_frames": "128",
"log_interval": 50,
"num_encoder_layers": "2,2,3,4,3,2",
"num_heads": "4,4,4,8,4,4",
"pos_dim": 48,
"pos_head_dim": "4",
"query_head_dim": "32",
"reset_interval": 200,
"sos_id": 1,
"subsampling_factor": 4,
"tokens": "/workspace/model-pl/data/lang_bpe_500/tokens.txt",
"use_averaged_model": false,
"use_ctc": false,
"use_transducer": true,
"valid_interval": 3000,
"value_head_dim": "12",
"vocab_size": 500,
"warm_step": 2000
}
2024-12-10 10:23:42,201 INFO [export.py:448] About to create model
2024-12-10 10:23:42,479 INFO [checkpoint.py:112] Loading checkpoint from /workspace/model-pl/epoch-150.pt
2024-12-10 10:23:49,919 INFO [export.py:544] Using torch.jit.script
Traceback (most recent call last):
File "/workspace/zipformer/export.py", line 561, in
main()
File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/workspace/zipformer/export.py", line 545, in main
model = torch.jit.script(model)
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_script.py", line 1432, in script
return _script_impl(
^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_script.py", line 1146, in _script_impl
return torch.jit._recursive.create_script_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_recursive.py", line 559, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_script.py", line 649, in _construct
init_fn(script_module)
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
scripted = create_script_module_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_script.py", line 649, in _construct
init_fn(script_module)
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
scripted = create_script_module_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_script.py", line 649, in _construct
init_fn(script_module)
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
scripted = create_script_module_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_recursive.py", line 572, in create_script_module_impl
method_stubs = stubs_fn(nn_module)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_recursive.py", line 895, in infer_methods_to_compile
stubs.append(make_stub_from_method(nn_module, method))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_recursive.py", line 88, in make_stub_from_method
return make_stub(func, method_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/_recursive.py", line 72, in make_stub
ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/frontend.py", line 373, in get_jit_def
return build_def(
^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/frontend.py", line 434, in build_def
return Def(Ident(r, def_name), decl, build_stmts(ctx, body))
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/frontend.py", line 196, in build_stmts
stmts = [build_stmt(ctx, s) for s in stmts]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/frontend.py", line 196, in
stmts = [build_stmt(ctx, s) for s in stmts]
^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/frontend.py", line 407, in call
return method(ctx, node)
^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/frontend.py", line 783, in build_If
build_stmts(ctx, stmt.body),
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/frontend.py", line 196, in build_stmts
stmts = [build_stmt(ctx, s) for s in stmts]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/frontend.py", line 196, in
stmts = [build_stmt(ctx, s) for s in stmts]
^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/jit/frontend.py", line 406, in call
raise UnsupportedNodeError(ctx, node)
torch.jit.frontend.UnsupportedNodeError: function definitions aren't supported:
File "/workspace/icefall/egs/commonvoice/ASR/zipformer/scaling.py", line 882
# these limitations, as limits on the absolute value and the proportion of positive
# values, to limits on the RMS value and the (mean / stddev).
def _abs_to_rms(x):
~~~ <--- HERE
# for normally distributed data, if the expected absolute value is x, the
# expected rms value will be sqrt(pi/2) * x.

And there is also another error which occurs before the error mentioned above in the method forward() of Balancer class in scaling.py
It occurs at line 873 in the condition of if statement: or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())).
I think that self.mem_cutoff is used here as method not as an attribute.

@csukuangfj
Copy link
Collaborator

What is

/workspace/icefall/egs/librispeech/ASR/zipformer/export.py/export.py

Could you post the real command you are using?

Also, could you try to use or are you using the latest master of icefall?

@csukuangfj
Copy link
Collaborator

And there is also another error which occurs before the error mentioned above in the method forward() of Balancer class in scaling.py
It occurs at line 873 in the condition of if statement: or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())).

Balancer should never participate in torch.jit export.

Please tell us whether you have made any changes to icefall.

@ChrystianKacki
Copy link
Author

What is

/workspace/icefall/egs/librispeech/ASR/zipformer/export.py/export.py
Could you post the real command you are using?

Also, could you try to use or are you using the latest master of icefall?

I am sorry, it should be /workspace/icefall/egs/librispeech/ASR/zipformer/export.py. I execute this command on latest running icefall docker image to get model for sherpa Python only version.

And there is also another error which occurs before the error mentioned above in the method forward() of Balancer class in scaling.py
It occurs at line 873 in the condition of if statement: or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())).

Balancer should never participate in torch.jit export.

This error at line 873 occurs in Balancer when I execute train.py of Common Voice recipe. I think that self.mem_cutoff is used here as method not as an attribute.

@csukuangfj
Copy link
Collaborator

So there are two issues, right? Could you post the logs and commands for each issue separately?


I execute this command on latest running icefall docker image

Please answer whether you are using the latest master of icefall. You can use git pull to update the code inside your docker container. Note that the latest icefall docker image does not imply that you are using the latest master of icefall.


Since you are using /workspace/icefall/egs/librispeech/ASR/zipformer/export.py, your log contains commonvoice

/workspace/icefall/egs/commonvoice/ASR/zipformer/scaling.py"

Could you tell us whether you have changed any code?

Would be great if you can provide information in such a way that we can follow your command to reproduce your issue.

@ChrystianKacki
Copy link
Author

So there are two issues, right? Could you post the logs and commands for each issue separately?

Yes, exactly, there are two issues with scaling.py script. Please, give me some time to prepare commands and logs later, because it's late night here in Poland. Thank You for Your friendly support, greetings!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants