Idealisan

图像分类任务数据集模板



from torch.utils.data import Dataset

# from data import ImageDataset
import os
from PIL import Image
import shutil


class ClassificationSet(Dataset):
    def __init__(self, data_path, transform=None, size=(256, 256), channel='L'):
        self.data_path = data_path
        self.transform = transform
        self.data = []
        self.label = []
        self.label_dict = {}
        self.label_dict_reverse = {}
        self.size = size
        self.channel = channel
        self.load_data()
        print("Data loaded in lazy mode.")

    def load_data(self):
        categories=os.listdir(self.data_path)
        categories=sorted(categories)
        for i, label in categories:
            self.label_dict[label] = i
            self.label_dict_reverse[i] = label
            for file in os.listdir(os.path.join(self.data_path, label)):
                self.data.append(os.path.join(self.data_path, label, file))
                self.label.append(i)

    def __getitem__(self, index):
        img = Image.open(self.data[index])
        img = img.convert(self.channel)
        img = img.resize(self.size)
        if self.transform is not None:
            img = self.transform(img)
        return img, self.label[index]

    def __len__(self):
        return len(self.data)

    def get_label_dict(self):
        return self.label_dict

    def get_label_dict_reverse(self):
        return self.label_dict_reverse

    def get_label(self):
        return self.label

    def get_data(self):
        return self.data

    def get_data_path(self):
        return self.data_path

    def get_transform(self):
        return self.transform

    def set_transform(self, transform):
        self.transform = transform

    def set_data_path(self, data_path):
        self.data_path = data_path

    def set_label_dict(self, label_dict):
        self.label_dict = label_dict

    def set_label_dict_reverse(self, label_dict_reverse):
        self.label_dict_reverse = label_dict_reverse

    def set_label(self, label):
        self.label = label

    def set_data(self, data):
        self.data = data

    def get_data_by_label(self, label):
        data = []
        for i in range(len(self.label)):
            if self.label[i] == label:
                data.append(self.data[i])
        return data

    def get_label_by_data(self, data):
        label = []
        for i in range(len(self.data)):
            if self.data[i] == data:
                label.append(self.label[i])
        return label

    def get_data_by_label_dict(self, label_dict):
        data = []
        for i in range(len(self.label)):
            if self.label[i] == label_dict:
                data.append(self.data[i])
        return data

    def get_label_by_data_dict(self, data_dict):
        label = []
        for i in range(len(self.data)):
            if self.data[i] == data_dict:
                label.append(self.label[i])
        return label



class FastClassificationSet(Dataset):
    def __init__(self, data_path, transform=None, size=(256, 256), channel='L',verbose=True):
        self.data_path = data_path
        self.transform = transform
        self.data = []
        self.label = []
        self.label_dict = {}
        self.label_dict_reverse = {}
        self.size = size
        self.channel = channel
        self.verbose = verbose

        self.prefetched_imgs = []
        self.load_data()

    def load_data(self):
        categories=os.listdir(self.data_path)
        categories=sorted(categories)
        for i, label in categories:
            self.label_dict[label] = i
            self.label_dict_reverse[i] = label
            for file in os.listdir(os.path.join(self.data_path, label)):
                img_path=os.path.join(self.data_path, label, file)
                self.data.append(img_path)
                self.label.append(i)
                img = Image.open    (img_path)
                img = img.convert(self.channel)
                img = img.resize(self.size)
                if self.transform is not None:
                    img = self.transform(img)
                self.prefetched_imgs.append(img)
                if self.verbose:
                    print(f"Prefetched image: {img_path}".ljust(shutil.get_terminal_size().columns),end='\r')
        print("Prefetched images.".ljust(shutil.get_terminal_size().columns))

    def __getitem__(self, index):
        return self.prefetched_imgs[index], self.label[index]

    def __len__(self):
        return len(self.data)

    def get_label_dict(self):
        return self.label_dict

    def get_label_dict_reverse(self):
        return self.label_dict_reverse

    def get_label(self):
        return self.label

    def get_data(self):
        return self.data

    def get_data_path(self):
        return self.data_path

    def get_transform(self):
        return self.transform

    def set_transform(self, transform):
        self.transform = transform

    def set_data_path(self, data_path):
        self.data_path = data_path

    def set_label_dict(self, label_dict):
        self.label_dict = label_dict

    def set_label_dict_reverse(self, label_dict_reverse):
        self.label_dict_reverse = label_dict_reverse

    def set_label(self, label):
        self.label = label

    def set_data(self, data):
        self.data = data

    def get_data_by_label(self, label):
        data = []
        for i in range(len(self.label)):
            if self.label[i] == label:
                data.append(self.data[i])
        return data

    def get_label_by_data(self, data):
        label = []
        for i in range(len(self.data)):
            if self.data[i] == data:
                label.append(self.label[i])
        return label

    def get_data_by_label_dict(self, label_dict):
        data = []
        for i in range(len(self.label)):
            if self.label[i] == label_dict:
                data.append(self.data[i])
        return data

    def get_label_by_data_dict(self, data_dict):
        label = []
        for i in range(len(self.data)):
            if self.data[i] == data_dict:
                label.append(self.label[i])
        return label
分类

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注