Use Env class in nonlinear extension (#7039)
[cvc5.git] / src / theory / arith / branch_and_bound.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * Branch and bound for arithmetic
14 */
15
16 #include "theory/arith/branch_and_bound.h"
17
18 #include "options/arith_options.h"
19 #include "proof/eager_proof_generator.h"
20 #include "proof/proof_node.h"
21 #include "theory/arith/arith_utilities.h"
22 #include "theory/rewriter.h"
23 #include "theory/theory.h"
24
25 using namespace cvc5::kind;
26
27 namespace cvc5 {
28 namespace theory {
29 namespace arith {
30
31 BranchAndBound::BranchAndBound(ArithState& s,
32 InferenceManager& im,
33 PreprocessRewriteEq& ppre,
34 ProofNodeManager* pnm)
35 : d_astate(s),
36 d_im(im),
37 d_ppre(ppre),
38 d_pfGen(new EagerProofGenerator(pnm, s.getUserContext())),
39 d_pnm(pnm)
40 {
41 }
42
43 TrustNode BranchAndBound::branchIntegerVariable(TNode var, Rational value)
44 {
45 TrustNode lem = TrustNode::null();
46 NodeManager* nm = NodeManager::currentNM();
47 Integer floor = value.floor();
48 if (d_astate.options().arith.brabTest)
49 {
50 Trace("integers") << "branch-round-and-bound enabled" << std::endl;
51 Integer ceil = value.ceiling();
52 Rational f = value - floor;
53 // Multiply by -1 to get abs value.
54 Rational c = (value - ceil) * (-1);
55 Integer nearest = (c > f) ? floor : ceil;
56
57 // Prioritize trying a simple rounding of the real solution first,
58 // it that fails, fall back on original branch and bound strategy.
59 Node ub =
60 Rewriter::rewrite(nm->mkNode(LEQ, var, mkRationalNode(nearest - 1)));
61 Node lb =
62 Rewriter::rewrite(nm->mkNode(GEQ, var, mkRationalNode(nearest + 1)));
63 Node right = nm->mkNode(OR, ub, lb);
64 Node rawEq = nm->mkNode(EQUAL, var, mkRationalNode(nearest));
65 Node eq = Rewriter::rewrite(rawEq);
66 // Also preprocess it before we send it out. This is important since
67 // arithmetic may prefer eliminating equalities.
68 TrustNode teq;
69 if (Theory::theoryOf(eq) == THEORY_ARITH)
70 {
71 teq = d_ppre.ppRewriteEq(eq);
72 eq = teq.isNull() ? eq : teq.getNode();
73 }
74 Node literal = d_astate.getValuation().ensureLiteral(eq);
75 Trace("integers") << "eq: " << eq << "\nto: " << literal << std::endl;
76 d_im.requirePhase(literal, true);
77 Node l = nm->mkNode(OR, literal, right);
78 Trace("integers") << "l: " << l << std::endl;
79 if (proofsEnabled())
80 {
81 Node less = nm->mkNode(LT, var, mkRationalNode(nearest));
82 Node greater = nm->mkNode(GT, var, mkRationalNode(nearest));
83 // TODO (project #37): justify. Thread proofs through *ensureLiteral*.
84 Debug("integers::pf") << "less: " << less << std::endl;
85 Debug("integers::pf") << "greater: " << greater << std::endl;
86 Debug("integers::pf") << "literal: " << literal << std::endl;
87 Debug("integers::pf") << "eq: " << eq << std::endl;
88 Debug("integers::pf") << "rawEq: " << rawEq << std::endl;
89 Pf pfNotLit = d_pnm->mkAssume(literal.negate());
90 // rewrite notLiteral to notRawEq, using teq.
91 Pf pfNotRawEq =
92 literal == rawEq
93 ? pfNotLit
94 : d_pnm->mkNode(
95 PfRule::MACRO_SR_PRED_TRANSFORM,
96 {pfNotLit,
97 teq.getGenerator()->getProofFor(teq.getProven())},
98 {rawEq.negate()});
99 Pf pfBot = d_pnm->mkNode(
100 PfRule::CONTRA,
101 {d_pnm->mkNode(PfRule::ARITH_TRICHOTOMY,
102 {d_pnm->mkAssume(less.negate()), pfNotRawEq},
103 {greater}),
104 d_pnm->mkAssume(greater.negate())},
105 {});
106 std::vector<Node> assumptions = {
107 literal.negate(), less.negate(), greater.negate()};
108 // Proof of (not (and (not (= v i)) (not (< v i)) (not (> v i))))
109 Pf pfNotAnd = d_pnm->mkScope(pfBot, assumptions);
110 Pf pfL = d_pnm->mkNode(PfRule::MACRO_SR_PRED_TRANSFORM,
111 {d_pnm->mkNode(PfRule::NOT_AND, {pfNotAnd}, {})},
112 {l});
113 lem = d_pfGen->mkTrustNode(l, pfL);
114 }
115 else
116 {
117 lem = TrustNode::mkTrustLemma(l, nullptr);
118 }
119 }
120 else
121 {
122 Node ub = Rewriter::rewrite(nm->mkNode(LEQ, var, mkRationalNode(floor)));
123 Node lb = ub.notNode();
124 if (proofsEnabled())
125 {
126 lem =
127 d_pfGen->mkTrustNode(nm->mkNode(OR, ub, lb), PfRule::SPLIT, {}, {ub});
128 }
129 else
130 {
131 lem = TrustNode::mkTrustLemma(nm->mkNode(OR, ub, lb), nullptr);
132 }
133 }
134
135 Trace("integers") << "integers: branch & bound: " << lem << std::endl;
136 return lem;
137 }
138
139 bool BranchAndBound::proofsEnabled() const { return d_pnm != nullptr; }
140
141 } // namespace arith
142 } // namespace theory
143 } // namespace cvc5