@@ -108,6 +108,58 @@ def rotate_psi(nn_state, basis, space, unitaries, psi=None):
108108 return psi_r
109109
110110
111+ def NLL (nn_state , samples , space , train_bases = None , ** kwargs ):
112+ r"""A function for calculating the negative log-likelihood.
113+
114+ :param nn_state: The neural network state (i.e. complex wavefunction or
115+ positive wavefunction).
116+ :type nn_state: WaveFunction
117+ :param samples: Samples to compute the NLL on.
118+ :type samples: torch.Tensor
119+ :param space: The hilbert space of the system.
120+ :type space: torch.Tensor
121+ :param train_bases: An array of bases where measurements were taken.
122+ :type train_bases: np.array(dtype=str)
123+ :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.
124+
125+ :returns: The Negative Log-Likelihood.
126+ :rtype: torch.Tensor
127+ """
128+ psi_r = torch .zeros (
129+ 2 , 1 << nn_state .num_visible , dtype = torch .double , device = nn_state .device
130+ )
131+ NLL = 0.0
132+ unitary_dict = unitaries .create_dict ()
133+ Z = nn_state .compute_normalization (space )
134+ eps = 0.000001
135+ if train_bases is None :
136+ for i in range (len (samples )):
137+ NLL -= (cplx .norm_sqr (nn_state .psi (samples [i ])) + eps ).log ()
138+ NLL += Z .log ()
139+ else :
140+ for i in range (len (samples )):
141+ # Check whether the sample was measured the reference basis
142+ is_reference_basis = True
143+ # b_ID = 0
144+ for j in range (nn_state .num_visible ):
145+ if train_bases [i ][j ] != "Z" :
146+ is_reference_basis = False
147+ break
148+ if is_reference_basis is True :
149+ NLL -= (cplx .norm_sqr (nn_state .psi (samples [i ])) + eps ).log ()
150+ NLL += Z .log ()
151+ else :
152+ psi_r = rotate_psi (nn_state , train_bases [i ], space , unitary_dict )
153+ # Get the index value of the sample state
154+ ind = 0
155+ for j in range (nn_state .num_visible ):
156+ if samples [i , nn_state .num_visible - j - 1 ] == 1 :
157+ ind += pow (2 , j )
158+ NLL -= cplx .norm_sqr (psi_r [:, ind ]).log ().item ()
159+ NLL += Z .log ()
160+ return NLL / float (len (samples ))
161+
162+
111163def KL (nn_state , target_psi , space , bases = None , ** kwargs ):
112164 r"""A function for calculating the total KL divergence.
113165
0 commit comments