-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
99 lines (88 loc) · 3.27 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#*******************************************************************************
# Imports and Setup
#*******************************************************************************
import cvxpy as cvx
import matplotlib.pyplot as plt
import torch
pastelBlue = "#0072B2"
pastelRed = "#F5615C"
pastelGreen = "#009E73"
pastelPurple = "#8770FE"
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
#*******************************************************************************
# Function Definitions
#*******************************************************************************
def plot_ellipsoid(sigma, mu, ax, color='b', n_points=100):
'''
Plot a 3D ellipsoid.
Args:
sigma: [D x D] covariance matrix tensor
mu: [D] mean tensor
ax: matplotlib axis object
color: ellipsoid color
n_points: number of points to generate to define ellipsoid surface
Return:
ax: matplotlib axis object with 3D ellipsoid plot
'''
eigvals, eigvecs = torch.linalg.eigh(sigma)
idx = eigvals.argsort(descending=True)
eigvals = eigvals[idx]
eigvecs = eigvecs[:, idx]
# compute radii of the ellipsoid
radii = torch.sqrt(eigvals)
# generate points on a unit sphere
u = torch.linspace(0, 2*torch.pi, n_points)
v = torch.linspace(0, torch.pi, n_points)
x = torch.outer(torch.cos(u), torch.sin(v))
y = torch.outer(torch.sin(u), torch.sin(v))
z = torch.outer(torch.ones(u.shape), torch.cos(v))
# transform points to the shape of the ellipsoid
points = torch.vstack([x.flatten(), y.flatten(), z.flatten()]).T @ \
torch.diag(radii)
points = points @ eigvecs.T + mu
points = points.reshape((n_points, n_points, 3))
# plot the ellipsoid
ax.plot_surface(points[:,:,0], points[:,:,1], points[:,:,2],
color=color, alpha = 0.4)
return ax
def min_enclosing_ellipsoid(X):
'''
Find the minimum-volume ellipsoid containing a set of points.
Args:
X: [Ns x D] tensor of data
Returns:
sigma: [D x D] covariance matrix tensor
mu: [D] mean tensor
'''
n, d = X.shape
A = cvx.Variable((d, d), PSD=True)
b = cvx.Variable((d))
constraints = [cvx.norm((X @ A)[i] + b) <= 1 for i in torch.arange(0, n)]
objective = cvx.Minimize(-cvx.log_det(A))
problem = cvx.Problem(objective, constraints)
try:
problem.solve(solver='SCS', verbose=False)
except:
return None, None
A = torch.tensor(A.value, dtype=torch.float32)
b = torch.tensor(b.value, dtype=torch.float32)
sigma = torch.linalg.inv(A.T @ A)
mu = -(sigma @ A.T) @ b
return sigma, mu
def outside_ellipsoid(X, mu, sigma):
'''
Compute Mahalanobis distance using Cholesky factorization.
Args:
X: [Ns x D] tensor of data
mu: [D] mean tensor
sigma: [D x D] covariance matrix tensor
Return:
mahalanobis_dist: the Mahalanobis distance between the data points and
the ellipsoid
'''
chol_factor = torch.linalg.cholesky(sigma)
chol_inv = torch.linalg.solve(chol_factor, torch.eye(sigma.shape[-1]))
mahalanobis_dist = \
torch.norm(chol_inv @ (X - mu)[..., None], dim=1).squeeze() - 1.
return mahalanobis_dist