diff --git a/src/cryo_sbi/wpa_simulator/noise.py b/src/cryo_sbi/wpa_simulator/noise.py index c1ed5aa..5d269ca 100644 --- a/src/cryo_sbi/wpa_simulator/noise.py +++ b/src/cryo_sbi/wpa_simulator/noise.py @@ -36,7 +36,7 @@ def get_snr(images, snr): ) signal_power = images[:, mask].pow(2).mean().sqrt() # torch.std(image[mask]) - noise_power = signal_power / torch.sqrt(snr.to(images.device)) + noise_power = signal_power / torch.sqrt(torch.pow(10, snr)) return noise_power