说说tensorflow使用Saver踩过的坑

tf.train.Saver的一种尝试

背景

来自于一个神奇的bug,模型第一次训练的时候能正常跑,但是通过tf.train.Saver 的restore恢复checkpoint之后就抛出Attempting to use uninitialized value xxxxxx,最后发现原来是saver定义时候的位置出现问题。

代码实验

以下代码模拟了一次模型保存和恢复的使用情况。首先是定义了计算图,初始化后保存模型,再清除计算图,之后再恢复checkpoint。

#!/usr/bin/python3
# -*- coding: utf-8 -*-
# Created by Ross on 18-10-31
import os

import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file


def create_graph():
    global a
    global saver
    global b
    a = tf.get_variable('a', shape=[2, 2], initializer=tf.initializers.random_normal())
    saver = tf.train.Saver(name='saver')  # 定义在a的后面,b的前面
    b = tf.get_variable('b', shape=[3, 3], initializer=tf.initializers.random_normal())


if __name__ == '__main__':
    with tf.Session() as sess:
        create_graph()  # 定义计算图
        sess.run(tf.global_variables_initializer())  # 初始化计算图

        if not os.path.exists('tmp'):
            os.mkdir('tmp')
        saver.save(sess, 'tmp/tmp.ckpt')  # 保存checkpoint

    tf.reset_default_graph()  # 清除计算图

    with tf.Session() as sess:
        create_graph()  # 再定义计算图
        print_tensors_in_checkpoint_file('tmp/tmp.ckpt', tensor_name='', all_tensors=True)
        saver.restore(sess, 'tmp/tmp.ckpt')  # 恢复checkpoint
        print(sess.run(b))  # 尝试运行b计算

运行后报错

tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value b
     [[{{node b/_2}} = _Send[T=DT_FLOAT, client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_4_b", _device="/job:localhost/replica:0/task:0/device:GPU:0"](b)]]
     [[{{node b/_3}} = _Recv[_start_time=0, client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_4_b", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

分析

原因是saver定义在了a之后和b之前,saver定义的时候默认会将已有的计算图保存,等到使用saver.save()的时候会将saver对象内保存的计算图保存到磁盘中,因此saver保存的计算图中没有b变量,所以恢复checkpoint的时候不会恢复b变量,因此运行计算b的时候会抛出变量未初始化的错误。