Fixes for relational triggers (#2967)
[cvc5.git] / src / theory / quantifiers / ematching / inst_match_generator.cpp
index 0a4386db9419dd012e82731200a4d13bfae074c0..9e76a6a31f7d90ee9e344bbdc7c72ec5e912471b 100644 (file)
@@ -98,28 +98,53 @@ void InstMatchGenerator::initialize( Node q, QuantifiersEngine* qe, std::vector<
   if( !d_pattern.isNull() ){
     Trace("inst-match-gen") << "Initialize, pattern term is " << d_pattern << std::endl;
     if( d_match_pattern.getKind()==NOT ){
+      Assert(d_pattern.getKind() == NOT);
       //we want to add the children of the NOT
-      d_match_pattern = d_pattern[0];
+      d_match_pattern = d_match_pattern[0];
+    }
+
+    if (d_pattern.getKind() == NOT && d_match_pattern.getKind() == EQUAL
+        && d_match_pattern[0].getKind() == INST_CONSTANT
+        && d_match_pattern[1].getKind() == INST_CONSTANT)
+    {
+      // special case: disequalities between variables x != y will match ground
+      // disequalities.
     }
-    if( d_match_pattern.getKind()==EQUAL || d_match_pattern.getKind()==GEQ ){
-      //make sure the matching portion of the equality is on the LHS of d_pattern
-      //  and record what d_match_pattern is
+    else if (d_match_pattern.getKind() == EQUAL
+             || d_match_pattern.getKind() == GEQ)
+    {
+      // We are one of the following cases:
+      //   f(x)~a, f(x)~y, x~a, x~y
+      // If we are the first or third case, we ensure that f(x)/x is on the left
+      // hand side of the relation d_pattern, d_match_pattern is f(x)/x and
+      // d_eq_class_rel (indicating the equivalence class that we are related
+      // to) is set to a.
       for( unsigned i=0; i<2; i++ ){
-        if( !quantifiers::TermUtil::hasInstConstAttr(d_match_pattern[i]) || d_match_pattern[i].getKind()==INST_CONSTANT ){
-          Node mp = d_match_pattern[1-i];
-          Node mpo = d_match_pattern[i];
-          if( mp.getKind()!=INST_CONSTANT ){
-            if( i==0 ){
-              if( d_match_pattern.getKind()==GEQ ){
-                d_pattern = NodeManager::currentNM()->mkNode( kind::GT, mp, mpo );
-                d_pattern = d_pattern.negate();
-              }else{
-                d_pattern = NodeManager::currentNM()->mkNode( d_match_pattern.getKind(), mp, mpo );
-              }
+        Node mp = d_match_pattern[i];
+        Node mpo = d_match_pattern[1 - i];
+        // If this side has free variables, and the other side does not or
+        // it is a free variable, then we will match on this side of the
+        // relation.
+        if (quantifiers::TermUtil::hasInstConstAttr(mp)
+            && (!quantifiers::TermUtil::hasInstConstAttr(mpo)
+                || mpo.getKind() == INST_CONSTANT))
+        {
+          if (i == 1)
+          {
+            if (d_match_pattern.getKind() == GEQ)
+            {
+              d_pattern = NodeManager::currentNM()->mkNode(kind::GT, mp, mpo);
+              d_pattern = d_pattern.negate();
+            }
+            else
+            {
+              d_pattern = NodeManager::currentNM()->mkNode(
+                  d_match_pattern.getKind(), mp, mpo);
             }
-            d_eq_class_rel = mpo;
-            d_match_pattern = mp;
           }
+          d_eq_class_rel = mpo;
+          d_match_pattern = mp;
+          // we won't find a term in the other direction
           break;
         }
       }
@@ -178,9 +203,7 @@ void InstMatchGenerator::initialize( Node q, QuantifiersEngine* qe, std::vector<
       {
         // 1-constructors have a trivial way of generating candidates in a
         // given equivalence class
-        const Datatype& dt =
-            static_cast<DatatypeType>(d_match_pattern.getType().toType())
-                .getDatatype();
+        const Datatype& dt = d_match_pattern.getType().getDatatype();
         if (dt.getNumConstructors() == 1)
         {
           d_cg = new inst::CandidateGeneratorConsExpand(qe, d_match_pattern);
@@ -188,14 +211,18 @@ void InstMatchGenerator::initialize( Node q, QuantifiersEngine* qe, std::vector<
       }
       if (d_cg == nullptr)
       {
-        // we will be scanning lists trying to find
-        // d_match_pattern.getOperator()
-        d_cg = new inst::CandidateGeneratorQE(qe, d_match_pattern);
-      }
-      //if matching on disequality, inform the candidate generator not to match on eqc
-      if( d_pattern.getKind()==NOT && d_pattern[0].getKind()==EQUAL ){
-        ((inst::CandidateGeneratorQE*)d_cg)->excludeEqc( d_eq_class_rel );
-        d_eq_class_rel = Node::null();
+        CandidateGeneratorQE* cg =
+            new CandidateGeneratorQE(qe, d_match_pattern);
+        // we will be scanning lists trying to find ground terms whose operator
+        // is the same as d_match_operator's.
+        d_cg = cg;
+        // if matching on disequality, inform the candidate generator not to
+        // match on eqc
+        if (d_pattern.getKind() == NOT && d_pattern[0].getKind() == EQUAL)
+        {
+          cg->excludeEqc(d_eq_class_rel);
+          d_eq_class_rel = Node::null();
+        }
       }
     }else if( d_match_pattern.getKind()==INST_CONSTANT ){
       if( d_pattern.getKind()==APPLY_SELECTOR_TOTAL ){
@@ -209,12 +236,15 @@ void InstMatchGenerator::initialize( Node q, QuantifiersEngine* qe, std::vector<
       }else{
         d_cg = new CandidateGeneratorQEAll( qe, d_match_pattern );
       }
-    }else if( d_match_pattern.getKind()==EQUAL &&
-              d_match_pattern[0].getKind()==INST_CONSTANT && d_match_pattern[1].getKind()==INST_CONSTANT ){
+    }
+    else if (d_match_pattern.getKind() == EQUAL)
+    {
       //we will be producing candidates via literal matching heuristics
-      Assert(d_pattern.getKind() == NOT);
-      // candidates will be all disequalities
-      d_cg = new inst::CandidateGeneratorQELitDeq(qe, d_match_pattern);
+      if (d_pattern.getKind() == NOT)
+      {
+        // candidates will be all disequalities
+        d_cg = new inst::CandidateGeneratorQELitDeq(qe, d_match_pattern);
+      }
     }else{
       Trace("inst-match-gen-warn") << "(?) Unknown matching pattern is " << d_match_pattern << std::endl;
     }
@@ -288,8 +318,10 @@ int InstMatchGenerator::getMatch(
           prev.push_back(d_children_types[0]);
         }
       }
+    }
     //for relational matching
-    }else if( !d_eq_class_rel.isNull() && d_eq_class_rel.getKind()==INST_CONSTANT ){
+    if (!d_eq_class_rel.isNull() && d_eq_class_rel.getKind() == INST_CONSTANT)
+    {
       int v = d_eq_class_rel.getAttribute(InstVarNumAttribute());
       //also must fit match to equivalence class
       bool pol = d_pattern.getKind()!=NOT;