微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何避免自定义 Julia 迭代器中的内存分配?

如何解决如何避免自定义 Julia 迭代器中的内存分配?

考虑以下 Julia“复合”迭代器:它合并了两个迭代器,ab, 假设每一个都根据 order 排序到一个单一的有序 顺序:

struct MergeSorted{T,A,B,O}
    a::A
    b::B
    order::O

    MergeSorted(a::A,b::B,order::O=Base.Order.Forward) where {A,O} =
        new{promote_type(eltype(A),eltype(B)),O}(a,b,order)
end

Base.eltype(::Type{MergeSorted{T,O}}) where {T,O} = T

@inline function Base.iterate(self::MergeSorted{T},state=(iterate(self.a),iterate(self.b))) where T
    a_result,b_result = state
    if b_result === nothing
        a_result === nothing && return nothing
        a_curr,a_state = a_result
        return T(a_curr),(iterate(self.a,a_state),b_result)
    end

    b_curr,b_state = b_result
    if a_result !== nothing
        a_curr,a_state = a_result
        Base.Order.lt(self.order,a_curr,b_curr) &&
            return T(a_curr),b_result)
    end
    return T(b_curr),(a_result,iterate(self.b,b_state))
end

代码有效,但类型不稳定,因为 Julia 迭代工具本身就是如此。在大多数情况下,编译器可以自动解决这个问题,但是,在这里它不起作用:以下测试代码说明了临时文件的创建:

>>> x = MergeSorted([1,4,5,9,32,44],[0,7,24,134]);
>>> sum(x);
>>> @time sum(x);
0.000013 seconds (61 allocations: 2.312 KiB)

注意分配计数。

除了玩弄代码并希望编译器能够优化类型歧义之外,还有什么方法可以有效地调试这种情况?有谁知道在这种特定情况下有没有不创建临时文件解决方案?

解决方法

如何诊断问题?

答案:使用@code_warntype

运行:

julia> @code_warntype iterate(x,iterate(x)[2])
Variables
  #self#::Core.Const(iterate)
  self::MergeSorted{Int64,Vector{Int64},Base.Order.ForwardOrdering}
  state::Tuple{Tuple{Int64,Int64},Tuple{Int64,Int64}}
  @_4::Int64
  @_5::Int64
  @_6::Union{}
  @_7::Int64
  b_state::Int64
  b_curr::Int64
  a_state::Int64
  a_curr::Int64
  b_result::Tuple{Int64,Int64}
  a_result::Tuple{Int64,Int64}

Body::Tuple{Int64,Any}
1 ─       nothing
│         Core.NewvarNode(:(@_4))
│         Core.NewvarNode(:(@_5))
│         Core.NewvarNode(:(@_6))
│         Core.NewvarNode(:(b_state))
│         Core.NewvarNode(:(b_curr))
│         Core.NewvarNode(:(a_state))
│         Core.NewvarNode(:(a_curr))
│   %9  = Base.indexed_iterate(state,1)::Core.PartialStruct(Tuple{Tuple{Int64,Any[Tuple{Int64,Core.Const(2)])
│         (a_result = Core.getfield(%9,1))
│         (@_7 = Core.getfield(%9,2))
│   %12 = Base.indexed_iterate(state,2,@_7::Core.Const(2))::Core.PartialStruct(Tuple{Tuple{Int64,Core.Const(3)])
│         (b_result = Core.getfield(%12,1))
│   %14 = (b_result === Main.nothing)::Core.Const(false)
└──       goto #3 if not %14
2 ─       Core.Const(:(a_result === Main.nothing))
│         Core.Const(:(%16))
│         Core.Const(:(return Main.nothing))
│         Core.Const(:(Base.indexed_iterate(a_result,1)))
│         Core.Const(:(a_curr = Core.getfield(%19,1)))
│         Core.Const(:(@_6 = Core.getfield(%19,2)))
│         Core.Const(:(Base.indexed_iterate(a_result,@_6)))
│         Core.Const(:(a_state = Core.getfield(%22,1)))
│         Core.Const(:(($(Expr(:static_parameter,1)))(a_curr)))
│         Core.Const(:(Base.getproperty(self,:a)))
│         Core.Const(:(Main.iterate(%25,a_state)))
│         Core.Const(:(Core.tuple(%26,b_result)))
│         Core.Const(:(Core.tuple(%24,%27)))
└──       Core.Const(:(return %28))
3 ┄ %30 = Base.indexed_iterate(b_result,1)::Core.PartialStruct(Tuple{Int64,Any[Int64,Core.Const(2)])
│         (b_curr = Core.getfield(%30,1))
│         (@_5 = Core.getfield(%30,2))
│   %33 = Base.indexed_iterate(b_result,@_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64,Core.Const(3)])
│         (b_state = Core.getfield(%33,1))
│   %35 = (a_result !== Main.nothing)::Core.Const(true)
└──       goto #6 if not %35
4 ─ %37 = Base.indexed_iterate(a_result,Core.Const(2)])
│         (a_curr = Core.getfield(%37,1))
│         (@_4 = Core.getfield(%37,2))
│   %40 = Base.indexed_iterate(a_result,@_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64,Core.Const(3)])
│         (a_state = Core.getfield(%40,1))
│   %42 = Base.Order::Core.Const(Base.Order)
│   %43 = Base.getproperty(%42,:lt)::Core.Const(Base.Order.lt)
│   %44 = Base.getproperty(self,:order)::Core.Const(Base.Order.ForwardOrdering())
│   %45 = a_curr::Int64
│   %46 = (%43)(%44,%45,b_curr)::Bool
└──       goto #6 if not %46
5 ─ %48 = ($(Expr(:static_parameter,1)))(a_curr)::Int64
│   %49 = Base.getproperty(self,:a)::Vector{Int64}
│   %50 = Main.iterate(%49,a_state)::Union{Nothing,Int64}}
│   %51 = Core.tuple(%50,b_result)::Tuple{Union{Nothing,Int64}},Int64}}
│   %52 = Core.tuple(%48,%51)::Tuple{Int64,Tuple{Union{Nothing,Int64}}}
└──       return %52
6 ┄ %54 = ($(Expr(:static_parameter,1)))(b_curr)::Int64
│   %55 = a_result::Tuple{Int64,Int64}
│   %56 = Base.getproperty(self,:b)::Vector{Int64}
│   %57 = Main.iterate(%56,b_state)::Union{Nothing,Int64}}
│   %58 = Core.tuple(%55,%57)::Tuple{Tuple{Int64,Union{Nothing,Int64}}}
│   %59 = Core.tuple(%54,%58)::Tuple{Int64,Tuple{Tuple{Int64,Int64}}}}
└──       return %59

并且您看到返回值的类型太多,因此 Julia 放弃了专门化它们(并假设返回类型的第二个元素是 Any)。

如何解决问题?

答案:减少 iterate 的返回类型选项的数量。

这是一篇快速的文章(我并不认为它是最简洁的,也没有对其进行广泛的测试,所以可能存在一些错误,但它很简单,可以使用您的代码快速编写以展示如何解决您的问题; 请注意,当其中一个集合为空时,我会使用特殊分支,因为只迭代一个集合应该会更快):

struct MergeSorted{T,A,B,O,F1,F2}
    a::A
    b::B
    order::O
    fa::F1
    fb::F2
    function MergeSorted(a::A,b::B,order::O=Base.Order.Forward) where {A,O}
        fa,fb = iterate(a),iterate(b)
        F1 = typeof(fa)
        F2 = typeof(fb)
        new{promote_type(eltype(A),eltype(B)),F2}(a,b,order,fa,fb)
    end
end

Base.eltype(::Type{MergeSorted{T,O}}) where {T,O} = T

struct State{Ta,Tb}
    a::Union{Nothing,Ta}
    b::Union{Nothing,Tb}
end

function Base.iterate(self::MergeSorted{T,Nothing,Nothing}) where {T,O}
    return nothing
end

function Base.iterate(self::MergeSorted{T,F1}
    return self.fa
end

function Base.iterate(self::MergeSorted{T,Nothing},state) where {T,F1}
    return iterate(self.a,state)
end

function Base.iterate(self::MergeSorted{T,F2}) where {T,F2}
    return self.fb
end

function Base.iterate(self::MergeSorted{T,F2},F2}
    return iterate(self.b,state)
end

@inline function Base.iterate(self::MergeSorted{T,F2}
    a_result,b_result = self.fa,self.fb
    return iterate(self,State{F1,F2}(a_result,b_result))
end

@inline function Base.iterate(self::MergeSorted{T,state::State{F1,b_result = state.a,state.b

    if b_result === nothing
        a_result === nothing && return nothing
        a_curr,a_state = a_result
        return T(a_curr),F2}(iterate(self.a,a_state),b_result)
    end

    b_curr,b_state = b_result
    if a_result !== nothing
        a_curr,a_state = a_result
        Base.Order.lt(self.order,a_curr,b_curr) &&
            return T(a_curr),b_result)
    end
    return T(b_curr),iterate(self.b,b_state))
end

现在你有:

julia> x = MergeSorted([1,4,5,9,32,44],[0,7,24,134]);

julia> sum(x)
269

julia> @allocated sum(x)
0

julia> @code_warntype iterate(x,Base.Order.ForwardOrdering,Int64}}
  state::State{Tuple{Int64,Int64}}
  @_4::Int64
  @_5::Int64
  @_6::Int64
  b_state::Int64
  b_curr::Int64
  a_state::Int64
  a_curr::Int64
  b_result::Union{Nothing,Int64}}
  a_result::Union{Nothing,Int64}}

Body::Union{Nothing,State{Tuple{Int64,Int64}}}}
1 ─       nothing
│         Core.NewvarNode(:(@_4))
│         Core.NewvarNode(:(@_5))
│         Core.NewvarNode(:(@_6))
│         Core.NewvarNode(:(b_state))
│         Core.NewvarNode(:(b_curr))
│         Core.NewvarNode(:(a_state))
│         Core.NewvarNode(:(a_curr))
│   %9  = Base.getproperty(state,:a)::Union{Nothing,Int64}}
│   %10 = Base.getproperty(state,:b)::Union{Nothing,Int64}}
│         (a_result = %9)
│         (b_result = %10)
│   %13 = (b_result === Main.nothing)::Bool
└──       goto #5 if not %13
2 ─ %15 = (a_result === Main.nothing)::Bool
└──       goto #4 if not %15
3 ─       return Main.nothing
4 ─ %18 = Base.indexed_iterate(a_result::Tuple{Int64,Core.Const(2)])
│         (a_curr = Core.getfield(%18,1))
│         (@_6 = Core.getfield(%18,2))
│   %21 = Base.indexed_iterate(a_result::Tuple{Int64,@_6::Core.Const(2))::Core.PartialStruct(Tuple{Int64,Core.Const(3)])
│         (a_state = Core.getfield(%21,1))
│   %23 = ($(Expr(:static_parameter,1)))(a_curr)::Int64
│   %24 = Core.apply_type(Main.State,$(Expr(:static_parameter,5)),6)))::Core.Const(State{Tuple{Int64,Int64}})
│   %25 = Base.getproperty(self,:a)::Vector{Int64}
│   %26 = Main.iterate(%25,Int64}}
│   %27 = (%24)(%26,b_result::Core.Const(nothing))::State{Tuple{Int64,Int64}}
│   %28 = Core.tuple(%23,%27)::Tuple{Int64,Int64}}}
└──       return %28
5 ─ %30 = Base.indexed_iterate(b_result::Tuple{Int64,2))
│   %33 = Base.indexed_iterate(b_result::Tuple{Int64,1))
│   %35 = (a_result !== Main.nothing)::Bool
└──       goto #8 if not %35
6 ─ %37 = Base.indexed_iterate(a_result::Tuple{Int64,2))
│   %40 = Base.indexed_iterate(a_result::Tuple{Int64,b_curr)::Bool
└──       goto #8 if not %46
7 ─ %48 = ($(Expr(:static_parameter,1)))(a_curr)::Int64
│   %49 = Core.apply_type(Main.State,Int64}})
│   %50 = Base.getproperty(self,:a)::Vector{Int64}
│   %51 = Main.iterate(%50,Int64}}
│   %52 = (%49)(%51,b_result::Tuple{Int64,Int64})::State{Tuple{Int64,Int64}}
│   %53 = Core.tuple(%48,%52)::Tuple{Int64,Int64}}}
└──       return %53
8 ┄ %55 = ($(Expr(:static_parameter,1)))(b_curr)::Int64
│   %56 = Core.apply_type(Main.State,Int64}})
│   %57 = a_result::Union{Nothing,Int64}}
│   %58 = Base.getproperty(self,:b)::Vector{Int64}
│   %59 = Main.iterate(%58,Int64}}
│   %60 = (%56)(%57,%59)::State{Tuple{Int64,Int64}}
│   %61 = Core.tuple(%55,%60)::Tuple{Int64,Int64}}}
└──       return %61

编辑:现在我意识到我的实现并不完全正确,因为它假设 iterate 的返回值如果不是 nothing 是类型稳定的(它不必是)。但是如果它不是类型稳定的,那么编译器必须分配。因此,完全正确的解决方案将首先检查 iterate 是否类型稳定。如果是 - 使用我的解决方案,如果不是 - 使用例如您的解决方案。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。