如何解决如何避免自定义 Julia 迭代器中的内存分配?
考虑以下 Julia“复合”迭代器:它合并了两个迭代器,a
和 b
,
假设每一个都根据 order
排序到一个单一的有序
顺序:
struct MergeSorted{T,A,B,O}
a::A
b::B
order::O
MergeSorted(a::A,b::B,order::O=Base.Order.Forward) where {A,O} =
new{promote_type(eltype(A),eltype(B)),O}(a,b,order)
end
Base.eltype(::Type{MergeSorted{T,O}}) where {T,O} = T
@inline function Base.iterate(self::MergeSorted{T},state=(iterate(self.a),iterate(self.b))) where T
a_result,b_result = state
if b_result === nothing
a_result === nothing && return nothing
a_curr,a_state = a_result
return T(a_curr),(iterate(self.a,a_state),b_result)
end
b_curr,b_state = b_result
if a_result !== nothing
a_curr,a_state = a_result
Base.Order.lt(self.order,a_curr,b_curr) &&
return T(a_curr),b_result)
end
return T(b_curr),(a_result,iterate(self.b,b_state))
end
此代码有效,但类型不稳定,因为 Julia 迭代工具本身就是如此。在大多数情况下,编译器可以自动解决这个问题,但是,在这里它不起作用:以下测试代码说明了临时文件的创建:
>>> x = MergeSorted([1,4,5,9,32,44],[0,7,24,134]);
>>> sum(x);
>>> @time sum(x);
0.000013 seconds (61 allocations: 2.312 KiB)
注意分配计数。
除了玩弄代码并希望编译器能够优化类型歧义之外,还有什么方法可以有效地调试这种情况?有谁知道在这种特定情况下有没有不创建临时文件的解决方案?
解决方法
如何诊断问题?
答案:使用@code_warntype
运行:
julia> @code_warntype iterate(x,iterate(x)[2])
Variables
#self#::Core.Const(iterate)
self::MergeSorted{Int64,Vector{Int64},Base.Order.ForwardOrdering}
state::Tuple{Tuple{Int64,Int64},Tuple{Int64,Int64}}
@_4::Int64
@_5::Int64
@_6::Union{}
@_7::Int64
b_state::Int64
b_curr::Int64
a_state::Int64
a_curr::Int64
b_result::Tuple{Int64,Int64}
a_result::Tuple{Int64,Int64}
Body::Tuple{Int64,Any}
1 ─ nothing
│ Core.NewvarNode(:(@_4))
│ Core.NewvarNode(:(@_5))
│ Core.NewvarNode(:(@_6))
│ Core.NewvarNode(:(b_state))
│ Core.NewvarNode(:(b_curr))
│ Core.NewvarNode(:(a_state))
│ Core.NewvarNode(:(a_curr))
│ %9 = Base.indexed_iterate(state,1)::Core.PartialStruct(Tuple{Tuple{Int64,Any[Tuple{Int64,Core.Const(2)])
│ (a_result = Core.getfield(%9,1))
│ (@_7 = Core.getfield(%9,2))
│ %12 = Base.indexed_iterate(state,2,@_7::Core.Const(2))::Core.PartialStruct(Tuple{Tuple{Int64,Core.Const(3)])
│ (b_result = Core.getfield(%12,1))
│ %14 = (b_result === Main.nothing)::Core.Const(false)
└── goto #3 if not %14
2 ─ Core.Const(:(a_result === Main.nothing))
│ Core.Const(:(%16))
│ Core.Const(:(return Main.nothing))
│ Core.Const(:(Base.indexed_iterate(a_result,1)))
│ Core.Const(:(a_curr = Core.getfield(%19,1)))
│ Core.Const(:(@_6 = Core.getfield(%19,2)))
│ Core.Const(:(Base.indexed_iterate(a_result,@_6)))
│ Core.Const(:(a_state = Core.getfield(%22,1)))
│ Core.Const(:(($(Expr(:static_parameter,1)))(a_curr)))
│ Core.Const(:(Base.getproperty(self,:a)))
│ Core.Const(:(Main.iterate(%25,a_state)))
│ Core.Const(:(Core.tuple(%26,b_result)))
│ Core.Const(:(Core.tuple(%24,%27)))
└── Core.Const(:(return %28))
3 ┄ %30 = Base.indexed_iterate(b_result,1)::Core.PartialStruct(Tuple{Int64,Any[Int64,Core.Const(2)])
│ (b_curr = Core.getfield(%30,1))
│ (@_5 = Core.getfield(%30,2))
│ %33 = Base.indexed_iterate(b_result,@_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64,Core.Const(3)])
│ (b_state = Core.getfield(%33,1))
│ %35 = (a_result !== Main.nothing)::Core.Const(true)
└── goto #6 if not %35
4 ─ %37 = Base.indexed_iterate(a_result,Core.Const(2)])
│ (a_curr = Core.getfield(%37,1))
│ (@_4 = Core.getfield(%37,2))
│ %40 = Base.indexed_iterate(a_result,@_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64,Core.Const(3)])
│ (a_state = Core.getfield(%40,1))
│ %42 = Base.Order::Core.Const(Base.Order)
│ %43 = Base.getproperty(%42,:lt)::Core.Const(Base.Order.lt)
│ %44 = Base.getproperty(self,:order)::Core.Const(Base.Order.ForwardOrdering())
│ %45 = a_curr::Int64
│ %46 = (%43)(%44,%45,b_curr)::Bool
└── goto #6 if not %46
5 ─ %48 = ($(Expr(:static_parameter,1)))(a_curr)::Int64
│ %49 = Base.getproperty(self,:a)::Vector{Int64}
│ %50 = Main.iterate(%49,a_state)::Union{Nothing,Int64}}
│ %51 = Core.tuple(%50,b_result)::Tuple{Union{Nothing,Int64}},Int64}}
│ %52 = Core.tuple(%48,%51)::Tuple{Int64,Tuple{Union{Nothing,Int64}}}
└── return %52
6 ┄ %54 = ($(Expr(:static_parameter,1)))(b_curr)::Int64
│ %55 = a_result::Tuple{Int64,Int64}
│ %56 = Base.getproperty(self,:b)::Vector{Int64}
│ %57 = Main.iterate(%56,b_state)::Union{Nothing,Int64}}
│ %58 = Core.tuple(%55,%57)::Tuple{Tuple{Int64,Union{Nothing,Int64}}}
│ %59 = Core.tuple(%54,%58)::Tuple{Int64,Tuple{Tuple{Int64,Int64}}}}
└── return %59
并且您看到返回值的类型太多,因此 Julia 放弃了专门化它们(并假设返回类型的第二个元素是 Any
)。
如何解决问题?
答案:减少 iterate
的返回类型选项的数量。
这是一篇快速的文章(我并不认为它是最简洁的,也没有对其进行广泛的测试,所以可能存在一些错误,但它很简单,可以使用您的代码快速编写以展示如何解决您的问题; 请注意,当其中一个集合为空时,我会使用特殊分支,因为只迭代一个集合应该会更快):
struct MergeSorted{T,A,B,O,F1,F2}
a::A
b::B
order::O
fa::F1
fb::F2
function MergeSorted(a::A,b::B,order::O=Base.Order.Forward) where {A,O}
fa,fb = iterate(a),iterate(b)
F1 = typeof(fa)
F2 = typeof(fb)
new{promote_type(eltype(A),eltype(B)),F2}(a,b,order,fa,fb)
end
end
Base.eltype(::Type{MergeSorted{T,O}}) where {T,O} = T
struct State{Ta,Tb}
a::Union{Nothing,Ta}
b::Union{Nothing,Tb}
end
function Base.iterate(self::MergeSorted{T,Nothing,Nothing}) where {T,O}
return nothing
end
function Base.iterate(self::MergeSorted{T,F1}
return self.fa
end
function Base.iterate(self::MergeSorted{T,Nothing},state) where {T,F1}
return iterate(self.a,state)
end
function Base.iterate(self::MergeSorted{T,F2}) where {T,F2}
return self.fb
end
function Base.iterate(self::MergeSorted{T,F2},F2}
return iterate(self.b,state)
end
@inline function Base.iterate(self::MergeSorted{T,F2}
a_result,b_result = self.fa,self.fb
return iterate(self,State{F1,F2}(a_result,b_result))
end
@inline function Base.iterate(self::MergeSorted{T,state::State{F1,b_result = state.a,state.b
if b_result === nothing
a_result === nothing && return nothing
a_curr,a_state = a_result
return T(a_curr),F2}(iterate(self.a,a_state),b_result)
end
b_curr,b_state = b_result
if a_result !== nothing
a_curr,a_state = a_result
Base.Order.lt(self.order,a_curr,b_curr) &&
return T(a_curr),b_result)
end
return T(b_curr),iterate(self.b,b_state))
end
现在你有:
julia> x = MergeSorted([1,4,5,9,32,44],[0,7,24,134]);
julia> sum(x)
269
julia> @allocated sum(x)
0
julia> @code_warntype iterate(x,Base.Order.ForwardOrdering,Int64}}
state::State{Tuple{Int64,Int64}}
@_4::Int64
@_5::Int64
@_6::Int64
b_state::Int64
b_curr::Int64
a_state::Int64
a_curr::Int64
b_result::Union{Nothing,Int64}}
a_result::Union{Nothing,Int64}}
Body::Union{Nothing,State{Tuple{Int64,Int64}}}}
1 ─ nothing
│ Core.NewvarNode(:(@_4))
│ Core.NewvarNode(:(@_5))
│ Core.NewvarNode(:(@_6))
│ Core.NewvarNode(:(b_state))
│ Core.NewvarNode(:(b_curr))
│ Core.NewvarNode(:(a_state))
│ Core.NewvarNode(:(a_curr))
│ %9 = Base.getproperty(state,:a)::Union{Nothing,Int64}}
│ %10 = Base.getproperty(state,:b)::Union{Nothing,Int64}}
│ (a_result = %9)
│ (b_result = %10)
│ %13 = (b_result === Main.nothing)::Bool
└── goto #5 if not %13
2 ─ %15 = (a_result === Main.nothing)::Bool
└── goto #4 if not %15
3 ─ return Main.nothing
4 ─ %18 = Base.indexed_iterate(a_result::Tuple{Int64,Core.Const(2)])
│ (a_curr = Core.getfield(%18,1))
│ (@_6 = Core.getfield(%18,2))
│ %21 = Base.indexed_iterate(a_result::Tuple{Int64,@_6::Core.Const(2))::Core.PartialStruct(Tuple{Int64,Core.Const(3)])
│ (a_state = Core.getfield(%21,1))
│ %23 = ($(Expr(:static_parameter,1)))(a_curr)::Int64
│ %24 = Core.apply_type(Main.State,$(Expr(:static_parameter,5)),6)))::Core.Const(State{Tuple{Int64,Int64}})
│ %25 = Base.getproperty(self,:a)::Vector{Int64}
│ %26 = Main.iterate(%25,Int64}}
│ %27 = (%24)(%26,b_result::Core.Const(nothing))::State{Tuple{Int64,Int64}}
│ %28 = Core.tuple(%23,%27)::Tuple{Int64,Int64}}}
└── return %28
5 ─ %30 = Base.indexed_iterate(b_result::Tuple{Int64,2))
│ %33 = Base.indexed_iterate(b_result::Tuple{Int64,1))
│ %35 = (a_result !== Main.nothing)::Bool
└── goto #8 if not %35
6 ─ %37 = Base.indexed_iterate(a_result::Tuple{Int64,2))
│ %40 = Base.indexed_iterate(a_result::Tuple{Int64,b_curr)::Bool
└── goto #8 if not %46
7 ─ %48 = ($(Expr(:static_parameter,1)))(a_curr)::Int64
│ %49 = Core.apply_type(Main.State,Int64}})
│ %50 = Base.getproperty(self,:a)::Vector{Int64}
│ %51 = Main.iterate(%50,Int64}}
│ %52 = (%49)(%51,b_result::Tuple{Int64,Int64})::State{Tuple{Int64,Int64}}
│ %53 = Core.tuple(%48,%52)::Tuple{Int64,Int64}}}
└── return %53
8 ┄ %55 = ($(Expr(:static_parameter,1)))(b_curr)::Int64
│ %56 = Core.apply_type(Main.State,Int64}})
│ %57 = a_result::Union{Nothing,Int64}}
│ %58 = Base.getproperty(self,:b)::Vector{Int64}
│ %59 = Main.iterate(%58,Int64}}
│ %60 = (%56)(%57,%59)::State{Tuple{Int64,Int64}}
│ %61 = Core.tuple(%55,%60)::Tuple{Int64,Int64}}}
└── return %61
编辑:现在我意识到我的实现并不完全正确,因为它假设 iterate
的返回值如果不是 nothing
是类型稳定的(它不必是)。但是如果它不是类型稳定的,那么编译器必须分配。因此,完全正确的解决方案将首先检查 iterate 是否类型稳定。如果是 - 使用我的解决方案,如果不是 - 使用例如您的解决方案。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。