Further standardization of datatypes (#5076)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 17 Sep 2020 05:22:20 +0000 (00:22 -0500)
committerGitHub <noreply@github.com>
Thu, 17 Sep 2020 05:22:20 +0000 (22:22 -0700)
We now have no custom calls to equality engine explain, and only 2 manual calls to equality engine (in its entailment check). This also updates the notify class to the standard one.

This commit makes datatypes ready to start work on proofs.

src/theory/datatypes/theory_datatypes.cpp
src/theory/datatypes/theory_datatypes.h

index afa640650c8aacc302ae784ee689317c8c261ae2..3f75958200763fbb1f7554cbadae49f4e0c5e812 100644 (file)
@@ -48,7 +48,6 @@ TheoryDatatypes::TheoryDatatypes(Context* c,
                                  ProofNodeManager* pnm)
     : Theory(THEORY_DATATYPES, c, u, out, valuation, logicInfo, pnm),
       d_term_sk(u),
-      d_notify(*this),
       d_labels(c),
       d_selector_apps(c),
       d_collectTermsCache(c),
@@ -58,7 +57,8 @@ TheoryDatatypes::TheoryDatatypes(Context* c,
       d_lemmas_produced_c(u),
       d_sygusExtension(nullptr),
       d_state(c, u, valuation),
-      d_im(*this, d_state, pnm)
+      d_im(*this, d_state, pnm),
+      d_notify(d_im, *this)
 {
 
   d_true = NodeManager::currentNM()->mkConst( true );
@@ -598,90 +598,11 @@ TrustNode TheoryDatatypes::ppRewrite(TNode in)
   return TrustNode::null();
 }
 
-bool TheoryDatatypes::propagateLit(TNode literal)
-{
-  Debug("dt::propagate") << "TheoryDatatypes::propagateLit(" << literal << ")"
-                         << std::endl;
-  return d_im.propagateLit(literal);
-}
-
-void TheoryDatatypes::addAssumptions( std::vector<TNode>& assumptions, std::vector<TNode>& tassumptions ) {
-  std::vector<TNode> ntassumptions;
-  for( unsigned i=0; i<tassumptions.size(); i++ ){
-    //flatten AND
-    if( tassumptions[i].getKind()==AND ){
-      for( unsigned j=0; j<tassumptions[i].getNumChildren(); j++ ){
-        explain( tassumptions[i][j], ntassumptions );
-      }
-    }else{
-      if( std::find( assumptions.begin(), assumptions.end(), tassumptions[i] )==assumptions.end() ){
-        assumptions.push_back( tassumptions[i] );
-      }
-    }
-  }
-  if( !ntassumptions.empty() ){
-    addAssumptions( assumptions, ntassumptions );
-  }
-}
-
-void TheoryDatatypes::explainEquality( TNode a, TNode b, bool polarity, std::vector<TNode>& assumptions ) {
-  if( a!=b ){
-    std::vector<TNode> tassumptions;
-    d_equalityEngine->explainEquality(a, b, polarity, tassumptions);
-    addAssumptions( assumptions, tassumptions );
-  }
-}
-
-void TheoryDatatypes::explainPredicate( TNode p, bool polarity, std::vector<TNode>& assumptions ) {
-  std::vector<TNode> tassumptions;
-  d_equalityEngine->explainPredicate(p, polarity, tassumptions);
-  addAssumptions( assumptions, tassumptions );
-}
-
-/** explain */
-void TheoryDatatypes::explain(TNode literal, std::vector<TNode>& assumptions){
-  Debug("datatypes-explain") << "Explain " << literal << std::endl;
-  bool polarity = literal.getKind() != kind::NOT;
-  TNode atom = polarity ? literal : literal[0];
-  if (atom.getKind() == kind::EQUAL) {
-    explainEquality( atom[0], atom[1], polarity, assumptions );
-  } else if( atom.getKind() == kind::AND && polarity ){
-    for( unsigned i=0; i<atom.getNumChildren(); i++ ){
-      explain( atom[i], assumptions );
-    }
-  } else {
-    Assert(atom.getKind() != kind::AND);
-    explainPredicate( atom, polarity, assumptions );
-  }
-}
-
 TrustNode TheoryDatatypes::explain(TNode literal)
 {
   return d_im.explainLit(literal);
 }
 
-Node TheoryDatatypes::explainLit(TNode literal)
-{
-  std::vector< TNode > assumptions;
-  explain( literal, assumptions );
-  return mkAnd( assumptions );
-}
-
-Node TheoryDatatypes::explain( std::vector< Node >& lits ) {
-  std::vector< TNode > assumptions;
-  for( unsigned i=0; i<lits.size(); i++ ){
-    explain( lits[i], assumptions );
-  }
-  return mkAnd( assumptions );
-}
-
-/** Conflict when merging two constants */
-void TheoryDatatypes::conflict(TNode a, TNode b){
-  Trace("dt-conflict") << "CONFLICT: Eq engine conflict merge : " << a
-                       << " == " << b << std::endl;
-  d_im.conflictEqConstantMerge(a, b);
-}
-
 /** called when a new equivalance class is created */
 void TheoryDatatypes::eqNotifyNewClass(TNode t){
   if( t.getKind()==APPLY_CONSTRUCTOR ){
@@ -1686,7 +1607,7 @@ void TheoryDatatypes::checkCycles() {
     printModelDebug("dt-cdt-debug");
     Trace("dt-cdt-debug") << "Process " << cdt_eqc.size() << " co-datatypes" << std::endl;
     std::vector< std::vector< Node > > part_out;
-    std::vector< TNode > exp;
+    std::vector<Node> exp;
     std::map< Node, Node > cn;
     std::map< Node, std::map< Node, int > > dni;
     for( unsigned i=0; i<cdt_eqc.size(); i++ ){
@@ -1718,7 +1639,7 @@ void TheoryDatatypes::checkCycles() {
           }
           Trace("dt-cdt") << std::endl;
           Node eq = part_out[i][0].eqNode( part_out[i][j] );
-          Node eqExp = mkAnd( exp );
+          Node eqExp = NodeManager::currentNM()->mkAnd(exp);
           d_im.addPendingInference(eq, eqExp);
           Trace("datatypes-infer") << "DtInfer : cdt-bisimilar : " << eq << " by " << eqExp << std::endl;
         }
@@ -1728,10 +1649,15 @@ void TheoryDatatypes::checkCycles() {
 }
 
 //everything is in terms of representatives
-void TheoryDatatypes::separateBisimilar( std::vector< Node >& part, std::vector< std::vector< Node > >& part_out,
-                                         std::vector< TNode >& exp,
-                                         std::map< Node, Node >& cn,
-                                         std::map< Node, std::map< Node, int > >& dni, int dniLvl, bool mkExp ){
+void TheoryDatatypes::separateBisimilar(
+    std::vector<Node>& part,
+    std::vector<std::vector<Node> >& part_out,
+    std::vector<Node>& exp,
+    std::map<Node, Node>& cn,
+    std::map<Node, std::map<Node, int> >& dni,
+    int dniLvl,
+    bool mkExp)
+{
   if( !mkExp ){
     Trace("dt-cdt-debug") << "Separate bisimilar : " << std::endl;
     for( unsigned i=0; i<part.size(); i++ ){
@@ -1758,7 +1684,7 @@ void TheoryDatatypes::separateBisimilar( std::vector< Node >& part, std::vector<
           Node cc = ncons.getOperator();
           cn_cons[part[j]] = ncons;
           if( mkExp ){
-            explainEquality( c, ncons, true, exp );
+            exp.push_back(c.eqNode(ncons));
           }
           new_part[cc].push_back( part[j] );
           if( !mkExp ){ Trace("dt-cdt-debug") << "  - " << part[j] << " is datatype " << ncons << "." << std::endl; }
@@ -1818,7 +1744,7 @@ void TheoryDatatypes::separateBisimilar( std::vector< Node >& part, std::vector<
             Node n = split_new_part[j][k];
             cn[n] = getRepresentative( cn_cons[n][cindex] );
             if( mkExp ){
-              explainEquality( cn[n], cn_cons[n][cindex], true, exp );
+              exp.push_back(cn[n].eqNode(cn_cons[n][cindex]));
             }
           }
           std::vector< std::vector< Node > > c_part_out;
@@ -2012,16 +1938,6 @@ void TheoryDatatypes::printModelDebug( const char* c ){
   }
 }
 
-Node TheoryDatatypes::mkAnd( std::vector< TNode >& assumptions ) {
-  if( assumptions.empty() ){
-    return d_true;
-  }else if( assumptions.size()==1 ){
-    return assumptions[0];
-  }else{
-    return NodeManager::currentNM()->mkNode( AND, assumptions );
-  }
-}
-
 void TheoryDatatypes::computeRelevantTerms(std::set<Node>& termSet)
 {
   Trace("dt-cmi") << "Have " << termSet.size() << " relevant terms..."
@@ -2081,16 +1997,18 @@ std::pair<bool, Node> TheoryDatatypes::entailmentCheck(TNode lit)
       Trace("dt-entail") << "  Tester indices are " << t_index << " and " << l_index << std::endl;
       if( l_index!=-1 && (l_index==t_index)==pol ){
         std::vector< TNode > exp_c;
+        Node eqToExplain;
         if( ei && !ei->d_constructor.get().isNull() ){
-          explainEquality( n, ei->d_constructor.get(), true, exp_c );
+          eqToExplain = n.eqNode(ei->d_constructor.get());
         }else{
           Node lbl = getLabel( n );
           Assert(!lbl.isNull());
           exp_c.push_back( lbl );
           Assert(areEqual(n, lbl[0]));
-          explainEquality( n, lbl[0], true, exp_c );
+          eqToExplain = n.eqNode(lbl[0]);
         }
-        Node exp = mkAnd( exp_c );
+        d_equalityEngine->explainLit(eqToExplain, exp_c);
+        Node exp = NodeManager::currentNM()->mkAnd(exp_c);
         Trace("dt-entail") << "  entailed, explanation is " << exp << std::endl;
         return make_pair(true, exp);
       }
index d34390a5f85f241ac1a5eb1453424a7c16d8d7f7..caea035ce867e366d16c7b521060837535f5e7bc 100644 (file)
@@ -29,6 +29,7 @@
 #include "theory/datatypes/inference_manager.h"
 #include "theory/datatypes/sygus_extension.h"
 #include "theory/theory.h"
+#include "theory/theory_eq_notify.h"
 #include "theory/uf/equality_engine.h"
 #include "util/hash.h"
 
@@ -44,46 +45,20 @@ class TheoryDatatypes : public Theory {
   typedef context::CDHashMap<Node, bool, NodeHashFunction> BoolMap;
   typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeMap;
 
-  Node d_true;
-  Node d_zero;
-  /** mkAnd */
-  Node mkAnd(std::vector<TNode>& assumptions);
-
  private:
   //notification class for equality engine
-  class NotifyClass : public eq::EqualityEngineNotify {
+  class NotifyClass : public TheoryEqNotifyClass
+  {
     TheoryDatatypes& d_dt;
   public:
-    NotifyClass(TheoryDatatypes& dt): d_dt(dt) {}
-    bool eqNotifyTriggerPredicate(TNode predicate, bool value) override
-    {
-      Debug("dt") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false") << ")" << std::endl;
-      if (value) {
-        return d_dt.propagateLit(predicate);
-      }
-      return d_dt.propagateLit(predicate.notNode());
-    }
-    bool eqNotifyTriggerTermEquality(TheoryId tag,
-                                     TNode t1,
-                                     TNode t2,
-                                     bool value) override
-    {
-      AlwaysAssert(tag == THEORY_DATATYPES);
-      Debug("dt") << "NotifyClass::eqNotifyTriggerTermMerge(" << tag << ", " << t1 << ", " << t2 << ")" << std::endl;
-      if (value) {
-        return d_dt.propagateLit(t1.eqNode(t2));
-      }
-      return d_dt.propagateLit(t1.eqNode(t2).notNode());
-    }
-    void eqNotifyConstantTermMerge(TNode t1, TNode t2) override
-    {
-      Debug("dt") << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << ")" << std::endl;
-      d_dt.conflict(t1, t2);
-    }
-    void eqNotifyNewClass(TNode t) override
-    {
-      Debug("dt") << "NotifyClass::eqNotifyNewClass(" << t << ")" << std::endl;
-      d_dt.eqNotifyNewClass(t);
+   NotifyClass(TheoryInferenceManager& im, TheoryDatatypes& dt)
+       : TheoryEqNotifyClass(im), d_dt(dt)
+   {
+   }
+   void eqNotifyNewClass(TNode t) override
+   {
+     Debug("dt") << "NotifyClass::eqNotifyNewClass(" << t << ")" << std::endl;
+     d_dt.eqNotifyNewClass(t);
     }
     void eqNotifyMerge(TNode t1, TNode t2) override
     {
@@ -91,9 +66,6 @@ class TheoryDatatypes : public Theory {
                   << std::endl;
       d_dt.eqNotifyMerge(t1, t2);
     }
-    void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override
-    {
-    }
   };/* class TheoryDatatypes::NotifyClass */
 private:
   /** equivalence class info
@@ -129,8 +101,6 @@ private:
   NodeMap d_term_sk;
   Node getTermSkolemFor( Node n );
 private:
-  /** The notify class */
-  NotifyClass d_notify;
   /** information necessary for equivalence classes */
   std::map< Node, EqcInfo* > d_eqc_info;
   /** map from nodes to their instantiated equivalent for each constructor type */
@@ -243,13 +213,7 @@ private:
   /** Conflict when merging two constants */
   void conflict(TNode a, TNode b);
   /** explain */
-  void addAssumptions( std::vector<TNode>& assumptions, std::vector<TNode>& tassumptions );
-  void explainEquality( TNode a, TNode b, bool polarity, std::vector<TNode>& assumptions );
-  void explainPredicate( TNode p, bool polarity, std::vector<TNode>& assumptions );
-  void explain( TNode literal, std::vector<TNode>& assumptions );
   TrustNode explain(TNode literal) override;
-  Node explainLit(TNode literal);
-  Node explain( std::vector< Node >& lits );
   /** called when a new equivalance class is created */
   void eqNotifyNewClass(TNode t);
   /** called when two equivalance classes have merged */
@@ -300,10 +264,13 @@ private:
                       std::vector<Node>& explanation,
                       bool firstTime = true);
   /** for checking whether two codatatype terms must be equal */
-  void separateBisimilar( std::vector< Node >& part, std::vector< std::vector< Node > >& part_out,
-                          std::vector< TNode >& exp,
-                          std::map< Node, Node >& cn,
-                          std::map< Node, std::map< Node, int > >& dni, int dniLvl, bool mkExp );
+  void separateBisimilar(std::vector<Node>& part,
+                         std::vector<std::vector<Node> >& part_out,
+                         std::vector<Node>& exp,
+                         std::map<Node, Node>& cn,
+                         std::map<Node, std::map<Node, int> >& dni,
+                         int dniLvl,
+                         bool mkExp);
   /** build model */
   Node getCodatatypesValue( Node n, std::map< Node, Node >& eqc_cons, std::map< Node, int >& vmap, int depth );
   /** get singleton lemma */
@@ -330,16 +297,19 @@ private:
    * equivalence classes.
    */
   void computeRelevantTerms(std::set<Node>& termSet) override;
-
+  /** Commonly used terms */
+  Node d_true;
+  Node d_zero;
   /** sygus symmetry breaking utility */
   std::unique_ptr<SygusExtension> d_sygusExtension;
-
   /** The theory rewriter for this theory. */
   DatatypesRewriter d_rewriter;
   /** A (default) theory state object */
   TheoryState d_state;
   /** The inference manager */
   InferenceManager d_im;
+  /** The notify class */
+  NotifyClass d_notify;
 };/* class TheoryDatatypes */
 
 }/* CVC4::theory::datatypes namespace */