博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
【TensorFlow系列】【三】冻结模型文件并做inference
阅读量:6247 次
发布时间:2019-06-22

本文共 4567 字,大约阅读时间需要 15 分钟。

hot3.png

本文基于mnist与lenet,讲述如下两个问题:

1.如何将训练好的网络模型冻结,形成net.pb文件?

2.如何将net.pb文件部署到TensorFlow中做inference?

pb文件保存的步骤

1.需要给input与最终的预测值取个名字,便于部署时输入数据并输出数据
2.利用graph_util.convert_variables_to_constants将网络中模型参数变量转换为常量
3.利用tf.gfile.FastGFile将模型参数序列化后的数据写入文件。

pb文件部署步骤:

1.利用tf.gfile.FastGFile读取pb文件,并将文件中存储的graph导入到TensorFlow中。
2.从graph中获取input与output变量,传入图片数据,做inference

 

【基于mnist与lenet,保存pb文件】

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datafrom tensorflow.python.framework import graph_utilmnist = input_data.read_data_sets(train_dir=r"E:\mnist_data",one_hot=True)#定义输入数据mnist图片大小28*28*1=784,None表示batch_sizex = tf.placeholder(dtype=tf.float32,shape=[None,28*28],name="input")#定义标签数据,mnist共10类y_ = tf.placeholder(dtype=tf.float32,shape=[None,10],name="y_")#将数据调整为二维数据,w*H*c---> 28*28*1,-1表示N张image = tf.reshape(x,shape=[-1,28,28,1])#第一层,卷积核={5*5*1*32},池化核={2*2*1,1*2*2*1}w1 = tf.Variable(initial_value=tf.random_normal(shape=[5,5,1,32],stddev=0.1,dtype=tf.float32,name="w1"))b1= tf.Variable(initial_value=tf.zeros(shape=[32]))conv1 = tf.nn.conv2d(input=image,filter=w1,strides=[1,1,1,1],padding="SAME",name="conv1")relu1 = tf.nn.relu(tf.nn.bias_add(conv1,b1),name="relu1")pool1 = tf.nn.max_pool(value=relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")#shape={None,14,14,32}#第二层,卷积核={5*5*32*64},池化核={2*2*1,1*2*2*1}w2 = tf.Variable(initial_value=tf.random_normal(shape=[5,5,32,64],stddev=0.1,dtype=tf.float32,name="w2"))b2 = tf.Variable(initial_value=tf.zeros(shape=[64]))conv2 = tf.nn.conv2d(input=pool1,filter=w2,strides=[1,1,1,1],padding="SAME")relu2 = tf.nn.relu(tf.nn.bias_add(conv2,b2),name="relu2")pool2 = tf.nn.max_pool(value=relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME",name="pool2")#shape={None,7,7,64}#FC1w3 = tf.Variable(initial_value=tf.random_normal(shape=[7*7*64,1024],stddev=0.1,dtype=tf.float32,name="w3"))b3 = tf.Variable(initial_value=tf.zeros(shape=[1024]))#关键,进行reshapeinput3 = tf.reshape(pool2,shape=[-1,7*7*64],name="input3")fc1 = tf.nn.relu(tf.nn.bias_add(value=tf.matmul(input3,w3),bias=b3),name="fc1")#shape={None,1024}#FC2w4 = tf.Variable(initial_value=tf.random_normal(shape=[1024,10],stddev=0.1,dtype=tf.float32,name="w4"))b4 = tf.Variable(initial_value=tf.zeros(shape=[10]))fc2 = tf.nn.bias_add(value=tf.matmul(fc1,w4),bias=b4)#shape={None,10}#定义交叉熵损失# 使用softmax将NN计算输出值表示为概率y = tf.nn.softmax(fc2,name="out")# 定义交叉熵损失函数cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=fc2,labels=y_)loss = tf.reduce_mean(cross_entropy)#定义solvertrain = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss)#定义正确值,判断二者下标index是否相等correct_predict = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))#定义如何计算准确率accuracy = tf.reduce_mean(tf.cast(correct_predict,dtype=tf.float32),name="accuracy")#定义初始化opinit = tf.global_variables_initializer()#训练NNwith tf.Session() as session:    session.run(fetches=init)    for i in range(0,1000):        xs, ys = mnist.train.next_batch(100)        session.run(fetches=train,feed_dict={x:xs,y_:ys})        if i%100 == 0:            train_accuracy = session.run(fetches=accuracy,feed_dict={x:xs,y_:ys})            print(i,"accuracy=",train_accuracy)    #训练完成后,将网络中的权值转化为常量,形成常量graph    constant_graph = graph_util.convert_variables_to_constants(sess=session,                                                            input_graph_def=session.graph_def,                                                            output_node_names=['out'])    #将带权值的graph序列化,写成pb文件存储起来    with tf.gfile.FastGFile("lenet.pb", mode='wb') as f:        f.write(constant_graph.SerializeToString())

【将pb文件部署到TensorFlow中并做inference】

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport numpy as npmnist = input_data.read_data_sets(train_dir=r"E:\mnist_data",one_hot=True)pb_path = r"lenet.pb"#导入pb文件到graph中with tf.gfile.FastGFile(pb_path,'rb') as f:    # 复制定义好的计算图到新的图中,先创建一个空的图.    graph_def = tf.GraphDef()    # 加载proto-buf中的模型    graph_def.ParseFromString(f.read())    # 最后复制pre-def图的到默认图中.    _ = tf.import_graph_def(graph_def, name='')with tf.Session() as session:    #获取输入tensor    input = tf.get_default_graph().get_tensor_by_name("input:0")    #获取预测tensor    output = tf.get_default_graph().get_tensor_by_name("out:0")    #取第100张图片测试    one_image = np.reshape(mnist.test.images[100], [-1, 784])    #将测试图片传入nn中,做inference    out = session.run(output,feed_dict={input:one_image})    pre_label = np.argmax(out,1)    print("pre_label=",pre_label)    print('true label:', np.argmax(mnist.test.labels[100],0))

测试结果如下图:

144014_NVvM_3800567.png

 

转载于:https://my.oschina.net/u/3800567/blog/1637829

你可能感兴趣的文章
词法分析
查看>>
安装laravel框架
查看>>
Linux 目录结构
查看>>
第二次实验
查看>>
R中,求五数,最小值、下四分位数、中位数、上四分位数、最大值
查看>>
【python-Day3】
查看>>
接上一篇——上海有哪些值得加入的互联网公司
查看>>
VFS相关内容
查看>>
【转载】同步和互斥的POSIX支持(互斥锁,条件变量,自旋锁)
查看>>
+load和+initialize的区别
查看>>
hdu 1319 Prime Cuts
查看>>
Effective_STL 学习笔记(二十四) 当关乎效率时应该在 map::operator[] 和 map-insert 之间仔细选择...
查看>>
Linux课程---7、shell技巧(获取帮助命令)
查看>>
写一个类似淘宝的ios app需要用到哪些技术?
查看>>
#505. 「LibreOJ β Round」ZQC 的游戏
查看>>
#iOS问题记录# UITextview富文本链接,禁止长按事件
查看>>
深度网络实现手写体识别
查看>>
Python Module_subprocess_调用 Powershell
查看>>
MVC原理图解
查看>>
c基础
查看>>