使用Jax的偏导数?

如何解决使用Jax的偏导数?

我对 Jax 文档感到困惑,这是我想要做的:

def line(m,x,b):
  return m*x + b

grad(line)(1,2,3)

错误

---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-48-d14b17620b30> in <module>()
      3 
----> 4 grad(line)(1,3)

FilteredStackTrace: TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating),but got int32. If you want to use integer-valued inputs,use vjp or set allow_int to True.

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred,unmodified.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
6 frames
/usr/local/lib/python3.7/dist-packages/jax/api.py in _check_input_dtype_revderiv(name,holomorphic,allow_int,x)
    844   elif not allow_int and not (dtypes.issubdtype(aval.dtype,np.floating) or
    845                               dtypes.issubdtype(aval.dtype,np.complexfloating)):
--> 846     raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype that "
    847                     "is a sub-dtype of np.floating or np.complexfloating),"
    848                     f"but got {aval.dtype.name}. If you want to use integer-valued "

TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating),use vjp or set allow_int to True.

我引用了官方的 tutorial 代码

import jax.numpy as jnp
from jax import grad,jit,vmap
from jax import random

key = random.PRNGKey(0)

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W,b,inputs):
    return sigmoid(jnp.dot(inputs,W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52,1.12,0.77],[0.88,-1.08,0.15],[0.52,0.06,-1.30],[0.74,-2.49,1.39]])
targets = jnp.array([True,True,False,True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W,b):
    preds = predict(W,inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key,W_key,b_key = random.split(key,3)
W = random.normal(W_key,(3,))
b = random.normal(b_key,())

W_grad = grad(loss,argnums=0)(W,b)
print('W_grad',W_grad)

结果:

W_grad [-0.16965576 -0.8774648  -1.4901345 ]

在这里做错了什么?我认为 key 正在以某种重要的方式被使用,但我不知道为什么/如何需要它。要回答这个问题,请根据需要调整第一个块中的代码以消除错误

解决方法

Jax 告​​诉你它不喜欢整数。 grad(line)(1.,2.,3.)(使用浮动)修复了问题。

,

我认为这里的错误很明显:

TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating),but got int32. If you want to use integer-valued inputs,use vjp or set allow_int to True.

要将 grad(line)(1,2,3)Int32 一起使用,请将其更改为 grad(line,allow_int=True)(1,3)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?