保存和载入模型的方法介绍 - Tensorflow

半兽人 发表于: 2019-04-13   最后更新时间: 2019-04-13 23:19:22  
{{totalSubscript}} 订阅, 3,593 游览

训练好的模型都需要保存,下面将举例演示如何保存和载入模型。

1.保存模型

首先需要建立一个saver,然后在session中通过saver的save即可将模型保存起来。代码如下:

# 之前是各种构建模型graph的操作(矩阵相乘,sigmoid等)
saver = tf.train.Saver() #生成saver
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) #先对模型初始化
    # 然后将数据丢入模型进行训练blablabla
    # 训练完以后,使用saver.save 来保存
    saver.save(sess, "save_path/file_name")
# file_name如果不存在,会自动创建

2.载入模型

将模型保存好以后,载入也比较方便。在session中通过调用saver的restore()函数,会从指定的路径找到模型文件,并覆盖到相关参数中。代码如下:

saver = tf.train.Saver()
with tf.Session() as sess:
    #参数可以进行初始化,也可不进行初始化。即使初始化了,初始化的值也会被restore的值给覆盖
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, "save_path/file_name") #会将已经保存的变量值resotre到变量中。

实例:保存/载入线性回归模型

实例描述

在之前的例子基础中,添加模型的保存及载入功能。

通过扩展上一章的例子,来演示一下模型的保存及载入。在代码“3-1线性回归.py”文件中生成模拟数据之后,加入对图变量的重置,在session创建之前定义saver及保存路径,在session中训练结束后,保存模型。

代码4-4 线性回归模型保存及载入

01 import tensorflow as tf
02 import numpy as np
03 import matplotlib.pyplot as plt
04
05 #模拟数据
06 ……
07 plt.plot(train_X, train_Y, 'ro', label='Original data')
08 plt.legend()
09 plt.show()
10
11 #重置图
12 tf.reset_default_graph()
13
14 #初始话等操作
15 ……
16 display_step = 2
17
18 saver = tf.train.Saver() #生成saver
19 savedir = "log/" #生成模型的路径
20
21 #启动session
22 with tf.Session() as sess:
23     sess.run(init)
24     #在这里添加Sess中的训练代码
25     ……
26     print (" Finished!")
27     saver.save(sess, savedir+"linermodel.cpkt") #保存模型
28     print ("cost=", sess.run(cost, feed_dict=
{X:     train_X, Y: train_Y}),"W=", sess.run(W), "b=", sess.run(b))
29 #其他代码
30 ……

运行上面代码可以看到,在代码的同级目下log文件夹里生成了几个文件,如图所示。

screenshot

再重启一个session,并命名为sess2,在代码里通过使用saver的restore函数将模型载入。

代码4-4 线性回归模型保存及载入(续)

31 with tf.Session()as sess2:
32 sess2.run(tf.global_variables_initializer())
33 saver.restore(sess2,savedir+"linermodel.cpkt")
34 print ("x=0.2,z=", sess2.run(z, feed_dict={X: 0.2}))

为了测试效果,可以将前面一个session注释掉,运行之后可以看到如下输出:

INFO:tensorflow:Restoring parameters from log/linermodel.cpkt
x=0.2,z= [ 0.42615247]

表明模型已经成功载入,并计算出正确的值了。

实例:分析模型内容,演示模型的其他保存方法

下面再来详细介绍下关于模型保存的其他细节。

实例描述

将4.1.10节生成的模型里面的内容打印出来,观察其存放的具体数据方式。同时演示如何将指定内容保存到模型文件中。

1.模型内容

虽然模型已经保存了,但是仍然对我们不透明。下面通过编写代码将模型里的内容打印出来,看看到底保存了哪些东西,都是什么样的。

代码4-5 模型内容

01 from tensorflow.python.tools.inspect_checkpoint import print_tensors_
in_checkpoint_file
02 savedir = "log/"
03 print_tensors_in_checkpoint_file(savedir+"linermodel.cpkt", None, True)

运行代码,打印如下信息:

tensor_name: bias
[ 0.01919404]
tensor_name: weight
[ 2.03479218]

可以看到,tensor_name:后面跟的就是创建的变量名,接着是它的数值。

2.保存模型的其他方法

前面的例子中Saver的创建比较简单,其实tf.train.Saver函数里面还可以放参数来实现更高级的功能,可以指定存储变量名字与变量的对应关系。可以写成这样:

saver = tf.train.Saver({'weight': W, 'bias': b})

代表将w变量的值放到weight名字中。类似的写法还有以下两种:

saver = tf.train.Saver([W, b]) #放到一个list里
saver = tf.train.Saver({v.op.name: v for v in [W, b]}) #将op的名字当作key

下面扩展上述的例子,给b和w分别指定一个固定值,并将它们颠倒放置。

代码4-5 模型内容(续)

03 W = tf.Variable(1.0, name="weight")
04 b = tf.Variable(2.0, name="bias")
05
06 #放到一个字典里
07 saver = tf.train.Saver({'weight': b, 'bias': W})
08
09 with tf.Session() as sess:
10 tf.global_variables_initializer().run()
11 saver.save(sess, savedir+"linermodel.cpkt")
12
13 print_tensors_in_checkpoint_file(savedir+"linermodel.cpkt", None, True)

运行上面代码,输出如下信息:

tensor_name: bias
1.0
tensor_name: weight
2.0

例子中,W值设为1.0,b的值设为2.0。在创建saver时将它们颠倒,保存的模型打印出来之后可以看到,bias变成了1.0,而weight变成了2.0。

更新于 2019-04-13
在线,2小时前登录

查看TensorFlow更多相关的文章或提一个关于TensorFlow的问题,也可以与我们一起分享文章