Arithmetic equality solver (#6876)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 15 Jul 2021 13:25:28 +0000 (08:25 -0500)
committerGitHub <noreply@github.com>
Thu, 15 Jul 2021 13:25:28 +0000 (13:25 +0000)
This is work towards the central equality engine. This adds a module of arithmetic that uses the equality engine in the default way.

This class will be incorporated into theory_arith.cpp. It will be the replacement for CongruenceManager when we use the central equality engine architecture.

src/theory/arith/equality_solver.cpp [new file with mode: 0644]
src/theory/arith/equality_solver.h [new file with mode: 0644]

diff --git a/src/theory/arith/equality_solver.cpp b/src/theory/arith/equality_solver.cpp
new file mode 100644 (file)
index 0000000..58793c6
--- /dev/null
@@ -0,0 +1,128 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Arithmetic equality solver
+ */
+
+#include "theory/arith/equality_solver.h"
+
+#include "theory/arith/inference_manager.h"
+
+using namespace cvc5::kind;
+
+namespace cvc5 {
+namespace theory {
+namespace arith {
+
+EqualitySolver::EqualitySolver(ArithState& astate, InferenceManager& aim)
+    : d_astate(astate),
+      d_aim(aim),
+      d_notify(*this),
+      d_ee(nullptr),
+      d_propLits(astate.getSatContext())
+{
+}
+
+bool EqualitySolver::needsEqualityEngine(EeSetupInfo& esi)
+{
+  esi.d_notify = &d_notify;
+  esi.d_name = "arith::ee";
+  return true;
+}
+
+void EqualitySolver::finishInit()
+{
+  d_ee = d_astate.getEqualityEngine();
+  // add the function kinds
+  d_ee->addFunctionKind(kind::NONLINEAR_MULT);
+  d_ee->addFunctionKind(kind::EXPONENTIAL);
+  d_ee->addFunctionKind(kind::SINE);
+  d_ee->addFunctionKind(kind::IAND);
+  d_ee->addFunctionKind(kind::POW2);
+}
+
+bool EqualitySolver::preNotifyFact(
+    TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal)
+{
+  if (atom.getKind() != EQUAL)
+  {
+    // finished processing, since not beneficial to add non-equality facts
+    return true;
+  }
+  Trace("arith-eq-solver") << "EqualitySolver::preNotifyFact: " << fact
+                           << std::endl;
+  // we will process
+  return false;
+}
+
+TrustNode EqualitySolver::explain(TNode lit)
+{
+  Trace("arith-eq-solver-debug") << "explain " << lit << "?" << std::endl;
+  // check if we propagated it?
+  if (d_propLits.find(lit) == d_propLits.end())
+  {
+    Trace("arith-eq-solver-debug") << "...did not propagate" << std::endl;
+    return TrustNode::null();
+  }
+  Trace("arith-eq-solver-debug")
+      << "...explain via inference manager" << std::endl;
+  // if we did, explain with the arithmetic inference manager
+  return d_aim.explainLit(lit);
+}
+bool EqualitySolver::propagateLit(Node lit)
+{
+  // notice this is only used when ee-mode=distributed
+  // remember that this was a literal we propagated
+  Trace("arith-eq-solver-debug") << "propagate lit " << lit << std::endl;
+  d_propLits.insert(lit);
+  return d_aim.propagateLit(lit);
+}
+void EqualitySolver::conflictEqConstantMerge(TNode a, TNode b)
+{
+  d_aim.conflictEqConstantMerge(a, b);
+}
+
+bool EqualitySolver::EqualitySolverNotify::eqNotifyTriggerPredicate(
+    TNode predicate, bool value)
+{
+  Trace("arith-eq-solver") << "...propagate (predicate) " << predicate << " -> "
+                           << value << std::endl;
+  if (value)
+  {
+    return d_es.propagateLit(predicate);
+  }
+  return d_es.propagateLit(predicate.notNode());
+}
+
+bool EqualitySolver::EqualitySolverNotify::eqNotifyTriggerTermEquality(
+    TheoryId tag, TNode t1, TNode t2, bool value)
+{
+  Trace("arith-eq-solver") << "...propagate (term eq) " << t1.eqNode(t2)
+                           << " -> " << value << std::endl;
+  if (value)
+  {
+    return d_es.propagateLit(t1.eqNode(t2));
+  }
+  return d_es.propagateLit(t1.eqNode(t2).notNode());
+}
+
+void EqualitySolver::EqualitySolverNotify::eqNotifyConstantTermMerge(TNode t1,
+                                                                     TNode t2)
+{
+  Trace("arith-eq-solver") << "...conflict merge " << t1 << " " << t2
+                           << std::endl;
+  d_es.conflictEqConstantMerge(t1, t2);
+}
+
+}  // namespace arith
+}  // namespace theory
+}  // namespace cvc5
diff --git a/src/theory/arith/equality_solver.h b/src/theory/arith/equality_solver.h
new file mode 100644 (file)
index 0000000..bce30e6
--- /dev/null
@@ -0,0 +1,115 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Arithmetic equality solver
+ */
+
+#include "cvc5_private.h"
+
+#ifndef CVC5__THEORY__ARITH__EQUALITY_SOLVER_H
+#define CVC5__THEORY__ARITH__EQUALITY_SOLVER_H
+
+#include "context/cdhashset.h"
+#include "expr/node.h"
+#include "proof/trust_node.h"
+#include "theory/arith/arith_state.h"
+#include "theory/ee_setup_info.h"
+#include "theory/uf/equality_engine.h"
+
+namespace cvc5 {
+namespace theory {
+namespace arith {
+
+class InferenceManager;
+
+/**
+ * The arithmetic equality solver. This class manages arithmetic equalities
+ * in the default way via an equality engine.
+ *
+ * Since arithmetic has multiple ways of propagating literals, it tracks
+ * the literals that it propagates and only explains the literals that
+ * originated from this class.
+ */
+class EqualitySolver
+{
+  using NodeSet = context::CDHashSet<Node>;
+
+ public:
+  EqualitySolver(ArithState& astate, InferenceManager& aim);
+  ~EqualitySolver() {}
+  //--------------------------------- initialization
+  /**
+   * Returns true if we need an equality engine, see
+   * Theory::needsEqualityEngine.
+   */
+  bool needsEqualityEngine(EeSetupInfo& esi);
+  /**
+   * Finish initialize
+   */
+  void finishInit();
+  //--------------------------------- end initialization
+  /**
+   * Pre-notify fact, return true if we are finished processing, false if
+   * we wish to assert the fact to the equality engine of this class.
+   */
+  bool preNotifyFact(
+      TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal);
+  /**
+   * Return an explanation for the literal lit (which was previously propagated
+   * by this solver).
+   */
+  TrustNode explain(TNode lit);
+
+ private:
+  /** Notification class from the equality engine */
+  class EqualitySolverNotify : public eq::EqualityEngineNotify
+  {
+   public:
+    EqualitySolverNotify(EqualitySolver& es) : d_es(es) {}
+
+    bool eqNotifyTriggerPredicate(TNode predicate, bool value) override;
+
+    bool eqNotifyTriggerTermEquality(TheoryId tag,
+                                     TNode t1,
+                                     TNode t2,
+                                     bool value) override;
+
+    void eqNotifyConstantTermMerge(TNode t1, TNode t2) override;
+    void eqNotifyNewClass(TNode t) override {}
+    void eqNotifyMerge(TNode t1, TNode t2) override {}
+    void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override {}
+
+   private:
+    /** reference to parent */
+    EqualitySolver& d_es;
+  };
+  /** Propagate literal */
+  bool propagateLit(Node lit);
+  /** Conflict when two constants merge */
+  void conflictEqConstantMerge(TNode a, TNode b);
+  /** reference to the state */
+  ArithState& d_astate;
+  /** reference to parent */
+  InferenceManager& d_aim;
+  /** Equality solver notify */
+  EqualitySolverNotify d_notify;
+  /** Pointer to the equality engine */
+  eq::EqualityEngine* d_ee;
+  /** The literals we have propagated */
+  NodeSet d_propLits;
+};
+
+}  // namespace arith
+}  // namespace theory
+}  // namespace cvc5
+
+#endif