sygusComp2018: Add evaluator (#2090)
authorAndres Noetzli <andres.noetzli@gmail.com>
Tue, 26 Jun 2018 23:09:03 +0000 (16:09 -0700)
committerGitHub <noreply@github.com>
Tue, 26 Jun 2018 23:09:03 +0000 (16:09 -0700)
This commit adds the Evaluator class that can be used to quickly
evaluate terms under a given substitution without going through the
rewriter. This has been shown to lead to significant performance
improvements on SyGuS PBE problems with a large number of inputs because
candidate solutions are evaluated on those inputs. With this commit, the
evaluator only gets called from the SyGuS solver but there are
potentially other places in the code that could profit from it.

src/Makefile.am
src/options/quantifiers_options.toml
src/theory/evaluator.cpp [new file with mode: 0644]
src/theory/evaluator.h [new file with mode: 0644]
src/theory/quantifiers/sygus/term_database_sygus.cpp
src/theory/quantifiers/sygus/term_database_sygus.h
test/unit/Makefile.am
test/unit/theory/evaluator_white.h [new file with mode: 0644]

index ce9f74d9e988c76969d271f2a525bf9f10cff0cc..b36c453e1fdcb4665e63830a1830c0851af9043d 100644 (file)
@@ -179,6 +179,8 @@ libcvc4_la_SOURCES = \
        theory/atom_requests.cpp \
        theory/atom_requests.h \
        theory/care_graph.h \
+       theory/evaluator.cpp \
+       theory/evaluator.h \
        theory/interrupted.h \
        theory/ite_utilities.cpp \
        theory/ite_utilities.h \
index 69868ad8d1297607ff751c964be79b10d5fb3ed5..be4c66b27398071f1632c258ebe68e41cc7f9d22 100644 (file)
@@ -1118,6 +1118,14 @@ header = "options/quantifiers_options.h"
   includes   = ["options/quantifiers_modes.h"]
   help       = "mode for using samples in the counterexample-guided inductive synthesis loop"
 
+[[option]]
+  name       = "sygusEvalOpt"
+  category   = "regular"
+  long       = "sygus-eval-opt"
+  type       = "bool"
+  default    = "true"
+  help       = "use optimized approach for evaluation in sygus"
+
 # Internal uses of sygus
 
 [[option]]
diff --git a/src/theory/evaluator.cpp b/src/theory/evaluator.cpp
new file mode 100644 (file)
index 0000000..ca2140e
--- /dev/null
@@ -0,0 +1,597 @@
+/*********************                                                        */
+/*! \file evaluator.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andres Noetzli
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The Evaluator class
+ **
+ ** The Evaluator class.
+ **/
+
+#include "theory/evaluator.h"
+
+#include "theory/bv/theory_bv_utils.h"
+#include "theory/theory.h"
+#include "util/integer.h"
+
+namespace CVC4 {
+namespace theory {
+
+EvalResult::EvalResult(const EvalResult& other)
+{
+  d_tag = other.d_tag;
+  switch (d_tag)
+  {
+    case BOOL: d_bool = other.d_bool; break;
+    case BITVECTOR:
+      new (&d_bv) BitVector;
+      d_bv = other.d_bv;
+      break;
+    case RATIONAL:
+      new (&d_rat) Rational;
+      d_rat = other.d_rat;
+      break;
+    case STRING:
+      new (&d_str) String;
+      d_str = other.d_str;
+      break;
+    case INVALID: break;
+  }
+}
+
+EvalResult& EvalResult::operator=(const EvalResult& other)
+{
+  if (this != &other)
+  {
+    d_tag = other.d_tag;
+    switch (d_tag)
+    {
+      case BOOL: d_bool = other.d_bool; break;
+      case BITVECTOR:
+        new (&d_bv) BitVector;
+        d_bv = other.d_bv;
+        break;
+      case RATIONAL:
+        new (&d_rat) Rational;
+        d_rat = other.d_rat;
+        break;
+      case STRING:
+        new (&d_str) String;
+        d_str = other.d_str;
+        break;
+      case INVALID: break;
+    }
+  }
+  return *this;
+}
+
+EvalResult::~EvalResult()
+{
+  switch (d_tag)
+  {
+    case BITVECTOR:
+    {
+      d_bv.~BitVector();
+      break;
+    }
+    case RATIONAL:
+    {
+      d_rat.~Rational();
+      break;
+    }
+    case STRING:
+    {
+      d_str.~String();
+      break;
+
+      default: break;
+    }
+  }
+}
+
+Node EvalResult::toNode() const
+{
+  NodeManager* nm = NodeManager::currentNM();
+  switch (d_tag)
+  {
+    case EvalResult::BOOL: return nm->mkConst(d_bool);
+    case EvalResult::BITVECTOR: return nm->mkConst(d_bv);
+    case EvalResult::RATIONAL: return nm->mkConst(d_rat);
+    case EvalResult::STRING: return nm->mkConst(d_str);
+    default:
+    {
+      Trace("evaluator") << "Missing conversion from " << d_tag << " to node"
+                         << std::endl;
+      return Node();
+    }
+  }
+
+  return Node();
+}
+
+Node Evaluator::eval(TNode n,
+                     const std::vector<Node>& args,
+                     const std::vector<Node>& vals)
+{
+  Trace("evaluator") << "Evaluating " << n << " under substitution " << args
+                     << " " << vals << std::endl;
+  return evalInternal(n, args, vals).toNode();
+}
+
+EvalResult Evaluator::evalInternal(TNode n,
+                                   const std::vector<Node>& args,
+                                   const std::vector<Node>& vals)
+{
+  std::unordered_map<TNode, EvalResult, TNodeHashFunction> results;
+  std::vector<TNode> queue;
+  queue.emplace_back(n);
+
+  while (queue.size() != 0)
+  {
+    TNode currNode = queue.back();
+
+    if (results.find(currNode) != results.end())
+    {
+      queue.pop_back();
+      continue;
+    }
+
+    bool doEval = true;
+    for (const auto& currNodeChild : currNode)
+    {
+      if (results.find(currNodeChild) == results.end())
+      {
+        queue.emplace_back(currNodeChild);
+        doEval = false;
+      }
+    }
+
+    if (doEval)
+    {
+      queue.pop_back();
+
+      Node currNodeVal = currNode;
+      if (currNode.isVar())
+      {
+        const auto& it = std::find(args.begin(), args.end(), currNode);
+
+        if (it == args.end())
+        {
+          return EvalResult();
+        }
+
+        ptrdiff_t pos = std::distance(args.begin(), it);
+        currNodeVal = vals[pos];
+      }
+      else if (currNode.getKind() == kind::APPLY_UF
+               && currNode.getOperator().getKind() == kind::LAMBDA)
+      {
+        // Create a copy of the current substitutions
+        std::vector<Node> lambdaArgs(args);
+        std::vector<Node> lambdaVals(vals);
+
+        // Add the values for the arguments of the lambda as substitutions at
+        // the beginning of the vector to shadow variables from outer scopes
+        // with the same name
+        Node op = currNode.getOperator();
+        for (const auto& lambdaArg : op[0])
+        {
+          lambdaArgs.insert(lambdaArgs.begin(), lambdaArg);
+        }
+
+        for (const auto& lambdaVal : currNode)
+        {
+          lambdaVals.insert(lambdaVals.begin(), results[lambdaVal].toNode());
+        }
+
+        // Lambdas are evaluated in a recursive fashion because each evaluation
+        // requires different substitutions
+        results[currNode] = evalInternal(op[1], lambdaArgs, lambdaVals);
+        continue;
+      }
+
+      switch (currNodeVal.getKind())
+      {
+        case kind::CONST_BOOLEAN:
+          results[currNode] = EvalResult(currNodeVal.getConst<bool>());
+          break;
+
+        case kind::NOT:
+        {
+          results[currNode] = EvalResult(!(results[currNode[0]].d_bool));
+          break;
+        }
+
+        case kind::AND:
+        {
+          bool res = results[currNode[0]].d_bool;
+          for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
+          {
+            res = res && results[currNode[i]].d_bool;
+          }
+          results[currNode] = EvalResult(res);
+          break;
+        }
+
+        case kind::OR:
+        {
+          bool res = results[currNode[0]].d_bool;
+          for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
+          {
+            res = res || results[currNode[i]].d_bool;
+          }
+          results[currNode] = EvalResult(res);
+          break;
+        }
+
+        case kind::CONST_RATIONAL:
+        {
+          const Rational& r = currNodeVal.getConst<Rational>();
+          results[currNode] = EvalResult(r);
+          break;
+        }
+
+        case kind::PLUS:
+        {
+          Rational res = results[currNode[0]].d_rat;
+          for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
+          {
+            res = res + results[currNode[i]].d_rat;
+          }
+          results[currNode] = EvalResult(res);
+          break;
+        }
+
+        case kind::MINUS:
+        {
+          const Rational& x = results[currNode[0]].d_rat;
+          const Rational& y = results[currNode[1]].d_rat;
+          results[currNode] = EvalResult(x - y);
+          break;
+        }
+
+        case kind::MULT:
+        {
+          Rational res = results[currNode[0]].d_rat;
+          for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
+          {
+            res = res * results[currNode[i]].d_rat;
+          }
+          results[currNode] = EvalResult(res);
+          break;
+        }
+
+        case kind::GEQ:
+        {
+          const Rational& x = results[currNode[0]].d_rat;
+          const Rational& y = results[currNode[1]].d_rat;
+          results[currNode] = EvalResult(x >= y);
+          break;
+        }
+
+        case kind::CONST_STRING:
+          results[currNode] = EvalResult(currNodeVal.getConst<String>());
+          break;
+
+        case kind::STRING_CONCAT:
+        {
+          String res = results[currNode[0]].d_str;
+          for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
+          {
+            res = res.concat(results[currNode[i]].d_str);
+          }
+          results[currNode] = EvalResult(res);
+          break;
+        }
+
+        case kind::STRING_LENGTH:
+        {
+          const String& s = results[currNode[0]].d_str;
+          results[currNode] = EvalResult(Rational(s.size()));
+          break;
+        }
+
+        case kind::STRING_SUBSTR:
+        {
+          const String& s = results[currNode[0]].d_str;
+          Integer s_len(s.size());
+          Integer i = results[currNode[1]].d_rat.getNumerator();
+          Integer j = results[currNode[2]].d_rat.getNumerator();
+
+          if (i.strictlyNegative() || j.strictlyNegative() || i >= s_len)
+          {
+            results[currNode] = EvalResult(String(""));
+          }
+          else if (i + j > s_len)
+          {
+            results[currNode] =
+                EvalResult(s.suffix((s_len - i).toUnsignedInt()));
+          }
+          else
+          {
+            results[currNode] =
+                EvalResult(s.substr(i.toUnsignedInt(), j.toUnsignedInt()));
+          }
+          break;
+        }
+
+        case kind::STRING_CHARAT:
+        {
+          const String& s = results[currNode[0]].d_str;
+          Integer s_len(s.size());
+          Integer i = results[currNode[1]].d_rat.getNumerator();
+          if (i.strictlyNegative() || i >= s_len)
+          {
+            results[currNode] = EvalResult(String(""));
+          }
+          else
+          {
+            results[currNode] = EvalResult(s.substr(i.toUnsignedInt(), 1));
+          }
+          break;
+        }
+
+        case kind::STRING_STRCTN:
+        {
+          const String& s = results[currNode[0]].d_str;
+          const String& t = results[currNode[1]].d_str;
+          results[currNode] = EvalResult(s.find(t) != std::string::npos);
+          break;
+        }
+
+        case kind::STRING_STRIDOF:
+        {
+          const String& s = results[currNode[0]].d_str;
+          Integer s_len(s.size());
+          const String& x = results[currNode[1]].d_str;
+          Integer i = results[currNode[2]].d_rat.getNumerator();
+
+          if (i.strictlyNegative() || i >= s_len)
+          {
+            results[currNode] = EvalResult(Rational(-1));
+          }
+          else
+          {
+            size_t r = s.find(x, i.toUnsignedInt());
+            if (r == std::string::npos)
+            {
+              results[currNode] = EvalResult(Rational(-1));
+            }
+            else
+            {
+              results[currNode] = EvalResult(Rational(r));
+            }
+          }
+          break;
+        }
+
+        case kind::STRING_STRREPL:
+        {
+          const String& s = results[currNode[0]].d_str;
+          const String& x = results[currNode[1]].d_str;
+          const String& y = results[currNode[2]].d_str;
+          results[currNode] = EvalResult(s.replace(x, y));
+          break;
+        }
+
+        case kind::STRING_PREFIX:
+        {
+          const String& t = results[currNode[0]].d_str;
+          const String& s = results[currNode[1]].d_str;
+          if (s.size() < t.size())
+          {
+            results[currNode] = EvalResult(false);
+          }
+          else
+          {
+            results[currNode] = EvalResult(s.prefix(t.size()) == t);
+          }
+          break;
+        }
+
+        case kind::STRING_SUFFIX:
+        {
+          const String& t = results[currNode[0]].d_str;
+          const String& s = results[currNode[1]].d_str;
+          if (s.size() < t.size())
+          {
+            results[currNode] = EvalResult(false);
+          }
+          else
+          {
+            results[currNode] = EvalResult(s.suffix(t.size()) == t);
+          }
+          break;
+        }
+
+        case kind::STRING_ITOS:
+        {
+          Integer i = results[currNode[0]].d_rat.getNumerator();
+          if (i.strictlyNegative())
+          {
+            results[currNode] = EvalResult(String(""));
+          }
+          else
+          {
+            results[currNode] = EvalResult(String(i.toString()));
+          }
+          break;
+        }
+
+        case kind::STRING_STOI:
+        {
+          const String& s = results[currNode[0]].d_str;
+          if (s.isNumber())
+          {
+            results[currNode] = EvalResult(Rational(-1));
+          }
+          else
+          {
+            results[currNode] = EvalResult(Rational(s.toNumber()));
+          }
+          break;
+        }
+
+        case kind::CONST_BITVECTOR:
+          results[currNode] = EvalResult(currNodeVal.getConst<BitVector>());
+          break;
+
+        case kind::BITVECTOR_NOT:
+          results[currNode] = EvalResult(~results[currNode[0]].d_bv);
+          break;
+
+        case kind::BITVECTOR_NEG:
+          results[currNode] = EvalResult(-results[currNode[0]].d_bv);
+          break;
+
+        case kind::BITVECTOR_EXTRACT:
+        {
+          unsigned lo = bv::utils::getExtractLow(currNodeVal);
+          unsigned hi = bv::utils::getExtractHigh(currNodeVal);
+          results[currNode] =
+              EvalResult(results[currNode[0]].d_bv.extract(hi, lo));
+          break;
+        }
+
+        case kind::BITVECTOR_CONCAT:
+        {
+          BitVector res = results[currNode[0]].d_bv;
+          for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
+          {
+            res = res.concat(results[currNode[i]].d_bv);
+          }
+          results[currNode] = EvalResult(res);
+          break;
+        }
+
+        case kind::BITVECTOR_PLUS:
+        {
+          BitVector res = results[currNode[0]].d_bv;
+          for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
+          {
+            res = res + results[currNode[i]].d_bv;
+          }
+          results[currNode] = EvalResult(res);
+          break;
+        }
+
+        case kind::BITVECTOR_MULT:
+        {
+          BitVector res = results[currNode[0]].d_bv;
+          for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
+          {
+            res = res * results[currNode[i]].d_bv;
+          }
+          results[currNode] = EvalResult(res);
+          break;
+        }
+        case kind::BITVECTOR_AND:
+        {
+          BitVector res = results[currNode[0]].d_bv;
+          for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
+          {
+            res = res & results[currNode[i]].d_bv;
+          }
+          results[currNode] = EvalResult(res);
+          break;
+        }
+
+        case kind::BITVECTOR_OR:
+        {
+          BitVector res = results[currNode[0]].d_bv;
+          for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
+          {
+            res = res | results[currNode[i]].d_bv;
+          }
+          results[currNode] = EvalResult(res);
+          break;
+        }
+
+        case kind::BITVECTOR_XOR:
+        {
+          BitVector res = results[currNode[0]].d_bv;
+          for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
+          {
+            res = res ^ results[currNode[i]].d_bv;
+          }
+          results[currNode] = EvalResult(res);
+          break;
+        }
+
+        case kind::EQUAL:
+        {
+          EvalResult lhs = results[currNode[0]];
+          EvalResult rhs = results[currNode[1]];
+
+          switch (lhs.d_tag)
+          {
+            case EvalResult::BOOL:
+            {
+              results[currNode] = EvalResult(lhs.d_bool == rhs.d_bool);
+              break;
+            }
+
+            case EvalResult::BITVECTOR:
+            {
+              results[currNode] = EvalResult(lhs.d_bv == rhs.d_bv);
+              break;
+            }
+
+            case EvalResult::RATIONAL:
+            {
+              results[currNode] = EvalResult(lhs.d_rat == rhs.d_rat);
+              break;
+            }
+
+            case EvalResult::STRING:
+            {
+              results[currNode] = EvalResult(lhs.d_str == rhs.d_str);
+              break;
+            }
+
+            default:
+            {
+              Trace("evaluator") << "Theory " << Theory::theoryOf(currNode[0])
+                                 << " not supported" << std::endl;
+              return EvalResult();
+              break;
+            }
+          }
+
+          break;
+        }
+
+        case kind::ITE:
+        {
+          if (results[currNode[0]].d_bool)
+          {
+            results[currNode] = results[currNode[1]];
+          }
+          else
+          {
+            results[currNode] = results[currNode[2]];
+          }
+          break;
+        }
+
+        default:
+        {
+          Trace("evaluator") << "Kind " << currNodeVal.getKind()
+                             << " not supported" << std::endl;
+          return EvalResult();
+        }
+      }
+    }
+  }
+
+  return results[n];
+}
+
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/evaluator.h b/src/theory/evaluator.h
new file mode 100644 (file)
index 0000000..0d7ddbe
--- /dev/null
@@ -0,0 +1,113 @@
+/*********************                                                        */
+/*! \file evaluator.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andres Noetzli
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The Evaluator class
+ **
+ ** The Evaluator class can be used to evaluate terms with constant leaves
+ ** quickly, without going through the rewriter.
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef __CVC4__THEORY__EVALUATOR_H
+#define __CVC4__THEORY__EVALUATOR_H
+
+#include <utility>
+#include <vector>
+
+#include "base/output.h"
+#include "expr/node.h"
+#include "util/bitvector.h"
+#include "util/rational.h"
+#include "util/regexp.h"
+
+namespace CVC4 {
+namespace theory {
+
+/**
+ * Struct that holds the result of an evaluation. The actual value is stored in
+ * a union to avoid the overhead of a class hierarchy with virtual methods.
+ */
+struct EvalResult
+{
+  /* Describes which type of result is being stored */
+  enum
+  {
+    BOOL,
+    BITVECTOR,
+    RATIONAL,
+    STRING,
+    INVALID
+  } d_tag;
+
+  /* Stores the actual result */
+  union
+  {
+    bool d_bool;
+    BitVector d_bv;
+    Rational d_rat;
+    String d_str;
+  };
+
+  EvalResult(const EvalResult& other);
+  EvalResult() : d_tag(INVALID) {}
+  EvalResult(bool b) : d_tag(BOOL), d_bool(b) {}
+  EvalResult(const BitVector& bv) : d_tag(BITVECTOR), d_bv(bv) {}
+  EvalResult(const Rational& i) : d_tag(RATIONAL), d_rat(i) {}
+  EvalResult(const String& str) : d_tag(STRING), d_str(str) {}
+
+  EvalResult& operator=(const EvalResult& other);
+
+  ~EvalResult();
+
+  /**
+   * Converts the result to a Node. If the result is not valid, this function
+   * returns the null node.
+   */
+  Node toNode() const;
+};
+
+/**
+ * The class that performs the actual evaluation of a term under a
+ * substitution. Right now, the class does not cache anything between different
+ * calls to `eval` but this might change in the future.
+ */
+class Evaluator
+{
+ public:
+  /**
+   * Evaluates node `n` under the substitution described by the variable names
+   * `args` and the corresponding values `vals`. The function returns a null
+   * node if there is a subterm that is not constant under the substitution or
+   * if an operator is not supported by the evaluator.
+   */
+  Node eval(TNode n,
+            const std::vector<Node>& args,
+            const std::vector<Node>& vals);
+
+ private:
+  /**
+   * Evaluates node `n` under the substitution described by the variable names
+   * `args` and the corresponding values `vals`. The internal version returns
+   * an EvalResult which has slightly less overhead for recursive calls. The
+   * function returns an invalid result if there is a subterm that is not
+   * constant under the substitution or if an operator is not supported by the
+   * evaluator.
+   */
+  EvalResult evalInternal(TNode n,
+                          const std::vector<Node>& args,
+                          const std::vector<Node>& vals);
+};
+
+}  // namespace theory
+}  // namespace CVC4
+
+#endif /* __CVC4__THEORY__EVALUATOR_H */
index d4349745b7cf7c5ffb8b545fb84943ed433897c4..26f26a14519a3625dc13e6f847b35441ed36e774 100644 (file)
@@ -22,6 +22,8 @@
 #include "theory/quantifiers/term_database.h"
 #include "theory/quantifiers/term_util.h"
 #include "theory/quantifiers_engine.h"
+#include "options/base_options.h"
+#include "printer/printer.h"
 
 using namespace CVC4::kind;
 
@@ -33,6 +35,7 @@ TermDbSygus::TermDbSygus(context::Context* c, QuantifiersEngine* qe)
     : d_quantEngine(qe),
       d_syexp(new SygusExplain(this)),
       d_ext_rw(new ExtendedRewriter(true)),
+      d_eval(new Evaluator),
       d_eval_unfold(new SygusEvalUnfold(this))
 {
   d_true = NodeManager::currentNM()->mkConst( true );
@@ -1548,13 +1551,39 @@ Node TermDbSygus::getEagerUnfold( Node n, std::map< Node, Node >& visited ) {
   }
 }
 
-
-Node TermDbSygus::evaluateBuiltin( TypeNode tn, Node bn, std::vector< Node >& args ) {
+Node TermDbSygus::evaluateBuiltin(TypeNode tn,
+                                  Node bn,
+                                  std::vector<Node>& args,
+                                  bool tryEval)
+{
   if( !args.empty() ){
     std::map< TypeNode, std::vector< Node > >::iterator it = d_var_list.find( tn );
     Assert( it!=d_var_list.end() );
     Assert( it->second.size()==args.size() );
-    return Rewriter::rewrite( bn.substitute( it->second.begin(), it->second.end(), args.begin(), args.end() ) );
+
+    Node res;
+    if (tryEval && options::sygusEvalOpt())
+    {
+      // Try evaluating, which is much faster than substitution+rewriting.
+      // This may fail if there is a subterm of bn under the
+      // substitution that is not constant, or if an operator in bn is not
+      // supported by the evaluator
+      res = d_eval->eval(bn, it->second, args);
+    }
+    if (!res.isNull())
+    {
+      Assert(res
+             == Rewriter::rewrite(bn.substitute(it->second.begin(),
+                                                it->second.end(),
+                                                args.begin(),
+                                                args.end())));
+      return res;
+    }
+    else
+    {
+      return Rewriter::rewrite(bn.substitute(
+          it->second.begin(), it->second.end(), args.begin(), args.end()));
+    }
   }else{
     return Rewriter::rewrite( bn );
   }
index be35d07f3c358780497660b00f2fb4056e4dbcb2..286533570e30b71d327004d676ab17f68ebf16f4 100644 (file)
@@ -19,6 +19,7 @@
 
 #include <unordered_set>
 
+#include "theory/evaluator.h"
 #include "theory/quantifiers/extended_rewrite.h"
 #include "theory/quantifiers/sygus/sygus_eval_unfold.h"
 #include "theory/quantifiers/sygus/sygus_explain.h"
@@ -53,6 +54,8 @@ class TermDbSygus {
   SygusExplain* getExplain() { return d_syexp.get(); }
   /** get the extended rewrite utility */
   ExtendedRewriter* getExtRewriter() { return d_ext_rw.get(); }
+  /** get the evaluator */
+  Evaluator* getEvaluator() { return d_eval.get(); }
   /** evaluation unfolding utility */
   SygusEvalUnfold* getEvalUnfold() { return d_eval_unfold.get(); }
   //------------------------------end utilities
@@ -182,7 +185,8 @@ class TermDbSygus {
    * form of bn [ args / vars(tn) ], where vars(tn) is the sygus variable
    * list for type tn (see Datatype::getSygusVarList).
    */
-  Node evaluateBuiltin(TypeNode tn, Node bn, std::vector<Node>& args);
+  Node evaluateBuiltin(TypeNode tn, Node bn, std::vector<Node>& args,
+bool tryEval = true);
   /** evaluate with unfolding
    *
    * n is any term that may involve sygus evaluation functions. This function
@@ -222,6 +226,8 @@ class TermDbSygus {
   std::unique_ptr<SygusExplain> d_syexp;
   /** extended rewriter */
   std::unique_ptr<ExtendedRewriter> d_ext_rw;
+  /** evaluator */
+  std::unique_ptr<Evaluator> d_eval;
   /** evaluation function unfolding utility */
   std::unique_ptr<SygusEvalUnfold> d_eval_unfold;
   //------------------------------end utilities
index cc9f6fb1b815882131cecb0f2fe2947d9cff1081..0ab3050398dd4a9db318ad7d140a99663d31e667 100644 (file)
@@ -6,6 +6,7 @@ UNIT_TESTS = \
        util/cardinality_public
 if WHITE_AND_BLACK_TESTS
 UNIT_TESTS += \
+       theory/evaluator_white \
        theory/logic_info_white \
        theory/theory_arith_white \
        theory/theory_black \
diff --git a/test/unit/theory/evaluator_white.h b/test/unit/theory/evaluator_white.h
new file mode 100644 (file)
index 0000000..4416ee0
--- /dev/null
@@ -0,0 +1,122 @@
+/*********************                                                        */
+/*! \file evaluator_white.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andres Noetzli
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief [[ Add one-line brief description here ]]
+ **
+ ** [[ Add lengthier description here ]]
+ ** \todo document this file
+ **/
+
+#include <cxxtest/TestSuite.h>
+#include <vector>
+
+#include "expr/node.h"
+#include "expr/node_manager.h"
+#include "smt/smt_engine.h"
+#include "smt/smt_engine_scope.h"
+#include "theory/bv/theory_bv_utils.h"
+#include "theory/evaluator.h"
+#include "theory/rewriter.h"
+#include "theory/theory_test_utils.h"
+
+using namespace CVC4;
+using namespace CVC4::smt;
+using namespace CVC4::theory;
+
+using namespace std;
+
+class TheoryEvaluatorWhite : public CxxTest::TestSuite
+{
+  ExprManager *d_em;
+  NodeManager *d_nm;
+  SmtEngine *d_smt;
+  SmtScope *d_scope;
+
+ public:
+  TheoryEvaluatorWhite() {}
+
+  void setUp()
+  {
+    Options opts;
+    opts.setOutputLanguage(language::output::LANG_SMTLIB_V2);
+    d_em = new ExprManager(opts);
+    d_nm = NodeManager::fromExprManager(d_em);
+    d_smt = new SmtEngine(d_em);
+    d_scope = new SmtScope(d_smt);
+  }
+
+  void tearDown()
+  {
+    delete d_scope;
+    delete d_smt;
+    delete d_em;
+  }
+
+  void testSimple()
+  {
+    TypeNode bv64Type = d_nm->mkBitVectorType(64);
+
+    Node w = d_nm->mkVar("w", bv64Type);
+    Node x = d_nm->mkVar("x", bv64Type);
+    Node y = d_nm->mkVar("y", bv64Type);
+    Node z = d_nm->mkVar("z", bv64Type);
+
+    Node zero = d_nm->mkConst(BitVector(64, (unsigned int)0));
+    Node one = d_nm->mkConst(BitVector(64, (unsigned int)1));
+    Node c1 = d_nm->mkConst(BitVector(
+        64,
+        (unsigned int)0b0000000100000101001110111001101000101110011101011011110011100111));
+    Node c2 = d_nm->mkConst(BitVector(
+        64,
+        (unsigned int)0b0000000100000101001110111001101000101110011101011011110011100111));
+
+    Node t = d_nm->mkNode(kind::ITE, d_nm->mkNode(kind::EQUAL, y, one), x, w);
+
+    std::vector<Node> args = {w, x, y, z};
+    std::vector<Node> vals = {c1, zero, one, c1};
+
+    Evaluator eval;
+    Node r = eval.eval(t, args, vals);
+    TS_ASSERT_EQUALS(r,
+                     Rewriter::rewrite(t.substitute(
+                         args.begin(), args.end(), vals.begin(), vals.end())));
+  }
+
+  void testLoop()
+  {
+    TypeNode bv64Type = d_nm->mkBitVectorType(64);
+
+    Node w = d_nm->mkBoundVar(bv64Type);
+    Node x = d_nm->mkVar("x", bv64Type);
+
+    Node zero = d_nm->mkConst(BitVector(1, (unsigned int)0));
+    Node one = d_nm->mkConst(BitVector(64, (unsigned int)1));
+    Node c = d_nm->mkConst(BitVector(
+        64,
+        (unsigned int)0b0001111000010111110000110110001101011110111001101100000101010100));
+
+    Node largs = d_nm->mkNode(kind::BOUND_VAR_LIST, w);
+    Node lbody = d_nm->mkNode(
+        kind::BITVECTOR_CONCAT, bv::utils::mkExtract(w, 62, 0), zero);
+    Node lambda = d_nm->mkNode(kind::LAMBDA, largs, lbody);
+    Node t = d_nm->mkNode(kind::BITVECTOR_AND,
+                          d_nm->mkNode(kind::APPLY_UF, lambda, one),
+                          d_nm->mkNode(kind::APPLY_UF, lambda, x));
+
+    std::vector<Node> args = {x};
+    std::vector<Node> vals = {c};
+    Evaluator eval;
+    Node r = eval.eval(t, args, vals);
+    TS_ASSERT_EQUALS(r,
+                     Rewriter::rewrite(t.substitute(
+                         args.begin(), args.end(), vals.begin(), vals.end())));
+  }
+};