深度学习数据集加载
【代码】深度学习数据集加载。
·
数据集结构
E:\Mytest\test20250622\pythonProject\dataset
├── rose
│ ├── rose1.jpg
│ ├── rose2.jpg
│ └── ...
└── sunflower
├── sunflower1.jpg
├── sunflower2.jpg
└── ...
主要只有的两个类
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
加载示例
data.py
import os
from torch.utils.data import Dataset
from PIL import Image
class MyDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = []
self.labels = []
"""
E:\Mytest\test20250622\pythonProject\dataset
├── rose
│ ├── rose1.jpg
│ ├── rose2.jpg
│ └── ...
└── sunflower
├── sunflower1.jpg
├── sunflower2.jpg
└── ...
os.listdir(self.root_dir),会返回E:\Mytest\test20250622\pythonProject\dataset
目录下的所有子目录,例如rose, sunflower
"""
classes = os.listdir(self.root_dir)
for index, className in enumerate(classes):
class_dir = os.path.join(self.root_dir, className)
if os.path.isdir(class_dir):
for img_name in os.listdir(class_dir):
img_path = os.path.join(class_dir, img_name)
self.images.append(img_path)
self.labels.append(index)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = Image.open(self.images[idx])
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
main.py
from data import MyDataset
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 定义图像变换
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# 实例化自定义的Dataset
dataset = MyDataset(root_dir='E:\\Mytest\\test20250622\\pythonProject\\dataset', transform=transform)
# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True)
# 遍历dataloader
for i, (images, labels) in enumerate(dataloader):
print(f"Batch {i + 1}, Image shape: {images.shape}, Labels: {labels}")
"""
Batch 1, Image shape: torch.Size([3, 3, 224, 224]), Labels: tensor([1, 0, 1])
Batch 2, Image shape: torch.Size([2, 3, 224, 224]), Labels: tensor([1, 0])
"""
# 将张量转换为可显示的图像格式 (H x W x C)
# image = images.squeeze(0).permute(1, 2, 0) # 去掉batch维度,并将通道放到最后
# 显示图像
# plt.imshow(image)
# plt.title(f"Label: {labels.item()}")
# plt.axis('off')
# plt.show()
# 可选:只显示前几张图
if i >= 4:
break
更多推荐


所有评论(0)