Skip to content

Commit 50f91e7

Browse files
authored
Merge pull request #78 from GTorlai/master
Added negative log-likelihood in training statistics
2 parents 4225211 + d5ec426 commit 50f91e7

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

qucumber/utils/training_statistics.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,58 @@ def rotate_psi(nn_state, basis, space, unitaries, psi=None):
108108
return psi_r
109109

110110

111+
def NLL(nn_state, samples, space, train_bases=None, **kwargs):
112+
r"""A function for calculating the negative log-likelihood.
113+
114+
:param nn_state: The neural network state (i.e. complex wavefunction or
115+
positive wavefunction).
116+
:type nn_state: WaveFunction
117+
:param samples: Samples to compute the NLL on.
118+
:type samples: torch.Tensor
119+
:param space: The hilbert space of the system.
120+
:type space: torch.Tensor
121+
:param train_bases: An array of bases where measurements were taken.
122+
:type train_bases: np.array(dtype=str)
123+
:param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.
124+
125+
:returns: The Negative Log-Likelihood.
126+
:rtype: torch.Tensor
127+
"""
128+
psi_r = torch.zeros(
129+
2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device
130+
)
131+
NLL = 0.0
132+
unitary_dict = unitaries.create_dict()
133+
Z = nn_state.compute_normalization(space)
134+
eps = 0.000001
135+
if train_bases is None:
136+
for i in range(len(samples)):
137+
NLL -= (cplx.norm_sqr(nn_state.psi(samples[i])) + eps).log()
138+
NLL += Z.log()
139+
else:
140+
for i in range(len(samples)):
141+
# Check whether the sample was measured the reference basis
142+
is_reference_basis = True
143+
# b_ID = 0
144+
for j in range(nn_state.num_visible):
145+
if train_bases[i][j] != "Z":
146+
is_reference_basis = False
147+
break
148+
if is_reference_basis is True:
149+
NLL -= (cplx.norm_sqr(nn_state.psi(samples[i])) + eps).log()
150+
NLL += Z.log()
151+
else:
152+
psi_r = rotate_psi(nn_state, train_bases[i], space, unitary_dict)
153+
# Get the index value of the sample state
154+
ind = 0
155+
for j in range(nn_state.num_visible):
156+
if samples[i, nn_state.num_visible - j - 1] == 1:
157+
ind += pow(2, j)
158+
NLL -= cplx.norm_sqr(psi_r[:, ind]).log().item()
159+
NLL += Z.log()
160+
return NLL / float(len(samples))
161+
162+
111163
def KL(nn_state, target_psi, space, bases=None, **kwargs):
112164
r"""A function for calculating the total KL divergence.
113165

0 commit comments

Comments
 (0)