Update symbol table to support operator overloading (#1154)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 29 Sep 2017 04:58:03 +0000 (23:58 -0500)
committerAndres Noetzli <andres.noetzli@gmail.com>
Fri, 29 Sep 2017 04:58:03 +0000 (21:58 -0700)
src/expr/symbol_table.cpp
src/expr/symbol_table.h

index c760b3a802bee671b1fab24889a292c724ef02f6..b411d8dfbfa5bc974e79b6ee13c82f7ae3f978cb 100644 (file)
@@ -21,6 +21,7 @@
 #include <ostream>
 #include <string>
 #include <utility>
+#include <unordered_map>
 
 #include "context/cdhashmap.h"
 #include "context/cdhashset.h"
@@ -41,23 +42,204 @@ using ::std::pair;
 using ::std::string;
 using ::std::vector;
 
+// This data structure stores a trie of expressions with
+// the same name, and must be distinguished by their argument types.
+// It is context-dependent.
+class OverloadedTypeTrie
+{
+public:
+  OverloadedTypeTrie(Context * c ) :
+    d_overloaded_symbols(new (true) CDHashSet<Expr, ExprHashFunction>(c)) {
+  }  
+  ~OverloadedTypeTrie() {
+    d_overloaded_symbols->deleteSelf();
+  }
+  /** is this function overloaded? */
+  bool isOverloadedFunction(Expr fun) const;
+  
+  /** Get overloaded constant for type.
+   * If possible, it returns a defined symbol with name
+   * that has type t. Otherwise returns null expression.
+  */
+  Expr getOverloadedConstantForType(const std::string& name, Type t) const;
+  
+  /**
+   * If possible, returns a defined function for a name
+   * and a vector of expected argument types. Otherwise returns
+   * null expression.
+   */
+  Expr getOverloadedFunctionForTypes(const std::string& name, const std::vector< Type >& argTypes) const;
+  /** called when obj is bound to name, and prev_bound_obj was already bound to name 
+  * Returns false if the binding is invalid.
+  */
+  bool bind(const string& name, Expr prev_bound_obj, Expr obj);
+private:
+  /** Marks expression obj with name as overloaded. 
+   * Adds relevant information to the type arg trie data structure.
+   * It returns false if there is already an expression bound to that name
+   * whose type expects the same arguments as the type of obj but is not identical
+   * to the type of obj.  For example, if we declare :
+   *
+   * (declare-datatypes () ((List (cons (hd Int) (tl List)) (nil))))
+   * (declare-fun cons (Int List) List)
+   *
+   * cons : constructor_type( Int, List, List )
+   * cons : function_type( Int, List, List )
+   * 
+   * These are put in the same place in the trie but do not have identical type,
+   * hence we return false.
+  */
+  bool markOverloaded(const string& name, Expr obj);
+  /** the null expression */
+  Expr d_nullExpr;
+  // The (context-independent) trie storing that maps expected argument
+  // vectors to symbols. All expressions stored in d_symbols are only 
+  // interpreted as active if they also appear in the context-dependent
+  // set d_overloaded_symbols.
+  class TypeArgTrie {
+  public:
+    // children of this node
+    std::map< Type, TypeArgTrie > d_children;
+    // symbols at this node
+    std::map< Type, Expr > d_symbols;
+  };
+  /** for each string with operator overloading, this stores the data structure above. */
+  std::unordered_map< std::string, TypeArgTrie > d_overload_type_arg_trie;
+  /** The set of overloaded symbols. */
+  CDHashSet<Expr, ExprHashFunction>* d_overloaded_symbols;
+};
+
+bool OverloadedTypeTrie::isOverloadedFunction(Expr fun) const {
+  return d_overloaded_symbols->find(fun)!=d_overloaded_symbols->end();
+}
+
+Expr OverloadedTypeTrie::getOverloadedConstantForType(const std::string& name, Type t) const {
+  std::unordered_map< std::string, TypeArgTrie >::const_iterator it = d_overload_type_arg_trie.find(name);
+  if(it!=d_overload_type_arg_trie.end()) {
+    std::map< Type, Expr >::const_iterator its = it->second.d_symbols.find(t);
+    if(its!=it->second.d_symbols.end()) {
+      Expr expr = its->second;
+      // must be an active symbol
+      if(isOverloadedFunction(expr)) {
+        return expr;
+      }
+    }
+  }
+  return d_nullExpr;
+}
+
+Expr OverloadedTypeTrie::getOverloadedFunctionForTypes(const std::string& name, 
+                                                       const std::vector< Type >& argTypes) const {
+  std::unordered_map< std::string, TypeArgTrie >::const_iterator it = d_overload_type_arg_trie.find(name);
+  if(it!=d_overload_type_arg_trie.end()) {
+    const TypeArgTrie * tat = &it->second;
+    for(unsigned i=0; i<argTypes.size(); i++) {
+      std::map< Type, TypeArgTrie >::const_iterator itc = tat->d_children.find(argTypes[i]);
+      if(itc!=tat->d_children.end()) {
+        tat = &itc->second;
+      }else{
+        // no functions match 
+        return d_nullExpr;
+      }
+    }
+    // now, we must ensure that there is *only* one active symbol at this node
+    Expr retExpr;
+    for(std::map< Type, Expr >::const_iterator its = tat->d_symbols.begin(); its != tat->d_symbols.end(); ++its) {
+      Expr expr = its->second;
+      if(isOverloadedFunction(expr)) {
+        if(retExpr.isNull()) {
+          retExpr = expr;
+        }else{
+          // multiple functions match
+          return d_nullExpr;
+        }
+      }
+    }
+    return retExpr;
+  }
+  return d_nullExpr;
+}
+
+bool OverloadedTypeTrie::bind(const string& name, Expr prev_bound_obj, Expr obj) {
+  bool retprev = true;
+  if(!isOverloadedFunction(prev_bound_obj)) {
+    // mark previous as overloaded
+    retprev = markOverloaded(name, prev_bound_obj);
+  }
+  // mark this as overloaded
+  bool retobj = markOverloaded(name, obj);
+  return retprev && retobj;
+}
+
+bool OverloadedTypeTrie::markOverloaded(const string& name, Expr obj) {
+  Trace("parser-overloading") << "Overloaded function : " << name;
+  Trace("parser-overloading") << " with type " << obj.getType() << std::endl;
+  // get the argument types
+  Type t = obj.getType();
+  Type rangeType = t;
+  std::vector< Type > argTypes;
+  if(t.isFunction()) {
+    argTypes = static_cast<FunctionType>(t).getArgTypes();
+    rangeType = static_cast<FunctionType>(t).getRangeType();
+  }else if(t.isConstructor()) {
+    argTypes = static_cast<ConstructorType>(t).getArgTypes();
+    rangeType = static_cast<ConstructorType>(t).getRangeType();
+  }else if(t.isTester()) {
+    argTypes.push_back( static_cast<TesterType>(t).getDomain() );
+    rangeType = static_cast<TesterType>(t).getRangeType();
+  }else if(t.isSelector()) {
+    argTypes.push_back( static_cast<SelectorType>(t).getDomain() );
+    rangeType = static_cast<SelectorType>(t).getRangeType();
+  }
+  // add to the trie
+  TypeArgTrie * tat = &d_overload_type_arg_trie[name];
+  for(unsigned i=0; i<argTypes.size(); i++) {
+    tat = &(tat->d_children[argTypes[i]]);
+  }
+  
+  // types can be identical but vary on the kind of the type, thus we must distinguish based on this
+  std::map< Type, Expr >::iterator it = tat->d_symbols.find( rangeType );
+  if( it!=tat->d_symbols.end() ){
+    Expr prev_obj = it->second;
+    // if there is already an active function with the same name and expects the same argument types
+    if( isOverloadedFunction(prev_obj) ){
+      if( prev_obj.getType()==obj.getType() ){
+        //types are identical, simply ignore it
+        return true;
+      }else{
+        //otherwise there is no way to distinguish these types, we return an error
+        return false;
+      }
+    }
+  }
+  
+  // otherwise, update the symbols
+  d_overloaded_symbols->insert(obj);
+  tat->d_symbols[rangeType] = obj;
+  return true;
+}
+
+
 class SymbolTable::Implementation {
  public:
   Implementation()
       : d_context(),
         d_exprMap(new (true) CDHashMap<string, Expr>(&d_context)),
         d_typeMap(new (true) TypeMap(&d_context)),
-        d_functions(new (true) CDHashSet<Expr, ExprHashFunction>(&d_context)) {}
+        d_functions(new (true) CDHashSet<Expr, ExprHashFunction>(&d_context)){
+    d_overload_trie = new OverloadedTypeTrie(&d_context);
+  }
 
   ~Implementation() {
     d_exprMap->deleteSelf();
     d_typeMap->deleteSelf();
     d_functions->deleteSelf();
+    delete d_overload_trie;
   }
 
-  void bind(const string& name, Expr obj, bool levelZero) throw();
-  void bindDefinedFunction(const string& name, Expr obj,
-                           bool levelZero) throw();
+  bool bind(const string& name, Expr obj, bool levelZero, bool doOverload) throw();
+  bool bindDefinedFunction(const string& name, Expr obj,
+                           bool levelZero, bool doOverload) throw();
   void bindType(const string& name, Type t, bool levelZero = false) throw();
   void bindType(const string& name, const vector<Type>& params, Type t,
                 bool levelZero = false) throw();
@@ -73,7 +255,16 @@ class SymbolTable::Implementation {
   void pushScope() throw();
   size_t getLevel() const throw();
   void reset();
-
+  //------------------------ operator overloading
+  /** implementation of function from header */
+  bool isOverloadedFunction(Expr fun) const;
+  
+  /** implementation of function from header */
+  Expr getOverloadedConstantForType(const std::string& name, Type t) const;
+  
+  /** implementation of function from header */
+  Expr getOverloadedFunctionForTypes(const std::string& name, const std::vector< Type >& argTypes) const;
+  //------------------------ end operator overloading
  private:
   /** The context manager for the scope maps. */
   Context d_context;
@@ -87,24 +278,49 @@ class SymbolTable::Implementation {
 
   /** A set of defined functions. */
   CDHashSet<Expr, ExprHashFunction>* d_functions;
+  
+  //------------------------ operator overloading
+  // the null expression
+  Expr d_nullExpr;
+  // overloaded type trie, stores all information regarding overloading
+  OverloadedTypeTrie * d_overload_trie;
+  /** bind with overloading
+   * This is called whenever obj is bound to name where overloading symbols is allowed.
+   * If a symbol is previously bound to that name, it marks both as overloaded.
+   * Returns false if the binding was invalid.
+  */
+  bool bindWithOverloading(const string& name, Expr obj);
+  //------------------------ end operator overloading
 }; /* SymbolTable::Implementation */
 
-void SymbolTable::Implementation::bind(const string& name, Expr obj,
-                                       bool levelZero) throw() {
+bool SymbolTable::Implementation::bind(const string& name, Expr obj,
+                                       bool levelZero, bool doOverload) throw() {
   PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null Expr");
   ExprManagerScope ems(obj);
+  if (doOverload) {
+    if( !bindWithOverloading(name, obj) ){
+      return false;
+    }
+  }
   if (levelZero) {
     d_exprMap->insertAtContextLevelZero(name, obj);
   } else {
     d_exprMap->insert(name, obj);
   }
+  return true;
 }
 
-void SymbolTable::Implementation::bindDefinedFunction(const string& name,
+bool SymbolTable::Implementation::bindDefinedFunction(const string& name,
                                                       Expr obj,
-                                                      bool levelZero) throw() {
+                                                      bool levelZero, 
+                                                      bool doOverload) throw() {
   PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null Expr");
   ExprManagerScope ems(obj);
+  if (doOverload) {
+    if( !bindWithOverloading(name, obj) ){
+      return false;
+    }
+  }
   if (levelZero) {
     d_exprMap->insertAtContextLevelZero(name, obj);
     d_functions->insertAtContextLevelZero(obj);
@@ -112,6 +328,7 @@ void SymbolTable::Implementation::bindDefinedFunction(const string& name,
     d_exprMap->insert(name, obj);
     d_functions->insert(obj);
   }
+  return true;
 }
 
 bool SymbolTable::Implementation::isBound(const string& name) const throw() {
@@ -130,7 +347,12 @@ bool SymbolTable::Implementation::isBoundDefinedFunction(Expr func) const
 }
 
 Expr SymbolTable::Implementation::lookup(const string& name) const throw() {
-  return (*d_exprMap->find(name)).second;
+  Expr expr = (*d_exprMap->find(name)).second;
+  if(isOverloadedFunction(expr)) {
+    return d_nullExpr;
+  }else{
+    return expr;
+  }
 }
 
 void SymbolTable::Implementation::bindType(const string& name, Type t,
@@ -255,18 +477,55 @@ void SymbolTable::Implementation::reset() {
   new (this) SymbolTable::Implementation();
 }
 
+bool SymbolTable::Implementation::isOverloadedFunction(Expr fun) const {
+  return d_overload_trie->isOverloadedFunction(fun);
+}
+
+Expr SymbolTable::Implementation::getOverloadedConstantForType(const std::string& name, Type t) const {
+  return d_overload_trie->getOverloadedConstantForType(name, t);
+}
+
+Expr SymbolTable::Implementation::getOverloadedFunctionForTypes(const std::string& name, 
+                                                                const std::vector< Type >& argTypes) const {
+  return d_overload_trie->getOverloadedFunctionForTypes(name, argTypes);
+}
+
+bool SymbolTable::Implementation::bindWithOverloading(const string& name, Expr obj){
+  CDHashMap<string, Expr>::const_iterator it = d_exprMap->find(name);
+  if(it != d_exprMap->end()) {
+    const Expr& prev_bound_obj = (*it).second;
+    if(prev_bound_obj!=obj) {
+      return d_overload_trie->bind(name, prev_bound_obj, obj);
+    }
+  }
+  return true;
+}
+
+bool SymbolTable::isOverloadedFunction(Expr fun) const {
+  return d_implementation->isOverloadedFunction(fun);
+}
+
+Expr SymbolTable::getOverloadedConstantForType(const std::string& name, Type t) const {
+  return d_implementation->getOverloadedConstantForType(name, t);
+}
+
+Expr SymbolTable::getOverloadedFunctionForTypes(const std::string& name, 
+                                                const std::vector< Type >& argTypes) const {
+  return d_implementation->getOverloadedFunctionForTypes(name, argTypes);
+}
+
 SymbolTable::SymbolTable()
     : d_implementation(new SymbolTable::Implementation()) {}
 
 SymbolTable::~SymbolTable() {}
 
-void SymbolTable::bind(const string& name, Expr obj, bool levelZero) throw() {
-  d_implementation->bind(name, obj, levelZero);
+bool SymbolTable::bind(const string& name, Expr obj, bool levelZero, bool doOverload) throw() {
+  return d_implementation->bind(name, obj, levelZero, doOverload);
 }
 
-void SymbolTable::bindDefinedFunction(const string& name, Expr obj,
-                                      bool levelZero) throw() {
-  d_implementation->bindDefinedFunction(name, obj, levelZero);
+bool SymbolTable::bindDefinedFunction(const string& name, Expr obj,
+                                      bool levelZero, bool doOverload) throw() {
+  return d_implementation->bindDefinedFunction(name, obj, levelZero, doOverload);
 }
 
 void SymbolTable::bindType(const string& name, Type t, bool levelZero) throw() {
index e644885635414597ab26d61379057d7d56a093d3..b6ca7a76f4d3ffb44d48a8c3e862cc20f202cbb6 100644 (file)
@@ -43,33 +43,58 @@ class CVC4_PUBLIC SymbolTable {
   ~SymbolTable();
 
   /**
-   * Bind an expression to a name in the current scope level.  If
-   * <code>name</code> is already bound to an expression in the current
+   * Bind an expression to a name in the current scope level.  
+   *
+   * When doOverload is false:
+   * if <code>name</code> is already bound to an expression in the current
    * level, then the binding is replaced. If <code>name</code> is bound
    * in a previous level, then the binding is "covered" by this one
-   * until the current scope is popped. If levelZero is true the name
-   * shouldn't be already bound.
+   * until the current scope is popped. 
+   * If levelZero is true the name shouldn't be already bound.
+   *
+   * When doOverload is true:
+   * if <code>name</code> is already bound to an expression in the current
+   * level, then we mark the previous bound expression and obj as overloaded
+   * functions.
    *
    * @param name an identifier
    * @param obj the expression to bind to <code>name</code>
    * @param levelZero set if the binding must be done at level 0
+   * @param doOverload set if the binding can overload the function name.
+   *
+   * Returns false if the binding was invalid.
    */
-  void bind(const std::string& name, Expr obj, bool levelZero = false) throw();
+  bool bind(const std::string& name, Expr obj, bool levelZero = false, 
+            bool doOverload = false) throw();
 
   /**
-   * Bind a function body to a name in the current scope.  If
-   * <code>name</code> is already bound to an expression in the current
+   * Bind a function body to a name in the current scope.  
+   *
+   * When doOverload is false:
+   * if <code>name</code> is already bound to an expression in the current
    * level, then the binding is replaced. If <code>name</code> is bound
    * in a previous level, then the binding is "covered" by this one
-   * until the current scope is popped.  Same as bind() but registers
-   * this as a function (so that isBoundDefinedFunction() returns true).
+   * until the current scope is popped. 
+   * If levelZero is true the name shouldn't be already bound.
+   *
+   * When doOverload is true:
+   * if <code>name</code> is already bound to an expression in the current
+   * level, then we mark the previous bound expression and obj as overloaded
+   * functions.
+   *
+   * Same as bind() but registers this as a function (so that 
+   * isBoundDefinedFunction() returns true).
    *
    * @param name an identifier
    * @param obj the expression to bind to <code>name</code>
    * @param levelZero set if the binding must be done at level 0
+   * @param doOverload set if the binding can overload the function name.
+   *
+   * Returns false if the binding was invalid.
    */
-  void bindDefinedFunction(const std::string& name, Expr obj,
-                           bool levelZero = false) throw();
+  bool bindDefinedFunction(const std::string& name, Expr obj,
+                           bool levelZero = false, 
+                           bool doOverload = false) throw();
 
   /**
    * Bind a type to a name in the current scope.  If <code>name</code>
@@ -133,7 +158,9 @@ class CVC4_PUBLIC SymbolTable {
    * Lookup a bound expression.
    *
    * @param name the identifier to lookup
-   * @returns the expression bound to <code>name</code> in the current scope.
+   * @returns the unique expression bound to <code>name</code> in the current scope.
+   * It returns the null expression if there is not a unique expression bound to 
+   * <code>name</code> in the current scope (i.e. if there is not exactly one).
    */
   Expr lookup(const std::string& name) const throw();
 
@@ -178,7 +205,31 @@ class CVC4_PUBLIC SymbolTable {
 
   /** Reset everything. */
   void reset();
-
+  
+  //------------------------ operator overloading
+  /** is this function overloaded? */
+  bool isOverloadedFunction(Expr fun) const;
+  
+  /** Get overloaded constant for type.
+   * If possible, it returns the defined symbol with name
+   * that has type t. Otherwise returns null expression.
+  */
+  Expr getOverloadedConstantForType(const std::string& name, Type t) const;
+  
+  /**
+   * If possible, returns the unique defined function for a name
+   * that expects arguments with types "argTypes".
+   * For example, if argTypes = ( T1, ..., Tn ), then this may return 
+   * an expression with type function( T1, ..., Tn ), or constructor( T1, ...., Tn ).
+   * 
+   * If there is not a unique defined function for the name and argTypes,
+   * this returns the null expression. This can happen either if there are
+   * no functions with name and expected argTypes, or alternatively there is
+   * more than one function with name and expected argTypes.
+   */
+  Expr getOverloadedFunctionForTypes(const std::string& name, const std::vector< Type >& argTypes) const;
+  //------------------------ end operator overloading
+  
  private:
   // Copying and assignment have not yet been implemented.
   SymbolTable(const SymbolTable&);