如何为 softmax

如何解决如何为 softmax

为了理解 JAX 的反向模式自动差异,我尝试为 softmax 编写一个 custom_vjp,如下所示:

import jax
import jax.numpy as jnp
import numpy as np

@jax.custom_vjp
def stablesoftmax(x):
    print(f"input: {x} shape: {x.shape}")
    expc = jnp.exp(x - jnp.amax(x))
    return expc / jnp.sum(expc)

def ssm_fwd(x):
    s = stablesoftmax(x)
    return s,s

def ssm_bwd(acts,d_dacts):
    dacts_dinput = jnp.diag(acts) - jnp.outer(acts,acts)  # Jacobian
    d_dinput = jnp.dot(d_dacts,dacts_dinput)  # Vector-Jacobian product
    print(f"Saved activations:\n{acts} shape: {acts.shape}")
    print(f"d/d_acts:\n{d_dacts} shape: {d_dacts.shape}")
    print(f"d_acts/d_input (Jacobian of softmax):\n{dacts_dinput} shape: {dacts_dinput.shape}")
    print(f"d/d_input:\n{d_dinput} shape: {d_dinput.shape}")
    return d_dinput

stablesoftmax.defvjp(ssm_fwd,ssm_bwd)

print(f"JAX version: {jax.__version__}")
y = np.array([1.,2.,3.])
a = stablesoftmax(y)
softmax_jac_fun = jax.jacrev(stablesoftmax)
dsoftmax_dy = softmax_jac_fun(y)
print(f"softmax Jacobian: {dsoftmax_dy}")

但是当我调用 jacrev 时,我得到关于 VJP 结果的结构与 softmax 的输入结构不匹配的错误

JAX version: 0.2.13
input: [1. 2. 3.] shape: (3,)
WARNING:absl:No GPU/TPU found,falling back to cpu. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
input: [1. 2. 3.] shape: (3,)
Saved activations:
[0.09003057 0.24472848 0.66524094] shape: (3,)
d/d_acts:
Traced<ShapedArray(float32[3])>with<BatchTrace(level=1/0)>
  with val = array([[1.,0.,0.],[0.,1.,1.]],dtype=float32)
       batch_dim = 0 shape: (3,)
d_acts/d_input (Jacobian of softmax):
[[ 0.08192507 -0.02203305 -0.05989202]
 [-0.02203305  0.18483645 -0.1628034 ]
 [-0.05989202 -0.1628034   0.22269544]] shape: (3,3)
d/d_input:
Traced<ShapedArray(float32[3])>with<BatchTrace(level=1/0)>
  with val = DeviceArray([[ 0.08192507,-0.02203305,-0.05989202],[-0.02203305,0.18483645,-0.1628034 ],[-0.05989202,-0.1628034,0.22269544]],)
Traceback (most recent call last):
  File "analysis/vjp_test.py",line 30,in <module>
    dsoftmax_dy = softmax_jac_fun(y)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: Custom VJP rule must produce an output with the same container (pytree) structure as the args tuple of the primal function,and in particular must produce a tuple of length equal to the number of arguments to the primal function,but got VJP output structure PyTreeDef(*) for primal input structure PyTreeDef((*,)).

但是当我打印形状时您可以看到它们都有形状 (3,) 但 JAX 似乎不同意? (实际上,输入和输出是 3 x 3 矩阵,但这是因为 JAX 试图对 jacrev 中的 JVP 进行 vmap,因此一次性拉回 R(3) 的整个基础(即 3x3 单位矩阵)。

注意:如果我直接使用 jax.grad 或 jax.vjp,我会得到同样的错误

解决方法

根据custom_vjp docs

bwd 的输出必须是长度等于原始函数参数数量的元组

所以反向传递中的 return 语句应该是这样的:

def ssm_bwd(acts,d_dacts):
    ...
    return (d_dinput,)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?