Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast concatentation #283

Merged
merged 1 commit into from
Jul 20, 2023
Merged

Fast concatentation #283

merged 1 commit into from
Jul 20, 2023

Conversation

haakon-e
Copy link
Member

@haakon-e haakon-e commented May 5, 2023

closes #274

@haakon-e
Copy link
Member Author

haakon-e commented May 5, 2023

Looking at the actual cat-statements we have in the code, I realized good solutions actually are different depending on details about the vectors/matrices are being reduced. So I will test a bit for the different types to verify which patterns repeat themselves.

@haakon-e haakon-e force-pushed the new-cat branch 7 times, most recently from 5c95b94 to c97f0d2 Compare June 20, 2023 01:44
Copy link
Member Author

@haakon-e haakon-e left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@odunbar I have added explanations to all my changes as review comments. Feel free to reply directly to each one you may have questions about.

For every change I made, I tested that the refactored function is faster and allocates less than the baseline. In many cases, the speedup is very large. In other cases, it's minimal.
Note: I've only tested this using structs I found in the unit/integration tests, which are typically very small. So it's conceivable, although unlikely, that the improvement for "real world" examples (i.e. very large distributions) may be different.

src/ParameterDistributions.jl Show resolved Hide resolved
src/ParameterDistributions.jl Show resolved Hide resolved
src/ParameterDistributions.jl Show resolved Hide resolved
src/ParameterDistributions.jl Show resolved Hide resolved
src/ParameterDistributions.jl Show resolved Hide resolved
src/ParameterDistributions.jl Show resolved Hide resolved
src/ParameterDistributions.jl Show resolved Hide resolved
src/ParameterDistributions.jl Outdated Show resolved Hide resolved
src/ParameterDistributions.jl Outdated Show resolved Hide resolved
src/ParameterDistributions.jl Outdated Show resolved Hide resolved
Copy link
Collaborator

@odunbar odunbar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this! It looks greatly improved. I do not think any of these changes will haunt us computationally compared with the use of cat(...).

I've left a couple of small comments regarding using the ndims for function distributions. An once resolved I'm happy!

src/ParameterDistributions.jl Show resolved Hide resolved
src/ParameterDistributions.jl Outdated Show resolved Hide resolved
src/ParameterDistributions.jl Outdated Show resolved Hide resolved
src/ParameterDistributions.jl Outdated Show resolved Hide resolved
@haakon-e
Copy link
Member Author

haakon-e commented Jul 5, 2023

@odunbar made the requested changes (I kept the relevant discussions "unresolved"). Let me know what you think. Once approved, I will squash before bors.

Copy link
Collaborator

@odunbar odunbar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @haakon-e these changes look good! I see a test failed though.

Could this be resolved - then I'm happy for merge.

@haakon-e
Copy link
Member Author

Hi @haakon-e these changes look good! I see a test failed though.

Could this be resolved - then I'm happy for merge.

@odunbar It appears the test only fails on v1.6.
I tried to go about reproducing locally, but I am for some reason not able to even instantiate EKP.jl on v1.6.7 (I get error ERROR: expected package `Distributions [31c24e10]` to be registered), so I'm at a loss for how start addressing this.

@haakon-e
Copy link
Member Author

haakon-e commented Jul 20, 2023

It turns out that on julia <= 1.8 the method mapslices, which we use to transform columns of a matrix, errors if the function that is applied to a column returns a n x 1 matrix instead of a n-vector. For example:

# julia v1.8.4
julia> xx = zeros(4,5);

julia> mapslices(x -> reshape(x, 4, 1), xx; dims=1)
ERROR: DimensionMismatch: tried to assign 2 elements to 1 destinations
Stacktrace:
 [1] throw_setindex_mismatch(X::Vector{Base.OneTo{Int64}}, I::Tuple{Int64})
   @ Base ./indices.jl:191
 [2] setindex_shape_check
   @ ./indices.jl:245 [inlined]
 [3] setindex!(A::Vector{Base.OneTo{Int64}}, X::Vector{Base.OneTo{Int64}}, I::Vector{Int64})
   @ Base ./array.jl:973
 [4] mapslices(f::var"#5#6", A::Matrix{Float64}; dims::Int64)
   @ Base ./abstractarray.jl:2873
 [5] top-level scope
   @ REPL[5]:1

In contrast, this works in julia v1.9:

# julia v1.9.2
julia> xx = zeros(4,5);

julia> mapslices(x -> reshape(x, 4, 1), xx; dims=1)
4×5 Matrix{Float64}:
 0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0

I have updated my implementation of transform_(un)constrained_to_(un)constrained to return a vector if a vector is given, and a matrix if a matrix is given.
This does two things:

  1. It fixes the aforementioned issue since mapslices passes the slices as vectors to the function
  2. this improves type inference on both julia v1.8 and julia v1.9. In 1.8, julia failed to infer the return type of the method (@code_warntype indicated Any). In 1.9, julia was only able to say Union{Matrix, Vector} for matrix or vector input types. With these changes, the compiler correctly infers the type of the output if the input is a vector or is a matrix.

All of this being said, as far as I can see the use of mapslices in the first place is not necessary. For example, the following calls are equivalent (see callsite in tests)

x_real_constrained1 = mapslices(x -> transform_unconstrained_to_constrained(u1, x), x_unbd[1:4, :]; dims = 1)
x_real_constrained1 == transform_unconstrained_to_constrained(u1, x_unbd[1:4, :])
true

but this could be changed in a different PR

replace various instances of "cat" + "..."
with faster (and often less allocating) alternatives
such as reduce(vcat, [iterable]).
Sometimes, the size and type of the output array can
be predicted. this information is now exploited.
For `transform_(un)cons_to_(un)cons`-methods, optimizations
meant Vectors and Matrices could be handled by the same function
without the need for multiple dispatch.
@haakon-e
Copy link
Member Author

bors r+

@bors
Copy link
Contributor

bors bot commented Jul 20, 2023

Build succeeded!

The publicly hosted instance of bors-ng is deprecated and will go away soon.

If you want to self-host your own instance, instructions are here.
For more help, visit the forum.

If you want to switch to GitHub's built-in merge queue, visit their help page.

@bors bors bot merged commit 87064e5 into main Jul 20, 2023
13 checks passed
@bors bors bot deleted the new-cat branch July 20, 2023 17:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

remove cat(X...) statements
2 participants