Arithmetic equality solver (#6876)
[cvc5.git] / src / theory / arith / equality_solver.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * Arithmetic equality solver
14 */
15
16 #include "theory/arith/equality_solver.h"
17
18 #include "theory/arith/inference_manager.h"
19
20 using namespace cvc5::kind;
21
22 namespace cvc5 {
23 namespace theory {
24 namespace arith {
25
26 EqualitySolver::EqualitySolver(ArithState& astate, InferenceManager& aim)
27 : d_astate(astate),
28 d_aim(aim),
29 d_notify(*this),
30 d_ee(nullptr),
31 d_propLits(astate.getSatContext())
32 {
33 }
34
35 bool EqualitySolver::needsEqualityEngine(EeSetupInfo& esi)
36 {
37 esi.d_notify = &d_notify;
38 esi.d_name = "arith::ee";
39 return true;
40 }
41
42 void EqualitySolver::finishInit()
43 {
44 d_ee = d_astate.getEqualityEngine();
45 // add the function kinds
46 d_ee->addFunctionKind(kind::NONLINEAR_MULT);
47 d_ee->addFunctionKind(kind::EXPONENTIAL);
48 d_ee->addFunctionKind(kind::SINE);
49 d_ee->addFunctionKind(kind::IAND);
50 d_ee->addFunctionKind(kind::POW2);
51 }
52
53 bool EqualitySolver::preNotifyFact(
54 TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal)
55 {
56 if (atom.getKind() != EQUAL)
57 {
58 // finished processing, since not beneficial to add non-equality facts
59 return true;
60 }
61 Trace("arith-eq-solver") << "EqualitySolver::preNotifyFact: " << fact
62 << std::endl;
63 // we will process
64 return false;
65 }
66
67 TrustNode EqualitySolver::explain(TNode lit)
68 {
69 Trace("arith-eq-solver-debug") << "explain " << lit << "?" << std::endl;
70 // check if we propagated it?
71 if (d_propLits.find(lit) == d_propLits.end())
72 {
73 Trace("arith-eq-solver-debug") << "...did not propagate" << std::endl;
74 return TrustNode::null();
75 }
76 Trace("arith-eq-solver-debug")
77 << "...explain via inference manager" << std::endl;
78 // if we did, explain with the arithmetic inference manager
79 return d_aim.explainLit(lit);
80 }
81 bool EqualitySolver::propagateLit(Node lit)
82 {
83 // notice this is only used when ee-mode=distributed
84 // remember that this was a literal we propagated
85 Trace("arith-eq-solver-debug") << "propagate lit " << lit << std::endl;
86 d_propLits.insert(lit);
87 return d_aim.propagateLit(lit);
88 }
89 void EqualitySolver::conflictEqConstantMerge(TNode a, TNode b)
90 {
91 d_aim.conflictEqConstantMerge(a, b);
92 }
93
94 bool EqualitySolver::EqualitySolverNotify::eqNotifyTriggerPredicate(
95 TNode predicate, bool value)
96 {
97 Trace("arith-eq-solver") << "...propagate (predicate) " << predicate << " -> "
98 << value << std::endl;
99 if (value)
100 {
101 return d_es.propagateLit(predicate);
102 }
103 return d_es.propagateLit(predicate.notNode());
104 }
105
106 bool EqualitySolver::EqualitySolverNotify::eqNotifyTriggerTermEquality(
107 TheoryId tag, TNode t1, TNode t2, bool value)
108 {
109 Trace("arith-eq-solver") << "...propagate (term eq) " << t1.eqNode(t2)
110 << " -> " << value << std::endl;
111 if (value)
112 {
113 return d_es.propagateLit(t1.eqNode(t2));
114 }
115 return d_es.propagateLit(t1.eqNode(t2).notNode());
116 }
117
118 void EqualitySolver::EqualitySolverNotify::eqNotifyConstantTermMerge(TNode t1,
119 TNode t2)
120 {
121 Trace("arith-eq-solver") << "...conflict merge " << t1 << " " << t2
122 << std::endl;
123 d_es.conflictEqConstantMerge(t1, t2);
124 }
125
126 } // namespace arith
127 } // namespace theory
128 } // namespace cvc5