Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable scalar/broadcast operation for LazyPropagation #167

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

ziyiyin97
Copy link
Member

@ziyiyin97 ziyiyin97 commented Jan 4, 2023

  1. Enable scalar/broadcast operation for LazyPropagation; add associated test (which won't pass with the current master)
  2. LazyPropagation now has an attribute val, which stores F * q if previously computed
  3. fix the reshape issue for multi source vector -- which can be in size of nsrc and also in size of nsrc * nt * nrec

@codecov
Copy link

codecov bot commented Jan 4, 2023

Codecov Report

Base: 81.88% // Head: 81.59% // Decreases project coverage by -0.29% ⚠️

Coverage data is based on head (171170f) compared to base (8f65ed4).
Patch coverage: 41.17% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #167      +/-   ##
==========================================
- Coverage   81.88%   81.59%   -0.30%     
==========================================
  Files          28       28              
  Lines        2186     2200      +14     
==========================================
+ Hits         1790     1795       +5     
- Misses        396      405       +9     
Impacted Files Coverage Δ
src/TimeModeling/LinearOperators/lazy.jl 83.72% <0.00%> (-0.99%) ⬇️
src/rrules.jl 62.02% <40.00%> (-5.14%) ⬇️
src/TimeModeling/Types/abstract.jl 77.41% <100.00%> (+0.18%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@mloubout mloubout added the bug label Jan 5, 2023
src/rrules.jl Outdated
@@ -32,16 +32,12 @@ LazyPropagation(F::judiPropagator, q) = LazyPropagation(identity, F, q)

for op in [:+, :-, :*, :/]
@eval begin
$(op)(F::LazyPropagation, y::AbstractArray{T}) where T = $(op)(eval_prop(F), y)
$(op)(y::AbstractArray{T}, F::LazyPropagation) where T = $(op)(y, eval_prop(F))
$(op)(F::LazyPropagation, y::Union{AbstractArray{T}, T}) where T = $(op)(eval_prop(F), y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are generalizing in a bad way. Unions aren't really nice for different abstract type. So Union of like Vector and Matrix is fine but for Scalar and AbstractArray it's not great.

You are also completely ignoring the setup here that make type separate through eval, that Union completely defeat the purpose.

Finally, you are again not considering what this type is about nor what you are trying to do and put here first way "that works" you could make up. The main point is to avoid evaluating PDEs when not necessary and you are not forcing potentially un-neccessary PDEs. a LazyPropagation is Linear, if you do a .* Lazy then it's the same as Lazy(L.F, a .* L.q) which does not evaluate anything.

@mloubout
Copy link
Member

mloubout commented Jan 5, 2023

Your change lead to ambiguities... please run these basic tests locally

src/rrules.jl Outdated
$(op)(y::Union{AbstractArray{T}, T}, F::LazyPropagation) where T = $(op)(y, eval_prop(F))
$(op)(F::LazyPropagation, y::AbstractArray{T}) where T = $(op)(eval_prop(F), y)
$(op)(y::AbstractArray{T}, F::LazyPropagation) where T = $(op)(y, eval_prop(F))
$(op)(F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, $(op)(F.q, y))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is only true for * and /

src/rrules.jl Outdated
end
end
@eval begin
broadcasted(::typeof($op), F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), eval_prop(F.q), y))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is eval_prop(F.q) ???

Again, only true for * and /

@ziyiyin97
Copy link
Member Author

On a side note: does it make sense to move the scalar operations (all of +-*/) into LazyPropagation.post?

@mloubout
Copy link
Member

mloubout commented Jan 5, 2023

nto LazyPropagation.post?

No because then it's not a linear operation anymore

@ziyiyin97 ziyiyin97 marked this pull request as draft January 5, 2023 17:32
@ziyiyin97
Copy link
Member Author

ziyiyin97 commented Jan 6, 2023

Hmm appreciate your @mloubout comment on this one: I am now on JUDI master and

julia> gs_inv = gradient(x -> norm(F(x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.50 s
Operator `gradient` ran in 0.34 s
(Float32[-0.081900775 0.07301128  6.170804f-6 7.20752f-6; 0.0637427 0.027981473  9.756089f-7 5.4272978f-6;  ; 0.06374304 0.027981216  9.755976f-7 5.4272914f-6; -0.08189945 0.07301152  6.170794f-6 7.2075245f-6],)

julia> gs_inv1 = gradient(x -> norm(F(1f0*x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.55 s
Operator `forward` ran in 0.49 s
Operator `gradient` ran in 0.34 s
ERROR: MethodError: no method matching *(::Float32, ::JUDI.LazyPropagation)
Closest candidates are:
  *(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  *(::T, ::T) where T<:Union{Float16, Float32, Float64} at float.jl:385
  *(::Union{Float16, Float32, Float64}, ::BigFloat) at mpfr.jl:414
  ...
Stacktrace:
  [1] (::ChainRules.var"#1490#1494"{JUDI.LazyPropagation, Float32, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}})()
    @ ChainRules ~/.julia/packages/ChainRules/ajkp7/src/rulesets/Base/arraymath.jl:111
  [2] unthunk
    @ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:204 [inlined]
  [3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1490#1494"{JUDI.LazyPropagation, Float32, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, ChainRules.var"#1489#1493"{JUDI.LazyPropagation, Float32}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:237
  [4] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:105 [inlined]
  [5] map
    @ ./tuple.jl:223 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:106 [inlined]
  [7] ZBack
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:206 [inlined]
  [8] Pullback
    @ ./REPL[26]:1 [inlined]
  [9] (::typeof((#10)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#60#61"{typeof((#10))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
 [11] gradient(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
 [12] top-level scope
    @ REPL[26]:1
 [13] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52

julia> import Base.*;

julia> *(y::Float32, F::JUDI.LazyPropagation) = JUDI.LazyPropagation(F.post, F.F, *(y, F.q));

julia> gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.56 s
Operator `forward` ran in 0.53 s
Operator `gradient` ran in 0.34 s
Operator `forward` ran in 0.43 s
Operator `gradient` ran in 0.35 s
(Float32[-0.081900775 0.07301128  6.170804f-6 7.20752f-6; 0.0637427 0.027981473  9.756089f-7 5.4272978f-6;  ; 0.06374304 0.027981216  9.755976f-7 5.4272914f-6; -0.08189945 0.07301152  6.170794f-6 7.2075245f-6],)

gs_inv performs a nonlinear forward modeling and an RTM. gs_inv1 fails because scalar multiplication is not defined yet. After the definition of multiplication, gs_inv2 did 2 evaluations on the LazyPropgation, which confuses me ... any idea why? Thanks

Full script below

using JUDI
using Flux
using ArgParse, Test, Printf, Aqua
using SegyIO, LinearAlgebra, Distributed, JOLI
using TimerOutputs: TimerOutputs, @timeit

Flux.Random.seed!(2022)

### Model
tti = false
viscoacoustic = false

nsrc = 1
dt = 1f0
include(joinpath(JUDIPATH, "../test/seismic_utils.jl"))
model, model0, dm = setup_model(tti, viscoacoustic, 4)
m, m0 = model.m.data, model0.m.data
q, srcGeometry, recGeometry, f0 = setup_geom(model; nsrc=nsrc, dt=dt)

# Common op
Pr = judiProjection(recGeometry)
Ps = judiProjection(srcGeometry)

ra = false
stype = "Point"
Pq = Ps

opt = Options(return_array=ra, sum_padding=true, f0=f0)
A_inv = judiModeling(model; options=opt)
A_inv0 = judiModeling(model0; options=opt)

# Operators
F = Pr*A_inv*adjoint(Pq)
F0 = Pr*A_inv0*adjoint(Pq)

gs_inv = gradient(x -> norm(F(x)*q), m0)

gs_inv1 = gradient(x -> norm(F(1f0*x)*q), m0)

import Base.*;
*(y::Float32, F::JUDI.LazyPropagation) = JUDI.LazyPropagation(F.post, F.F, *(y, F.q));
gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0)

@mloubout
Copy link
Member

mloubout commented Jan 6, 2023

That's quite curious indeed i'll see if can figure out what's going on

@mloubout
Copy link
Member

mloubout commented Jan 6, 2023

Well that's is baaaaaaaad, this is why people don't wanna use Julia for serious stuff.

When you do gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0) Zygote doesn't understand correctly that you want "only" the derivative w.r.t to m0, in part because it doesn't understand thinks. So it end up computing what you want, i.e d F(1*m0)*q / d m0 but because diff rules are defined for both left and right input for mul (and again since zygote always computes and evaluate everything) it also computes d F(1*m0)*q / d 1 which calls dot which calls eval_prop.

So there is not trivial way out of it except maybe having LazyPropagation store the result at its first evaluation so its only computed once (the above compute the same gradient twice)

@@ -73,7 +73,15 @@ vec(x::judiMultiSourceVector) = vcat(vec.(x.data)...)

time_sampling(ms::judiMultiSourceVector) = [1 for i=1:ms.nsrc]

reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N = reshape(vec(ms), dims)
function reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for

failing example
using JUDI
using Flux
using ArgParse, Test, Printf, Aqua
using SegyIO, LinearAlgebra, Distributed, JOLI
using TimerOutputs: TimerOutputs, @timeit

Flux.Random.seed!(2022)

### Model
tti = false
viscoacoustic = false

nsrc = 1
dt = 1f0
include(joinpath(JUDIPATH, "../test/seismic_utils.jl"))
model, model0, dm = setup_model(tti, viscoacoustic, 4)
m, m0 = model.m.data, model0.m.data
q, srcGeometry, recGeometry, f0 = setup_geom(model; nsrc=nsrc, dt=dt)

# Common op
Pr = judiProjection(recGeometry)
Ps = judiProjection(srcGeometry)

opt = Options(sum_padding=true, f0=f0)
A_inv = judiModeling(model; options=opt)
A_inv0 = judiModeling(model0; options=opt)

# Operators
F = Pr*A_inv*adjoint(Ps)
J = judiJacobian(F,q)
dm = vec(m-m0)

gs_inv = gradient(q -> norm(J(q)*dm), q)
ERROR: LoadError: DimensionMismatch: new dimensions (1,) must be consistent with array size 1501
Stacktrace:
  [1] (::Base.var"#throw_dmrsa#289")(dims::Tuple{Int64}, len::Int64)
    @ Base ./reshapedarray.jl:41
  [2] reshape
    @ ./reshapedarray.jl:45 [inlined]
  [3] reshape
    @ ~/.julia/dev/JUDI/src/TimeModeling/Types/abstract.jl:76 [inlined]
  [4] reshape(parent::judiVector{Float32, Matrix{Float32}}, shp::Tuple{Base.OneTo{Int64}})
    @ Base ./reshapedarray.jl:111
  [5] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::JUDI.LazyPropagation)
    @ JUDI ~/.julia/dev/JUDI/src/rrules.jl:142
  [6] _project
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:184 [inlined]
  [7] map(f::typeof(Zygote._project), t::Tuple{judiVector{Float32, Matrix{Float32}}}, s::Tuple{JUDI.LazyPropagation})
    @ Base ./tuple.jl:246
  [8] gradient(f::Function, args::judiVector{Float32, Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:98
  [9] top-level scope
    @ ~/.julia/dev/JUDI/test/MFE.jl:33
 [10] include(fname::String)
    @ Base.MainInclude ./client.jl:476
 [11] top-level scope
    @ REPL[1]:1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't failing before what changed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example above fails on master branch

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be more specific:

julia> gs_inv = gradient(() -> norm(J(q)*dm), Flux.params(q))
Operator `born` ran in 0.75 s
Grads(...)

this doesn't fail but

julia> gs_inv = gradient(q -> norm(J(q)*dm), q)
Operator `born` ran in 0.72 s
Operator `born` ran in 0.73 s
ERROR: DimensionMismatch: new dimensions (1,) must be consistent with array size 1501
Stacktrace:
  [1] (::Base.var"#throw_dmrsa#289")(dims::Tuple{Int64}, len::Int64)
    @ Base ./reshapedarray.jl:41
  [2] reshape
    @ ./reshapedarray.jl:45 [inlined]
  [3] reshape
    @ ~/.julia/dev/JUDI/src/TimeModeling/Types/abstract.jl:76 [inlined]
  [4] reshape(parent::judiVector{Float32, Matrix{Float32}}, shp::Tuple{Base.OneTo{Int64}})
    @ Base ./reshapedarray.jl:111
  [5] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::JUDI.LazyPropagation)
    @ JUDI ~/.julia/dev/JUDI/src/rrules.jl:142
  [6] _project
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:184 [inlined]
  [7] map(f::typeof(Zygote._project), t::Tuple{judiVector{Float32, Matrix{Float32}}}, s::Tuple{JUDI.LazyPropagation})
    @ Base ./tuple.jl:246
  [8] gradient(f::Function, args::judiVector{Float32, Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:98
  [9] top-level scope
    @ REPL[25]:1
 [10] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52

this fail

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum ok, then split into Dims{N} and Dims{1} and just an if/else these try/catch are really a bad idea anywhere near Zygote

@ziyiyin97 ziyiyin97 marked this pull request as ready for review January 6, 2023 20:02
Copy link
Member

@mloubout mloubout left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few test would be appreciated considering the amount of changes

@@ -130,6 +130,7 @@ end
size(jA::jAdjoint) = (jA.op.n, jA.op.m)
display(P::jAdjoint) = println("Adjoint($(P.op))")
display(P::judiProjection{D}) where D = println("JUDI projection operator $(repr(P.n)) -> $(repr(P.m))")
display(P::judiWavelet{T}) where T = println("JUDI wavelet injected at every grid point")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not quite true, it's just a container for a single time trace there is nothing about "everywhere in space" in it.

@@ -73,7 +73,15 @@ vec(x::judiMultiSourceVector) = vcat(vec.(x.data)...)

time_sampling(ms::judiMultiSourceVector) = [1 for i=1:ms.nsrc]

reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N = reshape(vec(ms), dims)
function reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't failing before what changed?

src/rrules.jl Outdated
Base.collect(F::LazyPropagation) = eval_prop(F)
LazyPropagation(post, F::judiPropagator, q) = LazyPropagation(post, F, q, nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

::Function

src/rrules.jl Outdated
Base.collect(F::LazyPropagation) = eval_prop(F)
LazyPropagation(post, F::judiPropagator, q) = LazyPropagation(post, F, q, nothing)
LazyPropagation(F::judiPropagator, q) = LazyPropagation(identity, F, q)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

COnstructor LazyPropagation(identity, F, q, nothing) extra function call not needed

@ziyiyin97
Copy link
Member Author

Could you enlighten me how (by code or something) you reach the conclusion here #167 (comment) ? I am experiencing issue below and would like to check what went wrong ...

julia> gs_inv = gradient(() -> norm(F(1f0*m)*q), Flux.params(m))
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.48 s
Operator `gradient` ran in 0.34 s
Grads(...)

julia> gs_inv = gradient(() -> norm(F(m*1f0)*q), Flux.params(m))
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.53 s
Operator `gradient` ran in 0.34 s
Operator `forward` ran in 0.49 s
Operator `gradient` ran in 0.34 s
Grads(...)

@mloubout
Copy link
Member

mloubout commented Jan 8, 2023

Debug every eval_prop to see which where it's called and what the inputs are. In that other case it was evaluated in dot then you can infer why and check that's undeed the gradient it computes by requesting it as a param

@mloubout
Copy link
Member

mloubout commented Jan 9, 2023

Not sure where you are in the debug, but I can tell you that's it's not super trivial and the fix will require some proper design to extend it cleanly to this type of case. But i'll leave it to you to at least find what the issue is as an exercise.

@ziyiyin97
Copy link
Member Author

Thanks! Yes I agree this is not simple. I will pick it up some time later this week

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants