From 868a62d550966065f8afdfc5b39715ca6c06314a Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Thu, 6 Jan 2022 13:09:10 -0800 Subject: [PATCH] Improve theory combination in the presence of real algebraic numbers (#7883) This PR changes how we handle real algebraic numbers in theory combination and model construction. The goal is to improve getEqualityStatus() and produce proper models more often. We now use a RAN-aware evaluator for getEqualityStatus() and change the way how the nonlinear extension finalizes its model. --- src/CMakeLists.txt | 2 + src/theory/arith/arith_evaluator.cpp | 94 +++++++++++++++++++++ src/theory/arith/arith_evaluator.h | 25 ++++++ src/theory/arith/nl/cad_solver.cpp | 36 ++++---- src/theory/arith/nl/cad_solver.h | 6 ++ src/theory/arith/nl/nonlinear_extension.cpp | 43 ++++++++-- src/theory/arith/nl/nonlinear_extension.h | 9 +- src/theory/arith/theory_arith.cpp | 20 +++-- test/regress/CMakeLists.txt | 1 + test/regress/regress0/nl/combined-uf.smt2 | 11 +++ 10 files changed, 216 insertions(+), 31 deletions(-) create mode 100644 src/theory/arith/arith_evaluator.cpp create mode 100644 src/theory/arith/arith_evaluator.h create mode 100644 test/regress/regress0/nl/combined-uf.smt2 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7a62a327a..830c70ca9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -352,6 +352,8 @@ libcvc5_add_sources( smt_util/boolean_simplification.h theory/arith/approx_simplex.cpp theory/arith/approx_simplex.h + theory/arith/arith_evaluator.cpp + theory/arith/arith_evaluator.h theory/arith/arith_ite_utils.cpp theory/arith/arith_ite_utils.h theory/arith/arith_msum.cpp diff --git a/src/theory/arith/arith_evaluator.cpp b/src/theory/arith/arith_evaluator.cpp new file mode 100644 index 000000000..0fe045a61 --- /dev/null +++ b/src/theory/arith/arith_evaluator.cpp @@ -0,0 +1,94 @@ +#include "theory/arith/arith_evaluator.h" + +#include "theory/arith/nl/poly_conversion.h" +#include "theory/rewriter.h" +#include "theory/theory.h" +#include "util/real_algebraic_number.h" + +namespace cvc5 { +namespace theory { +namespace arith { + +namespace { + +RealAlgebraicNumber evaluate(TNode expr, + const std::map& rans) +{ + switch (expr.getKind()) + { + case Kind::PLUS: + { + RealAlgebraicNumber aggr; + for (const auto& n : expr) + { + aggr += evaluate(n, rans); + } + return aggr; + } + case Kind::MULT: + case Kind::NONLINEAR_MULT: + { + RealAlgebraicNumber aggr(Integer(1)); + for (const auto& n : expr) + { + aggr *= evaluate(n, rans); + } + return aggr; + } + case Kind::MINUS: + Assert(expr.getNumChildren() == 2); + return evaluate(expr[0], rans) - evaluate(expr[1], rans); + case Kind::UMINUS: return -evaluate(expr[0], rans); + case Kind::CONST_RATIONAL: + return RealAlgebraicNumber(expr.getConst()); + default: + auto it = rans.find(expr); + if (it != rans.end()) + { + return it->second; + } + Assert(false) << "Unsupported expression kind for RAN evaluation: " + << expr.getKind(); + return RealAlgebraicNumber(); + } +} + +} // namespace + +bool isExpressionZero(Env& env, Node expr, const std::map& model) +{ + // Substitute constants and rewrite + expr = env.getRewriter()->rewrite(expr); + if (expr.isConst()) + { + return expr.getConst().isZero(); + } + std::map rans; + std::vector nodes; + std::vector repls; + for (const auto& [node, repl] : model) + { + if (repl.getType().isRealOrInt() + && Theory::isLeafOf(repl, TheoryId::THEORY_ARITH)) + { + nodes.emplace_back(node); + repls.emplace_back(repl); + } + else + { + rans.emplace(node, nl::node_to_ran(repl, node)); + } + } + expr = + expr.substitute(nodes.begin(), nodes.end(), repls.begin(), repls.end()); + expr = env.getRewriter()->rewrite(expr); + if (expr.isConst()) + { + return expr.getConst().isZero(); + } + return isZero(evaluate(expr, rans)); +} + +} // namespace arith +} // namespace theory +} // namespace cvc5 diff --git a/src/theory/arith/arith_evaluator.h b/src/theory/arith/arith_evaluator.h new file mode 100644 index 000000000..cc50c670c --- /dev/null +++ b/src/theory/arith/arith_evaluator.h @@ -0,0 +1,25 @@ + +#include "cvc5_private.h" + +#ifndef CVC5__THEORY__ARITH__ARITH_EVALUATOR_H +#define CVC5__THEORY__ARITH__ARITH_EVALUATOR_H + +#include "expr/node.h" +#include "smt/env.h" + +namespace cvc5 { +namespace theory { +namespace arith { + +/** + * Check if the expression `expr` is zero over the given model. + * The model may contain real algebraic numbers in standard witness form. + * The environment is used for rewriting. + */ +bool isExpressionZero(Env& env, Node expr, const std::map& model); + +} +} // namespace theory +} // namespace cvc5 + +#endif \ No newline at end of file diff --git a/src/theory/arith/nl/cad_solver.cpp b/src/theory/arith/nl/cad_solver.cpp index f4582ac20..7ecbccf6d 100644 --- a/src/theory/arith/nl/cad_solver.cpp +++ b/src/theory/arith/nl/cad_solver.cpp @@ -206,11 +206,6 @@ bool CadSolver::constructModelIfAvailable(std::vector& assertions) return false; } bool foundNonVariable = false; - for (const auto& sub: d_eqsubs.getSubstitutions()) - { - d_model.addSubstitution(sub.first, sub.second); - Trace("nl-cad") << "-> " << sub.first << " = " << sub.second << std::endl; - } for (const auto& v : d_CAC.getVariableOrdering()) { Node variable = d_CAC.getConstraints().varMapper()(v); @@ -219,16 +214,14 @@ bool CadSolver::constructModelIfAvailable(std::vector& assertions) Trace("nl-cad") << "Not a variable: " << variable << std::endl; foundNonVariable = true; } - Node value = value_to_node(d_CAC.getModel().get(v), d_ranVariable); - if (value.isConst()) - { - d_model.addSubstitution(variable, value); - } - else - { - d_model.addWitness(variable, value); - } - Trace("nl-cad") << "-> " << v << " = " << value << std::endl; + Node value = value_to_node(d_CAC.getModel().get(v), variable); + addToModel(variable, value); + } + for (const auto& sub : d_eqsubs.getSubstitutions()) + { + Trace("nl-cad") << "EqSubs: " << sub.first << " -> " << sub.second + << std::endl; + addToModel(sub.first, sub.second); } if (foundNonVariable) { @@ -249,6 +242,19 @@ bool CadSolver::constructModelIfAvailable(std::vector& assertions) #endif } +void CadSolver::addToModel(TNode var, TNode value) const +{ + Trace("nl-cad") << "-> " << var << " = " << value << std::endl; + if (value.getType().isRealOrInt()) + { + d_model.addSubstitution(var, value); + } + else + { + d_model.addWitness(var, value); + } +} + } // namespace nl } // namespace arith } // namespace theory diff --git a/src/theory/arith/nl/cad_solver.h b/src/theory/arith/nl/cad_solver.h index 73d09378b..d72c92a8a 100644 --- a/src/theory/arith/nl/cad_solver.h +++ b/src/theory/arith/nl/cad_solver.h @@ -82,6 +82,12 @@ class CadSolver: protected EnvObj bool constructModelIfAvailable(std::vector& assertions); private: + /** + * Add the variable assignment `var = value` to the nonlinear model. + * Depending on `value`, it is either added as substitution or witness. + */ + void addToModel(TNode var, TNode value) const; + /** * The variable used to encode real algebraic numbers to nodes. */ diff --git a/src/theory/arith/nl/nonlinear_extension.cpp b/src/theory/arith/nl/nonlinear_extension.cpp index 3f60f8596..b57c0d1db 100644 --- a/src/theory/arith/nl/nonlinear_extension.cpp +++ b/src/theory/arith/nl/nonlinear_extension.cpp @@ -283,28 +283,55 @@ void NonlinearExtension::checkFullEffort(std::map& arithModel, d_approximations, d_witnesses, options().smt.modelWitnessValue); + for (auto& am : arithModel) + { + Node val = getModelValue(am.first); + if (!val.isNull()) + { + am.second = val; + } + } } } -void NonlinearExtension::finalizeModel(TheoryModel* tm) +Node NonlinearExtension::getModelValue(TNode var) const { - Trace("nl-ext") << "NonlinearExtension::finalizeModel" << std::endl; + if (auto it = d_approximations.find(var); it != d_approximations.end()) + { + if (it->second.second.isNull()) + { + return it->second.first; + } + return Node::null(); + } + if (auto it = d_witnesses.find(var); it != d_witnesses.end()) + { + return it->second; + } + return Node::null(); +} - for (std::pair>& a : d_approximations) +bool NonlinearExtension::assertModel(TheoryModel* tm, TNode var) const +{ + if (auto it = d_approximations.find(var); it != d_approximations.end()) { - if (a.second.second.isNull()) + const auto& approx = it->second; + if (approx.second.isNull()) { - tm->recordApproximation(a.first, a.second.first); + tm->recordApproximation(var, approx.first); } else { - tm->recordApproximation(a.first, a.second.first, a.second.second); + tm->recordApproximation(var, approx.first, approx.second); } + return true; } - for (const auto& vw : d_witnesses) + if (auto it = d_witnesses.find(var); it != d_witnesses.end()) { - tm->recordApproximation(vw.first, vw.second); + tm->recordApproximation(var, it->second); + return true; } + return false; } Result::Sat NonlinearExtension::modelBasedRefinement(const std::set& termSet) diff --git a/src/theory/arith/nl/nonlinear_extension.h b/src/theory/arith/nl/nonlinear_extension.h index 53e0db90e..390dd72a3 100644 --- a/src/theory/arith/nl/nonlinear_extension.h +++ b/src/theory/arith/nl/nonlinear_extension.h @@ -112,9 +112,14 @@ class NonlinearExtension : EnvObj const std::set& termSet); /** - * Finalize the given model by adding approximations and witnesses. + * Retrieve the model value for the given variable. It may be either an + * arithmetic term or a witness. */ - void finalizeModel(TheoryModel* tm); + Node getModelValue(TNode var) const; + /** + * Assert the model for the given variable to the theory model. + */ + bool assertModel(TheoryModel* tm, TNode var) const; /** Does this class need a call to check(...) at last call effort? */ bool hasNlTerms() const { return d_hasNlTerms; } diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index c5f0620f9..899bbfe0e 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -19,6 +19,7 @@ #include "proof/proof_checker.h" #include "proof/proof_rule.h" #include "smt/smt_statistics_registry.h" +#include "theory/arith/arith_evaluator.h" #include "theory/arith/arith_rewriter.h" #include "theory/arith/equality_solver.h" #include "theory/arith/infer_bounds.h" @@ -175,7 +176,6 @@ void TheoryArith::postCheck(Effort level) d_im.doPendingPhaseRequirements(); return; } - d_nonlinearExtension->finalizeModel(getValuation().getModel()); } return; } @@ -290,6 +290,13 @@ bool TheoryArith::collectModelValues(TheoryModel* m, { continue; } + if (d_nonlinearExtension != nullptr) + { + if (d_nonlinearExtension->assertModel(m, p.first)) + { + continue; + } + } // maps to constant of comparable type Assert(p.first.getType().isComparableTo(p.second.getType())); if (m->assertEquality(p.first, p.second, true)) @@ -327,15 +334,16 @@ void TheoryArith::presolve(){ EqualityStatus TheoryArith::getEqualityStatus(TNode a, TNode b) { Debug("arith") << "TheoryArith::getEqualityStatus(" << a << ", " << b << ")" << std::endl; + if (a == b) + { + return EQUALITY_TRUE_IN_MODEL; + } if (d_arithModelCache.empty()) { return d_internal->getEqualityStatus(a,b); } - Node aval = - rewrite(a.substitute(d_arithModelCache.begin(), d_arithModelCache.end())); - Node bval = - rewrite(b.substitute(d_arithModelCache.begin(), d_arithModelCache.end())); - if (aval == bval) + Node diff = d_env.getNodeManager()->mkNode(Kind::MINUS, a, b); + if (isExpressionZero(d_env, diff, d_arithModelCache)) { return EQUALITY_TRUE_IN_MODEL; } diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index b3cc02f68..bc084714d 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -739,6 +739,7 @@ set(regress_0_tests regress0/named-expr-use.smt2 regress0/nl/all-logic.smt2 regress0/nl/coeff-sat.smt2 + regress0/nl/combined-uf.smt2 regress0/nl/iand-no-init.smt2 regress0/nl/issue3003.smt2 regress0/nl/issue3407.smt2 diff --git a/test/regress/regress0/nl/combined-uf.smt2 b/test/regress/regress0/nl/combined-uf.smt2 new file mode 100644 index 000000000..ac0a39de4 --- /dev/null +++ b/test/regress/regress0/nl/combined-uf.smt2 @@ -0,0 +1,11 @@ +; EXPECT: unsat +(set-logic QF_UFNRA) +(declare-fun a () Real) +(declare-fun b () Real) +(declare-fun f (Real) Real) +(assert (= (* a a) 2)) +(assert (> a 0)) +(assert (= (* b b b b) 4)) +(assert (< b 0)) +(assert (not (= (f (* a a)) (f (* b b))))) +(check-sat) -- 2.30.2