Move proxy variables to InferenceManager in strings (#3758)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sat, 15 Feb 2020 22:38:23 +0000 (16:38 -0600)
committerGitHub <noreply@github.com>
Sat, 15 Feb 2020 22:38:23 +0000 (16:38 -0600)
src/theory/strings/inference_manager.cpp
src/theory/strings/inference_manager.h
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h

index 2b5338a6a85ca73b2eebb842656ad2417edd62fc..67ba2d5a37fe0f3365b4e6e0eaca79f4f6346afd 100644 (file)
@@ -34,8 +34,16 @@ InferenceManager::InferenceManager(TheoryStrings& p,
                                    context::Context* c,
                                    context::UserContext* u,
                                    SolverState& s,
+                                   SkolemCache& skc,
                                    OutputChannel& out)
-    : d_parent(p), d_state(s), d_out(out), d_keep(c), d_lengthLemmaTermsCache(u)
+    : d_parent(p),
+      d_state(s),
+      d_skCache(skc),
+      d_out(out),
+      d_keep(c),
+      d_proxyVar(u),
+      d_proxyVarToLength(u),
+      d_lengthLemmaTermsCache(u)
 {
   NodeManager* nm = NodeManager::currentNM();
   d_zero = nm->mkConst(Rational(0));
@@ -284,6 +292,129 @@ void InferenceManager::sendPhaseRequirement(Node lit, bool pol)
   d_pendingReqPhase[lit] = pol;
 }
 
+Node InferenceManager::getProxyVariableFor(Node n) const
+{
+  NodeNodeMap::const_iterator it = d_proxyVar.find(n);
+  if (it != d_proxyVar.end())
+  {
+    return (*it).second;
+  }
+  return Node::null();
+}
+
+Node InferenceManager::getSymbolicDefinition(Node n,
+                                             std::vector<Node>& exp) const
+{
+  if (n.getNumChildren() == 0)
+  {
+    Node pn = getProxyVariableFor(n);
+    if (pn.isNull())
+    {
+      return Node::null();
+    }
+    Node eq = n.eqNode(pn);
+    eq = Rewriter::rewrite(eq);
+    if (std::find(exp.begin(), exp.end(), eq) == exp.end())
+    {
+      exp.push_back(eq);
+    }
+    return pn;
+  }
+  std::vector<Node> children;
+  if (n.getMetaKind() == metakind::PARAMETERIZED)
+  {
+    children.push_back(n.getOperator());
+  }
+  for (const Node& nc : n)
+  {
+    if (n.getType().isRegExp())
+    {
+      children.push_back(nc);
+    }
+    else
+    {
+      Node ns = getSymbolicDefinition(nc, exp);
+      if (ns.isNull())
+      {
+        return Node::null();
+      }
+      else
+      {
+        children.push_back(ns);
+      }
+    }
+  }
+  return NodeManager::currentNM()->mkNode(n.getKind(), children);
+}
+
+void InferenceManager::registerLength(Node n)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  // register length information:
+  //  for variables, split on empty vs positive length
+  //  for concat/const/replace, introduce proxy var and state length relation
+  Node lsum;
+  if (n.getKind() != STRING_CONCAT && n.getKind() != CONST_STRING)
+  {
+    Node lsumb = nm->mkNode(STRING_LENGTH, n);
+    lsum = Rewriter::rewrite(lsumb);
+    // can register length term if it does not rewrite
+    if (lsum == lsumb)
+    {
+      registerLength(n, LENGTH_SPLIT);
+      return;
+    }
+  }
+  Node sk = d_skCache.mkSkolemCached(n, SkolemCache::SK_PURIFY, "lsym");
+  StringsProxyVarAttribute spva;
+  sk.setAttribute(spva, true);
+  Node eq = Rewriter::rewrite(sk.eqNode(n));
+  Trace("strings-lemma") << "Strings::Lemma LENGTH Term : " << eq << std::endl;
+  d_proxyVar[n] = sk;
+  // If we are introducing a proxy for a constant or concat term, we do not
+  // need to send lemmas about its length, since its length is already
+  // implied.
+  if (n.isConst() || n.getKind() == STRING_CONCAT)
+  {
+    // do not send length lemma for sk.
+    registerLength(sk, LENGTH_IGNORE);
+  }
+  Trace("strings-assert") << "(assert " << eq << ")" << std::endl;
+  d_out.lemma(eq);
+  Node skl = nm->mkNode(STRING_LENGTH, sk);
+  if (n.getKind() == STRING_CONCAT)
+  {
+    std::vector<Node> nodeVec;
+    for (const Node& nc : n)
+    {
+      if (nc.getAttribute(StringsProxyVarAttribute()))
+      {
+        Assert(d_proxyVarToLength.find(nc) != d_proxyVarToLength.end());
+        nodeVec.push_back(d_proxyVarToLength[nc]);
+      }
+      else
+      {
+        Node lni = nm->mkNode(STRING_LENGTH, nc);
+        nodeVec.push_back(lni);
+      }
+    }
+    lsum = nm->mkNode(PLUS, nodeVec);
+    lsum = Rewriter::rewrite(lsum);
+  }
+  else if (n.getKind() == CONST_STRING)
+  {
+    lsum = nm->mkConst(Rational(n.getConst<String>().size()));
+  }
+  Assert(!lsum.isNull());
+  d_proxyVarToLength[sk] = lsum;
+  Node ceq = Rewriter::rewrite(skl.eqNode(lsum));
+  Trace("strings-lemma") << "Strings::Lemma LENGTH : " << ceq << std::endl;
+  Trace("strings-lemma-debug")
+      << "  prerewrite : " << skl.eqNode(lsum) << std::endl;
+  Trace("strings-assert") << "(assert " << ceq << ")" << std::endl;
+  d_out.lemma(ceq);
+}
+
 void InferenceManager::registerLength(Node n, LengthStatus s)
 {
   if (d_lengthLemmaTermsCache.find(n) != d_lengthLemmaTermsCache.end())
@@ -480,7 +611,7 @@ void InferenceManager::inferSubstitutionProxyVars(
         }
         else if (ns[i].isConst())
         {
-          ss = d_parent.getProxyVariableFor(ns[i]);
+          ss = getProxyVariableFor(ns[i]);
         }
         if (!ss.isNull())
         {
index 819e4b98f30a2db86ec9d88f23c2dbf849cf6215..50cfdb6fb609b626ba0db44461ce0004efb0290c 100644 (file)
@@ -25,6 +25,7 @@
 #include "expr/node.h"
 #include "theory/output_channel.h"
 #include "theory/strings/infer_info.h"
+#include "theory/strings/skolem_cache.h"
 #include "theory/strings/solver_state.h"
 #include "theory/uf/equality_engine.h"
 
@@ -67,12 +68,14 @@ class TheoryStrings;
 class InferenceManager
 {
   typedef context::CDHashSet<Node, NodeHashFunction> NodeSet;
+  typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeNodeMap;
 
  public:
   InferenceManager(TheoryStrings& p,
                    context::Context* c,
                    context::UserContext* u,
                    SolverState& s,
+                   SkolemCache& skc,
                    OutputChannel& out);
   ~InferenceManager() {}
 
@@ -164,6 +167,32 @@ class InferenceManager
    * decided with polarity pol.
    */
   void sendPhaseRequirement(Node lit, bool pol);
+
+  //---------------------------- proxy variables and length elaboration
+  /** Get symbolic definition
+   *
+   * This method returns the "symbolic definition" of n, call it n', and
+   * populates the vector exp with an explanation such that exp => n = n'.
+   *
+   * The symbolic definition of n is the term where (maximal) subterms of n
+   * are replaced by their proxy variables. For example, if we introduced
+   * proxy variable v for x ++ y, then given input x ++ y = w, this method
+   * returns v = w and adds v = x ++ y to exp.
+   */
+  Node getSymbolicDefinition(Node n, std::vector<Node>& exp) const;
+  /** Get proxy variable
+   *
+   * If this method returns the proxy variable for (string) term n if one
+   * exists, otherwise it returns null.
+   */
+  Node getProxyVariableFor(Node n) const;
+  /** register length
+   *
+   * This method is called on non-constant string terms n. It sends a lemma
+   * on the output channel that ensures that the length n satisfies its assigned
+   * status (given by argument s).
+   */
+  void registerLength(Node n);
   /** register length
    *
    * This method is called on non-constant string terms n. It sends a lemma
@@ -186,6 +215,7 @@ class InferenceManager
    * channel instead of adding them to pending lists.
    */
   void registerLength(Node n, LengthStatus s);
+  //---------------------------- end proxy variables and length elaboration
 
   //----------------------------constructing antecedants
   /**
@@ -290,6 +320,8 @@ class InferenceManager
    * This is a reference to the solver state of the theory of strings.
    */
   SolverState& d_state;
+  /** cache of all skolems */
+  SkolemCache& d_skCache;
   /** the output channel
    *
    * This is a reference to the output channel of the theory of strings.
@@ -316,6 +348,22 @@ class InferenceManager
    * SAT-context-dependent.
    */
   NodeSet d_keep;
+  /**
+   * Map string terms to their "proxy variables". Proxy variables are used are
+   * intermediate variables so that length information can be communicated for
+   * constants. For example, to communicate that "ABC" has length 3, we
+   * introduce a proxy variable v_{"ABC"} for "ABC", and assert:
+   *   v_{"ABC"} = "ABC" ^ len( v_{"ABC"} ) = 3
+   * Notice this is required since we cannot directly write len( "ABC" ) = 3,
+   * which rewrites to 3 = 3.
+   * In the above example, we store "ABC" -> v_{"ABC"} in this map.
+   */
+  NodeNodeMap d_proxyVar;
+  /**
+   * Map from proxy variables to their normalized length. In the above example,
+   * we store "ABC" -> 3.
+   */
+  NodeNodeMap d_proxyVarToLength;
   /** List of terms that we have register length for */
   NodeSet d_lengthLemmaTermsCache;
   /** infer substitution proxy vars
index 197f7ac4c4e4c5612f312f4cf0db4c41e444507d..23a41a0bb9eacc8868e336176a4ac8940357d89f 100644 (file)
@@ -71,13 +71,11 @@ TheoryStrings::TheoryStrings(context::Context* c,
       d_notify(*this),
       d_equalityEngine(d_notify, c, "theory::strings", true),
       d_state(c, d_equalityEngine, d_valuation),
-      d_im(*this, c, u, d_state, out),
+      d_im(*this, c, u, d_state, d_sk_cache, out),
       d_pregistered_terms_cache(u),
       d_registered_terms_cache(u),
       d_preproc(&d_sk_cache, u),
       d_extf_infer_cache(c),
-      d_proxy_var(u),
-      d_proxy_var_to_length(u),
       d_functionsTerms(c),
       d_has_extf(c, false),
       d_has_str_code(false),
@@ -1231,7 +1229,7 @@ void TheoryStrings::checkExtfEval( int effort ) {
           // only use symbolic definitions if option is set
           if (options::stringInferSym())
           {
-            nrs = getSymbolicDefinition(sn, exps);
+            nrs = d_im.getSymbolicDefinition(sn, exps);
           }
           if( !nrs.isNull() ){
             Trace("strings-extf-debug") << "  rewrite " << nrs << "..." << std::endl;
@@ -1531,51 +1529,6 @@ void TheoryStrings::checkExtfInference( Node n, Node nr, ExtfInfoTmp& in, int ef
   }
 }
 
-Node TheoryStrings::getProxyVariableFor(Node n) const
-{
-  NodeNodeMap::const_iterator it = d_proxy_var.find(n);
-  if (it != d_proxy_var.end())
-  {
-    return (*it).second;
-  }
-  return Node::null();
-}
-Node TheoryStrings::getSymbolicDefinition(Node n, std::vector<Node>& exp) const
-{
-  if( n.getNumChildren()==0 ){
-    Node pn = getProxyVariableFor(n);
-    if (pn.isNull())
-    {
-      return Node::null();
-    }
-    Node eq = n.eqNode(pn);
-    eq = Rewriter::rewrite(eq);
-    if (std::find(exp.begin(), exp.end(), eq) == exp.end())
-    {
-      exp.push_back(eq);
-    }
-    return pn;
-  }else{
-    std::vector< Node > children;
-    if (n.getMetaKind() == kind::metakind::PARAMETERIZED) {
-      children.push_back( n.getOperator() );
-    }
-    for( unsigned i=0; i<n.getNumChildren(); i++ ){
-      if( n.getKind()==kind::STRING_IN_REGEXP && i==1 ){
-        children.push_back( n[i] );
-      }else{
-        Node ns = getSymbolicDefinition( n[i], exp );
-        if( ns.isNull() ){
-          return Node::null();
-        }else{
-          children.push_back( ns );
-        }
-      }
-    }
-    return NodeManager::currentNM()->mkNode( n.getKind(), children );
-  }
-}
-
 void TheoryStrings::checkRegisterTermsPreNormalForm()
 {
   const std::vector<Node>& seqc = d_bsolver.getStringEqc();
@@ -1619,7 +1572,7 @@ void TheoryStrings::checkCodes()
         Node cc = nm->mkNode(kind::STRING_CODE, c);
         cc = Rewriter::rewrite(cc);
         Assert(cc.isConst());
-        Node cp = getProxyVariableFor(c);
+        Node cp = d_im.getProxyVariableFor(c);
         AlwaysAssert(!cp.isNull());
         Node vc = nm->mkNode(STRING_CODE, cp);
         if (!d_state.areEqual(cc, vc))
@@ -1701,68 +1654,7 @@ void TheoryStrings::registerTerm(Node n, int effort)
     // register length information:
     //  for variables, split on empty vs positive length
     //  for concat/const/replace, introduce proxy var and state length relation
-    Node lsum;
-    if (n.getKind() != STRING_CONCAT && n.getKind() != CONST_STRING)
-    {
-      Node lsumb = nm->mkNode(STRING_LENGTH, n);
-      lsum = Rewriter::rewrite(lsumb);
-      // can register length term if it does not rewrite
-      if (lsum == lsumb)
-      {
-        d_im.registerLength(n, LENGTH_SPLIT);
-        return;
-      }
-    }
-    Node sk = d_sk_cache.mkSkolemCached(n, SkolemCache::SK_PURIFY, "lsym");
-    StringsProxyVarAttribute spva;
-    sk.setAttribute(spva, true);
-    Node eq = Rewriter::rewrite(sk.eqNode(n));
-    Trace("strings-lemma") << "Strings::Lemma LENGTH Term : " << eq
-                           << std::endl;
-    d_proxy_var[n] = sk;
-    // If we are introducing a proxy for a constant or concat term, we do not
-    // need to send lemmas about its length, since its length is already
-    // implied.
-    if (n.isConst() || n.getKind() == STRING_CONCAT)
-    {
-      // do not send length lemma for sk.
-      d_im.registerLength(sk, LENGTH_IGNORE);
-    }
-    Trace("strings-assert") << "(assert " << eq << ")" << std::endl;
-    d_out->lemma(eq);
-    Node skl = nm->mkNode(STRING_LENGTH, sk);
-    if (n.getKind() == STRING_CONCAT)
-    {
-      std::vector<Node> node_vec;
-      for (unsigned i = 0; i < n.getNumChildren(); i++)
-      {
-        if (n[i].getAttribute(StringsProxyVarAttribute()))
-        {
-          Assert(d_proxy_var_to_length.find(n[i])
-                 != d_proxy_var_to_length.end());
-          node_vec.push_back(d_proxy_var_to_length[n[i]]);
-        }
-        else
-        {
-          Node lni = nm->mkNode(STRING_LENGTH, n[i]);
-          node_vec.push_back(lni);
-        }
-      }
-      lsum = nm->mkNode(PLUS, node_vec);
-      lsum = Rewriter::rewrite(lsum);
-    }
-    else if (n.getKind() == CONST_STRING)
-    {
-      lsum = nm->mkConst(Rational(n.getConst<String>().size()));
-    }
-    Assert(!lsum.isNull());
-    d_proxy_var_to_length[sk] = lsum;
-    Node ceq = Rewriter::rewrite(skl.eqNode(lsum));
-    Trace("strings-lemma") << "Strings::Lemma LENGTH : " << ceq << std::endl;
-    Trace("strings-lemma-debug")
-        << "  prerewrite : " << skl.eqNode(lsum) << std::endl;
-    Trace("strings-assert") << "(assert " << ceq << ")" << std::endl;
-    d_out->lemma(ceq);
+    d_im.registerLength(n);
   }
   else if (n.getKind() == STRING_CODE)
   {
index 960d3ceaaf29579ebba024fbde25efffa785392a..67b7482ca2f0877c9995545c19a88cac1893a369 100644 (file)
@@ -274,22 +274,6 @@ private:
   EqualityStatus getEqualityStatus(TNode a, TNode b) override;
 
  private:
-  /**
-   * Map string terms to their "proxy variables". Proxy variables are used are
-   * intermediate variables so that length information can be communicated for
-   * constants. For example, to communicate that "ABC" has length 3, we
-   * introduce a proxy variable v_{"ABC"} for "ABC", and assert:
-   *   v_{"ABC"} = "ABC" ^ len( v_{"ABC"} ) = 3
-   * Notice this is required since we cannot directly write len( "ABC" ) = 3,
-   * which rewrites to 3 = 3.
-   * In the above example, we store "ABC" -> v_{"ABC"} in this map.
-   */
-  NodeNodeMap d_proxy_var;
-  /**
-   * Map from proxy variables to their normalized length. In the above example,
-   * we store "ABC" -> 3.
-   */
-  NodeNodeMap d_proxy_var_to_length;
   /** All the function terms that the theory has seen */
   context::CDList<TNode> d_functionsTerms;
 private:
@@ -309,24 +293,6 @@ private:
   /** cache of all skolems */
   SkolemCache d_sk_cache;
 
-  /** Get proxy variable
-   *
-   * If this method returns the proxy variable for (string) term n if one
-   * exists, otherwise it returns null.
-   */
-  Node getProxyVariableFor(Node n) const;
-  /** Get symbolic definition
-   *
-   * This method returns the "symbolic definition" of n, call it n', and
-   * populates the vector exp with an explanation such that exp => n = n'.
-   *
-   * The symbolic definition of n is the term where (maximal) subterms of n
-   * are replaced by their proxy variables. For example, if we introduced
-   * proxy variable v for x ++ y, then given input x ++ y = w, this method
-   * returns v = w and adds v = x ++ y to exp.
-   */
-  Node getSymbolicDefinition(Node n, std::vector<Node>& exp) const;
-
   //--------------------------for checkExtfEval
   /**
    * Non-static information about an extended function t. This information is