From: Andrew Reynolds Date: Fri, 29 Sep 2017 04:58:03 +0000 (-0500) Subject: Update symbol table to support operator overloading (#1154) X-Git-Tag: cvc5-1.0.0~5600 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=821a9d90914fca4a13bc29f8ff15fb4220cbd1d4;p=cvc5.git Update symbol table to support operator overloading (#1154) --- diff --git a/src/expr/symbol_table.cpp b/src/expr/symbol_table.cpp index c760b3a80..b411d8dfb 100644 --- a/src/expr/symbol_table.cpp +++ b/src/expr/symbol_table.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include "context/cdhashmap.h" #include "context/cdhashset.h" @@ -41,23 +42,204 @@ using ::std::pair; using ::std::string; using ::std::vector; +// This data structure stores a trie of expressions with +// the same name, and must be distinguished by their argument types. +// It is context-dependent. +class OverloadedTypeTrie +{ +public: + OverloadedTypeTrie(Context * c ) : + d_overloaded_symbols(new (true) CDHashSet(c)) { + } + ~OverloadedTypeTrie() { + d_overloaded_symbols->deleteSelf(); + } + /** is this function overloaded? */ + bool isOverloadedFunction(Expr fun) const; + + /** Get overloaded constant for type. + * If possible, it returns a defined symbol with name + * that has type t. Otherwise returns null expression. + */ + Expr getOverloadedConstantForType(const std::string& name, Type t) const; + + /** + * If possible, returns a defined function for a name + * and a vector of expected argument types. Otherwise returns + * null expression. + */ + Expr getOverloadedFunctionForTypes(const std::string& name, const std::vector< Type >& argTypes) const; + /** called when obj is bound to name, and prev_bound_obj was already bound to name + * Returns false if the binding is invalid. + */ + bool bind(const string& name, Expr prev_bound_obj, Expr obj); +private: + /** Marks expression obj with name as overloaded. + * Adds relevant information to the type arg trie data structure. + * It returns false if there is already an expression bound to that name + * whose type expects the same arguments as the type of obj but is not identical + * to the type of obj. For example, if we declare : + * + * (declare-datatypes () ((List (cons (hd Int) (tl List)) (nil)))) + * (declare-fun cons (Int List) List) + * + * cons : constructor_type( Int, List, List ) + * cons : function_type( Int, List, List ) + * + * These are put in the same place in the trie but do not have identical type, + * hence we return false. + */ + bool markOverloaded(const string& name, Expr obj); + /** the null expression */ + Expr d_nullExpr; + // The (context-independent) trie storing that maps expected argument + // vectors to symbols. All expressions stored in d_symbols are only + // interpreted as active if they also appear in the context-dependent + // set d_overloaded_symbols. + class TypeArgTrie { + public: + // children of this node + std::map< Type, TypeArgTrie > d_children; + // symbols at this node + std::map< Type, Expr > d_symbols; + }; + /** for each string with operator overloading, this stores the data structure above. */ + std::unordered_map< std::string, TypeArgTrie > d_overload_type_arg_trie; + /** The set of overloaded symbols. */ + CDHashSet* d_overloaded_symbols; +}; + +bool OverloadedTypeTrie::isOverloadedFunction(Expr fun) const { + return d_overloaded_symbols->find(fun)!=d_overloaded_symbols->end(); +} + +Expr OverloadedTypeTrie::getOverloadedConstantForType(const std::string& name, Type t) const { + std::unordered_map< std::string, TypeArgTrie >::const_iterator it = d_overload_type_arg_trie.find(name); + if(it!=d_overload_type_arg_trie.end()) { + std::map< Type, Expr >::const_iterator its = it->second.d_symbols.find(t); + if(its!=it->second.d_symbols.end()) { + Expr expr = its->second; + // must be an active symbol + if(isOverloadedFunction(expr)) { + return expr; + } + } + } + return d_nullExpr; +} + +Expr OverloadedTypeTrie::getOverloadedFunctionForTypes(const std::string& name, + const std::vector< Type >& argTypes) const { + std::unordered_map< std::string, TypeArgTrie >::const_iterator it = d_overload_type_arg_trie.find(name); + if(it!=d_overload_type_arg_trie.end()) { + const TypeArgTrie * tat = &it->second; + for(unsigned i=0; i::const_iterator itc = tat->d_children.find(argTypes[i]); + if(itc!=tat->d_children.end()) { + tat = &itc->second; + }else{ + // no functions match + return d_nullExpr; + } + } + // now, we must ensure that there is *only* one active symbol at this node + Expr retExpr; + for(std::map< Type, Expr >::const_iterator its = tat->d_symbols.begin(); its != tat->d_symbols.end(); ++its) { + Expr expr = its->second; + if(isOverloadedFunction(expr)) { + if(retExpr.isNull()) { + retExpr = expr; + }else{ + // multiple functions match + return d_nullExpr; + } + } + } + return retExpr; + } + return d_nullExpr; +} + +bool OverloadedTypeTrie::bind(const string& name, Expr prev_bound_obj, Expr obj) { + bool retprev = true; + if(!isOverloadedFunction(prev_bound_obj)) { + // mark previous as overloaded + retprev = markOverloaded(name, prev_bound_obj); + } + // mark this as overloaded + bool retobj = markOverloaded(name, obj); + return retprev && retobj; +} + +bool OverloadedTypeTrie::markOverloaded(const string& name, Expr obj) { + Trace("parser-overloading") << "Overloaded function : " << name; + Trace("parser-overloading") << " with type " << obj.getType() << std::endl; + // get the argument types + Type t = obj.getType(); + Type rangeType = t; + std::vector< Type > argTypes; + if(t.isFunction()) { + argTypes = static_cast(t).getArgTypes(); + rangeType = static_cast(t).getRangeType(); + }else if(t.isConstructor()) { + argTypes = static_cast(t).getArgTypes(); + rangeType = static_cast(t).getRangeType(); + }else if(t.isTester()) { + argTypes.push_back( static_cast(t).getDomain() ); + rangeType = static_cast(t).getRangeType(); + }else if(t.isSelector()) { + argTypes.push_back( static_cast(t).getDomain() ); + rangeType = static_cast(t).getRangeType(); + } + // add to the trie + TypeArgTrie * tat = &d_overload_type_arg_trie[name]; + for(unsigned i=0; id_children[argTypes[i]]); + } + + // types can be identical but vary on the kind of the type, thus we must distinguish based on this + std::map< Type, Expr >::iterator it = tat->d_symbols.find( rangeType ); + if( it!=tat->d_symbols.end() ){ + Expr prev_obj = it->second; + // if there is already an active function with the same name and expects the same argument types + if( isOverloadedFunction(prev_obj) ){ + if( prev_obj.getType()==obj.getType() ){ + //types are identical, simply ignore it + return true; + }else{ + //otherwise there is no way to distinguish these types, we return an error + return false; + } + } + } + + // otherwise, update the symbols + d_overloaded_symbols->insert(obj); + tat->d_symbols[rangeType] = obj; + return true; +} + + class SymbolTable::Implementation { public: Implementation() : d_context(), d_exprMap(new (true) CDHashMap(&d_context)), d_typeMap(new (true) TypeMap(&d_context)), - d_functions(new (true) CDHashSet(&d_context)) {} + d_functions(new (true) CDHashSet(&d_context)){ + d_overload_trie = new OverloadedTypeTrie(&d_context); + } ~Implementation() { d_exprMap->deleteSelf(); d_typeMap->deleteSelf(); d_functions->deleteSelf(); + delete d_overload_trie; } - void bind(const string& name, Expr obj, bool levelZero) throw(); - void bindDefinedFunction(const string& name, Expr obj, - bool levelZero) throw(); + bool bind(const string& name, Expr obj, bool levelZero, bool doOverload) throw(); + bool bindDefinedFunction(const string& name, Expr obj, + bool levelZero, bool doOverload) throw(); void bindType(const string& name, Type t, bool levelZero = false) throw(); void bindType(const string& name, const vector& params, Type t, bool levelZero = false) throw(); @@ -73,7 +255,16 @@ class SymbolTable::Implementation { void pushScope() throw(); size_t getLevel() const throw(); void reset(); - + //------------------------ operator overloading + /** implementation of function from header */ + bool isOverloadedFunction(Expr fun) const; + + /** implementation of function from header */ + Expr getOverloadedConstantForType(const std::string& name, Type t) const; + + /** implementation of function from header */ + Expr getOverloadedFunctionForTypes(const std::string& name, const std::vector< Type >& argTypes) const; + //------------------------ end operator overloading private: /** The context manager for the scope maps. */ Context d_context; @@ -87,24 +278,49 @@ class SymbolTable::Implementation { /** A set of defined functions. */ CDHashSet* d_functions; + + //------------------------ operator overloading + // the null expression + Expr d_nullExpr; + // overloaded type trie, stores all information regarding overloading + OverloadedTypeTrie * d_overload_trie; + /** bind with overloading + * This is called whenever obj is bound to name where overloading symbols is allowed. + * If a symbol is previously bound to that name, it marks both as overloaded. + * Returns false if the binding was invalid. + */ + bool bindWithOverloading(const string& name, Expr obj); + //------------------------ end operator overloading }; /* SymbolTable::Implementation */ -void SymbolTable::Implementation::bind(const string& name, Expr obj, - bool levelZero) throw() { +bool SymbolTable::Implementation::bind(const string& name, Expr obj, + bool levelZero, bool doOverload) throw() { PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null Expr"); ExprManagerScope ems(obj); + if (doOverload) { + if( !bindWithOverloading(name, obj) ){ + return false; + } + } if (levelZero) { d_exprMap->insertAtContextLevelZero(name, obj); } else { d_exprMap->insert(name, obj); } + return true; } -void SymbolTable::Implementation::bindDefinedFunction(const string& name, +bool SymbolTable::Implementation::bindDefinedFunction(const string& name, Expr obj, - bool levelZero) throw() { + bool levelZero, + bool doOverload) throw() { PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null Expr"); ExprManagerScope ems(obj); + if (doOverload) { + if( !bindWithOverloading(name, obj) ){ + return false; + } + } if (levelZero) { d_exprMap->insertAtContextLevelZero(name, obj); d_functions->insertAtContextLevelZero(obj); @@ -112,6 +328,7 @@ void SymbolTable::Implementation::bindDefinedFunction(const string& name, d_exprMap->insert(name, obj); d_functions->insert(obj); } + return true; } bool SymbolTable::Implementation::isBound(const string& name) const throw() { @@ -130,7 +347,12 @@ bool SymbolTable::Implementation::isBoundDefinedFunction(Expr func) const } Expr SymbolTable::Implementation::lookup(const string& name) const throw() { - return (*d_exprMap->find(name)).second; + Expr expr = (*d_exprMap->find(name)).second; + if(isOverloadedFunction(expr)) { + return d_nullExpr; + }else{ + return expr; + } } void SymbolTable::Implementation::bindType(const string& name, Type t, @@ -255,18 +477,55 @@ void SymbolTable::Implementation::reset() { new (this) SymbolTable::Implementation(); } +bool SymbolTable::Implementation::isOverloadedFunction(Expr fun) const { + return d_overload_trie->isOverloadedFunction(fun); +} + +Expr SymbolTable::Implementation::getOverloadedConstantForType(const std::string& name, Type t) const { + return d_overload_trie->getOverloadedConstantForType(name, t); +} + +Expr SymbolTable::Implementation::getOverloadedFunctionForTypes(const std::string& name, + const std::vector< Type >& argTypes) const { + return d_overload_trie->getOverloadedFunctionForTypes(name, argTypes); +} + +bool SymbolTable::Implementation::bindWithOverloading(const string& name, Expr obj){ + CDHashMap::const_iterator it = d_exprMap->find(name); + if(it != d_exprMap->end()) { + const Expr& prev_bound_obj = (*it).second; + if(prev_bound_obj!=obj) { + return d_overload_trie->bind(name, prev_bound_obj, obj); + } + } + return true; +} + +bool SymbolTable::isOverloadedFunction(Expr fun) const { + return d_implementation->isOverloadedFunction(fun); +} + +Expr SymbolTable::getOverloadedConstantForType(const std::string& name, Type t) const { + return d_implementation->getOverloadedConstantForType(name, t); +} + +Expr SymbolTable::getOverloadedFunctionForTypes(const std::string& name, + const std::vector< Type >& argTypes) const { + return d_implementation->getOverloadedFunctionForTypes(name, argTypes); +} + SymbolTable::SymbolTable() : d_implementation(new SymbolTable::Implementation()) {} SymbolTable::~SymbolTable() {} -void SymbolTable::bind(const string& name, Expr obj, bool levelZero) throw() { - d_implementation->bind(name, obj, levelZero); +bool SymbolTable::bind(const string& name, Expr obj, bool levelZero, bool doOverload) throw() { + return d_implementation->bind(name, obj, levelZero, doOverload); } -void SymbolTable::bindDefinedFunction(const string& name, Expr obj, - bool levelZero) throw() { - d_implementation->bindDefinedFunction(name, obj, levelZero); +bool SymbolTable::bindDefinedFunction(const string& name, Expr obj, + bool levelZero, bool doOverload) throw() { + return d_implementation->bindDefinedFunction(name, obj, levelZero, doOverload); } void SymbolTable::bindType(const string& name, Type t, bool levelZero) throw() { diff --git a/src/expr/symbol_table.h b/src/expr/symbol_table.h index e64488563..b6ca7a76f 100644 --- a/src/expr/symbol_table.h +++ b/src/expr/symbol_table.h @@ -43,33 +43,58 @@ class CVC4_PUBLIC SymbolTable { ~SymbolTable(); /** - * Bind an expression to a name in the current scope level. If - * name is already bound to an expression in the current + * Bind an expression to a name in the current scope level. + * + * When doOverload is false: + * if name is already bound to an expression in the current * level, then the binding is replaced. If name is bound * in a previous level, then the binding is "covered" by this one - * until the current scope is popped. If levelZero is true the name - * shouldn't be already bound. + * until the current scope is popped. + * If levelZero is true the name shouldn't be already bound. + * + * When doOverload is true: + * if name is already bound to an expression in the current + * level, then we mark the previous bound expression and obj as overloaded + * functions. * * @param name an identifier * @param obj the expression to bind to name * @param levelZero set if the binding must be done at level 0 + * @param doOverload set if the binding can overload the function name. + * + * Returns false if the binding was invalid. */ - void bind(const std::string& name, Expr obj, bool levelZero = false) throw(); + bool bind(const std::string& name, Expr obj, bool levelZero = false, + bool doOverload = false) throw(); /** - * Bind a function body to a name in the current scope. If - * name is already bound to an expression in the current + * Bind a function body to a name in the current scope. + * + * When doOverload is false: + * if name is already bound to an expression in the current * level, then the binding is replaced. If name is bound * in a previous level, then the binding is "covered" by this one - * until the current scope is popped. Same as bind() but registers - * this as a function (so that isBoundDefinedFunction() returns true). + * until the current scope is popped. + * If levelZero is true the name shouldn't be already bound. + * + * When doOverload is true: + * if name is already bound to an expression in the current + * level, then we mark the previous bound expression and obj as overloaded + * functions. + * + * Same as bind() but registers this as a function (so that + * isBoundDefinedFunction() returns true). * * @param name an identifier * @param obj the expression to bind to name * @param levelZero set if the binding must be done at level 0 + * @param doOverload set if the binding can overload the function name. + * + * Returns false if the binding was invalid. */ - void bindDefinedFunction(const std::string& name, Expr obj, - bool levelZero = false) throw(); + bool bindDefinedFunction(const std::string& name, Expr obj, + bool levelZero = false, + bool doOverload = false) throw(); /** * Bind a type to a name in the current scope. If name @@ -133,7 +158,9 @@ class CVC4_PUBLIC SymbolTable { * Lookup a bound expression. * * @param name the identifier to lookup - * @returns the expression bound to name in the current scope. + * @returns the unique expression bound to name in the current scope. + * It returns the null expression if there is not a unique expression bound to + * name in the current scope (i.e. if there is not exactly one). */ Expr lookup(const std::string& name) const throw(); @@ -178,7 +205,31 @@ class CVC4_PUBLIC SymbolTable { /** Reset everything. */ void reset(); - + + //------------------------ operator overloading + /** is this function overloaded? */ + bool isOverloadedFunction(Expr fun) const; + + /** Get overloaded constant for type. + * If possible, it returns the defined symbol with name + * that has type t. Otherwise returns null expression. + */ + Expr getOverloadedConstantForType(const std::string& name, Type t) const; + + /** + * If possible, returns the unique defined function for a name + * that expects arguments with types "argTypes". + * For example, if argTypes = ( T1, ..., Tn ), then this may return + * an expression with type function( T1, ..., Tn ), or constructor( T1, ...., Tn ). + * + * If there is not a unique defined function for the name and argTypes, + * this returns the null expression. This can happen either if there are + * no functions with name and expected argTypes, or alternatively there is + * more than one function with name and expected argTypes. + */ + Expr getOverloadedFunctionForTypes(const std::string& name, const std::vector< Type >& argTypes) const; + //------------------------ end operator overloading + private: // Copying and assignment have not yet been implemented. SymbolTable(const SymbolTable&);