Move expand definition from Theory to TheoryRewriter (#6408)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 22 Apr 2021 02:42:08 +0000 (21:42 -0500)
committerGitHub <noreply@github.com>
Thu, 22 Apr 2021 02:42:08 +0000 (02:42 +0000)
This is work towards eliminating global calls to getCurrentSmtEngine()->expandDefinition.

The next step will be to add Rewriter::expandDefinition.

34 files changed:
src/smt/expand_definitions.cpp
src/theory/arith/arith_preprocess.cpp
src/theory/arith/arith_preprocess.h
src/theory/arith/arith_rewriter.cpp
src/theory/arith/arith_rewriter.h
src/theory/arith/theory_arith.cpp
src/theory/arith/theory_arith.h
src/theory/arrays/theory_arrays.cpp
src/theory/arrays/theory_arrays.h
src/theory/arrays/theory_arrays_rewriter.cpp
src/theory/arrays/theory_arrays_rewriter.h
src/theory/bags/bags_rewriter.h
src/theory/bags/theory_bags.cpp
src/theory/bags/theory_bags.h
src/theory/bv/theory_bv.cpp
src/theory/bv/theory_bv.h
src/theory/bv/theory_bv_rewriter.cpp
src/theory/bv/theory_bv_rewriter.h
src/theory/datatypes/datatypes_rewriter.cpp
src/theory/datatypes/datatypes_rewriter.h
src/theory/datatypes/theory_datatypes.cpp
src/theory/datatypes/theory_datatypes.h
src/theory/fp/theory_fp.cpp
src/theory/fp/theory_fp.h
src/theory/fp/theory_fp_rewriter.h
src/theory/sets/theory_sets.cpp
src/theory/sets/theory_sets.h
src/theory/strings/sequences_rewriter.cpp
src/theory/strings/sequences_rewriter.h
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h
src/theory/theory.h
src/theory/theory_rewriter.cpp
src/theory/theory_rewriter.h

index c5080db81dc9c4de4bb371f8be47b621471f63e6..d331e8e78252df53ae4f0f4e770e679ab520553d 100644 (file)
@@ -255,9 +255,10 @@ TrustNode ExpandDefs::expandDefinitions(
         // do not do any theory stuff if expandOnly is true
 
         theory::Theory* t = d_smt.getTheoryEngine()->theoryOf(node);
+        theory::TheoryRewriter* tr = t->getTheoryRewriter();
 
         Assert(t != NULL);
-        TrustNode trn = t->expandDefinition(n);
+        TrustNode trn = tr->expandDefinition(n);
         if (!trn.isNull())
         {
           node = trn.getNode();
index a33d802f1418edcf9d8a41029c20bd7c8109a152..d5533de24f12fa949eb5e7ae72314d15152b0a33 100644 (file)
@@ -26,8 +26,8 @@ namespace arith {
 ArithPreprocess::ArithPreprocess(ArithState& state,
                                  InferenceManager& im,
                                  ProofNodeManager* pnm,
-                                 const LogicInfo& info)
-    : d_im(im), d_opElim(pnm, info), d_reduced(state.getUserContext())
+                                 OperatorElim& oe)
+    : d_im(im), d_opElim(oe), d_reduced(state.getUserContext())
 {
 }
 TrustNode ArithPreprocess::eliminate(TNode n,
index ea24e5066def6b8ab657a1c74356f9e50abb5c7d..63b4515e7b80f109e70ff629f9e59772843fa0aa 100644 (file)
@@ -31,6 +31,7 @@ namespace arith {
 
 class ArithState;
 class InferenceManager;
+class OperatorElim;
 
 /**
  * This module can be used for (on demand) elimination of extended arithmetic
@@ -45,7 +46,7 @@ class ArithPreprocess
   ArithPreprocess(ArithState& state,
                   InferenceManager& im,
                   ProofNodeManager* pnm,
-                  const LogicInfo& info);
+                  OperatorElim& oe);
   ~ArithPreprocess() {}
   /**
    * Call eliminate operators on formula n, return the resulting trust node,
@@ -80,7 +81,7 @@ class ArithPreprocess
   /** Reference to the inference manager */
   InferenceManager& d_im;
   /** The operator elimination utility */
-  OperatorElim d_opElim;
+  OperatorElim& d_opElim;
   /** The set of assertions that were reduced */
   context::CDHashMap<Node, bool, NodeHashFunction> d_reduced;
 };
index 83aaaadd8907e3c13884b4879be4d05f67594221..b8135127de4f38379db20ec2b06db0101a440baf 100644 (file)
@@ -26,6 +26,7 @@
 #include "theory/arith/arith_rewriter.h"
 #include "theory/arith/arith_utilities.h"
 #include "theory/arith/normal_form.h"
+#include "theory/arith/operator_elim.h"
 #include "theory/theory.h"
 #include "util/iand.h"
 
@@ -33,6 +34,8 @@ namespace cvc5 {
 namespace theory {
 namespace arith {
 
+ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
+
 bool ArithRewriter::isAtom(TNode n) {
   Kind k = n.getKind();
   return arith::isRelationOperator(k) || k == kind::IS_INTEGER
@@ -893,6 +896,15 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
   return RewriteResponse(REWRITE_DONE, t);
 }
 
+TrustNode ArithRewriter::expandDefinition(Node node)
+{
+  // call eliminate operators, to eliminate partial operators only
+  std::vector<SkolemLemma> lems;
+  TrustNode ret = d_opElim.eliminate(node, lems, true);
+  Assert(lems.empty());
+  return ret;
+}
+
 RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
 {
   Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
index e476fbd620ba6769eaf80c72ddaeb02e2d775c20..6a92ba1ccef6ee06fc46ce570682a980a5fb4d59 100644 (file)
@@ -28,11 +28,19 @@ namespace cvc5 {
 namespace theory {
 namespace arith {
 
+class OperatorElim;
+
 class ArithRewriter : public TheoryRewriter
 {
  public:
+  ArithRewriter(OperatorElim& oe);
   RewriteResponse preRewrite(TNode n) override;
   RewriteResponse postRewrite(TNode n) override;
+  /**
+   * Expand definition, which eliminates extended operators like div/mod in
+   * the given node.
+   */
+  TrustNode expandDefinition(Node node) override;
 
  private:
   static Node makeSubtractionNode(TNode l, TNode r);
@@ -70,6 +78,8 @@ class ArithRewriter : public TheoryRewriter
   }
   /** return rewrite */
   static RewriteResponse returnRewrite(TNode t, Node ret, Rewrite r);
+  /** The operator elimination utility */
+  OperatorElim& d_opElim;
 }; /* class ArithRewriter */
 
 }  // namespace arith
index 181a816c28c4a58a7e9fc4e1672df8f2021a19e8..1843ddb8a590aa22d249df6cfc36059426336e99 100644 (file)
@@ -49,7 +49,9 @@ TheoryArith::TheoryArith(context::Context* c,
       d_astate(*d_internal, c, u, valuation),
       d_im(*this, d_astate, pnm),
       d_nonlinearExtension(nullptr),
-      d_arithPreproc(d_astate, d_im, pnm, logicInfo)
+      d_opElim(pnm, logicInfo),
+      d_arithPreproc(d_astate, d_im, pnm, d_opElim),
+      d_rewriter(d_opElim)
 {
   // indicate we are using the theory state object and inference manager
   d_theoryState = &d_astate;
@@ -103,15 +105,6 @@ void TheoryArith::preRegisterTerm(TNode n)
   d_internal->preRegisterTerm(n);
 }
 
-TrustNode TheoryArith::expandDefinition(Node node)
-{
-  // call eliminate operators, to eliminate partial operators only
-  std::vector<SkolemLemma> lems;
-  TrustNode ret = d_arithPreproc.eliminate(node, lems, true);
-  Assert(lems.empty());
-  return ret;
-}
-
 void TheoryArith::notifySharedTerm(TNode n) { d_internal->notifySharedTerm(n); }
 
 TrustNode TheoryArith::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems)
index 26a33e2477745e441dab13b5ee05f20828c4aff3..43c962e301f6787685e46de30a66ad770986ac63 100644 (file)
@@ -74,11 +74,6 @@ class TheoryArith : public Theory {
   /** finish initialization */
   void finishInit() override;
   //--------------------------------- end initialization
-  /**
-   * Expand definition, which eliminates extended operators like div/mod in
-   * the given node.
-   */
-  TrustNode expandDefinition(Node node) override;
   /**
    * Does non-context dependent setup for a node connected to a theory.
    */
@@ -158,6 +153,8 @@ class TheoryArith : public Theory {
    * arithmetic.
    */
   std::unique_ptr<nl::NonlinearExtension> d_nonlinearExtension;
+  /** The operator elimination utility */
+  OperatorElim d_opElim;
   /** The preprocess utility */
   ArithPreprocess d_arithPreproc;
   /** The theory rewriter for this theory. */
index 1a1090f689b48794cfb5858277c5b540ad82c88e..e887feccbe5d7d57970ba604acfcee8c3b703779 100644 (file)
@@ -300,7 +300,7 @@ Node TheoryArrays::solveWrite(TNode term, bool solve1, bool solve2, bool ppCheck
 TrustNode TheoryArrays::ppRewrite(TNode term, std::vector<SkolemLemma>& lems)
 {
   // first, see if we need to expand definitions
-  TrustNode texp = expandDefinition(term);
+  TrustNode texp = d_rewriter.expandDefinition(term);
   if (!texp.isNull())
   {
     return texp;
@@ -2068,62 +2068,6 @@ std::string TheoryArrays::TheoryArraysDecisionStrategy::identify() const
   return std::string("th_arrays_dec");
 }
 
-TrustNode TheoryArrays::expandDefinition(Node node)
-{
-  NodeManager* nm = NodeManager::currentNM();
-  Kind kind = node.getKind();
-
-  /* Expand
-   *
-   *   (eqrange a b i j)
-   *
-   * to
-   *
-   *  forall k . i <= k <= j => a[k] = b[k]
-   *
-   */
-  if (kind == kind::EQ_RANGE)
-  {
-    TNode a = node[0];
-    TNode b = node[1];
-    TNode i = node[2];
-    TNode j = node[3];
-    Node k = nm->mkBoundVar(i.getType());
-    Node bvl = nm->mkNode(kind::BOUND_VAR_LIST, k);
-    TypeNode type = k.getType();
-
-    Kind kle;
-    Node range;
-    if (type.isBitVector())
-    {
-      kle = kind::BITVECTOR_ULE;
-    }
-    else if (type.isFloatingPoint())
-    {
-      kle = kind::FLOATINGPOINT_LEQ;
-    }
-    else if (type.isInteger() || type.isReal())
-    {
-      kle = kind::LEQ;
-    }
-    else
-    {
-      Unimplemented() << "Type " << type << " is not supported for predicate "
-                      << kind;
-    }
-
-    range = nm->mkNode(kind::AND, nm->mkNode(kle, i, k), nm->mkNode(kle, k, j));
-
-    Node eq = nm->mkNode(kind::EQUAL,
-                         nm->mkNode(kind::SELECT, a, k),
-                         nm->mkNode(kind::SELECT, b, k));
-    Node implies = nm->mkNode(kind::IMPLIES, range, eq);
-    Node ret = nm->mkNode(kind::FORALL, bvl, implies);
-    return TrustNode::mkTrustRewrite(node, ret, nullptr);
-  }
-  return TrustNode::null();
-}
-
 void TheoryArrays::computeRelevantTerms(std::set<Node>& termSet)
 {
   NodeManager* nm = NodeManager::currentNM();
index 7cf8d52e3e2d61ac3d4c2a48bf65ce5d78061fdf..f9813cd3f563de9b082ddbc57a2c7e2529118da3 100644 (file)
@@ -158,8 +158,6 @@ class TheoryArrays : public Theory {
 
   std::string identify() const override { return std::string("TheoryArrays"); }
 
-  TrustNode expandDefinition(Node node) override;
-
   /////////////////////////////////////////////////////////////////////////////
   // PREPROCESSING
   /////////////////////////////////////////////////////////////////////////////
index 323dd004607221e0dfa3681bb88c82de361cee0b..6269cb5ddad903dbd152752dca3241080aa236c6 100644 (file)
@@ -45,6 +45,622 @@ void setMostFrequentValueCount(TNode store, uint64_t count) {
   return store.setAttribute(ArrayConstantMostFrequentValueCountAttr(), count);
 }
 
+Node TheoryArraysRewriter::normalizeConstant(TNode node)
+{
+  return normalizeConstant(node, node[1].getType().getCardinality());
+}
+
+// this function is called by printers when using the option "--model-u-dt-enum"
+Node TheoryArraysRewriter::normalizeConstant(TNode node, Cardinality indexCard)
+{
+  TNode store = node[0];
+  TNode index = node[1];
+  TNode value = node[2];
+
+  std::vector<TNode> indices;
+  std::vector<TNode> elements;
+
+  // Normal form for nested stores is just ordering by index - but also need
+  // to check if we are writing to default value
+
+  // Go through nested stores looking for where to insert index
+  // Also check whether we are replacing an existing store
+  TNode replacedValue;
+  uint32_t depth = 1;
+  uint32_t valCount = 1;
+  while (store.getKind() == kind::STORE)
+  {
+    if (index == store[1])
+    {
+      replacedValue = store[2];
+      store = store[0];
+      break;
+    }
+    else if (index >= store[1])
+    {
+      break;
+    }
+    if (value == store[2])
+    {
+      valCount += 1;
+    }
+    depth += 1;
+    indices.push_back(store[1]);
+    elements.push_back(store[2]);
+    store = store[0];
+  }
+  Node n = store;
+
+  // Get the default value at the bottom of the nested stores
+  while (store.getKind() == kind::STORE)
+  {
+    if (value == store[2])
+    {
+      valCount += 1;
+    }
+    depth += 1;
+    store = store[0];
+  }
+  Assert(store.getKind() == kind::STORE_ALL);
+  ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
+  Node defaultValue = storeAll.getValue();
+  NodeManager* nm = NodeManager::currentNM();
+
+  // Check if we are writing to default value - if so the store
+  // to index can be ignored
+  if (value == defaultValue)
+  {
+    if (replacedValue.isNull())
+    {
+      // Quick exit - if writing to default value and nothing was
+      // replaced, we can just return node[0]
+      return node[0];
+    }
+    // else rebuild the store without the replaced write and then exit
+  }
+  else
+  {
+    n = nm->mkNode(kind::STORE, n, index, value);
+  }
+
+  // Build the rest of the store after inserting/deleting
+  while (!indices.empty())
+  {
+    n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
+    indices.pop_back();
+    elements.pop_back();
+  }
+
+  // Ready to exit if write was to the default value (see previous comment)
+  if (value == defaultValue)
+  {
+    return n;
+  }
+
+  if (indexCard.isInfinite())
+  {
+    return n;
+  }
+
+  // When index sort is finite, we have to check whether there is any value
+  // that is written to more than the default value.  If so, it must become
+  // the new default value
+
+  TNode mostFrequentValue;
+  uint32_t mostFrequentValueCount = 0;
+  store = node[0];
+  if (store.getKind() == kind::STORE)
+  {
+    mostFrequentValue = getMostFrequentValue(store);
+    mostFrequentValueCount = getMostFrequentValueCount(store);
+  }
+
+  // Compute the most frequently written value for n
+  if (valCount > mostFrequentValueCount
+      || (valCount == mostFrequentValueCount && value < mostFrequentValue))
+  {
+    mostFrequentValue = value;
+    mostFrequentValueCount = valCount;
+  }
+
+  // Need to make sure the default value count is larger, or the same and the
+  // default value is expression-order-less-than nextValue
+  Cardinality::CardinalityComparison compare =
+      indexCard.compare(mostFrequentValueCount + depth);
+  Assert(compare != Cardinality::UNKNOWN);
+  if (compare == Cardinality::GREATER
+      || (compare == Cardinality::EQUAL && (defaultValue < mostFrequentValue)))
+  {
+    return n;
+  }
+
+  // Bad case: have to recompute value counts and/or possibly switch out
+  // default value
+  store = n;
+  std::unordered_set<TNode, TNodeHashFunction> indexSet;
+  std::unordered_map<TNode, uint32_t, TNodeHashFunction> elementsMap;
+  std::unordered_map<TNode, uint32_t, TNodeHashFunction>::iterator it;
+  uint32_t count;
+  uint32_t max = 0;
+  TNode maxValue;
+  while (store.getKind() == kind::STORE)
+  {
+    indices.push_back(store[1]);
+    indexSet.insert(store[1]);
+    elements.push_back(store[2]);
+    it = elementsMap.find(store[2]);
+    if (it != elementsMap.end())
+    {
+      (*it).second = (*it).second + 1;
+      count = (*it).second;
+    }
+    else
+    {
+      elementsMap[store[2]] = 1;
+      count = 1;
+    }
+    if (count > max || (count == max && store[2] < maxValue))
+    {
+      max = count;
+      maxValue = store[2];
+    }
+    store = store[0];
+  }
+
+  Assert(depth == indices.size());
+  compare = indexCard.compare(max + depth);
+  Assert(compare != Cardinality::UNKNOWN);
+  if (compare == Cardinality::GREATER
+      || (compare == Cardinality::EQUAL && (defaultValue < maxValue)))
+  {
+    Assert(!replacedValue.isNull() && mostFrequentValue == replacedValue);
+    return n;
+  }
+
+  // Out of luck: have to swap out default value
+
+  // Enumerate values from index type into newIndices and sort
+  std::vector<Node> newIndices;
+  TypeEnumerator te(index.getType());
+  bool needToSort = false;
+  uint32_t numTe = 0;
+  while (!te.isFinished()
+         && (!indexCard.isFinite()
+             || numTe < indexCard.getFiniteCardinality().toUnsignedInt()))
+  {
+    if (indexSet.find(*te) == indexSet.end())
+    {
+      if (!newIndices.empty() && (!(newIndices.back() < (*te))))
+      {
+        needToSort = true;
+      }
+      newIndices.push_back(*te);
+    }
+    ++numTe;
+    ++te;
+  }
+  Assert(indexCard.compare(newIndices.size() + depth) == Cardinality::EQUAL);
+  if (needToSort)
+  {
+    std::sort(newIndices.begin(), newIndices.end());
+  }
+
+  n = nm->mkConst(ArrayStoreAll(node.getType(), maxValue));
+  std::vector<Node>::iterator itNew = newIndices.begin(),
+                              it_end = newIndices.end();
+  while (itNew != it_end || !indices.empty())
+  {
+    if (itNew != it_end && (indices.empty() || (*itNew) < indices.back()))
+    {
+      n = nm->mkNode(kind::STORE, n, (*itNew), defaultValue);
+      ++itNew;
+    }
+    else if (itNew == it_end || indices.back() < (*itNew))
+    {
+      if (elements.back() != maxValue)
+      {
+        n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
+      }
+      indices.pop_back();
+      elements.pop_back();
+    }
+  }
+  return n;
+}
+
+RewriteResponse TheoryArraysRewriter::postRewrite(TNode node)
+{
+  Trace("arrays-postrewrite")
+      << "Arrays::postRewrite start " << node << std::endl;
+  switch (node.getKind())
+  {
+    case kind::SELECT:
+    {
+      TNode store = node[0];
+      TNode index = node[1];
+      Node n;
+      bool val;
+      while (store.getKind() == kind::STORE)
+      {
+        if (index == store[1])
+        {
+          val = true;
+        }
+        else if (index.isConst() && store[1].isConst())
+        {
+          val = false;
+        }
+        else
+        {
+          n = Rewriter::rewrite(mkEqNode(store[1], index));
+          if (n.getKind() != kind::CONST_BOOLEAN)
+          {
+            break;
+          }
+          val = n.getConst<bool>();
+        }
+        if (val)
+        {
+          // select(store(a,i,v),j) = v if i = j
+          Trace("arrays-postrewrite")
+              << "Arrays::postRewrite returning " << store[2] << std::endl;
+          return RewriteResponse(REWRITE_DONE, store[2]);
+        }
+        // select(store(a,i,v),j) = select(a,j) if i /= j
+        store = store[0];
+      }
+      if (store.getKind() == kind::STORE_ALL)
+      {
+        // select(store_all(v),i) = v
+        ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
+        n = storeAll.getValue();
+        Trace("arrays-postrewrite")
+            << "Arrays::postRewrite returning " << n << std::endl;
+        Assert(n.isConst());
+        return RewriteResponse(REWRITE_DONE, n);
+      }
+      else if (store != node[0])
+      {
+        n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index);
+        Trace("arrays-postrewrite")
+            << "Arrays::postRewrite returning " << n << std::endl;
+        return RewriteResponse(REWRITE_DONE, n);
+      }
+      break;
+    }
+    case kind::STORE:
+    {
+      TNode store = node[0];
+      TNode value = node[2];
+      // store(a,i,select(a,i)) = a
+      if (value.getKind() == kind::SELECT && value[0] == store
+          && value[1] == node[1])
+      {
+        Trace("arrays-postrewrite")
+            << "Arrays::postRewrite returning " << store << std::endl;
+        return RewriteResponse(REWRITE_DONE, store);
+      }
+      TNode index = node[1];
+      if (store.isConst() && index.isConst() && value.isConst())
+      {
+        // normalize constant
+        Node n = normalizeConstant(node);
+        Assert(n.isConst());
+        Trace("arrays-postrewrite")
+            << "Arrays::postRewrite returning " << n << std::endl;
+        return RewriteResponse(REWRITE_DONE, n);
+      }
+      if (store.getKind() == kind::STORE)
+      {
+        // store(store(a,i,v),j,w)
+        bool val;
+        if (index == store[1])
+        {
+          val = true;
+        }
+        else if (index.isConst() && store[1].isConst())
+        {
+          val = false;
+        }
+        else
+        {
+          Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index));
+          if (eqRewritten.getKind() != kind::CONST_BOOLEAN)
+          {
+            Trace("arrays-postrewrite")
+                << "Arrays::postRewrite returning " << node << std::endl;
+            return RewriteResponse(REWRITE_DONE, node);
+          }
+          val = eqRewritten.getConst<bool>();
+        }
+        NodeManager* nm = NodeManager::currentNM();
+        if (val)
+        {
+          // store(store(a,i,v),i,w) = store(a,i,w)
+          Node result = nm->mkNode(kind::STORE, store[0], index, value);
+          Trace("arrays-postrewrite")
+              << "Arrays::postRewrite returning " << result << std::endl;
+          return RewriteResponse(REWRITE_AGAIN, result);
+        }
+        else if (index < store[1])
+        {
+          // store(store(a,i,v),j,w) = store(store(a,j,w),i,v)
+          //    IF i != j and j comes before i in the ordering
+          std::vector<TNode> indices;
+          std::vector<TNode> elements;
+          indices.push_back(store[1]);
+          elements.push_back(store[2]);
+          store = store[0];
+          Node n;
+          while (store.getKind() == kind::STORE)
+          {
+            if (index == store[1])
+            {
+              val = true;
+            }
+            else if (index.isConst() && store[1].isConst())
+            {
+              val = false;
+            }
+            else
+            {
+              n = Rewriter::rewrite(mkEqNode(store[1], index));
+              if (n.getKind() != kind::CONST_BOOLEAN)
+              {
+                break;
+              }
+              val = n.getConst<bool>();
+            }
+            if (val)
+            {
+              store = store[0];
+              break;
+            }
+            else if (!(index < store[1]))
+            {
+              break;
+            }
+            indices.push_back(store[1]);
+            elements.push_back(store[2]);
+            store = store[0];
+          }
+          if (value.getKind() == kind::SELECT && value[0] == store
+              && value[1] == index)
+          {
+            n = store;
+          }
+          else
+          {
+            n = nm->mkNode(kind::STORE, store, index, value);
+          }
+          while (!indices.empty())
+          {
+            n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
+            indices.pop_back();
+            elements.pop_back();
+          }
+          Assert(n != node);
+          Trace("arrays-postrewrite")
+              << "Arrays::postRewrite returning " << n << std::endl;
+          return RewriteResponse(REWRITE_AGAIN, n);
+        }
+      }
+      break;
+    }
+    case kind::EQUAL:
+    {
+      if (node[0] == node[1])
+      {
+        Trace("arrays-postrewrite")
+            << "Arrays::postRewrite returning true" << std::endl;
+        return RewriteResponse(REWRITE_DONE,
+                               NodeManager::currentNM()->mkConst(true));
+      }
+      else if (node[0].isConst() && node[1].isConst())
+      {
+        Trace("arrays-postrewrite")
+            << "Arrays::postRewrite returning false" << std::endl;
+        return RewriteResponse(REWRITE_DONE,
+                               NodeManager::currentNM()->mkConst(false));
+      }
+      if (node[0] > node[1])
+      {
+        Node newNode =
+            NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]);
+        Trace("arrays-postrewrite")
+            << "Arrays::postRewrite returning " << newNode << std::endl;
+        return RewriteResponse(REWRITE_DONE, newNode);
+      }
+      break;
+    }
+    default: break;
+  }
+  Trace("arrays-postrewrite")
+      << "Arrays::postRewrite returning " << node << std::endl;
+  return RewriteResponse(REWRITE_DONE, node);
+}
+
+RewriteResponse TheoryArraysRewriter::preRewrite(TNode node)
+{
+  Trace("arrays-prerewrite")
+      << "Arrays::preRewrite start " << node << std::endl;
+  switch (node.getKind())
+  {
+    case kind::SELECT:
+    {
+      TNode store = node[0];
+      TNode index = node[1];
+      Node n;
+      bool val;
+      while (store.getKind() == kind::STORE)
+      {
+        if (index == store[1])
+        {
+          val = true;
+        }
+        else if (index.isConst() && store[1].isConst())
+        {
+          val = false;
+        }
+        else
+        {
+          n = Rewriter::rewrite(mkEqNode(store[1], index));
+          if (n.getKind() != kind::CONST_BOOLEAN)
+          {
+            break;
+          }
+          val = n.getConst<bool>();
+        }
+        if (val)
+        {
+          // select(store(a,i,v),j) = v if i = j
+          Trace("arrays-prerewrite")
+              << "Arrays::preRewrite returning " << store[2] << std::endl;
+          return RewriteResponse(REWRITE_AGAIN, store[2]);
+        }
+        // select(store(a,i,v),j) = select(a,j) if i /= j
+        store = store[0];
+      }
+      if (store.getKind() == kind::STORE_ALL)
+      {
+        // select(store_all(v),i) = v
+        ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
+        n = storeAll.getValue();
+        Trace("arrays-prerewrite")
+            << "Arrays::preRewrite returning " << n << std::endl;
+        Assert(n.isConst());
+        return RewriteResponse(REWRITE_DONE, n);
+      }
+      else if (store != node[0])
+      {
+        n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index);
+        Trace("arrays-prerewrite")
+            << "Arrays::preRewrite returning " << n << std::endl;
+        return RewriteResponse(REWRITE_DONE, n);
+      }
+      break;
+    }
+    case kind::STORE:
+    {
+      TNode store = node[0];
+      TNode value = node[2];
+      // store(a,i,select(a,i)) = a
+      if (value.getKind() == kind::SELECT && value[0] == store
+          && value[1] == node[1])
+      {
+        Trace("arrays-prerewrite")
+            << "Arrays::preRewrite returning " << store << std::endl;
+        return RewriteResponse(REWRITE_AGAIN, store);
+      }
+      if (store.getKind() == kind::STORE)
+      {
+        // store(store(a,i,v),j,w)
+        TNode index = node[1];
+        bool val;
+        if (index == store[1])
+        {
+          val = true;
+        }
+        else if (index.isConst() && store[1].isConst())
+        {
+          val = false;
+        }
+        else
+        {
+          Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index));
+          if (eqRewritten.getKind() != kind::CONST_BOOLEAN)
+          {
+            break;
+          }
+          val = eqRewritten.getConst<bool>();
+        }
+        NodeManager* nm = NodeManager::currentNM();
+        if (val)
+        {
+          // store(store(a,i,v),i,w) = store(a,i,w)
+          Node newNode = nm->mkNode(kind::STORE, store[0], index, value);
+          Trace("arrays-prerewrite")
+              << "Arrays::preRewrite returning " << newNode << std::endl;
+          return RewriteResponse(REWRITE_DONE, newNode);
+        }
+      }
+      break;
+    }
+    case kind::EQUAL:
+    {
+      if (node[0] == node[1])
+      {
+        Trace("arrays-prerewrite")
+            << "Arrays::preRewrite returning true" << std::endl;
+        return RewriteResponse(REWRITE_DONE,
+                               NodeManager::currentNM()->mkConst(true));
+      }
+      break;
+    }
+    default: break;
+  }
+
+  Trace("arrays-prerewrite")
+      << "Arrays::preRewrite returning " << node << std::endl;
+  return RewriteResponse(REWRITE_DONE, node);
+}
+
+TrustNode TheoryArraysRewriter::expandDefinition(Node node)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  Kind kind = node.getKind();
+
+  /* Expand
+   *
+   *   (eqrange a b i j)
+   *
+   * to
+   *
+   *  forall k . i <= k <= j => a[k] = b[k]
+   *
+   */
+  if (kind == kind::EQ_RANGE)
+  {
+    TNode a = node[0];
+    TNode b = node[1];
+    TNode i = node[2];
+    TNode j = node[3];
+    Node k = nm->mkBoundVar(i.getType());
+    Node bvl = nm->mkNode(kind::BOUND_VAR_LIST, k);
+    TypeNode type = k.getType();
+
+    Kind kle;
+    Node range;
+    if (type.isBitVector())
+    {
+      kle = kind::BITVECTOR_ULE;
+    }
+    else if (type.isFloatingPoint())
+    {
+      kle = kind::FLOATINGPOINT_LEQ;
+    }
+    else if (type.isInteger() || type.isReal())
+    {
+      kle = kind::LEQ;
+    }
+    else
+    {
+      Unimplemented() << "Type " << type << " is not supported for predicate "
+                      << kind;
+    }
+
+    range = nm->mkNode(kind::AND, nm->mkNode(kle, i, k), nm->mkNode(kle, k, j));
+
+    Node eq = nm->mkNode(kind::EQUAL,
+                         nm->mkNode(kind::SELECT, a, k),
+                         nm->mkNode(kind::SELECT, b, k));
+    Node implies = nm->mkNode(kind::IMPLIES, range, eq);
+    Node ret = nm->mkNode(kind::FORALL, bvl, implies);
+    return TrustNode::mkTrustRewrite(node, ret, nullptr);
+  }
+  return TrustNode::null();
+}
+
 }  // namespace arrays
 }  // namespace theory
 }  // namespace cvc5
index 0bbfc08467c9f9c9385cca4658c25407ec7620a8..498266ce3376310c093e13c12b11b981679f19b2 100644 (file)
@@ -43,459 +43,21 @@ static inline Node mkEqNode(Node a, Node b) {
 
 class TheoryArraysRewriter : public TheoryRewriter
 {
-  static Node normalizeConstant(TNode node) {
-    return normalizeConstant(node, node[1].getType().getCardinality());
-  }
+  /**
+   * Puts array constant node into normal form. This is so that array constants
+   * that are distinct nodes are semantically disequal.
+   */
+  static Node normalizeConstant(TNode node);
 
  public:
-  //this function is called by printers when using the option "--model-u-dt-enum"
-  static Node normalizeConstant(TNode node, Cardinality indexCard) {
-    TNode store = node[0];
-    TNode index = node[1];
-    TNode value = node[2];
+  /** Normalize a constant whose index type has cardinality indexCard */
+  static Node normalizeConstant(TNode node, Cardinality indexCard);
 
-    std::vector<TNode> indices;
-    std::vector<TNode> elements;
+  RewriteResponse postRewrite(TNode node) override;
 
-    // Normal form for nested stores is just ordering by index - but also need
-    // to check if we are writing to default value
+  RewriteResponse preRewrite(TNode node) override;
 
-    // Go through nested stores looking for where to insert index
-    // Also check whether we are replacing an existing store
-    TNode replacedValue;
-    unsigned depth = 1;
-    unsigned valCount = 1;
-    while (store.getKind() == kind::STORE) {
-      if (index == store[1]) {
-        replacedValue = store[2];
-        store = store[0];
-        break;
-      }
-      else if (!(index < store[1])) {
-        break;
-      }
-      if (value == store[2]) {
-        valCount += 1;
-      }
-      depth += 1;
-      indices.push_back(store[1]);
-      elements.push_back(store[2]);
-      store = store[0];
-    }
-    Node n = store;
-
-    // Get the default value at the bottom of the nested stores
-    while (store.getKind() == kind::STORE) {
-      if (value == store[2]) {
-        valCount += 1;
-      }
-      depth += 1;
-      store = store[0];
-    }
-    Assert(store.getKind() == kind::STORE_ALL);
-    ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
-    Node defaultValue = storeAll.getValue();
-    NodeManager* nm = NodeManager::currentNM();
-
-    // Check if we are writing to default value - if so the store
-    // to index can be ignored
-    if (value == defaultValue) {
-      if (replacedValue.isNull()) {
-        // Quick exit - if writing to default value and nothing was
-        // replaced, we can just return node[0]
-        return node[0];
-      }
-      // else rebuild the store without the replaced write and then exit
-    }
-    else {
-      n = nm->mkNode(kind::STORE, n, index, value);
-    }
-
-    // Build the rest of the store after inserting/deleting
-    while (!indices.empty()) {
-      n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
-      indices.pop_back();
-      elements.pop_back();
-    }
-
-    // Ready to exit if write was to the default value (see previous comment)
-    if (value == defaultValue) {
-      return n;
-    }
-
-    if (indexCard.isInfinite()) {
-      return n;
-    }
-
-    // When index sort is finite, we have to check whether there is any value
-    // that is written to more than the default value.  If so, it must become
-    // the new default value
-
-    TNode mostFrequentValue;
-    unsigned mostFrequentValueCount = 0;
-    store = node[0];
-    if (store.getKind() == kind::STORE) {
-      mostFrequentValue = getMostFrequentValue(store);
-      mostFrequentValueCount = getMostFrequentValueCount(store);
-    }
-
-    // Compute the most frequently written value for n
-    if (valCount > mostFrequentValueCount ||
-        (valCount == mostFrequentValueCount && value < mostFrequentValue)) {
-      mostFrequentValue = value;
-      mostFrequentValueCount = valCount;
-    }
-
-    // Need to make sure the default value count is larger, or the same and the default value is expression-order-less-than nextValue
-    Cardinality::CardinalityComparison compare = indexCard.compare(mostFrequentValueCount + depth);
-    Assert(compare != Cardinality::UNKNOWN);
-    if (compare == Cardinality::GREATER ||
-        (compare == Cardinality::EQUAL && (defaultValue < mostFrequentValue))) {
-      return n;
-    }
-
-    // Bad case: have to recompute value counts and/or possibly switch out
-    // default value
-    store = n;
-    std::unordered_set<TNode, TNodeHashFunction> indexSet;
-    std::unordered_map<TNode, unsigned, TNodeHashFunction> elementsMap;
-    std::unordered_map<TNode, unsigned, TNodeHashFunction>::iterator it;
-    unsigned count;
-    unsigned max = 0;
-    TNode maxValue;
-    while (store.getKind() == kind::STORE) {
-      indices.push_back(store[1]);
-      indexSet.insert(store[1]);
-      elements.push_back(store[2]);
-      it = elementsMap.find(store[2]);
-      if (it != elementsMap.end()) {
-        (*it).second = (*it).second + 1;
-        count = (*it).second;
-      }
-      else {
-        elementsMap[store[2]] = 1;
-        count = 1;
-      }
-      if (count > max ||
-          (count == max && store[2] < maxValue)) {
-        max = count;
-        maxValue = store[2];
-      }
-      store = store[0];
-    }
-
-    Assert(depth == indices.size());
-    compare = indexCard.compare(max + depth);
-    Assert(compare != Cardinality::UNKNOWN);
-    if (compare == Cardinality::GREATER ||
-        (compare == Cardinality::EQUAL && (defaultValue < maxValue))) {
-      Assert(!replacedValue.isNull() && mostFrequentValue == replacedValue);
-      return n;
-    }
-
-    // Out of luck: have to swap out default value
-
-    // Enumerate values from index type into newIndices and sort
-    std::vector<Node> newIndices;
-    TypeEnumerator te(index.getType());
-    bool needToSort = false;
-    unsigned numTe = 0;
-    while (!te.isFinished() && (!indexCard.isFinite() || numTe<indexCard.getFiniteCardinality().toUnsignedInt())) {
-      if (indexSet.find(*te) == indexSet.end()) {
-        if (!newIndices.empty() && (!(newIndices.back() < (*te)))) {
-          needToSort = true;
-        }
-        newIndices.push_back(*te);
-      }
-      ++numTe;
-      ++te;
-    }
-    Assert(indexCard.compare(newIndices.size() + depth) == Cardinality::EQUAL);
-    if (needToSort) {
-      std::sort(newIndices.begin(), newIndices.end());
-    }
-
-    n = nm->mkConst(ArrayStoreAll(node.getType(), maxValue));
-    std::vector<Node>::iterator itNew = newIndices.begin(), it_end = newIndices.end();
-    while (itNew != it_end || !indices.empty()) {
-      if (itNew != it_end && (indices.empty() || (*itNew) < indices.back())) {
-        n = nm->mkNode(kind::STORE, n, (*itNew), defaultValue);
-        ++itNew;
-      }
-      else if (itNew == it_end || indices.back() < (*itNew)) {
-        if (elements.back() != maxValue) {
-          n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
-        }
-        indices.pop_back();
-        elements.pop_back();
-      }
-    }
-    return n;
-  }
-
- public:
-  RewriteResponse postRewrite(TNode node) override
-  {
-    Trace("arrays-postrewrite") << "Arrays::postRewrite start " << node << std::endl;
-    switch (node.getKind()) {
-      case kind::SELECT: {
-        TNode store = node[0];
-        TNode index = node[1];
-        Node n;
-        bool val;
-        while (store.getKind() == kind::STORE) {
-          if (index == store[1]) {
-            val = true;
-          }
-          else if (index.isConst() && store[1].isConst()) {
-            val = false;
-          }
-          else {
-            n = Rewriter::rewrite(mkEqNode(store[1], index));
-            if (n.getKind() != kind::CONST_BOOLEAN) {
-              break;
-            }
-            val = n.getConst<bool>();
-          }
-          if (val) {
-            // select(store(a,i,v),j) = v if i = j
-            Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << store[2] << std::endl;
-            return RewriteResponse(REWRITE_DONE, store[2]);
-          }
-          // select(store(a,i,v),j) = select(a,j) if i /= j
-          store = store[0];
-        }
-        if (store.getKind() == kind::STORE_ALL) {
-          // select(store_all(v),i) = v
-          ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
-          n = storeAll.getValue();
-          Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl;
-          Assert(n.isConst());
-          return RewriteResponse(REWRITE_DONE, n);
-        }
-        else if (store != node[0]) {
-          n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index);
-          Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl;
-          return RewriteResponse(REWRITE_DONE, n);
-        }
-        break;
-      }
-      case kind::STORE: {
-        TNode store = node[0];
-        TNode value = node[2];
-        // store(a,i,select(a,i)) = a
-        if (value.getKind() == kind::SELECT &&
-            value[0] == store &&
-            value[1] == node[1]) {
-          Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << store << std::endl;
-          return RewriteResponse(REWRITE_DONE, store);
-        }
-        TNode index = node[1];
-        if (store.isConst() && index.isConst() && value.isConst()) {
-          // normalize constant
-          Node n = normalizeConstant(node);
-          Assert(n.isConst());
-          Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl;
-          return RewriteResponse(REWRITE_DONE, n);
-        }
-        if (store.getKind() == kind::STORE) {
-          // store(store(a,i,v),j,w)
-          bool val;
-          if (index == store[1]) {
-            val = true;
-          }
-          else if (index.isConst() && store[1].isConst()) {
-            val = false;
-          }
-          else {
-            Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index));
-            if (eqRewritten.getKind() != kind::CONST_BOOLEAN) {
-              Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << node << std::endl;
-              return RewriteResponse(REWRITE_DONE, node);
-            }
-            val = eqRewritten.getConst<bool>();
-          }
-          NodeManager* nm = NodeManager::currentNM();
-          if (val) {
-            // store(store(a,i,v),i,w) = store(a,i,w)
-            Node result = nm->mkNode(kind::STORE, store[0], index, value);
-            Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << result << std::endl;
-            return RewriteResponse(REWRITE_AGAIN, result);
-          }
-          else if (index < store[1]) {
-            // store(store(a,i,v),j,w) = store(store(a,j,w),i,v)
-            //    IF i != j and j comes before i in the ordering
-            std::vector<TNode> indices;
-            std::vector<TNode> elements;
-            indices.push_back(store[1]);
-            elements.push_back(store[2]);
-            store = store[0];
-            Node n;
-            while (store.getKind() == kind::STORE) {
-              if (index == store[1]) {
-                val = true;
-              }
-              else if (index.isConst() && store[1].isConst()) {
-                val = false;
-              }
-              else {
-                n = Rewriter::rewrite(mkEqNode(store[1], index));
-                if (n.getKind() != kind::CONST_BOOLEAN) {
-                  break;
-                }
-                val = n.getConst<bool>();
-              }
-              if (val) {
-                store = store[0];
-                break;
-              }
-              else if (!(index < store[1])) {
-                break;
-              }
-              indices.push_back(store[1]);
-              elements.push_back(store[2]);
-              store = store[0];
-            }
-            if (value.getKind() == kind::SELECT &&
-                value[0] == store &&
-                value[1] == index) {
-              n = store;
-            }
-            else {
-              n = nm->mkNode(kind::STORE, store, index, value);
-            }
-            while (!indices.empty()) {
-              n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
-              indices.pop_back();
-              elements.pop_back();
-            }
-            Assert(n != node);
-            Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl;
-            return RewriteResponse(REWRITE_AGAIN, n);
-          }
-        }
-        break;
-      }
-      case kind::EQUAL:{
-        if(node[0] == node[1]) {
-          Trace("arrays-postrewrite") << "Arrays::postRewrite returning true" << std::endl;
-          return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
-        }
-        else if (node[0].isConst() && node[1].isConst()) {
-          Trace("arrays-postrewrite") << "Arrays::postRewrite returning false" << std::endl;
-          return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(false));
-        }
-        if (node[0] > node[1]) {
-          Node newNode = NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]);
-          Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << newNode << std::endl;
-          return RewriteResponse(REWRITE_DONE, newNode);
-        }
-        break;
-      }
-      default:
-        break;
-    }
-    Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << node << std::endl;
-    return RewriteResponse(REWRITE_DONE, node);
-  }
-
-  RewriteResponse preRewrite(TNode node) override
-  {
-    Trace("arrays-prerewrite") << "Arrays::preRewrite start " << node << std::endl;
-    switch (node.getKind()) {
-      case kind::SELECT: {
-        TNode store = node[0];
-        TNode index = node[1];
-        Node n;
-        bool val;
-        while (store.getKind() == kind::STORE) {
-          if (index == store[1]) {
-            val = true;
-          }
-          else if (index.isConst() && store[1].isConst()) {
-            val = false;
-          }
-          else {
-            n = Rewriter::rewrite(mkEqNode(store[1], index));
-            if (n.getKind() != kind::CONST_BOOLEAN) {
-              break;
-            }
-            val = n.getConst<bool>();
-          }
-          if (val) {
-            // select(store(a,i,v),j) = v if i = j
-            Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << store[2] << std::endl;
-            return RewriteResponse(REWRITE_AGAIN, store[2]);
-          }
-          // select(store(a,i,v),j) = select(a,j) if i /= j
-          store = store[0];
-        }
-        if (store.getKind() == kind::STORE_ALL) {
-          // select(store_all(v),i) = v
-          ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
-          n = storeAll.getValue();
-          Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << n << std::endl;
-          Assert(n.isConst());
-          return RewriteResponse(REWRITE_DONE, n);
-        }
-        else if (store != node[0]) {
-          n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index);
-          Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << n << std::endl;
-          return RewriteResponse(REWRITE_DONE, n);
-        }
-        break;
-      }
-      case kind::STORE: {
-        TNode store = node[0];
-        TNode value = node[2];
-        // store(a,i,select(a,i)) = a
-        if (value.getKind() == kind::SELECT &&
-            value[0] == store &&
-            value[1] == node[1]) {
-          Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << store << std::endl;
-          return RewriteResponse(REWRITE_AGAIN, store);
-        }
-        if (store.getKind() == kind::STORE) {
-          // store(store(a,i,v),j,w)
-          TNode index = node[1];
-          bool val;
-          if (index == store[1]) {
-            val = true;
-          }
-          else if (index.isConst() && store[1].isConst()) {
-            val = false;
-          }
-          else {
-            Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index));
-            if (eqRewritten.getKind() != kind::CONST_BOOLEAN) {
-              break;
-            }
-            val = eqRewritten.getConst<bool>();
-          }
-          NodeManager* nm = NodeManager::currentNM();
-          if (val) {
-            // store(store(a,i,v),i,w) = store(a,i,w)
-            Node newNode = nm->mkNode(kind::STORE, store[0], index, value);
-            Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << newNode << std::endl;
-            return RewriteResponse(REWRITE_DONE, newNode);
-          }
-        }
-        break;
-      }
-      case kind::EQUAL:{
-        if(node[0] == node[1]) {
-          Trace("arrays-prerewrite") << "Arrays::preRewrite returning true" << std::endl;
-          return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
-        }
-        break;
-      }
-      default:
-        break;
-    }
-
-    Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << node << std::endl;
-    return RewriteResponse(REWRITE_DONE, node);
-  }
+  TrustNode expandDefinition(Node node) override;
 
   static inline void init() {}
   static inline void shutdown() {}
index 54c1d7253e1eb9bf22c9eb1085b395b8a88dbbf6..83f364f9d3dd3eaed2c534ee4ccc04c503f4f471 100644 (file)
@@ -56,7 +56,6 @@ class BagsRewriter : public TheoryRewriter
    * See the rewrite rules for these kinds below.
    */
   RewriteResponse preRewrite(TNode n) override;
-
  private:
   /**
    * rewrites for n include:
index 89312f99aa71d117679d75848e5b8dc4239a74f9..d5917f91c2c063e4b4416feb401992a2f89b68d8 100644 (file)
@@ -226,12 +226,6 @@ void TheoryBags::preRegisterTerm(TNode n)
   }
 }
 
-TrustNode TheoryBags::expandDefinition(Node n)
-{
-  // TODO(projects#224): add choose and is_singleton here
-  return TrustNode::null();
-}
-
 void TheoryBags::presolve() {}
 
 /**************************** eq::NotifyClass *****************************/
index 7b9299f54b2802109b8f4c8dd513fe6c53176c85..4ed131e64ec19bd909d90d6501b7c65fc9f93e3e 100644 (file)
@@ -72,7 +72,6 @@ class TheoryBags : public Theory
   Node getModelValue(TNode) override;
   std::string identify() const override { return "THEORY_BAGS"; }
   void preRegisterTerm(TNode n) override;
-  TrustNode expandDefinition(Node n) override;
   void presolve() override;
 
  private:
index 06f837c7f0147f818b0687e19b27f9cc943dc7ad..a0f3f28f776c927110ccd8c9095063931acc14d1 100644 (file)
@@ -134,29 +134,6 @@ void TheoryBV::finishInit()
   }
 }
 
-TrustNode TheoryBV::expandDefinition(Node node)
-{
-  Debug("bitvector-expandDefinition")
-      << "TheoryBV::expandDefinition(" << node << ")" << std::endl;
-
-  Node ret;
-  switch (node.getKind())
-  {
-    case kind::BITVECTOR_SDIV:
-    case kind::BITVECTOR_SREM:
-    case kind::BITVECTOR_SMOD:
-      ret = TheoryBVRewriter::eliminateBVSDiv(node);
-      break;
-
-    default: break;
-  }
-  if (!ret.isNull() && node != ret)
-  {
-    return TrustNode::mkTrustRewrite(node, ret, nullptr);
-  }
-  return TrustNode::null();
-}
-
 void TheoryBV::preRegisterTerm(TNode node)
 {
   d_internal->preRegisterTerm(node);
@@ -211,7 +188,7 @@ Theory::PPAssertStatus TheoryBV::ppAssert(
 TrustNode TheoryBV::ppRewrite(TNode t, std::vector<SkolemLemma>& lems)
 {
   // first, see if we need to expand definitions
-  TrustNode texp = expandDefinition(t);
+  TrustNode texp = d_rewriter.expandDefinition(t);
   if (!texp.isNull())
   {
     return texp;
index 0546e83ac8a008c05ca67e2c823b67b779e5afdf..1f14f05b00d0fdb85c91156ad0cb55090e5646a6 100644 (file)
@@ -63,8 +63,6 @@ class TheoryBV : public Theory
 
   void finishInit() override;
 
-  TrustNode expandDefinition(Node node) override;
-
   void preRegisterTerm(TNode n) override;
 
   bool preCheck(Effort e) override;
index 076ea67a9872c778f026fca643839458b262634f..9b3fde6fb2a6891a71b5fdd1f2c95238cbe92af6 100644 (file)
@@ -54,6 +54,26 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) {
   return res; 
 }
 
+TrustNode TheoryBVRewriter::expandDefinition(Node node)
+{
+  Debug("bitvector-expandDefinition")
+      << "TheoryBV::expandDefinition(" << node << ")" << std::endl;
+  Node ret;
+  switch (node.getKind())
+  {
+    case kind::BITVECTOR_SDIV:
+    case kind::BITVECTOR_SREM:
+    case kind::BITVECTOR_SMOD: ret = eliminateBVSDiv(node); break;
+
+    default: break;
+  }
+  if (!ret.isNull() && node != ret)
+  {
+    return TrustNode::mkTrustRewrite(node, ret, nullptr);
+  }
+  return TrustNode::null();
+}
+
 RewriteResponse TheoryBVRewriter::RewriteBitOf(TNode node, bool prerewrite)
 {
   Node resultNode = LinearRewriteStrategy<RewriteRule<BitOfConst>>::apply(node);
index dca6bc9b4055edc9a4e93c3ab28e323c7d8a2e31..e351840846ca17f9c6008b3e455dbf53ee252929 100644 (file)
@@ -47,6 +47,8 @@ class TheoryBVRewriter : public TheoryRewriter
   RewriteResponse postRewrite(TNode node) override;
   RewriteResponse preRewrite(TNode node) override;
 
+  TrustNode expandDefinition(Node node) override;
+
  private:
   static RewriteResponse IdentityRewrite(TNode node, bool prerewrite = false);
   static RewriteResponse UndefinedRewrite(TNode node, bool prerewrite = false);
index b85ac44bf2edffb808f79685e73c9567559f4e63..a6d3a45bc5dee3a8e41b9049f429af03becb6bd0 100644 (file)
@@ -18,6 +18,7 @@
 #include "expr/dtype.h"
 #include "expr/dtype_cons.h"
 #include "expr/node_algorithm.h"
+#include "expr/skolem_manager.h"
 #include "expr/sygus_datatype.h"
 #include "options/datatypes_options.h"
 #include "theory/datatypes/sygus_datatype_utils.h"
@@ -793,6 +794,111 @@ Node DatatypesRewriter::replaceDebruijn(Node n,
   return n;
 }
 
+TrustNode DatatypesRewriter::expandDefinition(Node n)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  TypeNode tn = n.getType();
+  Node ret;
+  switch (n.getKind())
+  {
+    case kind::APPLY_SELECTOR:
+    {
+      Node selector = n.getOperator();
+      // APPLY_SELECTOR always applies to an external selector, cindexOf is
+      // legal here
+      size_t cindex = utils::cindexOf(selector);
+      const DType& dt = utils::datatypeOf(selector);
+      const DTypeConstructor& c = dt[cindex];
+      Node selector_use;
+      TypeNode ndt = n[0].getType();
+      if (options::dtSharedSelectors())
+      {
+        size_t selectorIndex = utils::indexOf(selector);
+        Trace("dt-expand") << "...selector index = " << selectorIndex
+                           << std::endl;
+        Assert(selectorIndex < c.getNumArgs());
+        selector_use = c.getSelectorInternal(ndt, selectorIndex);
+      }
+      else
+      {
+        selector_use = selector;
+      }
+      Node sel = nm->mkNode(kind::APPLY_SELECTOR_TOTAL, selector_use, n[0]);
+      if (options::dtRewriteErrorSel())
+      {
+        ret = sel;
+      }
+      else
+      {
+        Node tester = c.getTester();
+        Node tst = nm->mkNode(APPLY_TESTER, tester, n[0]);
+        SkolemManager* sm = nm->getSkolemManager();
+        TypeNode tnw = nm->mkFunctionType(ndt, n.getType());
+        Node f =
+            sm->mkSkolemFunction(SkolemFunId::SELECTOR_WRONG, tnw, selector);
+        Node sk = nm->mkNode(kind::APPLY_UF, f, n[0]);
+        ret = nm->mkNode(kind::ITE, tst, sel, sk);
+        Trace("dt-expand") << "Expand def : " << n << " to " << ret
+                           << std::endl;
+      }
+    }
+    break;
+    case TUPLE_UPDATE:
+    case RECORD_UPDATE:
+    {
+      Assert(tn.isDatatype());
+      const DType& dt = tn.getDType();
+      NodeBuilder b(APPLY_CONSTRUCTOR);
+      b << dt[0].getConstructor();
+      size_t size, updateIndex;
+      if (n.getKind() == TUPLE_UPDATE)
+      {
+        Assert(tn.isTuple());
+        size = tn.getTupleLength();
+        updateIndex = n.getOperator().getConst<TupleUpdate>().getIndex();
+      }
+      else
+      {
+        Assert(tn.isRecord());
+        const DTypeConstructor& recCons = dt[0];
+        size = recCons.getNumArgs();
+        // get the index for the name
+        updateIndex = recCons.getSelectorIndexForName(
+            n.getOperator().getConst<RecordUpdate>().getField());
+      }
+      Debug("tuprec") << "expr is " << n << std::endl;
+      Debug("tuprec") << "updateIndex is " << updateIndex << std::endl;
+      Debug("tuprec") << "t is " << tn << std::endl;
+      Debug("tuprec") << "t has arity " << size << std::endl;
+      for (size_t i = 0; i < size; ++i)
+      {
+        if (i == updateIndex)
+        {
+          b << n[1];
+          Debug("tuprec") << "arg " << i << " gets updated to " << n[1]
+                          << std::endl;
+        }
+        else
+        {
+          b << nm->mkNode(
+              APPLY_SELECTOR_TOTAL, dt[0].getSelectorInternal(tn, i), n[0]);
+          Debug("tuprec") << "arg " << i << " copies "
+                          << b[b.getNumChildren() - 1] << std::endl;
+        }
+      }
+      ret = b;
+      Debug("tuprec") << "return " << ret << std::endl;
+    }
+    break;
+    default: break;
+  }
+  if (!ret.isNull() && n != ret)
+  {
+    return TrustNode::mkTrustRewrite(n, ret, nullptr);
+  }
+  return TrustNode::null();
+}
+
 }  // namespace datatypes
 }  // namespace theory
 }  // namespace cvc5
index 3b9b14fb7b44fc7a6430c2b0629d54ee20c6c5e9..c9a40ff7bf46fd7fad203e42557bc03d98882467 100644 (file)
@@ -48,6 +48,8 @@ class DatatypesRewriter : public TheoryRewriter
    * on all top-level codatatype subterms of n.
    */
   static Node normalizeConstant(Node n);
+  /** expand defintions */
+  TrustNode expandDefinition(Node n) override;
 
  private:
   /** rewrite constructor term in */
index 01ef77172fc0645b14c170b68cab08d58442151f..f9d08dfc2f38b1650df7eb1644f4a690c017e00f 100644 (file)
@@ -482,127 +482,11 @@ void TheoryDatatypes::preRegisterTerm(TNode n)
   d_im.process();
 }
 
-TrustNode TheoryDatatypes::expandDefinition(Node n)
-{
-  NodeManager* nm = NodeManager::currentNM();
-  TypeNode tn = n.getType();
-  Node ret;
-  switch (n.getKind())
-  {
-    case kind::APPLY_SELECTOR:
-    {
-      Node selector = n.getOperator();
-      // APPLY_SELECTOR always applies to an external selector, cindexOf is
-      // legal here
-      size_t cindex = utils::cindexOf(selector);
-      const DType& dt = utils::datatypeOf(selector);
-      const DTypeConstructor& c = dt[cindex];
-      Node selector_use;
-      TypeNode ndt = n[0].getType();
-      if (options::dtSharedSelectors())
-      {
-        size_t selectorIndex = utils::indexOf(selector);
-        Trace("dt-expand") << "...selector index = " << selectorIndex
-                           << std::endl;
-        Assert(selectorIndex < c.getNumArgs());
-        selector_use = c.getSelectorInternal(ndt, selectorIndex);
-      }else{
-        selector_use = selector;
-      }
-      Node sel = nm->mkNode(kind::APPLY_SELECTOR_TOTAL, selector_use, n[0]);
-      if (options::dtRewriteErrorSel())
-      {
-        ret = sel;
-      }
-      else
-      {
-        Node tester = c.getTester();
-        Node tst = nm->mkNode(APPLY_TESTER, tester, n[0]);
-        tst = Rewriter::rewrite(tst);
-        if (tst == d_true)
-        {
-          ret = sel;
-        }else{
-          SkolemManager* sm = nm->getSkolemManager();
-          TypeNode tnw = nm->mkFunctionType(ndt, n.getType());
-          Node f =
-              sm->mkSkolemFunction(SkolemFunId::SELECTOR_WRONG, tnw, selector);
-          Node sk = nm->mkNode(kind::APPLY_UF, f, n[0]);
-          if (tst == nm->mkConst(false))
-          {
-            ret = sk;
-          }
-          else
-          {
-            ret = nm->mkNode(kind::ITE, tst, sel, sk);
-          }
-        }
-        Trace("dt-expand") << "Expand def : " << n << " to " << ret
-                           << std::endl;
-      }
-    }
-    break;
-    case TUPLE_UPDATE:
-    case RECORD_UPDATE:
-    {
-      Assert(tn.isDatatype());
-      const DType& dt = tn.getDType();
-      NodeBuilder b(APPLY_CONSTRUCTOR);
-      b << dt[0].getConstructor();
-      size_t size, updateIndex;
-      if (n.getKind() == TUPLE_UPDATE)
-      {
-        Assert(tn.isTuple());
-        size = tn.getTupleLength();
-        updateIndex = n.getOperator().getConst<TupleUpdate>().getIndex();
-      }
-      else
-      {
-        Assert(tn.isRecord());
-        const DTypeConstructor& recCons = dt[0];
-        size = recCons.getNumArgs();
-        // get the index for the name
-        updateIndex = recCons.getSelectorIndexForName(
-            n.getOperator().getConst<RecordUpdate>().getField());
-      }
-      Debug("tuprec") << "expr is " << n << std::endl;
-      Debug("tuprec") << "updateIndex is " << updateIndex << std::endl;
-      Debug("tuprec") << "t is " << tn << std::endl;
-      Debug("tuprec") << "t has arity " << size << std::endl;
-      for (size_t i = 0; i < size; ++i)
-      {
-        if (i == updateIndex)
-        {
-          b << n[1];
-          Debug("tuprec") << "arg " << i << " gets updated to " << n[1]
-                          << std::endl;
-        }
-        else
-        {
-          b << nm->mkNode(
-              APPLY_SELECTOR_TOTAL, dt[0].getSelectorInternal(tn, i), n[0]);
-          Debug("tuprec") << "arg " << i << " copies "
-                          << b[b.getNumChildren() - 1] << std::endl;
-        }
-      }
-      ret = b;
-      Debug("tuprec") << "return " << ret << std::endl;
-    }
-    break;
-    default: break;
-  }
-  if (!ret.isNull() && n != ret)
-  {
-    return TrustNode::mkTrustRewrite(n, ret, nullptr);
-  }
-  return TrustNode::null();
-}
-
 TrustNode TheoryDatatypes::ppRewrite(TNode in, std::vector<SkolemLemma>& lems)
 {
   Debug("tuprec") << "TheoryDatatypes::ppRewrite(" << in << ")" << endl;
   // first, see if we need to expand definitions
-  TrustNode texp = expandDefinition(in);
+  TrustNode texp = d_rewriter.expandDefinition(in);
   if (!texp.isNull())
   {
     return texp;
index eb55ce6b014551c9445fd0ab9b005eb3b07ec02e..1ae122f5e2f801d6d9a002464568d4f4b16bfb10 100644 (file)
@@ -227,7 +227,6 @@ private:
   void notifyFact(TNode atom, bool pol, TNode fact, bool isInternal) override;
   //--------------------------------- end standard check
   void preRegisterTerm(TNode n) override;
-  TrustNode expandDefinition(Node n) override;
   TrustNode ppRewrite(TNode n, std::vector<SkolemLemma>& lems) override;
   EqualityStatus getEqualityStatus(TNode a, TNode b) override;
   std::string identify() const override
index 6629a839d6d2bcba2580c98f6f90bc0dc2af63d5..01dace411086739bb00b46aa7746dea62e82bf24 100644 (file)
@@ -681,11 +681,6 @@ void TheoryFp::preRegisterTerm(TNode node)
   return;
 }
 
-TrustNode TheoryFp::expandDefinition(Node node)
-{
-  return d_rewriter.expandDefinition(node);
-}
-
 void TheoryFp::handleLemma(Node node, InferenceId id)
 {
   Trace("fp") << "TheoryFp::handleLemma(): asserting " << node << std::endl;
index 78791b9b4db4566426c076e659b7e36c272d2595..8cf4c4cc55340c1ad00e795e27c296ddda8733ef 100644 (file)
@@ -62,7 +62,6 @@ class TheoryFp : public Theory
   //--------------------------------- end initialization
 
   void preRegisterTerm(TNode node) override;
-  TrustNode expandDefinition(Node node) override;
   TrustNode ppRewrite(TNode node, std::vector<SkolemLemma>& lems) override;
 
   //--------------------------------- standard check
index 97c0e216b61abe14d9ee5a7e9e19c5767eb7638f..027dd9819cf543f0aadbf8e04135ac23f847ebc9 100644 (file)
@@ -46,8 +46,8 @@ class TheoryFpRewriter : public TheoryRewriter
     // often this will suffice
     return postRewrite(equality).d_node;
   }
-  /** Expand definitions in node. */
-  TrustNode expandDefinition(Node node);
+  /** Expand definitions in node */
+  TrustNode expandDefinition(Node node) override;
 
  protected:
   /** TODO: document (projects issue #265) */
index fdb744d67ff03feb99f4a15a6c373adca5a35e9c..8406bd14ab03a6f6d10f70a56502a19bf11ec6bf 100644 (file)
@@ -131,12 +131,6 @@ void TheorySets::preRegisterTerm(TNode node)
   d_internal->preRegisterTerm(node);
 }
 
-TrustNode TheorySets::expandDefinition(Node n)
-{
-  // we currently do not expand any set operators
-  return TrustNode::null();
-}
-
 TrustNode TheorySets::ppRewrite(TNode n, std::vector<SkolemLemma>& lems)
 {
   Kind nk = n.getKind();
index bb8741e356d0840568a3493810986377cbe0294b..e99d25d36789076a1f64cdf5f44df6bb3a0f00db 100644 (file)
@@ -78,8 +78,6 @@ class TheorySets : public Theory
   Node getModelValue(TNode) override;
   std::string identify() const override { return "THEORY_SETS"; }
   void preRegisterTerm(TNode node) override;
-  /**  Expand partial operators (choose) from n. */
-  TrustNode expandDefinition(Node n) override;
   /**
    * If the sets-ext option is not set and we have an extended operator,
    * we throw an exception. Additionally, we expand operators like choose
index 431f488a57a1c1952ab2455d17eedfadbfe29b25..84127e8e378d63bbeaa8fab43a99863eea152525 100644 (file)
@@ -21,6 +21,7 @@
 #include "theory/rewriter.h"
 #include "theory/strings/arith_entail.h"
 #include "theory/strings/regexp_entail.h"
+#include "theory/strings/skolem_cache.h"
 #include "theory/strings/strings_rewriter.h"
 #include "theory/strings/theory_strings_utils.h"
 #include "theory/strings/word.h"
@@ -1514,6 +1515,30 @@ RewriteResponse SequencesRewriter::preRewrite(TNode node)
   return RewriteResponse(REWRITE_DONE, node);
 }
 
+TrustNode SequencesRewriter::expandDefinition(Node node)
+{
+  Trace("strings-exp-def") << "SequencesRewriter::expandDefinition : " << node
+                           << std::endl;
+
+  if (node.getKind() == kind::SEQ_NTH)
+  {
+    NodeManager* nm = NodeManager::currentNM();
+    Node s = node[0];
+    Node n = node[1];
+    // seq.nth(s, n) --> ite(0 <= n < len(s), seq.nth_total(s,n), Uf(s, n))
+    Node cond = nm->mkNode(AND,
+                           nm->mkNode(LEQ, nm->mkConst(Rational(0)), n),
+                           nm->mkNode(LT, n, nm->mkNode(STRING_LENGTH, s)));
+    Node ss = nm->mkNode(SEQ_NTH_TOTAL, s, n);
+    Node uf = SkolemCache::mkSkolemSeqNth(s.getType(), "Uf");
+    Node u = nm->mkNode(APPLY_UF, uf, s, n);
+    Node ret = nm->mkNode(ITE, cond, ss, u);
+    Trace("strings-exp-def") << "...return " << ret << std::endl;
+    return TrustNode::mkTrustRewrite(node, ret, nullptr);
+  }
+  return TrustNode::null();
+}
+
 Node SequencesRewriter::rewriteSeqNth(Node node)
 {
   Assert(node.getKind() == SEQ_NTH || node.getKind() == SEQ_NTH_TOTAL);
index 97db2c7f4ca99d50c9d46e54f527e73ad2016668..7af24596ac2d6dff464c008907e99e48e1b8a7aa 100644 (file)
@@ -127,6 +127,8 @@ class SequencesRewriter : public TheoryRewriter
  public:
   RewriteResponse postRewrite(TNode node) override;
   RewriteResponse preRewrite(TNode node) override;
+  /** Expand definition */
+  TrustNode expandDefinition(Node n) override;
 
   /** rewrite equality
    *
index 0ed003cc7297ffa26c973578a5141f8235b26b74..956f2148c43be5eb78d746d278f6fabd34806305 100644 (file)
@@ -553,29 +553,6 @@ void TheoryStrings::preRegisterTerm(TNode n)
   d_extTheory.registerTerm(n);
 }
 
-TrustNode TheoryStrings::expandDefinition(Node node)
-{
-  Trace("strings-exp-def") << "TheoryStrings::expandDefinition : " << node << std::endl;
-
-  if (node.getKind() == kind::SEQ_NTH)
-  {
-    NodeManager* nm = NodeManager::currentNM();
-    Node s = node[0];
-    Node n = node[1];
-    // seq.nth(s, n) --> ite(0 <= n < len(s), seq.nth_total(s,n), Uf(s, n))
-    Node cond = nm->mkNode(AND,
-                           nm->mkNode(LEQ, d_zero, n),
-                           nm->mkNode(LT, n, nm->mkNode(STRING_LENGTH, s)));
-    Node ss = nm->mkNode(SEQ_NTH_TOTAL, s, n);
-    Node uf = SkolemCache::mkSkolemSeqNth(s.getType(), "Uf");
-    Node u = nm->mkNode(APPLY_UF, uf, s, n);
-    Node ret = nm->mkNode(ITE, cond, ss, u);
-    Trace("strings-exp-def") << "...return " << ret << std::endl;
-    return TrustNode::mkTrustRewrite(node, ret, nullptr);
-  }
-  return TrustNode::null();
-}
-
 bool TheoryStrings::preNotifyFact(
     TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal)
 {
index fb6df80c7f7fd9e7fc8e1ea39de8fd57725a165e..01111880dff3c735dd7fdf027eb9eb2b7989c1fc 100644 (file)
@@ -94,8 +94,6 @@ class TheoryStrings : public Theory {
   void shutdown() override {}
   /** preregister term */
   void preRegisterTerm(TNode n) override;
-  /** Expand definition */
-  TrustNode expandDefinition(Node n) override;
   //--------------------------------- standard check
   /** Do we need a check call at last call effort? */
   bool needsCheckLastEffort() override;
index 247ebcf46edfa3658fdd5ef59a07f3d53e1f5a4b..9cf663a4faf0d85960d82893b3e2c25c955c5b79 100644 (file)
@@ -497,39 +497,6 @@ class Theory {
    */
   TheoryInferenceManager* getInferenceManager() { return d_inferManager; }
 
-  /**
-   * Expand definitions in the term node. This returns a term that is
-   * equivalent to node. It wraps this term in a TrustNode of kind
-   * TrustNodeKind::REWRITE. If node is unchanged by this method, the
-   * null TrustNode may be returned. This is an optimization to avoid
-   * constructing the trivial equality (= node node) internally within
-   * TrustNode.
-   *
-   * The purpose of this method is typically to eliminate the operators in node
-   * that are syntax sugar that cannot otherwise be eliminated during rewriting.
-   * For example, division relies on the introduction of an uninterpreted
-   * function for the divide-by-zero case, which we do not introduce with
-   * the rewriter, since this function may be cached in a non-global fashion.
-   *
-   * Some theories have kinds that are effectively definitions and should be
-   * expanded before they are handled.  Definitions allow a much wider range of
-   * actions than the normal forms given by the rewriter. However no
-   * assumptions can be made about subterms having been expanded or rewritten.
-   * Where possible rewrite rules should be used, definitions should only be
-   * used when rewrites are not possible, for example in handling
-   * under-specified operations using partially defined functions.
-   *
-   * Some theories like sets use expandDefinition as a "context
-   * independent preRegisterTerm".  This is required for cases where
-   * a theory wants to be notified about a term before preprocessing
-   * and simplification but doesn't necessarily want to rewrite it.
-   */
-  virtual TrustNode expandDefinition(Node node)
-  {
-    // by default, do nothing
-    return TrustNode::null();
-  }
-
   /**
    * Pre-register a term.  Done one time for a Node per SAT context level.
    */
index 75bcbff0ef37210f2e5203168df18e7561025afa..42e9148c2f0230146715d43dd88a9c2c8db280cb 100644 (file)
@@ -60,5 +60,11 @@ TrustNode TheoryRewriter::rewriteEqualityExtWithProof(Node node)
   return TrustNode::null();
 }
 
+TrustNode TheoryRewriter::expandDefinition(Node node)
+{
+  // no expansion
+  return TrustNode::null();
+}
+
 }  // namespace theory
 }  // namespace cvc5
index 2477de51e58a3f00431e2d5df5afa9a94822149b..031e32db4f5a80d80baf113c178db3de12ebac1e 100644 (file)
@@ -138,6 +138,30 @@ class TheoryRewriter
    * node if no rewrites are applied.
    */
   virtual TrustNode rewriteEqualityExtWithProof(Node node);
+
+  /**
+   * Expand definitions in the term node. This returns a term that is
+   * equivalent to node. It wraps this term in a TrustNode of kind
+   * TrustNodeKind::REWRITE. If node is unchanged by this method, the
+   * null TrustNode may be returned. This is an optimization to avoid
+   * constructing the trivial equality (= node node) internally within
+   * TrustNode.
+   *
+   * The purpose of this method is typically to eliminate the operators in node
+   * that are syntax sugar that cannot otherwise be eliminated during rewriting.
+   * For example, division relies on the introduction of an uninterpreted
+   * function for the divide-by-zero case, which we do not introduce with
+   * the standard rewrite methods.
+   *
+   * Some theories have kinds that are effectively definitions and should be
+   * expanded before they are handled.  Definitions allow a much wider range of
+   * actions than the normal forms given by the rewriter. However no
+   * assumptions can be made about subterms having been expanded or rewritten.
+   * Where possible rewrite rules should be used, definitions should only be
+   * used when rewrites are not possible, for example in handling
+   * under-specified operations using partially defined functions.
+   */
+  virtual TrustNode expandDefinition(Node node);
 };
 
 }  // namespace theory