AI

深度学习之图像分类

动手深度学习

Posted by LXG on March 6, 2026

图像分类数据集

Fashion-MNIST数据集

Fashion-MNIST dataset 是一个用于图像分类任务的经典数据集,由电商公司 Zalando 发布。 它的设计目的,是替代传统的手写数字数据集 MNIST dataset,提供一个更接近真实视觉任务、但难度仍然适中的基准数据集。

数据基本结构

项目 内容
数据类型 灰度图像
图像大小 28 × 28 像素
训练集 60,000 张
测试集 10,000 张
类别数 10 类服装
标签形式 0–9 整数

10个分类类别

标签 类别
0 T-shirt / top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot

读取数据集


import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

# 告诉 d2l:后续画图时尽量使用 SVG(矢量图)格式。
# 矢量图放大后不容易失真,在学习阶段更容易看清细节。
d2l.use_svg_display()

# =========================
# 1. 读取 Fashion-MNIST 数据集
# =========================

# 原始数据文件在 data/FashionMNIST/raw/ 下,主要有 4 个 idx 格式文件:
# - train-images-idx3-ubyte / t10k-images-idx3-ubyte(图像)
# - train-labels-idx1-ubyte / t10k-labels-idx1-ubyte(标签)
#
# 其中 idx3(图像文件)的逻辑结构可以理解为:
# [magic number][图片数量][行数][列数][像素字节流...]
# - 每张图是 28x28 灰度图,单像素 1 字节(0~255)
#
# idx1(标签文件)的逻辑结构可以理解为:
# [magic number][标签数量][标签字节流...]
# - 每个标签 1 字节,对应类别 0~9

# ToTensor() 会做两件关键事情:
# 1) 把 PIL 图像 / numpy 数组 转成 PyTorch 张量(Tensor)
# 2) 把像素从 [0, 255] 自动缩放到 [0, 1]
# 这样做有利于后续神经网络训练(数值更稳定)。
#
# 结合 Fashion-MNIST 来看:
# - 原始样本(解码后、变换前):
#   (PIL.Image(mode='L', size=(28, 28)), int标签)
# - ToTensor() 之后的单样本:
#   (torch.FloatTensor(shape=(1, 28, 28), 值域[0, 1]), int标签)
#   说明:灰度图会多一个通道维,所以是 (C, H, W) = (1, 28, 28)
trans = transforms.ToTensor()

# 下载(或读取本地缓存)训练集:
# - root="./data":数据保存目录
# - train=True:训练集(60000 张)
# - transform=trans:每次取样本时都会先做 ToTensor 转换
# - download=True:本地没有就自动下载
#
# 因为这里传入了 transform=trans,所以下面 mnist_train[i] 的数据结构是:
# (torch.FloatTensor(shape=(1, 28, 28)), int标签)
mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)

# 测试集(10000 张),参数含义同上,只是 train=False。
# mnist_test[i] 的结构与 mnist_train[i] 相同。
mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)

# 打印“完整的第一张数据”:图像张量全部数值 + 标签
first_img, first_label = mnist_train[0]
print("first_img tensor:\n", first_img)
print("first_label:", first_label)

打印首张图片数据


first_img tensor:
 tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0510,
          0.2863, 0.0000, 0.0000, 0.0039, 0.0157, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0039, 0.0039, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.1412, 0.5333,
          0.4980, 0.2431, 0.2118, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118,
          0.0157, 0.0000, 0.0000, 0.0118],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0000, 0.4000, 0.8000,
          0.6902, 0.5255, 0.5647, 0.4824, 0.0902, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0471, 0.0392, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6078, 0.9255,
          0.8118, 0.6980, 0.4196, 0.6118, 0.6314, 0.4275, 0.2510, 0.0902,
          0.3020, 0.5098, 0.2824, 0.0588],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.2706, 0.8118, 0.8745,
          0.8549, 0.8471, 0.8471, 0.6392, 0.4980, 0.4745, 0.4784, 0.5725,
          0.5529, 0.3451, 0.6745, 0.2588],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0039, 0.0039, 0.0039, 0.0000, 0.7843, 0.9098, 0.9098,
          0.9137, 0.8980, 0.8745, 0.8745, 0.8431, 0.8353, 0.6431, 0.4980,
          0.4824, 0.7686, 0.8980, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7176, 0.8824, 0.8471,
          0.8745, 0.8941, 0.9216, 0.8902, 0.8784, 0.8706, 0.8784, 0.8667,
          0.8745, 0.9608, 0.6784, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7569, 0.8941, 0.8549,
          0.8353, 0.7765, 0.7059, 0.8314, 0.8235, 0.8275, 0.8353, 0.8745,
          0.8627, 0.9529, 0.7922, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0039, 0.0118, 0.0000, 0.0471, 0.8588, 0.8627, 0.8314,
          0.8549, 0.7529, 0.6627, 0.8902, 0.8157, 0.8549, 0.8784, 0.8314,
          0.8863, 0.7725, 0.8196, 0.2039],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0235, 0.0000, 0.3882, 0.9569, 0.8706, 0.8627,
          0.8549, 0.7961, 0.7765, 0.8667, 0.8431, 0.8353, 0.8706, 0.8627,
          0.9608, 0.4667, 0.6549, 0.2196],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0157, 0.0000, 0.0000, 0.2157, 0.9255, 0.8941, 0.9020,
          0.8941, 0.9412, 0.9098, 0.8353, 0.8549, 0.8745, 0.9176, 0.8510,
          0.8510, 0.8196, 0.3608, 0.0000],
         [0.0000, 0.0000, 0.0039, 0.0157, 0.0235, 0.0275, 0.0078, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.8863, 0.8510, 0.8745,
          0.8706, 0.8588, 0.8706, 0.8667, 0.8471, 0.8745, 0.8980, 0.8431,
          0.8549, 1.0000, 0.3020, 0.0000],
         [0.0000, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.2431, 0.5686, 0.8000, 0.8941, 0.8118, 0.8353, 0.8667,
          0.8549, 0.8157, 0.8275, 0.8549, 0.8784, 0.8745, 0.8588, 0.8431,
          0.8784, 0.9569, 0.6235, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.1725, 0.3216, 0.4196,
          0.7412, 0.8941, 0.8627, 0.8706, 0.8510, 0.8863, 0.7843, 0.8039,
          0.8275, 0.9020, 0.8784, 0.9176, 0.6902, 0.7373, 0.9804, 0.9725,
          0.9137, 0.9333, 0.8431, 0.0000],
         [0.0000, 0.2235, 0.7333, 0.8157, 0.8784, 0.8667, 0.8784, 0.8157,
          0.8000, 0.8392, 0.8157, 0.8196, 0.7843, 0.6235, 0.9608, 0.7569,
          0.8078, 0.8745, 1.0000, 1.0000, 0.8667, 0.9176, 0.8667, 0.8275,
          0.8627, 0.9098, 0.9647, 0.0000],
         [0.0118, 0.7922, 0.8941, 0.8784, 0.8667, 0.8275, 0.8275, 0.8392,
          0.8039, 0.8039, 0.8039, 0.8627, 0.9412, 0.3137, 0.5882, 1.0000,
          0.8980, 0.8667, 0.7373, 0.6039, 0.7490, 0.8235, 0.8000, 0.8196,
          0.8706, 0.8941, 0.8824, 0.0000],
         [0.3843, 0.9137, 0.7765, 0.8235, 0.8706, 0.8980, 0.8980, 0.9176,
          0.9765, 0.8627, 0.7608, 0.8431, 0.8510, 0.9451, 0.2549, 0.2863,
          0.4157, 0.4588, 0.6588, 0.8588, 0.8667, 0.8431, 0.8510, 0.8745,
          0.8745, 0.8784, 0.8980, 0.1137],
         [0.2941, 0.8000, 0.8314, 0.8000, 0.7569, 0.8039, 0.8275, 0.8824,
          0.8471, 0.7255, 0.7725, 0.8078, 0.7765, 0.8353, 0.9412, 0.7647,
          0.8902, 0.9608, 0.9373, 0.8745, 0.8549, 0.8314, 0.8196, 0.8706,
          0.8627, 0.8667, 0.9020, 0.2627],
         [0.1882, 0.7961, 0.7176, 0.7608, 0.8353, 0.7725, 0.7255, 0.7451,
          0.7608, 0.7529, 0.7922, 0.8392, 0.8588, 0.8667, 0.8627, 0.9255,
          0.8824, 0.8471, 0.7804, 0.8078, 0.7294, 0.7098, 0.6941, 0.6745,
          0.7098, 0.8039, 0.8078, 0.4510],
         [0.0000, 0.4784, 0.8588, 0.7569, 0.7020, 0.6706, 0.7176, 0.7686,
          0.8000, 0.8235, 0.8353, 0.8118, 0.8275, 0.8235, 0.7843, 0.7686,
          0.7608, 0.7490, 0.7647, 0.7490, 0.7765, 0.7529, 0.6902, 0.6118,
          0.6549, 0.6941, 0.8235, 0.3608],
         [0.0000, 0.0000, 0.2902, 0.7412, 0.8314, 0.7490, 0.6863, 0.6745,
          0.6863, 0.7098, 0.7255, 0.7373, 0.7412, 0.7373, 0.7569, 0.7765,
          0.8000, 0.8196, 0.8235, 0.8235, 0.8275, 0.7373, 0.7373, 0.7608,
          0.7529, 0.8471, 0.6667, 0.0000],
         [0.0078, 0.0000, 0.0000, 0.0000, 0.2588, 0.7843, 0.8706, 0.9294,
          0.9373, 0.9490, 0.9647, 0.9529, 0.9569, 0.8667, 0.8627, 0.7569,
          0.7490, 0.7020, 0.7137, 0.7137, 0.7098, 0.6902, 0.6510, 0.6588,
          0.3882, 0.2275, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1569,
          0.2392, 0.1725, 0.2824, 0.1608, 0.1373, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000]]])
first_label: 9

读取小批量数据


## 读取小批量
batch_size = 2

def get_dataloader_workers():
    """返回 DataLoader 使用的子进程数量。"""
    # num_workers=4 通常能提升读取速度(CPU 预取数据)。
    # 如果你的机器核数较少,或在某些平台(如 Windows)遇到多进程问题,
    # 可以先改成 0 进行排查。
    return 4

# 通过 DataLoader 来按批次读取训练集:
# - batch_size=2:每次返回 2 个样本
# - shuffle=True:每个 epoch 开始前打乱顺序,减少模型记忆样本顺序的风险
# - num_workers=4:使用 4 个子进程并行加载
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, 
                             num_workers=get_dataloader_workers())

# 测试集的 DataLoader,通常不需要 shuffle,因为评估时顺序无关紧要。
# - batch_size=2:每次返回 2 个样本
# - shuffle=False:保持原始顺序,评估时更稳定
test_iter = data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers())

# 取一个 batch,并打印该 batch 中“所有样本”的完整内容。
for X, y in train_iter:
    print("batch X shape:", X.shape, "| y shape:", y.shape)

    # 设置张量打印格式:当元素较多时使用省略号显示。
    # 这样不会把 28x28 的所有像素都完整刷屏。
    torch.set_printoptions(edgeitems=2, threshold=50, linewidth=120)

    # 逐个样本打印:图像张量(完整像素值)+ 标签
    for i in range(X.shape[0]):
        print(f"\n===== sample {i} =====")
        print("image tensor:\n", X[i])
        print("label:", int(y[i]))

    break  # 只打印第一个 batch

运行结果


batch X shape: torch.Size([4, 1, 28, 28]) | y shape: torch.Size([4])

===== sample 0 =====
image tensor:
 tensor([[[0., 0.,  ..., 0., 0.],
         [0., 0.,  ..., 0., 0.],
         ...,
         [0., 0.,  ..., 0., 0.],
         [0., 0.,  ..., 0., 0.]]])
label: 1

===== sample 1 =====
image tensor:
 tensor([[[0., 0.,  ..., 0., 0.],
         [0., 0.,  ..., 0., 0.],
         ...,
         [0., 0.,  ..., 0., 0.],
         [0., 0.,  ..., 0., 0.]]])
label: 7

===== sample 2 =====
image tensor:
 tensor([[[0., 0.,  ..., 0., 0.],
         [0., 0.,  ..., 0., 0.],
         ...,
         [0., 0.,  ..., 0., 0.],
         [0., 0.,  ..., 0., 0.]]])
label: 6

===== sample 3 =====
image tensor:
 tensor([[[0., 0.,  ..., 0., 0.],
         [0., 0.,  ..., 0., 0.],
         ...,
         [0., 0.,  ..., 0., 0.],
         [0., 0.,  ..., 0., 0.]]])
label: 4