tensorflow2.0 where函数使用

参考代码:

import tensorflow as tf

a = tf.random.uniform((2, 2))
b = tf.zeros((2, 2))
mask = tf.constant([[True, False], [False, True]])
c = tf.where(mask, a, b)
print(c)
# tf.Tensor(
# [[0.9454906 0.       ]
#  [0.        0.8972126]], shape=(2, 2), dtype=float32)


标签: 、面试
  • 回复
隐藏