Pytorch实现CIFAR-10之数据预处理

示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/9/13 19:45
# @Author : Seven
# @Site :
# @File : train.py
# @Software: PyCharm

import torch
import os
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as Data
import torch.nn as nn
import torch.optim as optim

print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


DOWNLOAD_MNIST = False

if not(os.path.exists('./data/')) or not os.listdir('./data/'): # 判断数据是否存在
DOWNLOAD_MNIST = True

trainset = torchvision.datasets.CIFAR10(root='./data',
train=True,
download=DOWNLOAD_MNIST,
transform=transform_train)

trainloader = torch.utils.data.DataLoader(trainset,
batch_size=128,
shuffle=True,
num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data',
train=False,
download=DOWNLOAD_MNIST,
transform=transform_test)

testloader = torch.utils.data.DataLoader(testset,
batch_size=100,
shuffle=False,
num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

转载请注明:Seven的博客

本文标题:Pytorch实现CIFAR-10之数据预处理

文章作者:Seven

发布时间:2018年09月15日 - 00:00:00

最后更新:2018年12月11日 - 22:14:33

原始链接:http://yoursite.com/2018/09/15/2018-09-15-Pytorch-cifar10-data/

许可协议: 署名-非商业性使用-禁止演绎 4.0 国际 转载请保留原文链接及作者。

------ 本文结束------
坚持原创技术分享,您的支持将鼓励我继续创作!
0%