-
Notifications
You must be signed in to change notification settings - Fork 14
/
parameters.py
1405 lines (1228 loc) · 50.3 KB
/
parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import copy
import itertools
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
import functools
from typing import Optional, Dict, List, Any, Union, Mapping
import warnings
import numpy as np
from marshmallow import ValidationError as MarshmallowValidationError
from paramtools import utils
from paramtools import contrib
from paramtools.schema import ParamToolsSchema
from paramtools.schema_factory import SchemaFactory
from paramtools.sorted_key_list import SortedKeyList
from paramtools.typing import ValueObject, FileDictStringLike
from paramtools.exceptions import (
ParamToolsError,
SparseValueObjectsException,
ValidationError,
InconsistentLabelsException,
collision_list,
ParameterNameCollisionException,
)
from paramtools.values import Values, union, intersection
class ParameterSlice:
__slots__ = ("parameters", "_cache", "_key_cache")
def __init__(self, parameters):
self.parameters = parameters
self._cache = {}
self._key_cache = {}
def __getitem__(self, parameter_or_values):
keyfuncs = dict(self.parameters.keyfuncs)
if (
isinstance(parameter_or_values, str)
and parameter_or_values in self._cache
):
return self._cache[parameter_or_values]
elif isinstance(parameter_or_values, str):
data = self.parameters._data.get(parameter_or_values)
if data is None:
raise ValueError(f"Unknown parameter: {parameter_or_values}.")
try:
keyfunc = self._key_cache.get(parameter_or_values, None)
if keyfunc is None:
keyfunc = self.parameters._validator_schema.field_keyfunc(
parameter_or_values
)
self._key_cache[parameter_or_values] = keyfunc
self._cache[parameter_or_values] = keyfunc
keyfuncs["value"] = keyfunc
values = Values(data["value"], keyfuncs=keyfuncs)
self._cache[parameter_or_values] = values
return values
except contrib.validate.ValidationError as ve:
raise ParamToolsError(
f"There was an error retrieving the field for {parameter_or_values}",
{},
) from ve
else:
return Values(parameter_or_values, keyfuncs=keyfuncs)
class Parameters:
defaults = None
array_first: bool = False
label_to_extend: str = None
uses_extend_func: bool = False
index_rates: Dict = {}
def __init__(
self,
initial_state: Optional[dict] = None,
index_rates: Optional[dict] = None,
sort_values: bool = True,
**ops,
):
schemafactory = SchemaFactory(self.defaults)
(
self._defaults_schema,
self._validator_schema,
self._schema,
self._data,
) = schemafactory.schemas()
self.label_validators = schemafactory.label_validators
self.keyfuncs = {}
for label, lv in self.label_validators.items():
cmp_funcs = getattr(lv, "cmp_funcs", None)
if cmp_funcs is not None:
self.keyfuncs[label] = cmp_funcs()["key"]
self._stateless_label_grid = OrderedDict()
for name, v in self.label_validators.items():
if hasattr(v, "grid"):
self._stateless_label_grid[name] = v.grid()
else:
self._stateless_label_grid[name] = []
self.label_grid = copy.deepcopy(self._stateless_label_grid)
self._validator_schema.context["spec"] = self
self._warnings = {}
self._errors = {}
self._defer_validation = False
self._state = self.parse_labels(**(initial_state or {}))
self.index_rates = index_rates or self.index_rates
self.sel = ParameterSlice(self)
# set operators in order of importance:
# __init__ arg: most important
# class attribute: middle importance
# schema action: least important
# default value if three above are not specified.
default_ops = [
("array_first", False),
("label_to_extend", None),
("uses_extend_func", False),
]
schema_ops = self._schema.get("operators", {})
for name, default in default_ops:
if name in ops:
setattr(self, name, ops.get(name))
elif getattr(self, name, None) != default:
setattr(self, name, getattr(self, name))
elif name in schema_ops:
setattr(self, name, schema_ops[name])
else:
setattr(self, name, default)
if self.label_to_extend:
prev_array_first = self.array_first
self.array_first = False
self.set_state()
self.extend()
if prev_array_first:
self.array_first = True
self.set_state()
else:
self.set_state()
if "operators" not in self._schema:
self._schema["operators"] = {}
self._schema["operators"].update(self.operators)
if sort_values:
self.sort_values()
def __getitem__(self, parameter):
raise AttributeError(
f'Use params.sel["{parameter}"] instead of params["{parameter}"].'
)
def set_state(self, **labels):
"""
Sets state for the Parameters instance. The `_state`, `label_grid`, and
parameter attributes are all updated with the new state.
Use the `view_state` method to inspect the current state of the instance,
and use the `clear_state` method to revert to the default state.
**Raises**
- `ValidationError` if the labels kwargs contain labels that are not
specified in schema.json or if the label values fail the
validator set for the corresponding label in schema.json.
"""
self._set_state(**labels)
def clear_state(self):
"""
Reset the state of the `Parameters` instance.
"""
self._state = {}
self.label_grid = copy.deepcopy(self._stateless_label_grid)
self.set_state()
def view_state(self):
"""
Access the label state of the ``Parameters`` instance.
"""
return {label: value for label, value in self._state.items()}
def read_params(
self,
params_or_path: FileDictStringLike,
storage_options: Optional[Dict[str, Any]] = None,
):
"""
Read JSON data of the form:
- Python `dict`.
- JSON string.
- Local file path.
- Any URL readable by fsspec. For example:
- s3: `s3://paramtools-test/defaults.json`
- gcs: `gs://paramtools-dev/defaults.json`
- http: `https://somedomain.com/defaults.json`
- github: `github://PSLmodels:ParamTools@master/paramtools/tests/defaults.json`
**Returns**
- `params`: Python Dict created from JSON file.
"""
return utils.read_json(params_or_path, storage_options)
def adjust(
self,
params_or_path: Union[str, Mapping[str, List[ValueObject]]],
ignore_warnings: bool = False,
raise_errors: bool = True,
extend_adj: bool = True,
clobber: bool = True,
):
"""
Deserialize and validate parameter adjustments. `params_or_path`
can be a file path or a `dict` that has not been fully deserialized.
The adjusted values replace the current values stored in the
corresponding parameter attributes.
If `clobber` is `True` and extend mode is on, then all future values
for a given parameter be replaced by the values in the adjustment.
If `clobber` is `False` and extend mode is on, then user-defined values
will not be replaced by values in this adjustment. Only values that
were added automatically via the extend method will be updated.
This simply calls a private method `_adjust` to do the upate. Creating
this layer on top of `_adjust` makes it easy to subclass `Parameters` and
implement custom `adjust` methods.
**Parameters**
- `params_or_path`: Adjustment that is either a `dict`, file path, or
JSON string.
- `ignore_warnings`: Whether to raise an error on warnings or ignore them.
- `raise_errors`: Either raise errors or simply store the error messages.
- `extend_adj`: If in extend mode, this is a flag indicating whether to
extend the adjustment values or not.
- `clobber`: If in extend mode, this is a flag indicating whether to
override all values, including user-defined values, or to only
override automatically created values.
**Returns**
- `params`: Parsed, validated parameters.
**Raises**
- `marshmallow.exceptions.ValidationError` if data is not valid.
- `ParameterUpdateException` if label values do not match at
least one existing value item's corresponding label values.
"""
return self._adjust(
params_or_path,
ignore_warnings=ignore_warnings,
raise_errors=raise_errors,
extend_adj=extend_adj,
clobber=clobber,
)
def _adjust(
self,
params_or_path,
ignore_warnings=False,
raise_errors=True,
extend_adj=True,
deserialized=False,
validate=True,
clobber=True,
):
"""
Internal method for performing adjustments.
"""
# Validate user adjustments.
if deserialized:
parsed_params = {}
try:
parsed_params = self._validator_schema.load(
params_or_path, ignore_warnings, deserialized=True
)
except MarshmallowValidationError as ve:
self._parse_validation_messages(ve.messages, params_or_path)
else:
params = self.read_params(params_or_path)
parsed_params = {}
try:
parsed_params = self._validator_schema.load(
params, ignore_warnings
)
except MarshmallowValidationError as ve:
self._parse_validation_messages(ve.messages, params)
if not self._errors:
if self.label_to_extend is not None and extend_adj:
extend_grid = self._stateless_label_grid[self.label_to_extend]
to_delete = defaultdict(list)
backup = {}
for param, vos in parsed_params.items():
for vo in utils.grid_sort(
vos, self.label_to_extend, extend_grid
):
if self.label_to_extend in vo:
if clobber:
queryset = self.sel[param]
else:
queryset = self.sel[param]["_auto"] == True
queryset &= queryset.gt(
strict=False,
**{
self.label_to_extend: vo[
self.label_to_extend
]
},
)
other_labels = utils.filter_labels(
vo,
drop=[self.label_to_extend, "value", "_auto"],
)
if other_labels:
queryset &= intersection(
queryset.eq(strict=False, **{label: value})
for label, value in other_labels.items()
)
to_delete[param] += list(queryset)
# make copy of value objects since they
# are about to be modified
backup[param] = copy.deepcopy(self._data[param]["value"])
try:
array_first = self.array_first
self.array_first = False
# delete params that will be overwritten out by extend.
self.delete(
to_delete,
extend_adj=False,
raise_errors=True,
ignore_warnings=ignore_warnings,
)
# set user adjustments.
self._adjust(
parsed_params,
extend_adj=False,
raise_errors=True,
ignore_warnings=ignore_warnings,
)
self.extend(
params=parsed_params.keys(),
ignore_warnings=ignore_warnings,
raise_errors=True,
)
except ValidationError:
for param in backup:
self._data[param]["value"] = backup[param]
finally:
self.array_first = array_first
else:
for param, value in parsed_params.items():
self._update_param(param, value)
self._validator_schema.context["spec"] = self
has_errors = bool(self._errors.get("messages"))
has_warnings = bool(self._warnings.get("messages"))
# throw error if raise_errors is True or ignore_warnings is False
if (raise_errors and has_errors) or (
not ignore_warnings and has_warnings
):
raise self.validation_error
# Update attrs for params that were adjusted.
self._set_state(params=parsed_params.keys())
return parsed_params
@contextmanager
def transaction(
self, defer_validation=True, raise_errors=False, ignore_warnings=False
):
"""
Rollback any changes to parameter state after the context block closes.
.. code-block:: Python
import paramtools
class Params(paramtools.Parameters):
defaults = {
"min_param": {
"title": "Min param",
"description": "Must be less than 'max_param'",
"type": "int",
"value": 2,
"validators": {
"range": {"max": "max_param"}
}
},
"max_param": {
"title": "Max param",
"type": "int",
"value": 3
}
}
params = Params()
with params.transaction():
params.adjust({"min_param": 4})
params.adjust({"max_param": 5})
**Parameters:**
- `defer_validation`: Defer schema-level validation until the end of the block.
- `ignore_warnings`: Whether to raise an error on warnings or ignore them.
- `raise_errors`: Either raise errors or simply store the error messages.
"""
_data = copy.deepcopy(self._data)
_ops = dict(self.operators)
_state = dict(self.view_state())
try:
self._defer_validation = defer_validation
yield self
except Exception as e:
self._data = _data
raise e
finally:
self._state = _state
self._ops = _ops
self._defer_validation = False
if defer_validation:
self.validate(
self.specification(use_state=False, meta_data=False),
ignore_warnings=ignore_warnings,
raise_errors=raise_errors,
)
def validate(self, params, raise_errors=True, ignore_warnings=False):
"""
Validate parameter adjustment without modifying existing values.
For example, validate the current parameter values:
.. code-block:: Python
params.validate(
params.specification(use_state=False)
)
**Parameters:**
- `params`: Parameters to validate.
- `ignore_warnings`: Whether to raise an error on warnings or ignore them.
- `raise_errors`: Either raise errors or simply store the error messages.
"""
try:
self._validator_schema.load(
params, ignore_warnings, deserialized=True
)
except MarshmallowValidationError as ve:
self._parse_validation_messages(ve.messages, params)
has_errors = bool(self._errors.get("messages"))
has_warnings = bool(self._warnings.get("messages"))
if (raise_errors and has_errors) or (
not ignore_warnings and has_warnings
):
raise self.validation_error
def delete(
self,
params_or_path,
ignore_warnings=False,
raise_errors=True,
extend_adj=True,
):
"""
Delete value objects in params_or_path.
Returns: adjustment for deleting parameters.
Raises:
marshmallow.exceptions.ValidationError if data is not valid.
ParameterUpdateException if label values do not match at
least one existing value item's corresponding label values.
"""
return self._delete(
params_or_path,
ignore_warnings=ignore_warnings,
raise_errors=raise_errors,
extend_adj=extend_adj,
)
def _delete(
self,
params_or_path,
ignore_warnings=False,
raise_errors=True,
extend_adj=True,
):
"""
Internal method that sets the 'value' member for all value objects
to None. Value objects with 'value' set to None are deleted.
"""
params = self.read_params(params_or_path)
# Validate user adjustments.
parsed_params = {}
try:
parsed_params = self._validator_schema.load(
params, ignore_warnings=True
)
except MarshmallowValidationError as ve:
self._parse_validation_messages(ve.messages, params)
to_delete = {}
for param, vos in parsed_params.items():
to_delete[param] = [dict(vo, **{"value": None}) for vo in vos]
self._update_param(param, to_delete[param])
if self.label_to_extend is not None and extend_adj:
self.extend()
self._validator_schema.context["spec"] = self
has_errors = bool(self._errors.get("messages"))
has_warnings = bool(self._warnings.get("messages"))
# throw error if raise_errors is True or ignore_warnings is False
if (raise_errors and has_errors) or (
not ignore_warnings and has_warnings
):
raise self.validation_error
# Update attrs for params that were adjusted.
self._set_state(params=to_delete.keys())
return to_delete
@property
def errors(self):
if not self._errors:
return {}
return {
param: utils.ravel(messages)
for param, messages in self._errors["messages"].items()
}
@property
def warnings(self):
if not self._warnings:
return {}
return {
param: utils.ravel(messages)
for param, messages in self._warnings["messages"].items()
}
@property
def validation_error(self):
messages = {
"errors": self._errors.get("messages", {}),
"warnings": self._warnings.get("messages", {}),
}
labels = {
"errors": self._errors.get("labels", {}),
"warnings": self._warnings.get("labels", {}),
}
return ValidationError(messages=messages, labels=labels)
@property
def schema(self):
pre = dict(self._schema)
pre["operators"] = self.operators
return ParamToolsSchema().dump(pre)
@property
def operators(self):
return {
"array_first": self.array_first,
"label_to_extend": self.label_to_extend,
"uses_extend_func": self.uses_extend_func,
}
def dump(self, sort_values: bool = True, use_state: bool = True):
"""
Dump a representation of this instance to JSON. This makes it
possible to load this instance's data after sending the data
across the wire or from another programming language. The
dumped values will be queried using this instance's state.
"""
spec = self.specification(
meta_data=True,
include_empty=True,
serializable=True,
sort_values=sort_values,
use_state=use_state,
)
result = {"schema": self.schema}
result.update(spec)
return result
def specification(
self,
use_state: bool = True,
meta_data: bool = False,
include_empty: bool = False,
serializable: bool = False,
sort_values: bool = False,
**labels,
):
"""
Query value(s) of all parameters along labels specified in
`labels`.
**Parameters**
- `use_state`: Use the instance's state for the select operation.
- `meta_data`: Include information like the parameter
`description` and title.
- `include_empty`: Include parameters that do not meet the label query.
- `serializable`: Return data that is compatible with `json.dumps`.
- `sort_values`: Sort values by the `label` order.
**Returns**
- `dict` of parameter names and data.
"""
if use_state:
labels.update(self._state)
all_params = OrderedDict()
for param in self._validator_schema.fields:
result = self.select_eq(param, False, **labels)
if sort_values and result:
result = self.sort_values(
data={param: result}, has_meta_data=False
)[param]
if result or include_empty:
if meta_data:
param_data = self._data[param]
result = dict(param_data, **{"value": result})
# Add "value" key to match marshmallow schema format.
elif serializable:
result = {"value": result}
all_params[param] = result
if serializable:
ser = self._defaults_schema.dump(all_params)
# Unpack the values after serialization if meta_data not specified.
if not meta_data:
ser = {param: value["value"] for param, value in ser.items()}
return ser
else:
return all_params
def to_array(self, param, **labels):
"""
Convert a Value object to an n-labelal array. The list of Value
objects must span the specified parameter space. The parameter space
is defined by inspecting the label validators in schema.json
and the state attribute of the Parameters instance.
**Parameters**
- `param`: Name of parameter that will be used to create array.
- `labels`: Optionally, override instance state.
**Returns**
- `arr`: NumPy array created from list of value objects.
**Raises**
- `InconsistentLabelsException`: Value objects do not have consistent
labels.
- `SparseValueObjectsException`: Value object does not span the
entire space specified by the Order object.
- `ParamToolsError`: Parameter is an array type and has labels.
This is not supported by ParamTools when using array_first.
"""
label_grid = copy.deepcopy(self.label_grid)
state = copy.deepcopy(self._state)
if labels:
parsed_labels = self.parse_labels(**labels)
label_grid.update(parsed_labels)
state.update(parsed_labels)
if state:
value_items = list(
intersection(
self.sel[param].isin(strict=False, **{label: values})
for label, values in state.items()
)
)
else:
value_items = list(self.sel[param])
if not value_items:
return np.array([])
label_order, value_order = self._resolve_order(
param, value_items, label_grid
)
shape = []
for label in label_order:
shape.append(len(value_order[label]))
shape = tuple(shape)
# Compare len value items with the expected length if they are full.
# In the futute, sparse objects should be supported by filling in the
# unspecified labels.
number_dims = self._data[param].get("number_dims", 0)
if not shape and number_dims > 0:
return np.array(
value_items[0]["value"], dtype=self._numpy_type(param)
)
elif shape and number_dims > 0:
raise ParamToolsError(
f"\nParameter '{param}' is an array parameter with {number_dims} dimension(s) and "
f"has labels: {', '.join(label_order)}.\n\nParamTools does not "
f"support the use of 'array_first' with array parameters that use labels. "
f"\nYou may be able to describe this parameter's values with additional "
f"labels\nand the 'label_to_extend' operator."
)
elif not shape and number_dims == 0:
data_type = self._numpy_type(param)
value = value_items[0]["value"]
if data_type == object:
return value
else:
return data_type(value)
exp_full_shape = functools.reduce(lambda x, y: x * y, shape)
act_full_shape = len(value_items)
if act_full_shape != exp_full_shape:
# maintains label value order over value objects.
exp_grid = list(itertools.product(*value_order.values()))
# preserve label value order for each value object by
# iterating over label_order.
actual = list(
[tuple(vo[d] for d in label_order) for vo in value_items]
)
missing = "\n\t".join(
[str(d) for d in exp_grid if d not in actual]
)
counter = defaultdict(int)
extra = []
duplicates = []
for comb in actual:
counter[comb] += 1
if counter[comb] > 1:
duplicates.append((comb, counter[comb]))
if comb not in exp_grid:
extra.append(comb)
msg = ""
if missing:
msg += f"Missing combinations:\n\t{missing}"
if extra:
msg += f"Extra combinations:\n\t{extra}"
if duplicates:
msg += f"Duplicate combinations:\n\t{duplicates}"
raise SparseValueObjectsException(
f"The Value objects for {param} do not span the specified "
f"parameter space. {msg}"
)
def list_2_tuple(x):
return tuple(x) if isinstance(x, list) else x
arr = np.empty(shape, dtype=self._numpy_type(param))
for vi in value_items:
# ix stores the indices of `arr` that need to be filled in.
ix = [[] for i in range(len(label_order))]
for label_pos, label_name in enumerate(label_order):
# assume value_items is dense in the sense that it spans
# the label space.
ix[label_pos].append(
value_order[label_name].index(vi[label_name])
)
ix = tuple(map(list_2_tuple, ix))
arr[ix] = vi["value"]
return arr
def from_array(self, param, array=None, **labels):
"""
Convert NumPy array to a Value object.
**Parameters**
- `param`: Name of parameter to convert to a list of value objects.
- `array`: Optionally, provide a NumPy array to convert into a list
of value objects. If not specified, the value at `self.param` will
be used.
- `labels`: Optionally, override instance state.
**Returns**
- List of `ValueObjects`
**Raises**
- `InconsistentLabelsException`: Value objects do not have consistent
labels.
"""
if array is None:
array = getattr(self, param)
if not isinstance(array, np.ndarray):
raise TypeError(
"A NumPy Ndarray should be passed to this method "
"or the instance attribute should be an array."
)
label_grid = copy.deepcopy(self.label_grid)
state = copy.deepcopy(self._state)
if labels:
parsed_labels = self.parse_labels(**labels)
label_grid.update(parsed_labels)
state.update(parsed_labels)
if state:
value_items = list(
intersection(
self.sel[param].isin(strict=False, **{label: value})
for label, value in state.items()
)
)
else:
value_items = list(self.sel[param])
label_order, value_order = self._resolve_order(
param, value_items, label_grid
)
label_values = itertools.product(*value_order.values())
label_indices = itertools.product(
*map(lambda x: range(len(x)), value_order.values())
)
value_items = []
for dv, di in zip(label_values, label_indices):
vi = {label_order[j]: dv[j] for j in range(len(dv))}
vi["value"] = array[di]
value_items.append(vi)
return value_items
def extend(
self,
label: Optional[str] = None,
label_values: Optional[List[Any]] = None,
params: Optional[List[str]] = None,
raise_errors: bool = True,
ignore_warnings: bool = False,
):
"""
Extend parameters along `label`.
**Parameters**
- `label`: Label to extend values along. By default, `label_to_extend`
is used.
- `label_values`: values of `label` to extend. By default, this is a grid
created from the valid values of `label_to_extend`.
- `params`: Parameters to extend. By default, all parameters are extended.
- `raise_errors`: Whether `adjust` should raise or store errors.
- `ignore_warnings`: Whether `adjust` should raise or ignore warnings.
**Raises**
- `InconsistentLabelsException`: Value objects do not have consistent
labels.
"""
if label is None:
label = self.label_to_extend
else:
label = label
spec = self.specification(meta_data=True)
if params is not None:
spec = {
param: self._data[param]
for param, data in spec.items()
if param in params
}
full_extend_grid = self._stateless_label_grid[label]
if label_values is not None:
labels = self.parse_labels(**{label: label_values})
extend_grid = labels[label]
else:
extend_grid = self._stateless_label_grid[label]
cmp_funcs = self.label_validators[label].cmp_funcs(choices=extend_grid)
adjustment = defaultdict(list)
for param, data in spec.items():
if not any(label in vo for vo in data["value"]):
continue
extended_vos = set()
for vo in sorted(
data["value"], key=lambda val: cmp_funcs["key"](val[label])
):
hashable_vo = utils.hashable_value_object(vo)
if hashable_vo in extended_vos:
continue
else:
extended_vos.add(hashable_vo)
queryset = self.sel[param].gt(
strict=False, **{label: vo[label]}
)
other_labels = utils.filter_labels(
vo, drop=["value", label, "_auto"]
)
if other_labels:
queryset &= intersection(
queryset.eq(strict=False, **{oth_label: value})
for oth_label, value in other_labels.items()
)
extended_vos.update(
map(utils.hashable_value_object, list(queryset))
)
values = queryset.as_values().add(values=[vo])
defined_vals = {eq_vo[label] for eq_vo in queryset}
missing_vals = sorted(
set(extend_grid) - defined_vals, key=cmp_funcs["key"]
)
if not missing_vals:
continue
extended = defaultdict(list)
for vo in values:
extended[vo[label]].append(vo)
skl = SortedKeyList(extended.keys(), cmp_funcs["key"])
for val in missing_vals:
lte_val = skl.lte(val)
if lte_val is not None:
closest_val = lte_val.values[-1]
else:
closest_val = skl.gte(val).values[0]
if closest_val in extended:
value_objects = extended.pop(closest_val)
else:
value_objects = values.eq(
strict=False, **{label: closest_val}
)
# In practice, value_objects has length one.
# Theoretically, there could be multiple if the inital value
# object had less labels than later value objects and thus
# matched multiple value objects.
for value_object in value_objects:
ext = dict(value_object, **{label: val})
ext = self.extend_func(
param, ext, value_object, full_extend_grid, label
)
extended_vos.add(
utils.hashable_value_object(value_object)
)
extended[val].append(ext)
skl.add(val)
adjustment[param].append(OrderedDict(ext, _auto=True))
# Ensure that the adjust method of paramtools.Parameters is used
# in case the child class also implements adjust.
return self._adjust(
adjustment,
extend_adj=False,
ignore_warnings=ignore_warnings,
raise_errors=raise_errors,
deserialized=True,
)
def extend_func(
self,
param: str,
extend_vo: ValueObject,
known_vo: ValueObject,
extend_grid: List,
label: str,
):
"""
Function for applying indexing rates to parameter values as they
are extended. Projects may implement their own `extend_func` by
overriding this one. Projects need to write their own `indexing_rate`
method for returning the correct indexing rate for a given parameter
and value of `label`.
**Returns**
- `extend_vo`: New `ValueObject`.
"""
if not self.uses_extend_func or not self._data[param].get(
"indexed", False
):
return extend_vo
known_val = known_vo[label]
known_ix = extend_grid.index(known_val)
toext_val = extend_vo[label]