from torch.utils.data import Dataset
# from data import ImageDataset
import os
from PIL import Image
import shutil
class ClassificationSet(Dataset):
def __init__(self, data_path, transform=None, size=(256, 256), channel='L'):
self.data_path = data_path
self.transform = transform
self.data = []
self.label = []
self.label_dict = {}
self.label_dict_reverse = {}
self.size = size
self.channel = channel
self.load_data()
print("Data loaded in lazy mode.")
def load_data(self):
categories=os.listdir(self.data_path)
categories=sorted(categories)
for i, label in categories:
self.label_dict[label] = i
self.label_dict_reverse[i] = label
for file in os.listdir(os.path.join(self.data_path, label)):
self.data.append(os.path.join(self.data_path, label, file))
self.label.append(i)
def __getitem__(self, index):
img = Image.open(self.data[index])
img = img.convert(self.channel)
img = img.resize(self.size)
if self.transform is not None:
img = self.transform(img)
return img, self.label[index]
def __len__(self):
return len(self.data)
def get_label_dict(self):
return self.label_dict
def get_label_dict_reverse(self):
return self.label_dict_reverse
def get_label(self):
return self.label
def get_data(self):
return self.data
def get_data_path(self):
return self.data_path
def get_transform(self):
return self.transform
def set_transform(self, transform):
self.transform = transform
def set_data_path(self, data_path):
self.data_path = data_path
def set_label_dict(self, label_dict):
self.label_dict = label_dict
def set_label_dict_reverse(self, label_dict_reverse):
self.label_dict_reverse = label_dict_reverse
def set_label(self, label):
self.label = label
def set_data(self, data):
self.data = data
def get_data_by_label(self, label):
data = []
for i in range(len(self.label)):
if self.label[i] == label:
data.append(self.data[i])
return data
def get_label_by_data(self, data):
label = []
for i in range(len(self.data)):
if self.data[i] == data:
label.append(self.label[i])
return label
def get_data_by_label_dict(self, label_dict):
data = []
for i in range(len(self.label)):
if self.label[i] == label_dict:
data.append(self.data[i])
return data
def get_label_by_data_dict(self, data_dict):
label = []
for i in range(len(self.data)):
if self.data[i] == data_dict:
label.append(self.label[i])
return label
class FastClassificationSet(Dataset):
def __init__(self, data_path, transform=None, size=(256, 256), channel='L',verbose=True):
self.data_path = data_path
self.transform = transform
self.data = []
self.label = []
self.label_dict = {}
self.label_dict_reverse = {}
self.size = size
self.channel = channel
self.verbose = verbose
self.prefetched_imgs = []
self.load_data()
def load_data(self):
categories=os.listdir(self.data_path)
categories=sorted(categories)
for i, label in categories:
self.label_dict[label] = i
self.label_dict_reverse[i] = label
for file in os.listdir(os.path.join(self.data_path, label)):
img_path=os.path.join(self.data_path, label, file)
self.data.append(img_path)
self.label.append(i)
img = Image.open (img_path)
img = img.convert(self.channel)
img = img.resize(self.size)
if self.transform is not None:
img = self.transform(img)
self.prefetched_imgs.append(img)
if self.verbose:
print(f"Prefetched image: {img_path}".ljust(shutil.get_terminal_size().columns),end='\r')
print("Prefetched images.".ljust(shutil.get_terminal_size().columns))
def __getitem__(self, index):
return self.prefetched_imgs[index], self.label[index]
def __len__(self):
return len(self.data)
def get_label_dict(self):
return self.label_dict
def get_label_dict_reverse(self):
return self.label_dict_reverse
def get_label(self):
return self.label
def get_data(self):
return self.data
def get_data_path(self):
return self.data_path
def get_transform(self):
return self.transform
def set_transform(self, transform):
self.transform = transform
def set_data_path(self, data_path):
self.data_path = data_path
def set_label_dict(self, label_dict):
self.label_dict = label_dict
def set_label_dict_reverse(self, label_dict_reverse):
self.label_dict_reverse = label_dict_reverse
def set_label(self, label):
self.label = label
def set_data(self, data):
self.data = data
def get_data_by_label(self, label):
data = []
for i in range(len(self.label)):
if self.label[i] == label:
data.append(self.data[i])
return data
def get_label_by_data(self, data):
label = []
for i in range(len(self.data)):
if self.data[i] == data:
label.append(self.label[i])
return label
def get_data_by_label_dict(self, label_dict):
data = []
for i in range(len(self.label)):
if self.label[i] == label_dict:
data.append(self.data[i])
return data
def get_label_by_data_dict(self, data_dict):
label = []
for i in range(len(self.data)):
if self.data[i] == data_dict:
label.append(self.label[i])
return label
Post Views:
1,695