-
Notifications
You must be signed in to change notification settings - Fork 0
/
Optimize_WF_template.py
860 lines (753 loc) · 38.1 KB
/
Optimize_WF_template.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
# -*- coding: utf-8 -*-
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# notebook_metadata_filter: -jupytext.text_representation.jupytext_version
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# kernelspec:
# display_name: Python (sinn-full)
# language: python
# name: sinn-full
# ---
# %% [markdown]
# # Optimize workflow template
#
# Executing this notebook generates **1** _Task_, which can then either be executed or saved as a _task file_. A _task file_ is a JSON file containing all the information required to run a sequence of _Tasks_: the name of each _Task_ and its input parameters.
#
# :::{tip}
# To generate a list of tasks, use the [*Generate tasks* notebook](./Generate%20tasks.py). It will execute this notebook, allowing all parameters in the *parameters cell* to be changed, and save the result as a _task file_.
# :::
# %% [markdown]
# :::{figure-md} optimize-wf-flowchart
# <img src="optimize-wf-flowchart.svg" title="Flowchart – Optimize Task workflow">
#
# **Flowchart – Optimize Task workflow** Hexagonal nodes indicate steps executed as [Tasks] (../tasks/index) during workflow _execution_. Elliptical nodes are implemented within the workflow itself and executed during its _creation_.
# <a href="https://mermaid.ink/img/eyJjb2RlIjoiZmxvd2NoYXJ0IFREXG4gICAgc3ViZ3JhcGggd29ya2Zsb3dbT3B0aW1pemUgVGFzayBXb3JrZmxvd11cbiAgICBkaXNrZGF0YVtvbi1kaXNrIGRhdGFdXG4gICAgc3ludGhkYXRhW3N5bnRoZXRpYyBkYXRhXVxuICAgIGNkYXRhe3tjcmVhdGUgZGF0YSBhY2Nlc3Nvcn19XG4gICAgY3NhbXBsZXJ7e2NyZWF0ZSBkYXRhXFxuc2VnbWVudCBzYW1wbGVyfX1cbiAgICBjbW9kZWx7e2NyZWF0ZSBtb2RlbH19XG4gICAgY3ByaW9yKFtjcmVhdGUgcHJpb3JdKVxuICAgIGNpbml0KFtjaG9vc2UgaW5pdGlhbCBwYXJhbXNdKVxuICAgIGNvYmooW2Nob29zZSBvYmplY3RpdmVdKVxuICAgIGNoeXBlcihbY2hvb3NlIGh5cGVycGFyYW10ZXJzXSlcbiAgICBjb3B0aW1pemVye3tjcmVhdGUgb3B0aW1pemVyfX1cbiAgICBjcmVjW2NyZWF0ZSByZWNvcmRlcnNdXG4gICAgY3Rlc3QoW2NyZWF0ZSBjb252ZXJnZW5jZSB0ZXN0c10pXG4gICAgY29wdGltaXple3tvcHRpbWl6ZSBtb2RlbH19XG4gICAgZGlza2RhdGEgLS4tPiBjZGF0YVxuICAgIHN5bnRoZGF0YSAtLi0-IGNkYXRhXG4gICAgY3ByaW9yIC0uLT4gc3ludGhkYXRhXG4gICAgY21vZGVsIC0uLT4gc3ludGhkYXRhXG4gICAgY3ByaW9yIC0tPiBjb3B0aW1pemVyXG4gICAgY2h5cGVyIC0uLT52aHlwZXIoW3ZhbGlkYXRlIGh5cGVycGFyYW1ldGVyc10pXG4gICAgY2RhdGEgLS0-IGNzYW1wbGVyXG4gICAgY3NhbXBsZXIgJiBjaW5pdCAmIGNtb2RlbCAmIGNvYmogJiB2aHlwZXIgLS0-IGNvcHRpbWl6ZXJcbiAgICBjb3B0aW1pemVyICYgY3JlYyAmIGN0ZXN0IC0tPiBjb3B0aW1pemVcbiAgICBlbmRcblxuICAgIHN0eWxlIGNvcHRpbWl6ZSBmaWxsOiNjZGU0ZmYsIHN0cm9rZTojMTQ3ZWZmLCBzdHJva2Utd2lkdGg6MnB4XG4gICAgc3R5bGUgd29ya2Zsb3cgZm9udC13ZWlnaHQ6Ym9sZFxuIiwibWVybWFpZCI6e30sInVwZGF0ZUVkaXRvciI6ZmFsc2V9">
# (edit)
# </a>
# :::
# %% tags=["remove-cell"]
import sinnfull
sinnfull.setup('theano')
import sinn
#sinn.config.trust_all_inputs = True # Allow deserialization of arbitrary function
# %% tags=["remove-cell"]
import logging
logging.basicConfig(level="INFO")
logger = logging.getLogger("sinnfull.optimize_template")
logger.setLevel(logging.DEBUG)
# %% tags=["remove-cell"]
import functools
from typing import List
import numpy as np
import pymc3 as pm
import pint
import smttask
import theano_shim as shim
from mackelab_toolbox.optimizers import Adam
# %% tags=["remove-cell"]
import warnings
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning)
smttask.config.record = False
# %% tags=["hide-input"]
from sinnfull.models import objectives, priors, Prior
import sinnfull.models
import sinnfull.optim
from sinnfull import ureg
from sinnfull.parameters import ParameterSet
from sinnfull.models import (TimeAxis, models, ObjectiveFunction,
get_objectives, get_prior, get_model_class)
#from sinnfull.data.synthetic import SyntheticDataAccessor
#from sinnfull.sampling import sample_baseline_segment
from sinnfull.tasks import (CreateSyntheticDataset, CreateOptimizer, CreateModel,
CreateFixedSegmentSampler, OptimizeModel)
from sinnfull.rng import get_np_rng, get_shim_rng, draw_model_sample
from sinnfull.optim import AlternatedSGD, Recorder
from sinnfull.optim.convergence_tests import ConstantCost, ConstantParams, DivergingCost
from sinnfull.utils import recursive_dict_update, get_scipy_dist
from sinnfull import projectdir
# %%
from pydantic import BaseModel
# %% [markdown]
# ## Workflow parameters
# %% [markdown]
# Modifiable parameters are set in the following cell. Because we execute the notebook with Papermill, there can be only one parameters cell (identified by giving it the tag “parameters”).
#
# - To fit parameters with ground-truth latents: \
# Move the entries in `latent_hists` to `observed_hists`.
# - The used optimizer is currently hard-coded, but one can add a flag parameter to allow choosing between different ones.
# - Models are specified with _selectors_: sets of tags that [filter](tags-taming-model-proliferation) the model collection down to exactly one model.
# %% tags=["parameters"]
# This cell is tagged 'parameters' for Papermill
reason = None
Θ_init_key = 'ground truth'
nsteps = 5000
#fit_hyperθ_updates = {'params': {'b1': 1., 'b2': 1.}}
fit_hyperθ_updates = {}
task_save_location = 'tasklist'
step_kwargs = {}
model_rngkey = (1,)
optimizer_rngkey = (2,0)
param_rngkey = 3 # Base key: keys are generator as (param_key, i)
sim_rngkey = 4 # Base key: keys are generator as (sim_key, i)
sampler_rngkey = 5
init_discard = 1 * ureg.s # Amount of time to discard from the beginning of the data
# Values are tag selectors; selectors are always sets of strings.
# model_selector may be either a single select (i.e. set of tags)
# or a dictionary with entries matching submodels (+ __root__ and __connect__)
# If there are submodels, the __connect__ entry indicates how their histories
# relate. It should be a list of strings.
model_selector = {'__root__' : {'DeterministicDynamics'},
'input' : {'GaussianWhiteNoise'},
'dynamics' : {'WilsonCowan'},
'observations': {'GaussObs'},
'__connect__' : ['GaussianWhiteNoise.ξ -> WilsonCowan.I',
'WilsonCowan.u -> GaussObs.u']}
observed_hists=['observations.ubar'] # Use dotted names to denote histories in submodels
latent_hists =['input.ξ']
from sinnfull.optim import paramsets as optim_paramsets
default_hyperparams = optim_paramsets['WC'].default # Defined in [projectdir]/sinnfull/optim/paramsets
# Values are tag selectors.
# Selectors are always sets of strings;
# they can be within dicts (to apply them to submodels)
# or lists (to indicate multiple objectives).
# All objectives are ultimately summed together
# (Future: we may add notation for coefficients multiplying objectives)
default_objective = None
params_objective = {'input': {'GaussianWhiteNoise', 'log L'},
'observations': {'GaussObs', 'log L'}}
latents_objective = {'observations': {'GaussObs', 'log L'}}
prior_spec = ParameterSet(
{'input': {'selector': {'GWN', 'default'},
'kwds': dict(mu_mean=[-0.25, -0.5],
logsigma_mean=[-1., -1.],
M=2)},
'dynamics': {'selector': {'WC', 'rich'}, 'kwds': dict(M=2)},
'observations': {'selector': {'GaussObs', 'independent'},
'kwds': dict(logvar_mean=-4.,
logvar_std=0.,
C=2,
M=2)}
})
# NB: Different priors may have different parameters
# Draw data parameters from tighter distributions than the priors
# This reflects the practice of making the priors broader to avoid unduly
# influencing the fit (Gaussian priors tend to start shaping the posterior
# even for parameter values only moderately away from their mean).
synth_param_spec = prior_spec.copy() # Requires sinnfull.ParameterSet
synth_param_spec.update({'input.kwds.mu_std': 1., # Tip: Use dotted notation to avoid
'input.kwds.logsigma_std': 0.5, # quashing other params
'dynamics.kwds.scale': 0.25,
'observations.kwds.logvar_std': 0.25})
exec_environment = "module" # Changed to 'papermill' by sinnfull.utils.generate_task_from_nb
# %% [markdown]
# **Conversion of papermill string arguments to objects**
# To avoid dealing with serialization, the papermill arguments which should be Python objects are instead passed as strings. But this means we have to decode them.
# %%
from ast import literal_eval
if Θ_init_key[0] == '(':
Θ_init_key = literal_eval(Θ_init_key)
g = globals()
# Standard “just eval the string”
for param in ['default_hyperparams', 'fit_hyperθ_updates',
'synth_param_spec', 'prior_spec',
'model_selector',
'default_objective', 'params_objective', 'latents_objective',
'model_rngkey', 'optimizer_rngkey',
'init_discard']:
pval = g[param]
if isinstance(pval, str):
pval = literal_eval(pval)
elif isinstance(pval, dict):
for k, v in pval.items():
if isinstance(v, str):
pval[k] = literal_eval(v)
g[param] = pval
# JSON converts sets into lists. Convert them back to sets/tuples.
for param in ['model_selector',
'default_objective', 'params_objective', 'latents_objective',
'synth_param_spec', 'prior_spec']:
pval = g[param]
if isinstance(pval, dict):
pval = ParameterSet(pval)
for k, v in pval.flat():
if isinstance(v, str):
literal_eval(v)
if 'kwds' not in k and isinstance(v, list):
v = tuple(v)
pval[k] = v
g[param] = pval
# %%
# These vars need to be fully deserialized because we don’t just use them as argument
from mackelab_toolbox.typing import PintValue
init_discard = PintValue.json_decoder(init_discard, noerror=True)
# %% [markdown]
# An (experimental) alternative to using papermill to execute the notebook, is to use the function
#
# ```python
# sinnfull.utils.run_as_script('sinnfull.workflows.Optimize_task_template',
# param1=value1, ...)
# ```
#
# This places parameter values in a global dictionary (`sinnfull.utils.script_args`), which the code below retrieves.
# %%
if __name__ != "__main__":
# Running within an import
# - if run through `utils.run_as_script`, there will be parameters in `utils.script_args`
# which should replace the current values
from sinnfull.utils import script_args
if __name__ in script_args:
g = globals()
# One of the script_args will set exec_environment = "script"
for k, v in script_args[__name__].items():
if k in g:
g[k] = v
# %% [markdown]
# Retrieve the model.
# %%
ModelClass = get_model_class(model_selector)
# %% [markdown]
# Retrieve all the objectives and sum them.
# %%
if default_objective is not None:
default_objective = sum(get_objectives(default_objective))
if params_objective is not None:
params_objective = sum(get_objectives(params_objective))
if latents_objective is not None:
latents_objective = sum(get_objectives(latents_objective))
# %% [markdown]
# Retrieve the prior on the parameters. This is added to the objectives by the optimizer.
# %%
prior = get_prior(ModelClass, prior_spec)
# %% [markdown]
# Retrieve the parameter distribution used to generate synthetic data. This can be the same as the prior, but if the prior is broad, it can be a good idea to sample the data parameters from tighter distributions.
# %%
synth_param_dist = get_prior(ModelClass, synth_param_spec)
# %%
if __name__ == "__main__" and exec_environment == "module":
exec_environment = "notebook"
# Possible values of exec_environment: "notebook", "module", "papermill"
# %% tags=["remove-input"]
# Imports only used in the notebook
if exec_environment == "notebook":
from IPython.display import display
import textwrap
import holoviews as hv
hv.extension('bokeh')
# %% [markdown]
# ## Load hyperparameters
# Start by loading the file containing defaults, then replace those values given in `fit_hyperθ_updates`.
# %%
fit_hyperparams = ParameterSet(default_hyperparams,
basepath=sinnfull.optim.paramsets.basepath)
recursive_dict_update(fit_hyperparams, fit_hyperθ_updates, allow_new_keys=True) # Notebook parameter
fit_units = fit_hyperparams.units
fit_hyperparams.remove_units(fit_units)
# %%
T = fit_hyperparams.T * fit_units['[time]']
Δt = fit_hyperparams.Δt * fit_units['[time]']
time = TimeAxis(min=init_discard, max=init_discard+T, step=Δt, unit=fit_units['[time]'])
# %% [markdown]
# ## Define the synthetic data
# %% [markdown]
# The synthetic data accessor replaces actual data with model simulations. Thus it needs a _model_, as well as a _prior_ on the model's parameters. The prior is sampled to generate parameter sets. Here we use the same prior as for the inference, although this need not be the case.
#
# There should be as many `param_keys` as `sim_keys`. One may repeat `param_keys` (with different `sim_keys`), to simulate multiple trials with the same parameters. Duplicate `sim_keys` (with different `param_keys`) are also permitted, although a good use case is unclear.
#
# Note that `CreateSyntheticDataset` expects a model name rather than an instance; this is to avoid having to create throwaway variables just to instantiate a model. The model class is retrieved in the same way as `ModelClass` above.
# %%
synthdata_model = CreateModel(
time = time,
model_selector = model_selector,
params = None, # Just pick anything to instantiate the model (Uses model.get_test_parameters().)
rng_key = (sim_rngkey,0) # Just use the first key; reseeds when sampling dataset anyway
)
# %%
## Instantiate data accessor ##
data = CreateSyntheticDataset(
projectdir=projectdir,
model =synthdata_model,
#model_name=model_name,
#time =time,
prior =synth_param_dist,
param_keys=[(param_rngkey,i) for i in range(1)],
sim_keys =[(sim_rngkey,i) for i in range(1)],
init_conds = {'dynamics.u': 0}
)
# %% [markdown]
# The `SegmentSampler` creates an infinite iterator which provides a new segment on every call.
#
# - `trial_filter` argument is passed to `data.sel(...)`, and allows to restrict segments to certain trials. It can also be used to restrict time windows, but for this purpose `t0` and `T` are more convenient.
# - `t0` and `T` are used to select a fixed window to sample from.
# %%
segment_iterator = CreateFixedSegmentSampler(
data=data,
trial_filter={},
t0=init_discard, T=fit_hyperparams.T*ureg.s,
rng_key=(sampler_rngkey,)
)
# %% [markdown]
# It's a very good idea to have a look at the synthetic data before committing CPU time to fitting it. Once you are confident the generated data is as expected, replace `True` by `False` to avoid unnecessary plotting.
# %% [markdown] tags=["remove-cell"]
# hv.renderer('bokeh').theme = 'dark_minimal'
# %%
if True and exec_environment == "notebook":
seg_iter = segment_iterator.run()
sampled_segment = next(iter(seg_iter))[2]
from typing import Union
from sinn.models import IndexableNamespace
def get_param_div_str(params: Union[IndexableNamespace,dict]):
# Currently the formatting & line breaks are ignored, but according to the docs they should
# So leaving this as-is, and when holoviews fixes their bug it'll look nicer (https://github.com/holoviz/holoviews/issues/4743)
lines = []
if isinstance(params, dict):
params = params.items()
for name, value in params:
lines.append(f'<b>{name}</b><br>')
if isinstance(value, (IndexableNamespace, dict)):
lines.append(textwrap.indent(get_param_div_str(value),
" "))
elif isinstance(value, np.ndarray):
lines.append(textwrap.indent(np.array_repr(value, precision=4, suppress_small=True),
" ")
.replace("\n","<br>\n")+"<br><br>")
else:
lines.append(textwrap.indent(str(value), " ")
.replace("\n","<br>\n")+"<br><br>")
return "\n".join(lines)
param_vals = hv.Div(get_param_div_str(seg_iter.data.trials.trial.data[0].params))
curves = [hv.Curve(data_var.data[:,i],
kdims=list(data_var.coords), vdims=[f"{data_var_name}_{i}"])
for data_var_name, data_var in sampled_segment.data_vars.items()
for i in range(data_var.shape[1])]
data_panel = hv.Layout(curves).opts(hv.opts.Curve(height=150, width=200))
# FIXME: It should be possible to make a nested layout, so params appear
# to the side, but I can't figure out how
display((data_panel.cols(2) + param_vals).cols(2).opts(title="First training sample"))
## Alternative (less compact but currently more readable) param value list
# print(get_param_div_str(seg_iter.data.trials.trial.data[0].params)
# .replace(" ", " ")
# .replace("<b>", "").replace("</b>", "")
# .replace("<br>", "").replace("</br>", ""))
# %%
if True and exec_environment == "notebook":
panels = []
θvalset = synth_param_dist.random((param_rngkey,0), space='optim')
# Remove parameters that were sampled but are not actually optimized
prior_var_names = [θnm for θnm in θvalset if θnm in prior.optim_vars]
non_optim_vars = {θnm: synth_param_dist[θnm] for θnm in θvalset if θnm not in prior_var_names}
if non_optim_vars:
# For parameters that are sampled but not optimized, create a table comparing their sampled and true values
model_θvalset = synth_param_dist.random((param_rngkey,0), space='model')
non_optim_vals = {}
for θnm, optim_θ in non_optim_vars.items():
for nm, model_var in synth_param_dist.model_vars.items():
if optim_θ in shim.graph.symbolic_inputs(model_var):
break
else:
nm = model_var = None
if nm:
# Assumption: if we are not fitting θ, it is because it is fixed / deterministic
non_optim_vals[nm] = {'true value': model_θvalset[nm],
'prior value': prior[nm].eval()}
for θnm in prior_var_names:
if isinstance(prior[θnm], pm.model.DeterministicWrapper):
logger.info(f"Skipping {θnm}: not a random variable.")
continue
θvals = θvalset[θnm]
shape = tuple(θvals.shape)
for idx in np.ndindex(shape):
idx_pretty = str(idx).replace('(','').replace(')','').replace(',','')
θpretty = f"{θnm.split('.')[-1]}{idx_pretty}"
θval = θvals[idx]
# Get D for a specific index, in case domain depends on idx
# (which happens for ± connectivities)
D = get_scipy_dist(prior[θnm], idx=idx)
low, high = D.a, D.b # Scipy.stats stores domain bounds as `a`, `b` attributes
# If domain is unbounded, ignore 5% on each end
if low < -1e8: # Use numerical bounds, because in some cases inf is replaced with e.g 1e12
low = D.ppf(.05)
if high > 1e8:
high = D.ppf(.95)
# Ensure domain includes the actual sampled parameter
if θval < low:
low = θval - .05 * (high-θval)
if θval > high:
high = θval + .05 * (θval-low)
# Compute pdf
xarr = np.linspace(low, high, 100)
curve = hv.Curve(
zip(xarr, D.pdf(xarr)),
kdims=[θpretty], vdims=[f"p({θpretty})"], label="pdf")
θline = hv.VLine(θval, kdims=curve.kdims+curve.vdims,
label="sample")
# Append to list of plot panels
panel = (curve * θline).opts(framewise=True).relabel(θpretty)
panel.opts(legend_position='top')
panels.append(panel)
fig = hv.Layout(panels).cols(3) \
.opts(framewise=True, title="Data sample wrt prior") \
.opts(hv.opts.Curve(framewise=True, width=200, height=200),
hv.opts.VLine(color='orange'))
display(fig)
if non_optim_vars:
import pandas as pd
df = pd.DataFrame(non_optim_vals).T
print("Non-optimized parameters:")
display(df)
# %% [markdown]
# ## Set the model parameters
#
# Valid options:
# - integer tuple: Used as a key to sample initialization parameters from `prior`.
# - `'ground truth'`: Start fit from the ground truth parameters.
# - file name: Load parameters from provided file.
# %%
valid_options = ['ground truth', 'file', 'test']
if isinstance(Θ_init_key, tuple):
Θ_init = prior.random(Θ_init_key, space='optim')
modelΘ_init = prior.backward_transform_params(Θ_init)
elif Θ_init_key == 'ground truth':
Θ_init = synth_param_dist.random((param_rngkey,0), space='optim')
modelΘ_init = synth_param_dist.backward_transform_params(Θ_init)
elif isinstance(Θ_init_key, set):
# Parameter set selector
from sinnfull.models import paramsets
modelΘ_init = paramsets[Θ_init_key]
elif os.exists(Θ_init_key):
# Untested
modelΘ_init = ParameterSet(Θ_init_key)
else:
raise ValueError(f"Unrecognized value for `Θ_init_key`: {Θ_init_key} (type: {type(Θ_init_key)}\n"
f"It should be either an RNG key (int tuple), or one of {valid_options}")
# %% [markdown]
# :::{note}
# :class: dropdown
# Setting the initialization with ground truth works because it is equivalent to
#
# - draw values directly in the model space;
# - or draw them in the optim space and backtransform them to the model
#
# (in both cases, PyMC3 draws the optim variables and transforms them). One can check this with the following::
#
# ```python
# model_θ = prior.random((0,))
# backtransformed_θ = prior.backward_transform_params(
# prior.random((0,), space='optim'))
# assert all(np.all(np.isclose(model_θ[nm], backtransformed_θ[nm]))
# for nm in model_θ)
# ```
# :::
# %% [markdown]
# Store numeric values, so
# - future changes don't change stored initial values;
# - it doesn't matter _how_ we generated the initialization parameters, just what their values are.
# %%
if hasattr(modelΘ_init, 'get_values'):
modelΘ_init = modelΘ_init.get_values()
# %% [markdown]
# It can happen that we want to fit only a subset of parameters; in that case, `prior` will have a `Deterministic` wrapping a `Constant` variable for the fixed values. We want the initial parameters to be equal to those fixed values, *not* what we sampled above, event if we initialize with ground truth.
# %%
for nm, v in prior.model_vars.items():
if (nm in modelΘ_init
and (isinstance(getattr(v, 'distribution', None), pm.Constant)
or (isinstance(v, pm.model.DeterministicWrapper)
and set(pm.model._walk_up_rv(v)) == {'Constant'}))):
target_v = modelΘ_init[nm]
assert target_v.dtype == v.dtype
modelΘ_init[nm] = v.tag.test_value
optim_var_names = set(prior.optim_vars.keys())
Θ_init = {k:v for k,v in Θ_init.items() if k in optim_var_names}
if optim_var_names != set(Θ_init):
raise ValueError(
"The set of initial parameter values (in optimization space) does not "
"match the set of parameters to optimize as defined by the prior.\n"
f"Prior parameters: {sorted(optim_var_names)}\n"
f"Init parameters: {sorted(Θ_init)}\n"
"When initializing with ground truth, the non-constant distributions "
"of the prior must match those of ground truth model. (E.g. a "
"LogNormal prior can't be used for a parameter sampled from Normal.)"
)
# %% [markdown]
# ## Instantiate the model
#
# Having chosen initial parameters `Θ_init`, instantiate the model. The `rng_key` seeds the RNG used to integrate the model.
# %%
model = CreateModel(time = time,
model_selector = model_selector,
params = modelΘ_init,
rng_key = model_rngkey)
#submodel_selectors= submodel_selectors,
#connect = submodel_connections)
# %% [markdown]
# ## Define the optimizer
# %% [markdown]
# ### Objective functions
#
# Objective functions are model-specific and defined in [*sinnfull.models*](../models). They can be retrieved by name from the `objectives` dictionary.
#
# The optimizer has a multiple keyword arguments for objective functions, to allow specifying different objectives for the parameters vs latents, or for the edges of the data segment.
#
# - `logp` : Default objective for both parameters and latents
# - `logp_params` : Default objective for the parameters
# - `logp_latents` : Default objective for the latents
# - `prior_params` : A PyMC3 model used as prior. Since they are mathematically equivalent, this can also be used to provide a regularizer.
# - `prior_latents` : Currently not used.
#
# It suffices to specify `logp` to get a valid optimizer; alternatively, `logp_params` and `logp_latents` may be specified together. In case of redundantly specified objectives, the more specific takes precedence.
#
# > **NOTE** `prior_params` should be an instance of [`Prior`](../models/base.py), which itself subclasses `PyMC3.Model`.
#
# > **NOTE** `prior_params` is used for two things:
# >
# > - Providing a prior / regularizer on the parameters.
# > - Specifying which parameters will be optimized.
# >
# > Specifically, the model parameters the optimizer will attempt to fit are exactly those returned by `prior_params.optim_vars`.
#
# The name _“logp”_ is borrowed from the name of the analogous function in [PyMC3](https://docs.pymc.io/Probability_Distributions.html). In fact any scalar objective function is acceptable; it doesn't have to be a likelihood, or even a probability.
# %% [markdown]
# ### Updating hyper parameters during optimization
# `AlternatedOptimizer` provides a callback hook to update the hyper parameters on each pass. Here we use it to scale the latent learning rate by the standard deviation of the theoretical stationary distribution for the latents.
# This expects the model to define a `stationary_stats()` method, returning a dictionary of floats.
#
# The first thing we do is check that the model indeed provides this method, and that the returned dictionary is in the expected format. Since this test is run during workflow creation instead of run execution, it catches errors immediately rather than waiting for them to be created, queued and executed.
#
# :::{note}
# The hyperparameter update function assumes that a) there exists an `input` submodel and that b) only this `input` submodel has latent variables. Thus stationary statistics are computed for that submodel only.
# :::
# %% [markdown]
# :::{margin} Standardize
# - Use the path through the 'input' submodel to select latent histories.
# :::
# %%
test_model = model.run(cache=False)
# FIXME: it should be possible to translate the history names using
# without instantiating a model, using `model_selector['__connect__']`
latent_hist_ids = {id(getattr(test_model, hname)) for hname in latent_hists}
input_hist_ids = {id(h):hname for hname, h in test_model.input.nested_histories.items()}
assert latent_hist_ids <= set(input_hist_ids), \
"This workflow assumes that all latent histories are part of the 'input' submodel."
# Replace latent_hist names with name from 'input' submodel point to the same history
latent_hists = ['input.'+input_hist_ids[histid] for histid in latent_hist_ids]
# %% [markdown]
# :::{margin} Check that
# - `CreateModel` task returns a `Model`.
# - Model defines a valid `stationary_stats` method.
# :::
# %%
from collections.abc import Callable
import inspect
from numbers import Number
from sinn import History, Model
#test_model = model.run(cache=False)
required_stats = {'std'} # The statistics used by update_hyperθ
hist_names = {h.name for h in test_model.history_set}
input_latent_hists = [hname.split('.',1)[1] for hname in latent_hists]
#input_latent_hists = list(input_hist_ids.values())
test_model = test_model.input
if not isinstance(test_model, Model):
raise TypeError(f"`model.run()` should return a sinn Model, but instead returned a {type(test_model)}.\n"
f"model_name: {model_name}\nModelClass: {ModelClass}")
if not hasattr(test_model, 'stationary_stats_eval'):
raise ValueError(f"{test_model.name} does not provide the required 'stationary_stats_eval' method.")
elif not isinstance(test_model.stationary_stats_eval, Callable):
raise ValueError(f"{test_model.name}.stationary_stats_eval is not callable")
stats = test_model.stationary_stats_eval()
if not isinstance(stats, dict):
raise ValueError(f"{test_model.name}.stationary_stats must return a dictionary. Returned: {stats} (type: {type(stats)}).")
non_hist_keys = [k for k in stats if k not in hist_names]
if non_hist_keys:
raise ValueError(f"{test_model.name}.stationary_stats must return a dictionary where keys are strings matching history names. Offending keys: {non_hist_keys}.")
#stats = {h.name: v for h,v in stats.items()}
missing_hists = set(input_latent_hists) - set(stats)
if missing_hists:
raise ValueError(f"{test_model.name}.stationary_stats needs to define statistics for all latent histories. Missing: {missing_hists}.")
not_a_dict = [h_name for h_name in input_latent_hists if not isinstance(stats[h_name], dict)]
if not_a_dict:
raise ValueError(f"{test_model.name}.stationary_stats[hist name] must be a dictionary. Offending entries: {not_a_dict}.")
missing_stats = [h_name for h_name in input_latent_hists if not required_stats <= set(stats[h_name])]
if missing_stats:
raise ValueError(f"{test_model.name}.stationary_stats must define the following statistics: {required_stats}. "
f"Some or all of these are missing for the following entries: {missing_stats}.")
return_vals = {f"{h_name} - {stat}": stats[h_name][stat]
for h_name in input_latent_hists for stat in required_stats}
does_not_return_number = {k: f"{v} (type: {type(v)})" for k,v in return_vals.items()
if not isinstance(v, (Number, np.ndarray))}
if does_not_return_number:
raise ValueError(f"{test_model.name}.stationary_stats must return a nested dictionary of plain numbers or Numpy arrays. "
f"Offending entries:\n{does_not_return_number}")
del stats, non_hist_keys, missing_hists, not_a_dict, missing_stats, return_vals, does_not_return_number
# %% [markdown]
# Having validated the model's `stationary_stats()` method, we use it to define a hyperparameter update function which scales the learning rate to each history's variance.
# %%
## Hyperparams update callback ##
def update_hyperθ(optimizer):
"""
This function is called at the beginning of each `step` to update the hyper parameters.
.. Note:: This function's code is serialized into the task description.
Therefore it must not depend on any global variables or functions.
.. Note:: This function assumes that a) there exists a submodel “input”
and b) that this submodel is the only one with latent histories.
.. Note:: The AlternatedSGD optimizer currently assumes that all latent
histories have unique names
(so they cannot only be distinguished by their submodel).
:returns: A nested dictionary matching the structure of `fit_hyperparams`. Not all
entries are required; those provided will replace the ones in `fit_hyperparams`.
"""
λη = optimizer.orig_fit_hyperparams['latents']['λη']
# Only the input has latent hists, therefore we only need its stationary stats
stats = optimizer.model.input.stationary_stats_eval()
updates = {'latents': {'λη': {}}}
for hname in optimizer.latent_hists:
assert hname.startswith('input.')
updates['latents']['λη'][hname] = λη*stats[hname.split('.')[1]]['std']
return updates
# %% [markdown]
# #### Create the optimizer
# The last step is to assemble everything into an optimizer.
# %%
## Instantiate the optimizer ##
optimizer = CreateOptimizer(
model =model,
rng_key =optimizer_rngkey,
data_segments =segment_iterator,
observed_hists =observed_hists,
latent_hists =latent_hists,
prior_params =prior,
init_params =Θ_init,
fit_hyperparams =fit_hyperparams,
update_hyperparams =update_hyperθ,
logp =default_objective,
logp_params =params_objective,
logp_latents =latents_objective,
#convergence_tests =[constant_cost, diverging_cost]
)
# %% [markdown]
# ## Optimization task
# %% [markdown]
# The final task, `OptimizeModel`, is the only recorded one (all the others are `@MemoizedTask`'s). Its definition is quite simple:
#
# - Attach the recorders to the optimizer.
# - Iterate the optimizer for the number of steps specified by `nsteps`.
#
# Because it is a `@RecordedIterativeTask`, it is able to continue a fit from a previous one with fewer steps. *Iterative* tasks define two additional attributes compared to recorded ones:
#
# - An *iteration parameter* (in this case, `nsteps`). This must be an integer, must be among both the inputs and outputs, and must increase by 1 for each iteration.
# - An *iteration map*: A dictionary mapping outputs of one iteration to the inputs of the next. In this case the map is simply `{'optimizer: 'optimizer', 'recorders': 'recorders'}`.
# %% [markdown]
# ### Recorders
#
# Recorders are used by the `Optimize` task to record the optimizer's state during the optimization process. The recording frequency can be set independently for each recorder (so e.g. the expensive `latents_recorder` is executed less often).
#
# > Recorders are defined in [_optim/recorders_](../optim/recorders.py).
# %%
from sinnfull.optim.recorders import (
LogpRecorder, ΘRecorder, LatentsRecorder)
# %%
logL_recorder = LogpRecorder()
Θ_recorder = ΘRecorder(keys=tuple(prior.optim_vars.keys()))
latents_recorder = LatentsRecorder(optimizer)
# %% [markdown]
# ### Early-stopping conditions
#
# We add two tests for stopping a fit early:
#
# - `constant_cost` will return `Converged` if the last $n$ recorded $\log L$ values are all within $tol$ of each other.<br>
# (Equiv to: for any given pair, one parameter set is at most ($1 - e^{tol}$)% more likely than the other.)
# - `diverging_cost` will return `Failed` if the current cost value is lower than the initial value, or NaN.<br>
# This is based on the observation that randomly initialized fits almost always start with a sharp increase in the $\log L$. If further iterations annul this initial progress, the fit does not seem able to make stable progress.
# > Depending on circumstance, `diverging_cost` may be too strict for fits starting from ground truth parameters, since those start with a relatively high likelihood.
# %%
constant_cost = ConstantCost(cost_recorder='log L', tol=2**-7, n=12)
diverging_cost = DivergingCost(cost_recorder='log L', maximize=True)
constant_params = ConstantParams(param_recorder='Θ', rtol=2**-6, n=6)
# %% [markdown]
# ### `Optimize` Task
# %%
#for optimizer, nsteps in zip(optimizers, nsteps_list):
if reason is None:
reason = ""
else:
reason += "\n\n"
reason += \
f"""
- Init params: {Θ_init_key}
- {nsteps} passes
""".rstrip()
for k, v in fit_hyperθ_updates.items():
if isinstance(v, pint.Quantity):
reason += f"\n- {k}: {v:~}"
else:
reason += f"\n- {k}: {v}"
optimize = OptimizeModel(reason =reason,
nsteps =nsteps,
optimizer=optimizer,
step_kwargs=step_kwargs,
recorders=[logL_recorder, Θ_recorder, latents_recorder],
convergence_tests =[constant_cost & constant_params, diverging_cost]
)
# %%
# Test that task serialization is idempotent
# (If this fails, use diagnose/compare_tasks to identify where the change occurred)
if False:
from smttask import Task
import json
optimize.save("test-exported-optimize-task1")
task2 = Task.from_desc("test-exported-optimize-task1.taskdesc.json")
task2.save("test-exported-optimize-task2")
with open("test-exported-optimize-task1.taskdesc.json") as f:
json1 = f.read()
with open("test-exported-optimize-task2.taskdesc.json") as f:
json2 = f.read()
assert json1 == json2
# %% [markdown]
# ## Export or run the task
#
# At this point, we either
#
# - Save the created task to a file, so it can be run later.
# (If this code is being executed as a script.)
# - Execute the task.
# (If this code is being executed in a notebook.)
# %%
if exec_environment != "notebook":
# Notebook is either being run as a script or through papermill
optimize.save(task_save_location)
# %% [markdown]
# We can visualize the workflow by calling the `draw` method of the highest level task. It's a quick 'n dirty visualization, but still sometimes a useful way to see dependencies.
#
# ::::{margin}
# :::{hint}
# The Task nodes in this diagram correspond to the hexagonal nodes in the [flowchart](optimize-wf-flowchart) at the top.
# :::
# ::::
# %%
if True and exec_environment == "notebook":
optimize.draw()
# %% [markdown]
# The call to `run()` will recurse through the workflow tree, executing all required tasks. Since `OptimizeModel` is a `RecordedTask`, its result is saved to disk so that it will not need to be run again in the future if called with the same parameters.
# (To force a rerun, e.g. if the model code changed, one can execute `optimize.run(recompute=True)`.)
# %% [markdown] tags=["remove-cell"]
# Debugging options:
# ```python
# import theano
# theano.config.compute_test_value = 'warn'
# theano.config.optimizer = 'fast_compile'
# theano.config.NanGuardMode__action = 'raise'
# import mackelab_toolbox.optimizers
# mackelab_toolbox.optimizers.debug_flags['print grads'] = True
# # (compile functions with `shim.graph.compile(…, mode='guard:nan,inf,big')`)
# ```
# %% tags=["skip-execution"]
if exec_environment == "notebook":
result = optimize.run(record=False, recompute=True)
# %% [markdown]
# ---
# %% tags=["remove-cell"]
# %debug