Skip to content

Commit 7afa63e

Browse files
committed
add an hparam for ContinuousAutoregressiveWrapper
1 parent 8698122 commit 7afa63e

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "x-transformers"
3-
version = "2.7.0"
3+
version = "2.7.1"
44
description = "X-Transformers"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

x_transformers/continuous.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def __init__(
241241
self,
242242
net: ContinuousTransformerWrapper,
243243
loss_fn: Module | None = None,
244+
use_l1_loss = False,
244245
equal_loss_weight_batch = False, # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
245246
):
246247
super().__init__()
@@ -250,7 +251,15 @@ def __init__(
250251
probabilistic = net.probabilistic
251252
self.probabilistic = probabilistic
252253

253-
loss_fn = default(loss_fn, nn.MSELoss(reduction = 'none') if not probabilistic else GaussianNLL())
254+
# default loss function
255+
256+
if not exists(loss_fn):
257+
if probabilistic:
258+
loss_fn = GaussianNLL()
259+
elif use_l1_loss:
260+
loss_fn = nn.L1Loss(reduction = 'none')
261+
else:
262+
loss_fn = nn.MSELoss(reduction = 'none')
254263

255264
self.loss_fn = loss_fn
256265
self.equal_loss_weight_batch = equal_loss_weight_batch

0 commit comments

Comments
 (0)