Add tuple projection operator (#5904)
[cvc5.git] / src / parser / smt2 / smt2.cpp
index a8a2eb27a5cc133a2a11bb0c42ba18217ff1d208..049bf8b4d090b00fd8a216cb896800703708225f 100644 (file)
@@ -18,7 +18,6 @@
 #include <algorithm>
 
 #include "base/check.h"
-#include "expr/type.h"
 #include "options/options.h"
 #include "parser/antlr_input.h"
 #include "parser/parser.h"
 namespace CVC4 {
 namespace parser {
 
-Smt2::Smt2(api::Solver* solver, Input* input, bool strictMode, bool parseOnly)
-    : Parser(solver, input, strictMode, parseOnly),
+Smt2::Smt2(api::Solver* solver,
+           SymbolManager* sm,
+           Input* input,
+           bool strictMode,
+           bool parseOnly)
+    : Parser(solver, sm, input, strictMode, parseOnly),
       d_logicSet(false),
       d_seenSetLogic(false)
 {
-  pushScope(true);
 }
 
-Smt2::~Smt2() { popScope(); }
+Smt2::~Smt2() {}
 
 void Smt2::addArithmeticOperators() {
   addOperator(api::PLUS, "+");
@@ -80,10 +82,6 @@ void Smt2::addTranscendentalOperators()
 
 void Smt2::addQuantifiersOperators()
 {
-  if (!strictModeEnabled())
-  {
-    addOperator(api::INST_CLOSURE, "inst-closure");
-  }
 }
 
 void Smt2::addBitvectorOperators() {
@@ -441,17 +439,16 @@ api::Term Smt2::bindDefineFunRec(
   api::Sort ft = mkFlatFunctionType(sorts, t, flattenVars);
 
   // allow overloading
-  return bindVar(fname, ft, ExprManager::VAR_FLAG_NONE, true);
+  return bindVar(fname, ft, false, true);
 }
 
 void Smt2::pushDefineFunRecScope(
     const std::vector<std::pair<std::string, api::Sort>>& sortedVarNames,
     api::Term func,
     const std::vector<api::Term>& flattenVars,
-    std::vector<api::Term>& bvs,
-    bool bindingLevel)
+    std::vector<api::Term>& bvs)
 {
-  pushScope(bindingLevel);
+  pushScope();
 
   // bound variables are those that are explicitly named in the preamble
   // of the define-fun(s)-rec command, we define them here
@@ -470,16 +467,6 @@ void Smt2::reset() {
   d_logic = LogicInfo();
   operatorKindMap.clear();
   d_lastNamedTerm = std::pair<api::Term, std::string>();
-  this->Parser::reset();
-  pushScope(true);
-}
-
-void Smt2::resetAssertions() {
-  // Remove all declarations except the ones at level 0.
-  while (this->scopeLevel() > 0) {
-    this->popScope();
-  }
-  pushScope(true);
 }
 
 std::unique_ptr<Command> Smt2::invConstraint(
@@ -559,9 +546,12 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
     if(d_logic.areIntegersUsed()) {
       defineType("Int", d_solver->getIntegerSort(), true, true);
       addArithmeticOperators();
-      addOperator(api::INTS_DIVISION, "div");
-      addOperator(api::INTS_MODULUS, "mod");
-      addOperator(api::ABS, "abs");
+      if (!strictModeEnabled() || !d_logic.isLinear())
+      {
+        addOperator(api::INTS_DIVISION, "div");
+        addOperator(api::INTS_MODULUS, "mod");
+        addOperator(api::ABS, "abs");
+      }
       addIndexedOperator(api::DIVISIBLE, api::DIVISIBLE, "divisible");
     }
 
@@ -651,7 +641,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
     addOperator(api::INTERSECTION_MIN, "intersection_min");
     addOperator(api::DIFFERENCE_SUBTRACT, "difference_subtract");
     addOperator(api::DIFFERENCE_REMOVE, "difference_remove");
-    addOperator(api::SUBBAG, "bag.is_included");
+    addOperator(api::SUBBAG, "subbag");
     addOperator(api::BAG_COUNT, "bag.count");
     addOperator(api::DUPLICATE_REMOVAL, "duplicate_removal");
     addOperator(api::MK_BAG, "bag");
@@ -747,14 +737,6 @@ bool Smt2::sygus_v2() const
   return getLanguage() == language::input::LANG_SYGUS_V2;
 }
 
-void Smt2::setInfo(const std::string& flag, const SExpr& sexpr) {
-  // TODO: ???
-}
-
-void Smt2::setOption(const std::string& flag, const SExpr& sexpr) {
-  // TODO: ???
-}
-
 void Smt2::checkThatLogicIsSet()
 {
   if (!logicIsSet())
@@ -1088,12 +1070,11 @@ api::Term Smt2::applyParseOp(ParseOp& p, std::vector<api::Term>& args)
   else if (p.d_kind == api::APPLY_SELECTOR && !p.d_expr.isNull())
   {
     // tuple selector case
-    Integer x = p.d_expr.getExpr().getConst<Rational>().getNumerator();
-    if (!x.fitsUnsignedInt())
+    if (!p.d_expr.isUInt64())
     {
-      parseError("index of tupSel is larger than size of unsigned int");
+      parseError("index of tupSel is larger than size of uint64_t");
     }
-    unsigned int n = x.toUnsignedInt();
+    uint64_t n = p.d_expr.getUInt64();
     if (args.size() != 1)
     {
       parseError("tupSel should only be applied to one tuple argument");
@@ -1116,6 +1097,12 @@ api::Term Smt2::applyParseOp(ParseOp& p, std::vector<api::Term>& args)
     Debug("parser") << "applyParseOp: return selector " << ret << std::endl;
     return ret;
   }
+  else if (p.d_kind == api::TUPLE_PROJECT)
+  {
+    api::Term ret = d_solver->mkTerm(p.d_op, args[0]);
+    Debug("parser") << "applyParseOp: return projection " << ret << std::endl;
+    return ret;
+  }
   else if (p.d_kind != api::NULL_EXPR)
   {
     // it should not have an expression or type specified at this point
@@ -1215,28 +1202,16 @@ api::Term Smt2::applyParseOp(ParseOp& p, std::vector<api::Term>& args)
   return ret;
 }
 
-api::Term Smt2::setNamedAttribute(api::Term& expr, const SExpr& sexpr)
+void Smt2::notifyNamedExpression(api::Term& expr, std::string name)
 {
-  if (!sexpr.isKeyword())
-  {
-    parseError("improperly formed :named annotation");
-  }
-  std::string name = sexpr.getValue();
   checkUserSymbol(name);
-  // ensure expr is a closed subterm
-  if (expr.getExpr().hasFreeVariable())
-  {
-    std::stringstream ss;
-    ss << ":named annotations can only name terms that are closed";
-    parseError(ss.str());
-  }
-  // check that sexpr is a fresh function symbol, and reserve it
-  reserveSymbolAtAssertionLevel(name);
-  // define it
-  api::Term func = bindVar(name, expr.getSort(), ExprManager::VAR_FLAG_DEFINED);
-  // remember the last term to have been given a :named attribute
+  // remember the expression name in the symbol manager
+  getSymbolManager()->setExpressionName(expr, name, false);
+  // define the variable
+  defineVar(name, expr);
+  // set the last named term, which ensures that we catch when assertions are
+  // named
   setLastNamedTerm(expr, name);
-  return func;
 }
 
 api::Term Smt2::mkAnd(const std::vector<api::Term>& es)