TensorFlow--保存神经网络参数和加载神经网络参数

保存神经网络参数

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/9/4 16:11
# @Author : Seven
# @Site :
# @File : save.py
# @Software: PyCharm
import tensorflow as tf


# 保存神经网络参数
def save_para():
# 定义权重参数
W = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype = tf.float32, name = 'weights')
# 定义偏置参数
b = tf.Variable([[1, 2, 3]], dtype = tf.float32, name = 'biases')
# 参数初始化
init = tf.global_variables_initializer()
# 定义保存参数的saver
saver = tf.train.Saver()

with tf.Session() as sess:
sess.run(init)
# 保存session中的数据
save_path = saver.save(sess, './save_net.ckpt')
# 输出保存路径
print('Save to path: ', save_path)


save_para()

执行结果

1
Save to path:  ./save_net.ckpt

加载神经网络参数

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/9/4 16:14
# @Author : Seven
# @Site :
# @File : restore.py
# @Software: PyCharm
import tensorflow as tf
import numpy as np


# 恢复神经网络参数
def restore_para():
# 定义权重参数
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype = tf.float32, name = 'weights')
# 定义偏置参数
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype = tf.float32, name = 'biases')
# 定义提取参数的saver
saver = tf.train.Saver()

with tf.Session() as sess:
# 加载文件中的参数数据,会根据name加载数据并保存到变量W和b中
save_path = saver.restore(sess, 'save_net.ckpt')
# 输出保存路径
print('Weights: ', sess.run(W))
print('biases: ', sess.run(b))


restore_para()

执行结果

1
2
3
Weights:  [[1. 2. 3.]
[4. 5. 6.]]
biases: [[1. 2. 3.]]

转载请注明:Seven的博客

本文标题:TensorFlow--保存神经网络参数和加载神经网络参数

文章作者:Seven

发布时间:2018年09月04日 - 00:00:00

最后更新:2018年12月11日 - 22:12:10

原始链接:http://yoursite.com/2018/09/04/2018-09-04-TensorFlow-save-restore/

许可协议: 署名-非商业性使用-禁止演绎 4.0 国际 转载请保留原文链接及作者。

------ 本文结束------
坚持原创技术分享,您的支持将鼓励我继续创作!
0%