forked from kyk120/fpet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvtab.py
More file actions
117 lines (94 loc) · 3.41 KB
/
Copy pathvtab.py
File metadata and controls
117 lines (94 loc) · 3.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch.utils.data as data
from PIL import Image
import os
import os.path
from torchvision import transforms
import torch
_DATASET_NAME = (
'cifar',
'caltech101',
'dtd',
'oxford_flowers102',
'oxford_iiit_pet',
'svhn',
'sun397',
'patch_camelyon',
'eurosat',
'resisc45',
'diabetic_retinopathy',
'clevr_count',
'clevr_dist',
'dmlab',
'kitti',
'dsprites_loc',
'dsprites_ori',
'smallnorb_azi',
'smallnorb_ele',
)
_CLASSES_NUM = (100, 102, 47, 102, 37, 10, 397, 2, 10, 45, 5, 8, 6, 6, 4, 16, 16, 18, 9)
def get_classes_num(dataset_name):
dict_ = {name: num for name, num in zip(_DATASET_NAME, _CLASSES_NUM)}
return dict_[dataset_name]
def get_classes_name(idx):
return _DATASET_NAME[idx]
def default_loader(path):
return Image.open(path).convert('RGB')
def default_flist_reader(flist):
imlist = []
with open(flist, 'r') as rf:
for line in rf.readlines():
impath, imlabel = line.strip().split()
imlist.append((impath, int(imlabel)))
return imlist
class ImageFilelist(data.Dataset):
def __init__(self, root, flist, transform=None, target_transform=None,
flist_reader=default_flist_reader, loader=default_loader):
self.root = root
self.imlist = flist_reader(flist)
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
impath, target = self.imlist[index]
img = self.loader(os.path.join(self.root, impath))
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.imlist)
def get_data(name, normalize=True, batch_size=64, evaluate=True):
root = '/your/path/fpet/data/vtab-1k/' + name
if normalize:
transform = transforms.Compose([
transforms.Resize((224, 224), interpolation=3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
else:
transform = transforms.Compose([
transforms.Resize((224, 224), interpolation=3),
transforms.ToTensor()])
if evaluate:
train_loader = torch.utils.data.DataLoader(
ImageFilelist(root=root, flist=root + "/train800val200.txt",
transform=transform),
batch_size=batch_size, shuffle=True, drop_last=True,
num_workers=4, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
ImageFilelist(root=root, flist=root + "/test.txt",
transform=transform),
batch_size=256, shuffle=False,
num_workers=4, pin_memory=True)
else:
train_loader = torch.utils.data.DataLoader(
ImageFilelist(root=root, flist=root + "/train800.txt",
transform=transform),
batch_size=batch_size, shuffle=True, drop_last=True,
num_workers=4, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
ImageFilelist(root=root, flist=root + "/val200.txt",
transform=transform),
batch_size=256, shuffle=False,
num_workers=4, pin_memory=True)
return train_loader, val_loader