Pytorch实现CIFAR10之读取模型训练本地图片

示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/9/14 12:51
# @Author : Seven
# @Site :
# @File : test.py
# @Software: PyCharm

import torch
import numpy as np
from PIL import Image

# 读取模型
model = torch.load('LeNet.pkl')
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


def local_photos():
# input
im = Image.open('plane.jpg')
# im = im.convert('L')

im = im.resize((32, 32))
# im.show()
im = np.array(im).astype(np.float32)
im = np.reshape(im, [-1, 32*32*3])
im = (im - (255 / 2.0)) / 255

batch_xs = np.reshape(im, [-1, 3, 32, 32])
batch_xs = torch.FloatTensor(batch_xs)

# 预测
pred_y, _ = model(batch_xs)
pred_y = torch.max(pred_y, 1)[1].data.numpy().squeeze()

print("The predict is : ", classes[pred_y])


local_photos()

测试图片:

images

测试结果:

1
The predict is :  plane

转载请注明:Seven的博客

本文标题:Pytorch实现CIFAR10之读取模型训练本地图片

文章作者:Seven

发布时间:2018年09月15日 - 00:00:00

最后更新:2018年12月11日 - 22:14:45

原始链接:http://yoursite.com/2018/09/15/2018-09-15-Pytorch-Cifar10-test/

许可协议: 署名-非商业性使用-禁止演绎 4.0 国际 转载请保留原文链接及作者。

------ 本文结束------
坚持原创技术分享,您的支持将鼓励我继续创作!
0%