Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal Code Changes to Support Latest PyTorch and Bug Fixed for Extremely Low Adaptation Accuracy #29

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
A PyTorch implementation for [Adversarial Discriminative Domain Adaptation](https://arxiv.org/abs/1702.05464).

## Environment
- Python 3.6
- PyTorch 0.2.0
- Python >= 3.6 (Tested on Python 3.8)
- PyTorch >= 1.0.0 (Tested on PyTorch 1.11.0)

## Usage

Expand Down
6 changes: 3 additions & 3 deletions core/adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def train_tgt(src_encoder, tgt_encoder, critic,
params.num_epochs,
step + 1,
len_data_loader,
loss_critic.data[0],
loss_tgt.data[0],
acc.data[0]))
loss_critic.item(),
loss_tgt.item(),
acc.item()))

#############################
# 2.4 save model parameters #
Expand Down
8 changes: 4 additions & 4 deletions core/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def train_src(encoder, classifier, data_loader):
params.num_epochs_pre,
step + 1,
len(data_loader),
loss.data[0]))
loss.item()))

# eval model on test set
if ((epoch + 1) % params.eval_step_pre == 0):
Expand All @@ -78,8 +78,8 @@ def eval_src(encoder, classifier, data_loader):
classifier.eval()

# init loss and accuracy
loss = 0
acc = 0
loss = 0.
acc = 0.

# set loss function
criterion = nn.CrossEntropyLoss()
Expand All @@ -90,7 +90,7 @@ def eval_src(encoder, classifier, data_loader):
labels = make_variable(labels)

preds = classifier(encoder(images))
loss += criterion(preds, labels).data[0]
loss += criterion(preds, labels).item()

pred_cls = preds.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum()
Expand Down
6 changes: 3 additions & 3 deletions core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def eval_tgt(encoder, classifier, data_loader):
classifier.eval()

# init loss and accuracy
loss = 0
acc = 0
loss = 0.
acc = 0.

# set loss function
criterion = nn.CrossEntropyLoss()
Expand All @@ -25,7 +25,7 @@ def eval_tgt(encoder, classifier, data_loader):
labels = make_variable(labels).squeeze_()

preds = classifier(encoder(images))
loss += criterion(preds, labels).data[0]
loss += criterion(preds, labels).item()

pred_cls = preds.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum()
Expand Down
2 changes: 1 addition & 1 deletion datasets/usps.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, root, train=True, transform=None, download=False):
np.random.shuffle(indices)
self.train_data = self.train_data[indices[0:self.dataset_size], ::]
self.train_labels = self.train_labels[indices[0:self.dataset_size]]
self.train_data *= 255.0

self.train_data = self.train_data.transpose(
(0, 2, 3, 1)) # convert to HWC

Expand Down
1 change: 0 additions & 1 deletion models/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def __init__(self, input_dims, hidden_dims, output_dims):
nn.Linear(hidden_dims, hidden_dims),
nn.ReLU(),
nn.Linear(hidden_dims, output_dims),
nn.LogSoftmax()
)

def forward(self, input):
Expand Down
4 changes: 2 additions & 2 deletions params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
data_root = "data"
dataset_mean_value = 0.5
dataset_std_value = 0.5
dataset_mean = (dataset_mean_value, dataset_mean_value, dataset_mean_value)
dataset_std = (dataset_std_value, dataset_std_value, dataset_std_value)
dataset_mean = dataset_mean_value
dataset_std = dataset_std_value
batch_size = 50
image_size = 64

Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def make_variable(tensor, volatile=False):
"""Convert Tensor to Variable."""
if torch.cuda.is_available():
tensor = tensor.cuda()
return Variable(tensor, volatile=volatile)
return tensor


def make_cuda(tensor):
Expand Down