Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable scalar/broadcast operation for LazyPropagation #167

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/TimeModeling/LinearOperators/lazy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ end
size(jA::jAdjoint) = (jA.op.n, jA.op.m)
display(P::jAdjoint) = println("Adjoint($(P.op))")
display(P::judiProjection{D}) where D = println("JUDI projection operator $(repr(P.n)) -> $(repr(P.m))")
display(P::judiWavelet{T}) where T = println("JUDI wavelet injected at every grid point")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not quite true, it's just a container for a single time trace there is nothing about "everywhere in space" in it.


############################################################################################################################
# Indexing
Expand Down
10 changes: 9 additions & 1 deletion src/TimeModeling/Types/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,15 @@ vec(x::judiMultiSourceVector) = vcat(vec.(x.data)...)

time_sampling(ms::judiMultiSourceVector) = [1 for i=1:ms.nsrc]

reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N = reshape(vec(ms), dims)
function reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for

failing example
using JUDI
using Flux
using ArgParse, Test, Printf, Aqua
using SegyIO, LinearAlgebra, Distributed, JOLI
using TimerOutputs: TimerOutputs, @timeit

Flux.Random.seed!(2022)

### Model
tti = false
viscoacoustic = false

nsrc = 1
dt = 1f0
include(joinpath(JUDIPATH, "../test/seismic_utils.jl"))
model, model0, dm = setup_model(tti, viscoacoustic, 4)
m, m0 = model.m.data, model0.m.data
q, srcGeometry, recGeometry, f0 = setup_geom(model; nsrc=nsrc, dt=dt)

# Common op
Pr = judiProjection(recGeometry)
Ps = judiProjection(srcGeometry)

opt = Options(sum_padding=true, f0=f0)
A_inv = judiModeling(model; options=opt)
A_inv0 = judiModeling(model0; options=opt)

# Operators
F = Pr*A_inv*adjoint(Ps)
J = judiJacobian(F,q)
dm = vec(m-m0)

gs_inv = gradient(q -> norm(J(q)*dm), q)
ERROR: LoadError: DimensionMismatch: new dimensions (1,) must be consistent with array size 1501
Stacktrace:
  [1] (::Base.var"#throw_dmrsa#289")(dims::Tuple{Int64}, len::Int64)
    @ Base ./reshapedarray.jl:41
  [2] reshape
    @ ./reshapedarray.jl:45 [inlined]
  [3] reshape
    @ ~/.julia/dev/JUDI/src/TimeModeling/Types/abstract.jl:76 [inlined]
  [4] reshape(parent::judiVector{Float32, Matrix{Float32}}, shp::Tuple{Base.OneTo{Int64}})
    @ Base ./reshapedarray.jl:111
  [5] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::JUDI.LazyPropagation)
    @ JUDI ~/.julia/dev/JUDI/src/rrules.jl:142
  [6] _project
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:184 [inlined]
  [7] map(f::typeof(Zygote._project), t::Tuple{judiVector{Float32, Matrix{Float32}}}, s::Tuple{JUDI.LazyPropagation})
    @ Base ./tuple.jl:246
  [8] gradient(f::Function, args::judiVector{Float32, Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:98
  [9] top-level scope
    @ ~/.julia/dev/JUDI/test/MFE.jl:33
 [10] include(fname::String)
    @ Base.MainInclude ./client.jl:476
 [11] top-level scope
    @ REPL[1]:1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't failing before what changed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example above fails on master branch

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be more specific:

julia> gs_inv = gradient(() -> norm(J(q)*dm), Flux.params(q))
Operator `born` ran in 0.75 s
Grads(...)

this doesn't fail but

julia> gs_inv = gradient(q -> norm(J(q)*dm), q)
Operator `born` ran in 0.72 s
Operator `born` ran in 0.73 s
ERROR: DimensionMismatch: new dimensions (1,) must be consistent with array size 1501
Stacktrace:
  [1] (::Base.var"#throw_dmrsa#289")(dims::Tuple{Int64}, len::Int64)
    @ Base ./reshapedarray.jl:41
  [2] reshape
    @ ./reshapedarray.jl:45 [inlined]
  [3] reshape
    @ ~/.julia/dev/JUDI/src/TimeModeling/Types/abstract.jl:76 [inlined]
  [4] reshape(parent::judiVector{Float32, Matrix{Float32}}, shp::Tuple{Base.OneTo{Int64}})
    @ Base ./reshapedarray.jl:111
  [5] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::JUDI.LazyPropagation)
    @ JUDI ~/.julia/dev/JUDI/src/rrules.jl:142
  [6] _project
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:184 [inlined]
  [7] map(f::typeof(Zygote._project), t::Tuple{judiVector{Float32, Matrix{Float32}}}, s::Tuple{JUDI.LazyPropagation})
    @ Base ./tuple.jl:246
  [8] gradient(f::Function, args::judiVector{Float32, Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:98
  [9] top-level scope
    @ REPL[25]:1
 [10] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52

this fail

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum ok, then split into Dims{N} and Dims{1} and just an if/else these try/catch are really a bad idea anywhere near Zygote

try
return reshape(vec(ms), dims)
catch e
@assert dims[1] == ms.nsrc ### during AD, size(ms::judiVector) = ms.nsrc
return ms
end
end

############################################################################################################################
# Linear algebra `*`
(msv::judiMultiSourceVector{mT})(x::AbstractVector{T}) where {mT, T<:Number} = x
Expand Down
31 changes: 28 additions & 3 deletions src/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@ Parameters
* `F`: the JUDI propgator
* `q`: The source to compute F*q
"""
struct LazyPropagation
mutable struct LazyPropagation
post::Function
F::judiPropagator
q
val # store F * q
end

eval_prop(F::LazyPropagation) = F.post(F.F * F.q)
function eval_prop(F::LazyPropagation)
isnothing(F.val) && (F.val = F.F * F.q)
return F.post(F.val)
end
Base.collect(F::LazyPropagation) = eval_prop(F)
LazyPropagation(post, F::judiPropagator, q) = LazyPropagation(post, F, q, nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

::Function

LazyPropagation(F::judiPropagator, q) = LazyPropagation(identity, F, q)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

COnstructor LazyPropagation(identity, F, q, nothing) extra function call not needed


# Only a few arithmetic operation are supported
Expand All @@ -45,15 +50,35 @@ for op in [:+, :-, :*, :/]
end
end

for op in [:*, :/]
@eval begin
$(op)(F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, $(op)(F.q, y), isnothing(F.val) ? nothing : $(op)(F.val, y))
$(op)(y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, $(op)(y, F.q), isnothing(F.val) ? nothing : $(op)(y, F.val))
broadcasted(::typeof($op), F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), F.q, y), isnothing(F.val) ? nothing : broadcasted($(op), F.val, y))
broadcasted(::typeof($op), y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), y, F.q), isnothing(F.val) ? nothing : broadcasted($(op), y, F.val))
end
end

for op in [:+, :-]
@eval begin
$(op)(F::LazyPropagation, y::T) where T <: Number = $(op)(eval_prop(F), y)
$(op)(y::T, F::LazyPropagation) where T <: Number = $(op)(y, eval_prop(F))
broadcasted(::typeof($op), F::LazyPropagation, y::T) where T <: Number = broadcasted($(op), eval_prop(F), y)
broadcasted(::typeof($op), y::T, F::LazyPropagation) where T <: Number = broadcasted($(op), y, eval_prop(F))
end
end

broadcasted(::typeof(^), y::LazyPropagation, p::Real) = eval_prop(y).^(p)
*(F::judiPropagator, q::LazyPropagation) = F*eval_prop(q)

reshape(F::LazyPropagation, dims...) = LazyPropagation(x->reshape(x, dims...), F.F, Q.q)
reshape(F::LazyPropagation, dims...) = LazyPropagation(x->reshape(x, dims...), F.F, F.q, F.val)
vec(F::LazyPropagation) = LazyPropagation(vec, F.F, F.q, F.val)
copyto!(x::AbstractArray, F::LazyPropagation) = copyto!(x, eval_prop(F))
dot(x::AbstractArray, F::LazyPropagation) = dot(x, eval_prop(F))
dot(F::LazyPropagation, x::AbstractArray) = dot(x, F)
norm(F::LazyPropagation, p::Real=2) = norm(eval_prop(F), p)
adjoint(F::JUDI.LazyPropagation) = F
length(F::JUDI.LazyPropagation) = size(F.F, 1)

############################ Two params rules ############################################
function rrule(F::judiPropagator{T, O}, m::AbstractArray{T}, q::AbstractArray{T}) where {T, O}
Expand Down
2 changes: 1 addition & 1 deletion test/test_rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ perturb(x::judiVector) = judiVector(x.geometry, [randx(x.data[i]) for i=1:x.nsrc
reverse(x::judiVector) = judiVector(x.geometry, [x.data[i][end:-1:1, :] for i=1:x.nsrc])

misfit_objective_2p(d_obs, q0, m0, F) = .5f0*norm(F(m0, q0) - d_obs)^2
misfit_objective_1p(d_obs, q0, m0, F) = .5f0*norm(F(m0)*q0 - d_obs)^2
misfit_objective_1p(d_obs, q0, m0, F) = .5f0*norm(F(1f0*m0)*q0 - d_obs)^2

function loss(misfit, d_obs, q0, m0, F)
local ϕ
Expand Down