Skip to content

Commit

Permalink
first cut gamma coefficient - working, results poor
Browse files Browse the repository at this point in the history
  • Loading branch information
kanodiaayush committed Dec 2, 2023
1 parent b1e4263 commit 09d3e78
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 21 deletions.
116 changes: 106 additions & 10 deletions bemb/model/bayesian_coefficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn as nn
from torch.distributions.lowrank_multivariate_normal import LowRankMultivariateNormal
from torch.distributions.gamma import Gamma


class BayesianCoefficient(nn.Module):
Expand All @@ -21,7 +22,8 @@ def __init__(self,
num_obs: Optional[int] = None,
dim: int = 1,
prior_mean: float = 0.0,
prior_variance: Union[float, torch.Tensor] = 1.0
prior_variance: Union[float, torch.Tensor] = 1.0,
distribution: str = 'gaussian'
) -> None:
"""The Bayesian coefficient object represents a learnable tensor mu_i in R^k, where i is from a family (e.g., user, item)
so there are num_classes * num_obs learnable weights in total.
Expand Down Expand Up @@ -63,12 +65,27 @@ def __init__(self,
If a tensor with shape (num_classes, dim) is supplied, supplying a (num_classes, dim) tensor is amount
to specifying a different prior variance for each entry in the coefficient.
Defaults to 1.0.
distribution (str, optional): the distribution of the coefficient. Currently we support 'gaussian' and 'gamma'.
Defaults to 'gaussian'.
"""
super(BayesianCoefficient, self).__init__()
# do we use this at all? TODO: drop self.variation.
assert variation in ['item', 'user', 'constant', 'category']

self.variation = variation

assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}'
if distribution == 'gamma':
assert not obs2prior, 'Gamma distribution is not supported for obs2prior at present.'
mean = 1
variance = 10
assert mean > 0, 'Gamma distribution requires mean > 0'
assert variance > 0, 'Gamma distribution requires variance > 0'
prior_mean = mean**2 / variance
prior_variance = mean / variance

self.distribution = distribution

self.obs2prior = obs2prior
if variation == 'constant' or variation == 'category':
if obs2prior:
Expand All @@ -95,7 +112,8 @@ def __init__(self,
dim=num_obs,
prior_variance=1.0,
H_zero_mask=self.H_zero_mask,
is_H=True) # this is a distribution responsible for the obs2prior H term.
is_H=True,
distribution=self.distribution) # this is a distribution responsible for the obs2prior H term.

else:
self.register_buffer(
Expand All @@ -117,6 +135,14 @@ def __init__(self,
self.variational_mean_flexible = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)

# TOOD(kanodiaayush): initialize the gamma distribution variational mean in a more principled way.
'''
if self.distribution == 'gamma':
# take absolute value of the variational mean.
self.variational_mean_flexible.data = torch.abs(
self.variational_mean_flexible.data)
'''

if self.is_H and self.H_zero_mask is not None:
assert self.H_zero_mask.shape == self.variational_mean_flexible.shape, \
f"The H_zero_mask should have exactly the shape as the H variable, `H_zero_mask`.shape is {self.H_zero_mask.shape}, `H`.shape is {self.variational_mean_flexible.shape} "
Expand Down Expand Up @@ -163,6 +189,9 @@ def variational_mean(self) -> torch.Tensor:
else:
M = self.variational_mean_fixed + self.variational_mean_flexible

if self.distribution == 'gamma':
M = torch.pow(M, 2) + 0.000001

if self.is_H and (self.H_zero_mask is not None):
# a H-variable with zero-entry restriction.
# multiply zeros to entries with H_zero_mask[i, j] = 1.
Expand Down Expand Up @@ -196,7 +225,11 @@ def log_prior(self,
Returns:
torch.Tensor: the log prior of the variable with shape (num_seeds, num_classes).
"""
# p(sample)
# DEBUG_MARKER
'''
print(sample)
print('log_prior')
'''
num_seeds, num_classes, dim = sample.shape
# shape (num_seeds, num_classes)
if self.obs2prior:
Expand All @@ -211,9 +244,46 @@ def log_prior(self,

else:
mu = self.prior_zero_mean
out = LowRankMultivariateNormal(loc=mu,
cov_factor=self.prior_cov_factor,
cov_diag=self.prior_cov_diag).log_prob(sample)

if self.distribution == 'gaussian':
# DEBUG_MARKER
'''
print('sample.shape', sample.shape)
print('gaussian')
print("mu.shape, self.prior_cov_diag.shape")
print(mu.shape, self.prior_cov_diag.shape)
'''
out = LowRankMultivariateNormal(loc=mu,
cov_factor=self.prior_cov_factor,
cov_diag=self.prior_cov_diag).log_prob(sample)
elif self.distribution == 'gamma':
concentration = torch.pow(mu, 2)/self.prior_cov_diag
rate = mu/self.prior_cov_diag
# DEBUG_MARKER
'''
print('sample.shape', sample.shape)
print('gamma')
print("mu.shape, self.prior_cov_diag.shape")
print(mu.shape, self.prior_cov_diag.shape)
print("concentration.shape, rate.shape")
print(concentration.shape, rate.shape)
'''
out = Gamma(concentration=concentration,
rate=rate).log_prob(sample)
# drop the last dim, take the first element over the last dim
out = out[:, :, 0]


# DEBUG_MARKER
'''
print("sample.shape")
print(sample.shape)
print("out.shape")
print(out.shape)
print("num_seeds, num_classes")
print(num_seeds, num_classes)
breakpoint()
'''
assert out.shape == (num_seeds, num_classes)
return out

Expand Down Expand Up @@ -250,6 +320,15 @@ def rsample(self, num_seeds: int = 1) -> Union[torch.Tensor, Tuple[torch.Tensor]
"""
value_sample = self.variational_distribution.rsample(
torch.Size([num_seeds]))
# DEBUG_MARKER
'''
print("rsample")
print(self.distribution)
print("value_sample.shape")
print(value_sample.shape)
breakpoint()
'''
# DEBUG_MARKER
if self.obs2prior:
# sample obs2prior H as well.
H_sample = self.prior_H.rsample(num_seeds=num_seeds)
Expand All @@ -258,12 +337,29 @@ def rsample(self, num_seeds: int = 1) -> Union[torch.Tensor, Tuple[torch.Tensor]
return value_sample

@property
def variational_distribution(self) -> LowRankMultivariateNormal:
def variational_distribution(self) -> Union[LowRankMultivariateNormal, Gamma]:
"""Constructs the current variational distribution of the coefficient from current variational mean and covariance.
"""
return LowRankMultivariateNormal(loc=self.variational_mean,
cov_factor=self.variational_cov_factor,
cov_diag=torch.exp(self.variational_logstd))
if self.distribution == 'gaussian':
return LowRankMultivariateNormal(loc=self.variational_mean,
cov_factor=self.variational_cov_factor,
cov_diag=torch.exp(self.variational_logstd))
elif self.distribution == 'gamma':
# concentration is mean**2 / var (std**2)
concentration = torch.pow(self.variational_mean, 2)/torch.pow(torch.exp(self.variational_logstd), 2)
# rate is mean / var (std**2)
rate = self.variational_mean/torch.pow(torch.exp(self.variational_logstd), 2)
# DEBUG_MARKER
'''
print("self.variational_mean, self.variational_logstd")
print(self.variational_mean, self.variational_logstd)
print("concentration, rate")
print(concentration, rate)
'''
# DEBUG_MARKER
return Gamma(concentration=concentration, rate=rate)
else:
raise NotImplementedError("Unknown variational distribution type.")

@property
def device(self) -> torch.device:
Expand Down
30 changes: 29 additions & 1 deletion bemb/model/bemb.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(self,
num_items: int,
pred_item: bool,
num_classes: int = 2,
coef_dist_dict: Dict[str, str] = {'default' : 'gaussian'},
H_zero_mask_dict: Optional[Dict[str, torch.BoolTensor]] = None,
prior_mean: Union[float, Dict[str, float]] = 0.0,
prior_variance: Union[float, Dict[str, float]] = 1.0,
Expand Down Expand Up @@ -140,6 +141,13 @@ def __init__(self,
lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs
See the doc-string of parse_utility for an example.
coef_dist_dict (Dict[str, str]): a dictionary mapping coefficient name to coefficient distribution name.
The coefficient distribution name can be one of the following:
1. 'gaussian'
2. 'gamma' - obs2prior is not supported for gamma coefficients
If a coefficient does not appear in the dictionary, it will be assigned the distribution specified
by the 'default' key. By default, the default distribution is 'gaussian'.
obs2prior_dict (Dict[str, bool]): a dictionary maps coefficient name (e.g., 'lambda_item')
to a boolean indicating if observable (e.g., item_obs) enters the prior of the coefficient.
Expand Down Expand Up @@ -233,6 +241,7 @@ def __init__(self,
self.utility_formula = utility_formula
self.obs2prior_dict = obs2prior_dict
self.coef_dim_dict = coef_dim_dict
self.coef_dist_dict = coef_dist_dict
if H_zero_mask_dict is not None:
self.H_zero_mask_dict = H_zero_mask_dict
else:
Expand Down Expand Up @@ -348,6 +357,13 @@ def __init__(self,
warnings.warn(f"You provided a dictionary of prior variance, but coefficient {coef_name} is not a key in it. Supply a value for 'default' in the prior_variance dictionary to use that as default value (e.g., prior_variance['default'] = 0.3); now using variance=1.0 since this is not supplied.")
self.prior_variance[coef_name] = 1.0

if coef_name not in self.coef_dist_dict.keys():
if 'default' in self.coef_dist_dict.keys():
self.coef_dist_dict[coef_name] = self.coef_dist_dict['default']
else:
warnings.warn(f"You provided a dictionary of coef_dist_dict, but coefficient {coef_name} is not a key in it. Supply a value for 'default' in the coef_dist_dict dictionary to use that as default value (e.g., coef_dist_dict['default'] = 'gaussian'); now using distribution='gaussian' since this is not supplied.")
self.coef_dist_dict[coef_name] = 'gaussian'

s2 = self.prior_variance[coef_name] if isinstance(
self.prior_variance, dict) else self.prior_variance

Expand All @@ -367,7 +383,8 @@ def __init__(self,
prior_mean=mean,
prior_variance=s2,
H_zero_mask=H_zero_mask,
is_H=False)
is_H=False,
distribution=self.coef_dist_dict[coef_name])
self.coef_dict = nn.ModuleDict(coef_dict)

# ==============================================================================================================
Expand Down Expand Up @@ -653,6 +670,9 @@ def sample_coefficient_dictionary(self, num_seeds: int, deterministic: bool = Fa
"""
sample_dict = dict()
for coef_name, coef in self.coef_dict.items():
'''
print(coef_name)
'''
if deterministic:
sample_dict[coef_name] = coef.variational_distribution.mean.unsqueeze(dim=0) # (1, num_*, dim)
if coef.obs2prior:
Expand Down Expand Up @@ -935,6 +955,8 @@ def reshape_observable(obs, name):
assert obs.shape == (R, P, I, positive_integer)

additive_term = (coef_sample * obs).sum(dim=-1)
if obs_name == 'price_obs':
additive_term *= -1.0

# Type IV: factorized coefficient multiplied by observable.
# e.g., gamma_user * beta_item * price_obs.
Expand Down Expand Up @@ -965,6 +987,8 @@ def reshape_observable(obs, name):
coef = (coef_sample_0 * coef_sample_1).sum(dim=-1)

additive_term = (coef * obs).sum(dim=-1)
if obs_name == 'price_obs':
additive_term *= -1.0

else:
raise ValueError(f'Undefined term type: {term}')
Expand Down Expand Up @@ -1167,6 +1191,8 @@ def reshape_observable(obs, name):
assert obs.shape == (R, total_computation, positive_integer)

additive_term = (coef_sample * obs).sum(dim=-1)
if obs_name == 'price_obs':
additive_term *= -1.0

# Type IV: factorized coefficient multiplied by observable.
# e.g., gamma_user * beta_item * price_obs.
Expand Down Expand Up @@ -1196,6 +1222,8 @@ def reshape_observable(obs, name):
coef = (coef_sample_0 * coef_sample_1).sum(dim=-1)

additive_term = (coef * obs).sum(dim=-1)
if obs_name == 'price_obs':
additive_term *= -1.0

else:
raise ValueError(f'Undefined term type: {term}')
Expand Down
8 changes: 8 additions & 0 deletions bemb/model/bemb_flex_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ def training_step(self, batch, batch_idx):
loss = - elbo
return loss

# DEBUG_MARKER
'''
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
print(f"Epoch {self.current_epoch} has ended")
breakpoint()
'''
# DEBUG_MARKER

def _get_performance_dict(self, batch):
if self.model.pred_item:
log_p = self.model(batch, return_type='log_prob',
Expand Down
26 changes: 16 additions & 10 deletions tutorials/supermarket/configs.yaml
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
device: cuda
# data_dir: /home/tianyudu/Data/MoreSupermarket/tsv/
data_dir: /home/tianyudu/Data/MoreSupermarket/20180101-20191231_13/tsv/
# data_dir: /home/tianyudu/Data/MoreSupermarket/20180101-20191231_13/tsv/
data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/1904/20180101-20191231_44/tsv
# utility: lambda_item
# utility: lambda_item + theta_user * alpha_item
# utility: lambda_item + theta_user * alpha_item + zeta_user * item_obs
utility: lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs
# utility: lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs
utility: lambda_item + theta_user * alpha_item + gamma_user * price_obs
# utility: alpha_item * gamma_user * price_obs
out_dir: ./output/
# model configuration.
coef_dist_dict:
default: 'gaussian'
gamma_user: 'gamma'
obs2prior_dict:
lambda_item: True
theta_user: True
alpha_item: True
lambda_item: False
theta_user: False
alpha_item: False
zeta_user: True
lota_item: True
gamma_user: True
gamma_user: False
beta_item: True
coef_dim_dict:
lambda_item: 1
theta_user: 10
alpha_item: 10
gamma_user: 10
beta_item: 10
gamma_user: 1
beta_item: 1
#### optimization.
trace_log_q: False
shuffle: False
batch_size: 100000
num_epochs: 3
num_epochs: 100
learning_rate: 0.03
num_mc_seeds: 1
num_mc_seeds: 2
7 changes: 7 additions & 0 deletions tutorials/supermarket/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def load_tsv(file_name, data_dir):
bemb = LitBEMBFlex(
# trainings args.
pred_item = configs.pred_item,
coef_dist_dict=configs.coef_dist_dict,
learning_rate=configs.learning_rate,
num_seeds=configs.num_mc_seeds,
# model args, will be passed to BEMB constructor.
Expand All @@ -217,6 +218,12 @@ def load_tsv(file_name, data_dir):
bemb = bemb.to(configs.device)
bemb = run(bemb, dataset_list, batch_size=configs.batch_size, num_epochs=configs.num_epochs)

coeffs = bemb.model.coef_dict['gamma_user'].variational_mean.detach().cpu().numpy()
coeffs = coeffs**2
# give distribution statistics
print('Coefficients statistics:')
print(pd.DataFrame(coeffs).describe())

# ==============================================================================================
# inference example
# ==============================================================================================
Expand Down

0 comments on commit 09d3e78

Please sign in to comment.