Merged the branch sparse-tableau into trunk.
[cvc5.git] / src / theory / arith / tableau.cpp
1 /********************* */
2 /*! \file tableau.cpp
3 ** \verbatim
4 ** Original author: taking
5 ** Major contributors: none
6 ** Minor contributors (to current version): none
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009, 2010 The Analysis of Computer Systems Group (ACSys)
9 ** Courant Institute of Mathematical Sciences
10 ** New York University
11 ** See the file COPYING in the top-level source directory for licensing
12 ** information.\endverbatim
13 **
14 ** \brief [[ Add one-line brief description here ]]
15 **
16 ** [[ Add lengthier description here ]]
17 ** \todo document this file
18 **/
19
20
21 #include "theory/arith/tableau.h"
22
23 using namespace std;
24 using namespace CVC4;
25 using namespace CVC4::theory;
26 using namespace CVC4::theory::arith;
27
28
29 /*
30 void Tableau::addRow(ArithVar basicVar,
31 const std::vector<Rational>& coeffs,
32 const std::vector<ArithVar>& variables){
33
34 Assert(coeffs.size() == variables.size());
35
36 //The new basic variable cannot already be a basic variable
37 Assert(!d_basicVariables.isMember(basicVar));
38 d_basicVariables.add(basicVar);
39 ReducedRowVector* row_current = new ReducedRowVector(basicVar,variables, coeffs,d_rowCount, d_columnMatrix);
40 d_rowsTable[basicVar] = row_current;
41
42 //A variable in the row may have been made non-basic already.
43 //If this is the case we fake pivoting this variable
44 vector<ArithVar>::const_iterator varsIter = variables.begin();
45 vector<ArithVar>::const_iterator varsEnd = variables.end();
46
47 for( ; varsIter != varsEnd; ++varsIter){
48 ArithVar var = *varsIter;
49
50 if(d_basicVariables.isMember(var)){
51 EntryID varID = find(basicVar, var);
52 TableauEntry& entry = d_entryManager.get(varID);
53 const Rational& coeff = entry.getCoefficient();
54
55 loadRowIntoMergeBuffer(var);
56 rowPlusRowTimesConstant(coeff, basicVar, var);
57 emptyRowFromMergeBuffer(var);
58 }
59 }
60 }
61 */
62
63 /*
64 ReducedRowVector* Tableau::removeRow(ArithVar basic){
65 Assert(isBasic(basic));
66
67 ReducedRowVector* row = d_rowsTable[basic];
68
69 d_basicVariables.remove(basic);
70 d_rowsTable[basic] = NULL;
71
72 return row;
73 }
74 */
75
76 void Tableau::pivot(ArithVar oldBasic, ArithVar newBasic){
77 Assert(isBasic(oldBasic));
78 Assert(!isBasic(newBasic));
79 Assert(mergeBufferIsEmpty());
80
81 //cout << oldBasic << "," << newBasic << endl;
82 Debug("tableau") << "Tableau::pivot(" << oldBasic <<", " << newBasic <<")" << endl;
83
84 rowPivot(oldBasic, newBasic);
85 loadRowIntoMergeBuffer(newBasic);
86
87 ColIterator colIter = colIterator(newBasic);
88 while(!colIter.atEnd()){
89 EntryID id = colIter.getID();
90 TableauEntry& entry = d_entryManager.get(id);
91
92 ++colIter; //needs to be incremented before the variable is removed
93 if(entry.getRowVar() == newBasic){ continue; }
94
95 ArithVar basicTo = entry.getRowVar();
96 Rational coeff = entry.getCoefficient();
97
98 rowPlusRowTimesConstant(basicTo, coeff, newBasic);
99 }
100 emptyRowFromMergeBuffer(newBasic);
101
102 //Clear the column for used for this variable
103
104 Assert(mergeBufferIsEmpty());
105 Assert(!isBasic(oldBasic));
106 Assert(isBasic(newBasic));
107 Assert(getColLength(newBasic) == 1);
108 }
109
110 Tableau::~Tableau(){}
111
112 void Tableau::setColumnUnused(ArithVar v){
113 ColIterator colIter = colIterator(v);
114 while(!colIter.atEnd()){
115 ++colIter;
116 }
117 }
118 void Tableau::printTableau(){
119 Debug("tableau") << "Tableau::d_activeRows" << endl;
120
121 ArithVarSet::const_iterator basicIter = beginBasic(), endIter = endBasic();
122 for(; basicIter != endIter; ++basicIter){
123 ArithVar basic = *basicIter;
124 printRow(basic);
125 }
126 }
127
128 void Tableau::printRow(ArithVar basic){
129 Debug("tableau") << "{" << basic << ":";
130 for(RowIterator entryIter = rowIterator(basic); !entryIter.atEnd(); ++entryIter){
131 const TableauEntry& entry = *entryIter;
132 printEntry(entry);
133 Debug("tableau") << ",";
134 }
135 Debug("tableau") << "}" << endl;
136 }
137
138 void Tableau::printEntry(const TableauEntry& entry){
139 Debug("tableau") << entry.getColVar() << "*" << entry.getCoefficient();
140 }
141
142 uint32_t Tableau::numNonZeroEntriesByRow() const {
143 uint32_t rowSum = 0;
144 ArithVarSet::const_iterator i = d_basicVariables.begin(), end = d_basicVariables.end();
145 for(; i != end; ++i){
146 ArithVar basic = *i;
147 rowSum += getRowLength(basic);
148 }
149 return rowSum;
150 }
151
152 uint32_t Tableau::numNonZeroEntriesByCol() const {
153 uint32_t colSum = 0;
154 VectorSizeTable::const_iterator i = d_colLengths.begin();
155 VectorSizeTable::const_iterator end = d_colLengths.end();
156 for(; i != end; ++i){
157 colSum += *i;
158 }
159 return colSum;
160 }
161
162
163 EntryID Tableau::findOnRow(ArithVar row, ArithVar col){
164 for(RowIterator i = rowIterator(row); !i.atEnd(); ++i){
165 EntryID id = i.getID();
166 const TableauEntry& entry = *i;
167 ArithVar colVar = entry.getColVar();
168
169 if(colVar == col){
170 return id;
171 }
172 }
173 return ENTRYID_SENTINEL;
174 }
175
176 EntryID Tableau::findOnCol(ArithVar row, ArithVar col){
177 for(ColIterator i = colIterator(col); !i.atEnd(); ++i){
178 EntryID id = i.getID();
179 const TableauEntry& entry = *i;
180 ArithVar rowVar = entry.getRowVar();
181
182 if(rowVar == row){
183 return id;
184 }
185 }
186 return ENTRYID_SENTINEL;
187 }
188
189 const TableauEntry& Tableau::findEntry(ArithVar row, ArithVar col){
190 bool colIsShorter = getColLength(col) < getRowLength(row);
191 EntryID id = colIsShorter ? findOnCol(row,col) : findOnRow(row,col);
192 if(id == ENTRYID_SENTINEL){
193 return d_failedFind;
194 }else{
195 return d_entryManager.get(id);
196 }
197 }
198
199 void Tableau::removeRow(ArithVar basic){
200 RowIterator i = rowIterator(basic);
201 while(!i.atEnd()){
202 EntryID id = i.getID();
203 ++i;
204 removeEntry(id);
205 }
206 d_basicVariables.remove(basic);
207 }
208
209 void Tableau::loadRowIntoMergeBuffer(ArithVar basic){
210 Assert(mergeBufferIsEmpty());
211 for(RowIterator i = rowIterator(basic); !i.atEnd(); ++i){
212 EntryID id = i.getID();
213 const TableauEntry& entry = *i;
214 ArithVar colVar = entry.getColVar();
215 d_mergeBuffer[colVar] = make_pair(id, false);
216 }
217 }
218
219 void Tableau::emptyRowFromMergeBuffer(ArithVar basic){
220 Assert(isBasic(basic));
221 for(RowIterator i = rowIterator(basic); !i.atEnd(); ++i){
222 const TableauEntry& entry = *i;
223 ArithVar colVar = entry.getColVar();
224 Assert(d_mergeBuffer[colVar].first == i.getID());
225 d_mergeBuffer[colVar] = make_pair(ENTRYID_SENTINEL, false);
226 }
227
228 Assert(mergeBufferIsEmpty());
229 }
230
231
232 /**
233 * Changes basic to newbasic (a variable on the row).
234 */
235 void Tableau::rowPivot(ArithVar basicOld, ArithVar basicNew){
236 Assert(mergeBufferIsEmpty());
237 Assert(isBasic(basicOld));
238 Assert(!isBasic(basicNew));
239
240 EntryID newBasicID = findOnRow(basicOld, basicNew);
241
242 Assert(newBasicID != ENTRYID_SENTINEL);
243
244 TableauEntry& newBasicEntry = d_entryManager.get(newBasicID);
245 Rational negInverseA_rs = -(newBasicEntry.getCoefficient().inverse());
246
247 for(RowIterator i = rowIterator(basicOld); !i.atEnd(); ++i){
248 EntryID id = i.getID();
249 TableauEntry& entry = d_entryManager.get(id);
250
251 entry.getCoefficient() *= negInverseA_rs;
252 entry.setRowVar(basicNew);
253 }
254
255 d_rowHeads[basicNew] = d_rowHeads[basicOld];
256 d_rowHeads[basicOld] = ENTRYID_SENTINEL;
257
258 d_rowLengths[basicNew] = d_rowLengths[basicOld];
259 d_rowLengths[basicOld] = 0;
260
261 d_basicVariables.remove(basicOld);
262 d_basicVariables.add(basicNew);
263 }
264
265 void Tableau::addEntry(ArithVar row, ArithVar col, const Rational& coeff){
266 Assert(coeff != 0);
267
268 EntryID newId = d_entryManager.newEntry();
269 TableauEntry& newEntry = d_entryManager.get(newId);
270 newEntry = TableauEntry( row, col,
271 d_rowHeads[row], d_colHeads[col],
272 ENTRYID_SENTINEL, ENTRYID_SENTINEL,
273 coeff);
274 Assert(newEntry.getCoefficient() != 0);
275
276 Debug("tableau") << "addEntry(" << row << "," << col <<"," << coeff << ")" << endl;
277
278 ++d_entriesInUse;
279
280 if(d_rowHeads[row] != ENTRYID_SENTINEL)
281 d_entryManager.get(d_rowHeads[row]).setPrevRowID(newId);
282
283 if(d_colHeads[col] != ENTRYID_SENTINEL)
284 d_entryManager.get(d_colHeads[col]).setPrevColID(newId);
285
286 d_rowHeads[row] = newId;
287 d_colHeads[col] = newId;
288 ++d_rowLengths[row];
289 ++d_colLengths[col];
290 }
291
292 void Tableau::removeEntry(EntryID id){
293 Assert(d_entriesInUse > 0);
294 --d_entriesInUse;
295
296 TableauEntry& entry = d_entryManager.get(id);
297
298 ArithVar row = entry.getRowVar();
299 ArithVar col = entry.getColVar();
300
301 Assert(d_rowLengths[row] > 0);
302 Assert(d_colLengths[col] > 0);
303
304
305 --d_rowLengths[row];
306 --d_colLengths[col];
307
308 EntryID prevRow = entry.getPrevRowID();
309 EntryID prevCol = entry.getPrevColID();
310
311 EntryID nextRow = entry.getNextRowID();
312 EntryID nextCol = entry.getNextColID();
313
314 if(d_rowHeads[row] == id){
315 d_rowHeads[row] = nextRow;
316 }
317 if(d_colHeads[col] == id){
318 d_colHeads[col] = nextCol;
319 }
320
321 entry.markBlank();
322
323 if(prevRow != ENTRYID_SENTINEL){
324 d_entryManager.get(prevRow).setNextRowID(nextRow);
325 }
326 if(nextRow != ENTRYID_SENTINEL){
327 d_entryManager.get(nextRow).setPrevRowID(prevRow);
328 }
329
330 if(prevCol != ENTRYID_SENTINEL){
331 d_entryManager.get(prevCol).setNextColID(nextCol);
332 }
333 if(nextCol != ENTRYID_SENTINEL){
334 d_entryManager.get(nextCol).setPrevColID(prevCol);
335 }
336
337 d_entryManager.freeEntry(id);
338 }
339
340 void Tableau::rowPlusRowTimesConstant(ArithVar basicTo, const Rational& c, ArithVar basicFrom){
341
342 Debug("tableau") << "rowPlusRowTimesConstant("
343 << basicTo << "," << c << "," << basicFrom << ")"
344 << endl;
345
346 Assert(debugNoZeroCoefficients(basicTo));
347 Assert(debugNoZeroCoefficients(basicFrom));
348
349 Assert(c != 0);
350 Assert(isBasic(basicTo));
351 Assert(isBasic(basicFrom));
352 Assert( d_usedList.empty() );
353
354
355 RowIterator i = rowIterator(basicTo);
356 while(!i.atEnd()){
357 EntryID id = i.getID();
358 TableauEntry& entry = d_entryManager.get(id);
359 ArithVar colVar = entry.getColVar();
360
361 ++i;
362 if(bufferPairIsNotEmpty(d_mergeBuffer[colVar])){
363 d_mergeBuffer[colVar].second = true;
364 d_usedList.push_back(colVar);
365
366 EntryID inOtherRow = d_mergeBuffer[colVar].first;
367 const TableauEntry& other = d_entryManager.get(inOtherRow);
368 entry.getCoefficient() += c * other.getCoefficient();
369
370 if(entry.getCoefficient().sgn() == 0){
371 removeEntry(id);
372 }
373 }
374 }
375
376 for(RowIterator i = rowIterator(basicFrom); !i.atEnd(); ++i){
377 const TableauEntry& entry = *i;
378 ArithVar colVar = entry.getColVar();
379
380 if(!(d_mergeBuffer[colVar]).second){
381 Rational newCoeff = c * entry.getCoefficient();
382 addEntry(basicTo, colVar, newCoeff);
383 }
384 }
385
386 clearUsedList();
387
388 if(Debug.isOn("tableau")) { printTableau(); }
389 }
390
391 void Tableau::clearUsedList(){
392 ArithVarArray::iterator i, end;
393 for(i = d_usedList.begin(), end = d_usedList.end(); i != end; ++i){
394 ArithVar pos = *i;
395 d_mergeBuffer[pos].second = false;
396 }
397 d_usedList.clear();
398 }
399
400 void Tableau::addRow(ArithVar basic,
401 const std::vector<Rational>& coefficients,
402 const std::vector<ArithVar>& variables)
403 {
404 Assert(coefficients.size() == variables.size() );
405 Assert(!isBasic(basic));
406
407 d_basicVariables.add(basic);
408
409 if(Debug.isOn("tableau")){ printTableau(); }
410
411 addEntry(basic, basic, Rational(-1));
412
413 vector<Rational>::const_iterator coeffIter = coefficients.begin();
414 vector<ArithVar>::const_iterator varsIter = variables.begin();
415 vector<ArithVar>::const_iterator varsEnd = variables.end();
416
417 for(; varsIter != varsEnd; ++coeffIter, ++varsIter){
418 const Rational& coeff = *coeffIter;
419 ArithVar var_i = *varsIter;
420 addEntry(basic, var_i, coeff);
421 }
422
423 varsIter = variables.begin();
424 coeffIter = coefficients.begin();
425 for(; varsIter != varsEnd; ++coeffIter, ++varsIter){
426 ArithVar var = *varsIter;
427
428 if(isBasic(var)){
429 Rational coeff = *coeffIter;
430
431 loadRowIntoMergeBuffer(var);
432 rowPlusRowTimesConstant(basic, coeff, var);
433 emptyRowFromMergeBuffer(var);
434 }
435 }
436
437 if(Debug.isOn("tableau")) { printTableau(); }
438
439 Assert(debugNoZeroCoefficients(basic));
440
441 Assert(debugMatchingCountsForRow(basic));
442 Assert(getColLength(basic) == 1);
443 }
444
445 bool Tableau::debugNoZeroCoefficients(ArithVar basic){
446 for(RowIterator i=rowIterator(basic); !i.atEnd(); ++i){
447 const TableauEntry& entry = *i;
448 if(entry.getCoefficient() == 0){
449 return false;
450 }
451 }
452 return true;
453 }
454 bool Tableau::debugMatchingCountsForRow(ArithVar basic){
455 for(RowIterator i=rowIterator(basic); !i.atEnd(); ++i){
456 const TableauEntry& entry = *i;
457 ArithVar colVar = entry.getColVar();
458 uint32_t count = debugCountColLength(colVar);
459 Debug("tableau") << "debugMatchingCountsForRow "
460 << basic << ":" << colVar << " " << count
461 <<" "<< d_colLengths[colVar] << endl;
462 if( count != d_colLengths[colVar] ){
463 return false;
464 }
465 }
466 return true;
467 }
468
469
470 uint32_t Tableau::debugCountColLength(ArithVar var){
471 Debug("tableau") << var << " ";
472 uint32_t count = 0;
473 for(ColIterator i=colIterator(var); !i.atEnd(); ++i){
474 const TableauEntry& entry = *i;
475 Debug("tableau") << "(" << entry.getRowVar() << ", " << i.getID() << ") ";
476 ++count;
477 }
478 Debug("tableau") << endl;
479 return count;
480 }
481
482 uint32_t Tableau::debugCountRowLength(ArithVar var){
483 uint32_t count = 0;
484 for(RowIterator i=rowIterator(var); !i.atEnd(); ++i){
485 ++count;
486 }
487 return count;
488 }
489
490 /*
491 void ReducedRowVector::enqueueNonBasicVariablesAndCoefficients(std::vector< ArithVar >& variables,std::vector< Rational >& coefficients) const{
492 for(const_iterator i=begin(), endEntries=end(); i != endEntries; ++i){
493 ArithVar var = (*i).getArithVar();
494 const Rational& q = (*i).getCoefficient();
495 if(var != basic()){
496 variables.push_back(var);
497 coefficients.push_back(q);
498 }
499 }
500 }*/
501
502 Node Tableau::rowAsEquality(ArithVar basic, const ArithVarToNodeMap& map){
503 using namespace CVC4::kind;
504
505 Assert(getRowLength(basic) >= 2);
506
507 vector<Node> nonBasicPairs;
508 for(RowIterator i = rowIterator(basic); !i.atEnd(); ++i){
509 const TableauEntry& entry = *i;
510 ArithVar colVar = entry.getColVar();
511 if(colVar == basic) continue;
512 Node var = (map.find(colVar))->second;
513 Node coeff = mkRationalNode(entry.getCoefficient());
514
515 Node mult = NodeBuilder<2>(MULT) << coeff << var;
516 nonBasicPairs.push_back(mult);
517 }
518
519 Node sum = Node::null();
520 if(nonBasicPairs.size() == 1 ){
521 sum = nonBasicPairs.front();
522 }else{
523 Assert(nonBasicPairs.size() >= 2);
524 NodeBuilder<> sumBuilder(PLUS);
525 sumBuilder.append(nonBasicPairs);
526 sum = sumBuilder;
527 }
528 Node basicVar = (map.find(basic))->second;
529 return NodeBuilder<2>(EQUAL) << basicVar << sum;
530 }
531
532 double Tableau::densityMeasure() const{
533 Assert(numNonZeroEntriesByRow() == numNonZeroEntries());
534 Assert(numNonZeroEntriesByCol() == numNonZeroEntries());
535
536 uint32_t n = getNumRows();
537 if(n == 0){
538 return 1.0;
539 }else {
540 uint32_t s = numNonZeroEntries();
541 uint32_t m = d_colHeads.size();
542 uint32_t divisor = (n *(m - n + 1));
543
544 Assert(n >= 1);
545 Assert(m >= n);
546 Assert(divisor > 0);
547 Assert(divisor >= s);
548
549 return (double(s)) / divisor;
550 }
551 }
552
553 void TableauEntryManager::freeEntry(EntryID id){
554 Assert(get(id).blank());
555 Assert(d_size > 0);
556
557 d_freedEntries.push(id);
558 --d_size;
559 }
560
561 EntryID TableauEntryManager::newEntry(){
562 EntryID newId;
563 if(d_freedEntries.empty()){
564 newId = d_entries.size();
565 d_entries.push_back(TableauEntry());
566 }else{
567 newId = d_freedEntries.front();
568 d_freedEntries.pop();
569 }
570 ++d_size;
571 return newId;
572 }