Improve theory combination in the presence of real algebraic numbers (#7883)
authorGereon Kremer <gkremer@stanford.edu>
Thu, 6 Jan 2022 21:09:10 +0000 (13:09 -0800)
committerGitHub <noreply@github.com>
Thu, 6 Jan 2022 21:09:10 +0000 (21:09 +0000)
This PR changes how we handle real algebraic numbers in theory combination and model construction.
The goal is to improve getEqualityStatus() and produce proper models more often.
We now use a RAN-aware evaluator for getEqualityStatus() and change the way how the nonlinear extension finalizes its model.

src/CMakeLists.txt
src/theory/arith/arith_evaluator.cpp [new file with mode: 0644]
src/theory/arith/arith_evaluator.h [new file with mode: 0644]
src/theory/arith/nl/cad_solver.cpp
src/theory/arith/nl/cad_solver.h
src/theory/arith/nl/nonlinear_extension.cpp
src/theory/arith/nl/nonlinear_extension.h
src/theory/arith/theory_arith.cpp
test/regress/CMakeLists.txt
test/regress/regress0/nl/combined-uf.smt2 [new file with mode: 0644]

index 7a62a327aad0ee4002e017d0c412c70a507a32ef..830c70ca9855cee69b101ec91c078e10a1574b51 100644 (file)
@@ -352,6 +352,8 @@ libcvc5_add_sources(
   smt_util/boolean_simplification.h
   theory/arith/approx_simplex.cpp
   theory/arith/approx_simplex.h
+  theory/arith/arith_evaluator.cpp
+  theory/arith/arith_evaluator.h
   theory/arith/arith_ite_utils.cpp
   theory/arith/arith_ite_utils.h
   theory/arith/arith_msum.cpp
diff --git a/src/theory/arith/arith_evaluator.cpp b/src/theory/arith/arith_evaluator.cpp
new file mode 100644 (file)
index 0000000..0fe045a
--- /dev/null
@@ -0,0 +1,94 @@
+#include "theory/arith/arith_evaluator.h"
+
+#include "theory/arith/nl/poly_conversion.h"
+#include "theory/rewriter.h"
+#include "theory/theory.h"
+#include "util/real_algebraic_number.h"
+
+namespace cvc5 {
+namespace theory {
+namespace arith {
+
+namespace {
+
+RealAlgebraicNumber evaluate(TNode expr,
+                             const std::map<Node, RealAlgebraicNumber>& rans)
+{
+  switch (expr.getKind())
+  {
+    case Kind::PLUS:
+    {
+      RealAlgebraicNumber aggr;
+      for (const auto& n : expr)
+      {
+        aggr += evaluate(n, rans);
+      }
+      return aggr;
+    }
+    case Kind::MULT:
+    case Kind::NONLINEAR_MULT:
+    {
+      RealAlgebraicNumber aggr(Integer(1));
+      for (const auto& n : expr)
+      {
+        aggr *= evaluate(n, rans);
+      }
+      return aggr;
+    }
+    case Kind::MINUS:
+      Assert(expr.getNumChildren() == 2);
+      return evaluate(expr[0], rans) - evaluate(expr[1], rans);
+    case Kind::UMINUS: return -evaluate(expr[0], rans);
+    case Kind::CONST_RATIONAL:
+      return RealAlgebraicNumber(expr.getConst<Rational>());
+    default:
+      auto it = rans.find(expr);
+      if (it != rans.end())
+      {
+        return it->second;
+      }
+      Assert(false) << "Unsupported expression kind for RAN evaluation: "
+                    << expr.getKind();
+      return RealAlgebraicNumber();
+  }
+}
+
+}  // namespace
+
+bool isExpressionZero(Env& env, Node expr, const std::map<Node, Node>& model)
+{
+  // Substitute constants and rewrite
+  expr = env.getRewriter()->rewrite(expr);
+  if (expr.isConst())
+  {
+    return expr.getConst<Rational>().isZero();
+  }
+  std::map<Node, RealAlgebraicNumber> rans;
+  std::vector<TNode> nodes;
+  std::vector<TNode> repls;
+  for (const auto& [node, repl] : model)
+  {
+    if (repl.getType().isRealOrInt()
+        && Theory::isLeafOf(repl, TheoryId::THEORY_ARITH))
+    {
+      nodes.emplace_back(node);
+      repls.emplace_back(repl);
+    }
+    else
+    {
+      rans.emplace(node, nl::node_to_ran(repl, node));
+    }
+  }
+  expr =
+      expr.substitute(nodes.begin(), nodes.end(), repls.begin(), repls.end());
+  expr = env.getRewriter()->rewrite(expr);
+  if (expr.isConst())
+  {
+    return expr.getConst<Rational>().isZero();
+  }
+  return isZero(evaluate(expr, rans));
+}
+
+}  // namespace arith
+}  // namespace theory
+}  // namespace cvc5
diff --git a/src/theory/arith/arith_evaluator.h b/src/theory/arith/arith_evaluator.h
new file mode 100644 (file)
index 0000000..cc50c67
--- /dev/null
@@ -0,0 +1,25 @@
+
+#include "cvc5_private.h"
+
+#ifndef CVC5__THEORY__ARITH__ARITH_EVALUATOR_H
+#define CVC5__THEORY__ARITH__ARITH_EVALUATOR_H
+
+#include "expr/node.h"
+#include "smt/env.h"
+
+namespace cvc5 {
+namespace theory {
+namespace arith {
+
+/**
+ * Check if the expression `expr` is zero over the given model.
+ * The model may contain real algebraic numbers in standard witness form.
+ * The environment is used for rewriting.
+ */
+bool isExpressionZero(Env& env, Node expr, const std::map<Node, Node>& model);
+
+}
+}  // namespace theory
+}  // namespace cvc5
+
+#endif
\ No newline at end of file
index f4582ac2017d3a7b6cac8e43944cddfe65e74400..7ecbccf6db8b7bcbd400729f8b02955cfa6d495a 100644 (file)
@@ -206,11 +206,6 @@ bool CadSolver::constructModelIfAvailable(std::vector<Node>& assertions)
     return false;
   }
   bool foundNonVariable = false;
-  for (const auto& sub: d_eqsubs.getSubstitutions())
-  {
-    d_model.addSubstitution(sub.first, sub.second);
-    Trace("nl-cad") << "-> " << sub.first << " = " << sub.second << std::endl;
-  }
   for (const auto& v : d_CAC.getVariableOrdering())
   {
     Node variable = d_CAC.getConstraints().varMapper()(v);
@@ -219,16 +214,14 @@ bool CadSolver::constructModelIfAvailable(std::vector<Node>& assertions)
       Trace("nl-cad") << "Not a variable: " << variable << std::endl;
       foundNonVariable = true;
     }
-    Node value = value_to_node(d_CAC.getModel().get(v), d_ranVariable);
-    if (value.isConst())
-    {
-      d_model.addSubstitution(variable, value);
-    }
-    else
-    {
-      d_model.addWitness(variable, value);
-    }
-    Trace("nl-cad") << "-> " << v << " = " << value << std::endl;
+    Node value = value_to_node(d_CAC.getModel().get(v), variable);
+    addToModel(variable, value);
+  }
+  for (const auto& sub : d_eqsubs.getSubstitutions())
+  {
+    Trace("nl-cad") << "EqSubs: " << sub.first << " -> " << sub.second
+                    << std::endl;
+    addToModel(sub.first, sub.second);
   }
   if (foundNonVariable)
   {
@@ -249,6 +242,19 @@ bool CadSolver::constructModelIfAvailable(std::vector<Node>& assertions)
 #endif
 }
 
+void CadSolver::addToModel(TNode var, TNode value) const
+{
+  Trace("nl-cad") << "-> " << var << " = " << value << std::endl;
+  if (value.getType().isRealOrInt())
+  {
+    d_model.addSubstitution(var, value);
+  }
+  else
+  {
+    d_model.addWitness(var, value);
+  }
+}
+
 }  // namespace nl
 }  // namespace arith
 }  // namespace theory
index 73d09378bd0e5ad0e8f3a782d9196edc9327c35b..d72c92a8a59e531033b14c2125a1afb64db9f626 100644 (file)
@@ -82,6 +82,12 @@ class CadSolver: protected EnvObj
   bool constructModelIfAvailable(std::vector<Node>& assertions);
 
  private:
+  /**
+   * Add the variable assignment `var = value` to the nonlinear model.
+   * Depending on `value`, it is either added as substitution or witness.
+   */
+  void addToModel(TNode var, TNode value) const;
+
   /**
    * The variable used to encode real algebraic numbers to nodes.
    */
index 3f60f859649b958510996914172295ff0448b806..b57c0d1dbb574a09e094ce1307a000cee29256cf 100644 (file)
@@ -283,28 +283,55 @@ void NonlinearExtension::checkFullEffort(std::map<Node, Node>& arithModel,
                                 d_approximations,
                                 d_witnesses,
                                 options().smt.modelWitnessValue);
+    for (auto& am : arithModel)
+    {
+      Node val = getModelValue(am.first);
+      if (!val.isNull())
+      {
+        am.second = val;
+      }
+    }
   }
 }
 
-void NonlinearExtension::finalizeModel(TheoryModel* tm)
+Node NonlinearExtension::getModelValue(TNode var) const
 {
-  Trace("nl-ext") << "NonlinearExtension::finalizeModel" << std::endl;
+  if (auto it = d_approximations.find(var); it != d_approximations.end())
+  {
+    if (it->second.second.isNull())
+    {
+      return it->second.first;
+    }
+    return Node::null();
+  }
+  if (auto it = d_witnesses.find(var); it != d_witnesses.end())
+  {
+    return it->second;
+  }
+  return Node::null();
+}
 
-  for (std::pair<const Node, std::pair<Node, Node>>& a : d_approximations)
+bool NonlinearExtension::assertModel(TheoryModel* tm, TNode var) const
+{
+  if (auto it = d_approximations.find(var); it != d_approximations.end())
   {
-    if (a.second.second.isNull())
+    const auto& approx = it->second;
+    if (approx.second.isNull())
     {
-      tm->recordApproximation(a.first, a.second.first);
+      tm->recordApproximation(var, approx.first);
     }
     else
     {
-      tm->recordApproximation(a.first, a.second.first, a.second.second);
+      tm->recordApproximation(var, approx.first, approx.second);
     }
+    return true;
   }
-  for (const auto& vw : d_witnesses)
+  if (auto it = d_witnesses.find(var); it != d_witnesses.end())
   {
-    tm->recordApproximation(vw.first, vw.second);
+    tm->recordApproximation(var, it->second);
+    return true;
   }
+  return false;
 }
 
 Result::Sat NonlinearExtension::modelBasedRefinement(const std::set<Node>& termSet)
index 53e0db90efe94e002ce409dec050184f52929458..390dd72a36591464e851292c80241793b78e1343 100644 (file)
@@ -112,9 +112,14 @@ class NonlinearExtension : EnvObj
                        const std::set<Node>& termSet);
 
   /**
-   * Finalize the given model by adding approximations and witnesses.
+   * Retrieve the model value for the given variable. It may be either an
+   * arithmetic term or a witness.
    */
-  void finalizeModel(TheoryModel* tm);
+  Node getModelValue(TNode var) const;
+  /**
+   * Assert the model for the given variable to the theory model.
+   */
+  bool assertModel(TheoryModel* tm, TNode var) const;
 
   /** Does this class need a call to check(...) at last call effort? */
   bool hasNlTerms() const { return d_hasNlTerms; }
index c5f0620f91f3132fc0990d1c3b4249c5842ac465..899bbfe0ee014fc71d5b662f29530bcab3468188 100644 (file)
@@ -19,6 +19,7 @@
 #include "proof/proof_checker.h"
 #include "proof/proof_rule.h"
 #include "smt/smt_statistics_registry.h"
+#include "theory/arith/arith_evaluator.h"
 #include "theory/arith/arith_rewriter.h"
 #include "theory/arith/equality_solver.h"
 #include "theory/arith/infer_bounds.h"
@@ -175,7 +176,6 @@ void TheoryArith::postCheck(Effort level)
         d_im.doPendingPhaseRequirements();
         return;
       }
-      d_nonlinearExtension->finalizeModel(getValuation().getModel());
     }
     return;
   }
@@ -290,6 +290,13 @@ bool TheoryArith::collectModelValues(TheoryModel* m,
     {
       continue;
     }
+    if (d_nonlinearExtension != nullptr)
+    {
+      if (d_nonlinearExtension->assertModel(m, p.first))
+      {
+        continue;
+      }
+    }
     // maps to constant of comparable type
     Assert(p.first.getType().isComparableTo(p.second.getType()));
     if (m->assertEquality(p.first, p.second, true))
@@ -327,15 +334,16 @@ void TheoryArith::presolve(){
 
 EqualityStatus TheoryArith::getEqualityStatus(TNode a, TNode b) {
   Debug("arith") << "TheoryArith::getEqualityStatus(" << a << ", " << b << ")" << std::endl;
+  if (a == b)
+  {
+    return EQUALITY_TRUE_IN_MODEL;
+  }
   if (d_arithModelCache.empty())
   {
     return d_internal->getEqualityStatus(a,b);
   }
-  Node aval =
-      rewrite(a.substitute(d_arithModelCache.begin(), d_arithModelCache.end()));
-  Node bval =
-      rewrite(b.substitute(d_arithModelCache.begin(), d_arithModelCache.end()));
-  if (aval == bval)
+  Node diff = d_env.getNodeManager()->mkNode(Kind::MINUS, a, b);
+  if (isExpressionZero(d_env, diff, d_arithModelCache))
   {
     return EQUALITY_TRUE_IN_MODEL;
   }
index b3cc02f6809f736a69695e4070296bb6f79b210b..bc084714defad0bca0d0c82e4d1adf2ef4f116d4 100644 (file)
@@ -739,6 +739,7 @@ set(regress_0_tests
   regress0/named-expr-use.smt2
   regress0/nl/all-logic.smt2
   regress0/nl/coeff-sat.smt2
+  regress0/nl/combined-uf.smt2
   regress0/nl/iand-no-init.smt2
   regress0/nl/issue3003.smt2
   regress0/nl/issue3407.smt2
diff --git a/test/regress/regress0/nl/combined-uf.smt2 b/test/regress/regress0/nl/combined-uf.smt2
new file mode 100644 (file)
index 0000000..ac0a39d
--- /dev/null
@@ -0,0 +1,11 @@
+; EXPECT: unsat
+(set-logic QF_UFNRA)
+(declare-fun a () Real)
+(declare-fun b () Real)
+(declare-fun f (Real) Real)
+(assert (= (* a a) 2))
+(assert (> a 0))
+(assert (= (* b b b b) 4))
+(assert (< b 0))
+(assert (not (= (f (* a a)) (f (* b b)))))
+(check-sat)