-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_maxsat_instances.py
201 lines (176 loc) · 7.98 KB
/
generate_maxsat_instances.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import argparse
import copy
import math
from pathlib import Path
from typing import Counter, Dict, List, Optional, Sequence, Tuple
import chainer
import numpy as np
import bnn
import datasets
import encoder
import visualize
def add_norm(enc: encoder.Encoder, norm: str,
mod: Sequence[Tuple[encoder.Lit, Optional[int]]]) -> None:
if norm == '0':
for lit, w in mod:
enc.add_clause([lit], None if w is None else 1)
elif norm == '1':
for lit, w in mod:
enc.add_clause([lit], None if w is None else abs(w))
elif norm == '2':
for lit, w in mod:
enc.add_clause([lit], None if w is None else w*w)
elif norm == 'inf':
d: Dict[int, List[encoder.Lit]] = {}
for lit, w in mod:
if w is None:
enc.add_clause([lit])
else:
w = abs(w)
if w not in d:
d[w] = []
d[w].append(lit)
c_prev = 0
relax_prev = None
for w in sorted(d.keys()):
relax = enc.new_var()
enc.add_clause([-relax], cost=w - c_prev)
if relax_prev is not None:
enc.add_clause([-relax, relax_prev]) # relax → relax_prev
for lit in d[w]:
enc.add_clause([relax, lit]) # ¬lit → relax
c_prev = w
relax_prev = relax
else:
raise RuntimeError("unknown norm: " + str(norm))
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, choices=['mnist', 'mnist_back_image', 'mnist_rot'], default='mnist', help='dataset name')
parser.add_argument('--model', type=str, default=None, help='model file (*.npz)')
parser.add_argument('-o', '--output-dir', type=str, default="instances/maxsat", help='output directory')
parser.add_argument('--format', type=str, choices=["wbo", "wcnf"], help='file format')
parser.add_argument('--norm', type=str, choices=['0', '1', '2', 'inf'], nargs='*', default=['0', '1', '2', 'inf'], help='encoding of cardinality constraints')
parser.add_argument('--card', type=str, choices=["sequential", "parallel", "totalizer"], default="parallel", help='encoding of cardinality constraints')
parser.add_argument('--target', type=str, default="adversarial", choices=['adversarial', 'truelabel'], help='target label')
parser.add_argument('--instance-no', type=int, default=None, help='specify instance number')
parser.add_argument('--instances-per-class', type=int, default=None, help='number of instances to generate per class')
parser.add_argument('--debug-sat', action='store_true', help='produce CNF or OPB for debug')
parser.add_argument('--ratio', type=float, nargs='*', default=[1.0], help='restrict search space to most salient pixels')
args = parser.parse_args()
result_dir = Path(args.output_dir)
result_dir.mkdir(parents=True, exist_ok=True)
train, test = datasets.get_dataset(args.dataset)
if args.model is None:
weights_filename = f"models/{args.dataset}.npz"
else:
weights_filename = args.model
neurons = [28 * 28, 200, 100, 100, 100, 10]
model = bnn.BNN(neurons, stochastic_activation=True)
chainer.serializers.load_npz(weights_filename, model)
orig_image_scaled = test._datasets[0]
orig_image_scaled = orig_image_scaled[:1000]
orig_image = np.round(orig_image_scaled * 255).astype(np.uint8)
with chainer.using_config("train", False), chainer.using_config("enable_backprop", False):
orig_image_bin = bnn.bin(model.input_bn(orig_image_scaled)).array > 0
orig_logits = model(orig_image_scaled).array
predicated_label = np.argmax(orig_logits, axis=1)
if args.format == "wbo":
enc_base = encoder.BNNEncoder(cnf=False)
elif args.format == "wcnf":
enc_base = encoder.BNNEncoder(cnf=True, counter=args.card)
else:
raise RuntimeError("unknown ext: " + args.format)
inputs = enc_base.new_vars(784)
outputs = enc_base.new_vars(10)
enc_base.encode_bin_input(model, inputs, outputs)
counter = Counter[int]()
for instance_no, (x, true_label) in enumerate(test):
if args.instance_no is not None and instance_no != args.instance_no:
continue
if args.instances_per_class is not None and counter[true_label] >= args.instances_per_class:
continue
print(f"dataset={args.dataset}; instance={instance_no}; true_label={true_label} predicted_label={predicated_label[instance_no]}")
if predicated_label[instance_no] != true_label:
continue
counter[true_label] += 1
img = visualize.to_image(args.dataset, orig_image[instance_no])
fname = result_dir / f"bnn_{args.dataset}_{instance_no}_label{true_label}.png"
if not fname.exists():
img.save(fname)
with chainer.using_config("train", False), chainer.using_config("enable_backprop", True):
saliency_map = model.saliency_map(x, true_label)
enc = copy.copy(enc_base)
if args.target == "truelabel":
enc.add_clause([outputs[true_label]])
else:
enc.add_clause([-outputs[true_label]])
input_bn = model.input_bn
mu = input_bn.avg_mean
sigma = np.sqrt(input_bn.avg_var + input_bn.eps)
gamma = input_bn.gamma.array
beta = input_bn.beta.array
numerically_unstable = False
mod: List[Tuple[encoder.Lit, Optional[int]]] = []
for j, pixel in enumerate(orig_image[instance_no]):
# C_frac = 255 * (- beta[j] * sigma[j] / gamma[j] + mu[j])
C_frac = (- beta[j] * sigma[j] / gamma[j] + mu[j]) / np.float32(1 / 255.0)
if gamma[j] >= 0:
# x ≥ ⌈255 (- βσ/γ + μ)⌉ = C
C = int(math.ceil(C_frac))
if orig_image_bin[instance_no, j] != (pixel >= C):
numerically_unstable = True
break
#assert orig_image_bin[instance_no, j] == (pixel >= C)
if pixel < C:
mod.append((- inputs[j], C - pixel))
elif C == 0:
mod.append((inputs[j], None)) # impossible to change
else:
mod.append((inputs[j], (C - 1) - pixel))
else:
# x ≤ ⌊255 (- βσ/γ + μ)⌋ = C
C = int(math.floor(C_frac))
#assert orig_image_bin[instance_no, j] == (pixel <= C)
if orig_image_bin[instance_no, j] != (pixel <= C):
numerically_unstable = True
break
if pixel > C:
mod.append((- inputs[j], C - pixel))
elif C == 255:
mod.append((inputs[j], None)) # impossible to change
else:
mod.append((inputs[j], (C + 1) - pixel))
if numerically_unstable:
print("numerically unstable")
continue
# debug
if args.debug_sat:
enc2 = copy.copy(enc)
for lit, w in mod:
enc2.add_clause([lit])
if args.format == "wcnf":
fname = result_dir / f"bnn_{args.dataset}_{instance_no}_label{true_label}_{args.target}_{args.card}_debug.cnf"
elif args.format == "wbo":
fname = result_dir / f"bnn_{args.dataset}_{instance_no}_label{true_label}_{args.target}_debug.opb"
else:
raise RuntimeError("unknown ext: " + args.format)
enc2.write_to_file(fname)
for ratio in args.ratio:
if ratio == 1.0:
mod2 = mod
ratio_str = ""
else:
ratio_str = f"{int(ratio * 100)}p"
important_pixels = set(list(reversed(np.argsort(np.abs(saliency_map))))[:int(len(saliency_map) * ratio)])
important_variables = set(inputs[instance_no] for i in important_pixels)
mod2 = [(lit, w if abs(lit) in important_variables else None) for lit, w in mod]
for norm in args.norm:
xs: List[str] = [
"bnn", args.dataset, str(instance_no), f"label{true_label}",
args.target, "norm_" + str(norm), ratio_str
]
if args.format == "wcnf":
xs.append(args.card)
fname = result_dir / ('_'.join([s for s in xs if len(s) > 0]) + "." + args.format)
enc2 = copy.copy(enc)
add_norm(enc2, norm, mod2)
enc2.write_to_file_opt(fname)