Make term database optionally SAT-context-dependent (#5877)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 9 Feb 2021 21:35:40 +0000 (15:35 -0600)
committerGitHub <noreply@github.com>
Tue, 9 Feb 2021 21:35:40 +0000 (15:35 -0600)
This makes the terms registered to the term database (those considered by E-matching) optionally stored in a SAT-context-dependent manner. The motivation is to have a more flexible/fine-grained set of terms considered by E-matching, e.g. if preregistration becomes lazier in the future.

This uncovered 2 issues:

The induction techniques in "conjecture generator" were using private interfaces, this PR removes the friend relaionship and cleans the code
The conflict-based instantiation module was accessing the signature tables for BOUND_VARIABLES when an operator of an APPLY_UF was a BOUND_VARIABLE. This is possible when options::ufHo is enabled. This makes conflict-based instantiation skip such terms.

src/options/quantifiers_options.toml
src/theory/quantifiers/conjecture_generator.cpp
src/theory/quantifiers/quant_conflict_find.cpp
src/theory/quantifiers/term_database.cpp
src/theory/quantifiers/term_database.h
src/theory/quantifiers_engine.cpp

index fd781ab2b109a1f9fbabe6c23e664c6bb907e2ff..6fb3dd676f51a6aa23acf9ec07f355245fb19323 100644 (file)
@@ -219,6 +219,14 @@ header = "options/quantifiers_options.h"
   name = "relevant"
   help = "Quantifiers module considers only ground terms connected to current assertions."
 
+[[option]]
+  name       = "termDbCd"
+  category   = "regular"
+  long       = "term-db-cd"
+  type       = "bool"
+  default    = "true"
+  help       = "register terms in term database based on the SAT context"
+
 [[option]]
   name       = "registerQuantBodyTerms"
   category   = "regular"
index bfaf0b83cb3b87df89c4d2686237f71c9e275b5c..02b9ca6358dd3d7c5c773df5251c206b8a8212f4 100644 (file)
@@ -409,8 +409,12 @@ void ConjectureGenerator::check(Theory::Effort e, QEffort quant_e)
           Trace("sg-proc-debug") << "......term : " << n << std::endl;
           if( getTermDatabase()->hasTermCurrent( n ) ){
             if( isHandledTerm( n ) ){
-              getTermDatabase()->computeArgReps( n );
-              d_op_arg_index[r].addTerm( getTermDatabase()->d_arg_reps[n], n );
+              std::vector<TNode> areps;
+              for (const Node& nc : n)
+              {
+                areps.push_back(d_qstate.getRepresentative(nc));
+              }
+              d_op_arg_index[r].addTerm(areps, n);
             }
           }
           ++ieqc_i;
@@ -474,9 +478,9 @@ void ConjectureGenerator::check(Theory::Effort e, QEffort quant_e)
               }
               if( n.hasOperator() ){
                 Trace("sg-gen-eqc") << "   (" << n.getOperator();
-                getTermDatabase()->computeArgReps( n );
-                for (TNode ar : getTermDatabase()->d_arg_reps[n])
+                for (const Node& nc : n)
                 {
+                  TNode ar = d_qstate.getRepresentative(nc);
                   Trace("sg-gen-eqc") << " e" << d_em[ar];
                 }
                 Trace("sg-gen-eqc") << ") :: " << n << std::endl;
@@ -1800,9 +1804,14 @@ void TermGenEnv::collectSignatureInformation() {
   d_func_kind.clear();
   d_func_args.clear();
   TypeNode tnull;
-  for( std::map< Node, std::vector< Node > >::iterator it = getTermDatabase()->d_op_map.begin(); it != getTermDatabase()->d_op_map.end(); ++it ){
-    if( !it->second.empty() ){
-      Node nn = it->second[0];
+  TermDb* tdb = getTermDatabase();
+  for (size_t i = 0, nops = tdb->getNumOperators(); i < nops; i++)
+  {
+    Node op = tdb->getOperator(i);
+    DbList* dbl = tdb->getOrMkDbListForOp(op);
+    if (!dbl->d_list.empty())
+    {
+      Node nn = dbl->d_list[0];
       Trace("sg-rel-sig-debug") << "Check in signature : " << nn << std::endl;
       if( d_cg->isHandledTerm( nn ) && nn.getKind()!=APPLY_SELECTOR_TOTAL && !nn.getType().isBoolean() ){
         bool do_enum = true;
@@ -1815,16 +1824,20 @@ void TermGenEnv::collectSignatureInformation() {
         }
         if( do_enum ){
           Trace("sg-rel-sig-debug") << "Set enumeration..." << std::endl;
-          d_funcs.push_back( it->first );
-          for( unsigned i=0; i<nn.getNumChildren(); i++ ){
-            d_func_args[it->first].push_back( nn[i].getType() );
+          d_funcs.push_back(op);
+          for (const Node& nnc : nn)
+          {
+            d_func_args[op].push_back(nnc.getType());
           }
-          d_func_kind[it->first] = nn.getKind();
-          d_typ_tg_funcs[tnull].push_back( it->first );
-          d_typ_tg_funcs[nn.getType()].push_back( it->first );
-          d_tg_func_param[it->first] = ( nn.getMetaKind() == kind::metakind::PARAMETERIZED );
-          Trace("sg-rel-sig") << "Will enumerate function applications of : " << it->first << ", #args = " << d_func_args[it->first].size() << ", kind = " << nn.getKind() << std::endl;
-          //getTermDatabase()->computeUfEqcTerms( it->first );
+          d_func_kind[op] = nn.getKind();
+          d_typ_tg_funcs[tnull].push_back(op);
+          d_typ_tg_funcs[nn.getType()].push_back(op);
+          d_tg_func_param[op] =
+              (nn.getMetaKind() == kind::metakind::PARAMETERIZED);
+          Trace("sg-rel-sig")
+              << "Will enumerate function applications of : " << op
+              << ", #args = " << d_func_args[op].size()
+              << ", kind = " << nn.getKind() << std::endl;
         }
       }
       Trace("sg-rel-sig-debug") << "Done check in signature : " << nn << std::endl;
index 433de2f85e2820e76ae19125ca29fa213996d933..0829ccd9ca9c2ae37e07323a5809efd791bd90e2 100644 (file)
@@ -18,6 +18,7 @@
 #include "expr/node_algorithm.h"
 #include "options/quantifiers_options.h"
 #include "options/theory_options.h"
+#include "options/uf_options.h"
 #include "smt/smt_statistics_registry.h"
 #include "theory/quantifiers/ematching/trigger_term_info.h"
 #include "theory/quantifiers/first_order_model.h"
@@ -980,7 +981,11 @@ MatchGen::MatchGen( QuantInfo * qi, Node n, bool isVar )
   d_qni_size = 0;
   if( isVar ){
     Assert(qi->d_var_num.find(n) != qi->d_var_num.end());
-    if( n.getKind()==ITE ){
+    // rare case where we have a free variable in an operator, we are invalid
+    if (n.getKind() == ITE
+        || (options::ufHo() && n.getKind() == APPLY_UF
+            && expr::hasFreeVar(n.getOperator())))
+    {
       d_type = typ_invalid;
     }else{
       d_type = isHandledUfTerm( n ) ? typ_var : typ_tsym;
index 56f85e99e8ae816e07892d58b1c7107c55576cb3..b83cdf3e54d9341459929692d86e67da378d05b0 100644 (file)
@@ -40,11 +40,24 @@ TermDb::TermDb(QuantifiersState& qs,
     : d_quantEngine(qe),
       d_qstate(qs),
       d_qim(qim),
+      d_termsContext(),
+      d_termsContextUse(options::termDbCd() ? qs.getSatContext()
+                                            : &d_termsContext),
+      d_processed(d_termsContextUse),
+      d_typeMap(d_termsContextUse),
+      d_ops(d_termsContextUse),
+      d_opMap(d_termsContextUse),
       d_inactive_map(qs.getSatContext())
 {
   d_consistent_ee = true;
   d_true = NodeManager::currentNM()->mkConst(true);
   d_false = NodeManager::currentNM()->mkConst(false);
+  if (!options::termDbCd())
+  {
+    // when not maintaining terms in a context-dependent manner, we clear during
+    // each presolve, which requires maintaining a single outermost level
+    d_termsContext.push();
+  }
 }
 
 TermDb::~TermDb(){
@@ -60,78 +73,70 @@ void TermDb::registerQuantifier( Node q ) {
   }
 }
 
-unsigned TermDb::getNumOperators() { return d_ops.size(); }
-Node TermDb::getOperator(unsigned i)
+size_t TermDb::getNumOperators() const { return d_ops.size(); }
+
+Node TermDb::getOperator(size_t i) const
 {
   Assert(i < d_ops.size());
   return d_ops[i];
 }
 
 /** ground terms */
-unsigned TermDb::getNumGroundTerms(Node f) const
+size_t TermDb::getNumGroundTerms(Node f) const
 {
-  std::map<Node, std::vector<Node> >::const_iterator it = d_op_map.find(f);
-  if( it!=d_op_map.end() ){
-    return it->second.size();
-  }else{
-    return 0;
+  NodeDbListMap::const_iterator it = d_opMap.find(f);
+  if (it != d_opMap.end())
+  {
+    return it->second->d_list.size();
   }
+  return 0;
 }
 
-Node TermDb::getGroundTerm(Node f, unsigned i) const
+Node TermDb::getGroundTerm(Node f, size_t i) const
 {
-  std::map<Node, std::vector<Node> >::const_iterator it = d_op_map.find(f);
-  if (it != d_op_map.end())
-  {
-    Assert(i < it->second.size());
-    return it->second[i];
-  }
-  else
+  NodeDbListMap::const_iterator it = d_opMap.find(f);
+  if (it != d_opMap.end())
   {
-    Assert(false);
-    return Node::null();
+    Assert(i < it->second->d_list.size());
+    return it->second->d_list[i];
   }
+  Assert(false);
+  return Node::null();
 }
 
-unsigned TermDb::getNumTypeGroundTerms(TypeNode tn) const
+size_t TermDb::getNumTypeGroundTerms(TypeNode tn) const
 {
-  std::map<TypeNode, std::vector<Node> >::const_iterator it =
-      d_type_map.find(tn);
-  if( it!=d_type_map.end() ){
-    return it->second.size();
-  }else{
-    return 0;
+  TypeNodeDbListMap::const_iterator it = d_typeMap.find(tn);
+  if (it != d_typeMap.end())
+  {
+    return it->second->d_list.size();
   }
+  return 0;
 }
 
-Node TermDb::getTypeGroundTerm(TypeNode tn, unsigned i) const
+Node TermDb::getTypeGroundTerm(TypeNode tn, size_t i) const
 {
-  std::map<TypeNode, std::vector<Node> >::const_iterator it =
-      d_type_map.find(tn);
-  if (it != d_type_map.end())
-  {
-    Assert(i < it->second.size());
-    return it->second[i];
-  }
-  else
+  TypeNodeDbListMap::const_iterator it = d_typeMap.find(tn);
+  if (it != d_typeMap.end())
   {
-    Assert(false);
-    return Node::null();
+    Assert(i < it->second->d_list.size());
+    return it->second->d_list[i];
   }
+  Assert(false);
+  return Node::null();
 }
 
 Node TermDb::getOrMakeTypeGroundTerm(TypeNode tn, bool reqVar)
 {
-  std::map<TypeNode, std::vector<Node> >::const_iterator it =
-      d_type_map.find(tn);
-  if (it != d_type_map.end())
+  TypeNodeDbListMap::const_iterator it = d_typeMap.find(tn);
+  if (it != d_typeMap.end())
   {
-    Assert(!it->second.empty());
+    Assert(!it->second->d_list.empty());
     if (!reqVar)
     {
-      return it->second[0];
+      return it->second->d_list[0];
     }
-    for (const Node& v : it->second)
+    for (const Node& v : it->second->d_list)
     {
       if (v.isVar())
       {
@@ -208,7 +213,8 @@ void TermDb::addTerm(Node n)
   if (!TermUtil::hasInstConstAttr(n))
   {
     Trace("term-db-debug") << "register term : " << n << std::endl;
-    d_type_map[n.getType()].push_back(n);
+    DbList* dlt = getOrMkDbListForType(n.getType());
+    dlt->d_list.push_back(n);
     // if this is an atomic trigger, consider adding it
     if (inst::TriggerTermInfo::isAtomicTrigger(n))
     {
@@ -216,11 +222,8 @@ void TermDb::addTerm(Node n)
 
       Node op = getMatchOperator(n);
       Trace("term-db-debug") << "  match operator is : " << op << std::endl;
-      if (d_op_map.find(op) == d_op_map.end())
-      {
-        d_ops.push_back(op);
-      }
-      d_op_map[op].push_back(n);
+      DbList* dlo = getOrMkDbListForOp(op);
+      dlo->d_list.push_back(n);
       // If we are higher-order, we may need to register more terms.
       if (options::ufHo())
       {
@@ -241,6 +244,32 @@ void TermDb::addTerm(Node n)
   }
 }
 
+DbList* TermDb::getOrMkDbListForType(TypeNode tn)
+{
+  TypeNodeDbListMap::iterator it = d_typeMap.find(tn);
+  if (it != d_typeMap.end())
+  {
+    return it->second.get();
+  }
+  std::shared_ptr<DbList> dl = std::make_shared<DbList>(d_termsContextUse);
+  d_typeMap.insert(tn, dl);
+  return dl.get();
+}
+
+DbList* TermDb::getOrMkDbListForOp(TNode op)
+{
+  NodeDbListMap::iterator it = d_opMap.find(op);
+  if (it != d_opMap.end())
+  {
+    return it->second.get();
+  }
+  std::shared_ptr<DbList> dl = std::make_shared<DbList>(d_termsContextUse);
+  d_opMap.insert(op, dl);
+  Assert(op.getKind() != BOUND_VARIABLE);
+  d_ops.push_back(op);
+  return dl.get();
+}
+
 void TermDb::computeArgReps( TNode n ) {
   if (d_arg_reps.find(n) == d_arg_reps.end())
   {
@@ -270,7 +299,8 @@ void TermDb::computeUfEqcTerms( TNode f ) {
   eq::EqualityEngine* ee = d_qstate.getEqualityEngine();
   for (TNode ff : ops)
   {
-    for (const Node& n : d_op_map[ff])
+    DbList* dbl = getOrMkDbListForOp(ff);
+    for (const Node& n : dbl->d_list)
     {
       if (hasTermCurrent(n) && isTermActive(n))
       {
@@ -305,14 +335,14 @@ void TermDb::computeUfTerms( TNode f ) {
   NodeManager* nm = NodeManager::currentNM();
   for (TNode ff : ops)
   {
-    std::map<Node, std::vector<Node> >::iterator it = d_op_map.find(ff);
-    if (it == d_op_map.end())
+    NodeDbListMap::iterator it = d_opMap.find(ff);
+    if (it == d_opMap.end())
     {
       // no terms for this operator
       continue;
     }
     Trace("term-db-debug") << "Adding terms for operator " << ff << std::endl;
-    for (const Node& n : it->second)
+    for (const Node& n : it->second->d_list)
     {
       // to be added to term index, term must be relevant, and exist in EE
       if (!hasTermCurrent(n) || !d_qstate.hasTerm(n))
@@ -422,7 +452,8 @@ void TermDb::computeUfTerms( TNode f ) {
                    << (nonCongruentCount + congruentCount
                        + alreadyCongruentCount)
                    << " / ";
-      Trace("tdb") << relevantCount << " / " << it->second.size() << std::endl;
+      Trace("tdb") << relevantCount << " / " << it->second->d_list.size()
+                   << std::endl;
     }
   }
 }
@@ -466,7 +497,8 @@ void TermDb::addTermHo(Node n)
       Trace("term-db") << "register term in db (via purify) " << p_n
                        << std::endl;
       // also add this one internally
-      d_op_map[psk].push_back(p_n);
+      DbList* dblp = getOrMkDbListForOp(psk);
+      dblp->d_list.push_back(p_n);
       // maintain backwards mapping
       d_ho_purify_to_term[p_n] = n;
     }
@@ -960,13 +992,10 @@ void TermDb::setHasTerm( Node n ) {
 }
 
 void TermDb::presolve() {
-  if( options::incrementalSolving() ){
-    // reset the caches that are SAT context-independent but user
-    // context-dependent
-    d_ops.clear();
-    d_op_map.clear();
-    d_type_map.clear();
-    d_processed.clear();
+  if (options::incrementalSolving() && !options::termDbCd())
+  {
+    d_termsContext.pop();
+    d_termsContext.push();
   }
 }
 
@@ -1102,7 +1131,7 @@ bool TermDb::reset( Theory::Effort effort ){
           }
           Trace("quant-ho") << "  - process " << n_use << ", from " << n
                             << std::endl;
-          if (!n_use.isNull() && d_op_map.find(n_use) != d_op_map.end())
+          if (!n_use.isNull() && d_opMap.find(n_use) != d_opMap.end())
           {
             if (first.isNull())
             {
@@ -1125,15 +1154,6 @@ bool TermDb::reset( Theory::Effort effort ){
     Trace("quant-ho") << "...finished compute equal functions." << std::endl;
   }
 
-/*
-  //rebuild d_func/pred_map_trie for each operation, this will calculate all congruent terms
-  for( std::map< Node, std::vector< Node > >::iterator it = d_op_map.begin(); it != d_op_map.end(); ++it ){
-    computeUfTerms( it->first );
-    if( !d_consistent_ee ){
-      return false;
-    }
-  }
-*/  
   return true;
 }
 
index 6a695e70e29d27825dd91825dcf7a0ab02b97744..b2f964a3a2f0f92912f427b8938405086178d42e 100644 (file)
@@ -36,6 +36,15 @@ namespace quantifiers {
 class ConjectureGenerator;
 class TermGenEnv;
 
+/** Context-dependent list of nodes */
+class DbList
+{
+ public:
+  DbList(context::Context* c) : d_list(c) {}
+  /** The list */
+  context::CDList<Node> d_list;
+};
+
 /** Term Database
  *
  * This class is a key utility used by
@@ -56,11 +65,13 @@ class TermGenEnv;
  */
 class TermDb : public QuantifiersUtil {
   friend class ::CVC4::theory::QuantifiersEngine;
-  // TODO: eliminate these
-  friend class ::CVC4::theory::quantifiers::ConjectureGenerator;
-  friend class ::CVC4::theory::quantifiers::TermGenEnv;
-  typedef context::CDHashMap<Node, int, NodeHashFunction> NodeIntMap;
-  typedef context::CDHashMap<Node, bool, NodeHashFunction> NodeBoolMap;
+  using NodeBoolMap = context::CDHashMap<Node, bool, NodeHashFunction>;
+  using NodeList = context::CDList<Node>;
+  using NodeSet = context::CDHashSet<Node, NodeHashFunction>;
+  using TypeNodeDbListMap = context::
+      CDHashMap<TypeNode, std::shared_ptr<DbList>, TypeNodeHashFunction>;
+  using NodeDbListMap =
+      context::CDHashMap<Node, std::shared_ptr<DbList>, NodeHashFunction>;
 
  public:
   TermDb(QuantifiersState& qs,
@@ -76,26 +87,26 @@ class TermDb : public QuantifiersUtil {
   /** identify */
   std::string identify() const override { return "TermDb"; }
   /** get number of operators */
-  unsigned getNumOperators();
+  size_t getNumOperators() const;
   /** get operator at index i */
-  Node getOperator(unsigned i);
+  Node getOperator(size_t i) const;
   /** ground terms for operator
   * Get the number of ground terms with operator f that have been added to the
   * database
   */
-  unsigned getNumGroundTerms(Node f) const;
+  size_t getNumGroundTerms(Node f) const;
   /** get ground term for operator
   * Get the i^th ground term with operator f that has been added to the database
   */
-  Node getGroundTerm(Node f, unsigned i) const;
+  Node getGroundTerm(Node f, size_t i) const;
   /** get num type terms
   * Get the number of ground terms of tn that have been added to the database
   */
-  unsigned getNumTypeGroundTerms(TypeNode tn) const;
+  size_t getNumTypeGroundTerms(TypeNode tn) const;
   /** get type ground term
   * Returns the i^th ground term of type tn
   */
-  Node getTypeGroundTerm(TypeNode tn, unsigned i) const;
+  Node getTypeGroundTerm(TypeNode tn, size_t i) const;
   /** get or make ground term
    *
    * Returns the first ground term of type tn, or makes one if none exist. If
@@ -113,6 +124,10 @@ class TermDb : public QuantifiersUtil {
    * matched with via E-matching, and can be used in entailment tests below.
    */
   void addTerm(Node n);
+  /** Get the currently added ground terms of the given type */
+  DbList* getOrMkDbListForType(TypeNode tn);
+  /** Get the currently added ground terms for the given operator */
+  DbList* getOrMkDbListForOp(TNode op);
   /** get match operator for term n
   *
   * If n has a kind that we index, this function will
@@ -280,8 +295,18 @@ class TermDb : public QuantifiersUtil {
   QuantifiersState& d_qstate;
   /** The quantifiers inference manager */
   QuantifiersInferenceManager& d_qim;
+  /** A context for the data structures below, when not context-dependent */
+  context::Context d_termsContext;
+  /** The context we are using for the data structures below */
+  context::Context* d_termsContextUse;
   /** terms processed */
-  std::unordered_set< Node, NodeHashFunction > d_processed;
+  NodeSet d_processed;
+  /** map from types to ground terms for that type */
+  TypeNodeDbListMap d_typeMap;
+  /** list of all operators */
+  NodeList d_ops;
+  /** map from operators to ground terms for that operator */
+  NodeDbListMap d_opMap;
   /** select op map */
   std::map< Node, std::map< TypeNode, Node > > d_par_op_map;
   /** whether master equality engine is UF-inconsistent */
@@ -289,12 +314,6 @@ class TermDb : public QuantifiersUtil {
   /** boolean terms */
   Node d_true;
   Node d_false;
-  /** list of all operators */
-  std::vector<Node> d_ops;
-  /** map from operators to ground terms for that operator */
-  std::map< Node, std::vector< Node > > d_op_map;
-  /** map from type nodes to terms of that type */
-  std::map< TypeNode, std::vector< Node > > d_type_map;
   /** map from type nodes to a fresh variable we introduced */
   std::unordered_map<TypeNode, Node, TypeNodeHashFunction> d_type_fv;
   /** inactive map */
index bdc12cdaa07a5e88122fd753196e53203e2b47d1..33ec3cbf8476d6b45b0d41c89b9b9673ce16b7c6 100644 (file)
@@ -263,7 +263,8 @@ void QuantifiersEngine::presolve() {
   d_term_db->presolve();
   d_presolve = false;
   //add all terms to database
-  if( options::incrementalSolving() ){
+  if (options::incrementalSolving() && !options::termDbCd())
+  {
     Trace("quant-engine-proc") << "Add presolve cache " << d_presolve_cache.size() << std::endl;
     for (const Node& t : d_presolve_cache)
     {
@@ -762,14 +763,16 @@ void QuantifiersEngine::addTermToDatabase(Node n, bool withinQuant)
   {
     return;
   }
-  if( options::incrementalSolving() ){
+  if (options::incrementalSolving() && !options::termDbCd())
+  {
     if( d_presolve_in.find( n )==d_presolve_in.end() ){
       d_presolve_in.insert( n );
       d_presolve_cache.push_back( n );
     }
   }
   //only wait if we are doing incremental solving
-  if( !d_presolve || !options::incrementalSolving() ){
+  if (!d_presolve || !options::incrementalSolving() || options::termDbCd())
+  {
     d_term_db->addTerm(n);
     if (d_sygus_tdb && options::sygusEvalUnfold())
     {