Add n-ary match trie utility (#6909)
[cvc5.git] / src / expr / nary_match_trie.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * Implements a n-ary match trie
14 */
15
16 #include "expr/nary_match_trie.h"
17
18 #include <sstream>
19 #include "expr/nary_term_util.h"
20
21 using namespace cvc5::kind;
22
23 namespace cvc5 {
24 namespace expr {
25
26 class NaryMatchFrame
27 {
28 public:
29 NaryMatchFrame(const std::vector<Node>& syms, const NaryMatchTrie* t)
30 : d_syms(syms), d_trie(t), d_index(0), d_variant(0), d_boundVar(false)
31 {
32 }
33 /** Symbols to match */
34 std::vector<Node> d_syms;
35 /** The match trie */
36 const NaryMatchTrie* d_trie;
37 /** The index we are considering, 0 = operator, n>0 = variable # (n-1) */
38 size_t d_index;
39 /** List length considering */
40 size_t d_variant;
41 /** Whether we just bound a variable */
42 bool d_boundVar;
43 };
44
45 bool NaryMatchTrie::getMatches(Node n, NotifyMatch* ntm) const
46 {
47 NodeManager* nm = NodeManager::currentNM();
48 std::vector<Node> vars;
49 std::vector<Node> subs;
50 std::map<Node, Node> smap;
51
52 std::map<Node, NaryMatchTrie>::const_iterator itc;
53
54 std::vector<NaryMatchFrame> visit;
55 visit.push_back(NaryMatchFrame({n}, this));
56
57 while (!visit.empty())
58 {
59 NaryMatchFrame& curr = visit.back();
60 // currently, copy the symbols from previous frame TODO: improve?
61 std::vector<Node> syms = curr.d_syms;
62 const NaryMatchTrie* mt = curr.d_trie;
63 if (syms.empty())
64 {
65 // if we matched, there must be a data member at this node
66 Assert(!mt->d_data.isNull());
67 // notify match?
68 Assert(n == expr::narySubstitute(mt->d_data, vars, subs));
69 Trace("match-debug") << "notify : " << mt->d_data << std::endl;
70 if (!ntm->notify(n, mt->d_data, vars, subs))
71 {
72 return false;
73 }
74 visit.pop_back();
75 continue;
76 }
77
78 // clean up if we previously bound a variable
79 if (curr.d_boundVar)
80 {
81 Assert(!vars.empty());
82 Assert(smap.find(vars.back()) != smap.end());
83 smap.erase(vars.back());
84 vars.pop_back();
85 subs.pop_back();
86 curr.d_boundVar = false;
87 }
88
89 if (curr.d_index == 0)
90 {
91 curr.d_index++;
92 // finished matching variables, try to match the operator
93 Node next = syms.back();
94 Node op =
95 (!next.isNull() && next.hasOperator()) ? next.getOperator() : next;
96 itc = mt->d_children.find(op);
97 if (itc != mt->d_children.end())
98 {
99 syms.pop_back();
100 // push the children + null termination marker, in reverse order
101 if (NodeManager::isNAryKind(next.getKind()))
102 {
103 syms.push_back(Node::null());
104 }
105 if (next.hasOperator())
106 {
107 syms.insert(syms.end(), next.rbegin(), next.rend());
108 }
109 // new frame
110 visit.push_back(NaryMatchFrame(syms, &itc->second));
111 }
112 }
113 else if (curr.d_index <= mt->d_vars.size())
114 {
115 // try to match the next (variable, length)
116 Node var;
117 Node next;
118 do
119 {
120 var = mt->d_vars[curr.d_index - 1];
121 Assert(mt->d_children.find(var) != mt->d_children.end());
122 std::vector<Node> currChildren;
123 if (isListVar(var))
124 {
125 // get the length of the list we want to consider
126 size_t l = curr.d_variant;
127 curr.d_variant++;
128 // match with l, or increment d_index otherwise
129 bool foundChildren = true;
130 // We are in a state where the children of an n-ary child
131 // have been pused to syms. We try to extract l children here. If
132 // we encounter the null symbol, then we do not have sufficient
133 // children to match for this variant and fail.
134 for (size_t i = 0; i < l; i++)
135 {
136 Assert(!syms.empty());
137 Node s = syms.back();
138 if (s.isNull())
139 {
140 foundChildren = false;
141 break;
142 }
143 currChildren.push_back(s);
144 syms.pop_back();
145 }
146 if (foundChildren)
147 {
148 // we are matching the next list
149 next = nm->mkNode(SEXPR, currChildren);
150 }
151 else
152 {
153 // otherwise, we have run out of variants, go to next variable
154 curr.d_index++;
155 curr.d_variant = 0;
156 }
157 }
158 else
159 {
160 next = syms.back();
161 curr.d_index++;
162 // we could be at the end of an n-ary operator, in which case we
163 // do not match
164 if (!next.isNull())
165 {
166 currChildren.push_back(next);
167 syms.pop_back();
168 // check subtyping in the (non-list) case
169 if (!var.getType().isSubtypeOf(next.getType()))
170 {
171 next = Node::null();
172 }
173 }
174 }
175 if (!next.isNull())
176 {
177 // check if it is already bound, do the binding if necessary
178 std::map<Node, Node>::iterator its = smap.find(var);
179 if (its != smap.end())
180 {
181 if (its->second != next)
182 {
183 // failed to match
184 next = Node::null();
185 }
186 // otherwise, successfully matched, nothing to do
187 }
188 else
189 {
190 // add to binding
191 vars.push_back(var);
192 subs.push_back(next);
193 smap[var] = next;
194 curr.d_boundVar = true;
195 }
196 }
197 if (next.isNull())
198 {
199 // if we failed, revert changes to syms
200 syms.insert(syms.end(), currChildren.rbegin(), currChildren.rend());
201 }
202 } while (next.isNull() && curr.d_index <= mt->d_vars.size());
203 if (next.isNull())
204 {
205 // we are out of variables to match, finished with this frame
206 visit.pop_back();
207 continue;
208 }
209 Trace("match-debug") << "recurse var : " << var << std::endl;
210 itc = mt->d_children.find(var);
211 Assert(itc != mt->d_children.end());
212 visit.push_back(NaryMatchFrame(syms, &itc->second));
213 }
214 else
215 {
216 // no variables to match, we are done
217 visit.pop_back();
218 }
219 }
220 return true;
221 }
222
223 void NaryMatchTrie::addTerm(Node n)
224 {
225 Assert(!n.isNull());
226 std::vector<Node> visit;
227 visit.push_back(n);
228 NaryMatchTrie* curr = this;
229 while (!visit.empty())
230 {
231 Node cn = visit.back();
232 visit.pop_back();
233 if (cn.isNull())
234 {
235 curr = &(curr->d_children[cn]);
236 }
237 else if (cn.hasOperator())
238 {
239 curr = &(curr->d_children[cn.getOperator()]);
240 // add null terminator if an n-ary kind
241 if (NodeManager::isNAryKind(cn.getKind()))
242 {
243 visit.push_back(Node::null());
244 }
245 // note children are processed left to right
246 visit.insert(visit.end(), cn.rbegin(), cn.rend());
247 }
248 else
249 {
250 if (cn.isVar()
251 && std::find(curr->d_vars.begin(), curr->d_vars.end(), cn)
252 == curr->d_vars.end())
253 {
254 curr->d_vars.push_back(cn);
255 }
256 curr = &(curr->d_children[cn]);
257 }
258 }
259 curr->d_data = n;
260 }
261
262 void NaryMatchTrie::clear()
263 {
264 d_children.clear();
265 d_vars.clear();
266 d_data = Node::null();
267 }
268
269 std::string NaryMatchTrie::debugPrint() const
270 {
271 std::stringstream ss;
272 std::vector<std::tuple<const NaryMatchTrie*, size_t, Node>> visit;
273 visit.emplace_back(this, 0, Node::null());
274 do
275 {
276 std::tuple<const NaryMatchTrie*, size_t, Node> curr = visit.back();
277 visit.pop_back();
278 size_t indent = std::get<1>(curr);
279 for (size_t i = 0; i < indent; i++)
280 {
281 ss << " ";
282 }
283 Node n = std::get<2>(curr);
284 if (indent == 0)
285 {
286 ss << ".";
287 }
288 else
289 {
290 ss << n;
291 }
292 ss << ((!n.isNull() && isListVar(n)) ? " [*]" : "");
293 ss << std::endl;
294 const NaryMatchTrie* mt = std::get<0>(curr);
295 for (const std::pair<const Node, NaryMatchTrie>& c : mt->d_children)
296 {
297 visit.emplace_back(&c.second, indent + 1, c.first);
298 }
299 } while (!visit.empty());
300 return ss.str();
301 }
302
303 } // namespace expr
304 } // namespace cvc5