训练好的模型都需要保存,下面将举例演示如何保存和载入模型。
1.保存模型
首先需要建立一个saver,然后在session中通过saver的save即可将模型保存起来。代码如下:
# 之前是各种构建模型graph的操作(矩阵相乘,sigmoid等)
saver = tf.train.Saver() #生成saver
with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) #先对模型初始化
# 然后将数据丢入模型进行训练blablabla
# 训练完以后,使用saver.save 来保存
saver.save(sess, "save_path/file_name")
# file_name如果不存在,会自动创建
2.载入模型
将模型保存好以后,载入也比较方便。在session中通过调用saver的restore()
函数,会从指定的路径找到模型文件,并覆盖到相关参数中。代码如下:
saver = tf.train.Saver()
with tf.Session() as sess:
#参数可以进行初始化,也可不进行初始化。即使初始化了,初始化的值也会被restore的值给覆盖
sess.run(tf.global_variables_initializer())
saver.restore(sess, "save_path/file_name") #会将已经保存的变量值resotre到变量中。
实例:保存/载入线性回归模型
实例描述
在之前的例子基础中,添加模型的保存及载入功能。
通过扩展上一章的例子,来演示一下模型的保存及载入。在代码“3-1线性回归.py”文件中生成模拟数据之后,加入对图变量的重置,在session创建之前定义saver及保存路径,在session中训练结束后,保存模型。
代码4-4 线性回归模型保存及载入
01 import tensorflow as tf
02 import numpy as np
03 import matplotlib.pyplot as plt
04
05 #模拟数据
06 ……
07 plt.plot(train_X, train_Y, 'ro', label='Original data')
08 plt.legend()
09 plt.show()
10
11 #重置图
12 tf.reset_default_graph()
13
14 #初始话等操作
15 ……
16 display_step = 2
17
18 saver = tf.train.Saver() #生成saver
19 savedir = "log/" #生成模型的路径
20
21 #启动session
22 with tf.Session() as sess:
23 sess.run(init)
24 #在这里添加Sess中的训练代码
25 ……
26 print (" Finished!")
27 saver.save(sess, savedir+"linermodel.cpkt") #保存模型
28 print ("cost=", sess.run(cost, feed_dict=
{X: train_X, Y: train_Y}),"W=", sess.run(W), "b=", sess.run(b))
29 #其他代码
30 ……
运行上面代码可以看到,在代码的同级目下log文件夹里生成了几个文件,如图所示。
再重启一个session,并命名为sess2,在代码里通过使用saver的restore函数将模型载入。
代码4-4 线性回归模型保存及载入(续)
31 with tf.Session()as sess2:
32 sess2.run(tf.global_variables_initializer())
33 saver.restore(sess2,savedir+"linermodel.cpkt")
34 print ("x=0.2,z=", sess2.run(z, feed_dict={X: 0.2}))
为了测试效果,可以将前面一个session注释掉,运行之后可以看到如下输出:
INFO:tensorflow:Restoring parameters from log/linermodel.cpkt
x=0.2,z= [ 0.42615247]
表明模型已经成功载入,并计算出正确的值了。
实例:分析模型内容,演示模型的其他保存方法
下面再来详细介绍下关于模型保存的其他细节。
实例描述
将4.1.10节生成的模型里面的内容打印出来,观察其存放的具体数据方式。同时演示如何将指定内容保存到模型文件中。
1.模型内容
虽然模型已经保存了,但是仍然对我们不透明。下面通过编写代码将模型里的内容打印出来,看看到底保存了哪些东西,都是什么样的。
代码4-5 模型内容
01 from tensorflow.python.tools.inspect_checkpoint import print_tensors_
in_checkpoint_file
02 savedir = "log/"
03 print_tensors_in_checkpoint_file(savedir+"linermodel.cpkt", None, True)
运行代码,打印如下信息:
tensor_name: bias
[ 0.01919404]
tensor_name: weight
[ 2.03479218]
可以看到,tensor_name:后面跟的就是创建的变量名,接着是它的数值。
2.保存模型的其他方法
前面的例子中Saver的创建比较简单,其实tf.train.Saver函数里面还可以放参数来实现更高级的功能,可以指定存储变量名字与变量的对应关系。可以写成这样:
saver = tf.train.Saver({'weight': W, 'bias': b})
代表将w变量的值放到weight名字中。类似的写法还有以下两种:
saver = tf.train.Saver([W, b]) #放到一个list里
saver = tf.train.Saver({v.op.name: v for v in [W, b]}) #将op的名字当作key
下面扩展上述的例子,给b和w分别指定一个固定值,并将它们颠倒放置。
代码4-5 模型内容(续)
03 W = tf.Variable(1.0, name="weight")
04 b = tf.Variable(2.0, name="bias")
05
06 #放到一个字典里
07 saver = tf.train.Saver({'weight': b, 'bias': W})
08
09 with tf.Session() as sess:
10 tf.global_variables_initializer().run()
11 saver.save(sess, savedir+"linermodel.cpkt")
12
13 print_tensors_in_checkpoint_file(savedir+"linermodel.cpkt", None, True)
运行上面代码,输出如下信息:
tensor_name: bias
1.0
tensor_name: weight
2.0
例子中,W值设为1.0,b的值设为2.0。在创建saver时将它们颠倒,保存的模型打印出来之后可以看到,bias变成了1.0,而weight变成了2.0。