Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

getting RuntimeError: expected scalar type Double but found Float - when running model.train() #5

Open
shobhitagrawal1 opened this issue Dec 19, 2022 · 0 comments

Comments

@shobhitagrawal1
Copy link

Thank you for this module which is highly relevant to my experiment where I have replicates of the same subject at different sites. I am unfortunately getting the "expected scalar type Double but found Float" when I am trying to the train the model after setting up the data, (the same data was used for scVI and it shows no errors, so I am assuming the data is fine.). After searching through the internet I get that this might be due to a type specification in pytorch. Example suggestion from the internet: "You need to cast your tensors to float32, either with dtype='float32' or calling float() on your input tensors.
I would be highly obliged if you could look into this.
cheers
shobhit

Traceback (most recent call last):
File "", line 1, in
File "/home/agrawals/.local/lib/python3.9/site-packages/mrvi/_model.py", line 157, in train
super().train(**train_kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/model/base/_training_mixin.py", line 77, in train
return runner()
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainrunner.py", line 82, in call
self.trainer.fit(self.training_plan, self.data_splitter)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainer.py", line 188, in fit
super().fit(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
self._call_and_handle_interrupt(
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
results = self._run_stage()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
return self._run_train()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1283, in _run_train
self.fit_loop.run()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 271, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
batch_output = self.batch_loop.run(kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 87, in advance
outputs = self.optimizer_loop.run(optimizers, kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 201, in advance
result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 248, in _run_optimization
self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 358, in _optimizer_step
self.trainer._call_lightning_module_hook(
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1550, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/core/module.py", line 1705, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 168, in step
step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 216, in optimizer_step
return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 153, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/optim/optimizer.py", line 113, in wrapper
return func(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/optim/adam.py", line 118, in step
loss = closure()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 138, in _wrap_closure
closure_result = closure()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 146, in call
self._result = self.closure(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 132, in closure
step_output = self._step_fn()
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 407, in _training_step
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1704, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 358, in training_step
return self.model.training_step(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainingplans.py", line 351, in training_step
_, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/train/_trainingplans.py", line 282, in forward
return self.module(*args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_decorators.py", line 33, in auto_transfer_args
return fn(self, *args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_base_module.py", line 276, in forward
return _generic_forward(
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_base_module.py", line 837, in _generic_forward
inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/scvi/module/base/_decorators.py", line 33, in auto_transfer_args
return fn(self, *args, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/mrvi/module.py", line 122, in inference
x_feat = self.x_featurizer(x
)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward
input = module(input)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/agrawals/.local/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Double but found Float

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant