Connecting the core array solver in strings (#7800)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 14 Dec 2021 19:35:09 +0000 (13:35 -0600)
committerGitHub <noreply@github.com>
Tue, 14 Dec 2021 19:35:09 +0000 (19:35 +0000)
This PR takes most of the remaining changes from the seqArray branch apart from the extension to model construction.

Notably it connects the core array solver to the array solver in strings.

src/expr/skolem_manager.cpp
src/expr/skolem_manager.h
src/smt/proof_post_processor.cpp
src/theory/strings/array_core_solver.cpp
src/theory/strings/array_core_solver.h
src/theory/strings/array_solver.cpp
src/theory/strings/array_solver.h
src/theory/strings/rewrites.cpp
src/theory/strings/rewrites.h
src/theory/strings/sequences_rewriter.cpp

index 206ebb9ce8f1d691c92970c2a63051c83c41ecf6..f08ffc5f44fe59bf9029a7afd32a9f702b212348 100644 (file)
@@ -67,6 +67,7 @@ const char* toString(SkolemFunId id)
     case SkolemFunId::SK_FIRST_MATCH: return "SK_FIRST_MATCH";
     case SkolemFunId::SK_FIRST_MATCH_POST: return "SK_FIRST_MATCH_POST";
     case SkolemFunId::RE_UNFOLD_POS_COMPONENT: return "RE_UNFOLD_POS_COMPONENT";
+    case SkolemFunId::SEQ_MODEL_BASE_ELEMENT: return "SEQ_MODEL_BASE_ELEMENT";
     case SkolemFunId::BAGS_CARD_CARDINALITY: return "BAGS_CARD_CARDINALITY";
     case SkolemFunId::BAGS_CARD_ELEMENTS: return "BAGS_CARD_ELEMENTS";
     case SkolemFunId::BAGS_CARD_N: return "BAGS_CARD_N";
index 93b26b6cb84f3ba20060c9138c785656c79f4d45..cca28ccf047f683dadc6b553425262f44b0a54a2 100644 (file)
@@ -112,6 +112,8 @@ enum class SkolemFunId
    * i = 0, ..., n.
    */
   RE_UNFOLD_POS_COMPONENT,
+  /** Sequence model construction, element for base */
+  SEQ_MODEL_BASE_ELEMENT,
   BAGS_CARD_CARDINALITY,
   BAGS_CARD_ELEMENTS,
   BAGS_CARD_N,
index 90f0a48bffcadfa76c994398318140ba19cd4f3f..167a82e26c7eb14e5d45a1baf156e19c2bacf567 100644 (file)
@@ -422,6 +422,7 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     // not eliminated
     return Node::null();
   }
+  Trace("smt-proof-pp-debug") << "Expand macro " << id << std::endl;
   // macro elimination
   if (id == PfRule::MACRO_SR_EQ_INTRO)
   {
index ed369006832c40b6c084050a02fdd2b060da983f..3b8fdeff4ee7f92d31ce5ceac12869116bef3512 100644 (file)
@@ -70,7 +70,7 @@ void ArrayCoreSolver::checkNth(const std::vector<Node>& nthTerms)
       // (seq.extract A i l) ^ (<= 0 i) ^ (< i (str.len A)) --> (seq.unit
       // (seq.nth A i))
       std::vector<Node> exp;
-      Node cond1 = nm->mkNode(LEQ, nm->mkConst(Rational(0)), n[1]);
+      Node cond1 = nm->mkNode(LEQ, nm->mkConstInt(Rational(0)), n[1]);
       Node cond2 = nm->mkNode(LT, n[1], nm->mkNode(STRING_LENGTH, n[0]));
       Node cond = nm->mkNode(AND, cond1, cond2);
       Node body1 = nm->mkNode(
@@ -115,7 +115,7 @@ void ArrayCoreSolver::checkUpdate(const std::vector<Node>& updateTerms)
     // n[2][0]
     Node left = nm->mkNode(SEQ_NTH, termProxy, n[1]);
     Node right =
-        nm->mkNode(SEQ_NTH, n[2], nm->mkConst(Rational(0)));  // n[2][0]
+        nm->mkNode(SEQ_NTH, n[2], nm->mkConstInt(Rational(0)));  // n[2][0]
     right = Rewriter::rewrite(right);
     Node lem = nm->mkNode(EQUAL, left, right);
     Trace("seq-array-debug") << "enter" << std::endl;
@@ -211,10 +211,10 @@ void ArrayCoreSolver::check(const std::vector<Node>& nthTerms,
       Node i = n[1];
       Node sLen = nm->mkNode(STRING_LENGTH, s);
       Node iRev = nm->mkNode(
-          MINUS, sLen, nm->mkNode(PLUS, i, nm->mkConst(Rational(1))));
+          MINUS, sLen, nm->mkNode(PLUS, i, nm->mkConstInt(Rational(1))));
 
       std::vector<Node> nexp;
-      nexp.push_back(nm->mkNode(LEQ, nm->mkConst(Rational(0)), i));
+      nexp.push_back(nm->mkNode(LEQ, nm->mkConstInt(Rational(0)), i));
       nexp.push_back(nm->mkNode(LT, i, sLen));
 
       // 0 <= i ^ i < len(s) => seq.nth(seq.rev(s), i) = seq.nth(s, len(s) - i -
index 7101da6259832fa7ee80477cf1d62aa4540a575a..3873f6a691964bcd267be8834f9d51196638fb62 100644 (file)
@@ -62,7 +62,7 @@ class ArrayCoreSolver : protected EnvObj
   const std::map<Node, Node>& getWriteModel(Node eqc);
 
   /**
-   * Get connected sequences
+   * Get connected sequences, see documentation of computeConnected.
    * @return a map M such that sequence equivalence class representatives x and
    * y are connected if an only if M[x] = M[y].
    */
@@ -93,7 +93,18 @@ class ArrayCoreSolver : protected EnvObj
    */
   void checkUpdate(const std::vector<Node>& updateTerms);
 
-  // TODO: document
+  /**
+   * Given the current set of update terms, this computes the connected
+   * sequences implied by the current equality information + this set of terms.
+   * Connected sequences is a reflexive transitive relation where additionally
+   * a and b are connected if there exists an update term (seq.update a n x)
+   * that is currently equal to b.
+   *
+   * This method runs a union find algorithm to compute all connected sequences.
+   *
+   * As a result of running this method, the map d_connectedSeq is populated
+   * with information regarding which sequences are connected.
+   */
   void computeConnected(const std::vector<Node>& updateTerms);
 
   /** The solver state object */
@@ -110,7 +121,15 @@ class ArrayCoreSolver : protected EnvObj
   ExtTheory& d_extt;
   /** The write model */
   std::map<Node, std::map<Node, Node>> d_writeModel;
-  /** Connected */
+  /**
+   * Map from sequences to their "connected representative". Two sequences are
+   * connected (based on the definition described in computeConnected) iff they
+   * have the same connected representative. Sequences that do not occur in
+   * this map are assumed to be their own connected representative.
+   *
+   * This map is only valid after running computeConnected, and is valid
+   * only during model building.
+   */
   std::map<Node, Node> d_connectedSeq;
   /** The set of lemmas been sent */
   context::CDHashSet<Node> d_lem;
index c04bfe9184c3865a2b86a9035a3f0acbcdb09f7c..672ca8b765cbf5f22f166b41924f8e39a0ef0e62 100644 (file)
@@ -41,6 +41,7 @@ ArraySolver::ArraySolver(Env& env,
       d_termReg(tr),
       d_csolver(cs),
       d_esolver(es),
+      d_coreSolver(env, s, im, tr, cs, es, extt),
       d_eqProc(context())
 {
   NodeManager* nm = NodeManager::currentNM();
@@ -63,6 +64,32 @@ void ArraySolver::checkArrayConcat()
   checkTerms(SEQ_NTH);
 }
 
+void ArraySolver::checkArray()
+{
+  if (!d_termReg.hasSeqUpdate())
+  {
+    Trace("seq-array") << "No seq.update/seq.nth terms, skipping check..."
+                       << std::endl;
+    return;
+  }
+  Trace("seq-array") << "ArraySolver::checkArray..." << std::endl;
+  d_coreSolver.check(d_currTerms[SEQ_NTH], d_currTerms[STRING_UPDATE]);
+}
+
+void ArraySolver::checkArrayEager()
+{
+  if (!d_termReg.hasSeqUpdate())
+  {
+    Trace("seq-array") << "No seq.update/seq.nth terms, skipping check..."
+                       << std::endl;
+    return;
+  }
+  Trace("seq-array") << "ArraySolver::checkArray..." << std::endl;
+  std::vector<Node> nthTerms = d_esolver.getActive(SEQ_NTH);
+  std::vector<Node> updateTerms = d_esolver.getActive(STRING_UPDATE);
+  d_coreSolver.check(nthTerms, updateTerms);
+}
+
 void ArraySolver::checkTerms(Kind k)
 {
   Assert(k == STRING_UPDATE || k == SEQ_NTH);
@@ -271,6 +298,16 @@ void ArraySolver::checkTerms(Kind k)
   }
 }
 
+const std::map<Node, Node>& ArraySolver::getWriteModel(Node eqc)
+{
+  return d_coreSolver.getWriteModel(eqc);
+}
+
+const std::map<Node, Node>& ArraySolver::getConnectedSequences()
+{
+  return d_coreSolver.getConnectedSequences();
+}
+
 }  // namespace strings
 }  // namespace theory
 }  // namespace cvc5
index 941061e9e23aebda52fd10a5e6ba34b1f8a3d3c9..23bacd118886ac7acf745ec9048b6138532453fc 100644 (file)
@@ -19,6 +19,7 @@
 #define CVC5__THEORY__STRINGS__ARRAY_SOLVER_H
 
 #include "context/cdhashset.h"
+#include "theory/strings/array_core_solver.h"
 #include "theory/strings/core_solver.h"
 #include "theory/strings/extf_solver.h"
 #include "theory/strings/inference_manager.h"
@@ -54,6 +55,32 @@ class ArraySolver : protected EnvObj
    * their application to concatenation terms.
    */
   void checkArrayConcat();
+  /**
+   * Perform reasoning about seq.nth and seq.update operations (lazily), which
+   * calls the core sequences-array solver for the set of nth/update terms over atomic
+   * equivalence classes.
+   */
+  void checkArray();
+  /**
+   * Same as `checkArray`, but called eagerly, and for all nth/update terms, not just
+   * those over atomic equivalence classes.
+   */
+  void checkArrayEager();
+
+  /**
+   * @param eqc The sequence equivalence class representative. We can assume
+   * the equivalence class of eqc contains no concatenation terms.
+   * @return the map corresponding to the model for eqc. The domain of
+   * the returned map should be in distinct integer equivalence classes of the
+   * equality engine of strings theory. The model assigned to eqc will be
+   * a skeleton constructed via seq.++ where the components take values from
+   * this map.
+   */
+  const std::map<Node, Node>& getWriteModel(Node eqc);
+  /**
+   * Get connected sequences from the core array solver.
+   */
+  const std::map<Node, Node>& getConnectedSequences();
 
  private:
   /** check terms of given kind */
@@ -72,6 +99,8 @@ class ArraySolver : protected EnvObj
   std::map<Kind, std::vector<Node> > d_currTerms;
   /** Common constants */
   Node d_zero;
+  /** The core array solver */
+  ArrayCoreSolver d_coreSolver;
   /** Equalities we have processed in the current context */
   NodeSet d_eqProc;
 };
index 4da6e5600d678d00a7065536fe7dff45212faa90..bfe9021aaf39b5ca389053f6f215d761d11df632 100644 (file)
@@ -154,6 +154,7 @@ const char* toString(Rewrite r)
     case Rewrite::UPD_CONST_INDEX_MAX_OOB: return "UPD_CONST_INDEX_MAX_OOB";
     case Rewrite::UPD_CONST_INDEX_NEG: return "UPD_CONST_INDEX_NEG";
     case Rewrite::UPD_CONST_INDEX_OOB: return "UPD_CONST_INDEX_OOB";
+    case Rewrite::UPD_REV: return "UPD_REV";
     case Rewrite::STOI_CONCAT_NONNUM: return "STOI_CONCAT_NONNUM";
     case Rewrite::STOI_EVAL: return "STOI_EVAL";
     case Rewrite::STR_CONV_CONST: return "STR_CONV_CONST";
@@ -223,6 +224,7 @@ const char* toString(Rewrite r)
     case Rewrite::SEQ_UNIT_EVAL: return "SEQ_UNIT_EVAL";
     case Rewrite::SEQ_NTH_EVAL: return "SEQ_NTH_EVAL";
     case Rewrite::SEQ_NTH_TOTAL_OOB: return "SEQ_NTH_TOTAL_OOB";
+    case Rewrite::SEQ_NTH_UNIT: return "SEQ_NTH_UNIT";
     default: return "?";
   }
 }
index c96dffcdebf47f25b600053202b2a6c63e44858b..b57c5f2765041963f6834857a28657a965a3e29b 100644 (file)
@@ -155,6 +155,7 @@ enum class Rewrite : uint32_t
   UPD_CONST_INDEX_MAX_OOB,
   UPD_CONST_INDEX_NEG,
   UPD_CONST_INDEX_OOB,
+  UPD_REV,
   STOI_CONCAT_NONNUM,
   STOI_EVAL,
   STR_CONV_CONST,
@@ -223,7 +224,8 @@ enum class Rewrite : uint32_t
   CHARAT_ELIM,
   SEQ_UNIT_EVAL,
   SEQ_NTH_EVAL,
-  SEQ_NTH_TOTAL_OOB
+  SEQ_NTH_TOTAL_OOB,
+  SEQ_NTH_UNIT
 };
 
 /**
index 7670c0b70439f6635fd8b7cc876a5ae3ead12716..1ccb67490e3bef36d62e53df83fc45ee95bfa1e6 100644 (file)
@@ -1759,15 +1759,15 @@ Node SequencesRewriter::rewriteSeqNth(Node node)
       Node ret = nm->mkGroundValue(s.getType().getSequenceElementType());
       return returnRewrite(node, ret, Rewrite::SEQ_NTH_TOTAL_OOB);
     }
-    else
-    {
-      return node;
-    }
   }
-  else
+
+  if (s.getKind() == SEQ_UNIT && i.isConst() && i.getConst<Rational>().isZero())
   {
-    return node;
+    Node ret = s[0];
+    return returnRewrite(node, ret, Rewrite::SEQ_NTH_UNIT);
   }
+
+  return node;
 }
 
 Node SequencesRewriter::rewriteCharAt(Node node)
@@ -2045,6 +2045,8 @@ Node SequencesRewriter::rewriteUpdate(Node node)
 {
   Assert(node.getKind() == kind::STRING_UPDATE);
   Node s = node[0];
+  Node i = node[1];
+  Node x = node[2];
   if (s.isConst())
   {
     if (Word::isEmpty(s))
@@ -2082,6 +2084,16 @@ Node SequencesRewriter::rewriteUpdate(Node node)
     }
   }
 
+  if (s.getKind() == STRING_REV)
+  {
+    NodeManager* nm = NodeManager::currentNM();
+    Node idx = nm->mkNode(MINUS,
+                          nm->mkNode(STRING_LENGTH, s),
+                          nm->mkNode(PLUS, i, nm->mkConst(Rational(1))));
+    Node ret = nm->mkNode(STRING_REV, nm->mkNode(STRING_UPDATE, s, idx, x));
+    return returnRewrite(node, ret, Rewrite::UPD_REV);
+  }
+
   return node;
 }