Rename namespace CVC5 to cvc5. (#6258)
[cvc5.git] / src / expr / symbol_table.cpp
index 33046be7aca1eabf886068cdc1dff9de519211a1..1c513fea4d7fcd6b51151ea27e5295a19bbf1381 100644 (file)
@@ -2,10 +2,10 @@
 /*! \file symbol_table.cpp
  ** \verbatim
  ** Top contributors (to current version):
- **   Morgan Deters, Christopher L. Conway, Francois Bobot
+ **   Andrew Reynolds, Tim King, Morgan Deters
  ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
+ ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
  ** All rights reserved.  See the file COPYING in the top-level source
  ** directory for licensing information.\endverbatim
  **
 #include <unordered_map>
 #include <utility>
 
+#include "api/cvc4cpp.h"
 #include "context/cdhashmap.h"
 #include "context/cdhashset.h"
 #include "context/context.h"
-#include "expr/expr.h"
-#include "expr/expr_manager_scope.h"
-#include "expr/type.h"
 
-namespace CVC4 {
+namespace cvc5 {
 
-using ::CVC4::context::CDHashMap;
-using ::CVC4::context::CDHashSet;
-using ::CVC4::context::Context;
+using ::cvc5::context::CDHashMap;
+using ::cvc5::context::CDHashSet;
+using ::cvc5::context::Context;
 using ::std::copy;
 using ::std::endl;
 using ::std::ostream_iterator;
@@ -42,35 +40,86 @@ using ::std::pair;
 using ::std::string;
 using ::std::vector;
 
-// This data structure stores a trie of expressions with
-// the same name, and must be distinguished by their argument types.
-// It is context-dependent.
+/** Overloaded type trie.
+ *
+ * This data structure stores a trie of expressions with
+ * the same name, and must be distinguished by their argument types.
+ * It is context-dependent.
+ *
+ * Using the argument allowFunVariants,
+ * it may either be configured to allow function variants or not,
+ * where a function variant is function that expects the same
+ * argument types as another.
+ *
+ * For example, the following definitions introduce function
+ * variants for the symbol f:
+ *
+ * 1. (declare-fun f (Int) Int) and
+ *    (declare-fun f (Int) Bool)
+ *
+ * 2. (declare-fun f (Int) Int) and
+ *    (declare-fun f (Int) Int)
+ *
+ * 3. (declare-datatypes ((Tup 0)) ((f (data Int)))) and
+ *    (declare-fun f (Int) Tup)
+ *
+ * 4. (declare-datatypes ((Tup 0)) ((mkTup (f Int)))) and
+ *    (declare-fun f (Tup) Bool)
+ * 
+ * If function variants is set to true, we allow function variants
+ * but not function redefinition. In examples 2 and 3, f is 
+ * declared twice as a symbol of identical argument and range
+ * types. We never accept these definitions. However, we do
+ * allow examples 1 and 4 above when allowFunVariants is true.
+ * 
+ * For 0-argument functions (constants), we always allow
+ * function variants.  That is, we always accept these examples:
+ * 
+ * 5.  (declare-fun c () Int)
+ *     (declare-fun c () Bool)
+ * 
+ * 6.  (declare-datatypes ((Enum 0)) ((c)))
+ *     (declare-fun c () Int)
+ * 
+ * and always reject constant redefinition such as:
+ * 
+ * 7. (declare-fun c () Int)
+ *    (declare-fun c () Int)
+ * 
+ * 8. (declare-datatypes ((Enum 0)) ((c))) and
+ *    (declare-fun c () Enum)
+ */
 class OverloadedTypeTrie {
  public:
-  OverloadedTypeTrie(Context* c)
-      : d_overloaded_symbols(new (true) CDHashSet<Expr, ExprHashFunction>(c)) {}
+  OverloadedTypeTrie(Context* c, bool allowFunVariants = false)
+      : d_overloaded_symbols(
+            new (true) CDHashSet<api::Term, api::TermHashFunction>(c)),
+        d_allowFunctionVariants(allowFunVariants)
+  {
+  }
   ~OverloadedTypeTrie() { d_overloaded_symbols->deleteSelf(); }
 
   /** is this function overloaded? */
-  bool isOverloadedFunction(Expr fun) const;
+  bool isOverloadedFunction(api::Term fun) const;
 
   /** Get overloaded constant for type.
    * If possible, it returns a defined symbol with name
    * that has type t. Otherwise returns null expression.
    */
-  Expr getOverloadedConstantForType(const std::string& name, Type t) const;
+  api::Term getOverloadedConstantForType(const std::string& name,
+                                         api::Sort t) const;
 
   /**
    * If possible, returns a defined function for a name
    * and a vector of expected argument types. Otherwise returns
    * null expression.
    */
-  Expr getOverloadedFunctionForTypes(const std::string& name,
-                                     const std::vector<Type>& argTypes) const;
+  api::Term getOverloadedFunctionForTypes(
+      const std::string& name, const std::vector<api::Sort>& argTypes) const;
   /** called when obj is bound to name, and prev_bound_obj was already bound to
    * name Returns false if the binding is invalid.
    */
-  bool bind(const string& name, Expr prev_bound_obj, Expr obj);
+  bool bind(const string& name, api::Term prev_bound_obj, api::Term obj);
 
  private:
   /** Marks expression obj with name as overloaded.
@@ -88,9 +137,9 @@ class OverloadedTypeTrie {
    * These are put in the same place in the trie but do not have identical type,
    * hence we return false.
    */
-  bool markOverloaded(const string& name, Expr obj);
+  bool markOverloaded(const string& name, api::Term obj);
   /** the null expression */
-  Expr d_nullExpr;
+  api::Term d_nullTerm;
   // The (context-independent) trie storing that maps expected argument
   // vectors to symbols. All expressions stored in d_symbols are only
   // interpreted as active if they also appear in the context-dependent
@@ -98,75 +147,84 @@ class OverloadedTypeTrie {
   class TypeArgTrie {
    public:
     // children of this node
-    std::map<Type, TypeArgTrie> d_children;
+    std::map<api::Sort, TypeArgTrie> d_children;
     // symbols at this node
-    std::map<Type, Expr> d_symbols;
+    std::map<api::Sort, api::Term> d_symbols;
   };
   /** for each string with operator overloading, this stores the data structure
    * above. */
   std::unordered_map<std::string, TypeArgTrie> d_overload_type_arg_trie;
   /** The set of overloaded symbols. */
-  CDHashSet<Expr, ExprHashFunction>* d_overloaded_symbols;
+  CDHashSet<api::Term, api::TermHashFunction>* d_overloaded_symbols;
+  /** allow function variants
+   * This is true if we allow overloading (non-constant) functions that expect
+   * the same argument types.
+   */
+  bool d_allowFunctionVariants;
+  /** get unique overloaded function
+  * If tat->d_symbols contains an active overloaded function, it
+  * returns that function, where that function must be unique 
+  * if reqUnique=true.
+  * Otherwise, it returns the null expression.
+  */
+  api::Term getOverloadedFunctionAt(const TypeArgTrie* tat,
+                                    bool reqUnique = true) const;
 };
 
-bool OverloadedTypeTrie::isOverloadedFunction(Expr fun) const {
+bool OverloadedTypeTrie::isOverloadedFunction(api::Term fun) const
+{
   return d_overloaded_symbols->find(fun) != d_overloaded_symbols->end();
 }
 
-Expr OverloadedTypeTrie::getOverloadedConstantForType(const std::string& name,
-                                                      Type t) const {
+api::Term OverloadedTypeTrie::getOverloadedConstantForType(
+    const std::string& name, api::Sort t) const
+{
   std::unordered_map<std::string, TypeArgTrie>::const_iterator it =
       d_overload_type_arg_trie.find(name);
   if (it != d_overload_type_arg_trie.end()) {
-    std::map<Type, Expr>::const_iterator its = it->second.d_symbols.find(t);
+    std::map<api::Sort, api::Term>::const_iterator its =
+        it->second.d_symbols.find(t);
     if (its != it->second.d_symbols.end()) {
-      Expr expr = its->second;
+      api::Term expr = its->second;
       // must be an active symbol
       if (isOverloadedFunction(expr)) {
         return expr;
       }
     }
   }
-  return d_nullExpr;
+  return d_nullTerm;
 }
 
-Expr OverloadedTypeTrie::getOverloadedFunctionForTypes(
-    const std::string& name, const std::vector<Type>& argTypes) const {
+api::Term OverloadedTypeTrie::getOverloadedFunctionForTypes(
+    const std::string& name, const std::vector<api::Sort>& argTypes) const
+{
   std::unordered_map<std::string, TypeArgTrie>::const_iterator it =
       d_overload_type_arg_trie.find(name);
   if (it != d_overload_type_arg_trie.end()) {
     const TypeArgTrie* tat = &it->second;
     for (unsigned i = 0; i < argTypes.size(); i++) {
-      std::map<Type, TypeArgTrie>::const_iterator itc =
+      std::map<api::Sort, TypeArgTrie>::const_iterator itc =
           tat->d_children.find(argTypes[i]);
       if (itc != tat->d_children.end()) {
         tat = &itc->second;
       } else {
-        // no functions match
-        return d_nullExpr;
-      }
-    }
-    // now, we must ensure that there is *only* one active symbol at this node
-    Expr retExpr;
-    for (std::map<Type, Expr>::const_iterator its = tat->d_symbols.begin();
-         its != tat->d_symbols.end(); ++its) {
-      Expr expr = its->second;
-      if (isOverloadedFunction(expr)) {
-        if (retExpr.isNull()) {
-          retExpr = expr;
-        } else {
-          // multiple functions match
-          return d_nullExpr;
-        }
+        Trace("parser-overloading")
+            << "Could not find overloaded function " << name << std::endl;
+
+          // no functions match
+        return d_nullTerm;
       }
     }
-    return retExpr;
+    // we ensure that there is *only* one active symbol at this node
+    return getOverloadedFunctionAt(tat);
   }
-  return d_nullExpr;
+  return d_nullTerm;
 }
 
-bool OverloadedTypeTrie::bind(const string& name, Expr prev_bound_obj,
-                              Expr obj) {
+bool OverloadedTypeTrie::bind(const string& name,
+                              api::Term prev_bound_obj,
+                              api::Term obj)
+{
   bool retprev = true;
   if (!isOverloadedFunction(prev_bound_obj)) {
     // mark previous as overloaded
@@ -177,25 +235,33 @@ bool OverloadedTypeTrie::bind(const string& name, Expr prev_bound_obj,
   return retprev && retobj;
 }
 
-bool OverloadedTypeTrie::markOverloaded(const string& name, Expr obj) {
+bool OverloadedTypeTrie::markOverloaded(const string& name, api::Term obj)
+{
   Trace("parser-overloading") << "Overloaded function : " << name;
-  Trace("parser-overloading") << " with type " << obj.getType() << std::endl;
+  Trace("parser-overloading") << " with type " << obj.getSort() << std::endl;
   // get the argument types
-  Type t = obj.getType();
-  Type rangeType = t;
-  std::vector<Type> argTypes;
-  if (t.isFunction()) {
-    argTypes = static_cast<FunctionType>(t).getArgTypes();
-    rangeType = static_cast<FunctionType>(t).getRangeType();
-  } else if (t.isConstructor()) {
-    argTypes = static_cast<ConstructorType>(t).getArgTypes();
-    rangeType = static_cast<ConstructorType>(t).getRangeType();
-  } else if (t.isTester()) {
-    argTypes.push_back(static_cast<TesterType>(t).getDomain());
-    rangeType = static_cast<TesterType>(t).getRangeType();
-  } else if (t.isSelector()) {
-    argTypes.push_back(static_cast<SelectorType>(t).getDomain());
-    rangeType = static_cast<SelectorType>(t).getRangeType();
+  api::Sort t = obj.getSort();
+  api::Sort rangeType = t;
+  std::vector<api::Sort> argTypes;
+  if (t.isFunction())
+  {
+    argTypes = t.getFunctionDomainSorts();
+    rangeType = t.getFunctionCodomainSort();
+  }
+  else if (t.isConstructor())
+  {
+    argTypes = t.getConstructorDomainSorts();
+    rangeType = t.getConstructorCodomainSort();
+  }
+  else if (t.isTester())
+  {
+    argTypes.push_back(t.getTesterDomainSort());
+    rangeType = t.getTesterCodomainSort();
+  }
+  else if (t.isSelector())
+  {
+    argTypes.push_back(t.getSelectorDomainSort());
+    rangeType = t.getSelectorCodomainSort();
   }
   // add to the trie
   TypeArgTrie* tat = &d_overload_type_arg_trie[name];
@@ -203,24 +269,33 @@ bool OverloadedTypeTrie::markOverloaded(const string& name, Expr obj) {
     tat = &(tat->d_children[argTypes[i]]);
   }
 
-  // types can be identical but vary on the kind of the type, thus we must
-  // distinguish based on this
-  std::map<Type, Expr>::iterator it = tat->d_symbols.find(rangeType);
-  if (it != tat->d_symbols.end()) {
-    Expr prev_obj = it->second;
-    // if there is already an active function with the same name and expects the
-    // same argument types
-    if (isOverloadedFunction(prev_obj)) {
-      if (prev_obj.getType() == obj.getType()) {
-        // types are identical, simply ignore it
-        return true;
-      } else {
-        // otherwise there is no way to distinguish these types, we return an
-        // error
+  // check if function variants are allowed here
+  if (d_allowFunctionVariants || argTypes.empty())
+  {
+    // they are allowed, check for redefinition
+    std::map<api::Sort, api::Term>::iterator it =
+        tat->d_symbols.find(rangeType);
+    if (it != tat->d_symbols.end())
+    {
+      api::Term prev_obj = it->second;
+      // if there is already an active function with the same name and expects
+      // the same argument types and has the same return type, we reject the 
+      // re-declaration here.
+      if (isOverloadedFunction(prev_obj))
+      {
         return false;
       }
     }
   }
+  else
+  {
+    // they are not allowed, we cannot have any function defined here.
+    api::Term existingFun = getOverloadedFunctionAt(tat, false);
+    if (!existingFun.isNull())
+    {
+      return false;
+    }
+  }
 
   // otherwise, update the symbols
   d_overloaded_symbols->insert(obj);
@@ -228,175 +303,186 @@ bool OverloadedTypeTrie::markOverloaded(const string& name, Expr obj) {
   return true;
 }
 
+api::Term OverloadedTypeTrie::getOverloadedFunctionAt(
+    const OverloadedTypeTrie::TypeArgTrie* tat, bool reqUnique) const
+{
+  api::Term retExpr;
+  for (std::map<api::Sort, api::Term>::const_iterator its =
+           tat->d_symbols.begin();
+       its != tat->d_symbols.end();
+       ++its)
+  {
+    api::Term expr = its->second;
+    if (isOverloadedFunction(expr))
+    {
+      if (retExpr.isNull())
+      {
+        if (!reqUnique) 
+        {
+          return expr;
+        }
+        else 
+        {
+          retExpr = expr;
+        }
+      }
+      else
+      {
+        // multiple functions match
+        return d_nullTerm;
+      }
+    }
+  }
+  return retExpr;
+}
+
 class SymbolTable::Implementation {
  public:
   Implementation()
       : d_context(),
-        d_exprMap(new (true) CDHashMap<string, Expr>(&d_context)),
-        d_typeMap(new (true) TypeMap(&d_context)),
-        d_functions(new (true) CDHashSet<Expr, ExprHashFunction>(&d_context)) {
-    d_overload_trie = new OverloadedTypeTrie(&d_context);
+        d_exprMap(&d_context),
+        d_typeMap(&d_context),
+        d_overload_trie(&d_context)
+  {
+    // use an outermost push, to be able to clear definitions not at level zero
+    d_context.push();
   }
 
-  ~Implementation() {
-    d_exprMap->deleteSelf();
-    d_typeMap->deleteSelf();
-    d_functions->deleteSelf();
-    delete d_overload_trie;
-  }
+  ~Implementation() { d_context.pop(); }
 
-  bool bind(const string& name, Expr obj, bool levelZero, bool doOverload);
-  bool bindDefinedFunction(const string& name, Expr obj, bool levelZero,
-                           bool doOverload);
-  void bindType(const string& name, Type t, bool levelZero = false);
-  void bindType(const string& name, const vector<Type>& params, Type t,
+  bool bind(const string& name, api::Term obj, bool levelZero, bool doOverload);
+  void bindType(const string& name, api::Sort t, bool levelZero = false);
+  void bindType(const string& name,
+                const vector<api::Sort>& params,
+                api::Sort t,
                 bool levelZero = false);
   bool isBound(const string& name) const;
-  bool isBoundDefinedFunction(const string& name) const;
-  bool isBoundDefinedFunction(Expr func) const;
   bool isBoundType(const string& name) const;
-  Expr lookup(const string& name) const;
-  Type lookupType(const string& name) const;
-  Type lookupType(const string& name, const vector<Type>& params) const;
+  api::Term lookup(const string& name) const;
+  api::Sort lookupType(const string& name) const;
+  api::Sort lookupType(const string& name,
+                       const vector<api::Sort>& params) const;
   size_t lookupArity(const string& name);
   void popScope();
   void pushScope();
   size_t getLevel() const;
   void reset();
+  void resetAssertions();
   //------------------------ operator overloading
   /** implementation of function from header */
-  bool isOverloadedFunction(Expr fun) const;
+  bool isOverloadedFunction(api::Term fun) const;
 
   /** implementation of function from header */
-  Expr getOverloadedConstantForType(const std::string& name, Type t) const;
+  api::Term getOverloadedConstantForType(const std::string& name,
+                                         api::Sort t) const;
 
   /** implementation of function from header */
-  Expr getOverloadedFunctionForTypes(const std::string& name,
-                                     const std::vector<Type>& argTypes) const;
+  api::Term getOverloadedFunctionForTypes(
+      const std::string& name, const std::vector<api::Sort>& argTypes) const;
   //------------------------ end operator overloading
  private:
   /** The context manager for the scope maps. */
   Context d_context;
 
   /** A map for expressions. */
-  CDHashMap<string, Expr>* d_exprMap;
+  CDHashMap<string, api::Term> d_exprMap;
 
   /** A map for types. */
-  using TypeMap = CDHashMap<string, std::pair<vector<Type>, Type>>;
-  TypeMap* d_typeMap;
-
-  /** A set of defined functions. */
-  CDHashSet<Expr, ExprHashFunction>* d_functions;
+  using TypeMap = CDHashMap<string, std::pair<vector<api::Sort>, api::Sort>>;
+  TypeMap d_typeMap;
 
   //------------------------ operator overloading
   // the null expression
-  Expr d_nullExpr;
+  api::Term d_nullTerm;
   // overloaded type trie, stores all information regarding overloading
-  OverloadedTypeTrie* d_overload_trie;
+  OverloadedTypeTrie d_overload_trie;
   /** bind with overloading
    * This is called whenever obj is bound to name where overloading symbols is
    * allowed. If a symbol is previously bound to that name, it marks both as
    * overloaded. Returns false if the binding was invalid.
    */
-  bool bindWithOverloading(const string& name, Expr obj);
+  bool bindWithOverloading(const string& name, api::Term obj);
   //------------------------ end operator overloading
 }; /* SymbolTable::Implementation */
 
-bool SymbolTable::Implementation::bind(const string& name, Expr obj,
-                                       bool levelZero, bool doOverload) {
-  PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null Expr");
-  ExprManagerScope ems(obj);
-  if (doOverload) {
-    if (!bindWithOverloading(name, obj)) {
-      return false;
-    }
-  }
-  if (levelZero) {
-    d_exprMap->insertAtContextLevelZero(name, obj);
-  } else {
-    d_exprMap->insert(name, obj);
-  }
-  return true;
-}
-
-bool SymbolTable::Implementation::bindDefinedFunction(const string& name,
-                                                      Expr obj, bool levelZero,
-                                                      bool doOverload) {
-  PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null Expr");
-  ExprManagerScope ems(obj);
+bool SymbolTable::Implementation::bind(const string& name,
+                                       api::Term obj,
+                                       bool levelZero,
+                                       bool doOverload)
+{
+  PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null api::Term");
+  Trace("sym-table") << "SymbolTable: bind " << name
+                     << ", levelZero=" << levelZero
+                     << ", doOverload=" << doOverload << std::endl;
   if (doOverload) {
     if (!bindWithOverloading(name, obj)) {
       return false;
     }
   }
   if (levelZero) {
-    d_exprMap->insertAtContextLevelZero(name, obj);
-    d_functions->insertAtContextLevelZero(obj);
+    d_exprMap.insertAtContextLevelZero(name, obj);
   } else {
-    d_exprMap->insert(name, obj);
-    d_functions->insert(obj);
+    d_exprMap.insert(name, obj);
   }
   return true;
 }
 
 bool SymbolTable::Implementation::isBound(const string& name) const {
-  return d_exprMap->find(name) != d_exprMap->end();
+  return d_exprMap.find(name) != d_exprMap.end();
 }
 
-bool SymbolTable::Implementation::isBoundDefinedFunction(
-    const string& name) const {
-  CDHashMap<string, Expr>::iterator found = d_exprMap->find(name);
-  return found != d_exprMap->end() && d_functions->contains((*found).second);
-}
-
-bool SymbolTable::Implementation::isBoundDefinedFunction(Expr func) const {
-  return d_functions->contains(func);
-}
-
-Expr SymbolTable::Implementation::lookup(const string& name) const {
+api::Term SymbolTable::Implementation::lookup(const string& name) const
+{
   Assert(isBound(name));
-  Expr expr = (*d_exprMap->find(name)).second;
+  api::Term expr = (*d_exprMap.find(name)).second;
   if (isOverloadedFunction(expr)) {
-    return d_nullExpr;
+    return d_nullTerm;
   } else {
     return expr;
   }
 }
 
-void SymbolTable::Implementation::bindType(const string& name, Type t,
-                                           bool levelZero) {
+void SymbolTable::Implementation::bindType(const string& name,
+                                           api::Sort t,
+                                           bool levelZero)
+{
   if (levelZero) {
-    d_typeMap->insertAtContextLevelZero(name, make_pair(vector<Type>(), t));
+    d_typeMap.insertAtContextLevelZero(name, make_pair(vector<api::Sort>(), t));
   } else {
-    d_typeMap->insert(name, make_pair(vector<Type>(), t));
+    d_typeMap.insert(name, make_pair(vector<api::Sort>(), t));
   }
 }
 
 void SymbolTable::Implementation::bindType(const string& name,
-                                           const vector<Type>& params, Type t,
-                                           bool levelZero) {
+                                           const vector<api::Sort>& params,
+                                           api::Sort t,
+                                           bool levelZero)
+{
   if (Debug.isOn("sort")) {
     Debug("sort") << "bindType(" << name << ", [";
     if (params.size() > 0) {
-      copy(params.begin(), params.end() - 1,
-           ostream_iterator<Type>(Debug("sort"), ", "));
+      copy(params.begin(),
+           params.end() - 1,
+           ostream_iterator<api::Sort>(Debug("sort"), ", "));
       Debug("sort") << params.back();
     }
     Debug("sort") << "], " << t << ")" << endl;
   }
   if (levelZero) {
-    d_typeMap->insertAtContextLevelZero(name, make_pair(params, t));
+    d_typeMap.insertAtContextLevelZero(name, make_pair(params, t));
   } else {
-    d_typeMap->insert(name, make_pair(params, t));
+    d_typeMap.insert(name, make_pair(params, t));
   }
 }
 
 bool SymbolTable::Implementation::isBoundType(const string& name) const {
-  return d_typeMap->find(name) != d_typeMap->end();
+  return d_typeMap.find(name) != d_typeMap.end();
 }
 
-Type SymbolTable::Implementation::lookupType(const string& name) const {
-  pair<vector<Type>, Type> p = (*d_typeMap->find(name)).second;
+api::Sort SymbolTable::Implementation::lookupType(const string& name) const
+{
+  std::pair<std::vector<api::Sort>, api::Sort> p =
+      (*d_typeMap.find(name)).second;
   PrettyCheckArgument(p.first.size() == 0, name,
                       "type constructor arity is wrong: "
                       "`%s' requires %u parameters but was provided 0",
@@ -404,69 +490,62 @@ Type SymbolTable::Implementation::lookupType(const string& name) const {
   return p.second;
 }
 
-Type SymbolTable::Implementation::lookupType(const string& name,
-                                             const vector<Type>& params) const {
-  pair<vector<Type>, Type> p = (*d_typeMap->find(name)).second;
+api::Sort SymbolTable::Implementation::lookupType(
+    const string& name, const vector<api::Sort>& params) const
+{
+  std::pair<std::vector<api::Sort>, api::Sort> p =
+      (*d_typeMap.find(name)).second;
   PrettyCheckArgument(p.first.size() == params.size(), params,
                       "type constructor arity is wrong: "
                       "`%s' requires %u parameters but was provided %u",
                       name.c_str(), p.first.size(), params.size());
   if (p.first.size() == 0) {
-    PrettyCheckArgument(p.second.isSort(), name.c_str());
+    PrettyCheckArgument(p.second.isUninterpretedSort(), name.c_str());
     return p.second;
   }
-  if (p.second.isSortConstructor()) {
-    if (Debug.isOn("sort")) {
-      Debug("sort") << "instantiating using a sort constructor" << endl;
-      Debug("sort") << "have formals [";
-      copy(p.first.begin(), p.first.end() - 1,
-           ostream_iterator<Type>(Debug("sort"), ", "));
-      Debug("sort") << p.first.back() << "]" << endl << "parameters   [";
-      copy(params.begin(), params.end() - 1,
-           ostream_iterator<Type>(Debug("sort"), ", "));
-      Debug("sort") << params.back() << "]" << endl
-                    << "type ctor    " << name << endl
-                    << "type is      " << p.second << endl;
-    }
-
-    Type instantiation = SortConstructorType(p.second).instantiate(params);
-
-    Debug("sort") << "instance is  " << instantiation << endl;
-
-    return instantiation;
-  } else if (p.second.isDatatype()) {
-    PrettyCheckArgument(DatatypeType(p.second).isParametric(), name,
-                        "expected parametric datatype");
-    return DatatypeType(p.second).instantiate(params);
-  } else {
-    if (Debug.isOn("sort")) {
-      Debug("sort") << "instantiating using a sort substitution" << endl;
-      Debug("sort") << "have formals [";
-      copy(p.first.begin(), p.first.end() - 1,
-           ostream_iterator<Type>(Debug("sort"), ", "));
-      Debug("sort") << p.first.back() << "]" << endl << "parameters   [";
-      copy(params.begin(), params.end() - 1,
-           ostream_iterator<Type>(Debug("sort"), ", "));
-      Debug("sort") << params.back() << "]" << endl
-                    << "type ctor    " << name << endl
-                    << "type is      " << p.second << endl;
-    }
-
-    Type instantiation = p.second.substitute(p.first, params);
+  if (p.second.isDatatype())
+  {
+    PrettyCheckArgument(
+        p.second.isParametricDatatype(), name, "expected parametric datatype");
+    return p.second.instantiate(params);
+  }
+  bool isSortConstructor = p.second.isSortConstructor();
+  if (Debug.isOn("sort"))
+  {
+    Debug("sort") << "instantiating using a sort "
+                  << (isSortConstructor ? "constructor" : "substitution")
+                  << std::endl;
+    Debug("sort") << "have formals [";
+    copy(p.first.begin(),
+         p.first.end() - 1,
+         ostream_iterator<api::Sort>(Debug("sort"), ", "));
+    Debug("sort") << p.first.back() << "]" << std::endl << "parameters   [";
+    copy(params.begin(),
+         params.end() - 1,
+         ostream_iterator<api::Sort>(Debug("sort"), ", "));
+    Debug("sort") << params.back() << "]" << endl
+                  << "type ctor    " << name << std::endl
+                  << "type is      " << p.second << std::endl;
+  }
+  api::Sort instantiation = isSortConstructor
+                                ? p.second.instantiate(params)
+                                : p.second.substitute(p.first, params);
 
-    Debug("sort") << "instance is  " << instantiation << endl;
+  Debug("sort") << "instance is  " << instantiation << std::endl;
 
-    return instantiation;
-  }
+  return instantiation;
 }
 
 size_t SymbolTable::Implementation::lookupArity(const string& name) {
-  pair<vector<Type>, Type> p = (*d_typeMap->find(name)).second;
+  std::pair<std::vector<api::Sort>, api::Sort> p =
+      (*d_typeMap.find(name)).second;
   return p.first.size();
 }
 
 void SymbolTable::Implementation::popScope() {
-  if (d_context.getLevel() == 0) {
+  // should not pop beyond level one
+  if (d_context.getLevel() == 1)
+  {
     throw ScopeException();
   }
   d_context.pop();
@@ -479,111 +558,125 @@ size_t SymbolTable::Implementation::getLevel() const {
 }
 
 void SymbolTable::Implementation::reset() {
+  Trace("sym-table") << "SymbolTable: reset" << std::endl;
   this->SymbolTable::Implementation::~Implementation();
   new (this) SymbolTable::Implementation();
 }
 
-bool SymbolTable::Implementation::isOverloadedFunction(Expr fun) const {
-  return d_overload_trie->isOverloadedFunction(fun);
+void SymbolTable::Implementation::resetAssertions()
+{
+  Trace("sym-table") << "SymbolTable: resetAssertions" << std::endl;
+  // pop all contexts
+  while (d_context.getLevel() > 0)
+  {
+    d_context.pop();
+  }
+  d_context.push();
+}
+
+bool SymbolTable::Implementation::isOverloadedFunction(api::Term fun) const
+{
+  return d_overload_trie.isOverloadedFunction(fun);
 }
 
-Expr SymbolTable::Implementation::getOverloadedConstantForType(
-    const std::string& name, Type t) const {
-  return d_overload_trie->getOverloadedConstantForType(name, t);
+api::Term SymbolTable::Implementation::getOverloadedConstantForType(
+    const std::string& name, api::Sort t) const
+{
+  return d_overload_trie.getOverloadedConstantForType(name, t);
 }
 
-Expr SymbolTable::Implementation::getOverloadedFunctionForTypes(
-    const std::string& name, const std::vector<Type>& argTypes) const {
-  return d_overload_trie->getOverloadedFunctionForTypes(name, argTypes);
+api::Term SymbolTable::Implementation::getOverloadedFunctionForTypes(
+    const std::string& name, const std::vector<api::Sort>& argTypes) const
+{
+  return d_overload_trie.getOverloadedFunctionForTypes(name, argTypes);
 }
 
 bool SymbolTable::Implementation::bindWithOverloading(const string& name,
-                                                      Expr obj) {
-  CDHashMap<string, Expr>::const_iterator it = d_exprMap->find(name);
-  if (it != d_exprMap->end()) {
-    const Expr& prev_bound_obj = (*it).second;
+                                                      api::Term obj)
+{
+  CDHashMap<string, api::Term>::const_iterator it = d_exprMap.find(name);
+  if (it != d_exprMap.end())
+  {
+    const api::Term& prev_bound_obj = (*it).second;
     if (prev_bound_obj != obj) {
-      return d_overload_trie->bind(name, prev_bound_obj, obj);
+      return d_overload_trie.bind(name, prev_bound_obj, obj);
     }
   }
   return true;
 }
 
-bool SymbolTable::isOverloadedFunction(Expr fun) const {
+bool SymbolTable::isOverloadedFunction(api::Term fun) const
+{
   return d_implementation->isOverloadedFunction(fun);
 }
 
-Expr SymbolTable::getOverloadedConstantForType(const std::string& name,
-                                               Type t) const {
+api::Term SymbolTable::getOverloadedConstantForType(const std::string& name,
+                                                    api::Sort t) const
+{
   return d_implementation->getOverloadedConstantForType(name, t);
 }
 
-Expr SymbolTable::getOverloadedFunctionForTypes(
-    const std::string& name, const std::vector<Type>& argTypes) const {
+api::Term SymbolTable::getOverloadedFunctionForTypes(
+    const std::string& name, const std::vector<api::Sort>& argTypes) const
+{
   return d_implementation->getOverloadedFunctionForTypes(name, argTypes);
 }
 
-SymbolTable::SymbolTable()
-    : d_implementation(new SymbolTable::Implementation()) {}
+SymbolTable::SymbolTable() : d_implementation(new SymbolTable::Implementation())
+{
+}
 
 SymbolTable::~SymbolTable() {}
-
-bool SymbolTable::bind(const string& name, Expr obj, bool levelZero,
-                       bool doOverload) throw() {
+bool SymbolTable::bind(const string& name,
+                       api::Term obj,
+                       bool levelZero,
+                       bool doOverload)
+{
   return d_implementation->bind(name, obj, levelZero, doOverload);
 }
 
-bool SymbolTable::bindDefinedFunction(const string& name, Expr obj,
-                                      bool levelZero, bool doOverload) throw() {
-  return d_implementation->bindDefinedFunction(name, obj, levelZero,
-                                               doOverload);
-}
-
-void SymbolTable::bindType(const string& name, Type t, bool levelZero) throw() {
+void SymbolTable::bindType(const string& name, api::Sort t, bool levelZero)
+{
   d_implementation->bindType(name, t, levelZero);
 }
 
-void SymbolTable::bindType(const string& name, const vector<Type>& params,
-                           Type t, bool levelZero) throw() {
+void SymbolTable::bindType(const string& name,
+                           const vector<api::Sort>& params,
+                           api::Sort t,
+                           bool levelZero)
+{
   d_implementation->bindType(name, params, t, levelZero);
 }
 
-bool SymbolTable::isBound(const string& name) const throw() {
+bool SymbolTable::isBound(const string& name) const
+{
   return d_implementation->isBound(name);
 }
-
-bool SymbolTable::isBoundDefinedFunction(const string& name) const throw() {
-  return d_implementation->isBoundDefinedFunction(name);
-}
-
-bool SymbolTable::isBoundDefinedFunction(Expr func) const throw() {
-  return d_implementation->isBoundDefinedFunction(func);
-}
-bool SymbolTable::isBoundType(const string& name) const throw() {
+bool SymbolTable::isBoundType(const string& name) const
+{
   return d_implementation->isBoundType(name);
 }
-Expr SymbolTable::lookup(const string& name) const throw() {
+api::Term SymbolTable::lookup(const string& name) const
+{
   return d_implementation->lookup(name);
 }
-Type SymbolTable::lookupType(const string& name) const throw() {
+api::Sort SymbolTable::lookupType(const string& name) const
+{
   return d_implementation->lookupType(name);
 }
 
-Type SymbolTable::lookupType(const string& name,
-                             const vector<Type>& params) const throw() {
+api::Sort SymbolTable::lookupType(const string& name,
+                                  const vector<api::Sort>& params) const
+{
   return d_implementation->lookupType(name, params);
 }
 size_t SymbolTable::lookupArity(const string& name) {
   return d_implementation->lookupArity(name);
 }
-void SymbolTable::popScope() throw(ScopeException) {
-  d_implementation->popScope();
-}
-
-void SymbolTable::pushScope() throw() { d_implementation->pushScope(); }
-size_t SymbolTable::getLevel() const throw() {
-  return d_implementation->getLevel();
-}
+void SymbolTable::popScope() { d_implementation->popScope(); }
+void SymbolTable::pushScope() { d_implementation->pushScope(); }
+size_t SymbolTable::getLevel() const { return d_implementation->getLevel(); }
 void SymbolTable::reset() { d_implementation->reset(); }
+void SymbolTable::resetAssertions() { d_implementation->resetAssertions(); }
 
-}  // namespace CVC4
+}  // namespace cvc5