Minor cleaning of instantiation utilities (#7334)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 12 Oct 2021 15:29:38 +0000 (10:29 -0500)
committerGitHub <noreply@github.com>
Tue, 12 Oct 2021 15:29:38 +0000 (15:29 +0000)
Also fixes a bug in the quantifiers macro utility which did not compute the instantiation constant body of a quantified formula properly.

This is work towards a major refactoring of conflict-based instantiation / entailment checks.

src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp
src/theory/quantifiers/cegqi/inst_strategy_cegqi.h
src/theory/quantifiers/instantiate.cpp
src/theory/quantifiers/instantiate.h
src/theory/quantifiers/quant_module.cpp
src/theory/quantifiers/quant_module.h
src/theory/quantifiers/quantifiers_macros.cpp
src/theory/quantifiers/quantifiers_registry.cpp
src/theory/quantifiers/quantifiers_registry.h

index 0337d8959502c528e0c7bca7145e6adb09314399..339524d942c283a0d9e27d910401c7f049475e06 100644 (file)
@@ -40,10 +40,8 @@ InstRewriterCegqi::InstRewriterCegqi(InstStrategyCegqi* p)
 {
 }
 
-TrustNode InstRewriterCegqi::rewriteInstantiation(Node q,
-                                                  std::vector<Node>& terms,
-                                                  Node inst,
-                                                  bool doVts)
+TrustNode InstRewriterCegqi::rewriteInstantiation(
+    Node q, const std::vector<Node>& terms, Node inst, bool doVts)
 {
   return d_parent->rewriteInstantiation(q, terms, inst, doVts);
 }
@@ -344,10 +342,8 @@ void InstStrategyCegqi::preRegisterQuantifier(Node q)
     }
   }
 }
-TrustNode InstStrategyCegqi::rewriteInstantiation(Node q,
-                                                  std::vector<Node>& terms,
-                                                  Node inst,
-                                                  bool doVts)
+TrustNode InstStrategyCegqi::rewriteInstantiation(
+    Node q, const std::vector<Node>& terms, Node inst, bool doVts)
 {
   Node prevInst = inst;
   if (doVts)
index a568b0b4d7c1cb79c8a484561c25f67df2c590c7..5a886e28ddcd3dda2b83dfe543c97b1c1cf3bd80 100644 (file)
@@ -49,7 +49,7 @@ class InstRewriterCegqi : public InstantiationRewriter
    * corresponding to the rewrite and its proof generator.
    */
   TrustNode rewriteInstantiation(Node q,
-                                 std::vector<Node>& terms,
+                                 const std::vector<Node>& terms,
                                  Node inst,
                                  bool doVts) override;
 
@@ -116,7 +116,7 @@ class InstStrategyCegqi : public QuantifiersModule
    * proof generator.
    */
   TrustNode rewriteInstantiation(Node q,
-                                 std::vector<Node>& terms,
+                                 const std::vector<Node>& terms,
                                  Node inst,
                                  bool doVts);
   /** get the instantiation rewriter object */
index 0807188d5a9b9d6f4d09189b0c4a2ca77474c1cd..f756fcfd159da8d74bca76b92e2274b2754f79f0 100644 (file)
@@ -485,7 +485,7 @@ bool Instantiate::addInstantiationExpFail(Node q,
 }
 
 void Instantiate::recordInstantiation(Node q,
-                                      std::vector<Node>& terms,
+                                      const std::vector<Node>& terms,
                                       bool doVts)
 {
   Trace("inst-debug") << "Record instantiation for " << q << std::endl;
@@ -497,7 +497,7 @@ void Instantiate::recordInstantiation(Node q,
 }
 
 bool Instantiate::existsInstantiation(Node q,
-                                      std::vector<Node>& terms,
+                                      const std::vector<Node>& terms,
                                       bool modEq)
 {
   if (options::incrementalSolving())
@@ -520,8 +520,8 @@ bool Instantiate::existsInstantiation(Node q,
 }
 
 Node Instantiate::getInstantiation(Node q,
-                                   std::vector<Node>& vars,
-                                   std::vector<Node>& terms,
+                                   const std::vector<Node>& vars,
+                                   const std::vector<Node>& terms,
                                    InferenceId id,
                                    Node pfArg,
                                    bool doVts,
@@ -576,14 +576,17 @@ Node Instantiate::getInstantiation(Node q,
   return body;
 }
 
-Node Instantiate::getInstantiation(Node q, std::vector<Node>& terms, bool doVts)
+Node Instantiate::getInstantiation(Node q,
+                                   const std::vector<Node>& terms,
+                                   bool doVts)
 {
   Assert(d_qreg.d_vars.find(q) != d_qreg.d_vars.end());
   return getInstantiation(
       q, d_qreg.d_vars[q], terms, InferenceId::UNKNOWN, Node::null(), doVts);
 }
 
-bool Instantiate::recordInstantiationInternal(Node q, std::vector<Node>& terms)
+bool Instantiate::recordInstantiationInternal(Node q,
+                                              const std::vector<Node>& terms)
 {
   if (options::incrementalSolving())
   {
@@ -601,7 +604,8 @@ bool Instantiate::recordInstantiationInternal(Node q, std::vector<Node>& terms)
   return d_inst_match_trie[q].addInstMatch(d_qstate, q, terms);
 }
 
-bool Instantiate::removeInstantiationInternal(Node q, std::vector<Node>& terms)
+bool Instantiate::removeInstantiationInternal(Node q,
+                                              const std::vector<Node>& terms)
 {
   if (options::incrementalSolving())
   {
index 753213f35319e65d0749f102f2ac6e0ed8f1d0ed..b4972a7b6f97d6a2250c2b5d96f524c9a4f52ae6 100644 (file)
@@ -65,7 +65,7 @@ class InstantiationRewriter
    * and its proof generator.
    */
   virtual TrustNode rewriteInstantiation(Node q,
-                                         std::vector<Node>& terms,
+                                         const std::vector<Node>& terms,
                                          Node inst,
                                          bool doVts) = 0;
 };
@@ -203,7 +203,7 @@ class Instantiate : public QuantifiersUtil
    * but does not enqueue an instantiation lemma.
    */
   void recordInstantiation(Node q,
-                           std::vector<Node>& terms,
+                           const std::vector<Node>& terms,
                            bool doVts = false);
   /** exists instantiation
    *
@@ -212,7 +212,7 @@ class Instantiate : public QuantifiersUtil
    *   modEq : whether to check for duplication modulo equality
    */
   bool existsInstantiation(Node q,
-                           std::vector<Node>& terms,
+                           const std::vector<Node>& terms,
                            bool modEq = false);
   //--------------------------------------general utilities
   /** get instantiation
@@ -225,8 +225,8 @@ class Instantiate : public QuantifiersUtil
    * single INSTANTIATE step concluding the instantiated body of q from q.
    */
   Node getInstantiation(Node q,
-                        std::vector<Node>& vars,
-                        std::vector<Node>& terms,
+                        const std::vector<Node>& vars,
+                        const std::vector<Node>& terms,
                         InferenceId id = InferenceId::UNKNOWN,
                         Node pfArg = Node::null(),
                         bool doVts = false,
@@ -235,7 +235,9 @@ class Instantiate : public QuantifiersUtil
    *
    * Same as above but with vars equal to the bound variables of q.
    */
-  Node getInstantiation(Node q, std::vector<Node>& terms, bool doVts = false);
+  Node getInstantiation(Node q,
+                        const std::vector<Node>& terms,
+                        bool doVts = false);
   //--------------------------------------end general utilities
 
   /**
@@ -297,9 +299,9 @@ class Instantiate : public QuantifiersUtil
 
  private:
   /** record instantiation, return true if it was not a duplicate */
-  bool recordInstantiationInternal(Node q, std::vector<Node>& terms);
+  bool recordInstantiationInternal(Node q, const std::vector<Node>& terms);
   /** remove instantiation from the cache */
-  bool removeInstantiationInternal(Node q, std::vector<Node>& terms);
+  bool removeInstantiationInternal(Node q, const std::vector<Node>& terms);
   /**
    * Ensure that n has type tn, return a term equivalent to it for that type
    * if possible.
index 8fb37c54826f3cf7eb5107215ffc35a44adfd531..db4eb9d34f669576479c4be6612c4c2c7205733d 100644 (file)
@@ -76,5 +76,10 @@ quantifiers::QuantifiersRegistry& QuantifiersModule::getQuantifiersRegistry()
   return d_qreg;
 }
 
+quantifiers::TermRegistry& QuantifiersModule::getTermRegistry()
+{
+  return d_treg;
+}
+
 }  // namespace theory
 }  // namespace cvc5
index 639f9c2b41bf20a0da35091cd43335f118c1bd9a..ce6c1b04cb22aa21bf2798457513b5cff68075d5 100644 (file)
@@ -167,6 +167,8 @@ class QuantifiersModule : protected EnvObj
   quantifiers::QuantifiersInferenceManager& getInferenceManager();
   /** get the quantifiers registry */
   quantifiers::QuantifiersRegistry& getQuantifiersRegistry();
+  /** get the term registry */
+  quantifiers::TermRegistry& getTermRegistry();
   //----------------------------end general queries
  protected:
   /** Reference to the state of the quantifiers engine */
index a3bdf10adca589b11633f1312ff30a8bd4f0b4e6..9b9580b02e5112b599cdc287a4754339d6925449 100644 (file)
@@ -88,7 +88,7 @@ Node QuantifiersMacros::solve(Node lit, bool reqGround)
             << "...does not contain bad (recursive) operator." << std::endl;
         // must be ground UF term if mode is GROUND_UF
         if (options::macrosQuantMode() != options::MacrosQuantMode::GROUND_UF
-            || preservesTriggerVariables(body, n_def))
+            || preservesTriggerVariables(lit, n_def))
         {
           Trace("macros-debug")
               << "...respects ground-uf constraint." << std::endl;
@@ -139,6 +139,7 @@ bool QuantifiersMacros::containsBadOp(Node n, Node op, bool reqGround)
 
 bool QuantifiersMacros::preservesTriggerVariables(Node q, Node n)
 {
+  Assert(q.getKind() == FORALL) << "Expected quantified formula, got " << q;
   Node icn = d_qreg.substituteBoundVariablesToInstConstants(n, q);
   Trace("macros-debug2") << "Get free variables in " << icn << std::endl;
   std::vector<Node> var;
index 6d5e003635cbe9c64ea11174f70e7978512f5cb8..487bcc0fef40be8a2c74433dc851ad2f2a9d60c3 100644 (file)
@@ -38,6 +38,7 @@ void QuantifiersRegistry::registerQuantifier(Node q)
   {
     return;
   }
+  Assert(q.getKind() == kind::FORALL);
   NodeManager* nm = NodeManager::currentNM();
   Debug("quantifiers-engine")
       << "Instantiation constants for " << q << " : " << std::endl;
@@ -144,42 +145,42 @@ Node QuantifiersRegistry::substituteBoundVariablesToInstConstants(Node n,
                                                                   Node q)
 {
   registerQuantifier(q);
-  return n.substitute(d_vars[q].begin(),
-                      d_vars[q].end(),
-                      d_inst_constants[q].begin(),
-                      d_inst_constants[q].end());
+  std::vector<Node>& vars = d_vars.at(q);
+  std::vector<Node>& consts = d_inst_constants.at(q);
+  Assert(vars.size() == q[0].getNumChildren());
+  Assert(vars.size() == consts.size());
+  return n.substitute(vars.begin(), vars.end(), consts.begin(), consts.end());
 }
 
 Node QuantifiersRegistry::substituteInstConstantsToBoundVariables(Node n,
                                                                   Node q)
 {
   registerQuantifier(q);
-  return n.substitute(d_inst_constants[q].begin(),
-                      d_inst_constants[q].end(),
-                      d_vars[q].begin(),
-                      d_vars[q].end());
+  std::vector<Node>& vars = d_vars.at(q);
+  std::vector<Node>& consts = d_inst_constants.at(q);
+  Assert(vars.size() == q[0].getNumChildren());
+  Assert(vars.size() == consts.size());
+  return n.substitute(consts.begin(), consts.end(), vars.begin(), vars.end());
 }
 
-Node QuantifiersRegistry::substituteBoundVariables(Node n,
-                                                   Node q,
-                                                   std::vector<Node>& terms)
+Node QuantifiersRegistry::substituteBoundVariables(
+    Node n, Node q, const std::vector<Node>& terms)
 {
   registerQuantifier(q);
-  Assert(d_vars[q].size() == terms.size());
-  return n.substitute(
-      d_vars[q].begin(), d_vars[q].end(), terms.begin(), terms.end());
+  std::vector<Node>& vars = d_vars.at(q);
+  Assert(vars.size() == q[0].getNumChildren());
+  Assert(vars.size() == terms.size());
+  return n.substitute(vars.begin(), vars.end(), terms.begin(), terms.end());
 }
 
-Node QuantifiersRegistry::substituteInstConstants(Node n,
-                                                  Node q,
-                                                  std::vector<Node>& terms)
+Node QuantifiersRegistry::substituteInstConstants(
+    Node n, Node q, const std::vector<Node>& terms)
 {
   registerQuantifier(q);
-  Assert(d_inst_constants[q].size() == terms.size());
-  return n.substitute(d_inst_constants[q].begin(),
-                      d_inst_constants[q].end(),
-                      terms.begin(),
-                      terms.end());
+  std::vector<Node>& consts = d_inst_constants.at(q);
+  Assert(consts.size() == q[0].getNumChildren());
+  Assert(consts.size() == terms.size());
+  return n.substitute(consts.begin(), consts.end(), terms.begin(), terms.end());
 }
 
 QuantAttributes& QuantifiersRegistry::getQuantAttributes()
index 559939bbee8389c8918383448f9f533e688994f9..858a85cae45efdc1267cf047786a6f4835b4a4e9 100644 (file)
@@ -84,9 +84,9 @@ class QuantifiersRegistry : public QuantifiersUtil
    */
   Node substituteInstConstantsToBoundVariables(Node n, Node q);
   /** substitute { variables of q -> terms } in n */
-  Node substituteBoundVariables(Node n, Node q, std::vector<Node>& terms);
+  Node substituteBoundVariables(Node n, Node q, const std::vector<Node>& terms);
   /** substitute { instantiation constants of q -> terms } in n */
-  Node substituteInstConstants(Node n, Node q, std::vector<Node>& terms);
+  Node substituteInstConstants(Node n, Node q, const std::vector<Node>& terms);
   //----------------------------- end instantiation constants
   /** Get quantifiers attributes utility class */
   QuantAttributes& getQuantAttributes();