🎨关于 Transforms.Compose 的小问题

  • 今天简单的玩了一下pytorch自带的CIFAR10数据集构建多分类预测模型,使用torchvision.transforms.Compose时,发现了一个小小的BUG。

  • 我已经训练好了CIFAR10的预测模型,而我要对模型进行简单的验证。

导入相关模块和已经构建好的神经网络结构模型

1
2
3
4
5
6
import torch
import torchvision
from PIL import Image

from nn_model import Net

编写主程序代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
image_path = './image/img.png'
image = Image.open(image_path)

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Resize((32, 32))])
image = transform(image)
image = torch.reshape(image, [1, 3, 32, 32])

model = Net()
model.load_state_dict(torch.load('nn_cifar10.pth', map_location=torch.device('cuda')))

with torch.no_grad():
output = model(image)
print(output.argmax(1))

'''img.png即为要进行验证模型的图片,nn_cifar10.pth就是已经训练好的预测模型'''
  • 当我找了一张小狗的图片进行验证时,结果是 airplane ,什么!!!飞机?我的天~
dog airplane
  • 我的第一反应是我的预测模型训练迭代不好造成的,于是我对CIFAR10模型训练增长到35次,这次应该没问题了吧。但是这次竟然给我预测成了 frog,不应该呀!
frog cat
  • 于是我又找了一张小猫的图片,竟然也是 frog ,这么一来说明我验证的全不对了。怎么回事呢?

🐎去看官方文档,在Compose并没有找到答案,在我捣鼓了半天后,突然发现了一个问题!!!

1
2
3
4
5
6
7
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Resize((32, 32))])

'''将上面的ToTensor与Resize互换位置'''

transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()])
  • 当我将ToTensorResize两个 object 互换位置后,发现预测的结果正确了!而且再试了多个不同图片也预测正确了!😵🤔

🐟晕,究竟是BUG还是其它,还得慢慢查找原因。