Added ezSAT::keep_cnf() and ezSAT::non_incremental()
authorClifford Wolf <clifford@clifford.at>
Sun, 20 Jul 2014 23:49:59 +0000 (01:49 +0200)
committerClifford Wolf <clifford@clifford.at>
Mon, 21 Jul 2014 00:01:32 +0000 (02:01 +0200)
libs/ezsat/Makefile
libs/ezsat/ezminisat.cc
libs/ezsat/ezsat.cc
libs/ezsat/ezsat.h
libs/ezsat/testbench.cc

index 1dcb5d15b6ddcc6bff19fe55b4eefe57148ceb8f..b1f8641609395fa4d3d67cedfc6d9e37babda782 100644 (file)
@@ -18,7 +18,8 @@ test: all
        ./testbench
        ./demo_bit
        ./demo_vec
-       ./demo_cmp
+       # ./demo_cmp
+       # ./puzzle3d
 
 clean:
        rm -f demo_bit demo_vec demo_cmp testbench puzzle3d *.o *.d
index 6a6c075f55c70c999b1ecc92363c6152d2ee43d0..3f43f3ece25b30660f3391ac79adfac81c6836c3 100644 (file)
@@ -63,7 +63,8 @@ void ezMiniSAT::clear()
 #if EZMINISAT_SIMPSOLVER && EZMINISAT_INCREMENTAL
 void ezMiniSAT::freeze(int id)
 {
-       cnfFrozenVars.insert(bind(id));
+       if (!mode_non_incremental())
+               cnfFrozenVars.insert(bind(id));
 }
 
 bool ezMiniSAT::eliminated(int idx)
@@ -89,6 +90,8 @@ void ezMiniSAT::alarmHandler(int)
 
 bool ezMiniSAT::solver(const std::vector<int> &modelExpressions, std::vector<bool> &modelValues, const std::vector<int> &assumptions)
 {
+       preSolverCallback();
+
        solverTimoutStatus = false;
 
        if (0) {
index 6da363fc11eb28ffdf5c3cee7ded030e9793d6ab..4c0b624be852265f132a9101b24d99cf2b662059 100644 (file)
@@ -30,6 +30,11 @@ const int ezSAT::FALSE = 2;
 
 ezSAT::ezSAT()
 {
+       flag_keep_cnf = false;
+       flag_non_incremental = false;
+
+       non_incremental_solve_used_up = false;
+
        cnfConsumed = false;
        cnfVariableCount = 0;
        cnfClausesCount = 0;
@@ -588,19 +593,40 @@ int ezSAT::bind(int id, bool auto_freeze)
 
 void ezSAT::consumeCnf()
 {
-       cnfConsumed = true;
+       if (mode_keep_cnf())
+               cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end());
+       else
+               cnfConsumed = true;
        cnfClauses.clear();
 }
 
 void ezSAT::consumeCnf(std::vector<std::vector<int>> &cnf)
 {
-       cnfConsumed = true;
+       if (mode_keep_cnf())
+               cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end());
+       else
+               cnfConsumed = true;
        cnf.swap(cnfClauses);
        cnfClauses.clear();
 }
 
+void ezSAT::getFullCnf(std::vector<std::vector<int>> &full_cnf) const
+{
+       assert(full_cnf.empty());
+       full_cnf.insert(full_cnf.end(), cnfClausesBackup.begin(), cnfClausesBackup.end());
+       full_cnf.insert(full_cnf.end(), cnfClauses.begin(), cnfClauses.end());
+}
+
+void ezSAT::preSolverCallback()
+{
+       assert(!non_incremental_solve_used_up);
+       if (mode_non_incremental())
+               non_incremental_solve_used_up = true;
+}
+
 bool ezSAT::solver(const std::vector<int>&, std::vector<bool>&, const std::vector<int>&)
 {
+       preSolverCallback();
        fprintf(stderr, "************************************************************************\n");
        fprintf(stderr, "ERROR: You are trying to use the solve() method of the ezSAT base class!\n");
        fprintf(stderr, "Use a dervied class like ezMiniSAT instead.\n");
@@ -1081,16 +1107,26 @@ void ezSAT::printDIMACS(FILE *f, bool verbose) const
                        if (cnfExpressionVariables[i] != 0)
                                fprintf(f, "c %*d: %s\n", digits, cnfExpressionVariables[i], to_string(-i-1).c_str());
 
+               if (mode_keep_cnf()) {
+                       fprintf(f, "c\n");
+                       fprintf(f, "c %d clauses from backup, %d from current buffer\n",
+                                       int(cnfClausesBackup.size()), int(cnfClauses.size()));
+               }
+
                fprintf(f, "c\n");
        }
 
-       fprintf(f, "p cnf %d %d\n", cnfVariableCount, int(cnfClauses.size()));
+       std::vector<std::vector<int>> all_clauses;
+       getFullCnf(all_clauses);
+       assert(cnfClausesCount == int(all_clauses.size()));
+
+       fprintf(f, "p cnf %d %d\n", cnfVariableCount, cnfClausesCount);
        int maxClauseLen = 0;
-       for (auto &clause : cnfClauses)
+       for (auto &clause : all_clauses)
                maxClauseLen = std::max(int(clause.size()), maxClauseLen);
        if (!verbose)
                maxClauseLen = std::min(maxClauseLen, 3);
-       for (auto &clause : cnfClauses) {
+       for (auto &clause : all_clauses) {
                for (auto idx : clause)
                        fprintf(f, " %*d", digits, idx);
                if (maxClauseLen >= int(clause.size()))
index 85240556684e478ed293743ea0c11cce8e81c5ad..83e1b23c57a257b6e1ac3926ff48954f914dafee 100644 (file)
@@ -48,6 +48,11 @@ public:
        static const int FALSE;
 
 private:
+       bool flag_keep_cnf;
+       bool flag_non_incremental;
+
+       bool non_incremental_solve_used_up;
+
        std::map<std::string, int> literalsCache;
        std::vector<std::string> literals;
 
@@ -57,7 +62,7 @@ private:
        bool cnfConsumed;
        int cnfVariableCount, cnfClausesCount;
        std::vector<int> cnfLiteralVariables, cnfExpressionVariables;
-       std::vector<std::vector<int>> cnfClauses;
+       std::vector<std::vector<int>> cnfClauses, cnfClausesBackup;
 
        void add_clause(const std::vector<int> &args);
        void add_clause(const std::vector<int> &args, bool argsPolarity, int a = 0, int b = 0, int c = 0);
@@ -67,6 +72,9 @@ private:
        int bind_cnf_and(const std::vector<int> &args);
        int bind_cnf_or(const std::vector<int> &args);
 
+protected:
+       void preSolverCallback();
+
 public:
        int solverTimeout;
        bool solverTimoutStatus;
@@ -74,6 +82,12 @@ public:
        ezSAT();
        virtual ~ezSAT();
 
+       void keep_cnf() { flag_keep_cnf = true; }
+       void non_incremental() { flag_non_incremental = true; }
+
+       bool mode_keep_cnf() const { return flag_keep_cnf; }
+       bool mode_non_incremental() const { return flag_non_incremental; }
+
        // manage expressions
 
        int value(bool val);
@@ -155,6 +169,9 @@ public:
        void consumeCnf();
        void consumeCnf(std::vector<std::vector<int>> &cnf);
 
+       // use this function to get the full CNF in keep_cnf mode
+       void getFullCnf(std::vector<std::vector<int>> &full_cnf) const;
+
        std::string cnfLiteralInfo(int idx) const;
 
        // simple helpers for build expressions easily
index 8332ad919ef0233c79ece5c6aea00e9e6f60648b..d20258c379a5531ba5a8ff37cb6c0c30396c31c1 100644 (file)
@@ -64,6 +64,7 @@ void test_simple()
        printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
 
        ezMiniSAT sat;
+       sat.non_incremental();
        sat.assume(sat.OR("A", "B"));
        sat.assume(sat.NOT(sat.AND("A", "B")));
        test(sat);
@@ -121,6 +122,8 @@ void test_xorshift32()
        printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
 
        ezMiniSAT sat;
+       sat.keep_cnf();
+
        xorshift128 rng;
 
        std::vector<int> bits = sat.vec_var("i", 32);
@@ -137,6 +140,9 @@ void test_xorshift32()
        test_xorshift32_try(sat, rng());
        test_xorshift32_try(sat, rng());
        test_xorshift32_try(sat, rng());
+
+       sat.printDIMACS(stdout, true);
+       printf("\n");
 }
 
 // ------------------------------------------------------------------------------------------------------------