Add tuple projection operator (#5904)
[cvc5.git] / src / parser / smt2 / smt2.cpp
index 1e5d2155a0ed1fafec20b3bb825f8c067b950efa..049bf8b4d090b00fd8a216cb896800703708225f 100644 (file)
@@ -5,7 +5,7 @@
  **   Andrew Reynolds, Andres Noetzli, Morgan Deters
  ** This file is part of the CVC4 project.
  ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
+ ** in the top-level source directory and their institutional affiliations.
  ** All rights reserved.  See the file COPYING in the top-level source
  ** directory for licensing information.\endverbatim
  **
@@ -18,7 +18,6 @@
 #include <algorithm>
 
 #include "base/check.h"
-#include "expr/type.h"
 #include "options/options.h"
 #include "parser/antlr_input.h"
 #include "parser/parser.h"
 namespace CVC4 {
 namespace parser {
 
-Smt2::Smt2(api::Solver* solver, Input* input, bool strictMode, bool parseOnly)
-    : Parser(solver, input, strictMode, parseOnly),
+Smt2::Smt2(api::Solver* solver,
+           SymbolManager* sm,
+           Input* input,
+           bool strictMode,
+           bool parseOnly)
+    : Parser(solver, sm, input, strictMode, parseOnly),
       d_logicSet(false),
       d_seenSetLogic(false)
 {
-  if (!strictModeEnabled())
-  {
-    addCoreSymbols();
-  }
 }
 
+Smt2::~Smt2() {}
+
 void Smt2::addArithmeticOperators() {
   addOperator(api::PLUS, "+");
   addOperator(api::MINUS, "-");
@@ -81,10 +82,6 @@ void Smt2::addTranscendentalOperators()
 
 void Smt2::addQuantifiersOperators()
 {
-  if (!strictModeEnabled())
-  {
-    addOperator(api::INST_CLOSURE, "inst-closure");
-  }
 }
 
 void Smt2::addBitvectorOperators() {
@@ -288,9 +285,9 @@ void Smt2::addSepOperators() {
 
 void Smt2::addCoreSymbols()
 {
-  defineType("Bool", d_solver->getBooleanSort());
-  defineVar("true", d_solver->mkTrue());
-  defineVar("false", d_solver->mkFalse());
+  defineType("Bool", d_solver->getBooleanSort(), true, true);
+  defineVar("true", d_solver->mkTrue(), true, true);
+  defineVar("false", d_solver->mkFalse(), true, true);
   addOperator(api::AND, "and");
   addOperator(api::DISTINCT, "distinct");
   addOperator(api::EQUAL, "=");
@@ -442,17 +439,16 @@ api::Term Smt2::bindDefineFunRec(
   api::Sort ft = mkFlatFunctionType(sorts, t, flattenVars);
 
   // allow overloading
-  return bindVar(fname, ft, ExprManager::VAR_FLAG_NONE, true);
+  return bindVar(fname, ft, false, true);
 }
 
 void Smt2::pushDefineFunRecScope(
     const std::vector<std::pair<std::string, api::Sort>>& sortedVarNames,
     api::Term func,
     const std::vector<api::Term>& flattenVars,
-    std::vector<api::Term>& bvs,
-    bool bindingLevel)
+    std::vector<api::Term>& bvs)
 {
-  pushScope(bindingLevel);
+  pushScope();
 
   // bound variables are those that are explicitly named in the preamble
   // of the define-fun(s)-rec command, we define them here
@@ -471,62 +467,6 @@ void Smt2::reset() {
   d_logic = LogicInfo();
   operatorKindMap.clear();
   d_lastNamedTerm = std::pair<api::Term, std::string>();
-  this->Parser::reset();
-
-  if( !strictModeEnabled() ) {
-    addCoreSymbols();
-  }
-}
-
-void Smt2::resetAssertions() {
-  // Remove all declarations except the ones at level 0.
-  while (this->scopeLevel() > 0) {
-    this->popScope();
-  }
-}
-
-Smt2::SynthFunFactory::SynthFunFactory(
-    Smt2* smt2,
-    const std::string& id,
-    bool isInv,
-    api::Sort range,
-    std::vector<std::pair<std::string, api::Sort>>& sortedVarNames)
-    : d_smt2(smt2), d_id(id), d_sort(range), d_isInv(isInv)
-{
-  if (range.isNull())
-  {
-    smt2->parseError("Must supply return type for synth-fun.");
-  }
-  if (range.isFunction())
-  {
-    smt2->parseError("Cannot use synth-fun with function return type.");
-  }
-
-  std::vector<api::Sort> varSorts;
-  for (const std::pair<std::string, api::Sort>& p : sortedVarNames)
-  {
-    varSorts.push_back(p.second);
-  }
-
-  api::Sort funSort = varSorts.empty()
-                          ? range
-                          : d_smt2->d_solver->mkFunctionSort(varSorts, range);
-
-  // we do not allow overloading for synth fun
-  d_fun = d_smt2->bindBoundVar(id, funSort);
-
-  Debug("parser-sygus") << "Define synth fun : " << id << std::endl;
-
-  d_smt2->pushScope(true);
-  d_sygusVars = d_smt2->bindBoundVars(sortedVarNames);
-}
-
-std::unique_ptr<Command> Smt2::SynthFunFactory::mkCommand(api::Grammar* grammar)
-{
-  Debug("parser-sygus") << "...read synth fun " << d_id << std::endl;
-  d_smt2->popScope();
-  return std::unique_ptr<Command>(new SynthFunCommand(
-      d_smt2->d_solver, d_id, d_fun, d_sygusVars, d_sort, d_isInv, grammar));
 }
 
 std::unique_ptr<Command> Smt2::invConstraint(
@@ -556,8 +496,7 @@ std::unique_ptr<Command> Smt2::invConstraint(
     terms.push_back(getVariable(name));
   }
 
-  return std::unique_ptr<Command>(
-      new SygusInvConstraintCommand(api::termVectorToExprs(terms)));
+  return std::unique_ptr<Command>(new SygusInvConstraintCommand(terms));
 }
 
 Command* Smt2::setLogic(std::string name, bool fromCommand)
@@ -605,17 +544,20 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
 
   if(d_logic.isTheoryEnabled(theory::THEORY_ARITH)) {
     if(d_logic.areIntegersUsed()) {
-      defineType("Int", d_solver->getIntegerSort());
+      defineType("Int", d_solver->getIntegerSort(), true, true);
       addArithmeticOperators();
-      addOperator(api::INTS_DIVISION, "div");
-      addOperator(api::INTS_MODULUS, "mod");
-      addOperator(api::ABS, "abs");
+      if (!strictModeEnabled() || !d_logic.isLinear())
+      {
+        addOperator(api::INTS_DIVISION, "div");
+        addOperator(api::INTS_MODULUS, "mod");
+        addOperator(api::ABS, "abs");
+      }
       addIndexedOperator(api::DIVISIBLE, api::DIVISIBLE, "divisible");
     }
 
     if (d_logic.areRealsUsed())
     {
-      defineType("Real", d_solver->getRealSort());
+      defineType("Real", d_solver->getRealSort(), true, true);
       addArithmeticOperators();
       addOperator(api::DIVISION, "/");
       if (!strictModeEnabled())
@@ -664,7 +606,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
 
   if(d_logic.isTheoryEnabled(theory::THEORY_DATATYPES)) {
     const std::vector<api::Sort> types;
-    defineType("Tuple", d_solver->mkTupleSort(types));
+    defineType("Tuple", d_solver->mkTupleSort(types), true, true);
     addDatatypesOperators();
   }
 
@@ -684,16 +626,35 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
     addOperator(api::CARD, "card");
     addOperator(api::COMPLEMENT, "complement");
     addOperator(api::CHOOSE, "choose");
+    addOperator(api::IS_SINGLETON, "is_singleton");
     addOperator(api::JOIN, "join");
     addOperator(api::PRODUCT, "product");
     addOperator(api::TRANSPOSE, "transpose");
     addOperator(api::TCLOSURE, "tclosure");
   }
 
+  if (d_logic.isTheoryEnabled(theory::THEORY_BAGS))
+  {
+    defineVar("emptybag", d_solver->mkEmptyBag(d_solver->getNullSort()));
+    addOperator(api::UNION_MAX, "union_max");
+    addOperator(api::UNION_DISJOINT, "union_disjoint");
+    addOperator(api::INTERSECTION_MIN, "intersection_min");
+    addOperator(api::DIFFERENCE_SUBTRACT, "difference_subtract");
+    addOperator(api::DIFFERENCE_REMOVE, "difference_remove");
+    addOperator(api::SUBBAG, "subbag");
+    addOperator(api::BAG_COUNT, "bag.count");
+    addOperator(api::DUPLICATE_REMOVAL, "duplicate_removal");
+    addOperator(api::MK_BAG, "bag");
+    addOperator(api::BAG_CARD, "bag.card");
+    addOperator(api::BAG_CHOOSE, "bag.choose");
+    addOperator(api::BAG_IS_SINGLETON, "bag.is_singleton");
+    addOperator(api::BAG_FROM_SET, "bag.from_set");
+    addOperator(api::BAG_TO_SET, "bag.to_set");
+  }
   if(d_logic.isTheoryEnabled(theory::THEORY_STRINGS)) {
-    defineType("String", d_solver->getStringSort());
-    defineType("RegLan", d_solver->getRegExpSort());
-    defineType("Int", d_solver->getIntegerSort());
+    defineType("String", d_solver->getStringSort(), true, true);
+    defineType("RegLan", d_solver->getRegExpSort(), true, true);
+    defineType("Int", d_solver->getIntegerSort(), true, true);
 
     if (getLanguage() == language::input::LANG_SMTLIB_V2_6
         || getLanguage() == language::input::LANG_SYGUS_V2)
@@ -718,11 +679,11 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
   }
 
   if (d_logic.isTheoryEnabled(theory::THEORY_FP)) {
-    defineType("RoundingMode", d_solver->getRoundingmodeSort());
-    defineType("Float16", d_solver->mkFloatingPointSort(5, 11));
-    defineType("Float32", d_solver->mkFloatingPointSort(8, 24));
-    defineType("Float64", d_solver->mkFloatingPointSort(11, 53));
-    defineType("Float128", d_solver->mkFloatingPointSort(15, 113));
+    defineType("RoundingMode", d_solver->getRoundingModeSort(), true, true);
+    defineType("Float16", d_solver->mkFloatingPointSort(5, 11), true, true);
+    defineType("Float32", d_solver->mkFloatingPointSort(8, 24), true, true);
+    defineType("Float64", d_solver->mkFloatingPointSort(11, 53), true, true);
+    defineType("Float128", d_solver->mkFloatingPointSort(15, 113), true, true);
 
     defineVar("RNE", d_solver->mkRoundingMode(api::ROUND_NEAREST_TIES_TO_EVEN));
     defineVar("roundNearestTiesToEven",
@@ -760,8 +721,8 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
 api::Grammar* Smt2::mkGrammar(const std::vector<api::Term>& boundVars,
                               const std::vector<api::Term>& ntSymbols)
 {
-  d_allocGrammars.emplace_back(new api::Grammar(
-      std::move(d_solver->mkSygusGrammar(boundVars, ntSymbols))));
+  d_allocGrammars.emplace_back(
+      new api::Grammar(d_solver->mkSygusGrammar(boundVars, ntSymbols)));
   return d_allocGrammars.back().get();
 }
 
@@ -776,14 +737,6 @@ bool Smt2::sygus_v2() const
   return getLanguage() == language::input::LANG_SYGUS_V2;
 }
 
-void Smt2::setInfo(const std::string& flag, const SExpr& sexpr) {
-  // TODO: ???
-}
-
-void Smt2::setOption(const std::string& flag, const SExpr& sexpr) {
-  // TODO: ???
-}
-
 void Smt2::checkThatLogicIsSet()
 {
   if (!logicIsSet())
@@ -820,7 +773,8 @@ void Smt2::checkLogicAllowsFreeSorts()
   if (!d_logic.isTheoryEnabled(theory::THEORY_UF)
       && !d_logic.isTheoryEnabled(theory::THEORY_ARRAYS)
       && !d_logic.isTheoryEnabled(theory::THEORY_DATATYPES)
-      && !d_logic.isTheoryEnabled(theory::THEORY_SETS))
+      && !d_logic.isTheoryEnabled(theory::THEORY_SETS)
+      && !d_logic.isTheoryEnabled(theory::THEORY_BAGS))
   {
     parseErrorLogic("Free sort symbols not allowed in ");
   }
@@ -1083,33 +1037,22 @@ api::Term Smt2::applyParseOp(ParseOp& p, std::vector<api::Term>& args)
       parseError("Too many arguments to array constant.");
     }
     api::Term constVal = args[0];
-    if (!constVal.isConst())
+
+    // To parse array constants taking reals whose values are specified by
+    // rationals, e.g. ((as const (Array Int Real)) (/ 1 3)), we must handle
+    // the fact that (/ 1 3) is the division of constants 1 and 3, and not
+    // the resulting constant rational value. Thus, we must construct the
+    // resulting rational here. This also is applied for integral real values
+    // like 5.0 which are converted to (/ 5 1) to distinguish them from
+    // integer constants. We must ensure numerator and denominator are
+    // constant and the denominator is non-zero.
+    if (constVal.getKind() == api::DIVISION)
     {
-      // To parse array constants taking reals whose values are specified by
-      // rationals, e.g. ((as const (Array Int Real)) (/ 1 3)), we must handle
-      // the fact that (/ 1 3) is the division of constants 1 and 3, and not
-      // the resulting constant rational value. Thus, we must construct the
-      // resulting rational here. This also is applied for integral real values
-      // like 5.0 which are converted to (/ 5 1) to distinguish them from
-      // integer constants. We must ensure numerator and denominator are
-      // constant and the denominator is non-zero.
-      if (constVal.getKind() == api::DIVISION && constVal[0].isConst()
-          && constVal[1].isConst()
-          && !constVal[1].getExpr().getConst<Rational>().isZero())
-      {
-        std::stringstream sdiv;
-        sdiv << constVal[0] << "/" << constVal[1];
-        constVal = d_solver->mkReal(sdiv.str());
-      }
-      if (!constVal.isConst())
-      {
-        std::stringstream ss;
-        ss << "expected constant term inside array constant, but found "
-           << "nonconstant term:" << std::endl
-           << "the term: " << constVal;
-        parseError(ss.str());
-      }
+      std::stringstream sdiv;
+      sdiv << constVal[0] << "/" << constVal[1];
+      constVal = d_solver->mkReal(sdiv.str());
     }
+
     if (!p.d_type.getArrayElementSort().isComparableTo(constVal.getSort()))
     {
       std::stringstream ss;
@@ -1127,12 +1070,11 @@ api::Term Smt2::applyParseOp(ParseOp& p, std::vector<api::Term>& args)
   else if (p.d_kind == api::APPLY_SELECTOR && !p.d_expr.isNull())
   {
     // tuple selector case
-    Integer x = p.d_expr.getExpr().getConst<Rational>().getNumerator();
-    if (!x.fitsUnsignedInt())
+    if (!p.d_expr.isUInt64())
     {
-      parseError("index of tupSel is larger than size of unsigned int");
+      parseError("index of tupSel is larger than size of uint64_t");
     }
-    unsigned int n = x.toUnsignedInt();
+    uint64_t n = p.d_expr.getUInt64();
     if (args.size() != 1)
     {
       parseError("tupSel should only be applied to one tuple argument");
@@ -1155,6 +1097,12 @@ api::Term Smt2::applyParseOp(ParseOp& p, std::vector<api::Term>& args)
     Debug("parser") << "applyParseOp: return selector " << ret << std::endl;
     return ret;
   }
+  else if (p.d_kind == api::TUPLE_PROJECT)
+  {
+    api::Term ret = d_solver->mkTerm(p.d_op, args[0]);
+    Debug("parser") << "applyParseOp: return projection " << ret << std::endl;
+    return ret;
+  }
   else if (p.d_kind != api::NULL_EXPR)
   {
     // it should not have an expression or type specified at this point
@@ -1200,6 +1148,12 @@ api::Term Smt2::applyParseOp(ParseOp& p, std::vector<api::Term>& args)
       parseError(
           "eqrange predicate requires option --arrays-exp to be enabled.");
     }
+    if (kind == api::SINGLETON && args.size() == 1)
+    {
+      api::Term ret = d_solver->mkTerm(api::SINGLETON, args[0]);
+      Debug("parser") << "applyParseOp: return singleton " << ret << std::endl;
+      return ret;
+    }
     api::Term ret = d_solver->mkTerm(kind, args);
     Debug("parser") << "applyParseOp: return default builtin " << ret
                     << std::endl;
@@ -1248,28 +1202,16 @@ api::Term Smt2::applyParseOp(ParseOp& p, std::vector<api::Term>& args)
   return ret;
 }
 
-api::Term Smt2::setNamedAttribute(api::Term& expr, const SExpr& sexpr)
+void Smt2::notifyNamedExpression(api::Term& expr, std::string name)
 {
-  if (!sexpr.isKeyword())
-  {
-    parseError("improperly formed :named annotation");
-  }
-  std::string name = sexpr.getValue();
   checkUserSymbol(name);
-  // ensure expr is a closed subterm
-  if (expr.getExpr().hasFreeVariable())
-  {
-    std::stringstream ss;
-    ss << ":named annotations can only name terms that are closed";
-    parseError(ss.str());
-  }
-  // check that sexpr is a fresh function symbol, and reserve it
-  reserveSymbolAtAssertionLevel(name);
-  // define it
-  api::Term func = bindVar(name, expr.getSort(), ExprManager::VAR_FLAG_DEFINED);
-  // remember the last term to have been given a :named attribute
+  // remember the expression name in the symbol manager
+  getSymbolManager()->setExpressionName(expr, name, false);
+  // define the variable
+  defineVar(name, expr);
+  // set the last named term, which ensures that we catch when assertions are
+  // named
   setLastNamedTerm(expr, name);
-  return func;
 }
 
 api::Term Smt2::mkAnd(const std::vector<api::Term>& es)