Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mapreduce with accumulation inside is broken #1448

Open
Red-Portal opened this issue Aug 17, 2023 · 10 comments
Open

mapreduce with accumulation inside is broken #1448

Red-Portal opened this issue Aug 17, 2023 · 10 comments
Labels
bug Something isn't working ChainRules adjoint -> rrule, and further integration compiler

Comments

@Red-Portal
Copy link

Red-Portal commented Aug 17, 2023

Hi, the following use-case of mapreduce doesnt work:

gradient(randn(10)) do x
         y₀ = Float64[]
         ∑x = 0.0
         ys = mapreduce(vcat, x, 1:length(x); init = y₀) do xᵢ, r
             yᵢ = xᵢ.^2
             ∑x += xᵢ
             [yᵢ]
         end
         sum(ys) + ∑x
end

It seems the ∑x += xᵢ part is at fault here because with or without init it doesn't work:

(vcat, [[0.7008872619503351], [0.057800842475147274], [0.4508806424034738], [6.360461041381114], [8.229642138382558e-5], [0.43781177206525196], [1.7425577575168238], [0.8947064561514089], [0.678655434187004], [0.10421486484899199]])
(init = Float64[],)
ERROR: MethodError: no method matching iterate(::Nothing)

Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen})
   @ Base range.jl:880
  iterate(::Union{LinRange, StepRangeLen}, ::Integer)
   @ Base range.jl:880
  iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}}
   @ Base dict.jl:698
  ...

Stacktrace:
  [1] indexed_iterate(I::Nothing, i::Int64)
    @ Base ./tuple.jl:91
  [2] chain_rrule_kw(::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::Function, ::Function, ::Vararg{Any})
    @ Zygote ./REPL[7]:5
  [3] macro expansion
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::typeof(reduce), ::typeof(vcat), ::Vector{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101
  [5] _pullback
    @ ./reducedim.jl:359 [inlined]
  [6] _pullback(::Zygote.Context{false}, ::Base.var"##mapreduce#801", ::Base.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:init,), Tuple{Vector{Float64}}}}, ::typeof(mapreduce), ::var"#24#26", ::typeof(vcat), ::Vector{Float64}, ::UnitRange{Int64})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
  [7] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
  [8] adjoint
    @ ~/.julia/packages/Zygote/4rucm/src/lib/lib.jl:203 [inlined]
  [9] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [10] _pullback
    @ ./reducedim.jl:359 [inlined]
 [11] _pullback
    @ ./REPL[8]:4 [inlined]
 [12] _pullback(ctx::Zygote.Context{false}, f::var"#23#25", args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
 [13] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:44
 [14] pullback
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:42 [inlined]
 [15] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:96
 [16] top-level scope
    @ REPL[8]:1

This used to work and got broken at some point. Is this an rrule problem? This currently works without problem on ReverseDiff and ForwardDiff.

@Red-Portal
Copy link
Author

Related issues in Bijectors.jl and Turing.jl

@ToucheSir ToucheSir added ChainRules adjoint -> rrule, and further integration compiler labels Aug 18, 2023
@ToucheSir
Copy link
Member

That code path has failed in the past because there are ambiguities in which rrules might apply for a given call. In this case I'm not sure if that is the culprit, however. I believe the problem is that ChainRules does not have a rrule for reduce(vcat, ...; init=...), yet somehow the has_chain_rrule detection logic is reporting it does.

@ToucheSir ToucheSir added the bug Something isn't working label Aug 18, 2023
@mcabbott
Copy link
Member

mapreduce(f, vcat, x, 1:length(x); init = y₀) could probably be plumbed to reduce(vcat, foldl(f, x, 1:length(x); init = y₀)). Perhaps that would be one way to work around this.

Note also that reduce(vcat, xs; init) and mapreduce(f, vcat, xs) are always pairwise, they never hit the magic fast path of reduce(vcat, xs).

@Red-Portal
Copy link
Author

@torfjelde Is there a reason we compute the first element first and then use that to initialize mapreduce in Stacked?

@torfjelde
Copy link
Contributor

Is there a reason we compute the first element first and then use that to initialize mapreduce in Stacked?

Type-stability issues, in particular when combined with AD. Very often we'd run into instabilities without init, and so I believe this was a way to work around this (type-stability is quite crucial here, in particular with Zygote).

@ToucheSir
Copy link
Member

About type stability, note that any call to a method with kwargs (whether they're provided in the call or not) will be type unstable unless there's a rrule defined for that particular method. In this case there is not.

@Red-Portal
Copy link
Author

@torfjelde This issue is still persisting; any suggestions on how we should deal with this? Maybe just change the Stacked bijector implementation so that we don't hit this edge-case at all? I'm thinking computing the Jacobian and the transformation through two separate calls to mapreduce. Probably less efficient, but I don't see any other way unless this gets fixed. Also, we could expect the mapreduce(vcat, ...) fast path to kick in?

@torfjelde
Copy link
Contributor

aybe just change the Stacked bijector implementation so that we don't hit this edge-case at all?

Yep, that's what we should do imo.

Is stack applicable here?

Also, we could expect the mapreduce(vcat, ...) fast path to kick in?

Does such a fast-path exist?

@Red-Portal
Copy link
Author

Red-Portal commented Jun 4, 2024

Does such a fast-path exist?

Oh sorry, I meant reduce(vcat). For this, I'm quoting @mcabbott 's reply:

note also that reduce(vcat, xs; init) and mapreduce(f, vcat, xs) are always pairwise, they never hit the magic fast path of reduce(vcat, xs).

@torfjelde
Copy link
Contributor

But yeah, I'd recommend that we just work around it by implementing something more specialized 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ChainRules adjoint -> rrule, and further integration compiler
Projects
None yet
Development

No branches or pull requests

4 participants