授课语音

使用预训练的 ResNet 网络进行图片分类

在深度学习中,ResNet(Residual Networks)是一种非常强大的卷积神经网络,它通过引入跳跃连接(skip connections)来解决深度神经网络训练中遇到的梯度消失问题。在 PyTorch 中,我们可以使用预训练的 ResNet 网络来进行图片分类任务。

下面将介绍如何使用预训练的 ResNet 网络进行图片分类的步骤。

1. 环境准备

确保已经安装了 PyTorch 和 torchvision(用于计算机视觉任务)。如果没有安装,可以使用以下命令安装:

pip install torch torchvision matplotlib

2. 加载预训练的 ResNet 网络

PyTorch 提供了多个预训练的网络模型,包括 ResNet。这里我们使用 ResNet18(可以选择其他版本,如 ResNet34、ResNet50 等)。

import torch
import torchvision
from torchvision import datasets, transforms
from torch import nn, optim
import matplotlib.pyplot as plt
from PIL import Image

# 加载预训练的 ResNet18 模型
model = torchvision.models.resnet18(pretrained=True)

# 将模型设置为评估模式
model.eval()

3. 准备数据集

假设我们要分类的数据是图片格式。首先,我们需要对图片进行预处理(如缩放、裁剪、归一化等),然后加载图片。

我们将使用 PyTorch 的 torchvision.transforms 来处理图片。

3.1 定义数据预处理操作

# 定义图像预处理操作
transform = transforms.Compose([
    transforms.Resize(256),  # 将图片缩放为 256x256
    transforms.CenterCrop(224),  # 裁剪中心区域,大小为 224x224
    transforms.ToTensor(),  # 将图片转换为 Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 对图片进行归一化处理
])

3.2 加载图片

# 加载一张图片
img_path = 'path_to_your_image.jpg'  # 替换为图片路径
img = Image.open(img_path)

# 应用预处理操作
img_tensor = transform(img)

# 添加一个批次维度
img_tensor = img_tensor.unsqueeze(0)

4. 进行预测

将预处理后的图片输入到预训练的 ResNet 网络中进行分类。网络的输出是一个类别的概率分布。

# 将图片传入模型进行预测
with torch.no_grad():  # 在预测时不需要计算梯度
    output = model(img_tensor)

# 获取预测结果,输出最大概率对应的类别索引
_, predicted_class = torch.max(output, 1)

# 打印预测的类别索引
print(f"Predicted class index: {predicted_class.item()}")

5. 加载标签文件

PyTorch 提供了一个预定义的标签文件,可以将预测的类别索引映射为实际的类别名称。我们可以使用 imagenet_class_index.json 文件来获取这些标签。

# 加载 ImageNet 类别名称
import json
LABELS_PATH = 'https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json'
import requests

# 获取 ImageNet 标签
response = requests.get(LABELS_PATH)
labels = response.json()

# 输出预测类别的名称
predicted_class_name = labels[str(predicted_class.item())][1]
print(f"Predicted class name: {predicted_class_name}")

6. 显示图片与预测结果

# 显示图片
plt.imshow(img)
plt.title(f"Predicted: {predicted_class_name}")
plt.axis('off')  # 不显示坐标轴
plt.show()

7. 完整代码

以下是完整的代码流程,结合了图片的加载、预处理、预测和结果展示。

import torch
import torchvision
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import requests

# 加载预训练的 ResNet18 模型
model = torchvision.models.resnet18(pretrained=True)
model.eval()  # 将模型设置为评估模式

# 定义图像预处理操作
transform = transforms.Compose([
    transforms.Resize(256),  # 缩放图片
    transforms.CenterCrop(224),  # 裁剪中心区域
    transforms.ToTensor(),  # 转换为 Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
])

# 加载并预处理图片
img_path = 'path_to_your_image.jpg'  # 替换为你的图片路径
img = Image.open(img_path)
img_tensor = transform(img).unsqueeze(0)  # 添加批次维度

# 进行预测
with torch.no_grad():  # 关闭梯度计算
    output = model(img_tensor)

# 获取预测结果
_, predicted_class = torch.max(output, 1)

# 获取 ImageNet 类别名称
LABELS_PATH = 'https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json'
response = requests.get(LABELS_PATH)
labels = response.json()

# 输出预测的类别名称
predicted_class_name = labels[str(predicted_class.item())][1]
print(f"Predicted class index: {predicted_class.item()}")
print(f"Predicted class name: {predicted_class_name}")

# 显示图片与预测结果
plt.imshow(img)
plt.title(f"Predicted: {predicted_class_name}")
plt.axis('off')
plt.show()

8. 总结

通过以上步骤,我们成功地使用了预训练的 ResNet 网络进行图片分类。这个流程包括了图片的预处理、加载预训练的网络、进行推理和输出分类结果。通过这种方式,我们可以非常方便地使用深度学习模型对图像进行分类,即使我们没有大量的训练数据。

去1:1私密咨询

系列课程: