如何解决Python - 不同的常规/分析函数
为了执行导数,我开发了以下代码:
import matplotlib.pyplot as plt
import numpy as np
from math import *
xi = jnp.linspace(-3,3)
def f(x):
a = x**3+5
return a
g1i = jax.vmap(jax.grad(f))(xi)
g2i = jax.vmap(jax.grad(jax.grad(f)))(xi)
g3i = jax.vmap(jax.grad(jax.grad(jax.grad(f))))(xi)
plt.plot(xi,yi,label = "f")
plt.plot(xi,g1i,label = "f'")
plt.plot(xi,g2i,label = "f''")
plt.plot(xi,g3i,label = "f'''")
plt.legend()
这段代码有效,但现在我有兴趣应用以下代码来计算看涨价格的一阶导数,相对于基础资产(即 delta),尝试使用以下代码,但它没有作品:
import scipy.stats as si
import sympy as sy
import sys
xi = jnp.linspace(1,1.5)
def analytical_call(s0):
T=1.
q=0.
r=0.
k=1.
sigma=0.4
Kt = k*exp((q-r)*T)
d = (log(Kt/s0)+(sigma**2)/2*T)/sigma
result = (Kt * si.norm.cdf((d / sqrt(T)),0.0,1.0) - s0 * si.norm.cdf(((d - sigma * T) / sqrt(T)),1.0) ) * exp(-q * T) + exp(-q * T) * (s0 - Kt)
return result
print(analytical_call(1))
g1i = jax.vmap(jax.grad(analytical_call))(xi)
g2i = jax.vmap(jax.grad(jax.grad(analytical_call)))(xi)
plt.plot(xi,label = "f'")
plt.legend()
你有什么提示吗?提前致谢!
解决方法
正如评论中已经提到的,您不能使用 jax 库之外的方法,例如 scipy.stats.norm.cdf
。请改用 jax.scipy.stats
。同样,将 exp
和 sqrt
替换为它们的 jax 等价物 jnp.exp
和 jnp.sqrt
:
from jax import jit,grad,vmap
import jax.numpy as jnp
from jax.scipy.stats.norm import cdf
def analytical_call(s0):
T,q,r,k,sigma = 1.0,0.0,1.0,0.4
Kt = k*jnp.exp((q-r)*T)
d = (jnp.log(Kt/s0)+(sigma**2)/2*T)/sigma
result = (Kt * cdf((d / jnp.sqrt(T)),1.0) - s0 * cdf(((d - sigma * T) / jnp.sqrt(T)),1.0) ) * jnp.exp(-q * T) + jnp.exp(-q * T) * (s0 - Kt)
return result
g = vmap(grad(analytical_call))
h = vmap(grad(grad(analytical_call)))
xi = jnp.linspace(1,1.5)
然后,您可以计算 g(xi)
和 h(xi)
。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。