土拨鼠

查看: 26|回复: 0

tf.gradients()求导错误

[复制链接]

19

主题

22

帖子

206

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
206
发表于 2019-12-12 20:36:17 | 显示全部楼层 |阅读模式
最近在使用tf.gradients()进行求导时发现当使用tf.Variable() 会出现求导错误的情况。
例如以下代码:
  1. import tensorflow as tf

  2. w1 = tf.Variable([[1,2]])
  3. w2 = tf.Variable([[3,4]])

  4. res = tf.matmul(w1, [[2],[1]])

  5. grads = tf.gradients(res,[w1])

  6. with tf.Session() as sess:
  7.     tf.global_variables_initializer().run()
  8.     re = sess.run(grads)
  9.     print(re)
  10. #  [array([[2, 1]], dtype=int32)]
复制代码
此时grads的值为None,当执行re = sess.run(grads)会直接报错:TypeError: Fetch argument None has invalid type <class 'NoneType'>


但是当你修改一下代码:
  1. import tensorflow as tf

  2. w1 = tf.Variable([[1.0,2]])
  3. w2 = tf.Variable([[3,4]])

  4. res = tf.matmul(w1, [[2],[1]])

  5. grads = tf.gradients(res,[w1])

  6. with tf.Session() as sess:
  7. tf.global_variables_initializer().run()
  8. re = sess.run(grads)
  9. print(re)
  10. # [array([[2.0, 1.0]], dtype=float32)]
复制代码
在这里我将w1 = tf.Variable([[1,2]])修改为w1 = tf.Variable([[1.0,2]]) 使w1变为了浮点类型,此时再运行将不会报错
ps:w1 = tf.Variable([[1,2]])修改为w1 = tf.Variable([[1,2]],dtype=tf.flaot32)或w1 = tf.Variable([[1,2]],dtype=tf.flaot64)   只要是浮点类型均可。



TensorFlow1.9.0-1.15.0版本好像均存在这个问题
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表