forked from LilitYolyan/CutPaste
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a96e8c0
commit 0c8cec2
Showing
3 changed files
with
91 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
from torch.distributions import MultivariateNormal, Normal | ||
from torch.distributions.distribution import Distribution | ||
|
||
# Code from https://discuss.pytorch.org/t/kernel-density-estimation-as-loss-function/62261/8 | ||
class GaussianKDE(Distribution): | ||
def __init__(self, X, bw): | ||
""" | ||
X : tensor (n, d) | ||
`n` points with `d` dimensions to which KDE will be fit | ||
bw : numeric | ||
bandwidth for Gaussian kernel | ||
""" | ||
self.X = X | ||
self.bw = bw | ||
self.dims = X.shape[-1] | ||
self.n = X.shape[0] | ||
self.mvn = MultivariateNormal(loc=torch.zeros(self.dims), | ||
covariance_matrix=torch.eye(self.dims)) | ||
|
||
def sample(self, num_samples): | ||
idxs = (np.random.uniform(0, 1, num_samples) * self.n).astype(int) | ||
norm = Normal(loc=self.X[idxs], scale=self.bw) | ||
return norm.sample() | ||
|
||
def score_samples(self, Y, X=None): | ||
"""Returns the kernel density estimates of each point in `Y`. | ||
Parameters | ||
---------- | ||
Y : tensor (m, d) | ||
`m` points with `d` dimensions for which the probability density will | ||
be calculated | ||
X : tensor (n, d), optional | ||
`n` points with `d` dimensions to which KDE will be fit. Provided to | ||
allow batch calculations in `log_prob`. By default, `X` is None and | ||
all points used to initialize KernelDensityEstimator are included. | ||
Returns | ||
------- | ||
log_probs : tensor (m) | ||
log probability densities for each of the queried points in `Y` | ||
""" | ||
if X == None: | ||
X = self.X | ||
log_probs = torch.log( | ||
(self.bw**(-self.dims) * | ||
torch.exp(self.mvn.log_prob( | ||
(X.unsqueeze(1) - Y) / self.bw))).sum(dim=0) / self.n) | ||
|
||
return log_probs | ||
|
||
def log_prob(self, Y): | ||
"""Returns the total log probability of one or more points, `Y`, using | ||
a Multivariate Normal kernel fit to `X` and scaled using `bw`. | ||
Parameters | ||
---------- | ||
Y : tensor (m, d) | ||
`m` points with `d` dimensions for which the probability density will | ||
be calculated | ||
Returns | ||
------- | ||
log_prob : numeric | ||
total log probability density for the queried points, `Y` | ||
""" | ||
|
||
X_chunks = self.X.split(1000) | ||
Y_chunks = Y.split(1000) | ||
|
||
log_prob = 0 | ||
|
||
for x in X_chunks: | ||
for y in Y_chunks: | ||
log_prob += self.score_samples(y, x).sum(dim=0) | ||
|
||
return log_prob | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters