Pytorch实现DenseNet

示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/9/14 16:02
# @Author : Seven
# @Site :
# @File : DenseNet.py
# @Software: PyCharm
import math
import torch
import torch.nn as nn


class Bn_act_conv_drop(nn.Module):
def __init__(self, inputs, outs, kernel_size, padding):
super(Bn_act_conv_drop, self).__init__()
self.bn = nn.Sequential(
nn.BatchNorm2d(inputs),
nn.ReLU()
)
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=inputs,
out_channels=outs,
kernel_size=kernel_size,
padding=padding,
stride=1),
nn.ReLU(),
nn.Dropout()
)

def forward(self, inputs):
network = self.bn(inputs)
network = self.conv(network)
return network


class Transition(nn.Module):
def __init__(self, inputs, outs):
super(Transition, self).__init__()
self.conv = Bn_act_conv_drop(inputs, outs, kernel_size=1, padding=0)
self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)

def forward(self, inputs):
network = self.conv(inputs)
network = self.avgpool(network)
return network


class Block(nn.Module):
def __init__(self, inputs, growth):
super(Block, self).__init__()
self.conv1 = Bn_act_conv_drop(inputs, 4*growth, kernel_size=1, padding=0)
self.conv2 = Bn_act_conv_drop(4*growth, growth, kernel_size=3, padding=1)

def forward(self, inputs):
network = self.conv1(inputs)
network = self.conv2(network)

out = torch.cat([network, inputs], 1)
return out


class DenseNet(nn.Module):
def __init__(self, blocks, growth):
super(DenseNet, self).__init__()
num_planes = 2*growth
inputs = 3
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=inputs,
out_channels=num_planes,
kernel_size=3,
# stride=2,
padding=1),
nn.ReLU(),
# nn.MaxPool2d(kernel_size=2, stride=2)
)
self.block1 = self._block(blocks[0], num_planes, growth)
num_planes += blocks[0] * growth
out_planes = int(math.floor(num_planes * 0.5))
self.tran1 = Transition(inputs=num_planes, outs=out_planes)
num_planes = out_planes

self.block2 = self._block(blocks[1], num_planes, growth)
num_planes += blocks[1] * growth
out_planes = int(math.floor(num_planes * 0.5))
self.tran2 = Transition(inputs=num_planes, outs=out_planes)
num_planes = out_planes

self.block3 = self._block(blocks[2], num_planes, growth)
num_planes += blocks[2] * growth
out_planes = int(math.floor(num_planes * 0.5))
self.tran3 = Transition(inputs=num_planes, outs=out_planes)
num_planes = out_planes

self.block4 = self._block(blocks[3], num_planes, growth)
num_planes += blocks[3] * growth

self.bn = nn.Sequential(
nn.BatchNorm2d(num_planes),
nn.ReLU()
)
self.avgpool = nn.AvgPool2d(kernel_size=4)
self.linear = nn.Linear(num_planes, 10)

def forward(self, inputs):
network = self.conv(inputs)
network = self.block1(network)
network = self.tran1(network)

network = self.block2(network)
network = self.tran2(network)

network = self.block3(network)
network = self.tran3(network)

network = self.block4(network)
network = self.bn(network)

network = self.avgpool(network)
network = network.view(network.size(0), -1)
out = self.linear(network)

return out, network

@staticmethod
def _block(layers, inputs, growth):
block_layer = []
for layer in range(layers):
network = Block(inputs, growth)
block_layer.append(network)
inputs += growth
block_layer = nn.Sequential(*block_layer)
return block_layer


def DenseNet121():
return DenseNet(blocks=[6, 12, 24, 16], growth=32)


def DenseNet169():
return DenseNet(blocks=[6, 12, 32, 32], growth=32)


def DenseNet201():
return DenseNet(blocks=[6, 12, 48, 32], growth=32)


def DenseNet161():
return DenseNet(blocks=[6, 12, 36, 24], growth=48)


def DenseNet_cifar():
return DenseNet(blocks=[6, 12, 24, 16], growth=12)

转载请注明:Seven的博客

本文标题:Pytorch实现DenseNet

文章作者:Seven

发布时间:2018年09月15日 - 00:00:00

最后更新:2018年12月11日 - 22:15:13

原始链接:http://yoursite.com/2018/09/15/2018-09-15-Pytorch-DenseNet/

许可协议: 署名-非商业性使用-禁止演绎 4.0 国际 转载请保留原文链接及作者。

------ 本文结束------
坚持原创技术分享,您的支持将鼓励我继续创作!
0%