diff --git a/Pytorch_MNIST.py b/Pytorch_MNIST.py index 5067a2c..4ab8a55 100644 --- a/Pytorch_MNIST.py +++ b/Pytorch_MNIST.py @@ -31,8 +31,7 @@ def forward(self, x): x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) - output = F.log_softmax(x, dim=1) - return output + return F.log_softmax(x, dim=1) train_losses = [] train_counter = [] @@ -120,8 +119,8 @@ def main(): cuda_kwargs = {'num_workers': 1, 'pin_memory': True, 'shuffle': True} - train_kwargs.update(cuda_kwargs) - test_kwargs.update(cuda_kwargs) + train_kwargs |= cuda_kwargs + test_kwargs |= cuda_kwargs transform=transforms.Compose([ transforms.ToTensor(), diff --git a/Pytorch_ResNet.py b/Pytorch_ResNet.py index 6167898..d728d06 100644 --- a/Pytorch_ResNet.py +++ b/Pytorch_ResNet.py @@ -107,10 +107,7 @@ class ResNet(nn.Module): def __init__(self, block, layers, num_classes, grayscale): self.inplanes = 64 - if grayscale: - in_dim = 1 - else: - in_dim = 3 + in_dim = 1 if grayscale else 3 super(ResNet, self).__init__() self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3, bias=False) @@ -141,12 +138,9 @@ def _make_layer(self, block, planes, blocks, stride=1): nn.BatchNorm2d(planes * block.expansion), ) - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample)) + layers = [block(self.inplanes, planes, stride, downsample)] self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes)) - + layers.extend(block(self.inplanes, planes) for _ in range(1, blocks)) return nn.Sequential(*layers) def forward(self, x): @@ -170,11 +164,12 @@ def forward(self, x): def resnet18(num_classes): """Constructs a ResNet-18 model.""" - model = ResNet(block=BasicBlock, - layers=[2, 2, 2, 2], - num_classes=NUM_CLASSES, - grayscale=GRAYSCALE) - return model + return ResNet( + block=BasicBlock, + layers=[2, 2, 2, 2], + num_classes=NUM_CLASSES, + grayscale=GRAYSCALE, + ) torch.manual_seed(RANDOM_SEED) @@ -186,8 +181,7 @@ def resnet18(num_classes): def compute_accuracy(model, data_loader, device): correct_pred, num_examples = 0, 0 - for i, (features, targets) in enumerate(data_loader): - + for features, targets in data_loader: features = features.to(device) targets = targets.to(device)