参考代码:
import tensorflow as tf a = tf.random.uniform((2, 2)) b = tf.zeros((2, 2)) mask = tf.constant([[True, False], [False, True]]) c = tf.where(mask, a, b) print(c) # tf.Tensor( # [[0.9454906 0. ] # [0. 0.8972126]], shape=(2, 2), dtype=float32)
浙公网安备 33010602006230号
浙ICP备14015892号