From 16c2fe5ec2ebb29da131aa590a4a0b79b1e94dc9 Mon Sep 17 00:00:00 2001 From: Aina Niemetz Date: Mon, 19 Aug 2019 15:10:04 -0700 Subject: [PATCH] New C++ API: Add checks for Solver::checkValid and Solver::checkValidAssuming. (#3197) --- src/api/cvc4cpp.cpp | 48 +++++++++++++++--- src/smt/smt_engine.h | 3 ++ test/unit/api/solver_black.h | 98 ++++++++++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 8 deletions(-) diff --git a/src/api/cvc4cpp.cpp b/src/api/cvc4cpp.cpp index b29d6a26f..904da0f10 100644 --- a/src/api/cvc4cpp.cpp +++ b/src/api/cvc4cpp.cpp @@ -38,12 +38,14 @@ #include "base/cvc4_check.h" #include "expr/expr.h" #include "expr/expr_manager.h" +#include "expr/expr_manager_scope.h" #include "expr/kind.h" #include "expr/metakind.h" #include "expr/node_manager.h" #include "expr/type.h" #include "options/main_options.h" #include "options/options.h" +#include "options/smt_options.h" #include "smt/model.h" #include "smt/smt_engine.h" #include "theory/logic_info.h" @@ -3084,33 +3086,61 @@ Term Solver::simplify(const Term& t) { CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_ARG_CHECK_NOT_NULL(t); + return d_smtEngine->simplify(*t.d_expr); + CVC4_API_SOLVER_TRY_CATCH_END; } Result Solver::checkValid(void) const { - // CHECK: - // if d_queryMade -> incremental enabled + CVC4_API_SOLVER_TRY_CATCH_BEGIN; + CVC4::ExprManagerScope exmgrs(*(d_exprMgr.get())); + CVC4_API_CHECK(!d_smtEngine->isQueryMade() + || CVC4::options::incrementalSolving()) + << "Cannot make multiple queries unless incremental solving is enabled " + "(try --incremental)"; + CVC4::Result r = d_smtEngine->query(); return Result(r); + + CVC4_API_SOLVER_TRY_CATCH_END; } Result Solver::checkValidAssuming(Term assumption) const { - // CHECK: - // if assumptions.size() > 0: incremental enabled? + CVC4_API_SOLVER_TRY_CATCH_BEGIN; + CVC4::ExprManagerScope exmgrs(*(d_exprMgr.get())); + CVC4_API_CHECK(!d_smtEngine->isQueryMade() + || CVC4::options::incrementalSolving()) + << "Cannot make multiple queries unless incremental solving is enabled " + "(try --incremental)"; + CVC4_API_ARG_CHECK_NOT_NULL(assumption); + CVC4::Result r = d_smtEngine->query(*assumption.d_expr); return Result(r); + + CVC4_API_SOLVER_TRY_CATCH_END; } Result Solver::checkValidAssuming(const std::vector& assumptions) const { - // CHECK: - // if assumptions.size() > 0: incremental enabled? + CVC4_API_SOLVER_TRY_CATCH_BEGIN; + CVC4::ExprManagerScope exmgrs(*(d_exprMgr.get())); + CVC4_API_CHECK(!d_smtEngine->isQueryMade() + || CVC4::options::incrementalSolving()) + << "Cannot make multiple queries unless incremental solving is enabled " + "(try --incremental)"; + for (const Term& assumption : assumptions) + { + CVC4_API_ARG_CHECK_NOT_NULL(assumption); + } + std::vector eassumptions = termVectorToExprs(assumptions); CVC4::Result r = d_smtEngine->query(eassumptions); return Result(r); + + CVC4_API_SOLVER_TRY_CATCH_END; } /* SMT-LIB commands */ @@ -3556,7 +3586,8 @@ std::vector Solver::getValue(const std::vector& terms) const void Solver::pop(uint32_t nscopes) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - CVC4_API_CHECK(d_smtEngine->getOption("incremental").toString() == "true") + CVC4::ExprManagerScope exmgrs(*(d_exprMgr.get())); + CVC4_API_CHECK(CVC4::options::incrementalSolving()) << "Cannot pop when not solving incrementally (use --incremental)"; CVC4_API_CHECK(nscopes <= d_smtEngine->getNumUserLevels()) << "Cannot pop beyond first pushed context"; @@ -3581,7 +3612,8 @@ void Solver::printModel(std::ostream& out) const void Solver::push(uint32_t nscopes) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - CVC4_API_CHECK(d_smtEngine->getOption("incremental").toString() == "true") + CVC4::ExprManagerScope exmgrs(*(d_exprMgr.get())); + CVC4_API_CHECK(CVC4::options::incrementalSolving()) << "Cannot push when not solving incrementally (use --incremental)"; for (uint32_t n = 0; n < nscopes; ++n) diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index 4ac21d392..120c712ea 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -154,6 +154,9 @@ class CVC4_PUBLIC SmtEngine */ bool isFullyInited() { return d_fullyInited; } + /** Return true if a query() or checkSat() has already been made. */ + bool isQueryMade() { return d_queryMade; } + /** Return the user context level. */ size_t getNumUserLevels() { return d_userLevels.size(); } diff --git a/test/unit/api/solver_black.h b/test/unit/api/solver_black.h index 3a90f7ed2..835ecd880 100644 --- a/test/unit/api/solver_black.h +++ b/test/unit/api/solver_black.h @@ -89,7 +89,12 @@ class SolverBlack : public CxxTest::TestSuite void testPop1(); void testPop2(); void testPop3(); + void testSimplify(); + void testCheckValid1(); + void testCheckValid2(); + void testCheckValidAssuming1(); + void testCheckValidAssuming2(); void testSetInfo(); void testSetLogic(); @@ -1027,6 +1032,99 @@ void SolverBlack::testSimplify() TS_ASSERT_THROWS_NOTHING(d_solver->simplify(f2)); } +void SolverBlack::testCheckValid1() +{ + d_solver->setOption("incremental", "false"); + TS_ASSERT_THROWS_NOTHING(d_solver->checkValid()); + TS_ASSERT_THROWS(d_solver->checkValid(), CVC4ApiException&); +} + +void SolverBlack::testCheckValid2() +{ + d_solver->setOption("incremental", "true"); + + Sort realSort = d_solver->getRealSort(); + Sort intSort = d_solver->getIntegerSort(); + + // Constants + Term x = d_solver->mkConst(intSort, "x"); + Term y = d_solver->mkConst(realSort, "y"); + // Values + Term three = d_solver->mkReal(3); + Term neg2 = d_solver->mkReal(-2); + Term two_thirds = d_solver->mkReal(2, 3); + // Terms + Term three_y = d_solver->mkTerm(MULT, three, y); + Term diff = d_solver->mkTerm(MINUS, y, x); + // Formulas + Term x_geq_3y = d_solver->mkTerm(GEQ, x, three_y); + Term x_leq_y = d_solver->mkTerm(LEQ, x, y); + Term neg2_lt_x = d_solver->mkTerm(LT, neg2, x); + // Assertions + Term assertions = d_solver->mkTerm(AND, x_geq_3y, x_leq_y, neg2_lt_x); + + TS_ASSERT_THROWS_NOTHING(d_solver->checkValid()); + d_solver->assertFormula(assertions); + TS_ASSERT_THROWS_NOTHING(d_solver->checkValid()); +} + +void SolverBlack::testCheckValidAssuming1() +{ + d_solver->setOption("incremental", "false"); + TS_ASSERT_THROWS_NOTHING(d_solver->checkValidAssuming(d_solver->mkTrue())); + TS_ASSERT_THROWS(d_solver->checkValidAssuming(d_solver->mkTrue()), + CVC4ApiException&); +} + +void SolverBlack::testCheckValidAssuming2() +{ + d_solver->setOption("incremental", "true"); + + Sort uSort = d_solver->mkUninterpretedSort("u"); + Sort intSort = d_solver->getIntegerSort(); + Sort boolSort = d_solver->getBooleanSort(); + Sort uToIntSort = d_solver->mkFunctionSort(uSort, intSort); + Sort intPredSort = d_solver->mkFunctionSort(intSort, boolSort); + + Term n = Term(); + // Constants + Term x = d_solver->mkConst(uSort, "x"); + Term y = d_solver->mkConst(uSort, "y"); + // Functions + Term f = d_solver->mkConst(uToIntSort, "f"); + Term p = d_solver->mkConst(intPredSort, "p"); + // Values + Term zero = d_solver->mkReal(0); + Term one = d_solver->mkReal(1); + // Terms + Term f_x = d_solver->mkTerm(APPLY_UF, f, x); + Term f_y = d_solver->mkTerm(APPLY_UF, f, y); + Term sum = d_solver->mkTerm(PLUS, f_x, f_y); + Term p_0 = d_solver->mkTerm(APPLY_UF, p, zero); + Term p_f_y = d_solver->mkTerm(APPLY_UF, p, f_y); + // Assertions + Term assertions = + d_solver->mkTerm(AND, + std::vector{ + d_solver->mkTerm(LEQ, zero, f_x), // 0 <= f(x) + d_solver->mkTerm(LEQ, zero, f_y), // 0 <= f(y) + d_solver->mkTerm(LEQ, sum, one), // f(x) + f(y) <= 1 + p_0.notTerm(), // not p(0) + p_f_y // p(f(y)) + }); + + TS_ASSERT_THROWS_NOTHING(d_solver->checkValidAssuming(d_solver->mkTrue())); + d_solver->assertFormula(assertions); + TS_ASSERT_THROWS_NOTHING( + d_solver->checkValidAssuming(d_solver->mkTerm(DISTINCT, x, y))); + TS_ASSERT_THROWS_NOTHING(d_solver->checkValidAssuming( + {d_solver->mkFalse(), d_solver->mkTerm(DISTINCT, x, y)})); + TS_ASSERT_THROWS(d_solver->checkValidAssuming(n), CVC4ApiException&); + TS_ASSERT_THROWS( + d_solver->checkValidAssuming({n, d_solver->mkTerm(DISTINCT, x, y)}), + CVC4ApiException&); +} + void SolverBlack::testSetLogic() { TS_ASSERT_THROWS_NOTHING(d_solver->setLogic("AUFLIRA")); -- 2.30.2