Bump version
[yosys.git] / libs / ezsat / testbench.cc
1 /*
2 * ezSAT -- A simple and easy to use CNF generator for SAT solvers
3 *
4 * Copyright (C) 2013 Claire Xenia Wolf <claire@yosyshq.com>
5 *
6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 *
18 */
19
20 #include "ezminisat.h"
21 #include <stdio.h>
22
23 struct xorshift128 {
24 uint32_t x, y, z, w;
25 xorshift128() {
26 x = 123456789;
27 y = 362436069;
28 z = 521288629;
29 w = 88675123;
30 }
31 uint32_t operator()() {
32 uint32_t t = x ^ (x << 11);
33 x = y; y = z; z = w;
34 w ^= (w >> 19) ^ t ^ (t >> 8);
35 return w;
36 }
37 };
38
39 bool test(ezSAT &sat, int assumption = 0)
40 {
41 std::vector<int> modelExpressions;
42 std::vector<bool> modelValues;
43
44 for (int id = 1; id <= sat.numLiterals(); id++)
45 if (sat.bound(id))
46 modelExpressions.push_back(id);
47
48 if (sat.solve(modelExpressions, modelValues, assumption)) {
49 printf("satisfiable:");
50 for (int i = 0; i < int(modelExpressions.size()); i++)
51 printf(" %s=%d", sat.to_string(modelExpressions[i]).c_str(), int(modelValues[i]));
52 printf("\n\n");
53 return true;
54 } else {
55 printf("not satisfiable.\n\n");
56 return false;
57 }
58 }
59
60 // ------------------------------------------------------------------------------------------------------------
61
62 void test_simple()
63 {
64 printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
65
66 ezMiniSAT sat;
67 sat.non_incremental();
68 sat.assume(sat.OR("A", "B"));
69 sat.assume(sat.NOT(sat.AND("A", "B")));
70 test(sat);
71 }
72
73 // ------------------------------------------------------------------------------------------------------------
74
75 void test_xorshift32_try(ezSAT &sat, uint32_t input_pattern)
76 {
77 uint32_t output_pattern = input_pattern;
78 output_pattern ^= output_pattern << 13;
79 output_pattern ^= output_pattern >> 17;
80 output_pattern ^= output_pattern << 5;
81
82 std::vector<int> modelExpressions;
83 std::vector<int> forwardAssumptions, backwardAssumptions;
84 std::vector<bool> forwardModel, backwardModel;
85
86 sat.vec_append(modelExpressions, sat.vec_var("i", 32));
87 sat.vec_append(modelExpressions, sat.vec_var("o", 32));
88
89 sat.vec_append_unsigned(forwardAssumptions, sat.vec_var("i", 32), input_pattern);
90 sat.vec_append_unsigned(backwardAssumptions, sat.vec_var("o", 32), output_pattern);
91
92 if (!sat.solve(modelExpressions, backwardModel, backwardAssumptions)) {
93 printf("backward solving failed!\n");
94 abort();
95 }
96
97 if (!sat.solve(modelExpressions, forwardModel, forwardAssumptions)) {
98 printf("forward solving failed!\n");
99 abort();
100 }
101
102 printf("xorshift32 test with input pattern 0x%08x:\n", input_pattern);
103
104 printf("forward solution: input=0x%08x output=0x%08x\n",
105 (unsigned int)sat.vec_model_get_unsigned(modelExpressions, forwardModel, sat.vec_var("i", 32)),
106 (unsigned int)sat.vec_model_get_unsigned(modelExpressions, forwardModel, sat.vec_var("o", 32)));
107
108 printf("backward solution: input=0x%08x output=0x%08x\n",
109 (unsigned int)sat.vec_model_get_unsigned(modelExpressions, backwardModel, sat.vec_var("i", 32)),
110 (unsigned int)sat.vec_model_get_unsigned(modelExpressions, backwardModel, sat.vec_var("o", 32)));
111
112 if (forwardModel != backwardModel) {
113 printf("forward and backward results are inconsistend!\n");
114 abort();
115 }
116
117 printf("passed.\n\n");
118 }
119
120 void test_xorshift32()
121 {
122 printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
123
124 ezMiniSAT sat;
125 sat.keep_cnf();
126
127 xorshift128 rng;
128
129 std::vector<int> bits = sat.vec_var("i", 32);
130
131 bits = sat.vec_xor(bits, sat.vec_shl(bits, 13));
132 bits = sat.vec_xor(bits, sat.vec_shr(bits, 17));
133 bits = sat.vec_xor(bits, sat.vec_shl(bits, 5));
134
135 sat.vec_set(bits, sat.vec_var("o", 32));
136
137 test_xorshift32_try(sat, 0);
138 test_xorshift32_try(sat, 314159265);
139 test_xorshift32_try(sat, rng());
140 test_xorshift32_try(sat, rng());
141 test_xorshift32_try(sat, rng());
142 test_xorshift32_try(sat, rng());
143
144 sat.printDIMACS(stdout, true);
145 printf("\n");
146 }
147
148 // ------------------------------------------------------------------------------------------------------------
149
150 #define CHECK(_expr1, _expr2) check(#_expr1, _expr1, #_expr2, _expr2)
151
152 void check(const char *expr1_str, bool expr1, const char *expr2_str, bool expr2)
153 {
154 if (expr1 == expr2) {
155 printf("[ %s ] == [ %s ] .. ok (%s == %s)\n", expr1_str, expr2_str, expr1 ? "true" : "false", expr2 ? "true" : "false");
156 } else {
157 printf("[ %s ] != [ %s ] .. ERROR (%s != %s)\n", expr1_str, expr2_str, expr1 ? "true" : "false", expr2 ? "true" : "false");
158 abort();
159 }
160 }
161
162 void test_signed(int8_t a, int8_t b, int8_t c)
163 {
164 ezMiniSAT sat;
165
166 std::vector<int> av = sat.vec_const_signed(a, 8);
167 std::vector<int> bv = sat.vec_const_signed(b, 8);
168 std::vector<int> cv = sat.vec_const_signed(c, 8);
169
170 printf("Testing signed arithmetic using: a=%+d, b=%+d, c=%+d\n", int(a), int(b), int(c));
171
172 CHECK(a < b+c, sat.solve(sat.vec_lt_signed(av, sat.vec_add(bv, cv))));
173 CHECK(a <= b-c, sat.solve(sat.vec_le_signed(av, sat.vec_sub(bv, cv))));
174
175 CHECK(a > b+c, sat.solve(sat.vec_gt_signed(av, sat.vec_add(bv, cv))));
176 CHECK(a >= b-c, sat.solve(sat.vec_ge_signed(av, sat.vec_sub(bv, cv))));
177
178 printf("\n");
179 }
180
181 void test_unsigned(uint8_t a, uint8_t b, uint8_t c)
182 {
183 ezMiniSAT sat;
184
185 if (b < c)
186 b ^= c, c ^= b, b ^= c;
187
188 std::vector<int> av = sat.vec_const_unsigned(a, 8);
189 std::vector<int> bv = sat.vec_const_unsigned(b, 8);
190 std::vector<int> cv = sat.vec_const_unsigned(c, 8);
191
192 printf("Testing unsigned arithmetic using: a=%d, b=%d, c=%d\n", int(a), int(b), int(c));
193
194 CHECK(a < b+c, sat.solve(sat.vec_lt_unsigned(av, sat.vec_add(bv, cv))));
195 CHECK(a <= b-c, sat.solve(sat.vec_le_unsigned(av, sat.vec_sub(bv, cv))));
196
197 CHECK(a > b+c, sat.solve(sat.vec_gt_unsigned(av, sat.vec_add(bv, cv))));
198 CHECK(a >= b-c, sat.solve(sat.vec_ge_unsigned(av, sat.vec_sub(bv, cv))));
199
200 printf("\n");
201 }
202
203 void test_count(uint32_t x)
204 {
205 ezMiniSAT sat;
206
207 int count = 0;
208 for (int i = 0; i < 32; i++)
209 if (((x >> i) & 1) != 0)
210 count++;
211
212 printf("Testing bit counting using x=0x%08x (%d set bits) .. ", x, count);
213
214 std::vector<int> v = sat.vec_const_unsigned(x, 32);
215
216 std::vector<int> cv6 = sat.vec_const_unsigned(count, 6);
217 std::vector<int> cv4 = sat.vec_const_unsigned(count <= 15 ? count : 15, 4);
218
219 if (cv6 != sat.vec_count(v, 6, false)) {
220 fprintf(stderr, "FAILED 6bit-no-clipping test!\n");
221 abort();
222 }
223
224 if (cv4 != sat.vec_count(v, 4, true)) {
225 fprintf(stderr, "FAILED 4bit-clipping test!\n");
226 abort();
227 }
228
229 printf("ok.\n");
230 }
231
232 void test_arith()
233 {
234 printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
235
236 xorshift128 rng;
237
238 for (int i = 0; i < 100; i++)
239 test_signed(rng() % 19 - 10, rng() % 19 - 10, rng() % 19 - 10);
240
241 for (int i = 0; i < 100; i++)
242 test_unsigned(rng() % 10, rng() % 10, rng() % 10);
243
244 test_count(0x00000000);
245 test_count(0xffffffff);
246 for (int i = 0; i < 30; i++)
247 test_count(rng());
248
249 printf("\n");
250 }
251
252 // ------------------------------------------------------------------------------------------------------------
253
254 void test_onehot()
255 {
256 printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
257 ezMiniSAT ez;
258
259 int a = ez.frozen_literal("a");
260 int b = ez.frozen_literal("b");
261 int c = ez.frozen_literal("c");
262 int d = ez.frozen_literal("d");
263
264 std::vector<int> abcd;
265 abcd.push_back(a);
266 abcd.push_back(b);
267 abcd.push_back(c);
268 abcd.push_back(d);
269
270 ez.assume(ez.onehot(abcd));
271
272 int solution_counter = 0;
273 while (1)
274 {
275 std::vector<bool> modelValues;
276 bool ok = ez.solve(abcd, modelValues);
277
278 if (!ok)
279 break;
280
281 printf("Solution: %d %d %d %d\n", int(modelValues[0]), int(modelValues[1]), int(modelValues[2]), int(modelValues[3]));
282
283 int count_hot = 0;
284 std::vector<int> sol;
285 for (int i = 0; i < 4; i++) {
286 if (modelValues[i])
287 count_hot++;
288 sol.push_back(modelValues[i] ? abcd[i] : ez.NOT(abcd[i]));
289 }
290 ez.assume(ez.NOT(ez.expression(ezSAT::OpAnd, sol)));
291
292 if (count_hot != 1) {
293 fprintf(stderr, "Wrong number of hot bits!\n");
294 abort();
295 }
296
297 solution_counter++;
298 }
299
300 if (solution_counter != 4) {
301 fprintf(stderr, "Wrong number of one-hot solutions!\n");
302 abort();
303 }
304
305 printf("\n");
306 }
307
308 void test_manyhot()
309 {
310 printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
311 ezMiniSAT ez;
312
313 int a = ez.frozen_literal("a");
314 int b = ez.frozen_literal("b");
315 int c = ez.frozen_literal("c");
316 int d = ez.frozen_literal("d");
317
318 std::vector<int> abcd;
319 abcd.push_back(a);
320 abcd.push_back(b);
321 abcd.push_back(c);
322 abcd.push_back(d);
323
324 ez.assume(ez.manyhot(abcd, 1, 2));
325
326 int solution_counter = 0;
327 while (1)
328 {
329 std::vector<bool> modelValues;
330 bool ok = ez.solve(abcd, modelValues);
331
332 if (!ok)
333 break;
334
335 printf("Solution: %d %d %d %d\n", int(modelValues[0]), int(modelValues[1]), int(modelValues[2]), int(modelValues[3]));
336
337 int count_hot = 0;
338 std::vector<int> sol;
339 for (int i = 0; i < 4; i++) {
340 if (modelValues[i])
341 count_hot++;
342 sol.push_back(modelValues[i] ? abcd[i] : ez.NOT(abcd[i]));
343 }
344 ez.assume(ez.NOT(ez.expression(ezSAT::OpAnd, sol)));
345
346 if (count_hot != 1 && count_hot != 2) {
347 fprintf(stderr, "Wrong number of hot bits!\n");
348 abort();
349 }
350
351 solution_counter++;
352 }
353
354 if (solution_counter != 4 + 4*3/2) {
355 fprintf(stderr, "Wrong number of one-hot solutions!\n");
356 abort();
357 }
358
359 printf("\n");
360 }
361
362 void test_ordered()
363 {
364 printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
365 ezMiniSAT ez;
366
367 int a = ez.frozen_literal("a");
368 int b = ez.frozen_literal("b");
369 int c = ez.frozen_literal("c");
370
371 int x = ez.frozen_literal("x");
372 int y = ez.frozen_literal("y");
373 int z = ez.frozen_literal("z");
374
375 std::vector<int> abc;
376 abc.push_back(a);
377 abc.push_back(b);
378 abc.push_back(c);
379
380 std::vector<int> xyz;
381 xyz.push_back(x);
382 xyz.push_back(y);
383 xyz.push_back(z);
384
385 ez.assume(ez.ordered(abc, xyz));
386
387 int solution_counter = 0;
388
389 while (1)
390 {
391 std::vector<int> modelVariables;
392 std::vector<bool> modelValues;
393
394 modelVariables.push_back(a);
395 modelVariables.push_back(b);
396 modelVariables.push_back(c);
397
398 modelVariables.push_back(x);
399 modelVariables.push_back(y);
400 modelVariables.push_back(z);
401
402 bool ok = ez.solve(modelVariables, modelValues);
403
404 if (!ok)
405 break;
406
407 printf("Solution: %d %d %d | %d %d %d\n",
408 int(modelValues[0]), int(modelValues[1]), int(modelValues[2]),
409 int(modelValues[3]), int(modelValues[4]), int(modelValues[5]));
410
411 std::vector<int> sol;
412 for (size_t i = 0; i < modelVariables.size(); i++)
413 sol.push_back(modelValues[i] ? modelVariables[i] : ez.NOT(modelVariables[i]));
414 ez.assume(ez.NOT(ez.expression(ezSAT::OpAnd, sol)));
415
416 solution_counter++;
417 }
418
419 if (solution_counter != 8+7+6+5+4+3+2+1) {
420 fprintf(stderr, "Wrong number of solutions!\n");
421 abort();
422 }
423
424 printf("\n");
425 }
426
427 // ------------------------------------------------------------------------------------------------------------
428
429
430 int main()
431 {
432 test_simple();
433 test_xorshift32();
434 test_arith();
435 test_onehot();
436 test_manyhot();
437 test_ordered();
438 printf("Passed all tests.\n\n");
439 return 0;
440 }
441