Fix for corner case of higher-order matching (#1708)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 4 Apr 2018 21:17:35 +0000 (16:17 -0500)
committerGitHub <noreply@github.com>
Wed, 4 Apr 2018 21:17:35 +0000 (16:17 -0500)
src/theory/quantifiers/ematching/ho_trigger.cpp
src/theory/quantifiers/term_database.cpp
test/regress/Makefile.tests
test/regress/regress0/ho/ho-match-fun-suffix.smt2 [new file with mode: 0644]

index 0e095511935f5a0695a52cfec86ef0cbea2f8e03..fad6334b846caf681b347028915fa66ed37e4d64 100644 (file)
@@ -236,7 +236,7 @@ bool HigherOrderTrigger::sendInstantiation(InstMatch& m)
     {
       TNode var = ha.first;
       unsigned vnum = var.getAttribute(InstVarNumAttribute());
-      Node value = m.d_vals[vnum];
+      TNode value = m.d_vals[vnum];
       Trace("ho-unif-debug") << "  val[" << var << "] = " << value << std::endl;
 
       Trace("ho-unif-debug2") << "initialize lambda information..."
@@ -256,7 +256,17 @@ bool HigherOrderTrigger::sendInstantiation(InstMatch& m)
       for (unsigned i = 0; i < ha.second.size(); i++)
       {
         std::vector<TNode> args;
-        Node f = uf::TheoryUfRewriter::decomposeHoApply(ha.second[i], args);
+        // must substitute the operator we matched with the original
+        // higher-order variable (var) that matched it. This ensures that the
+        // argument vector (args) below is of the proper length. This handles,
+        // for example, matches like:
+        //   (@ x y) with (@ (@ k1 k2) k3)
+        // where k3 but not k2 should be an argument of the match.
+        Node hmatch = ha.second[i];
+        Trace("ho-unif-debug2") << "Match is " << hmatch << std::endl;
+        hmatch = hmatch.substitute(value, var);
+        Trace("ho-unif-debug2") << "Pre-subs match is " << hmatch << std::endl;
+        Node f = uf::TheoryUfRewriter::decomposeHoApply(hmatch, args);
         // Assert( f==value );
         for (unsigned k = 0, size = args.size(); k < size; k++)
         {
@@ -348,7 +358,9 @@ bool HigherOrderTrigger::sendInstantiation(InstMatch& m)
       Trace("ho-unif-debug2") << "finished." << std::endl;
     }
 
-    return sendInstantiation(m, 0);
+    bool ret = sendInstantiation(m, 0);
+    Trace("ho-unif-debug") << "Finished, success = " << ret << std::endl;
+    return ret;
   }
   else
   {
@@ -361,6 +373,8 @@ bool HigherOrderTrigger::sendInstantiation(InstMatch& m)
 // occurring as pattern operators (very small)
 bool HigherOrderTrigger::sendInstantiation(InstMatch& m, unsigned var_index)
 {
+  Trace("ho-unif-debug2") << "send inst " << var_index << " / "
+                          << d_ho_var_list.size() << std::endl;
   if (var_index == d_ho_var_list.size())
   {
     // we now have an instantiation to try
@@ -394,13 +408,18 @@ bool HigherOrderTrigger::sendInstantiationArg(InstMatch& m,
                                               Node lbvl,
                                               bool arg_changed)
 {
+  Trace("ho-unif-debug2") << "send inst arg " << arg_index << " / "
+                          << lbvl.getNumChildren() << std::endl;
   if (arg_index == lbvl.getNumChildren())
   {
     // construct the lambda
     if (arg_changed)
     {
+      Trace("ho-unif-debug2")
+          << "  make lambda from children: " << d_lchildren[vnum] << std::endl;
       Node body =
           NodeManager::currentNM()->mkNode(kind::APPLY_UF, d_lchildren[vnum]);
+      Trace("ho-unif-debug2") << "  got " << body << std::endl;
       Node lam = NodeManager::currentNM()->mkNode(kind::LAMBDA, lbvl, body);
       m.d_vals[vnum] = lam;
       Trace("ho-unif-debug2") << "  try " << vnum << " -> " << lam << std::endl;
@@ -438,35 +457,58 @@ bool HigherOrderTrigger::sendInstantiationArg(InstMatch& m,
 
 int HigherOrderTrigger::addHoTypeMatchPredicateLemmas()
 {
+  if (d_ho_var_types.empty())
+  {
+    return 0;
+  }
+  Trace("ho-quant-trigger") << "addHoTypeMatchPredicateLemmas..." << std::endl;
   unsigned numLemmas = 0;
-  if (!d_ho_var_types.empty())
+  // this forces expansion of APPLY_UF terms to curried HO_APPLY chains
+  unsigned size = d_quantEngine->getTermDatabase()->getNumOperators();
+  quantifiers::TermUtil* tutil = d_quantEngine->getTermUtil();
+  NodeManager* nm = NodeManager::currentNM();
+  for (unsigned j = 0; j < size; j++)
   {
-    // this forces expansion of APPLY_UF terms to curried HO_APPLY chains
-    unsigned size = d_quantEngine->getTermDatabase()->getNumOperators();
-    for (unsigned j = 0; j < size; j++)
+    Node f = d_quantEngine->getTermDatabase()->getOperator(j);
+    if (f.isVar())
     {
-      Node f = d_quantEngine->getTermDatabase()->getOperator(j);
-      if (f.isVar())
+      TypeNode tn = f.getType();
+      if (tn.isFunction())
       {
-        TypeNode tn = f.getType();
-        if (d_ho_var_types.find(tn) != d_ho_var_types.end())
+        std::vector<TypeNode> argTypes = tn.getArgTypes();
+        Assert(argTypes.size() > 0);
+        TypeNode range = tn.getRangeType();
+        // for each function type suffix of the type of f, for example if
+        // f : (Int -> (Int -> Int))
+        // we iterate with stn = (Int -> (Int -> Int)) and (Int -> Int)
+        for (unsigned a = 0, size = argTypes.size(); a < size; a++)
         {
-          Node u = d_quantEngine->getTermUtil()->getHoTypeMatchPredicate(tn);
-          Node au = NodeManager::currentNM()->mkNode(kind::APPLY_UF, u, f);
-          if (d_quantEngine->addLemma(au))
+          std::vector<TypeNode> sargts;
+          sargts.insert(sargts.begin(), argTypes.begin() + a, argTypes.end());
+          Assert(sargts.size() > 0);
+          TypeNode stn = nm->mkFunctionType(sargts, range);
+          Trace("ho-quant-trigger-debug")
+              << "For " << f << ", check " << stn << "..." << std::endl;
+          // if a variable of this type occurs in this trigger
+          if (d_ho_var_types.find(stn) != d_ho_var_types.end())
           {
-            // this forces f to be a first-class member of the quantifier-free
-            // equality engine,
-            //  which in turn forces the quantifier-free theory solver to expand
-            //  it to HO_APPLY
-            Trace("ho-quant") << "Added ho match predicate lemma : " << au
-                              << std::endl;
-            numLemmas++;
+            Node u = tutil->getHoTypeMatchPredicate(tn);
+            Node au = nm->mkNode(kind::APPLY_UF, u, f);
+            if (d_quantEngine->addLemma(au))
+            {
+              // this forces f to be a first-class member of the quantifier-free
+              // equality engine, which in turn forces the quantifier-free
+              // theory solver to expand it to an HO_APPLY chain.
+              Trace("ho-quant")
+                  << "Added ho match predicate lemma : " << au << std::endl;
+              numLemmas++;
+            }
           }
         }
       }
     }
   }
+
   return numLemmas;
 }
 
index 8e22b2ced8b4dba88f2e1a3535fae2d7fa258cbb..2013bff5d521aec208d73e79228b180581c57e00 100644 (file)
@@ -308,6 +308,7 @@ void TermDb::computeUfTerms( TNode f ) {
     if( options::ufHo() ){
       ops.insert( ops.end(), d_ho_op_rep_slaves[f].begin(), d_ho_op_rep_slaves[f].end() );
     }
+    Trace("term-db-debug") << "computeUfTerms for " << f << std::endl;
     unsigned congruentCount = 0;
     unsigned nonCongruentCount = 0;
     unsigned alreadyCongruentCount = 0;
@@ -318,7 +319,8 @@ void TermDb::computeUfTerms( TNode f ) {
       //Assert( !options::ufHo() || ee->areEqual( ff, f ) );
       std::map< Node, std::vector< Node > >::iterator it = d_op_map.find( ff );
       if( it!=d_op_map.end() ){
-        Trace("term-db-debug") << "Adding terms for operator " << f << std::endl;
+        Trace("term-db-debug")
+            << "Adding terms for operator " << ff << std::endl;
         for( unsigned i=0; i<it->second.size(); i++ ){
           Node n = it->second[i];
           //to be added to term index, term must be relevant, and exist in EE
@@ -856,14 +858,22 @@ bool TermDb::reset( Theory::Effort effort ){
         eq::EqClassIterator eqc_i = eq::EqClassIterator( r, ee );
         while( !eqc_i.isFinished() ){
           TNode n = (*eqc_i);
-          if( d_op_map.find( n )!=d_op_map.end() ){
-            if( first.isNull() ){
-              first = n;
-              d_ho_op_rep[n] = n;
-            }else{
-              Trace("quant-ho") << "  have : " << n << " == " << first << ", type = " << n.getType() << std::endl;
-              d_ho_op_rep[n] = first;
-              d_ho_op_rep_slaves[first].push_back( n );
+          if (n.isVar())
+          {
+            if (d_op_map.find(n) != d_op_map.end())
+            {
+              if (first.isNull())
+              {
+                first = n;
+                d_ho_op_rep[n] = n;
+              }
+              else
+              {
+                Trace("quant-ho") << "  have : " << n << " == " << first
+                                  << ", type = " << n.getType() << std::endl;
+                d_ho_op_rep[n] = first;
+                d_ho_op_rep_slaves[first].push_back(n);
+              }
             }
           }
           ++eqc_i;
index 4fb9065c6af9843116217d13a5a343ce5e8520dd..de80368a483f003221ba8dcdd5f48ba189ff1a64 100644 (file)
@@ -443,6 +443,7 @@ REG0_TESTS = \
        regress0/ho/ext-ho.smt2 \
        regress0/ho/ext-sat-partial-eval.smt2 \
        regress0/ho/ext-sat.smt2 \
+       regress0/ho/ho-match-fun-suffix.smt2 \
        regress0/ho/ho-matching-enum.smt2 \
        regress0/ho/ho-matching-nested-app.smt2 \
        regress0/ho/ite-apply-eq.smt2 \
diff --git a/test/regress/regress0/ho/ho-match-fun-suffix.smt2 b/test/regress/regress0/ho/ho-match-fun-suffix.smt2
new file mode 100644 (file)
index 0000000..1e4ad24
--- /dev/null
@@ -0,0 +1,13 @@
+; COMMAND-LINE: --uf-ho
+; EXPECT: unsat
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun f (Int Int) Int)
+(declare-fun a () Int)
+(declare-fun b () Int)
+
+(assert (forall ((x (-> Int Int)) (y Int)) (not (= (x y) 0))))
+
+(assert (= (f a b) 0))
+
+(check-sat)
\ No newline at end of file