1. 程式人生 > >tensorflow條件語句-tf.case

tensorflow條件語句-tf.case

tf.case
tf.case(
    pred_fn_pairs,
    default=None,
    exclusive=False,
    strict=False,
    name='case'
)

建立case操作

pred_fn_pairs引數是大小為N的字典或pairs的列表。每對包含一個布林標量tensor和一個python可呼叫函式項,當條件為True將返回對應的函式項建立的tensors。在pred_fn_pairs對中的所有呼叫子以及預設值(如果提供的話)都應該返回相同數量和型別的張量。

如果exclusive==True,則計算所有的謂詞,如果多個謂詞計算為True,則引發異常。如果exclusive==False,則執行在求值為True的第一個謂詞處停止,並且立即返回由相應函式生成的張量(tensors)。如果沒有一個謂詞評估為true,則此操作返回預設生成的張量。

tf.case支援在tensorflow.python.util.nest中實現的巢狀結構。所有的呼叫都必須返回相同的(可能巢狀的)列表、元組和/或命名元組的值結構。單例列表和元組形成了唯一的例外:當由可呼叫程式返回時,它們被隱式地解包為單個值。通過通過傳遞strict=True禁用此行為。

如果無序字典在pred_fn_pairs使用,則條件測試的順序不能保證。不管怎麼樣,順序保證是確定的,以便在條件分支中變數按固定順序被建立。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Mon Aug 27 11:16:32 2018
@author: myhaspl
"""

import tensorflow as tf
x = tf.constant(7)
y =  tf.constant(27)
f1 = lambda: tf.constant(17)
f2 = lambda: tf.constant(23)
r = tf.case([(tf.less(x, y), f1)], default=f2)
#if (x < y) return 17;
#else return 23;
sess=tf.Session()
with sess:

    print sess.run(r) 

17

Example 2:

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Mon Aug 27 11:16:32 2018
@author: myhaspl
"""

import tensorflow as tf
x = tf.constant(7)
y =  tf.constant(27)
z = tf.constant(21)
def f1(): 
    return tf.constant(17)
def f2(): 
    return tf.constant(23)
def f3(): 
    return tf.constant(-1)
r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},default=f3, exclusive=True);

sess=tf.Session()
with sess:

    print sess.run(r)

相當於:

if (x < y && x > z) raise OpError("Only one predicate may evaluate true");
if (x < y) return 17;
else if (x > z) return 23;
else return -1;

引數:

pred_fn_pairs: 字典或pairs的列表(由boolean標量和可呼叫函式返回張量列表)
default: 預設返回tensors列表
exclusive: 如果為True表示僅最多一個謂語為True 
strict: boolean開啟或關閉'strict' 模式
name: 操作的名字(可選)
返回:

第一個謂詞為True時執行返回的tensors,如果沒有謂詞為True,則返回default

Raises:

TypeError: pred_fn_pairs 不是一個列表或字典
TypeError: pred_fn_pairs是一個列表,但不包括2個元素的元組 
TypeError: fns[i] 不是任何i的呼叫,或default不是可呼叫的