Merge pull request #28 from kbansal/sets
[cvc5.git] / src / theory / arith / cut_log.h
1
2 #include "cvc4_private.h"
3
4 #pragma once
5
6 #include "expr/kind.h"
7 #include "util/statistics_registry.h"
8 #include "theory/arith/arithvar.h"
9 #include "theory/arith/constraint_forward.h"
10 #include "util/dense_map.h"
11 #include <vector>
12 #include <map>
13 #include <set>
14 #include <ext/hash_map>
15
16 namespace CVC4 {
17 namespace theory {
18 namespace arith {
19
20 /** A low level vector of indexed doubles. */
21 struct PrimitiveVec {
22 int len;
23 int* inds;
24 double* coeffs;
25 PrimitiveVec();
26 ~PrimitiveVec();
27 bool initialized() const;
28 void clear();
29 void setup(int l);
30 void print(std::ostream& out) const;
31 };
32 std::ostream& operator<<(std::ostream& os, const PrimitiveVec& pv);
33
34 struct DenseVector {
35 DenseMap<Rational> lhs;
36 Rational rhs;
37 void purge();
38 void print(std::ostream& os) const;
39
40 static void print(std::ostream& os, const DenseMap<Rational>& lhs);
41 };
42
43 /** The different kinds of cuts. */
44 enum CutInfoKlass{ MirCutKlass, GmiCutKlass, BranchCutKlass,
45 RowsDeletedKlass,
46 UnknownKlass};
47 std::ostream& operator<<(std::ostream& os, CutInfoKlass kl);
48
49 /** A general class for describing a cut. */
50 class CutInfo {
51 protected:
52 CutInfoKlass d_klass;
53 int d_execOrd;
54
55 int d_poolOrd; /* cut's ordinal in the current node pool */
56 Kind d_cutType; /* Lowerbound, upperbound or undefined. */
57 double d_cutRhs; /* right hand side of the cut */
58 PrimitiveVec d_cutVec; /* vector of the cut */
59
60 /**
61 * The number of rows at the time the cut was made.
62 * This is required to descramble indices after the fact!
63 */
64 int d_mAtCreation;
65
66 /** This is the number of structural variables. */
67 int d_N;
68
69 /** if selected, make this non-zero */
70 int d_rowId;
71
72 /* If the cut has been successfully created,
73 * the cut is stored in exact precision in d_exactPrecision.
74 * If the cut has not yet been proven, this is null.
75 */
76 DenseVector* d_exactPrecision;
77
78 ConstraintCPVec* d_explanation;
79
80 public:
81 CutInfo(CutInfoKlass kl, int cutid, int ordinal);
82
83 virtual ~CutInfo();
84
85 int getId() const;
86
87 int getRowId() const;
88 void setRowId(int rid);
89
90 void print(std::ostream& out) const;
91 //void init_cut(int l);
92 PrimitiveVec& getCutVector();
93 const PrimitiveVec& getCutVector() const;
94
95 Kind getKind() const;
96 void setKind(Kind k);
97
98
99 void setRhs(double r);
100 double getRhs() const;
101
102 CutInfoKlass getKlass() const;
103 int poolOrdinal() const;
104
105 void setDimensions(int N, int M);
106 int getN() const;
107 int getMAtCreation() const;
108
109 bool operator<(const CutInfo& o) const;
110
111 /* Returns true if the cut was successfully made in exact precision.*/
112 bool reconstructed() const;
113
114 /* Returns true if the cut has an explanation. */
115 bool proven() const;
116
117 void setReconstruction(const DenseVector& ep);
118 void setExplanation(const ConstraintCPVec& ex);
119 void swapExplanation(ConstraintCPVec& ex);
120
121 const DenseVector& getReconstruction() const;
122 const ConstraintCPVec& getExplanation() const;
123
124 void clearReconstruction();
125 };
126 std::ostream& operator<<(std::ostream& os, const CutInfo& ci);
127
128 struct BranchCutInfo : public CutInfo {
129 BranchCutInfo(int execOrd, int br, Kind dir, double val);
130 };
131
132 struct RowsDeleted : public CutInfo {
133 RowsDeleted(int execOrd, int nrows, const int num[]);
134 };
135
136 class TreeLog;
137
138 class NodeLog {
139 private:
140 int d_nid;
141 NodeLog* d_parent; /* If null this is the root */
142 TreeLog* d_tl; /* TreeLog containing the node. */
143
144 struct CmpCutPointer{
145 int operator()(const CutInfo* a, const CutInfo* b) const{
146 return *a < *b;
147 }
148 };
149 typedef std::set<CutInfo*, CmpCutPointer> CutSet;
150 CutSet d_cuts;
151 std::map<int, int> d_rowIdsSelected;
152
153 enum Status {Open, Closed, Branched};
154 Status d_stat;
155
156 int d_brVar; // branching variable
157 double d_brVal;
158 int d_downId;
159 int d_upId;
160
161 public:
162 typedef __gnu_cxx::hash_map<int, ArithVar> RowIdMap;
163 private:
164 RowIdMap d_rowId2ArithVar;
165
166 public:
167 NodeLog(); /* default constructor. */
168 NodeLog(TreeLog* tl, int node, const RowIdMap& m); /* makes a root node. */
169 NodeLog(TreeLog* tl, NodeLog* parent, int node);/* makes a non-root node. */
170
171 ~NodeLog();
172
173 int getNodeId() const;
174 void addSelected(int ord, int sel);
175 void applySelected();
176 void addCut(CutInfo* ci);
177 void print(std::ostream& o) const;
178
179 bool isRoot() const;
180 const NodeLog& getParent() const;
181
182 void copyParentRowIds();
183
184 bool isBranch() const;
185 int branchVariable() const;
186 double branchValue() const;
187
188 typedef CutSet::const_iterator const_iterator;
189 const_iterator begin() const;
190 const_iterator end() const;
191
192 void setBranch(int br, double val, int dn, int up);
193 void closeNode();
194
195 int getDownId() const;
196 int getUpId() const;
197
198 /**
199 * Looks up a row id to the appropraite arith variable.
200 * Be careful these are deleted in context during replay!
201 * failure returns ARITHVAR_SENTINEL */
202 ArithVar lookupRowId(int rowId) const;
203
204 /**
205 * Maps a row id to an arithvar.
206 * Be careful these are deleted in context during replay!
207 */
208 void mapRowId(int rowid, ArithVar v);
209 void applyRowsDeleted(const RowsDeleted& rd);
210
211 };
212 std::ostream& operator<<(std::ostream& os, const NodeLog& nl);
213
214 class ApproximateSimplex;
215 class TreeLog {
216 private:
217 ApproximateSimplex* d_generator;
218
219 int next_exec_ord;
220 typedef std::map<int, NodeLog> ToNodeMap;
221 ToNodeMap d_toNode;
222 DenseMultiset d_branches;
223
224 uint32_t d_numCuts;
225
226 bool d_active;
227
228 public:
229 TreeLog();
230
231 NodeLog& getNode(int nid);
232 void branch(int nid, int br, double val, int dn, int up);
233 void close(int nid);
234
235 //void applySelected();
236 void print(std::ostream& o) const;
237
238 typedef ToNodeMap::const_iterator const_iterator;
239 const_iterator begin() const;
240 const_iterator end() const;
241
242 int getExecutionOrd();
243
244 void reset(const NodeLog::RowIdMap& m);
245
246 // Applies rd tp to the node with id nid
247 void applyRowsDeleted(int nid, const RowsDeleted& rd);
248
249 // Synonym for getNode(nid).mapRowId(ind, v)
250 void mapRowId(int nid, int ind, ArithVar v);
251
252 private:
253 void clear();
254
255 public:
256 void makeInactive();
257 void makeActive();
258
259 bool isActivelyLogging() const;
260
261 void addCut();
262 uint32_t cutCount() const;
263
264 void logBranch(uint32_t x);
265 uint32_t numBranches(uint32_t x);
266
267 int getRootId() const;
268
269 uint32_t numNodes() const{
270 return d_toNode.size();
271 }
272
273 NodeLog& getRootNode();
274 void printBranchInfo(std::ostream& os) const;
275 };
276
277
278
279 }/* CVC4::theory::arith namespace */
280 }/* CVC4::theory namespace */
281 }/* CVC4 namespace */