Skip to content

Commit

Permalink
Type conversions for remaining Datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
RafaelArutjunjan committed Jul 11, 2024
1 parent d9c6c7f commit daf0857
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/DataStructures/DataSetExact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ struct DataSetExact{XdistType<:Distribution, YdistType<:Distribution} <: Abstrac
end


function (::Type{T})(DSE::DataSetExact; kwargs...) where T<:Number
DataSetExact(ConvertDist(xdist(DSE),T), ConvertDist(ydist(DSE), T), dims(DSE), T.(yInvCov(DSE)),
(isnothing(DSE.WoundX) ? nothing : [SVector{xdim(DSE)}(Z) for Z in Windup(T.(xdata(DSE)), xdim(DSE))]);
xnames=xnames(DSE), ynames=ynames(DSE), name=name(DSE), kwargs...)
end


# For SciMLBase.remake
DataSetExact(;
xdist::Distribution=Normal(0,1),
Expand Down Expand Up @@ -140,12 +147,6 @@ ynames(DSE::DataSetExact) = DSE.ynames
name(DSE::DataSetExact) = DSE.name |> string


function Base.BigFloat(DSE::DataSetExact; kwargs...)
BigDist(D::MultivariateNormal) = MvNormal(BigFloat.(mean(D)), BigFloat.(cov(D)))
remake(DSE; xdist=BigDist(xdist(DSE)), ydist=BigDist(ydist(DSE)), InvCov=BigFloat.(yInvCov(DSE)),
WoundX=(isnothing(DSE.WoundX) ? nothing : Windup(BigFloat.(xdata(DSE)), xdim(DSE))), kwargs...)
end


# function InformNames(DS::DataSetExact, xnames::AbstractVector{String}, ynames::AbstractVector{String})
# @assert length(xnames) == xdim(DS) && length(ynames) == ydim(DS)
Expand Down
4 changes: 4 additions & 0 deletions src/DataStructures/DataSetUncertain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ struct DataSetUncertain <: AbstractUnknownUncertaintyDataSet
end
end

function (::Type{T})(DS::DataSetUncertain; kwargs...) where T<:Number
DataSetUncertain(T.(xdata(DS)), T.(ydata(DS)), dims(DS), yinverrormodel(DS), SplitErrorParams(DS), T.(DS.testp), xnames(DS), ynames(DS), name(DS))
end

# For SciMLBase.remake
DataSetUncertain(;
x::AbstractVector=[0.],
Expand Down
16 changes: 16 additions & 0 deletions src/DataStructures/DistributionTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,24 @@ Distributions.cov(d::InformationGeometry.Dirac) = Diagonal(zeros(length(d)))
Distributions.invcov(d::InformationGeometry.Dirac) = Diagonal([Inf for i in 1:length(d)])
Distributions.pdf(d::InformationGeometry.Dirac, x::AbstractVector{<:Number}) = x == mean(d) ? 1.0 : 0.0
Distributions.logpdf(d::InformationGeometry.Dirac, x::AbstractVector{<:Number}) = log(pdf(d, x))
Distributions.params(d::InformationGeometry.Dirac) = (d.μ,)


# Fix gradlogpdf for Cauchy distribution and product distributions in general
Distributions.gradlogpdf(P::Cauchy,x::Real) = gradlogpdf(TDist(1), (x - P.μ) / P.σ) / P.σ
Distributions.gradlogpdf(P::Product,x::AbstractVector) = [gradlogpdf(P.v[i],x[i]) for i in 1:length(x)]



# Get Symbol for everything before {} in type.
UnparametrizeType(D) = (S=string(typeof(D)); Symbol(S[1:findfirst('{',S)-1]))

## Change Number Type of distributions
ConvertDist(D::UnivariateDistribution, ::Type{T}) where T<:Number = eval(quote $(UnparametrizeType(D)){$T}($(T).($(params(D)))...) end)
ConvertDist(D::InformationGeometry.Dirac, ::Type{T}) where T<:Number = InformationGeometry.Dirac(T.(D.μ))
ConvertDist(D::MultivariateNormal, ::Type{T}) where T<:Number = MvNormal(T.(mean(D)), T.(cov(D)))

function ConvertDist(D::Union{Distributions.Product, InformationGeometry.GeneralProduct}, ::Type{T}) where T<:Number
@assert eltype(D.v) <: Distribution{Univariate, Continuous}
product_distribution(broadcast(x->ConvertDist(x,T), D.v))
end
6 changes: 6 additions & 0 deletions src/DataStructures/GeneralizedDataSet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ struct GeneralizedDataSet{DistType<:Distribution} <: AbstractFixedUncertaintyDat
end
end

function (::Type{T})(DS::GeneralizedDataSet; kwargs...) where T<:Number
GeneralizedDataSet(ConvertDist(dist(DS),T), dims(DS),
(isnothing(DS.WoundX) ? nothing : [SVector{xdim(DS)}(Z) for Z in Windup(T.(xdata(DS)), xdim(DS))]);
xnames=xnames(DS), ynames=ynames(DS), name=name(DS), kwargs...)
end

# For SciMLBase.remake
begin
GeneralizedDataSet(;
Expand Down

0 comments on commit daf0857

Please sign in to comment.