如何解决python - lambda 表达式中未使用参数的解释
问题
请帮助理解这个 lambda 定义是什么,其中参数 w
没有出现在表达式部分。
loss_w = lambda w: loss(x,t) # <---- parameter w is not used in the expression
gradient(loss_w,W)
loss_w 被称为 f(arg)
,只有一个参数,而 loss
有两个参数。
def gradient(f,arg):
...
fh2: float = f(arg) # <--- how come 'loss' can take one argument as f(W)?
据我所知,该参数将用于表达式中,例如lambda w: w**2
。
lambda_expr ::= "lambda" [parameter_list] ":" expression
lambda_expr_nocond ::= "lambda" [parameter_list] ":" expression_nocond
Lambda expressions (sometimes called lambda forms) are used to create anonymous functions.
The expression lambda parameters: expression yields a function object. The unnamed object
behaves like a function object defined with:
def <lambda>(parameters):
return expression
代码
- cs231 Gradient Checks
import numpy as np
W = 0.01 * np.random.randn(2,3)
def relu(x):
return np.maximum(0,x)
def loss(x,t): # <---------- 'loss' function
global W
a = relu(np.matmul(X,W.T))
a:float = a - np.max(a,axis=-1,keepdims=True)
p:float = np.exp(a) / np.sum(np.exp(a),keepdims=True)
batch_size = p.shape[0]
return -np.sum(np.log(p[np.arange(batch_size),t] + 1e-7)) / batch_size
def gradient(f,arg): # <---------- 'loss' function is passed as f
h:float = 1e-4 # 0.0001
grad = np.zeros_like(arg,dtype=float)
it = np.nditer(arg,flags=['multi_index'],op_flags=['readwrite'])
while not it.finished:
idx = it.multi_index
tmp_val = arg[idx]
# f(x+h)
arg[idx] = tmp_val + h
fh1: float = f(arg) # <--- why loss(x,t) can only take one argument?
# f(x-h)
arg[idx] = tmp_val - h
fh2: float = f(arg)
grad[idx] = (fh1 - fh2) / (2*h)
arg[idx] = tmp_val
it.iternext()
return grad
def numerical_gradient(x,t):
t = t.reshape(1,t.size) if t.ndim == 1 else t
x = x.reshape(1,y.size) if x.ndim == 1 else x
loss_w = lambda w: loss(x,t) # <----- What is this w?
global W
return gradient(loss_w,W)
原始代码是实现 cs231 的 two_layer_net.py。
更新
现在我确信 W
是一个冗余参数。 f(x)
的原始意图是明确的,因此尝试将权重参数 W
作为 f(W)
传递。或许应该是lambda w: loss(w,t)
。
解决方法
通常,您可以编写一个根本不使用其参数的 lambda 函数。例如:
>>> a = lambda x,y:1+2
>>> a(1,2)
3
>>> a(None,4)
3
>>> a()
Traceback (most recent call last):
File "<stdin>",line 1,in <module>
TypeError: <lambda>() missing 2 required positional arguments: 'x' and 'y'
在您的问题中,您问过“为什么 loss(x,t) 只能采用一个参数?”。因为 lambda 函数只接受一个参数,甚至不使用它,而是将其他参数提供给 loss
方法。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。