Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric-mingjie committed Jul 6, 2018
1 parent ac06e38 commit 5b4d636
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 39 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,10 @@ python main.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg

| CIFAR100-Resnet-164 | Baseline | Sparsity (1e-5) | Prune (40%) | Fine-tune-160(40%) | Prune(60%) | Fine-tune-160(60%) |
| :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: |:--------------------: | :-----------------:|
| Top1 Accuracy (%) | ----- | 76.87 | 48.0 | --- | --- | -- |
| Parameters | 1.73M | 1.73M | 1.49M | --- |--- | -- |
| Top1 Accuracy (%) | 76.79 | 76.87 | 48.0 | 77.36 | --- | --- |
| Parameters | 1.73M | 1.73M | 1.49M | 1.49M |--- | --- |

Note: For results of pruning 60% of the channels for resnet164-cifar100, in this implementation, sometimes some layers are all pruned and there would be error. However, we also provide a [mask implementation](https://github.com/Eric-mingjie/network-slimming/tree/master/mask-impl) where we apply a mask to the scaling factor in BN layer. For mask implementaion, when pruning 60% of the channels in resnet164-cifar100, we can also train the pruned network.

| CIFAR100-Densenet-40 | Baseline | Sparsity (1e-5) | Prune (40%) | Fine-tune-160(40%) | Prune(60%) | Fine-tune-160(60%) |
| :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: |:--------------------: | :-----------------:|
Expand Down
15 changes: 1 addition & 14 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@
else:
print("=> no checkpoint found at '{}'".format(args.resume))

history_score = np.zeros((args.epochs - args.start_epoch + 1, 3))

# additional subgradient descent on the sparsity-induced penalty term
def updateBN():
for m in model.modules():
Expand All @@ -132,19 +130,14 @@ def updateBN():

def train(epoch):
model.train()
global history_score
avg_loss = 0.
train_acc = 0.
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
avg_loss += loss.data[0]
pred = output.data.max(1, keepdim=True)[1]
train_acc += pred.eq(target.data.view_as(pred)).cpu().sum()
loss.backward()
if args.sr:
updateBN()
Expand All @@ -153,8 +146,6 @@ def train(epoch):
print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0]))
history_score[epoch][0] = avg_loss / len(train_loader)
history_score[epoch][1] = train_acc / float(len(train_loader))

def test():
model.eval()
Expand Down Expand Up @@ -187,8 +178,6 @@ def save_checkpoint(state, is_best, filepath):
param_group['lr'] *= 0.1
train(epoch)
prec1 = test()
history_score[epoch][2] = prec1
np.savetxt(os.path.join(args.save, 'record.txt'), history_score, fmt = '%10.5f', delimiter=',')
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
Expand All @@ -198,6 +187,4 @@ def save_checkpoint(state, is_best, filepath):
'optimizer': optimizer.state_dict(),
}, is_best, filepath=args.save)

print("Best accuracy: "+str(best_prec1))
history_score[-1][0] = best_prec1
np.savetxt(os.path.join(args.save, 'record.txt'), history_score, fmt = '%10.5f', delimiter=',')
print("Best accuracy: "+str(best_prec1))
24 changes: 1 addition & 23 deletions mask-impl/main_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@
else:
print("=> no checkpoint found at '{}'".format(args.resume))

history_score = np.zeros((args.epochs - args.start_epoch + 1, 3))

# additional subgradient descent on the sparsity-induced penalty term
def updateBN():
for m in model.modules():
Expand All @@ -138,30 +136,16 @@ def BN_grad_zero():
m.weight.grad.data.mul_(mask)
m.bias.grad.data.mul_(mask)

def compute_cfg():
cfg = []
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
weight_nonzero = (m.weight.data == 0).sum()
bias_nonzero = (m.bias.data == 0).sum()
cfg.append((weight_nonzero, bias_nonzero))
return cfg

def train(epoch):
model.train()
global history_score
avg_loss = 0.
train_acc = 0.
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
avg_loss += loss.data[0]
pred = output.data.max(1, keepdim=True)[1]
train_acc += pred.eq(target.data.view_as(pred)).cpu().sum()
loss.backward()
if args.sr:
updateBN()
Expand All @@ -171,8 +155,6 @@ def train(epoch):
print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0]))
history_score[epoch][0] = avg_loss / len(train_loader)
history_score[epoch][1] = train_acc / float(len(train_loader))

def test():
model.eval()
Expand Down Expand Up @@ -206,8 +188,6 @@ def save_checkpoint(state, is_best, filepath):
param_group['lr'] *= 0.1
train(epoch)
prec1 = test()
history_score[epoch][2] = prec1
np.savetxt(os.path.join(args.save, 'record.txt'), history_score, fmt = '%10.5f', delimiter=',')
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
Expand All @@ -217,6 +197,4 @@ def save_checkpoint(state, is_best, filepath):
'optimizer': optimizer.state_dict(),
}, is_best, filepath=args.save)

print("Best accuracy: "+str(best_prec1))
history_score[-1][0] = best_prec1
np.savetxt(os.path.join(args.save, 'record.txt'), history_score, fmt = '%10.5f', delimiter=',')
print("Best accuracy: "+str(best_prec1))

0 comments on commit 5b4d636

Please sign in to comment.