如何解决随机梯度下降算法的值误差
我无法运行我的 SGD 代码,我不知道问题出在哪里。如果你能帮助我,那就太好了。这是我的代码:
class logistic_regression:
def __init__(self,X,y_actual,alpha,max_iter,batch_size):
self.X = X
self.y_actual = y_actual
self.alpha = alpha
self.max_iter = max_iter
self.batch_size = batch_size
def sigmoid(self,z):
return 1/(1+np.exp(-z))
def predictor(self,theta,X):
predictions = np.matmul(X,theta)
sigmoidal_prediction = self.sigmoid(predictions)
return(sigmoidal_prediction)
def loss(self,h,y):
h = h + 1e-9
h = np.array(h,dtype=np.complex128)
y = np.array(y,dtype=np.complex128)
h = h.flatten()
y = y.flatten()
return (-((y*np.log(h))-((1-y)*np.log(1-h)))).mean()
def stochastic_gradient_descent(self):
X1 = np.matrix(sm.add_constant(self.X))
m,n = X1.shape
y_actual = self.y_actual.to_numpy().reshape(m,1)
Xy = np.c_[X1,y_actual]
# Initializing the random number generator
rng = np.random.default_rng(seed=123)
theta = np.ones((n,1))
predictions = None
for i in range(0,self.max_iter):
rng.shuffle(Xy) # Shuffle X and y
# Performing minibatch moves
for i in range(self.batch_size):
j = i + self.batch_size
X_batch,y_batch = Xy[i:j,:-1],Xy[i:j,-1:]
predictions = self.predictor(theta,X_batch)
gradient = np.matmul(np.transpose(X_batch),(predictions-y_batch))/self.batch_size
theta = theta - self.alpha*gradient
f1 = metrics.f1_score(y_actual,np.around(predictions),labels=None,pos_label=1,average='binary',sample_weight=None)
ceo = self.loss(predictions,y_actual)
print("\nCross Entropy: %f" % (ceo),"\nAlpha = %s" % self.alpha,"\nIterations: %s" % self.max_iter,"\nF1 Score: ",f1)
return(theta)
def classifier(self,threshold=0.5):
X1 = np.matrix(sm.add_constant(self.X))
theta = self.stochastic_gradient_descent()
return [1 if i >= threshold else 0 for i in self.predictor(theta,X1)]
我调用这个函数:
log_reg = logistic_regression(X_test_std,y_test,0.01,100,2)
print(log_reg.classifier())
但是出现了值错误:
ValueError: 发现输入变量的样本数量不一致:[1151,2]
尺寸问题位于 f1
中的 ceo
和 def stochastic_gradient_descent(self)
。但我不知道如何解决这个问题。你能给我一些提示吗?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。