Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inserting a known derivative: function needs to exclude dual numbers #179

Open
kimauth opened this issue Mar 18, 2022 · 1 comment
Open

Comments

@kimauth
Copy link
Member

kimauth commented Mar 18, 2022

I ran into the following problem today:

function tensor_exp(A::SymmetricTensor{2})
    E = eigen(A)
    A_exp = zero(A)
    for (i, λ) in enumerate(E.values)
        N = E.vectors[:,i]
        A_exp += exp(λ) * N  N
    end
    return A_exp
end

tensor_exp_gradient(A::AbstractTensor{2}) = tensor_exp(A), tensor_exp(A)

@implement_gradient tensor_exp tensor_exp_gradient
julia> A = rand(SymmetricTensor{2,3}) + one(SymmetricTensor{2,3})
julia> gradient(tensor_exp, A)

ERROR: MethodError: no method matching precision(::Type{ForwardDiff.Dual{ForwardDiff.Tag{typeof(tensor_exp), SymmetricTensor{2, 3, Float64, 6}}, Float64, 6}})
Closest candidates are:
  precision(::Type{Float16}) at C:\Users\auth\AppData\Local\Programs\Julia-1.7.2\share\julia\base\float.jl:686
  precision(::Type{Float32}) at C:\Users\auth\AppData\Local\Programs\Julia-1.7.2\share\julia\base\float.jl:687
  precision(::Type{Float64}) at C:\Users\auth\AppData\Local\Programs\Julia-1.7.2\share\julia\base\float.jl:688
  ...
Stacktrace:
 [1] eigen(R::SymmetricTensor{2, 3, ForwardDiff.Dual{ForwardDiff.Tag{typeof(tensor_exp), SymmetricTensor{2, 3, Float64, 6}}, Float64, 6}, 6})
   @ Tensors C:\Users\auth\.julia\packages\Tensors\fjqpn\src\eigen.jl:137
 [2] tensor_exp(A::SymmetricTensor{2, 3, ForwardDiff.Dual{ForwardDiff.Tag{typeof(tensor_exp), SymmetricTensor{2, 3, Float64, 6}}, Float64, 6}, 6})
   @ Main c:\Users\auth\OneDrive - Chalmers\Documents\courses\Phd courses\Computational nonlinear mechanics\computer assignments\cass3\tensor_log_exp.jl:29
 [3] gradient(f::typeof(tensor_exp), v::SymmetricTensor{2, 3, Float64, 6})
   @ Tensors C:\Users\auth\.julia\packages\Tensors\fjqpn\src\automatic_differentiation.jl:455
 [4] top-level scope
   @ REPL[6]:1

The problem here is that we run into the tensor_exp function with dual numbers (instead of using tensor_exp_gradient).
Looking at the methods of tensor_exp, we can see why:

julia> methods(tensor_exp)
# 2 methods for generic function "tensor_exp":
[1] tensor_exp(A::SymmetricTensor{2}) in Main at c:\Users\auth\OneDrive - Chalmers\Documents\courses\Phd courses\Computational nonlinear mechanics\computer assignments\cass3\tensor_log_exp.jl:28
[2] tensor_exp(x::Union{ForwardDiff.Dual, AbstractTensor{<:Any, <:Any, <:ForwardDiff.Dual}}) in Main at C:\Users\auth\.julia\packages\Tensors\fjqpn\src\automatic_differentiation.jl:253

The original tensor_exp is more specific than the one defined for Dual numbers by @implement_gradient. The solution could be to not allow dual numbers in the original function at all, e.g. by

function tensor_exp(A::SymmetricTensor{2,dim,Float64}) where dim

(This is of course not so nice if one doesn't own this function. )

Perhaps there is a better solution than specifying the number type of the Tensor. In case there isn't we should probably add a hint about it to the docs.

@KnutAM
Copy link
Member

KnutAM commented Mar 19, 2022

One solution is to document and export the _propagate_gradient function used by @implement_gradient. It is not as sleek as the macro but allows the user to solve such a problem. I.e.

tensor_exp(A::SymmetricTensor{2,dim,<:ForwardDiff.Dual}) where{dim} = Tensors._propagate_gradient(tensor_exp_gradient, A)

For this to work the output of tensor_exp must be a symmetric tensor, i.e.

function tensor_exp(A::SymmetricTensor{2})
    E = eigen(A)
    A_exp = zero(A)
    for (i, λ) in enumerate(E.values)
        N = E.vectors[:,i]
        A_exp += exp(λ) * otimes(N)
    end
    return A_exp
end

KnutAM added a commit to KnutAM/Tensors.jl that referenced this issue Mar 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants