forked from borisdayma/dalle-mini
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
122 lines (102 loc) · 4.67 KB
/
Copy pathdataset.py
File metadata and controls
122 lines (102 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
An image-caption dataset dataloader.
Luke Melas-Kyriazi, 2021
"""
import warnings
from typing import Optional, Callable
from pathlib import Path
import numpy as np
import torch
import pandas as pd
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
from PIL import ImageFile
from PIL.Image import DecompressionBombWarning
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DecompressionBombWarning)
class CaptionDataset(Dataset):
"""
A PyTorch Dataset class for (image, texts) tasks. Note that this dataset
returns the raw text rather than tokens. This is done on purpose, because
it's easy to tokenize a batch of text after loading it from this dataset.
"""
def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None,
image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',
include_captions: bool = True):
"""
:param images_root: folder where images are stored
:param captions_path: path to csv that maps image filenames to captions
:param image_transform: image transform pipeline
:param text_transform: image transform pipeline
:param image_transform_type: image transform type, either `torchvision` or `albumentations`
:param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.
"""
# Base path for images
self.images_root = Path(images_root)
# Load captions as DataFrame
self.captions = pd.read_csv(captions_path, delimiter='\t', header=0)
self.captions['image_file'] = self.captions['image_file'].astype(str)
# PyTorch transformation pipeline for the image (normalizing, etc.)
self.text_transform = text_transform
self.image_transform = image_transform
self.image_transform_type = image_transform_type.lower()
assert self.image_transform_type in ['torchvision', 'albumentations']
# Total number of datapoints
self.size = len(self.captions)
# Return image+captions or just images
self.include_captions = include_captions
def verify_that_all_images_exist(self):
for image_file in self.captions['image_file']:
p = self.images_root / image_file
if not p.is_file():
print(f'file does not exist: {p}')
def _get_raw_image(self, i):
image_file = self.captions.iloc[i]['image_file']
image_path = self.images_root / image_file
image = default_loader(image_path)
return image
def _get_raw_text(self, i):
return self.captions.iloc[i]['caption']
def __getitem__(self, i):
image = self._get_raw_image(i)
caption = self._get_raw_text(i)
if self.image_transform is not None:
if self.image_transform_type == 'torchvision':
image = self.image_transform(image)
elif self.image_transform_type == 'albumentations':
image = self.image_transform(image=np.array(image))['image']
else:
raise NotImplementedError(f"{self.image_transform_type=}")
return {'image': image, 'text': caption} if self.include_captions else image
def __len__(self):
return self.size
if __name__ == "__main__":
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import AutoTokenizer
# Paths
images_root = './images'
captions_path = './images-list-clean.tsv'
# Create transforms
tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
def tokenize(text):
return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length')
image_transform = A.Compose([
A.Resize(256, 256), A.CenterCrop(256, 256),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
# Create dataset
dataset = CaptionDataset(
images_root=images_root,
captions_path=captions_path,
image_transform=image_transform,
text_transform=tokenize,
image_transform_type='albumentations')
# Create dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
batch = next(iter(dataloader))
print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()})
# # (Optional) Check that all the images exist
# dataset = CaptionDataset(images_root=images_root, captions_path=captions_path)
# dataset.verify_that_all_images_exist()
# print('Done')