diff --git "a/AerialImageSegmentation/Task5\357\274\232\346\250\241\345\236\213\350\256\255\347\273\203\344\270\216\351\252\214\350\257\201.md" "b/AerialImageSegmentation/Task5\357\274\232\346\250\241\345\236\213\350\256\255\347\273\203\344\270\216\351\252\214\350\257\201.md" index 9a8b8b9..82cc36a 100644 --- "a/AerialImageSegmentation/Task5\357\274\232\346\250\241\345\236\213\350\256\255\347\273\203\344\270\216\351\252\214\350\257\201.md" +++ "b/AerialImageSegmentation/Task5\357\274\232\346\250\241\345\236\213\350\256\255\347\273\203\344\270\216\351\252\214\350\257\201.md" @@ -80,15 +80,15 @@ criterion = nn.CrossEntropyLoss(size_average=False) optimizer = torch.optim.Adam(model.parameters(), 0.001) best_loss = 1000.0 for epoch in range(20): -print('Epoch: ', epoch) + print('Epoch: ', epoch) -train(train_loader, model, criterion, optimizer, epoch) -val_loss = validate(val_loader, model, criterion) + train(train_loader, model, criterion, optimizer, epoch) + val_loss = validate(val_loader, model, criterion) -# 记录下验证集精度 -if val_loss < best_loss: - best_loss = val_loss - torch.save(model.state_dict(), './model.pt') + # 记录下验证集精度 + if val_loss < best_loss: + best_loss = val_loss + torch.save(model.state_dict(), './model.pt') ``` 其中每个Epoch的训练代码如下: