Module monk.pytorch.datasets.csv_dataset
Expand source code
from pytorch.datasets.imports import *
from system.imports import *
class DatasetCustom(Dataset):
'''
Class for single label CSV dataset
Args:
img_list (str): List of images
label_list (str): List of labels in the same order as images
prefix (str): Path to folder containing images
transform (torchvision transforms): List of compiled transforms
'''
def __init__(self, img_list, label_list, prefix, transform=None):
self.img_list = img_list;
self.label_list = label_list;
self.transform = transform;
self.prefix = prefix;
def __len__(self):
'''
Returns length of images in dataset
Args:
None
Returns:
int: Length of images in dataset
'''
return len(self.img_list)
def __getitem__(self, index):
'''
Returns transformed image and label as per index
Args:
None
Returns:
pytorch tensor: Image loaded as pytorch tensor
int: Class ID
'''
image_name = self.prefix + "/" + self.img_list[index];
image = Image.open(image_name).convert('RGB');
label = int(self.label_list[index]);
if self.transform is not None:
image = self.transform(image);
return image, label
Classes
class DatasetCustom (img_list, label_list, prefix, transform=None)
-
Class for single label CSV dataset
Args
img_list
:str
- List of images
label_list
:str
- List of labels in the same order as images
prefix
:str
- Path to folder containing images
transform
:torchvision
transforms
- List of compiled transforms
Expand source code
class DatasetCustom(Dataset): ''' Class for single label CSV dataset Args: img_list (str): List of images label_list (str): List of labels in the same order as images prefix (str): Path to folder containing images transform (torchvision transforms): List of compiled transforms ''' def __init__(self, img_list, label_list, prefix, transform=None): self.img_list = img_list; self.label_list = label_list; self.transform = transform; self.prefix = prefix; def __len__(self): ''' Returns length of images in dataset Args: None Returns: int: Length of images in dataset ''' return len(self.img_list) def __getitem__(self, index): ''' Returns transformed image and label as per index Args: None Returns: pytorch tensor: Image loaded as pytorch tensor int: Class ID ''' image_name = self.prefix + "/" + self.img_list[index]; image = Image.open(image_name).convert('RGB'); label = int(self.label_list[index]); if self.transform is not None: image = self.transform(image); return image, label
Ancestors
- torch.utils.data.dataset.Dataset