Improve ezsat onehot encoding scheme
authorClaire Wolf <claire@symbioticeda.com>
Thu, 2 Apr 2020 10:22:28 +0000 (12:22 +0200)
committerClaire Wolf <claire@symbioticeda.com>
Thu, 2 Apr 2020 10:22:28 +0000 (12:22 +0200)
Signed-off-by: Claire Wolf <claire@symbioticeda.com>
libs/ezsat/ezsat.cc

index 47fdb8efe5a890780915adfa56c69bbf901d886d..8c666ca1f979cc9ae2cde2206649842f3bb90ea8 100644 (file)
@@ -1371,20 +1371,34 @@ int ezSAT::onehot(const std::vector<int> &vec, bool max_only)
        if (max_only == false)
                formula.push_back(expression(OpOr, vec));
 
-       // create binary vector
-       int num_bits = clog2(vec.size());
-       std::vector<int> bits;
-       for (int k = 0; k < num_bits; k++)
-               bits.push_back(literal());
-
-       // add at-most-one clauses using binary encoding
-       for (size_t i = 0; i < vec.size(); i++)
-               for (int k = 0; k < num_bits; k++) {
-                       std::vector<int> clause;
-                       clause.push_back(NOT(vec[i]));
-                       clause.push_back((i & (1 << k)) != 0 ? bits[k] : NOT(bits[k]));
-                       formula.push_back(expression(OpOr, clause));
-               }
+       if (vec.size() < 8)
+       {
+               // fall-back to simple O(n^2) solution for small cases
+               for (size_t i = 0; i < vec.size(); i++)
+                       for (size_t j = i+1; j < vec.size(); j++) {
+                               std::vector<int> clause;
+                               clause.push_back(NOT(vec[i]));
+                               clause.push_back(NOT(vec[j]));
+                               formula.push_back(expression(OpOr, clause));
+                       }
+       }
+       else
+       {
+               // create binary vector
+               int num_bits = clog2(vec.size());
+               std::vector<int> bits;
+               for (int k = 0; k < num_bits; k++)
+                       bits.push_back(literal());
+
+               // add at-most-one clauses using binary encoding
+               for (size_t i = 0; i < vec.size(); i++)
+                       for (int k = 0; k < num_bits; k++) {
+                               std::vector<int> clause;
+                               clause.push_back(NOT(vec[i]));
+                               clause.push_back((i & (1 << k)) != 0 ? bits[k] : NOT(bits[k]));
+                               formula.push_back(expression(OpOr, clause));
+                       }
+       }
 
        return expression(OpAnd, formula);
 }