Update symbol table to support operator overloading (#1154)
[cvc5.git] / src / expr / symbol_table.cpp
1 /********************* */
2 /*! \file symbol_table.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Morgan Deters, Christopher L. Conway, Francois Bobot
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Convenience class for scoping variable and type
13 ** declarations (implementation)
14 **
15 ** Convenience class for scoping variable and type declarations
16 ** (implementation).
17 **/
18
19 #include "expr/symbol_table.h"
20
21 #include <ostream>
22 #include <string>
23 #include <utility>
24 #include <unordered_map>
25
26 #include "context/cdhashmap.h"
27 #include "context/cdhashset.h"
28 #include "context/context.h"
29 #include "expr/expr.h"
30 #include "expr/expr_manager_scope.h"
31 #include "expr/type.h"
32
33 namespace CVC4 {
34
35 using ::CVC4::context::CDHashMap;
36 using ::CVC4::context::CDHashSet;
37 using ::CVC4::context::Context;
38 using ::std::copy;
39 using ::std::endl;
40 using ::std::ostream_iterator;
41 using ::std::pair;
42 using ::std::string;
43 using ::std::vector;
44
45 // This data structure stores a trie of expressions with
46 // the same name, and must be distinguished by their argument types.
47 // It is context-dependent.
48 class OverloadedTypeTrie
49 {
50 public:
51 OverloadedTypeTrie(Context * c ) :
52 d_overloaded_symbols(new (true) CDHashSet<Expr, ExprHashFunction>(c)) {
53 }
54 ~OverloadedTypeTrie() {
55 d_overloaded_symbols->deleteSelf();
56 }
57 /** is this function overloaded? */
58 bool isOverloadedFunction(Expr fun) const;
59
60 /** Get overloaded constant for type.
61 * If possible, it returns a defined symbol with name
62 * that has type t. Otherwise returns null expression.
63 */
64 Expr getOverloadedConstantForType(const std::string& name, Type t) const;
65
66 /**
67 * If possible, returns a defined function for a name
68 * and a vector of expected argument types. Otherwise returns
69 * null expression.
70 */
71 Expr getOverloadedFunctionForTypes(const std::string& name, const std::vector< Type >& argTypes) const;
72 /** called when obj is bound to name, and prev_bound_obj was already bound to name
73 * Returns false if the binding is invalid.
74 */
75 bool bind(const string& name, Expr prev_bound_obj, Expr obj);
76 private:
77 /** Marks expression obj with name as overloaded.
78 * Adds relevant information to the type arg trie data structure.
79 * It returns false if there is already an expression bound to that name
80 * whose type expects the same arguments as the type of obj but is not identical
81 * to the type of obj. For example, if we declare :
82 *
83 * (declare-datatypes () ((List (cons (hd Int) (tl List)) (nil))))
84 * (declare-fun cons (Int List) List)
85 *
86 * cons : constructor_type( Int, List, List )
87 * cons : function_type( Int, List, List )
88 *
89 * These are put in the same place in the trie but do not have identical type,
90 * hence we return false.
91 */
92 bool markOverloaded(const string& name, Expr obj);
93 /** the null expression */
94 Expr d_nullExpr;
95 // The (context-independent) trie storing that maps expected argument
96 // vectors to symbols. All expressions stored in d_symbols are only
97 // interpreted as active if they also appear in the context-dependent
98 // set d_overloaded_symbols.
99 class TypeArgTrie {
100 public:
101 // children of this node
102 std::map< Type, TypeArgTrie > d_children;
103 // symbols at this node
104 std::map< Type, Expr > d_symbols;
105 };
106 /** for each string with operator overloading, this stores the data structure above. */
107 std::unordered_map< std::string, TypeArgTrie > d_overload_type_arg_trie;
108 /** The set of overloaded symbols. */
109 CDHashSet<Expr, ExprHashFunction>* d_overloaded_symbols;
110 };
111
112 bool OverloadedTypeTrie::isOverloadedFunction(Expr fun) const {
113 return d_overloaded_symbols->find(fun)!=d_overloaded_symbols->end();
114 }
115
116 Expr OverloadedTypeTrie::getOverloadedConstantForType(const std::string& name, Type t) const {
117 std::unordered_map< std::string, TypeArgTrie >::const_iterator it = d_overload_type_arg_trie.find(name);
118 if(it!=d_overload_type_arg_trie.end()) {
119 std::map< Type, Expr >::const_iterator its = it->second.d_symbols.find(t);
120 if(its!=it->second.d_symbols.end()) {
121 Expr expr = its->second;
122 // must be an active symbol
123 if(isOverloadedFunction(expr)) {
124 return expr;
125 }
126 }
127 }
128 return d_nullExpr;
129 }
130
131 Expr OverloadedTypeTrie::getOverloadedFunctionForTypes(const std::string& name,
132 const std::vector< Type >& argTypes) const {
133 std::unordered_map< std::string, TypeArgTrie >::const_iterator it = d_overload_type_arg_trie.find(name);
134 if(it!=d_overload_type_arg_trie.end()) {
135 const TypeArgTrie * tat = &it->second;
136 for(unsigned i=0; i<argTypes.size(); i++) {
137 std::map< Type, TypeArgTrie >::const_iterator itc = tat->d_children.find(argTypes[i]);
138 if(itc!=tat->d_children.end()) {
139 tat = &itc->second;
140 }else{
141 // no functions match
142 return d_nullExpr;
143 }
144 }
145 // now, we must ensure that there is *only* one active symbol at this node
146 Expr retExpr;
147 for(std::map< Type, Expr >::const_iterator its = tat->d_symbols.begin(); its != tat->d_symbols.end(); ++its) {
148 Expr expr = its->second;
149 if(isOverloadedFunction(expr)) {
150 if(retExpr.isNull()) {
151 retExpr = expr;
152 }else{
153 // multiple functions match
154 return d_nullExpr;
155 }
156 }
157 }
158 return retExpr;
159 }
160 return d_nullExpr;
161 }
162
163 bool OverloadedTypeTrie::bind(const string& name, Expr prev_bound_obj, Expr obj) {
164 bool retprev = true;
165 if(!isOverloadedFunction(prev_bound_obj)) {
166 // mark previous as overloaded
167 retprev = markOverloaded(name, prev_bound_obj);
168 }
169 // mark this as overloaded
170 bool retobj = markOverloaded(name, obj);
171 return retprev && retobj;
172 }
173
174 bool OverloadedTypeTrie::markOverloaded(const string& name, Expr obj) {
175 Trace("parser-overloading") << "Overloaded function : " << name;
176 Trace("parser-overloading") << " with type " << obj.getType() << std::endl;
177 // get the argument types
178 Type t = obj.getType();
179 Type rangeType = t;
180 std::vector< Type > argTypes;
181 if(t.isFunction()) {
182 argTypes = static_cast<FunctionType>(t).getArgTypes();
183 rangeType = static_cast<FunctionType>(t).getRangeType();
184 }else if(t.isConstructor()) {
185 argTypes = static_cast<ConstructorType>(t).getArgTypes();
186 rangeType = static_cast<ConstructorType>(t).getRangeType();
187 }else if(t.isTester()) {
188 argTypes.push_back( static_cast<TesterType>(t).getDomain() );
189 rangeType = static_cast<TesterType>(t).getRangeType();
190 }else if(t.isSelector()) {
191 argTypes.push_back( static_cast<SelectorType>(t).getDomain() );
192 rangeType = static_cast<SelectorType>(t).getRangeType();
193 }
194 // add to the trie
195 TypeArgTrie * tat = &d_overload_type_arg_trie[name];
196 for(unsigned i=0; i<argTypes.size(); i++) {
197 tat = &(tat->d_children[argTypes[i]]);
198 }
199
200 // types can be identical but vary on the kind of the type, thus we must distinguish based on this
201 std::map< Type, Expr >::iterator it = tat->d_symbols.find( rangeType );
202 if( it!=tat->d_symbols.end() ){
203 Expr prev_obj = it->second;
204 // if there is already an active function with the same name and expects the same argument types
205 if( isOverloadedFunction(prev_obj) ){
206 if( prev_obj.getType()==obj.getType() ){
207 //types are identical, simply ignore it
208 return true;
209 }else{
210 //otherwise there is no way to distinguish these types, we return an error
211 return false;
212 }
213 }
214 }
215
216 // otherwise, update the symbols
217 d_overloaded_symbols->insert(obj);
218 tat->d_symbols[rangeType] = obj;
219 return true;
220 }
221
222
223 class SymbolTable::Implementation {
224 public:
225 Implementation()
226 : d_context(),
227 d_exprMap(new (true) CDHashMap<string, Expr>(&d_context)),
228 d_typeMap(new (true) TypeMap(&d_context)),
229 d_functions(new (true) CDHashSet<Expr, ExprHashFunction>(&d_context)){
230 d_overload_trie = new OverloadedTypeTrie(&d_context);
231 }
232
233 ~Implementation() {
234 d_exprMap->deleteSelf();
235 d_typeMap->deleteSelf();
236 d_functions->deleteSelf();
237 delete d_overload_trie;
238 }
239
240 bool bind(const string& name, Expr obj, bool levelZero, bool doOverload) throw();
241 bool bindDefinedFunction(const string& name, Expr obj,
242 bool levelZero, bool doOverload) throw();
243 void bindType(const string& name, Type t, bool levelZero = false) throw();
244 void bindType(const string& name, const vector<Type>& params, Type t,
245 bool levelZero = false) throw();
246 bool isBound(const string& name) const throw();
247 bool isBoundDefinedFunction(const string& name) const throw();
248 bool isBoundDefinedFunction(Expr func) const throw();
249 bool isBoundType(const string& name) const throw();
250 Expr lookup(const string& name) const throw();
251 Type lookupType(const string& name) const throw();
252 Type lookupType(const string& name, const vector<Type>& params) const throw();
253 size_t lookupArity(const string& name);
254 void popScope() throw(ScopeException);
255 void pushScope() throw();
256 size_t getLevel() const throw();
257 void reset();
258 //------------------------ operator overloading
259 /** implementation of function from header */
260 bool isOverloadedFunction(Expr fun) const;
261
262 /** implementation of function from header */
263 Expr getOverloadedConstantForType(const std::string& name, Type t) const;
264
265 /** implementation of function from header */
266 Expr getOverloadedFunctionForTypes(const std::string& name, const std::vector< Type >& argTypes) const;
267 //------------------------ end operator overloading
268 private:
269 /** The context manager for the scope maps. */
270 Context d_context;
271
272 /** A map for expressions. */
273 CDHashMap<string, Expr>* d_exprMap;
274
275 /** A map for types. */
276 using TypeMap = CDHashMap<string, std::pair<vector<Type>, Type>>;
277 TypeMap* d_typeMap;
278
279 /** A set of defined functions. */
280 CDHashSet<Expr, ExprHashFunction>* d_functions;
281
282 //------------------------ operator overloading
283 // the null expression
284 Expr d_nullExpr;
285 // overloaded type trie, stores all information regarding overloading
286 OverloadedTypeTrie * d_overload_trie;
287 /** bind with overloading
288 * This is called whenever obj is bound to name where overloading symbols is allowed.
289 * If a symbol is previously bound to that name, it marks both as overloaded.
290 * Returns false if the binding was invalid.
291 */
292 bool bindWithOverloading(const string& name, Expr obj);
293 //------------------------ end operator overloading
294 }; /* SymbolTable::Implementation */
295
296 bool SymbolTable::Implementation::bind(const string& name, Expr obj,
297 bool levelZero, bool doOverload) throw() {
298 PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null Expr");
299 ExprManagerScope ems(obj);
300 if (doOverload) {
301 if( !bindWithOverloading(name, obj) ){
302 return false;
303 }
304 }
305 if (levelZero) {
306 d_exprMap->insertAtContextLevelZero(name, obj);
307 } else {
308 d_exprMap->insert(name, obj);
309 }
310 return true;
311 }
312
313 bool SymbolTable::Implementation::bindDefinedFunction(const string& name,
314 Expr obj,
315 bool levelZero,
316 bool doOverload) throw() {
317 PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null Expr");
318 ExprManagerScope ems(obj);
319 if (doOverload) {
320 if( !bindWithOverloading(name, obj) ){
321 return false;
322 }
323 }
324 if (levelZero) {
325 d_exprMap->insertAtContextLevelZero(name, obj);
326 d_functions->insertAtContextLevelZero(obj);
327 } else {
328 d_exprMap->insert(name, obj);
329 d_functions->insert(obj);
330 }
331 return true;
332 }
333
334 bool SymbolTable::Implementation::isBound(const string& name) const throw() {
335 return d_exprMap->find(name) != d_exprMap->end();
336 }
337
338 bool SymbolTable::Implementation::isBoundDefinedFunction(
339 const string& name) const throw() {
340 CDHashMap<string, Expr>::iterator found = d_exprMap->find(name);
341 return found != d_exprMap->end() && d_functions->contains((*found).second);
342 }
343
344 bool SymbolTable::Implementation::isBoundDefinedFunction(Expr func) const
345 throw() {
346 return d_functions->contains(func);
347 }
348
349 Expr SymbolTable::Implementation::lookup(const string& name) const throw() {
350 Expr expr = (*d_exprMap->find(name)).second;
351 if(isOverloadedFunction(expr)) {
352 return d_nullExpr;
353 }else{
354 return expr;
355 }
356 }
357
358 void SymbolTable::Implementation::bindType(const string& name, Type t,
359 bool levelZero) throw() {
360 if (levelZero) {
361 d_typeMap->insertAtContextLevelZero(name, make_pair(vector<Type>(), t));
362 } else {
363 d_typeMap->insert(name, make_pair(vector<Type>(), t));
364 }
365 }
366
367 void SymbolTable::Implementation::bindType(const string& name,
368 const vector<Type>& params, Type t,
369 bool levelZero) throw() {
370 if (Debug.isOn("sort")) {
371 Debug("sort") << "bindType(" << name << ", [";
372 if (params.size() > 0) {
373 copy(params.begin(), params.end() - 1,
374 ostream_iterator<Type>(Debug("sort"), ", "));
375 Debug("sort") << params.back();
376 }
377 Debug("sort") << "], " << t << ")" << endl;
378 }
379 if (levelZero) {
380 d_typeMap->insertAtContextLevelZero(name, make_pair(params, t));
381 } else {
382 d_typeMap->insert(name, make_pair(params, t));
383 }
384 }
385
386 bool SymbolTable::Implementation::isBoundType(const string& name) const
387 throw() {
388 return d_typeMap->find(name) != d_typeMap->end();
389 }
390
391 Type SymbolTable::Implementation::lookupType(const string& name) const throw() {
392 pair<vector<Type>, Type> p = (*d_typeMap->find(name)).second;
393 PrettyCheckArgument(p.first.size() == 0, name,
394 "type constructor arity is wrong: "
395 "`%s' requires %u parameters but was provided 0",
396 name.c_str(), p.first.size());
397 return p.second;
398 }
399
400 Type SymbolTable::Implementation::lookupType(const string& name,
401 const vector<Type>& params) const
402 throw() {
403 pair<vector<Type>, Type> p = (*d_typeMap->find(name)).second;
404 PrettyCheckArgument(p.first.size() == params.size(), params,
405 "type constructor arity is wrong: "
406 "`%s' requires %u parameters but was provided %u",
407 name.c_str(), p.first.size(), params.size());
408 if (p.first.size() == 0) {
409 PrettyCheckArgument(p.second.isSort(), name.c_str());
410 return p.second;
411 }
412 if (p.second.isSortConstructor()) {
413 if (Debug.isOn("sort")) {
414 Debug("sort") << "instantiating using a sort constructor" << endl;
415 Debug("sort") << "have formals [";
416 copy(p.first.begin(), p.first.end() - 1,
417 ostream_iterator<Type>(Debug("sort"), ", "));
418 Debug("sort") << p.first.back() << "]" << endl << "parameters [";
419 copy(params.begin(), params.end() - 1,
420 ostream_iterator<Type>(Debug("sort"), ", "));
421 Debug("sort") << params.back() << "]" << endl
422 << "type ctor " << name << endl
423 << "type is " << p.second << endl;
424 }
425
426 Type instantiation = SortConstructorType(p.second).instantiate(params);
427
428 Debug("sort") << "instance is " << instantiation << endl;
429
430 return instantiation;
431 } else if (p.second.isDatatype()) {
432 PrettyCheckArgument(DatatypeType(p.second).isParametric(), name,
433 "expected parametric datatype");
434 return DatatypeType(p.second).instantiate(params);
435 } else {
436 if (Debug.isOn("sort")) {
437 Debug("sort") << "instantiating using a sort substitution" << endl;
438 Debug("sort") << "have formals [";
439 copy(p.first.begin(), p.first.end() - 1,
440 ostream_iterator<Type>(Debug("sort"), ", "));
441 Debug("sort") << p.first.back() << "]" << endl << "parameters [";
442 copy(params.begin(), params.end() - 1,
443 ostream_iterator<Type>(Debug("sort"), ", "));
444 Debug("sort") << params.back() << "]" << endl
445 << "type ctor " << name << endl
446 << "type is " << p.second << endl;
447 }
448
449 Type instantiation = p.second.substitute(p.first, params);
450
451 Debug("sort") << "instance is " << instantiation << endl;
452
453 return instantiation;
454 }
455 }
456
457 size_t SymbolTable::Implementation::lookupArity(const string& name) {
458 pair<vector<Type>, Type> p = (*d_typeMap->find(name)).second;
459 return p.first.size();
460 }
461
462 void SymbolTable::Implementation::popScope() throw(ScopeException) {
463 if (d_context.getLevel() == 0) {
464 throw ScopeException();
465 }
466 d_context.pop();
467 }
468
469 void SymbolTable::Implementation::pushScope() throw() { d_context.push(); }
470
471 size_t SymbolTable::Implementation::getLevel() const throw() {
472 return d_context.getLevel();
473 }
474
475 void SymbolTable::Implementation::reset() {
476 this->SymbolTable::Implementation::~Implementation();
477 new (this) SymbolTable::Implementation();
478 }
479
480 bool SymbolTable::Implementation::isOverloadedFunction(Expr fun) const {
481 return d_overload_trie->isOverloadedFunction(fun);
482 }
483
484 Expr SymbolTable::Implementation::getOverloadedConstantForType(const std::string& name, Type t) const {
485 return d_overload_trie->getOverloadedConstantForType(name, t);
486 }
487
488 Expr SymbolTable::Implementation::getOverloadedFunctionForTypes(const std::string& name,
489 const std::vector< Type >& argTypes) const {
490 return d_overload_trie->getOverloadedFunctionForTypes(name, argTypes);
491 }
492
493 bool SymbolTable::Implementation::bindWithOverloading(const string& name, Expr obj){
494 CDHashMap<string, Expr>::const_iterator it = d_exprMap->find(name);
495 if(it != d_exprMap->end()) {
496 const Expr& prev_bound_obj = (*it).second;
497 if(prev_bound_obj!=obj) {
498 return d_overload_trie->bind(name, prev_bound_obj, obj);
499 }
500 }
501 return true;
502 }
503
504 bool SymbolTable::isOverloadedFunction(Expr fun) const {
505 return d_implementation->isOverloadedFunction(fun);
506 }
507
508 Expr SymbolTable::getOverloadedConstantForType(const std::string& name, Type t) const {
509 return d_implementation->getOverloadedConstantForType(name, t);
510 }
511
512 Expr SymbolTable::getOverloadedFunctionForTypes(const std::string& name,
513 const std::vector< Type >& argTypes) const {
514 return d_implementation->getOverloadedFunctionForTypes(name, argTypes);
515 }
516
517 SymbolTable::SymbolTable()
518 : d_implementation(new SymbolTable::Implementation()) {}
519
520 SymbolTable::~SymbolTable() {}
521
522 bool SymbolTable::bind(const string& name, Expr obj, bool levelZero, bool doOverload) throw() {
523 return d_implementation->bind(name, obj, levelZero, doOverload);
524 }
525
526 bool SymbolTable::bindDefinedFunction(const string& name, Expr obj,
527 bool levelZero, bool doOverload) throw() {
528 return d_implementation->bindDefinedFunction(name, obj, levelZero, doOverload);
529 }
530
531 void SymbolTable::bindType(const string& name, Type t, bool levelZero) throw() {
532 d_implementation->bindType(name, t, levelZero);
533 }
534
535 void SymbolTable::bindType(const string& name, const vector<Type>& params,
536 Type t, bool levelZero) throw() {
537 d_implementation->bindType(name, params, t, levelZero);
538 }
539
540 bool SymbolTable::isBound(const string& name) const throw() {
541 return d_implementation->isBound(name);
542 }
543
544 bool SymbolTable::isBoundDefinedFunction(const string& name) const throw() {
545 return d_implementation->isBoundDefinedFunction(name);
546 }
547
548 bool SymbolTable::isBoundDefinedFunction(Expr func) const throw() {
549 return d_implementation->isBoundDefinedFunction(func);
550 }
551 bool SymbolTable::isBoundType(const string& name) const throw() {
552 return d_implementation->isBoundType(name);
553 }
554 Expr SymbolTable::lookup(const string& name) const throw() {
555 return d_implementation->lookup(name);
556 }
557 Type SymbolTable::lookupType(const string& name) const throw() {
558 return d_implementation->lookupType(name);
559 }
560
561 Type SymbolTable::lookupType(const string& name,
562 const vector<Type>& params) const throw() {
563 return d_implementation->lookupType(name, params);
564 }
565 size_t SymbolTable::lookupArity(const string& name) {
566 return d_implementation->lookupArity(name);
567 }
568 void SymbolTable::popScope() throw(ScopeException) {
569 d_implementation->popScope();
570 }
571
572 void SymbolTable::pushScope() throw() { d_implementation->pushScope(); }
573 size_t SymbolTable::getLevel() const throw() {
574 return d_implementation->getLevel();
575 }
576 void SymbolTable::reset() { d_implementation->reset(); }
577
578 } // namespace CVC4