Fix single invocation partition for higher-order (#7046)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 23 Aug 2021 20:45:38 +0000 (15:45 -0500)
committerGitHub <noreply@github.com>
Mon, 23 Aug 2021 20:45:38 +0000 (20:45 +0000)
It was not robust to cases where a function-to-synthesize occurred in a higher-order context.

Also does general clean up of the single invocation utility.

src/expr/subs.cpp
src/expr/subs.h
src/theory/quantifiers/single_inv_partition.cpp
src/theory/quantifiers/single_inv_partition.h
test/regress/CMakeLists.txt
test/regress/regress0/sygus/ho-occ-synth-fun.sy [new file with mode: 0644]

index 7e9c83d061b3b4c48266841ecc378f7cb067e0b0..b140a41909e5160567c76fd6caace8c011bf744f 100644 (file)
@@ -173,6 +173,12 @@ std::string Subs::toString() const
   return ss.str();
 }
 
+void Subs::clear()
+{
+  d_vars.clear();
+  d_subs.clear();
+}
+
 std::ostream& operator<<(std::ostream& out, const Subs& s)
 {
   out << s.toString();
index 56158d36cf6e5e5fb42af0864724b8d8086fbb13..afde63b6e8e4bbab4d775a6ca1529cb903ee5fa7 100644 (file)
@@ -67,6 +67,8 @@ class Subs
   std::map<Node, Node> toMap() const;
   /** Get string for this substitution */
   std::string toString() const;
+  /** clear the substitution */
+  void clear();
   /** The data */
   std::vector<Node> d_vars;
   std::vector<Node> d_subs;
index 05492b5b7124c09e45cf08b772fdcd32a928d029..73bcad535ba11d1f050b8ced4589136de877f6fa 100644 (file)
@@ -233,42 +233,37 @@ bool SingleInvocationPartition::init(std::vector<Node>& funcs,
       std::map<Node, bool> visited;
       // functions to arguments
       std::vector<Node> args;
-      std::vector<Node> terms;
-      std::vector<Node> subs;
+      Subs sb;
       bool singleInvocation = true;
       bool ngroundSingleInvocation = false;
-      if (processConjunct(cr, visited, args, terms, subs))
+      if (processConjunct(cr, visited, args, sb))
       {
-        for (unsigned j = 0; j < terms.size(); j++)
+        for (size_t j = 0, vsize = sb.d_vars.size(); j < vsize; j++)
         {
-          si_terms.push_back(subs[j]);
-          Node op = subs[j].hasOperator() ? subs[j].getOperator() : subs[j];
+          Node s = sb.d_subs[j];
+          si_terms.push_back(s);
+          Node op = s.hasOperator() ? s.getOperator() : s;
           Assert(d_func_fo_var.find(op) != d_func_fo_var.end());
           si_subs.push_back(d_func_fo_var[op]);
         }
         std::map<Node, Node> subs_map;
         std::map<Node, Node> subs_map_rev;
         // normalize the invocations
-        if (!terms.empty())
+        if (!sb.empty())
         {
-          Assert(terms.size() == subs.size());
-          cr = cr.substitute(
-              terms.begin(), terms.end(), subs.begin(), subs.end());
+          cr = sb.apply(cr);
         }
         std::vector<Node> children;
         children.push_back(cr);
-        terms.clear();
-        subs.clear();
+        sb.clear();
         Trace("si-prt") << "...single invocation, with arguments: "
                         << std::endl;
         for (unsigned j = 0; j < args.size(); j++)
         {
           Trace("si-prt") << args[j] << " ";
-          if (args[j].getKind() == BOUND_VARIABLE
-              && std::find(terms.begin(), terms.end(), args[j]) == terms.end())
+          if (args[j].getKind() == BOUND_VARIABLE && !sb.contains(args[j]))
           {
-            terms.push_back(args[j]);
-            subs.push_back(d_si_vars[j]);
+            sb.add(args[j], d_si_vars[j]);
           }
           else
           {
@@ -276,12 +271,8 @@ bool SingleInvocationPartition::init(std::vector<Node>& funcs,
           }
         }
         Trace("si-prt") << std::endl;
-        cr = children.size() == 1
-                 ? children[0]
-                 : NodeManager::currentNM()->mkNode(OR, children);
-        Assert(terms.size() == subs.size());
-        cr =
-            cr.substitute(terms.begin(), terms.end(), subs.begin(), subs.end());
+        cr = nm->mkOr(children);
+        cr = sb.apply(cr);
         Trace("si-prt-debug") << "...normalized invocations to " << cr
                               << std::endl;
         // now must check if it has other bound variables
@@ -417,8 +408,7 @@ bool SingleInvocationPartition::collectConjuncts(Node n,
 bool SingleInvocationPartition::processConjunct(Node n,
                                                 std::map<Node, bool>& visited,
                                                 std::vector<Node>& args,
-                                                std::vector<Node>& terms,
-                                                std::vector<Node>& subs)
+                                                Subs& sb)
 {
   std::map<Node, bool>::iterator it = visited.find(n);
   if (it != visited.end())
@@ -430,7 +420,7 @@ bool SingleInvocationPartition::processConjunct(Node n,
     bool ret = true;
     for (unsigned i = 0; i < n.getNumChildren(); i++)
     {
-      if (!processConjunct(n[i], visited, args, terms, subs))
+      if (!processConjunct(n[i], visited, args, sb))
       {
         ret = false;
       }
@@ -445,7 +435,20 @@ bool SingleInvocationPartition::processConjunct(Node n,
         if (std::find(d_input_funcs.begin(), d_input_funcs.end(), f)
             != d_input_funcs.end())
         {
-          success = true;
+          // If n is an application of a function-to-synthesize f, or is
+          // itself a function-to-synthesize, then n must be fully applied.
+          // This catches cases where n is a function-to-synthesize that occurs
+          // in a higher-order context.
+          // If the type of n is functional, then it is not fully applied.
+          if (n.getType().isFunction())
+          {
+            ret = false;
+            success = false;
+          }
+          else
+          {
+            success = true;
+          }
         }
       }
       else
@@ -458,7 +461,8 @@ bool SingleInvocationPartition::processConjunct(Node n,
       }
       if (success)
       {
-        if (std::find(terms.begin(), terms.end(), n) == terms.end())
+        Trace("si-prt-debug") << "Process " << n << std::endl;
+        if (!sb.contains(n))
         {
           // check if it matches the type requirement
           if (isAntiSkolemizableType(f))
@@ -487,8 +491,7 @@ bool SingleInvocationPartition::processConjunct(Node n,
             }
             if (ret)
             {
-              terms.push_back(n);
-              subs.push_back(d_func_inv[f]);
+              sb.add(n, d_func_inv[f]);
             }
           }
           else
@@ -500,7 +503,6 @@ bool SingleInvocationPartition::processConjunct(Node n,
         }
       }
     }
-    //}
     visited[n] = ret;
     return ret;
   }
index 7f8cfc3265d8682731d34765a1b5a6fc7606232a..1b4ea62b003ff2ac5b12cdd68ded42be926a81b4 100644 (file)
@@ -22,6 +22,7 @@
 #include <vector>
 
 #include "expr/node.h"
+#include "expr/subs.h"
 #include "expr/type_node.h"
 
 namespace cvc5 {
@@ -284,8 +285,7 @@ class SingleInvocationPartition
   bool processConjunct(Node n,
                        std::map<Node, bool>& visited,
                        std::vector<Node>& args,
-                       std::vector<Node>& terms,
-                       std::vector<Node>& subs);
+                       Subs& sb);
 
   /** get the and node corresponding to d_conjuncts[index] */
   Node getConjunct(int index);
index 483e3a01e4c5919537581cb33cd83b33e5cce8c2..d68400b662fd9016b8911d7a61a9a99badc4854b 100644 (file)
@@ -1243,6 +1243,7 @@ set(regress_0_tests
   regress0/sygus/dt-sel-parse1.sy
   regress0/sygus/General_plus10.sy
   regress0/sygus/hd-05-d1-prog-nogrammar.sy
+  regress0/sygus/ho-occ-synth-fun.sy
   regress0/sygus/inv-different-var-order.sy
   regress0/sygus/issue3356-syg-inf-usort.smt2
   regress0/sygus/issue3624.sy
diff --git a/test/regress/regress0/sygus/ho-occ-synth-fun.sy b/test/regress/regress0/sygus/ho-occ-synth-fun.sy
new file mode 100644 (file)
index 0000000..4bef02b
--- /dev/null
@@ -0,0 +1,9 @@
+; COMMAND-LINE: --sygus-out=status
+; EXPECT: unsat
+(set-logic HO_ALL)
+(synth-fun f ((x Int)) Int)
+(synth-fun g ((x Int)) Int)
+(declare-var P (-> (-> Int Int) Bool))
+(constraint (=> (P f) (P g)))
+; a trivial class of solutions is where f = g.
+(check-synth)