1. 程式人生 > >tensorflow儲存載入多個模型

tensorflow儲存載入多個模型

#儲存載入過個模型時要注意必須指定Graph
class MLP(object):
    def __init__(self, id):
        if not os.path.exists('./' + id):
            os.makedirs('./' + id)
        self.id = id

        self.graph = tf.Graph()
        self.session_conf = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=False)
        self.load_model()

    def init_net(self):
        # Placeholders for input, output and dropout
        self.input_x = tf.placeholder(tf.float32, [None, 1], name="input_x")
        self.input_y = tf.placeholder(tf.float32, [None, 1], name="input_y")

        with tf.name_scope('mlp1'):
            W = tf.Variable(tf.truncated_normal([1,50], stddev=0.1), name="W")
            b = tf.Variable(tf.constant(0.1, shape=[50]), name="b")
            self.mlp1 = tf.nn.xw_plus_b(self.input_x, W, b, name="xwb")

        with tf.name_scope('mlp2'):
            W1 = tf.Variable(tf.truncated_normal([50,1], stddev=0.1), name="W1")
            b1 = tf.Variable(tf.constant(0.1, shape=[1]), name="b1")
            self.mlp1 = tf.nn.xw_plus_b(self.mlp1, W1, b1, name="xwb1")
            self.prediction = tf.nn.sigmoid(self.mlp1)

        with tf.name_scope("loss"):
            losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.mlp1, labels=self.input_y)
            self.loss = tf.reduce_mean(losses)

        with tf.name_scope("optimizer"):
            self.global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.AdamOptimizer(1e-3)
            grads_and_vars = optimizer.compute_gradients(self.loss)
            self.train_op = optimizer.apply_gradients(grads_and_vars, global_step=self.global_step)

    def load_model(self):
        with self.graph.as_default():
            self.sess = tf.Session(graph=self.graph, config=self.session_conf)
            if os.path.exists('./' + self.id + '/model.meta'):
                self.init_net()
                self.saver = tf.train.Saver()
                self.saver.restore(self.sess, tf.train.latest_checkpoint('./' + self.id))
            else:
                self.init_net()
                self.sess.run(tf.global_variables_initializer())
                self.saver = tf.train.Saver()


    def train(self):
        print 'traning'
        with self.sess.as_default():
            for i in range(1000):
                x, y = generate_data(1000,self.id)
                loss,_ = self.sess.run([self.loss,self.train_op],feed_dict={self.input_x:x,self.input_y:y})
                x_test,y_test = generate_data(100,self.id)
                prediction = self.sess.run(self.prediction, feed_dict={self.input_x: x_test, self.input_y: y_test})
                acc = self.get_acc(prediction,y_test)
                print 'step:',i,'loss:',loss,'acc:',acc
            self.saver.save(self.sess, './' + self.id + '/model')

    def test(self):
        print 'testing'
        with self.sess.as_default():
            x_test, y_test = generate_data(1000,self.id)
            prediction = self.sess.run(self.prediction, feed_dict={self.input_x: x_test, self.input_y: y_test})
            acc = self.get_acc(prediction, y_test)
            print 'acc:', acc

相關推薦

tensorflow儲存載入模型

#儲存載入過個模型時要注意必須指定Graphclass MLP(object): def __init__(self, id): if not os.path.exists('./' + id): os.makedirs('.

Tensorflow 同時載入模型

這涉及到詞向量,具體看可以參考這篇文章:Word2vec 之 Skip-Gram 模型,下面只進行簡單的描述, 上圖的流程是把文章的單詞使用詞向量來表示。 (1)提取文章所有的單詞,把其按其出現的次數降許(這裡只取前5000個),比如單詞‘network’出現的次數最多,編號ID為0,依次類推… (2)

TensorFlow 載入模型的方法

採用 TensorFlow 的時候,有時候我們需要載入的不止是一個模型,那麼如何載入多個模型呢? 原文:bretahajek.com/2017/04/imp… 關於 TensorFlow 可以有很多東西可以說。但這次我只介紹如何匯入訓練好的模型(圖),因為我做不到匯入第二個模型並將它和第一個模型

TensorFlow學習系列(三):儲存/恢復和混合模型

這篇教程是翻譯Morgan寫的TensorFlow教程,作者已經授權翻譯,這是原文。 目錄 在學習這篇部落格之前,我希望你已經掌握了Tensorflow基本的操作。如果沒有,你可以閱讀這篇入門文章。 為什麼要

TensorFlow儲存/恢復和混合模型

這篇教程是翻譯Morgan寫的TensorFlow教程,作者已經授權翻譯,這是原文 目錄 在學習這篇部落格之前,我希望你已經掌握了Tensorflow基本的操作。如果沒有,你可以閱讀這篇入門文章。 為什麼要學習模型的儲存和恢復呢?因為這對於避免資料的混亂無序是至關重要的,特別是在你程式碼中的不同圖。

tensorflow 使用塊GPU同時訓練模型

轉自:http://stackoverflow.com/questions/34775522/tensorflow-mutiple-sessions-with-mutiple-gpus TensorFlow will attempt to use (an equal fraction of the me

Asp.net Mvc action返回模型實體給view

姓名 query ont info erb users box html asp 1、controller中action代碼: public class HomeController : Controller { public ActionResu

keras實現模型融合(非keras自帶模型,這裡以3自己的模型為例)

該程式碼可以實現類似圖片的效果,多個模型採用第一個輸入。 圖片來源:https://github.com/keras-team/keras/issues/4205   step 1:重新定義模型(這是我自己的模型,你們可以用你們自己的),與預訓練不一樣,這裡定義模型inp

Pycharm載入專案

Pycharm中的載入多個專案 使用Pycharm,總會建立幾個專案檔案,有時候又不想全部一個一個的開啟,所以這時候需要一個專案共存的方法,現在說一下怎麼專案共存。 中英文對照 英文:首先開啟setting介面: 中文首先開啟設定介面 然後就是選擇Project,下面

three.js 合併模型

      方法一:THREE.Geometry.merge()合併多個模型為一個 關鍵點:通過THREE.Geometry.merge()函式,你可以將多個幾何體合併起來建立一個聯合體 參考部落格:63 Three.js 將多個網格合併成一個網格

tensorflow enqueue_many傳入值的列表傳入異常問題————Shape () must have rank at least 1

tf 的佇列操作enqueue_many傳入的值是列表,但是放入[]列表拋異常 File "C:\Users\lihongjie\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\data_fl

在業務控制方法中寫入User,Admin模型收集參數

code prot style servle EDA dmi simple register ping 1) 可以在業務控制方法中書寫1個模型來收集客戶端的參數 2) 模型中的屬性名必須和客戶端參數名一一對應 3) 這裏說的模型不是Model對象,Model是向視

react native載入jsbundle(assets和其他目錄)

在使用ReactInstanceManager.Builder構建一個ReactInstanceManager例項的時候只能傳入一個bundle,setBundleAssetName和setJSBundleFile分別對應從assets和從一個檔案路徑載入Bundle。有時需要將業務程式碼和通用

tensorflow隨筆-讀取檔案

#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ Created on Sat Sep 15 10:54:53 2018 @author: myhaspl @email:[email protect

【ARToolkit】小發現:可以在一個patt裡面畫模型

       無意中的一個小發現,我在draw函式裡面更改模型的平移,旋轉,虛擬物體型別的時候,本來是把茶壺函式   glutSolidTeapot( 50.0 )。更改為正方體 glutSolidCube(50.0) 的時候,忘記把茶壺函式註釋掉,然後就直接執行了,結果發現

Pycharm中的載入專案

使用Pycharm,總會建立幾個專案檔案,有時候又不想全部一個一個的開啟,所以這時候需要一個專案共存的方法,現在說一下怎麼專案共存。 中英文對照 英文:首先開啟setting介面: 中文首先開啟設定介面 然後就是選擇Project,下面的Project S

【Spring】例項化上下文物件及載入配置檔案

一、例項化上下文物件 ApplicationContext ctx = new ClassPathXmlApplicationContext("applicationContext.xml"); Car car = (Car) ctx.getBean("car");

tensorflow 批量讀取csv檔案

tensorflow 批量讀取多個csv檔案 #!/usr/bin/python # -*- coding:utf-8 -*- import tensorflow as tf import os def csvfile(fileist): file_queue=tf.train.str

使用nginx載入tomcat實現session共享(負載均衡)

需要用到:   memcached 官網:http://memcached.org/ 用memcached實現session共享  tomcat叢集     以我的為例,我用的tomcat版本是 apache-tomcat-7.0.68 現將tomcat資料夾複

iOS【完美解決SDWebImage載入圖片記憶體崩潰的問題】

SDWebImage大家肯定都恨熟悉了,國內外太多的App使用其進行圖片載入。但是最近在使用過程中發現,我用SDWebImage載入多個圖片,類似微博動態那種,在載入的過程中。我發現當圖片解析度比較大的時候(不是圖片大),載入幾張圖片就崩潰了。網上說可以每次載入圖片清空mem