Skip to content

Commit

Permalink
Merge branch 'gamma_coeff' into bemb_chunked
Browse files Browse the repository at this point in the history
  • Loading branch information
kanodiaayush committed Dec 8, 2023
2 parents 38abd25 + 189b1dc commit 24cb867
Show file tree
Hide file tree
Showing 29 changed files with 3,457 additions and 163 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
*.DS_Store

*.history

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
5 changes: 4 additions & 1 deletion bemb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
__version__ = '0.1.4'
__version__ = '0.1.7'
import bemb.model
import bemb.utils
import bemb.data

from .utils.run_helper_lightning import run
1 change: 1 addition & 0 deletions bemb/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .simulate_choice_dataset import load_simulation_dataset
42 changes: 42 additions & 0 deletions bemb/data/simulate_choice_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Generate a simulated choice dataset for tutorials, unit tests, and debugging.
"""
from typing import List

import numpy as np
import torch
from torch_choice.data import ChoiceDataset


def load_simulation_dataset(num_users: int, num_items: int, data_size: int) -> List[ChoiceDataset]:
user_index = torch.LongTensor(np.random.choice(num_users, size=data_size))
Us = np.arange(num_users)
Is = np.sin(np.arange(num_users) / num_users * 4 * np.pi)
Is = (Is + 1) / 2 * num_items
Is = Is.astype(int)

PREFERENCE = dict((u, i) for (u, i) in zip(Us, Is))

# construct users.
item_index = torch.LongTensor(np.random.choice(num_items, size=data_size))

for idx in range(data_size):
if np.random.rand() <= 0.5:
item_index[idx] = PREFERENCE[int(user_index[idx])]

user_obs = torch.zeros(num_users, num_items)
user_obs[torch.arange(num_users), Is] = 1

item_obs = torch.eye(num_items)

dataset = ChoiceDataset(user_index=user_index, item_index=item_index, user_obs=user_obs, item_obs=item_obs)

idx = np.random.permutation(len(dataset))
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
train_idx = idx[:train_size]
val_idx = idx[train_size: train_size + val_size]
test_idx = idx[train_size + val_size:]

dataset_list = [dataset[train_idx], dataset[val_idx], dataset[test_idx]]
return dataset_list
88 changes: 76 additions & 12 deletions bemb/model/bayesian_coefficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
"""
from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from torch.distributions.lowrank_multivariate_normal import LowRankMultivariateNormal
from torch.distributions.gamma import Gamma


class BayesianCoefficient(nn.Module):
Expand All @@ -21,7 +23,8 @@ def __init__(self,
num_obs: Optional[int] = None,
dim: int = 1,
prior_mean: float = 0.0,
prior_variance: Union[float, torch.Tensor] = 1.0
prior_variance: Union[float, torch.Tensor] = 1.0,
distribution: str = 'gaussian'
) -> None:
"""The Bayesian coefficient object represents a learnable tensor mu_i in R^k, where i is from a family (e.g., user, item)
so there are num_classes * num_obs learnable weights in total.
Expand Down Expand Up @@ -63,12 +66,34 @@ def __init__(self,
If a tensor with shape (num_classes, dim) is supplied, supplying a (num_classes, dim) tensor is amount
to specifying a different prior variance for each entry in the coefficient.
Defaults to 1.0.
distribution (str, optional): the distribution of the coefficient. Currently we support 'gaussian' and 'gamma'.
Defaults to 'gaussian'.
"""
super(BayesianCoefficient, self).__init__()
# do we use this at all? TODO: drop self.variation.
assert variation in ['item', 'user', 'constant', 'category']

self.variation = variation

assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}'
if distribution == 'gamma':
'''
assert not obs2prior, 'Gamma distribution is not supported for obs2prior at present.'
mean = 1.0
variance = 10.0
assert mean > 0, 'Gamma distribution requires mean > 0'
assert variance > 0, 'Gamma distribution requires variance > 0'
# shape (concentration) is mean^2/variance, rate is variance/mean for Gamma distribution.
shape = prior_mean ** 2 / prior_variance
rate = prior_mean / prior_variance
prior_mean = np.log(shape)
prior_variance = rate
'''
prior_mean = np.log(prior_mean)
prior_variance = prior_variance

self.distribution = distribution

self.obs2prior = obs2prior
if variation == 'constant' or variation == 'category':
if obs2prior:
Expand All @@ -89,13 +114,15 @@ def __init__(self,
if self.obs2prior:
# the mean of prior distribution depends on observables.
# initiate a Bayesian Coefficient with shape (dim, num_obs) standard Gaussian.
prior_H_dist = 'gaussian'
self.prior_H = BayesianCoefficient(variation='constant',
num_classes=dim,
obs2prior=False,
dim=num_obs,
prior_variance=1.0,
H_zero_mask=self.H_zero_mask,
is_H=True) # this is a distribution responsible for the obs2prior H term.
is_H=True,
distribution=prior_H_dist) # this is a distribution responsible for the obs2prior H term.

else:
self.register_buffer(
Expand All @@ -114,13 +141,21 @@ def __init__(self,
num_classes, dim) * self.prior_variance)

# create variational distribution.
self.variational_mean_flexible = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)
if self.distribution == 'gaussian':
self.variational_mean_flexible = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)
# TOOD(kanodiaayush): initialize the gamma distribution variational mean in a more principled way.
elif self.distribution == 'gamma':
# initialize using uniform distribution between 0.5 and 1.5
# for a gamma distribution, we store the concentration as log(concentration) = variational_mean_flexible
self.variational_mean_flexible = nn.Parameter(
torch.rand(num_classes, dim) + 0.5, requires_grad=True)

if self.is_H and self.H_zero_mask is not None:
assert self.H_zero_mask.shape == self.variational_mean_flexible.shape, \
f"The H_zero_mask should have exactly the shape as the H variable, `H_zero_mask`.shape is {self.H_zero_mask.shape}, `H`.shape is {self.variational_mean_flexible.shape} "

# for gamma distribution, we store the rate as log(rate) = variational_logstd
self.variational_logstd = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)

Expand Down Expand Up @@ -163,6 +198,10 @@ def variational_mean(self) -> torch.Tensor:
else:
M = self.variational_mean_fixed + self.variational_mean_flexible

if self.distribution == 'gamma':
# M = torch.pow(M, 2) + 0.000001
M = M.exp() / self.variational_logstd.exp()

if self.is_H and (self.H_zero_mask is not None):
# a H-variable with zero-entry restriction.
# multiply zeros to entries with H_zero_mask[i, j] = 1.
Expand Down Expand Up @@ -196,7 +235,11 @@ def log_prior(self,
Returns:
torch.Tensor: the log prior of the variable with shape (num_seeds, num_classes).
"""
# p(sample)
# DEBUG_MARKER
'''
print(sample)
print('log_prior')
'''
num_seeds, num_classes, dim = sample.shape
# shape (num_seeds, num_classes)
if self.obs2prior:
Expand All @@ -211,9 +254,19 @@ def log_prior(self,

else:
mu = self.prior_zero_mean
out = LowRankMultivariateNormal(loc=mu,
cov_factor=self.prior_cov_factor,
cov_diag=self.prior_cov_diag).log_prob(sample)

if self.distribution == 'gaussian':
out = LowRankMultivariateNormal(loc=mu,
cov_factor=self.prior_cov_factor,
cov_diag=self.prior_cov_diag).log_prob(sample)
elif self.distribution == 'gamma':
concentration = torch.exp(mu)
rate = self.prior_variance
out = Gamma(concentration=concentration,
rate=rate).log_prob(sample)
# sum over the last dimension
out = torch.sum(out, dim=-1)

assert out.shape == (num_seeds, num_classes)
return out

Expand Down Expand Up @@ -250,6 +303,7 @@ def rsample(self, num_seeds: int = 1) -> Union[torch.Tensor, Tuple[torch.Tensor]
"""
value_sample = self.variational_distribution.rsample(
torch.Size([num_seeds]))
# DEBUG_MARKER
if self.obs2prior:
# sample obs2prior H as well.
H_sample = self.prior_H.rsample(num_seeds=num_seeds)
Expand All @@ -258,12 +312,22 @@ def rsample(self, num_seeds: int = 1) -> Union[torch.Tensor, Tuple[torch.Tensor]
return value_sample

@property
def variational_distribution(self) -> LowRankMultivariateNormal:
def variational_distribution(self) -> Union[LowRankMultivariateNormal, Gamma]:
"""Constructs the current variational distribution of the coefficient from current variational mean and covariance.
"""
return LowRankMultivariateNormal(loc=self.variational_mean,
cov_factor=self.variational_cov_factor,
cov_diag=torch.exp(self.variational_logstd))
if self.distribution == 'gaussian':
return LowRankMultivariateNormal(loc=self.variational_mean,
cov_factor=self.variational_cov_factor,
cov_diag=torch.exp(self.variational_logstd))
elif self.distribution == 'gamma':
# for a gamma distribution, we store the concentration as log(concentration) = variational_mean_flexible
assert self.variational_mean_fixed == None, 'Gamma distribution does not support fixed mean'
concentration = self.variational_mean_flexible.exp()
# for gamma distribution, we store the rate as log(rate) = variational_logstd
rate = torch.exp(self.variational_logstd)
return Gamma(concentration=concentration, rate=rate)
else:
raise NotImplementedError("Unknown variational distribution type.")

@property
def device(self) -> torch.device:
Expand Down
Loading

0 comments on commit 24cb867

Please sign in to comment.