如何解决如何从 JAX 中的数据集读取数据
我是 JAX 的新手。我在下面有这个代码,它有“特征矩阵”作为一个数组,“目标向量”作为一个数组。但我不希望程序读取这些数据数组。这些数组已经在代码中了。我想修改代码,以便它读取我导入的波士顿房价数据集。有人能告诉我我需要对这段代码做哪些修改才能使线性回归工作吗?
import jax.numpy as np
from jax import grad,jit
from sklearn.datasets import load_boston
import sklearn.linear_model as sk
boston = load_boston()
X = np.array(boston.data)
y = np.array(boston.target)
def J(X,w,b,y):
"""Cost function for a linear regression. A forward pass of our model.
Args:
X: a features matrix.
w: weights (a column vector).
b: a bias.
y: a target vector.
Returns:
scalar: a cost of this solution.
"""
y_hat = X.dot(w) + b # Predict values.
return ((y_hat - y)**2).mean() # Return cost.
# A features matrix.
X = np.array([
[4.,7.],[1.,8.],[-5.,-6.],[3.,-1.],[0.,9.]
])
# A target column vector.
y = np.array([
[37.],[24.],[-34.],[16.],[21.]
])
learning_rate = 0.01
w = np.zeros((2,1))
b = 0.
%timeit grad(J,argnums=1)(X,y)
%timeit grad(J,argnums=2)(X,y)
for i in range(100):
w -= learning_rate * grad(J,y)
b -= learning_rate * grad(J,y)
if i % 10 == 0:
print(J(X,y))
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。