050e6baed305c5135d696b645eb2b11ba68de639
[cvc5.git] / src / theory / arith / nl / iand_table.cpp
1 /********************* */
2 /*! \file iand_table.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Yoni Zohar
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Utilities to maintain finite tables that represent
13 ** the value of iand.
14 **/
15
16 #include "theory/arith/nl/iand_table.h"
17
18 #include <cmath>
19
20 #include "cvc4_private.h"
21 #include "theory/arith/nl/nl_model.h"
22 namespace CVC4 {
23 namespace theory {
24 namespace arith {
25 namespace nl {
26
27 static Rational intpow2(uint64_t b)
28 {
29 return Rational(Integer(2).pow(b), Integer(1));
30 }
31
32 Node pow2(uint64_t k)
33 {
34 Assert(k >= 0);
35 NodeManager* nm = NodeManager::currentNM();
36 return nm->mkConst<Rational>(intpow2(k));
37 }
38
39 bool oneBitAnd(bool a, bool b) { return (a && b); }
40
41 // computes (bv_to_int ((_ extract i+size-1 i) (int_to_bv x))))
42 Node intExtract(Node x, uint64_t i, uint64_t size)
43 {
44 Assert(size > 0);
45 NodeManager* nm = NodeManager::currentNM();
46 // extract definition in integers is:
47 // (mod (div a (two_to_the j)) (two_to_the (+ (- i j) 1))))
48 Node extract =
49 nm->mkNode(kind::INTS_MODULUS_TOTAL,
50 nm->mkNode(kind::INTS_DIVISION_TOTAL, x, pow2(i * size)),
51 pow2(size));
52 return extract;
53 }
54
55 Node IAndTable::createITEFromTable(
56 Node x,
57 Node y,
58 uint64_t granularity,
59 const std::map<std::pair<int64_t, int64_t>, uint64_t>& table)
60 {
61 NodeManager* nm = NodeManager::currentNM();
62 Assert(granularity <= 8);
63 uint64_t num_of_values = ((uint64_t)pow(2, granularity));
64 // The table represents a function from pairs of integers to integers, where
65 // all integers are between 0 (inclusive) and num_of_values (exclusive).
66 // additionally, there is a default value (-1, -1).
67 Assert(table.size() == 1 + ((uint64_t)(num_of_values * num_of_values)));
68 // start with the default, most common value.
69 // this value is represented in the table by (-1, -1).
70 Node ite = nm->mkConst<Rational>(table.at(std::make_pair(-1, -1)));
71 for (uint64_t i = 0; i < num_of_values; i++)
72 {
73 for (uint64_t j = 0; j < num_of_values; j++)
74 {
75 // skip the most common value, as it was already stored.
76 if (table.at(std::make_pair(i, j)) == table.at(std::make_pair(-1, -1)))
77 {
78 continue;
79 }
80 // append the current value to the ite.
81 ite = nm->mkNode(
82 kind::ITE,
83 nm->mkNode(kind::AND,
84 nm->mkNode(kind::EQUAL, x, nm->mkConst<Rational>(i)),
85 nm->mkNode(kind::EQUAL, y, nm->mkConst<Rational>(j))),
86 nm->mkConst<Rational>(table.at(std::make_pair(i, j))),
87 ite);
88 }
89 }
90 return ite;
91 }
92
93 Node IAndTable::createBitwiseNode(Node x,
94 Node y,
95 uint64_t bvsize,
96 uint64_t granularity)
97 {
98 NodeManager* nm = NodeManager::currentNM();
99 Assert(0 < granularity && granularity <= 8);
100 // Standardize granularity.
101 // If it is greater than bvsize, it is set to bvsize.
102 // Otherwise, it is set to the closest (going down) divider of bvsize.
103 if (granularity > bvsize)
104 {
105 granularity = bvsize;
106 }
107 else
108 {
109 while (bvsize % granularity != 0)
110 {
111 granularity = granularity - 1;
112 }
113 }
114
115 // Create the sum.
116 // For granularity 1, the sum has bvsize elements.
117 // In contrast, if bvsize = granularity, sum has one element.
118 // Each element in the sum is an ite that corresponds to the generated table,
119 // multiplied by the appropriate power of two.
120 // More details are in bv_to_int.h .
121
122 // number of elements in the sum expression
123 uint64_t sumSize = bvsize / granularity;
124 // initialize the sum
125 Node sumNode = nm->mkConst<Rational>(0);
126 // compute the table for the current granularity if needed
127 if (d_bvandTable.find(granularity) == d_bvandTable.end())
128 {
129 computeAndTable(granularity);
130 }
131 const std::map<std::pair<int64_t, int64_t>, uint64_t>& table =
132 d_bvandTable[granularity];
133 for (uint64_t i = 0; i < sumSize; i++)
134 {
135 // compute the current blocks of x and y
136 Node xExtract = intExtract(x, i, granularity);
137 Node yExtract = intExtract(y, i, granularity);
138 // compute the ite for this part
139 Node sumPart = createITEFromTable(xExtract, yExtract, granularity, table);
140 // append the current block to the sum
141 sumNode =
142 nm->mkNode(kind::PLUS,
143 sumNode,
144 nm->mkNode(kind::MULT, pow2(i * granularity), sumPart));
145 }
146 return sumNode;
147 }
148
149 void IAndTable::computeAndTable(uint64_t granularity)
150 {
151 Assert(d_bvandTable.find(granularity) == d_bvandTable.end());
152 // the table was not yet computed
153 std::map<std::pair<int64_t, int64_t>, uint64_t> table;
154 uint64_t num_of_values = ((uint64_t)pow(2, granularity));
155 // populate the table with all the values
156 for (uint64_t i = 0; i < num_of_values; i++)
157 {
158 for (uint64_t j = 0; j < num_of_values; j++)
159 {
160 // compute
161 // (bv_to_int (bvand ((int_to_bv granularity) i) ((int_to_bv granularity)
162 // j)))
163 int64_t sum = 0;
164 for (uint64_t n = 0; n < granularity; n++)
165 {
166 // b is the result of f on the current bit
167 bool b = oneBitAnd((((i >> n) & 1) == 1), (((j >> n) & 1) == 1));
168 // add the corresponding power of 2 only if the result is 1
169 if (b)
170 {
171 sum += 1 << n;
172 }
173 }
174 table[std::make_pair(i, j)] = sum;
175 }
176 }
177 // optimize the table by identifying and adding the default value
178 addDefaultValue(table, num_of_values);
179 Assert(table.size() == 1 + (static_cast<uint64_t>(num_of_values * num_of_values)));
180 // store the table in the cache and return it
181 d_bvandTable[granularity] = table;
182 }
183
184 void IAndTable::addDefaultValue(
185 std::map<std::pair<int64_t, int64_t>, uint64_t>& table,
186 uint64_t num_of_values)
187 {
188 // map each result to the number of times it occurs
189 std::map<uint64_t, uint64_t> counters;
190 for (uint64_t i = 0; i <= num_of_values; i++)
191 {
192 counters[i] = 0;
193 }
194 for (const std::pair<std::pair<int64_t, int64_t>, uint64_t>& element : table)
195 {
196 uint64_t result = element.second;
197 counters[result]++;
198 }
199
200 // compute the most common result
201 uint64_t most_common_result = 0;
202 uint64_t max_num_of_occ = 0;
203 for (uint64_t i = 0; i <= num_of_values; i++)
204 {
205 if (counters[i] >= max_num_of_occ)
206 {
207 max_num_of_occ = counters[i];
208 most_common_result = i;
209 }
210 }
211 // sanity check: some value appears at least once.
212 Assert(max_num_of_occ != 0);
213
214 // -1 is the default case of the table.
215 // add it to the table
216 table[std::make_pair(-1, -1)] = most_common_result;
217 }
218
219 } // namespace nl
220 } // namespace arith
221 } // namespace theory
222 } // namespace CVC4