-
Notifications
You must be signed in to change notification settings - Fork 0
/
UCFT_CIFAR10.py
50 lines (38 loc) · 1.18 KB
/
UCFT_CIFAR10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import sys
import time
import argparse
import datetime
from torch.autograd import Variable
from typing import Optional
from art.estimators.classification.pytorch import PyTorchClassifier
from resnet import ResNet50 as resnet50
#code to connect to evaluation script
def get_art_model(
model_kwargs: dict, wrapper_kwargs: dict, weights_path: Optional[str] = None
) -> PyTorchClassifier:
model = resnet50()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(weights_path)
model.load_state_dict(checkpoint['model'])
model.to(DEVICE)
cudnn.benchmark = True
wrapped_model = PyTorchClassifier(
model,
loss=nn.CrossEntropyLoss(),
optimizer=torch.optim.Adam(model.parameters(), lr=1e-100),
input_shape=(32, 32, 3),
nb_classes=10,
clip_values=(0.0, 1.0),
**wrapper_kwargs,
)
return wrapped_model
# x = x.permute(0, 3, 1, 2)