Use matching heuristics for EPR instantiation.
authorajreynol <andrew.j.reynolds@gmail.com>
Sat, 17 Sep 2016 00:06:05 +0000 (19:06 -0500)
committerajreynol <andrew.j.reynolds@gmail.com>
Sat, 17 Sep 2016 00:06:16 +0000 (19:06 -0500)
src/options/quantifiers_options
src/theory/quantifiers/ceg_instantiator.cpp
src/theory/quantifiers/ceg_instantiator.h
src/theory/quantifiers/quant_conflict_find.cpp
src/theory/quantifiers/term_database.cpp

index 9e704691a0724cdc18f71ef10ecfd2271c8c67f1..3fc589b5ce1e430b30a0a59507ab36bda1f28406 100644 (file)
@@ -304,6 +304,8 @@ option cbqiNestedQE --cbqi-nested-qe bool :default false
  
 option quantEpr --quant-epr bool :default false :read-write
  infer whether in effectively propositional fragment, use for cbqi
+option quantEprMatching --quant-epr-match bool :default true
+ use matching heuristics for EPR instantiation
  
 ### local theory extensions options 
 
index 5876c7377950997c91144d2c360e56f8584490f5..987b69522c2111564c12490b1ca2792b9c57e066 100644 (file)
@@ -21,7 +21,9 @@
 #include "theory/quantifiers/first_order_model.h"
 #include "theory/quantifiers/term_database.h"
 #include "theory/quantifiers/quantifiers_rewriter.h"
+#include "theory/quantifiers/trigger.h"
 #include "theory/theory_engine.h"
+#include "theory/quantifiers/term_database.h"
 
 #include "theory/bv/theory_bv_utils.h"
 #include "util/bitvector.h"
@@ -1379,17 +1381,16 @@ Node CegInstantiator::getModelValue( Node n ) {
 void CegInstantiator::collectCeAtoms( Node n, std::map< Node, bool >& visited ) {
   if( n.getKind()==FORALL ){
     d_is_nested_quant = true;
-  }else{
-    if( visited.find( n )==visited.end() ){
-      visited[n] = true;
-      if( TermDb::isBoolConnective( n.getKind() ) ){
-        for( unsigned i=0; i<n.getNumChildren(); i++ ){
-          collectCeAtoms( n[i], visited );
-        }
-      }else{
-        if( std::find( d_ce_atoms.begin(), d_ce_atoms.end(), n )==d_ce_atoms.end() ){
-          d_ce_atoms.push_back( n );
-        }
+  }else if( visited.find( n )==visited.end() ){
+    visited[n] = true;
+    if( TermDb::isBoolConnective( n.getKind() ) ){
+      for( unsigned i=0; i<n.getNumChildren(); i++ ){
+        collectCeAtoms( n[i], visited );
+      }
+    }else{
+      if( std::find( d_ce_atoms.begin(), d_ce_atoms.end(), n )==d_ce_atoms.end() ){
+        Trace("cbqi-ce-atoms") << "CE atoms : " << n << std::endl;
+        d_ce_atoms.push_back( n );
       }
     }
   }
@@ -1711,12 +1712,88 @@ bool DtInstantiator::processEquality( CegInstantiator * ci, SolvedForm& sf, Node
   return false;
 }
 
+void EprInstantiator::reset( Node pv, unsigned effort ) {
+  d_equal_terms.clear();
+}
+
 bool EprInstantiator::processEqualTerm( CegInstantiator * ci, SolvedForm& sf, Node pv, Node pv_coeff, Node n, unsigned effort ) {
-  return ci->doAddInstantiationInc( pv, n, pv_coeff, 0, sf, effort );
+  if( options::quantEprMatching() ){
+    Assert( pv_coeff.isNull() );
+    d_equal_terms.push_back( n ); 
+    return false;  
+  }else{
+    return ci->doAddInstantiationInc( pv, n, pv_coeff, 0, sf, effort );
+  }
 }
 
+void EprInstantiator::computeMatchScore( CegInstantiator * ci, Node pv, Node catom, std::vector< Node >& arg_reps, TermArgTrie * tat, unsigned index, std::map< Node, int >& match_score ) {
+  if( index==catom.getNumChildren() ){
+    Assert( tat->hasNodeData() );
+    Node gcatom = tat->getNodeData();
+    Trace("epr-inst") << "Matched : " << catom << " and " << gcatom << std::endl;
+    for( unsigned i=0; i<catom.getNumChildren(); i++ ){
+      if( catom[i]==pv ){
+        Trace("epr-inst") << "...increment " << gcatom[i] << std::endl;
+        match_score[gcatom[i]]++;
+      }else{
+        //recursive matching
+        computeMatchScore( ci, pv, catom[i], gcatom[i], match_score );
+      }
+    }
+  }else{
+    std::map< TNode, TermArgTrie >::iterator it = tat->d_data.find( arg_reps[index] );
+    if( it!=tat->d_data.end() ){
+      computeMatchScore( ci, pv, catom, arg_reps, &it->second, index+1, match_score );
+    }
+  }
+}
+
+void EprInstantiator::computeMatchScore( CegInstantiator * ci, Node pv, Node catom, Node eqc, std::map< Node, int >& match_score ) {
+  if( inst::Trigger::isAtomicTrigger( catom ) && TermDb::containsTerm( catom, pv ) ){
+    Trace("epr-inst") << "Find matches for " << catom << "..." << std::endl;
+    std::vector< Node > arg_reps;
+    for( unsigned j=0; j<catom.getNumChildren(); j++ ){
+      arg_reps.push_back( ci->getQuantifiersEngine()->getMasterEqualityEngine()->getRepresentative( catom[j] ) );
+    }
+    if( ci->getQuantifiersEngine()->getMasterEqualityEngine()->hasTerm( eqc ) ){
+      Node rep = ci->getQuantifiersEngine()->getMasterEqualityEngine()->getRepresentative( eqc );
+      Node op = ci->getQuantifiersEngine()->getTermDatabase()->getMatchOperator( catom );
+      TermArgTrie * tat = ci->getQuantifiersEngine()->getTermDatabase()->getTermArgTrie( rep, op );
+      Trace("epr-inst") << "EPR instantiation match term : " << catom << ", check ground terms=" << (tat!=NULL) << std::endl;
+      if( tat ){
+        computeMatchScore( ci, pv, catom, arg_reps, tat, 0, match_score );
+      }
+    }
+  }
+}
+
+struct sortEqTermsMatch {
+  std::map< Node, int > d_match_score;
+  bool operator() (Node i, Node j) {
+    int match_score_i = d_match_score[i];
+    int match_score_j = d_match_score[j];
+    return match_score_i>match_score_j || ( match_score_i==match_score_j && i<j );
+  }
+};
+
+    
 bool EprInstantiator::processEqualTerms( CegInstantiator * ci, SolvedForm& sf, Node pv, std::vector< Node >& eqc, unsigned effort ) {
-  //TODO: heuristic for best matching constant
+  if( options::quantEprMatching() ){
+    //heuristic for best matching constant
+    sortEqTermsMatch setm;
+    for( unsigned i=0; i<ci->getNumCEAtoms(); i++ ){
+      Node catom = ci->getCEAtom( i );
+      computeMatchScore( ci, pv, catom, catom, setm.d_match_score );
+    }
+    //sort by match score
+    std::sort( d_equal_terms.begin(), d_equal_terms.end(), setm );
+    Node pv_coeff;
+    for( unsigned i=0; i<d_equal_terms.size(); i++ ){
+      if( ci->doAddInstantiationInc( pv, d_equal_terms[i], pv_coeff, 0, sf, effort ) ){
+        return true;
+      }
+    }
+  }
   return false;
 }
 
index 3b949c23700d3724cf7802dd59071ee8ea44f4f8..259c604dcb237d010977dc5b003f575e113b3100 100644 (file)
@@ -140,8 +140,6 @@ private:
   Node getModelBasedProjectionValue( Node e, Node t, bool isLower, Node c, Node me, Node mt, Node theta, Node inf_coeff, Node delta_coeff );
   void processAssertions();
   void addToAuxVarSubstitution( std::vector< Node >& subs_lhs, std::vector< Node >& subs_rhs, Node l, Node r );
-  //get model value
-  Node getModelValue( Node n );
 private:
   int solve_arith( Node v, Node atom, Node & veq_c, Node & val, Node& vts_coeff_inf, Node& vts_coeff_delta );
   Node solve_dt( Node v, Node a, Node b, Node sa, Node sb );
@@ -157,9 +155,14 @@ public:
 
 //interface for instantiators
 public:
+  //get quantifiers engine
+  QuantifiersEngine* getQuantifiersEngine() { return d_qe; }
   void pushStackVariable( Node v );
   void popStackVariable();
   bool doAddInstantiationInc( Node pv, Node n, Node pv_coeff, int bt, SolvedForm& sf, unsigned effort );
+  Node getModelValue( Node n );
+  unsigned getNumCEAtoms() { return d_ce_atoms.size(); }
+  Node getCEAtom( unsigned i ) { return d_ce_atoms[i]; }
 };
 
 
@@ -237,10 +240,17 @@ public:
   std::string identify() const { return "Dt"; }
 };
 
+class TermArgTrie;
+
 class EprInstantiator : public Instantiator {
+private:
+  std::vector< Node > d_equal_terms;
+  void computeMatchScore( CegInstantiator * ci, Node pv, Node catom, std::vector< Node >& arg_reps, TermArgTrie * tat, unsigned index, std::map< Node, int >& match_score );
+  void computeMatchScore( CegInstantiator * ci, Node pv, Node catom, Node eqc, std::map< Node, int >& match_score );
 public:
   EprInstantiator( QuantifiersEngine * qe, TypeNode tn ) : Instantiator( qe, tn ){}
   virtual ~EprInstantiator(){}
+  void reset( Node pv, unsigned effort );
   bool processEqualTerm( CegInstantiator * ci, SolvedForm& sf, Node pv, Node pv_coeff, Node n, unsigned effort );
   bool processEqualTerms( CegInstantiator * ci, SolvedForm& sf, Node pv, std::vector< Node >& eqc, unsigned effort );
   std::string identify() const { return "Epr"; }
index bac2aa35cd1a3a6f1154a98a53732078f91fb1ae..522f4dfceb200fed1d87e693d24d55d9d8df4f3b 100644 (file)
@@ -1853,7 +1853,7 @@ void MatchGen::setInvalid() {
 }
 
 bool MatchGen::isHandledBoolConnective( TNode n ) {
-  return TermDb::isBoolConnective( n.getKind() ) && ( n.getKind()!=ITE || n.getType().isBoolean() );
+  return TermDb::isBoolConnective( n.getKind() ) && ( n.getKind()!=ITE || n.getType().isBoolean() ) && n.getKind()!=SEP_STAR;
 }
 
 bool MatchGen::isHandledUfTerm( TNode n ) {
index dee4952bd36fdbefe03704c117b96ed3a1629292..ff11babc958a072ba32d6020710ab76460c74388 100644 (file)
@@ -1958,7 +1958,7 @@ bool TermDb::isComm( Kind k ) {
 }
 
 bool TermDb::isBoolConnective( Kind k ) {
-  return k==OR || k==AND || k==IFF || k==ITE || k==FORALL || k==NOT;
+  return k==OR || k==AND || k==IFF || k==ITE || k==FORALL || k==NOT || k==SEP_STAR;
 }
 
 void TermDb::registerTrigger( theory::inst::Trigger* tr, Node op ){