From 13fe52801f8f6237b2ba3f57eb5a65943eb7f8fc Mon Sep 17 00:00:00 2001 From: Thesoul2 Date: Mon, 12 Aug 2024 20:55:29 +0800 Subject: [PATCH] Update AerialImageSementation. Added the index in python code. --- ...273\203\344\270\216\351\252\214\350\257\201.md" | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git "a/AerialImageSegmentation/Task5\357\274\232\346\250\241\345\236\213\350\256\255\347\273\203\344\270\216\351\252\214\350\257\201.md" "b/AerialImageSegmentation/Task5\357\274\232\346\250\241\345\236\213\350\256\255\347\273\203\344\270\216\351\252\214\350\257\201.md" index 9a8b8b9..82cc36a 100644 --- "a/AerialImageSegmentation/Task5\357\274\232\346\250\241\345\236\213\350\256\255\347\273\203\344\270\216\351\252\214\350\257\201.md" +++ "b/AerialImageSegmentation/Task5\357\274\232\346\250\241\345\236\213\350\256\255\347\273\203\344\270\216\351\252\214\350\257\201.md" @@ -80,15 +80,15 @@ criterion = nn.CrossEntropyLoss(size_average=False) optimizer = torch.optim.Adam(model.parameters(), 0.001) best_loss = 1000.0 for epoch in range(20): -print('Epoch: ', epoch) + print('Epoch: ', epoch) -train(train_loader, model, criterion, optimizer, epoch) -val_loss = validate(val_loader, model, criterion) + train(train_loader, model, criterion, optimizer, epoch) + val_loss = validate(val_loader, model, criterion) -# 记录下验证集精度 -if val_loss < best_loss: - best_loss = val_loss - torch.save(model.state_dict(), './model.pt') + # 记录下验证集精度 + if val_loss < best_loss: + best_loss = val_loss + torch.save(model.state_dict(), './model.pt') ``` 其中每个Epoch的训练代码如下: