-
Notifications
You must be signed in to change notification settings - Fork 182
/
Copy pathutils.py
40 lines (34 loc) · 1.23 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# pyre-strict
"""
This file contains helpers for unittest creation
"""
import torch
# for testing vanilla mlps
def create_normal_pdf_training_data(
input_dim: int, num_data_points: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
x are sampled from a multivariate normal distribution
y are the corresponding pdf values
"""
mean = torch.zeros(input_dim)
sigma = torch.eye(input_dim)
multi_variate_normal = torch.distributions.MultivariateNormal(mean, sigma)
x = multi_variate_normal.sample(
torch.Size((num_data_points,))
) # sample from a mvn distribution
y = torch.exp(
-0.5 * ((x - mean) @ torch.inverse(sigma) * (x - mean)).sum(dim=1)
) / (
# pyre-ignore[58]: `**` is not supported for operand types `Tensor` and `int`.
# PyTorch idiosyncrasy.
torch.sqrt((2 * torch.tensor(3.14)) ** mean.shape[0] * torch.det(sigma))
) # corresponding pdf of mvn
y_corrupted = y + 0.01 * torch.randn(num_data_points) # noise corrupted targets
return x, y_corrupted