From be1f03037110e8334bb2e73e9b6afb76eee959e2 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Thu, 15 Jul 2021 08:25:28 -0500 Subject: [PATCH] Arithmetic equality solver (#6876) This is work towards the central equality engine. This adds a module of arithmetic that uses the equality engine in the default way. This class will be incorporated into theory_arith.cpp. It will be the replacement for CongruenceManager when we use the central equality engine architecture. --- src/theory/arith/equality_solver.cpp | 128 +++++++++++++++++++++++++++ src/theory/arith/equality_solver.h | 115 ++++++++++++++++++++++++ 2 files changed, 243 insertions(+) create mode 100644 src/theory/arith/equality_solver.cpp create mode 100644 src/theory/arith/equality_solver.h diff --git a/src/theory/arith/equality_solver.cpp b/src/theory/arith/equality_solver.cpp new file mode 100644 index 000000000..58793c654 --- /dev/null +++ b/src/theory/arith/equality_solver.cpp @@ -0,0 +1,128 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Arithmetic equality solver + */ + +#include "theory/arith/equality_solver.h" + +#include "theory/arith/inference_manager.h" + +using namespace cvc5::kind; + +namespace cvc5 { +namespace theory { +namespace arith { + +EqualitySolver::EqualitySolver(ArithState& astate, InferenceManager& aim) + : d_astate(astate), + d_aim(aim), + d_notify(*this), + d_ee(nullptr), + d_propLits(astate.getSatContext()) +{ +} + +bool EqualitySolver::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = "arith::ee"; + return true; +} + +void EqualitySolver::finishInit() +{ + d_ee = d_astate.getEqualityEngine(); + // add the function kinds + d_ee->addFunctionKind(kind::NONLINEAR_MULT); + d_ee->addFunctionKind(kind::EXPONENTIAL); + d_ee->addFunctionKind(kind::SINE); + d_ee->addFunctionKind(kind::IAND); + d_ee->addFunctionKind(kind::POW2); +} + +bool EqualitySolver::preNotifyFact( + TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal) +{ + if (atom.getKind() != EQUAL) + { + // finished processing, since not beneficial to add non-equality facts + return true; + } + Trace("arith-eq-solver") << "EqualitySolver::preNotifyFact: " << fact + << std::endl; + // we will process + return false; +} + +TrustNode EqualitySolver::explain(TNode lit) +{ + Trace("arith-eq-solver-debug") << "explain " << lit << "?" << std::endl; + // check if we propagated it? + if (d_propLits.find(lit) == d_propLits.end()) + { + Trace("arith-eq-solver-debug") << "...did not propagate" << std::endl; + return TrustNode::null(); + } + Trace("arith-eq-solver-debug") + << "...explain via inference manager" << std::endl; + // if we did, explain with the arithmetic inference manager + return d_aim.explainLit(lit); +} +bool EqualitySolver::propagateLit(Node lit) +{ + // notice this is only used when ee-mode=distributed + // remember that this was a literal we propagated + Trace("arith-eq-solver-debug") << "propagate lit " << lit << std::endl; + d_propLits.insert(lit); + return d_aim.propagateLit(lit); +} +void EqualitySolver::conflictEqConstantMerge(TNode a, TNode b) +{ + d_aim.conflictEqConstantMerge(a, b); +} + +bool EqualitySolver::EqualitySolverNotify::eqNotifyTriggerPredicate( + TNode predicate, bool value) +{ + Trace("arith-eq-solver") << "...propagate (predicate) " << predicate << " -> " + << value << std::endl; + if (value) + { + return d_es.propagateLit(predicate); + } + return d_es.propagateLit(predicate.notNode()); +} + +bool EqualitySolver::EqualitySolverNotify::eqNotifyTriggerTermEquality( + TheoryId tag, TNode t1, TNode t2, bool value) +{ + Trace("arith-eq-solver") << "...propagate (term eq) " << t1.eqNode(t2) + << " -> " << value << std::endl; + if (value) + { + return d_es.propagateLit(t1.eqNode(t2)); + } + return d_es.propagateLit(t1.eqNode(t2).notNode()); +} + +void EqualitySolver::EqualitySolverNotify::eqNotifyConstantTermMerge(TNode t1, + TNode t2) +{ + Trace("arith-eq-solver") << "...conflict merge " << t1 << " " << t2 + << std::endl; + d_es.conflictEqConstantMerge(t1, t2); +} + +} // namespace arith +} // namespace theory +} // namespace cvc5 diff --git a/src/theory/arith/equality_solver.h b/src/theory/arith/equality_solver.h new file mode 100644 index 000000000..bce30e697 --- /dev/null +++ b/src/theory/arith/equality_solver.h @@ -0,0 +1,115 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Arithmetic equality solver + */ + +#include "cvc5_private.h" + +#ifndef CVC5__THEORY__ARITH__EQUALITY_SOLVER_H +#define CVC5__THEORY__ARITH__EQUALITY_SOLVER_H + +#include "context/cdhashset.h" +#include "expr/node.h" +#include "proof/trust_node.h" +#include "theory/arith/arith_state.h" +#include "theory/ee_setup_info.h" +#include "theory/uf/equality_engine.h" + +namespace cvc5 { +namespace theory { +namespace arith { + +class InferenceManager; + +/** + * The arithmetic equality solver. This class manages arithmetic equalities + * in the default way via an equality engine. + * + * Since arithmetic has multiple ways of propagating literals, it tracks + * the literals that it propagates and only explains the literals that + * originated from this class. + */ +class EqualitySolver +{ + using NodeSet = context::CDHashSet; + + public: + EqualitySolver(ArithState& astate, InferenceManager& aim); + ~EqualitySolver() {} + //--------------------------------- initialization + /** + * Returns true if we need an equality engine, see + * Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi); + /** + * Finish initialize + */ + void finishInit(); + //--------------------------------- end initialization + /** + * Pre-notify fact, return true if we are finished processing, false if + * we wish to assert the fact to the equality engine of this class. + */ + bool preNotifyFact( + TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal); + /** + * Return an explanation for the literal lit (which was previously propagated + * by this solver). + */ + TrustNode explain(TNode lit); + + private: + /** Notification class from the equality engine */ + class EqualitySolverNotify : public eq::EqualityEngineNotify + { + public: + EqualitySolverNotify(EqualitySolver& es) : d_es(es) {} + + bool eqNotifyTriggerPredicate(TNode predicate, bool value) override; + + bool eqNotifyTriggerTermEquality(TheoryId tag, + TNode t1, + TNode t2, + bool value) override; + + void eqNotifyConstantTermMerge(TNode t1, TNode t2) override; + void eqNotifyNewClass(TNode t) override {} + void eqNotifyMerge(TNode t1, TNode t2) override {} + void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override {} + + private: + /** reference to parent */ + EqualitySolver& d_es; + }; + /** Propagate literal */ + bool propagateLit(Node lit); + /** Conflict when two constants merge */ + void conflictEqConstantMerge(TNode a, TNode b); + /** reference to the state */ + ArithState& d_astate; + /** reference to parent */ + InferenceManager& d_aim; + /** Equality solver notify */ + EqualitySolverNotify d_notify; + /** Pointer to the equality engine */ + eq::EqualityEngine* d_ee; + /** The literals we have propagated */ + NodeSet d_propLits; +}; + +} // namespace arith +} // namespace theory +} // namespace cvc5 + +#endif -- 2.30.2