Update copyright headers.
[cvc5.git] / src / expr / symbol_table.cpp
1 /********************* */
2 /*! \file symbol_table.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Tim King, Morgan Deters
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Convenience class for scoping variable and type
13 ** declarations (implementation)
14 **
15 ** Convenience class for scoping variable and type declarations
16 ** (implementation).
17 **/
18
19 #include "expr/symbol_table.h"
20
21 #include <ostream>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25
26 #include "context/cdhashmap.h"
27 #include "context/cdhashset.h"
28 #include "context/context.h"
29 #include "expr/expr.h"
30 #include "expr/expr_manager_scope.h"
31 #include "expr/type.h"
32
33 namespace CVC4 {
34
35 using ::CVC4::context::CDHashMap;
36 using ::CVC4::context::CDHashSet;
37 using ::CVC4::context::Context;
38 using ::std::copy;
39 using ::std::endl;
40 using ::std::ostream_iterator;
41 using ::std::pair;
42 using ::std::string;
43 using ::std::vector;
44
45 /** Overloaded type trie.
46 *
47 * This data structure stores a trie of expressions with
48 * the same name, and must be distinguished by their argument types.
49 * It is context-dependent.
50 *
51 * Using the argument allowFunVariants,
52 * it may either be configured to allow function variants or not,
53 * where a function variant is function that expects the same
54 * argument types as another.
55 *
56 * For example, the following definitions introduce function
57 * variants for the symbol f:
58 *
59 * 1. (declare-fun f (Int) Int) and
60 * (declare-fun f (Int) Bool)
61 *
62 * 2. (declare-fun f (Int) Int) and
63 * (declare-fun f (Int) Int)
64 *
65 * 3. (declare-datatypes ((Tup 0)) ((f (data Int)))) and
66 * (declare-fun f (Int) Tup)
67 *
68 * 4. (declare-datatypes ((Tup 0)) ((mkTup (f Int)))) and
69 * (declare-fun f (Tup) Bool)
70 *
71 * If function variants is set to true, we allow function variants
72 * but not function redefinition. In examples 2 and 3, f is
73 * declared twice as a symbol of identical argument and range
74 * types. We never accept these definitions. However, we do
75 * allow examples 1 and 4 above when allowFunVariants is true.
76 *
77 * For 0-argument functions (constants), we always allow
78 * function variants. That is, we always accept these examples:
79 *
80 * 5. (declare-fun c () Int)
81 * (declare-fun c () Bool)
82 *
83 * 6. (declare-datatypes ((Enum 0)) ((c)))
84 * (declare-fun c () Int)
85 *
86 * and always reject constant redefinition such as:
87 *
88 * 7. (declare-fun c () Int)
89 * (declare-fun c () Int)
90 *
91 * 8. (declare-datatypes ((Enum 0)) ((c))) and
92 * (declare-fun c () Enum)
93 */
94 class OverloadedTypeTrie {
95 public:
96 OverloadedTypeTrie(Context* c, bool allowFunVariants = false)
97 : d_overloaded_symbols(new (true) CDHashSet<Expr, ExprHashFunction>(c)),
98 d_allowFunctionVariants(allowFunVariants)
99 {
100 }
101 ~OverloadedTypeTrie() { d_overloaded_symbols->deleteSelf(); }
102
103 /** is this function overloaded? */
104 bool isOverloadedFunction(Expr fun) const;
105
106 /** Get overloaded constant for type.
107 * If possible, it returns a defined symbol with name
108 * that has type t. Otherwise returns null expression.
109 */
110 Expr getOverloadedConstantForType(const std::string& name, Type t) const;
111
112 /**
113 * If possible, returns a defined function for a name
114 * and a vector of expected argument types. Otherwise returns
115 * null expression.
116 */
117 Expr getOverloadedFunctionForTypes(const std::string& name,
118 const std::vector<Type>& argTypes) const;
119 /** called when obj is bound to name, and prev_bound_obj was already bound to
120 * name Returns false if the binding is invalid.
121 */
122 bool bind(const string& name, Expr prev_bound_obj, Expr obj);
123
124 private:
125 /** Marks expression obj with name as overloaded.
126 * Adds relevant information to the type arg trie data structure.
127 * It returns false if there is already an expression bound to that name
128 * whose type expects the same arguments as the type of obj but is not
129 * identical to the type of obj. For example, if we declare :
130 *
131 * (declare-datatypes () ((List (cons (hd Int) (tl List)) (nil))))
132 * (declare-fun cons (Int List) List)
133 *
134 * cons : constructor_type( Int, List, List )
135 * cons : function_type( Int, List, List )
136 *
137 * These are put in the same place in the trie but do not have identical type,
138 * hence we return false.
139 */
140 bool markOverloaded(const string& name, Expr obj);
141 /** the null expression */
142 Expr d_nullExpr;
143 // The (context-independent) trie storing that maps expected argument
144 // vectors to symbols. All expressions stored in d_symbols are only
145 // interpreted as active if they also appear in the context-dependent
146 // set d_overloaded_symbols.
147 class TypeArgTrie {
148 public:
149 // children of this node
150 std::map<Type, TypeArgTrie> d_children;
151 // symbols at this node
152 std::map<Type, Expr> d_symbols;
153 };
154 /** for each string with operator overloading, this stores the data structure
155 * above. */
156 std::unordered_map<std::string, TypeArgTrie> d_overload_type_arg_trie;
157 /** The set of overloaded symbols. */
158 CDHashSet<Expr, ExprHashFunction>* d_overloaded_symbols;
159 /** allow function variants
160 * This is true if we allow overloading (non-constant) functions that expect
161 * the same argument types.
162 */
163 bool d_allowFunctionVariants;
164 /** get unique overloaded function
165 * If tat->d_symbols contains an active overloaded function, it
166 * returns that function, where that function must be unique
167 * if reqUnique=true.
168 * Otherwise, it returns the null expression.
169 */
170 Expr getOverloadedFunctionAt(const TypeArgTrie* tat, bool reqUnique=true) const;
171 };
172
173 bool OverloadedTypeTrie::isOverloadedFunction(Expr fun) const {
174 return d_overloaded_symbols->find(fun) != d_overloaded_symbols->end();
175 }
176
177 Expr OverloadedTypeTrie::getOverloadedConstantForType(const std::string& name,
178 Type t) const {
179 std::unordered_map<std::string, TypeArgTrie>::const_iterator it =
180 d_overload_type_arg_trie.find(name);
181 if (it != d_overload_type_arg_trie.end()) {
182 std::map<Type, Expr>::const_iterator its = it->second.d_symbols.find(t);
183 if (its != it->second.d_symbols.end()) {
184 Expr expr = its->second;
185 // must be an active symbol
186 if (isOverloadedFunction(expr)) {
187 return expr;
188 }
189 }
190 }
191 return d_nullExpr;
192 }
193
194 Expr OverloadedTypeTrie::getOverloadedFunctionForTypes(
195 const std::string& name, const std::vector<Type>& argTypes) const {
196 std::unordered_map<std::string, TypeArgTrie>::const_iterator it =
197 d_overload_type_arg_trie.find(name);
198 if (it != d_overload_type_arg_trie.end()) {
199 const TypeArgTrie* tat = &it->second;
200 for (unsigned i = 0; i < argTypes.size(); i++) {
201 std::map<Type, TypeArgTrie>::const_iterator itc =
202 tat->d_children.find(argTypes[i]);
203 if (itc != tat->d_children.end()) {
204 tat = &itc->second;
205 } else {
206 Trace("parser-overloading")
207 << "Could not find overloaded function " << name << std::endl;
208 // it may be a parametric datatype
209 TypeNode tna = TypeNode::fromType(argTypes[i]);
210 if (tna.isParametricDatatype())
211 {
212 Trace("parser-overloading")
213 << "Parametric overloaded datatype selector " << name << " "
214 << tna << std::endl;
215 DatatypeType tnd = static_cast<DatatypeType>(argTypes[i]);
216 const Datatype& dt = tnd.getDatatype();
217 // tng is the "generalized" version of the instantiated parametric
218 // type tna
219 Type tng = dt.getDatatypeType();
220 itc = tat->d_children.find(tng);
221 if (itc != tat->d_children.end())
222 {
223 tat = &itc->second;
224 }
225 }
226 if (tat == nullptr)
227 {
228 // no functions match
229 return d_nullExpr;
230 }
231 }
232 }
233 // we ensure that there is *only* one active symbol at this node
234 return getOverloadedFunctionAt(tat);
235 }
236 return d_nullExpr;
237 }
238
239 bool OverloadedTypeTrie::bind(const string& name, Expr prev_bound_obj,
240 Expr obj) {
241 bool retprev = true;
242 if (!isOverloadedFunction(prev_bound_obj)) {
243 // mark previous as overloaded
244 retprev = markOverloaded(name, prev_bound_obj);
245 }
246 // mark this as overloaded
247 bool retobj = markOverloaded(name, obj);
248 return retprev && retobj;
249 }
250
251 bool OverloadedTypeTrie::markOverloaded(const string& name, Expr obj) {
252 Trace("parser-overloading") << "Overloaded function : " << name;
253 Trace("parser-overloading") << " with type " << obj.getType() << std::endl;
254 // get the argument types
255 Type t = obj.getType();
256 Type rangeType = t;
257 std::vector<Type> argTypes;
258 if (t.isFunction()) {
259 argTypes = static_cast<FunctionType>(t).getArgTypes();
260 rangeType = static_cast<FunctionType>(t).getRangeType();
261 } else if (t.isConstructor()) {
262 argTypes = static_cast<ConstructorType>(t).getArgTypes();
263 rangeType = static_cast<ConstructorType>(t).getRangeType();
264 } else if (t.isTester()) {
265 argTypes.push_back(static_cast<TesterType>(t).getDomain());
266 rangeType = static_cast<TesterType>(t).getRangeType();
267 } else if (t.isSelector()) {
268 argTypes.push_back(static_cast<SelectorType>(t).getDomain());
269 rangeType = static_cast<SelectorType>(t).getRangeType();
270 }
271 // add to the trie
272 TypeArgTrie* tat = &d_overload_type_arg_trie[name];
273 for (unsigned i = 0; i < argTypes.size(); i++) {
274 tat = &(tat->d_children[argTypes[i]]);
275 }
276
277 // check if function variants are allowed here
278 if (d_allowFunctionVariants || argTypes.empty())
279 {
280 // they are allowed, check for redefinition
281 std::map<Type, Expr>::iterator it = tat->d_symbols.find(rangeType);
282 if (it != tat->d_symbols.end())
283 {
284 Expr prev_obj = it->second;
285 // if there is already an active function with the same name and expects
286 // the same argument types and has the same return type, we reject the
287 // re-declaration here.
288 if (isOverloadedFunction(prev_obj))
289 {
290 return false;
291 }
292 }
293 }
294 else
295 {
296 // they are not allowed, we cannot have any function defined here.
297 Expr existingFun = getOverloadedFunctionAt(tat, false);
298 if (!existingFun.isNull())
299 {
300 return false;
301 }
302 }
303
304 // otherwise, update the symbols
305 d_overloaded_symbols->insert(obj);
306 tat->d_symbols[rangeType] = obj;
307 return true;
308 }
309
310 Expr OverloadedTypeTrie::getOverloadedFunctionAt(
311 const OverloadedTypeTrie::TypeArgTrie* tat, bool reqUnique) const
312 {
313 Expr retExpr;
314 for (std::map<Type, Expr>::const_iterator its = tat->d_symbols.begin();
315 its != tat->d_symbols.end();
316 ++its)
317 {
318 Expr expr = its->second;
319 if (isOverloadedFunction(expr))
320 {
321 if (retExpr.isNull())
322 {
323 if (!reqUnique)
324 {
325 return expr;
326 }
327 else
328 {
329 retExpr = expr;
330 }
331 }
332 else
333 {
334 // multiple functions match
335 return d_nullExpr;
336 }
337 }
338 }
339 return retExpr;
340 }
341
342 class SymbolTable::Implementation {
343 public:
344 Implementation()
345 : d_context(),
346 d_exprMap(new (true) CDHashMap<string, Expr>(&d_context)),
347 d_typeMap(new (true) TypeMap(&d_context))
348 {
349 d_overload_trie = new OverloadedTypeTrie(&d_context);
350 }
351
352 ~Implementation() {
353 d_exprMap->deleteSelf();
354 d_typeMap->deleteSelf();
355 delete d_overload_trie;
356 }
357
358 bool bind(const string& name, Expr obj, bool levelZero, bool doOverload);
359 void bindType(const string& name, Type t, bool levelZero = false);
360 void bindType(const string& name, const vector<Type>& params, Type t,
361 bool levelZero = false);
362 bool isBound(const string& name) const;
363 bool isBoundType(const string& name) const;
364 Expr lookup(const string& name) const;
365 Type lookupType(const string& name) const;
366 Type lookupType(const string& name, const vector<Type>& params) const;
367 size_t lookupArity(const string& name);
368 void popScope();
369 void pushScope();
370 size_t getLevel() const;
371 void reset();
372 //------------------------ operator overloading
373 /** implementation of function from header */
374 bool isOverloadedFunction(Expr fun) const;
375
376 /** implementation of function from header */
377 Expr getOverloadedConstantForType(const std::string& name, Type t) const;
378
379 /** implementation of function from header */
380 Expr getOverloadedFunctionForTypes(const std::string& name,
381 const std::vector<Type>& argTypes) const;
382 //------------------------ end operator overloading
383 private:
384 /** The context manager for the scope maps. */
385 Context d_context;
386
387 /** A map for expressions. */
388 CDHashMap<string, Expr>* d_exprMap;
389
390 /** A map for types. */
391 using TypeMap = CDHashMap<string, std::pair<vector<Type>, Type>>;
392 TypeMap* d_typeMap;
393
394 //------------------------ operator overloading
395 // the null expression
396 Expr d_nullExpr;
397 // overloaded type trie, stores all information regarding overloading
398 OverloadedTypeTrie* d_overload_trie;
399 /** bind with overloading
400 * This is called whenever obj is bound to name where overloading symbols is
401 * allowed. If a symbol is previously bound to that name, it marks both as
402 * overloaded. Returns false if the binding was invalid.
403 */
404 bool bindWithOverloading(const string& name, Expr obj);
405 //------------------------ end operator overloading
406 }; /* SymbolTable::Implementation */
407
408 bool SymbolTable::Implementation::bind(const string& name, Expr obj,
409 bool levelZero, bool doOverload) {
410 PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null Expr");
411 ExprManagerScope ems(obj);
412 if (doOverload) {
413 if (!bindWithOverloading(name, obj)) {
414 return false;
415 }
416 }
417 if (levelZero) {
418 d_exprMap->insertAtContextLevelZero(name, obj);
419 } else {
420 d_exprMap->insert(name, obj);
421 }
422 return true;
423 }
424
425 bool SymbolTable::Implementation::isBound(const string& name) const {
426 return d_exprMap->find(name) != d_exprMap->end();
427 }
428
429 Expr SymbolTable::Implementation::lookup(const string& name) const {
430 Assert(isBound(name));
431 Expr expr = (*d_exprMap->find(name)).second;
432 if (isOverloadedFunction(expr)) {
433 return d_nullExpr;
434 } else {
435 return expr;
436 }
437 }
438
439 void SymbolTable::Implementation::bindType(const string& name, Type t,
440 bool levelZero) {
441 if (levelZero) {
442 d_typeMap->insertAtContextLevelZero(name, make_pair(vector<Type>(), t));
443 } else {
444 d_typeMap->insert(name, make_pair(vector<Type>(), t));
445 }
446 }
447
448 void SymbolTable::Implementation::bindType(const string& name,
449 const vector<Type>& params, Type t,
450 bool levelZero) {
451 if (Debug.isOn("sort")) {
452 Debug("sort") << "bindType(" << name << ", [";
453 if (params.size() > 0) {
454 copy(params.begin(), params.end() - 1,
455 ostream_iterator<Type>(Debug("sort"), ", "));
456 Debug("sort") << params.back();
457 }
458 Debug("sort") << "], " << t << ")" << endl;
459 }
460 if (levelZero) {
461 d_typeMap->insertAtContextLevelZero(name, make_pair(params, t));
462 } else {
463 d_typeMap->insert(name, make_pair(params, t));
464 }
465 }
466
467 bool SymbolTable::Implementation::isBoundType(const string& name) const {
468 return d_typeMap->find(name) != d_typeMap->end();
469 }
470
471 Type SymbolTable::Implementation::lookupType(const string& name) const {
472 pair<vector<Type>, Type> p = (*d_typeMap->find(name)).second;
473 PrettyCheckArgument(p.first.size() == 0, name,
474 "type constructor arity is wrong: "
475 "`%s' requires %u parameters but was provided 0",
476 name.c_str(), p.first.size());
477 return p.second;
478 }
479
480 Type SymbolTable::Implementation::lookupType(const string& name,
481 const vector<Type>& params) const {
482 pair<vector<Type>, Type> p = (*d_typeMap->find(name)).second;
483 PrettyCheckArgument(p.first.size() == params.size(), params,
484 "type constructor arity is wrong: "
485 "`%s' requires %u parameters but was provided %u",
486 name.c_str(), p.first.size(), params.size());
487 if (p.first.size() == 0) {
488 PrettyCheckArgument(p.second.isSort(), name.c_str());
489 return p.second;
490 }
491 if (p.second.isSortConstructor()) {
492 if (Debug.isOn("sort")) {
493 Debug("sort") << "instantiating using a sort constructor" << endl;
494 Debug("sort") << "have formals [";
495 copy(p.first.begin(), p.first.end() - 1,
496 ostream_iterator<Type>(Debug("sort"), ", "));
497 Debug("sort") << p.first.back() << "]" << endl << "parameters [";
498 copy(params.begin(), params.end() - 1,
499 ostream_iterator<Type>(Debug("sort"), ", "));
500 Debug("sort") << params.back() << "]" << endl
501 << "type ctor " << name << endl
502 << "type is " << p.second << endl;
503 }
504
505 Type instantiation = SortConstructorType(p.second).instantiate(params);
506
507 Debug("sort") << "instance is " << instantiation << endl;
508
509 return instantiation;
510 } else if (p.second.isDatatype()) {
511 PrettyCheckArgument(DatatypeType(p.second).isParametric(), name,
512 "expected parametric datatype");
513 return DatatypeType(p.second).instantiate(params);
514 } else {
515 if (Debug.isOn("sort")) {
516 Debug("sort") << "instantiating using a sort substitution" << endl;
517 Debug("sort") << "have formals [";
518 copy(p.first.begin(), p.first.end() - 1,
519 ostream_iterator<Type>(Debug("sort"), ", "));
520 Debug("sort") << p.first.back() << "]" << endl << "parameters [";
521 copy(params.begin(), params.end() - 1,
522 ostream_iterator<Type>(Debug("sort"), ", "));
523 Debug("sort") << params.back() << "]" << endl
524 << "type ctor " << name << endl
525 << "type is " << p.second << endl;
526 }
527
528 Type instantiation = p.second.substitute(p.first, params);
529
530 Debug("sort") << "instance is " << instantiation << endl;
531
532 return instantiation;
533 }
534 }
535
536 size_t SymbolTable::Implementation::lookupArity(const string& name) {
537 pair<vector<Type>, Type> p = (*d_typeMap->find(name)).second;
538 return p.first.size();
539 }
540
541 void SymbolTable::Implementation::popScope() {
542 if (d_context.getLevel() == 0) {
543 throw ScopeException();
544 }
545 d_context.pop();
546 }
547
548 void SymbolTable::Implementation::pushScope() { d_context.push(); }
549
550 size_t SymbolTable::Implementation::getLevel() const {
551 return d_context.getLevel();
552 }
553
554 void SymbolTable::Implementation::reset() {
555 this->SymbolTable::Implementation::~Implementation();
556 new (this) SymbolTable::Implementation();
557 }
558
559 bool SymbolTable::Implementation::isOverloadedFunction(Expr fun) const {
560 return d_overload_trie->isOverloadedFunction(fun);
561 }
562
563 Expr SymbolTable::Implementation::getOverloadedConstantForType(
564 const std::string& name, Type t) const {
565 return d_overload_trie->getOverloadedConstantForType(name, t);
566 }
567
568 Expr SymbolTable::Implementation::getOverloadedFunctionForTypes(
569 const std::string& name, const std::vector<Type>& argTypes) const {
570 return d_overload_trie->getOverloadedFunctionForTypes(name, argTypes);
571 }
572
573 bool SymbolTable::Implementation::bindWithOverloading(const string& name,
574 Expr obj) {
575 CDHashMap<string, Expr>::const_iterator it = d_exprMap->find(name);
576 if (it != d_exprMap->end()) {
577 const Expr& prev_bound_obj = (*it).second;
578 if (prev_bound_obj != obj) {
579 return d_overload_trie->bind(name, prev_bound_obj, obj);
580 }
581 }
582 return true;
583 }
584
585 bool SymbolTable::isOverloadedFunction(Expr fun) const {
586 return d_implementation->isOverloadedFunction(fun);
587 }
588
589 Expr SymbolTable::getOverloadedConstantForType(const std::string& name,
590 Type t) const {
591 return d_implementation->getOverloadedConstantForType(name, t);
592 }
593
594 Expr SymbolTable::getOverloadedFunctionForTypes(
595 const std::string& name, const std::vector<Type>& argTypes) const {
596 return d_implementation->getOverloadedFunctionForTypes(name, argTypes);
597 }
598
599 SymbolTable::SymbolTable()
600 : d_implementation(new SymbolTable::Implementation()) {}
601
602 SymbolTable::~SymbolTable() {}
603 bool SymbolTable::bind(const string& name,
604 Expr obj,
605 bool levelZero,
606 bool doOverload)
607 {
608 return d_implementation->bind(name, obj, levelZero, doOverload);
609 }
610
611 void SymbolTable::bindType(const string& name, Type t, bool levelZero)
612 {
613 d_implementation->bindType(name, t, levelZero);
614 }
615
616 void SymbolTable::bindType(const string& name,
617 const vector<Type>& params,
618 Type t,
619 bool levelZero)
620 {
621 d_implementation->bindType(name, params, t, levelZero);
622 }
623
624 bool SymbolTable::isBound(const string& name) const
625 {
626 return d_implementation->isBound(name);
627 }
628 bool SymbolTable::isBoundType(const string& name) const
629 {
630 return d_implementation->isBoundType(name);
631 }
632 Expr SymbolTable::lookup(const string& name) const
633 {
634 return d_implementation->lookup(name);
635 }
636 Type SymbolTable::lookupType(const string& name) const
637 {
638 return d_implementation->lookupType(name);
639 }
640
641 Type SymbolTable::lookupType(const string& name,
642 const vector<Type>& params) const
643 {
644 return d_implementation->lookupType(name, params);
645 }
646 size_t SymbolTable::lookupArity(const string& name) {
647 return d_implementation->lookupArity(name);
648 }
649 void SymbolTable::popScope() { d_implementation->popScope(); }
650 void SymbolTable::pushScope() { d_implementation->pushScope(); }
651 size_t SymbolTable::getLevel() const { return d_implementation->getLevel(); }
652 void SymbolTable::reset() { d_implementation->reset(); }
653
654 } // namespace CVC4