-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathlinear_probing.py
More file actions
118 lines (90 loc) · 3.7 KB
/
linear_probing.py
File metadata and controls
118 lines (90 loc) · 3.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# More hidden layers and non-linearities for probing/fine-tuning
# Screen sessions
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import tqdm
from models import *
import wandb
wandb.init(project="ssl-vision", entity="aryan9101")
if torch.cuda.is_available():
device = torch.device("cuda:0")
print("Running on a GPU")
else:
device = torch.device("cpu")
print("Running on a CPU")
cifar10_mean = torch.tensor([0.49139968, 0.48215827, 0.44653124])
cifar10_std = torch.tensor([0.24703233, 0.24348505, 0.26158768])
class Cifar10Dataset(Dataset):
def __init__(self, train):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cifar10_mean, cifar10_std)
])
self.dataset = torchvision.datasets.CIFAR10(root='./SSL-Vision/data',
train=train,
download=True)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img, label = self.dataset[idx]
img = self.transform(img)
return img, label
batch_size = 128
trainset = Cifar10Dataset(True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testset = Cifar10Dataset(False)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
checkpoint = torch.load(f"./SSL-Vision/mae_timm.pth")
mae = get_mae_small().to(device)
mae.load_state_dict(checkpoint['mae_state_dict'])
mae_classifier = MAE_Classifier(mae, 384, len(classes)).to(device)
learning_rate = 1e-3
num_epochs = 10
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(mae_classifier.parameters(), lr=learning_rate)
train_losses = []
test_losses = []
for epoch in range(num_epochs):
train_loss = 0.0
train_acc = 0.0
train_total = 0
mae_classifier.train()
for inputs, labels in tqdm.tqdm(trainloader):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
_, outputs = mae_classifier(inputs)
loss = criterion(outputs, labels.long())
loss.backward()
optimizer.step()
train_loss += loss.item() * inputs.shape[0]
train_acc += torch.sum((torch.argmax(outputs, dim=1) == labels)).item()
train_total += inputs.shape[0]
train_loss = train_loss / train_total
train_acc = train_acc / train_total
train_losses.append(train_loss)
test_loss = 0.0
test_acc = 0.0
test_total = 0
mae_classifier.eval()
with torch.no_grad():
for inputs, labels in testloader:
inputs = inputs.to(device)
labels = labels.to(device)
_, outputs = mae_classifier(inputs)
loss = criterion(outputs, labels.long())
test_loss += loss.item() * inputs.shape[0]
test_acc += torch.sum((torch.argmax(outputs, dim=1) == labels)).item()
test_total += inputs.shape[0]
test_loss = test_loss / test_total
test_acc = test_acc / test_total
test_losses.append(test_loss)
print(f'[{epoch + 1:2d}] train loss: {train_loss:.3f} | train accuracy: {train_acc:.3f} | test_loss: {test_loss:.3f} | test_accuracy: {test_acc:.3f}')
wandb.log({"train_loss": train_loss, "test_loss": test_loss, "train_accuracy": train_acc, "test_accuracy": test_acc, "epoch": epoch + 1})
print('Finished Training')