This repository has been archived by the owner on Jun 14, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 12
/
unit.jl
331 lines (270 loc) · 8.94 KB
/
unit.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
using Distributions
using JSON
using HMMBase
using LinearAlgebra
using Test
using Random
using HMMBase: from_dict, issquare
Random.seed!(2019)
@testset "Base" begin
hmm1 = HMM([0.9 0.1; 0.1 0.9], [Normal(0, 1), Normal(10, 1)])
hmm2 = HMM(
[0.9 0.1; 0.1 0.9],
[MvNormal([0.0, 0.0], [1.0, 1.0]), MvNormal([10.0, 10.0], [1.0, 1.0])],
)
@test size(hmm1) == (2, 1)
@test size(hmm2) == (2, 2)
@test nparams(hmm1) == 3
@test nparams(hmm2) == 3
end
@testset "Base (2)" begin
hmm1 = HMM([0.9 0.1; 0.1 0.9], [Normal(0, 1), Normal(10, 1)])
hmm2 = HMM(
[0.9 0.1; 0.1 0.9],
[MvNormal([0.0, 0.0], [1.0, 1.0]), MvNormal([10.0, 10.0], [1.0, 1.0])],
)
@test issquare(hmm1.A)
@test istransmat(hmm1.A)
@test hmm1 != hmm2
@test hmm1 == copy(hmm1)
@test hmm1 !== copy(hmm1) # !== (identity) not != (equality)
end
@testset "Base (3)" begin
hmm = HMM([0.9 0.1; 0.1 0.9], [Normal(0, 1), Normal(10, 1)])
perm = [1, 2]
hmmp = permute(hmm, perm)
@test hmmp == hmm
perm = [2, 1]
hmmp = permute(hmm, perm)
@test hmmp != hmm
@test hmmp.B == hmm.B[perm]
@test diag(hmmp.A) == diag(hmm.A)[perm]
end
@testset "Base (4)" begin
hmm1 = HMM([0.9 0.1; 0.1 0.9], [Normal(0, 1), Normal(10, 1)])
hmm2 = HMM(
[0.9 0.1; 0.1 0.9],
[MvNormal([0.0, 0.0], [1.0, 1.0]), MvNormal([10.0, 10.0], [1.0, 1.0])],
)
hmm3 = HMM([0.9 0.1; 0.1 0.9], [MvNormal(ones(3)), MvNormal(ones(3))])
# Univariate HMMs should return observations vectors
# (consistent with Distributions.jl)
z1, y1 = rand(hmm1, 1000, seq = true)
y11 = rand(hmm1, z1)
@test size(z1) == (1000,)
@test size(y1) == (1000,)
@test size(y11) == size(y1)
# Multivariate HMMs should return a `TxK` matrix
# (different from Distributions.jl which returns `KxT`)
z2, y2 = rand(hmm2, 1000, seq = true)
y22 = rand(hmm2, z2)
@test size(z2) == (1000,)
@test size(y2) == (1000, 2)
@test size(y22) == size(y2)
# Rand called with T < 1 should return empty arrays
z3, y3 = rand(hmm2, 0, seq = true)
y33 = rand(hmm2, z3)
@test size(z3) == (0,)
@test size(y3) == (0, 2)
@test size(y33) == size(y3)
# Multivariate HMM should work with more observations than states:
# related to the issue: https://github.com/maxmouchet/HMMBase.jl/issues/12
y = rand(hmm3, 1000)
@test size(y) == (1000, 3)
end
@testset "Base (5)" begin
# Emission matrix constructor
hmm1 = HMM([0.9 0.1; 0.1 0.9], [0.0 0.5 0.5; 0.25 0.25 0.5])
hmm2 = HMM(
[0.9 0.1; 0.1 0.9],
[Categorical([0.0, 0.5, 0.5]), Categorical([0.25, 0.25, 0.5])],
)
@test hmm1 == hmm2
end
@testset "Constructors" begin
# Test that errors are raised
# Wrong trans. matrix
@test_throws ArgumentError HMM(ones(2, 2), [Normal(); Normal()])
# Wrong trans. matrix dimensions
@test_throws ArgumentError HMM(
[0.8 0.1 0.1; 0.1 0.1 0.8],
[Normal(0, 1), Normal(10, 1)],
)
# Wrong number of distributions
@test_throws ArgumentError HMM(
[0.8 0.2; 0.1 0.9],
[Normal(0, 1), Normal(10, 1), Normal()],
)
# Wrong distributions size
@test_throws ArgumentError HMM(
[0.8 0.2; 0.1 0.9],
[MvNormal(randn(3)), MvNormal(randn(10))],
)
# Wrong initial state
@test_throws ArgumentError HMM(
[0.1; 0.1],
[0.9 0.1; 0.1 0.9],
[Normal(0, 1), Normal(10, 1)],
)
# Wrong initial state length
@test_throws ArgumentError HMM(
[0.1; 0.1; 0.8],
[0.9 0.1; 0.1 0.9],
[Normal(0, 1), Normal(10, 1)],
)
end
@testset "Stationnary Distributions" begin
hmm1 = HMM([0.9 0.1; 0.1 0.9], [Normal(0, 1), Normal(10, 1)])
hmm2 = HMM([1.0 0.0; 0.0 1.0], [Normal(0, 1), Normal(10, 1)])
hmm3 = HMM([0.0 0.8 0.2; 0.6 0.0 0.4; 0.0 1.0 0.0], [Normal(), Normal(), Normal()])
dists1 = statdists(hmm1)
dists2 = statdists(hmm2)
dists3 = statdists(hmm3)
@test length(dists1) == 1
@test length(dists2) == 2
@test length(dists3) == 1
@test permutedims(dists2[1]) ≈ [1.0 0.0] * (hmm2.A^1000)
@test permutedims(dists2[2]) ≈ [0.0 1.0] * (hmm2.A^1000)
@test dists3[1] ≈ [15 / 53, 25 / 53, 13 / 53]
end
@testset "Messages (1)" begin
# Example from https://en.wikipedia.org/wiki/Forward%E2%80%93backward_algorithm
A = [0.7 0.3; 0.3 0.7]
B = [Categorical([0.9, 0.1]), Categorical([0.2, 0.8])]
hmm = HMM(A, B)
O = [1, 1, 2, 1, 1]
α, logtot1 = forward(hmm, O)
α = round.(α, digits = 4)
β, logtot2 = backward(hmm, O)
β = round.(β, digits = 4)
γ = posteriors(hmm, O)
γ = round.(γ, digits = 4)
@test α == [
0.8182 0.1818
0.8834 0.1166
0.1907 0.8093
0.7308 0.2692
0.8673 0.1327
]
@test β == [
0.5923 0.4077
0.3763 0.6237
0.6533 0.3467
0.6273 0.3727
1.0 1.0
]
@test γ == [
0.8673 0.1327
0.8204 0.1796
0.3075 0.6925
0.8204 0.1796
0.8673 0.1327
]
@test logtot1 ≈ logtot2
end
@testset "Messages (3)" begin
hmm = HMM([0.9 0.1; 0.1 0.9], [Normal(0, 0), Normal(10, 0)])
y = rand(hmm, 1000)
# The likelihood of a Normal distribution with std. = 0
# equals either 0 or +Inf (-Inf, +Inf in log domain).
# This cause the forward/backward algorithms to return NaNs,
# We mark the tests as broken, since I don't know if this can be fixed.
# The workaround is to set `robust = true`.
_, logtot1 = forward(hmm, y)
_, logtot2 = backward(hmm, y)
@test_broken !isnan(logtot1)
@test_broken !isnan(logtot2)
@test_nowarn viterbi(hmm, y)
_, logtot3 = forward(hmm, y, robust = true)
_, logtot4 = backward(hmm, y, robust = true)
@test !isnan(logtot3)
@test !isnan(logtot4)
@test logtot3 ≈ logtot4
end
@testset "Viterbi" begin
hmm = HMM([0.9 0.1; 0.1 0.9], [Normal(0, 1), Normal(100, 1)])
z, y = rand(hmm, 1000, seq = true)
@test viterbi(hmm, y) == z
end
@testset "MLE" begin
hmm = HMM([0.9 0.1; 0.1 0.9], [Normal(0, 1), Normal(10, 1)])
y = rand(hmm, 1000)
# Likelihood should not decrease
_, hist = fit_mle(hmm, y)
@test issorted(round.(hist.logtots, digits = 9))
_, hist = fit_mle(hmm, y, robust = true)
@test issorted(round.(hist.logtots, digits = 9))
_, hist = fit_mle(hmm, y, maxiter = 0)
@test hist.iterations == 0
@test !hist.converged
end
@testset "Utilities (1)" begin
# Make sure that we do not relabel the states if they are in 1...K
mapping, _ = gettransmat([3, 3, 1, 1, 2, 2], relabel = true)
for (k, v) in mapping
@test mapping[k] == v
end
mapping, transmat = gettransmat([3, 3, 8, 8, 3, 3], relabel = true)
@test mapping[3] == 1
@test mapping[8] == 2
@test transmat == [2 / 3 1 / 3; 1 / 2 1 / 2]
transmat = randtransmat(10)
@test issquare(transmat)
@test istransmat(transmat)
end
@testset "Utilities (2)" begin
ref = [1, 1, 2, 2, 3, 3]
seq1 = [2, 2, 3, 3, 1, 1]
seq2 = [1, 1, 1, 1, 2, 2]
@test remapseq(seq1, ref) == ref
@test remapseq(seq2, ref) == [1, 1, 1, 1, 3, 3]
end
@testset "Reproducibility" begin
hmm = HMM([0.9 0.1; 0.1 0.9], [Normal(0, 1), Normal(10, 1)])
z1, y1 = rand(MersenneTwister(0), hmm, 1000, seq = true)
z2, y2 = rand(MersenneTwister(0), hmm, 1000, seq = true)
z3, y3 = rand(hmm, 1000, seq = true)
@test z1 == z2 != z3
@test y1 == y2 != y3
A1 = randtransmat(MersenneTwister(0), Dirichlet(4, 1.0))
A2 = randtransmat(MersenneTwister(0), Dirichlet(4, 1.0))
A3 = randtransmat(Dirichlet(4, 1.0))
@test A1 == A2 != A3
A4 = randtransmat(MersenneTwister(0), 4)
A5 = randtransmat(MersenneTwister(0), 4)
A6 = randtransmat(4)
@test A4 == A5 != A6
end
@testset "Experimental" begin
# from_dict(...)
hmm1 = HMM([0.9 0.1; 0.1 0.9], [Normal(0, 1), Normal(10, 1)])
d = JSON.parse(json(hmm1))
@test from_dict(HMM{Univariate,Float64}, Normal, d) == hmm1
hmm2 = HMM(
[0.9 0.1; 0.1 0.9],
[MvNormal([0.0, 0.0], [1.0, 1.0]), MvNormal([10.0, 10.0], [1.0, 1.0])],
)
d = JSON.parse(json(hmm2))
@test_broken from_dict(HMM{Multivariate,Float64}, MvNormal, d) == hmm2
# /!\ This test fails with Distributions <= v0.23.4, and works otherwise.
# hmm3 = HMM(
# [0.9 0.1; 0.1 0.9],
# [
# MixtureModel([Normal(0, 1)]),
# MixtureModel([Normal(5, 2), Normal(10, 1)], [0.25, 0.75]),
# ],
# )
# d = JSON.parse(json(hmm3))
# @test_broken from_dict(
# HMM{Univariate,Float64},
# MixtureModel{Univariate,Continuous,Normal,Float64},
# d,
# ) == hmm3
# MixtureModel <-> HMM (stationnary distribution)
a = [0.4, 0.6]
B = [Normal(0, 1), Exponential(2)]
m = MixtureModel(B, a)
@test MixtureModel(HMM(m)).prior == m.prior
@test MixtureModel(HMM(m)).components == m.components
# TODO: Assert error if #distns != 1
end