微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在茱莉亚使用推广规则?

如何解决如何在茱莉亚使用推广规则?

我正在尝试编写一个结构来计算梯度(在https://www.youtube.com/watch?v=rZS2LGiurKY之后) 这是我到目前为止的内容

struct GRAD{F <: Array{Float64,2},∇F <:Array{Float64,2}}
    f::F
    ∇f::∇F
end

begin 
    import Base: +,*,-,^,/,convert,promote_rule,size,reshape,promote
    # addition rule 
    +(x::GRAD,y::GRAD) = GRAD(x.f+y.f,x.∇f+y.∇f) 
    -(x::GRAD,y::GRAD) = GRAD(x.f-y.f,x.∇f-y.∇f) 

    # multiplying by scalar
    *(y::Real,x::GRAD) = GRAD(x.f.*y,x.∇f.*y)
    *(x::GRAD,y::Real) = *(y::Real,x::GRAD) 
    # product rule 
    *(x::GRAD,y::GRAD)  = GRAD(x.f.*y.f,x.f.*y.∇f+ x.∇f.*y.f)

    convert(::Type{GRAD},x::Array) = GRAD(x,zero(x))    
    size(x::GRAD) = size(x.f)   
    Base.promote_rule(::Type{GRAD{F,∇F}},x::Type{<:Array}) = GRAD # bug is here!! 
end
A = rand(5,5)
r = rand(5,1)
b = rand(5,1)
g = GRAD(r,zeros(5,1) + [1 for i=1:5])

我想计算A*g(应为A*ones())的梯度, 但是当我这样做

> A*g
MethodError: no method matching *(::Array{Float64,::Main.workspace2861.GRAD{Array{Float64,Array{Float64,2}})
Closest candidates are:
*(::Any,::Any,!Matched::Any,!Matched::Any...) at operators.jl:538
*(!Matched::Real,::Main.workspace2861.GRAD) at /var/folders/2s/p1vy6rx91lsfh9ltgzz6j_lmb6r7gr/T/Unexpected invention.jl#==#c23631c4-0646-11eb-13be-3b5fa3514823:6
*(::Union{StridedArray{T,Linearalgebra.Adjoint{var"#s828",var"#s827"} where var"#s827"<:Union{StridedArray{T,Linearalgebra.LowerTriangular{T,S} where S<:AbstractArray{T,Linearalgebra.UnitLowerTriangular{T,Linearalgebra.UnitUpperTriangular{T,Linearalgebra.UpperTriangular{T,2}} where var"#s828",Linearalgebra.Transpose{var"#s826",var"#s825"} where var"#s825"<:Union{StridedArray{T,2}} where var"#s826",2}} where T,!Matched::Linearalgebra.Adjoint{var"#s828",var"#s827"} where var"#s827"<:SparseArrays.AbstractSparseMatrixCSC where var"#s828") at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/SparseArrays/src/linalg.jl:147

但是使用convert(GRAD,A) * g可以得到正确的结果。

我在做什么错了?

解决方法

我相信,对于属于promote子类型的参数,Base仅回退到Number。相反,您必须手动进行调度,例如

struct GRAD{F <: Array{Float64,2},∇F <:Array{Float64,2}}
    f::F
    ∇f::∇F
end

begin 
    import Base: +,*,-,^,/,convert,promote_rule,size,reshape,promote
    # addition rule 
    +(x::GRAD,y::GRAD) = GRAD(x.f+y.f,x.∇f+y.∇f) 
    -(x::GRAD,y::GRAD) = GRAD(x.f-y.f,x.∇f-y.∇f) 

    # multiplying by scalar
    *(y::Real,x::GRAD) = GRAD(x.f.*y,x.∇f.*y)
    *(x::GRAD,y::Real) = *(y::Real,x::GRAD) 
    # product rule 
    *(x::GRAD,y::GRAD)  = GRAD(x.f.*y.f,x.f.*y.∇f+ x.∇f.*y.f)
    *(x::GRAD,y::AbstractArray) = *(promote(x,y)...) #manually implement promotion
    *(x::AbstractArray,y::GRAD) = *(promote(x,y)...)

    convert(::Type{GRAD},x::Array) = GRAD(x,zero(x))    
    size(x::GRAD) = size(x.f)   
    Base.promote_rule(::Type{<:GRAD},x::Type{<:Array}) = GRAD #fixed
end

A = rand(5,5)
r = rand(5,1)
b = rand(5,1)
g = GRAD(r,zeros(5,1) + [1 for i=1:5])
julia> A*g
GRAD{Array{Float64,Array{Float64,2}}([0.22826090714985026 0.3029960652217887 … 0.04569934008285687 0.3480034221401326; 0.2263393729468651 0.09785205038459334 … 0.2354369234901423 0.03963994636800826; … ; 0.2465774394414207 0.04853374224132803 … 0.1316815422172956 0.41189932434750665; 0.07773901558602414 0.3714828548333624 … 0.07235526901207193 0.38751984258803623],[0.46212899620837633 0.6134351660317792 … 0.09252127498998997 0.7045554762696247; 0.7634551330528128 0.33006033892034314 … 0.7941416705740725 0.13370771569513296; … ; 0.40923528629708694 0.08054962346187167 … 0.21854689444181385 0.6836137900806378; 0.16866950942083414 0.8060023710186879 … 0.15698845214696422 0.8407976515709032])

julia> (A*g).∇f
5×5 Array{Float64,2}:
 0.462129  0.613435   0.833935  0.0925213  0.704555
 0.763455  0.33006    0.354147  0.794142   0.133708
 0.774017  0.347564   0.255648  0.725451   0.629586
 0.409235  0.0805496  0.1764    0.218547   0.683614
 0.16867   0.806002   0.21655   0.156988   0.840798

有帮助吗?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?