Eliminate partial operators within lambdas during grammar normalization (#2570)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 3 Oct 2018 16:40:44 +0000 (11:40 -0500)
committerHaniel Barbosa <hanielbbarbosa@gmail.com>
Wed, 3 Oct 2018 16:40:44 +0000 (11:40 -0500)
src/theory/quantifiers/sygus/sygus_grammar_norm.cpp
src/theory/quantifiers/sygus/sygus_grammar_norm.h
test/regress/CMakeLists.txt
test/regress/Makefile.tests
test/regress/regress2/sygus/multi-udiv.sy [new file with mode: 0644]

index 3d066e8dd92238673eefaf8ea1e303888145107e..b014a30c6d5c1c3ef7e9d89689a8b0968dcb02dc 100644 (file)
@@ -67,6 +67,88 @@ bool OpPosTrie::getOrMakeType(TypeNode tn,
   return d_children[op_pos[ind]].getOrMakeType(tn, unres_tn, op_pos, ind + 1);
 }
 
+Kind SygusGrammarNorm::TypeObject::getEliminateKind(Kind ok)
+{
+  Kind nk = ok;
+  // We also must ensure that builtin operators which are eliminated
+  // during expand definitions are replaced by the proper operator.
+  if (ok == BITVECTOR_UDIV)
+  {
+    nk = BITVECTOR_UDIV_TOTAL;
+  }
+  else if (ok == BITVECTOR_UREM)
+  {
+    nk = BITVECTOR_UREM_TOTAL;
+  }
+  else if (ok == DIVISION)
+  {
+    nk = DIVISION_TOTAL;
+  }
+  else if (ok == INTS_DIVISION)
+  {
+    nk = INTS_DIVISION_TOTAL;
+  }
+  else if (ok == INTS_MODULUS)
+  {
+    nk = INTS_MODULUS_TOTAL;
+  }
+  return nk;
+}
+
+Node SygusGrammarNorm::TypeObject::eliminatePartialOperators(Node n)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  std::unordered_map<TNode, Node, TNodeHashFunction> visited;
+  std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
+  std::vector<TNode> visit;
+  TNode cur;
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+    it = visited.find(cur);
+
+    if (it == visited.end())
+    {
+      visited[cur] = Node::null();
+      visit.push_back(cur);
+      for (const Node& cn : cur)
+      {
+        visit.push_back(cn);
+      }
+    }
+    else if (it->second.isNull())
+    {
+      Node ret = cur;
+      bool childChanged = false;
+      std::vector<Node> children;
+      if (cur.getMetaKind() == metakind::PARAMETERIZED)
+      {
+        children.push_back(cur.getOperator());
+      }
+      for (const Node& cn : cur)
+      {
+        it = visited.find(cn);
+        Assert(it != visited.end());
+        Assert(!it->second.isNull());
+        childChanged = childChanged || cn != it->second;
+        children.push_back(it->second);
+      }
+      Kind ok = cur.getKind();
+      Kind nk = getEliminateKind(ok);
+      if (nk != ok || childChanged)
+      {
+        ret = nm->mkNode(nk, children);
+      }
+      visited[cur] = ret;
+    }
+  } while (!visit.empty());
+  Assert(visited.find(n) != visited.end());
+  Assert(!visited.find(n)->second.isNull());
+  return visited[n];
+}
+
 void SygusGrammarNorm::TypeObject::addConsInfo(SygusGrammarNorm* sygus_norm,
                                                const DatatypeConstructor& cons)
 {
@@ -74,36 +156,15 @@ void SygusGrammarNorm::TypeObject::addConsInfo(SygusGrammarNorm* sygus_norm,
   /* Recover the sygus operator to not lose reference to the original
    * operator (NOT, ITE, etc) */
   Node sygus_op = Node::fromExpr(cons.getSygusOp());
+  Trace("sygus-grammar-normalize-debug")
+      << ".....operator is " << sygus_op << std::endl;
   Node exp_sop_n = Node::fromExpr(
       smt::currentSmtEngine()->expandDefinitions(sygus_op.toExpr()));
+  // if it is a builtin operator, convert to total version if necessary
   if (exp_sop_n.getKind() == kind::BUILTIN)
   {
     Kind ok = NodeManager::operatorToKind(sygus_op);
-    Kind nk = ok;
-    Trace("sygus-grammar-normalize-debug")
-        << "...builtin operator is " << ok << std::endl;
-    // We also must ensure that builtin operators which are eliminated
-    // during expand definitions are replaced by the proper operator.
-    if (ok == kind::BITVECTOR_UDIV)
-    {
-      nk = kind::BITVECTOR_UDIV_TOTAL;
-    }
-    else if (ok == kind::BITVECTOR_UREM)
-    {
-      nk = kind::BITVECTOR_UREM_TOTAL;
-    }
-    else if (ok == kind::DIVISION)
-    {
-      nk = kind::DIVISION_TOTAL;
-    }
-    else if (ok == kind::INTS_DIVISION)
-    {
-      nk = kind::INTS_DIVISION_TOTAL;
-    }
-    else if (ok == kind::INTS_MODULUS)
-    {
-      nk = kind::INTS_MODULUS_TOTAL;
-    }
+    Kind nk = getEliminateKind(ok);
     if (nk != ok)
     {
       Trace("sygus-grammar-normalize-debug")
@@ -111,7 +172,21 @@ void SygusGrammarNorm::TypeObject::addConsInfo(SygusGrammarNorm* sygus_norm,
       exp_sop_n = NodeManager::currentNM()->operatorOf(nk);
     }
   }
-  d_ops.push_back(Rewriter::rewrite(exp_sop_n));
+  else
+  {
+    exp_sop_n = Rewriter::rewrite(exp_sop_n);
+    Trace("sygus-grammar-normalize-debug")
+        << ".....operator (post-rewrite) is " << exp_sop_n << std::endl;
+    // eliminate all partial operators from it
+    exp_sop_n = eliminatePartialOperators(exp_sop_n);
+    Trace("sygus-grammar-normalize-debug")
+        << ".....operator (eliminate partial operators) is " << exp_sop_n
+        << std::endl;
+    // rewrite again
+    exp_sop_n = Rewriter::rewrite(exp_sop_n);
+  }
+
+  d_ops.push_back(exp_sop_n);
   Trace("sygus-grammar-normalize-defs")
       << "\tOriginal op: " << cons.getSygusOp()
       << "\n\tExpanded one: " << exp_sop_n
index a0f81dcf3e04c17d219da6a58ed9071d5909c448..993d4166829c0df75500e6a221e2c9adcba6e7a3 100644 (file)
@@ -207,6 +207,16 @@ class SygusGrammarNorm
      */
     void addConsInfo(SygusGrammarNorm* sygus_norm,
                      const DatatypeConstructor& cons);
+    /**
+     * Returns the total version of Kind k if it is a partial operator, or
+     * otherwise k itself.
+     */
+    static Kind getEliminateKind(Kind k);
+    /**
+     * Returns a version of n where all partial functions such as bvudiv
+     * have been replaced by their total versions like bvudiv_total.
+     */
+    static Node eliminatePartialOperators(Node n);
 
     /** builds a datatype with the information in the type object
      *
index c798af3784c43d1cf6286c68fbadefa3ae38a532..bec5362e5a8ae29b571dd7a8f151c049d9bdd147 100644 (file)
@@ -1718,6 +1718,7 @@ set(regress_2_tests
   regress2/sygus/lustre-real.sy
   regress2/sygus/max2-univ.sy
   regress2/sygus/mpg_guard1-dd.sy
+  regress2/sygus/multi-udiv.sy
   regress2/sygus/nia-max-square.sy
   regress2/sygus/no-syntax-test-no-si.sy
   regress2/sygus/process-10-vars-2fun.sy
index deb4ad647ded140c902d4570acfa2e7122ba1f3d..37c911d415a258497a038e6f47fe2e426765443b 100644 (file)
@@ -1711,6 +1711,7 @@ REG2_TESTS = \
        regress2/sygus/lustre-real.sy \
        regress2/sygus/max2-univ.sy \
        regress2/sygus/mpg_guard1-dd.sy \
+       regress2/sygus/multi-udiv.sy \
        regress2/sygus/nia-max-square.sy \
        regress2/sygus/no-syntax-test-no-si.sy \
        regress2/sygus/process-10-vars-2fun.sy \
diff --git a/test/regress/regress2/sygus/multi-udiv.sy b/test/regress/regress2/sygus/multi-udiv.sy
new file mode 100644 (file)
index 0000000..6574175
--- /dev/null
@@ -0,0 +1,42 @@
+; EXPECT: unsat
+; COMMAND-LINE: --sygus-out=status
+  ( set-logic BV )
+  ( define-fun hd05  (    ( x  ( BitVec 32 ) ) )  ( BitVec 32 )  ( bvor     x  ( bvsub     x   #x00000001 ) ) )
+( synth-fun f  (    ( x  ( BitVec 32 ) ) )  ( BitVec 32 ) (
+       (Start  ( BitVec 32 ) (         #x00000001
+               #x00000000
+               #xffffffff
+               x
+               (bvsrem NT0 NT0)
+               (bvudiv NT0 NT0)
+               (bvsdiv NT0 NT0)
+               (bvurem NT0 NT0)
+               (bvsrem NT4 NT0)
+               (bvudiv NT4 NT0)
+               (bvurem NT4 NT0)
+               (bvsdiv NT4 NT0)
+               (bvnot NT0)
+               (bvneg NT0)
+               (bvadd NT0 NT0)
+               (bvor NT0 NT0)
+               (bvor NT4 NT0)
+               (bvadd NT4 NT0)
+))
+       (NT0  ( BitVec 32 ) (           #x00000001
+               #x00000000
+               #xffffffff
+               x
+))
+       (NT4  ( BitVec 32 ) (           (bvnot NT0)
+               (bvneg NT0)
+               (bvadd NT0 NT0)
+               (bvor NT0 NT0)
+               (bvsrem NT0 NT0)
+               (bvudiv NT0 NT0)
+               (bvsdiv NT0 NT0)
+               (bvurem NT0 NT0)
+))
+))
+  ( declare-var x  ( BitVec 32 ) )
+  ( constraint  ( =     ( hd05    x )  ( f    x ) ) )
+  ( check-synth )