微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

为什么当我的函数 jax.grad 中有 np.power 时不能给我导数?

如何解决为什么当我的函数 jax.grad 中有 np.power 时不能给我导数?

我想训练一个简单的线性模型。 x 和 y 下面的这些是我的数据。

import numpy as np
x = np.linspace(0,1,100)
y = 2 * x + 3 + np.random.randn(100)

f 是计算所有数据均方误差的函数

def f(params,x,y):
  return np.mean(np.power((params['w'] * x + params['b'])-y,2))

from jax import grad
df = grad(f)
params = dict()
#initialize parameters
params['w'] = 2.4
params['b'] = 10.
df(params,y) # I will do this in a loop (implementing gradient decent part

这给了我一个错误

FilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray

当我清除 np.power 代码时有效。为什么?

解决方法

JAX 无法计算 numpy 函数的梯度,但可以计算 jax.numpy 函数的梯度。如果您根据 jax.numpy 重写代码,它应该适合您:

import numpy as np
x = np.linspace(0,1,100)
y = 2 * x + 3 + np.random.randn(100)

import jax.numpy as jnp
def f(params,x,y):
  return jnp.mean(jnp.power((params['w'] * x + params['b'])-y,2))

from jax import grad
df = grad(f)
params = dict()

params['w'] = 2.4
params['b'] = 10.
df(params,y)
# {'b': DeviceArray(14.661432,dtype=float32),#  'w': DeviceArray(7.3792152,dtype=float32)}

您可以在 TracerArrayConversionError documentation page 中阅读更多详细信息。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。