Skip to content

Commit

Permalink
Merge branch 'main' of github.com:gsbDBI/torch-choice
Browse files Browse the repository at this point in the history
  • Loading branch information
TianyuDu committed Sep 29, 2022
2 parents 1a4058a + 0c8574f commit 239cc71
Show file tree
Hide file tree
Showing 4 changed files with 810 additions and 6 deletions.
31 changes: 30 additions & 1 deletion torch_choice/model/conditional_logit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def __init__(self,
coef_variation_dict: Dict[str, str],
num_param_dict: Optional[Dict[str, int]]=None,
num_items: Optional[int]=None,
num_users: Optional[int]=None
num_users: Optional[int]=None,
regularization: Optional[str]=None,
regularization_weight: Optional[float]=None
) -> None:
"""
Args:
Expand Down Expand Up @@ -77,6 +79,16 @@ def __init__(self,
as the `coef_variation_dict`. Values of `num_param_dict` records numbers of features in each kind of variable.
If None is supplied, num_param_dict will be a dictionary with the same keys as the `coef_variation_dict` dictionary
and values of all ones. Default to be None.
regularization (Optional[str]): this argument takes values from {'L1', 'L2', None}, which specifies the type of
regularization added to the log-likelihood.
- 'L1' will subtract regularization_weight * 1-norm of parameters from the log-likelihood.
- 'L2' will subtract regularization_weight * 2-norm of parameters from the log-likelihood.
- None does not modify the log-likelihood.
Defaults to None.
regularization_weight (Optional[float]): the weight of parameter norm subtracted from the log-likelihood.
This term controls the strength of regularization. This argument is required if and only if regularization
is not None.
Defaults to None.
"""
super(ConditionalLogitModel, self).__init__()

Expand All @@ -93,6 +105,14 @@ def __init__(self,
self.num_items = num_items
self.num_users = num_users

self.regularization = regularization
assert self.regularization in ['L1', 'L2', None], f"Provided regularization={self.regularization} is not allowed, allowed values are ['L1', 'L2', None]."
self.regularization_weight = regularization_weight
if (self.regularization is not None) and (self.regularization_weight is None):
raise ValueError(f'You specified regularization type {self.regularization} without providing regularization_weight.')
if (self.regularization is None) and (self.regularization_weight is not None):
raise ValueError(f'You specified no regularization but you provide regularization_weight={self.regularization_weight}, you should leave regularization_weight as None if you do not want to regularize the model.')

# check number of parameters specified are all positive.
for var_type, num_params in self.num_param_dict.items():
assert num_params > 0, f'num_params needs to be positive, got: {num_params}.'
Expand Down Expand Up @@ -210,6 +230,15 @@ def negative_log_likelihood(self, batch: ChoiceDataset, y: torch.Tensor, is_trai
nll = - logP[torch.arange(len(y)), y].sum()
return nll

def loss(self, *args, **kwargs):
"""The loss function to be optimized. This is a wrapper of `negative_log_likelihood` + regularization loss if required."""
nll = self.negative_log_likelihood(*args, **kwargs)
if self.regularization is not None:
L = {'L1': 1, 'L2': 2}[self.regularization]
for param in self.parameters():
nll += self.regularization_weight * torch.norm(param, p=L)
return nll

@property
def device(self) -> torch.device:
"""Returns the device of the coefficient.
Expand Down
36 changes: 35 additions & 1 deletion torch_choice/model/nested_logit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def __init__(self,
item_coef_variation_dict: Dict[str, str],
item_num_param_dict: Dict[str, int],
num_users: Optional[int]=None,
shared_lambda: bool=False
shared_lambda: bool=False,
regularization: Optional[str]=None,
regularization_weight: Optional[float]=None
) -> None:
"""Initialization method of the nested logit model.
Expand Down Expand Up @@ -52,6 +54,18 @@ def __init__(self,
If set to True, a single lambda will be learned for all categories, otherwise, the
model learns an individual lambda for each category.
Defaults to False.
regularization (Optional[str]): this argument takes values from {'L1', 'L2', None}, which specifies the type of
regularization added to the log-likelihood.
- 'L1' will subtract regularization_weight * 1-norm of parameters from the log-likelihood.
- 'L2' will subtract regularization_weight * 2-norm of parameters from the log-likelihood.
- None does not modify the log-likelihood.
Defaults to None.
regularization_weight (Optional[float]): the weight of parameter norm subtracted from the log-likelihood.
This term controls the strength of regularization. This argument is required if and only if regularization
is not None.
Defaults to None.
"""
super(NestedLogitModel, self).__init__()
self.category_to_item = category_to_item
Expand Down Expand Up @@ -85,6 +99,14 @@ def __init__(self,
# used to warn users if forgot to call clamp.
self._clamp_called_flag = True

self.regularization = regularization
assert self.regularization in ['L1', 'L2', None], f"Provided regularization={self.regularization} is not allowed, allowed values are ['L1', 'L2', None]."
self.regularization_weight = regularization_weight
if (self.regularization is not None) and (self.regularization_weight is None):
raise ValueError(f'You specified regularization type {self.regularization} without providing regularization_weight.')
if (self.regularization is None) and (self.regularization_weight is not None):
raise ValueError(f'You specified no regularization but you provide regularization_weight={self.regularization_weight}, you should leave regularization_weight as None if you do not want to regularize the model.')

@property
def num_params(self) -> int:
"""Get the total number of parameters. For example, if there is only an user-specific coefficient to be multiplied
Expand Down Expand Up @@ -289,6 +311,18 @@ def negative_log_likelihood(self,
nll = - logP[torch.arange(len(y)), y].sum()
return nll

def loss(self, *args, **kwargs):
"""The loss function to be optimized. This is a wrapper of `negative_log_likelihood` + regularization loss if required."""
nll = self.negative_log_likelihood(*args, **kwargs)
if self.regularization is not None:
L = {'L1': 1, 'L2': 2}[self.regularization]
for name, param in self.named_parameters():
if name == 'lambda_weight':
# we don't regularize the lambda term, we only regularize coefficients.
continue
nll += self.regularization_weight * torch.norm(param, p=L)
return nll

@property
def device(self) -> torch.device:
"""Returns the device of the coefficient.
Expand Down
16 changes: 12 additions & 4 deletions torch_choice/utils/run_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch_choice.model.nested_logit_model import NestedLogitModel


def run(model, dataset, batch_size=-1, learning_rate=0.01, num_epochs=5000):
def run(model, dataset, dataset_test=None, batch_size=-1, learning_rate=0.01, num_epochs=5000):
"""All in one script for the model training and result presentation."""
assert isinstance(model, ConditionalLogitModel) or isinstance(model, NestedLogitModel), \
f'A model of type {type(model)} is not supported by this runner.'
Expand All @@ -33,18 +33,26 @@ def run(model, dataset, batch_size=-1, learning_rate=0.01, num_epochs=5000):
ll, count = 0.0, 0.0
for batch in data_loader:
item_index = batch['item'].item_index if isinstance(model, NestedLogitModel) else batch.item_index
loss = model.negative_log_likelihood(batch, item_index)
# the model.loss returns negative log-likelihood + regularization term.
loss = model.loss(batch, item_index)

ll -= loss.detach().item()# * len(batch)
count += len(batch)
if e % (num_epochs // 10) == 0:
# record log-likelihood.
ll -= model.negative_log_likelihood(batch, item_index).detach().item() # * len(batch)
count += len(batch)

optimizer.zero_grad()
loss.backward()
optimizer.step()

# ll /= count
if e % (num_epochs // 10) == 0:
print(f'Epoch {e}: Log-likelihood={ll}')

if dataset_test is not None:
test_ll = - model.negative_log_likelihood(dataset_test, dataset_test.item_index).detach().item()
print('Test set log-likelihood: ', test_ll)

# current methods of computing standard deviation will corrupt the model, load weights into another model for returning.
state_dict = deepcopy(model.state_dict())
trained_model.load_state_dict(state_dict)
Expand Down
Loading

0 comments on commit 239cc71

Please sign in to comment.