Refactor and update copyright headers. (#6316)
[cvc5.git] / src / theory / arith / tableau.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Tim King
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * [[ Add one-line brief description here ]]
14 *
15 * [[ Add lengthier description here ]]
16 * \todo document this file
17 */
18
19 #include "base/output.h"
20 #include "theory/arith/tableau.h"
21
22 using namespace std;
23 namespace cvc5 {
24 namespace theory {
25 namespace arith {
26
27
28 void Tableau::pivot(ArithVar oldBasic, ArithVar newBasic, CoefficientChangeCallback& cb){
29 Assert(isBasic(oldBasic));
30 Assert(!isBasic(newBasic));
31 Assert(d_mergeBuffer.empty());
32
33 Debug("tableau") << "Tableau::pivot(" << oldBasic <<", " << newBasic <<")" << endl;
34
35 RowIndex ridx = basicToRowIndex(oldBasic);
36
37 rowPivot(oldBasic, newBasic, cb);
38 Assert(ridx == basicToRowIndex(newBasic));
39
40 loadRowIntoBuffer(ridx);
41
42 ColIterator colIter = colIterator(newBasic);
43 while(!colIter.atEnd()){
44 EntryID id = colIter.getID();
45 Entry& entry = d_entries.get(id);
46
47 ++colIter; //needs to be incremented before the variable is removed
48 if(entry.getRowIndex() == ridx){ continue; }
49
50 RowIndex to = entry.getRowIndex();
51 Rational coeff = entry.getCoefficient();
52 if(cb.canUseRow(to)){
53 rowPlusBufferTimesConstant(to, coeff, cb);
54 }else{
55 rowPlusBufferTimesConstant(to, coeff);
56 }
57 }
58 clearBuffer();
59
60 //Clear the column for used for this variable
61
62 Assert(d_mergeBuffer.empty());
63 Assert(!isBasic(oldBasic));
64 Assert(isBasic(newBasic));
65 Assert(getColLength(newBasic) == 1);
66 }
67
68 /**
69 * Changes basic to newbasic (a variable on the row).
70 */
71 void Tableau::rowPivot(ArithVar basicOld, ArithVar basicNew, CoefficientChangeCallback& cb){
72 Assert(isBasic(basicOld));
73 Assert(!isBasic(basicNew));
74
75 RowIndex rid = basicToRowIndex(basicOld);
76
77 EntryID newBasicID = findOnRow(rid, basicNew);
78
79 Assert(newBasicID != ENTRYID_SENTINEL);
80
81 Tableau::Entry& newBasicEntry = d_entries.get(newBasicID);
82 const Rational& a_rs = newBasicEntry.getCoefficient();
83 int a_rs_sgn = a_rs.sgn();
84 Rational negInverseA_rs = -(a_rs.inverse());
85
86 for(RowIterator i = basicRowIterator(basicOld); !i.atEnd(); ++i){
87 EntryID id = i.getID();
88 Tableau::Entry& entry = d_entries.get(id);
89
90 entry.getCoefficient() *= negInverseA_rs;
91 }
92
93 d_basic2RowIndex.remove(basicOld);
94 d_basic2RowIndex.set(basicNew, rid);
95 d_rowIndex2basic.set(rid, basicNew);
96
97 cb.multiplyRow(rid, -a_rs_sgn);
98 }
99
100 void Tableau::addRow(ArithVar basic,
101 const std::vector<Rational>& coefficients,
102 const std::vector<ArithVar>& variables)
103 {
104 Assert(basic < getNumColumns());
105 Assert(debugIsASet(variables));
106 Assert(coefficients.size() == variables.size());
107 Assert(!isBasic(basic));
108
109 RowIndex newRow = Matrix<Rational>::addRow(coefficients, variables);
110 addEntry(newRow, basic, Rational(-1));
111
112 Assert(!d_basic2RowIndex.isKey(basic));
113 Assert(!d_rowIndex2basic.isKey(newRow));
114
115 d_basic2RowIndex.set(basic, newRow);
116 d_rowIndex2basic.set(newRow, basic);
117
118
119 if(Debug.isOn("matrix")){ printMatrix(); }
120
121 NoEffectCCCB noeffect;
122 NoEffectCCCB* nep = &noeffect;
123 CoefficientChangeCallback* cccb = static_cast<CoefficientChangeCallback*>(nep);
124
125 vector<Rational>::const_iterator coeffIter = coefficients.begin();
126 vector<ArithVar>::const_iterator varsIter = variables.begin();
127 vector<ArithVar>::const_iterator varsEnd = variables.end();
128 for(; varsIter != varsEnd; ++coeffIter, ++varsIter){
129 ArithVar var = *varsIter;
130
131 if(isBasic(var)){
132 Rational coeff = *coeffIter;
133
134 RowIndex ri = basicToRowIndex(var);
135
136 loadRowIntoBuffer(ri);
137 rowPlusBufferTimesConstant(newRow, coeff, *cccb);
138 clearBuffer();
139 }
140 }
141
142 if(Debug.isOn("matrix")) { printMatrix(); }
143
144 Assert(debugNoZeroCoefficients(newRow));
145 Assert(debugMatchingCountsForRow(newRow));
146 Assert(getColLength(basic) == 1);
147 }
148
149 void Tableau::removeBasicRow(ArithVar basic){
150 RowIndex rid = basicToRowIndex(basic);
151
152 removeRow(rid);
153 d_basic2RowIndex.remove(basic);
154 d_rowIndex2basic.remove(rid);
155 }
156
157 void Tableau::substitutePlusTimesConstant(ArithVar to, ArithVar from, const Rational& mult, CoefficientChangeCallback& cb){
158 if(!mult.isZero()){
159 RowIndex to_idx = basicToRowIndex(to);
160 addEntry(to_idx, from, mult); // Add an entry to be cancelled out
161 RowIndex from_idx = basicToRowIndex(from);
162
163 cb.update(to_idx, from, 0, mult.sgn());
164
165 loadRowIntoBuffer(from_idx);
166 rowPlusBufferTimesConstant(to_idx, mult, cb);
167 clearBuffer();
168 }
169 }
170
171 uint32_t Tableau::rowComplexity(ArithVar basic) const{
172 uint32_t complexity = 0;
173 for(RowIterator i = basicRowIterator(basic); !i.atEnd(); ++i){
174 const Entry& e = *i;
175 complexity += e.getCoefficient().complexity();
176 }
177 return complexity;
178 }
179
180 double Tableau::avgRowComplexity() const{
181 double sum = 0;
182 uint32_t rows = 0;
183 for(BasicIterator i = beginBasic(), i_end = endBasic(); i != i_end; ++i){
184 sum += rowComplexity(*i);
185 rows++;
186 }
187 return (rows == 0) ? 0 : (sum/rows);
188 }
189
190 void Tableau::printBasicRow(ArithVar basic, std::ostream& out){
191 printRow(basicToRowIndex(basic), out);
192 }
193
194 } // namespace arith
195 } // namespace theory
196 } // namespace cvc5