微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Python - 不同的常规/分析函数

如何解决Python - 不同的常规/分析函数

为了执行导数,我开发了以下代码

import matplotlib.pyplot as plt
import numpy as np
from math import *

xi = jnp.linspace(-3,3)

def f(x):
  a = x**3+5
  return a


g1i = jax.vmap(jax.grad(f))(xi)
g2i = jax.vmap(jax.grad(jax.grad(f)))(xi)
g3i = jax.vmap(jax.grad(jax.grad(jax.grad(f))))(xi)
plt.plot(xi,yi,label = "f")
plt.plot(xi,g1i,label = "f'")
plt.plot(xi,g2i,label = "f''")
plt.plot(xi,g3i,label = "f'''")
plt.legend()

这段代码有效,但现在我有兴趣应用以下代码来计算看涨价格的一阶导数,相对于基础资产(即 delta),尝试使用以下代码,但它没有作品:

import scipy.stats as si
import sympy as sy
import sys
xi = jnp.linspace(1,1.5)
def analytical_call(s0):
    T=1.
    q=0.
    r=0.
    k=1.
    sigma=0.4
    Kt = k*exp((q-r)*T)
    d = (log(Kt/s0)+(sigma**2)/2*T)/sigma
    result = (Kt * si.norm.cdf((d / sqrt(T)),0.0,1.0)  - s0 * si.norm.cdf(((d - sigma * T) / sqrt(T)),1.0)  ) * exp(-q * T) + exp(-q * T) * (s0 - Kt)
    return result
print(analytical_call(1))

g1i = jax.vmap(jax.grad(analytical_call))(xi)
g2i = jax.vmap(jax.grad(jax.grad(analytical_call)))(xi)
plt.plot(xi,label = "f'")
plt.legend()

你有什么提示吗?提前致谢!

解决方法

正如评论中已经提到的,您不能使用 jax 库之外的方法,例如 scipy.stats.norm.cdf。请改用 jax.scipy.stats。同样,将 expsqrt 替换为它们的 jax 等价物 jnp.expjnp.sqrt

from jax import jit,grad,vmap
import jax.numpy as jnp
from jax.scipy.stats.norm import cdf

def analytical_call(s0):
    T,q,r,k,sigma = 1.0,0.0,1.0,0.4
    Kt = k*jnp.exp((q-r)*T)
    d = (jnp.log(Kt/s0)+(sigma**2)/2*T)/sigma
    result = (Kt * cdf((d / jnp.sqrt(T)),1.0)  - s0 * cdf(((d - sigma * T) / jnp.sqrt(T)),1.0)  ) * jnp.exp(-q * T) + jnp.exp(-q * T) * (s0 - Kt)
    return result

g = vmap(grad(analytical_call))
h = vmap(grad(grad(analytical_call)))
xi = jnp.linspace(1,1.5)

然后,您可以计算 g(xi)h(xi)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。