Add new symmetry breaking technique for finite model finding. Improvements to bounde...
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 27 Sep 2013 14:27:19 +0000 (09:27 -0500)
committerAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 27 Sep 2013 14:27:29 +0000 (09:27 -0500)
15 files changed:
src/smt/options
src/smt/smt_engine.cpp
src/theory/quantifiers/Makefile.am
src/theory/quantifiers/bounded_integers.cpp
src/theory/quantifiers/bounded_integers.h
src/theory/quantifiers/full_model_check.cpp
src/theory/quantifiers/symmetry_breaking.cpp [new file with mode: 0755]
src/theory/quantifiers/symmetry_breaking.h [new file with mode: 0755]
src/theory/quantifiers_engine.cpp [changed mode: 0644->0755]
src/theory/rep_set.cpp
src/theory/uf/options
src/theory/uf/theory_uf_strong_solver.cpp
src/theory/uf/theory_uf_strong_solver.h
src/util/sort_inference.cpp
src/util/sort_inference.h

index f39662c106d39bbb3e266a09401f22bb93930661..7a72881b4b187fba5a9f9f944e5db39fdf490191 100644 (file)
@@ -48,7 +48,7 @@ option unconstrainedSimp --unconstrained-simp bool :default false :read-write
 option repeatSimp --repeat-simp bool :read-write
  make multiple passes with nonclausal simplifier
 
-option sortInference --sort-inference bool :default false
+option sortInference --sort-inference bool :read-write :default false
  apply sort inference to input problem
 
 common-option incrementalSolving incremental -i --incremental bool
index e1dc3531ed0a49255f5d9e414c08407140e42946..7fadb477ba6ebd28560b8ffee1cc3a90aeda6130 100644 (file)
@@ -1028,6 +1028,9 @@ void SmtEngine::setLogicInternal() throw() {
       options::fmfInstGen.set( false );
     }
   }
+  if( options::ufssSymBreak() ){
+    options::sortInference.set( true );
+  }
 
   //until bugs 371,431 are fixed
   if( ! options::minisatUseElim.wasSetByUser()){
index 80011868bcd903ab961e29e683b2c47cfa8cd504..be24d6c678bbc897f84054ff3e7a71a5caf2b3ca 100644 (file)
@@ -52,7 +52,10 @@ libquantifiers_la_SOURCES = \
        rewrite_engine.h \
        rewrite_engine.cpp \
        relevant_domain.h \
-       relevant_domain.cpp
+       relevant_domain.cpp \
+       symmetry_breaking.h \
+       symmetry_breaking.cpp
+
 
 EXTRA_DIST = \
        kinds \
index e1e2f96c2ba73dfbc3ea129e1af5ae3f99c2c795..30ff5242b6c1d2a80f2eb8a1a26eecc7c4745434 100644 (file)
@@ -321,7 +321,7 @@ Node BoundedIntegers::getNextDecisionRequest() {
   return Node::null();
 }
 
-void BoundedIntegers::getBoundValues( Node f, Node v, RepSetIterator * rsi, Node & l, Node & u ) {
+void BoundedIntegers::getBounds( Node f, Node v, RepSetIterator * rsi, Node & l, Node & u ) {
   l = d_bounds[0][f][v];
   u = d_bounds[1][f][v];
   if( d_nground_range[f].find(v)!=d_nground_range[f].end() ){
@@ -356,6 +356,10 @@ void BoundedIntegers::getBoundValues( Node f, Node v, RepSetIterator * rsi, Node
       l = l.substitute( vars.begin(), vars.end(), subs.begin(), subs.end() );
     }
   }
+}
+
+void BoundedIntegers::getBoundValues( Node f, Node v, RepSetIterator * rsi, Node & l, Node & u ) {
+  getBounds( f, v, rsi, l, u );
   Trace("bound-int-rsi") << "Get value in model for..." << l << " and " << u << std::endl;
   l = d_quantEngine->getModel()->getCurrentModelValue( l );
   u = d_quantEngine->getModel()->getCurrentModelValue( u );
index 27d5b75692f6f6392ffdbb5fe07590049b162fde..3da938d3120506e063b9a830aacef1c31f6882cc 100644 (file)
@@ -115,6 +115,7 @@ public:
   int getBoundVarNum( Node f, int i ) { return d_set_nums[f][i]; }
   Node getLowerBound( Node f, Node v ){ return d_bounds[0][f][v]; }
   Node getUpperBound( Node f, Node v ){ return d_bounds[1][f][v]; }
+  void getBounds( Node f, Node v, RepSetIterator * rsi, Node & l, Node & u );
   void getBoundValues( Node f, Node v, RepSetIterator * rsi, Node & l, Node & u );
   bool isGroundRange(Node f, Node v);
 };
index cdf6976756af526a3f439585dbf2f353b9724b9d..bf10369e627cf1e445f68491f2df92a672837db7 100644 (file)
@@ -840,14 +840,15 @@ void FullModelChecker::doCheck(FirstOrderModelFmc * fm, Node f, Def & d, Node n
       Node i = fm->getUsedRepresentative( r[1] );
       Node e = fm->getUsedRepresentative( r[2] );
       d.addEntry(fm, mkArrayCond(i), e );
-      r = r[0];
+      r = fm->getRepresentative( r[0] );
     }
     Node defC = mkArrayCond(fm->getStar(n.getType().getArrayIndexType()));
     bool success = false;
+    Node odefaultValue;
     if( r.getKind() == kind::STORE_ALL ){
       ArrayStoreAll storeAll = r.getConst<ArrayStoreAll>();
-      Node defaultValue = Node::fromExpr(storeAll.getExpr());
-      defaultValue = fm->getUsedRepresentative( defaultValue, true );
+      odefaultValue = Node::fromExpr(storeAll.getExpr());
+      Node defaultValue = fm->getUsedRepresentative( odefaultValue, true );
       if( !defaultValue.isNull() ){
         d.addEntry(fm, defC, defaultValue);
         success = true;
@@ -855,6 +856,7 @@ void FullModelChecker::doCheck(FirstOrderModelFmc * fm, Node f, Def & d, Node n
     }
     if( !success ){
       Trace("fmc-warn") << "WARNING : ARRAYS : Can't process base array " << r << std::endl;
+      Trace("fmc-warn") << "          Default value was : " << odefaultValue << std::endl;
       Trace("fmc-debug") << "Can't process base array " << r << std::endl;
       //can't process this array
       d.reset();
@@ -1191,29 +1193,35 @@ bool FullModelChecker::doMeet( FirstOrderModelFmc * fm, std::vector< Node > & co
 }
 
 Node FullModelChecker::doIntervalMeet( FirstOrderModelFmc * fm, Node i1, Node i2, bool mk ) {
-  if( !fm->isInterval( i1 ) || !fm->isInterval( i2 ) ){
-    std::cout << "Not interval during meet! " << i1 << " " << i2 << std::endl;
-    exit( 0 );
-  }
-  Node b[2];
-  for( unsigned j=0; j<2; j++ ){
-    Node b1 = i1[j];
-    Node b2 = i2[j];
-    if( fm->isStar( b1 ) ){
-      b[j] = b2;
-    }else if( fm->isStar( b2 ) ){
-      b[j] = b1;
-    }else if( b1.getConst<Rational>() < b2.getConst<Rational>() ){
-      b[j] = j==0 ? b2 : b1;
+  if( fm->isStar( i1 ) ){
+    return i2;
+  }else if( fm->isStar( i2 ) ){
+    return i1;
+  }else{
+    if( !fm->isInterval( i1 ) || !fm->isInterval( i2 ) ){
+      std::cout << "Not interval during meet! " << i1 << " " << i2 << std::endl;
+      exit( 0 );
+    }
+    Node b[2];
+    for( unsigned j=0; j<2; j++ ){
+      Node b1 = i1[j];
+      Node b2 = i2[j];
+      if( fm->isStar( b1 ) ){
+        b[j] = b2;
+      }else if( fm->isStar( b2 ) ){
+        b[j] = b1;
+      }else if( b1.getConst<Rational>() < b2.getConst<Rational>() ){
+        b[j] = j==0 ? b2 : b1;
+      }else{
+        b[j] = j==0 ? b1 : b2;
+      }
+    }
+    if( fm->isStar( b[0] ) || fm->isStar( b[1] ) || b[0].getConst<Rational>() < b[1].getConst<Rational>() ){
+      return mk ? fm->getInterval( b[0], b[1] ) : i1;
     }else{
-      b[j] = j==0 ? b1 : b2;
+      return Node::null();
     }
   }
-  if( fm->isStar( b[0] ) || fm->isStar( b[1] ) || b[0].getConst<Rational>() < b[1].getConst<Rational>() ){
-    return mk ? fm->getInterval( b[0], b[1] ) : i1;
-  }else{
-    return Node::null();
-  }
 }
 
 Node FullModelChecker::mkCond( std::vector< Node > & cond ) {
diff --git a/src/theory/quantifiers/symmetry_breaking.cpp b/src/theory/quantifiers/symmetry_breaking.cpp
new file mode 100755 (executable)
index 0000000..6a7baeb
--- /dev/null
@@ -0,0 +1,296 @@
+/*********************                                                        */
+/*! \file symmetry_breaking.cpp
+ ** \verbatim
+ ** Original author: ajreynol
+ ** Major contributors: none
+ ** Minor contributors (to current version): none
+ ** This file is part of the CVC4 prototype.
+ ** Copyright (c) 2009-2012  New York University and The University of Iowa
+ ** See the file COPYING in the top-level source directory for licensing
+ ** information.\endverbatim
+ **
+ ** \brief symmetry breaking module
+ **
+ **/
+
+#include <vector>
+
+#include "theory/quantifiers/symmetry_breaking.h"
+#include "theory/rewriter.h"
+#include "theory/quantifiers_engine.h"
+#include "theory/theory_engine.h"
+#include "util/sort_inference.h"
+#include "theory/uf/theory_uf_strong_solver.h"
+
+using namespace CVC4;
+using namespace CVC4::kind;
+using namespace CVC4::theory;
+using namespace std;
+
+namespace CVC4 {
+
+eq::EqualityEngine * SubsortSymmetryBreaker::getEqualityEngine() {
+  return ((uf::TheoryUF*)d_qe->getTheoryEngine()->theoryOf( theory::THEORY_UF ))->getEqualityEngine();
+}
+
+bool SubsortSymmetryBreaker::areEqual( Node n1, Node n2 ) {
+  return getEqualityEngine()->hasTerm( n1 ) && getEqualityEngine()->hasTerm( n2 ) && getEqualityEngine()->areEqual( n1,n2 );
+}
+
+bool SubsortSymmetryBreaker::areDisequal( Node n1, Node n2 ) {
+  return getEqualityEngine()->hasTerm( n1 ) && getEqualityEngine()->hasTerm( n2 ) && getEqualityEngine()->areDisequal( n1,n2, false );
+}
+
+
+Node SubsortSymmetryBreaker::getRepresentative( Node n ) {
+  return getEqualityEngine()->getRepresentative( n );
+}
+
+uf::StrongSolverTheoryUF * SubsortSymmetryBreaker::getStrongSolver() {
+  return ((uf::TheoryUF*)d_qe->getTheoryEngine()->theoryOf( theory::THEORY_UF ))->getStrongSolver();
+}
+
+SubsortSymmetryBreaker::SubsortSymmetryBreaker(QuantifiersEngine* qe, context::Context* c) :
+d_qe(qe), d_conflict(c,false), d_max_dom_const_sort(c,0), d_has_dom_const_sort(c,false),
+d_fact_index(c,0), d_fact_list(c) {
+  d_true =  NodeManager::currentNM()->mkConst( true );
+}
+
+SubsortSymmetryBreaker::TypeInfo::TypeInfo( SubsortSymmetryBreaker * ssb, context::Context * c ) :
+d_ssb( ssb ), d_dom_constants( c ), d_first_active( c, 0 ){
+  d_dc_nodes = 0;
+}
+
+unsigned SubsortSymmetryBreaker::TypeInfo::getNumDomainConstants() {
+  if( d_nodes.empty() ){
+    return 0;
+  }else{
+    return 1 + d_dom_constants.size();
+  }
+}
+
+Node SubsortSymmetryBreaker::TypeInfo::getDomainConstant( int i ) {
+  if( i==0 ){
+    return d_nodes[0];
+  }else{
+    Assert( i<=(int)d_dom_constants.size() );
+    return d_dom_constants[i-1];
+  }
+}
+
+Node SubsortSymmetryBreaker::TypeInfo::getFirstActive() {
+  if( d_first_active.get()<(int)d_nodes.size() ){
+    Node fa = d_nodes[d_first_active.get()];
+    return d_ssb->getEqualityEngine()->hasTerm( fa ) ? fa : Node::null();
+  }else{
+    return Node::null();
+  }
+}
+
+SubsortSymmetryBreaker::TypeInfo * SubsortSymmetryBreaker::getTypeInfo( TypeNode tn, int sid ) {
+  if( d_type_info.find( sid )==d_type_info.end() ){
+    d_type_info[sid] = new TypeInfo( this, d_qe->getSatContext() );
+    d_sub_sorts[tn].push_back( sid );
+    d_sid_to_type[sid] = tn;
+  }
+  return d_type_info[sid];
+}
+
+void SubsortSymmetryBreaker::newEqClass( Node n ) {
+  Trace("sym-break-temp") << "New eq class " << n << std::endl;
+  if( !d_conflict ){
+    TypeNode tn = n.getType();
+    SortInference * si = d_qe->getTheoryEngine()->getSortInference();
+    if( si->isWellSorted( n ) ){
+      int sid = si->getSortId( n );
+      Trace("sym-break-debug") << "SSB: New eq class " << n << " : " << n.getType() << " : " << sid << std::endl;
+      TypeInfo * ti = getTypeInfo( tn, sid );
+      if( std::find( ti->d_nodes.begin(), ti->d_nodes.end(), n )==ti->d_nodes.end() ){
+        if( ti->d_nodes.empty() ){
+          //for first subsort, we add unit equality
+          if( d_sub_sorts[tn][0]!=sid ){
+            Trace("sym-break-temp") << "Do sym break unit with " << d_type_info[d_sub_sorts[tn][0]]->getBaseConstant() << std::endl;
+            //add unit symmetry breaking lemma
+            Node eq = n.eqNode( d_type_info[d_sub_sorts[tn][0]]->getBaseConstant() );
+            eq = Rewriter::rewrite( eq );
+            d_unit_lemmas.push_back( eq );
+            Trace("sym-break-lemma") << "*** SymBreak : Unit lemma (" << sid << "==" << d_sub_sorts[tn][0] << ") : " << eq << std::endl;
+            d_pending_lemmas.push_back( eq );
+          }
+          Trace("sym-break-dc") << "* Set first domain constant : " << n << " for " << tn << " : " << sid << std::endl;
+          ti->d_dc_nodes++;
+        }
+        ti->d_node_to_id[n] = ti->d_nodes.size();
+        ti->d_nodes.push_back( n );
+      }
+      if( !d_has_dom_const_sort.get() ){
+        d_has_dom_const_sort.set( true );
+        d_max_dom_const_sort.set( sid );
+      }
+    }
+  }
+  Trace("sym-break-temp") << "Done new eq class" << std::endl;
+}
+
+
+
+void SubsortSymmetryBreaker::merge( Node a, Node b ) {
+
+}
+
+void SubsortSymmetryBreaker::assertDisequal( Node a, Node b ) {
+
+}
+
+void SubsortSymmetryBreaker::processFirstActive( TypeNode tn, int sid, int curr_card ){
+  TypeInfo * ti = getTypeInfo( tn, sid );
+  if( (int)ti->getNumDomainConstants()<curr_card ){
+    Trace("sym-break-dc-debug") << "Check for domain constants " << tn << " : " << sid << ", curr_card = " << curr_card << ", ";
+    Trace("sym-break-dc-debug") << "#domain constants = " << ti->getNumDomainConstants() << std::endl;
+    Node fa = ti->getFirstActive();
+    bool invalid = true;
+    while( invalid && !fa.isNull() && (int)ti->getNumDomainConstants()<curr_card ){
+      invalid = false;
+      unsigned deq = 0;
+      for( unsigned i=0; i<ti->getNumDomainConstants(); i++ ){
+        Node dc = ti->getDomainConstant( i );
+        if( areEqual( fa, dc ) ){
+          invalid = true;
+          break;
+        }else if( areDisequal( fa, dc ) ){
+          deq++;
+        }
+      }
+      if( deq==ti->getNumDomainConstants() ){
+        Trace("sym-break-dc") << "* Can infer domain constant #" << ti->getNumDomainConstants()+1;
+        Trace("sym-break-dc") << " : " << fa << " for " << tn << " : " << sid << std::endl;
+        //add to domain constants
+        ti->d_dom_constants.push_back( fa );
+        if( ti->d_node_to_id[fa]>ti->d_dc_nodes ){
+          Trace("sym-break-dc-debug") << "Swap nodes... " << ti->d_dc_nodes << " " << ti->d_node_to_id[fa] << " " << ti->d_nodes.size() << std::endl;
+          //swap
+          Node on = ti->d_nodes[ti->d_dc_nodes];
+          int id = ti->d_node_to_id[fa];
+
+          ti->d_nodes[ti->d_dc_nodes] = fa;
+          ti->d_nodes[id] = on;
+          ti->d_node_to_id[fa] = ti->d_dc_nodes;
+          ti->d_node_to_id[on] = id;
+        }
+        ti->d_dc_nodes++;
+        Trace("sym-break-dc-debug") << "Get max type info..." << std::endl;
+        Assert( d_has_dom_const_sort.get() );
+        int msid = d_max_dom_const_sort.get();
+        TypeInfo * max_ti = getTypeInfo( d_sid_to_type[msid], msid );
+        Trace("sym-break-dc-debug") << "Swap nodes..." << std::endl;
+        //now, check if we can apply symmetry breaking to another sort
+        if( ti->getNumDomainConstants()>max_ti->getNumDomainConstants() ){
+          Trace("sym-break-dc") << "Max domain constant subsort for " << tn << " becomes " << sid << std::endl;
+          d_max_dom_const_sort.set( sid );
+        }else if( ti!=max_ti ){
+          //construct symmetry breaking lemma
+          //current domain constant must be disequal from all current ones
+          Trace("sym-break-dc") << "Get domain constant " << ti->getNumDomainConstants()-1;
+          Trace("sym-break-dc") << " from max_ti, " << max_ti->getNumDomainConstants() << std::endl;
+          //apply a symmetry breaking lemma
+          Node m = max_ti->getDomainConstant(ti->getNumDomainConstants()-1);
+          //if fa and m are disequal from all previous domain constants in the other sort
+          std::vector< Node > cc;
+          for( unsigned r=0; r<2; r++ ){
+            Node n = ((r==0)==(msid>sid)) ? fa : m;
+            Node on = ((r==0)==(msid>sid)) ? m : fa;
+            TypeInfo * t = ((r==0)==(msid>sid)) ? max_ti : ti;
+            for( unsigned i=0; i<t->d_node_to_id[on]; i++ ){
+              cc.push_back( n.eqNode( t->d_nodes[i] ) );
+            }
+          }
+          //then, we can assume fa = m
+          cc.push_back( fa.eqNode( m ) );
+          Node lem = NodeManager::currentNM()->mkNode( kind::OR, cc );
+          lem = Rewriter::rewrite( lem );
+          if( std::find( d_lemmas.begin(), d_lemmas.end(), lem )==d_lemmas.end() ){
+            d_lemmas.push_back( lem );
+            Trace("sym-break-lemma") << "*** Symmetry break lemma for " << tn << " (" << sid << "==" << d_max_dom_const_sort.get() << ") : ";
+            Trace("sym-break-lemma") << lem << std::endl;
+            d_pending_lemmas.push_back( lem );
+          }
+        }
+        invalid = true;
+      }
+      if( invalid ){
+        ti->d_first_active.set( ti->d_first_active + 1 );
+        fa = ti->getFirstActive();
+      }
+    }
+  }
+}
+
+void SubsortSymmetryBreaker::printDebugTypeInfo( const char * c, TypeNode tn, int sid ) {
+  Trace(c) << "TypeInfo( " << tn << ", " << sid << " ) = " << std::endl;
+  Trace(c) << "  Domain constants : ";
+  TypeInfo * ti = getTypeInfo( tn, sid );
+  for( NodeList::const_iterator it = ti->d_dom_constants.begin(); it != ti->d_dom_constants.end(); ++it ){
+    Node dc = *it;
+    Trace(c) << dc << " ";
+  }
+  Trace(c) << std::endl;
+  Trace(c) << "  First active node : " << ti->getFirstActive() << std::endl;
+}
+
+
+void SubsortSymmetryBreaker::queueFact( Node n ) {
+  d_fact_list.push_back( n );
+  /*
+  if( n.getKind()==EQUAL ){
+    merge( n[0], n[1] );
+  }else if( n.getKind()==NOT && n[0].getKind()==EQUAL ){
+    assertDisequal( n[0][0], n[0][1] );
+  }else{
+    newEqClass( n );
+  }
+  */
+}
+
+bool SubsortSymmetryBreaker::check( Theory::Effort level ) {
+  d_pending_lemmas.clear();
+
+  Trace("sym-break-debug") << "SymBreak : check " << level << std::endl;
+  while( d_fact_index.get()<d_fact_list.size() ){
+    Node f = d_fact_list[d_fact_index.get()];
+    d_fact_index.set( d_fact_index.get() + 1 );
+    if( f.getKind()==EQUAL ){
+      merge( f[0], f[1] );
+    }else if( f.getKind()==NOT && f[0].getKind()==EQUAL ){
+      assertDisequal( f[0][0], f[0][1] );
+    }else{
+      newEqClass( f );
+    }
+  }
+  Trace("sym-break-debug") << "SymBreak : update first actives" << std::endl;
+  for( std::map< TypeNode, std::vector< int > >::iterator it = d_sub_sorts.begin(); it != d_sub_sorts.end(); ++it ){
+    int card = getStrongSolver()->getCardinality( it->first );
+    for( unsigned i=0; i<it->second.size(); i++ ){
+      //check if the first active is disequal from all domain constants
+      processFirstActive( it->first, it->second[i], card );
+    }
+  }
+
+
+  Trace("sym-break-debug") << "SymBreak : finished check, now flush lemmas... (#lemmas = " << d_pending_lemmas.size() << ")" << std::endl;
+  //flush pending lemmas
+  if( !d_pending_lemmas.empty() ){
+    for( unsigned i=0; i<d_pending_lemmas.size(); i++ ){
+      getStrongSolver()->getOutputChannel().lemma( d_pending_lemmas[i] );
+      ++( getStrongSolver()->d_statistics.d_sym_break_lemmas );
+    }
+    d_pending_lemmas.clear();
+    return true;
+  }else{
+    return false;
+  }
+}
+
+
+
+}
+
diff --git a/src/theory/quantifiers/symmetry_breaking.h b/src/theory/quantifiers/symmetry_breaking.h
new file mode 100755 (executable)
index 0000000..3db9097
--- /dev/null
@@ -0,0 +1,121 @@
+/*********************                                                        */
+/*! \file symmetry_breaking.h
+ ** \verbatim
+ ** Original author: ajreynol
+ ** Major contributors: none
+ ** Minor contributors (to current version): none
+ ** This file is part of the CVC4 prototype.
+ ** Copyright (c) 2009-2012  New York University and The University of Iowa
+ ** See the file COPYING in the top-level source directory for licensing
+ ** information.\endverbatim
+ **
+ ** \brief Pre-process step for first-order reasoning
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef __CVC4__QUANT_SYMMETRY_BREAKING_H
+#define __CVC4__QUANT_SYMMETRY_BREAKING_H
+
+#include "theory/theory.h"
+
+#include <iostream>
+#include <string>
+#include <vector>
+#include <map>
+#include "expr/node.h"
+#include "expr/type_node.h"
+
+#include "util/sort_inference.h"
+#include "context/context.h"
+#include "context/context_mm.h"
+#include "context/cdchunk_list.h"
+
+namespace CVC4 {
+namespace theory {
+
+namespace uf {
+  class StrongSolverTheoryUF;
+}
+
+class SubsortSymmetryBreaker {
+  typedef context::CDHashMap<Node, bool, NodeHashFunction> NodeBoolMap;
+  typedef context::CDHashMap<Node, int, NodeHashFunction> NodeIntMap;
+  typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeNodeMap;
+  //typedef context::CDChunkList<int> IntList;
+  typedef context::CDList<Node> NodeList;
+  typedef context::CDHashMap<Node, NodeList*, NodeHashFunction> NodeListMap;
+private:
+  /** quantifiers engine */
+  QuantifiersEngine* d_qe;
+  eq::EqualityEngine * getEqualityEngine();
+  bool areDisequal( Node n1, Node n2 );
+  bool areEqual( Node n1, Node n2 );
+  Node getRepresentative( Node n );
+  uf::StrongSolverTheoryUF * getStrongSolver();
+  std::vector< Node > d_unit_lemmas;
+  Node d_true;
+  context::CDO< bool > d_conflict;
+public:
+  SubsortSymmetryBreaker( QuantifiersEngine* qe, context::Context* c );
+  ~SubsortSymmetryBreaker(){}
+
+private:
+  class TypeInfo {
+  private:
+    SubsortSymmetryBreaker * d_ssb;
+    //bool isActive( Node n, unsigned & deq );
+  public:
+    TypeInfo( SubsortSymmetryBreaker * ssb, context::Context* c );
+    //list of all nodes from this (sub)type
+    std::vector< Node > d_nodes;
+    //the current domain constants for this (sub)type
+    NodeList d_dom_constants;
+    //# nodes in d_nodes that have been domain constants, size of this distinct # of domain constants seen
+    unsigned d_dc_nodes;
+    //the node we are currently watching to become a domain constant
+    context::CDO< int > d_first_active;
+    //node to id
+    std::map< Node, unsigned > d_node_to_id;
+    Node getBaseConstant() { return d_nodes.empty() ? Node::null() : d_nodes[0]; }
+    bool hasDomainConstant( Node n );
+    unsigned getNumDomainConstants();
+    Node getDomainConstant( int i );
+    Node getFirstActive();
+  };
+  std::map< TypeNode, std::vector< int > > d_sub_sorts;
+  std::map< int, TypeNode > d_sid_to_type;
+  std::map< int, TypeInfo * > d_type_info;
+
+  //maximum domain constants sort
+  context::CDO< int > d_max_dom_const_sort;
+  context::CDO< bool > d_has_dom_const_sort;
+
+  TypeInfo * getTypeInfo( TypeNode tn, int sid );
+
+  void processFirstActive( TypeNode tn, int sid, int curr_card );
+private:
+  //void printDebugNodeInfo( const char * c, Node n );
+  void printDebugTypeInfo( const char * c, TypeNode tn, int sid );
+  /** new node */
+  void newEqClass( Node n );
+  /** merge */
+  void merge( Node a, Node b );
+  /** assert disequal */
+  void assertDisequal( Node a, Node b );
+  /** fact list */
+  context::CDO< unsigned > d_fact_index;
+  NodeList d_fact_list;
+  std::vector< Node > d_pending_lemmas;
+  std::vector< Node > d_lemmas;
+public:
+  /** queue fact */
+  void queueFact( Node n );
+  /** check */
+  bool check( Theory::Effort level );
+};
+
+}
+}
+
+#endif
old mode 100644 (file)
new mode 100755 (executable)
index e5cc8a1..0fe50aa
@@ -29,6 +29,7 @@
 #include "theory/rewriterules/rr_trigger.h"
 #include "theory/quantifiers/bounded_integers.h"
 #include "theory/quantifiers/rewrite_engine.h"
+#include "theory/uf/options.h"
 
 using namespace std;
 using namespace CVC4;
@@ -632,6 +633,7 @@ Node EqualityQueryQuantifiersEngine::getInternalRepresentative( Node a, Node f,
   }else{
     int sortId = 0;
     if( optInternalRepSortInference() ){
+    //if( options::ufssSymBreak() ){
       sortId = d_qe->getTheoryEngine()->getSortInference()->getSortId( f, f[0][index] );
     }
     if( d_int_rep[sortId].find( r )==d_int_rep[sortId].end() ){
index 647ef965ae18a97c23a9de14e7d2771870746257..800e007f79e68712ca19b6a9454ded437b2111bb 100644 (file)
@@ -278,7 +278,12 @@ bool RepSetIterator::resetIndex( int i, bool initial ) {
         Node range = Rewriter::rewrite( NodeManager::currentNM()->mkNode( MINUS, u, l ) );
         Node ra = Rewriter::rewrite( NodeManager::currentNM()->mkNode( LEQ, range, NodeManager::currentNM()->mkConst( Rational( 9999 ) ) ) );
         d_domain[ii].clear();
-        d_lower_bounds[ii] = l;
+        Node tl = l;
+        Node tu = u;
+        if( d_qe->getBoundedIntegers() && d_qe->getBoundedIntegers()->isBoundVar( d_owner, d_owner[0][ii] ) ){
+          d_qe->getBoundedIntegers()->getBounds( d_owner, d_owner[0][ii], this, tl, tu );
+        }
+        d_lower_bounds[ii] = tl;
         if( ra==NodeManager::currentNM()->mkConst(true) ){
           long rr = range.getConst<Rational>().getNumerator().getLong()+1;
           Trace("bound-int-rsi")  << "Actual bound range is " << rr << std::endl;
index 437e30e461340e55e1e5a03107b8560eb1298e7a..b9f60b83d72c55972b8838fd8e2fed62bd00fa34 100644 (file)
@@ -40,5 +40,7 @@ option ufssMinimalModel /--disable-uf-ss-min-model bool :default true
 option ufssCliqueSplits --uf-ss-clique-splits bool :default false
  use cliques instead of splitting on demand to shrink model
 
+option ufssSymBreak --uf-ss-sym-break bool :default false
+ finite model finding symmetry breaking techniques
 
 endmodule
index adcf78a8643088fd35b5ff9c9155e7f8b3d1befa..82cd1f809cc623fb38a19a3c4847dd02669d7a80 100644 (file)
@@ -20,6 +20,8 @@
 #include "theory/quantifiers/term_database.h"
 #include "theory/uf/options.h"
 #include "theory/model.h"
+#include "theory/quantifiers/symmetry_breaking.h"
+
 
 //#define ONE_SPLIT_REGION
 //#define DISABLE_QUICK_CLIQUE_CHECKS
@@ -117,6 +119,10 @@ void StrongSolverTheoryUF::SortModel::Region::setEqual( Node a, Node b ){
           if( options::ufssDiseqPropagation() ){
             d_cf->d_thss->getDisequalityPropagator()->assertDisequal(a, n, Node::null());
           }
+          if( options::ufssSymBreak() ){
+            //d_cf->d_thss->getSymmetryBreaker()->assertDisequal( a, n );
+            d_cf->d_thss->getSymmetryBreaker()->queueFact( a.eqNode( n ).negate() );
+          }
         }
         setDisequal( b, n, t, false );
         nr->setDisequal( n, b, t, false );
@@ -515,9 +521,15 @@ void StrongSolverTheoryUF::SortModel::merge( Node a, Node b ){
       }
       d_reps = d_reps - 1;
 
-      if( options::ufssDiseqPropagation() && !d_conflict ){
-        //notify the disequality propagator
-        d_thss->getDisequalityPropagator()->merge(a, b);
+      if( !d_conflict ){
+        if( options::ufssDiseqPropagation() ){
+          //notify the disequality propagator
+          d_thss->getDisequalityPropagator()->merge(a, b);
+        }
+        if( options::ufssSymBreak() ){
+          //d_thss->getSymmetryBreaker()->merge(a, b);
+          d_thss->getSymmetryBreaker()->queueFact( a.eqNode( b ) );
+        }
       }
     }
   }
@@ -565,9 +577,15 @@ void StrongSolverTheoryUF::SortModel::assertDisequal( Node a, Node b, Node reaso
           checkRegion( bi );
         }
 
-        if( options::ufssDiseqPropagation() && !d_conflict ){
-          //notify the disequality propagator
-          d_thss->getDisequalityPropagator()->assertDisequal(a, b, Node::null());
+        if( !d_conflict ){
+          if( options::ufssDiseqPropagation() ){
+            //notify the disequality propagator
+            d_thss->getDisequalityPropagator()->assertDisequal(a, b, Node::null());
+          }
+          if( options::ufssSymBreak() ){
+            //d_thss->getSymmetryBreaker()->assertDisequal(a, b);
+            d_thss->getSymmetryBreaker()->queueFact( a.eqNode( b ).negate() );
+          }
         }
       }
     }
@@ -670,7 +688,7 @@ void StrongSolverTheoryUF::SortModel::check( Theory::Effort level, OutputChannel
               for( int i=0; i<(int)d_regions_index; i++ ){
                 if( d_regions[i]->d_valid ){
                   Node op = d_regions[i]->d_nodes.begin()->first;
-                  int sort_id = d_thss->getTheory()->getQuantifiersEngine()->getTheoryEngine()->getSortInference()->getSortId(op);
+                  int sort_id = d_thss->getSortInference()->getSortId(op);
                   if( sortsFound.find( sort_id )!=sortsFound.end() ){
                     combineRegions( sortsFound[sort_id], i );
                     recheck = true;
@@ -979,17 +997,32 @@ void StrongSolverTheoryUF::SortModel::moveNode( Node n, int ri ){
 void StrongSolverTheoryUF::SortModel::allocateCardinality( OutputChannel* out ){
   if( d_aloc_cardinality>0 ){
     Trace("uf-ss-fmf") << "No model of size " << d_aloc_cardinality << " exists for type " << d_type << " in this branch" << std::endl;
-    if( Trace.isOn("uf-ss-cliques") ){
-      Trace("uf-ss-cliques") << "Cliques of size " << (d_aloc_cardinality+1) << " : " << std::endl;
-      for( size_t i=0; i<d_cliques[ d_aloc_cardinality ].size(); i++ ){
-        Trace("uf-ss-cliques") << "  ";
-        for( size_t j=0; j<d_cliques[ d_aloc_cardinality ][i].size(); j++ ){
-          Trace("uf-ss-cliques") << d_cliques[ d_aloc_cardinality ][i][j] << " ";
-        }
-        Trace("uf-ss-cliques") << std::endl;
+  }
+  if( Trace.isOn("uf-ss-cliques") ){
+    Trace("uf-ss-cliques") << "Cliques of size " << (d_aloc_cardinality+1) << " for " << d_type << " : " << std::endl;
+    for( size_t i=0; i<d_cliques[ d_aloc_cardinality ].size(); i++ ){
+      Trace("uf-ss-cliques") << "  ";
+      for( size_t j=0; j<d_cliques[ d_aloc_cardinality ][i].size(); j++ ){
+        Trace("uf-ss-cliques") << d_cliques[ d_aloc_cardinality ][i][j] << " ";
       }
+      Trace("uf-ss-cliques") << std::endl;
+    }
+  }
+  /*
+  if( options::ufssSymBreak() ){
+    std::vector< Node > reps;
+    getRepresentatives( reps );
+    if( d_aloc_cardinality>0 ){
+      d_thss->getSymmetryBreaker()->allocateCardinality( out, d_type, d_aloc_cardinality+1, d_cliques[ d_aloc_cardinality ], reps );
+    }else{
+      std::vector< Node > clique;
+      clique.push_back( d_cardinality_term );
+      std::vector< std::vector< Node > > cliques;
+      cliques.push_back( clique );
+      d_thss->getSymmetryBreaker()->allocateCardinality( out, d_type, 1, cliques, reps );
     }
   }
+  */
   d_aloc_cardinality = d_aloc_cardinality + 1;
 
   //check for abort case
@@ -1094,7 +1127,7 @@ bool StrongSolverTheoryUF::SortModel::addSplit( Region* r, OutputChannel* out ){
     Trace("uf-ss-lemma") << "*** Split on " << s << std::endl;
     if( options::sortInference()) {
       for( int i=0; i<2; i++ ){
-        int si = d_thss->getTheory()->getQuantifiersEngine()->getTheoryEngine()->getSortInference()->getSortId( s[i] );
+        int si = d_thss->getSortInference()->getSortId( s[i] );
         Trace("uf-ss-split-si") << si << " ";
       }
       Trace("uf-ss-split-si")  << std::endl;
@@ -1122,10 +1155,10 @@ void StrongSolverTheoryUF::SortModel::addCliqueLemma( std::vector< Node >& cliqu
     clique.pop_back();
   }
   //debugging information
-  if( Trace.isOn("uf-ss-cliques") ){
+  if( options::ufssSymBreak() ){
     std::vector< Node > clique_vec;
     clique_vec.insert( clique_vec.begin(), clique.begin(), clique.end() );
-    d_cliques[ d_cardinality ].push_back( clique_vec );
+    addClique( d_cardinality, clique_vec );
   }
   if( options::ufssSimpleCliques() && !options::ufssExplainedCliques() ){
     //add as lemma
@@ -1273,7 +1306,7 @@ void StrongSolverTheoryUF::SortModel::addTotalityAxiom( Node n, int cardinality,
       Node cardLit = d_cardinality_literal[ cardinality ];
       int sort_id = 0;
       if( options::sortInference() ){
-        sort_id = d_thss->getTheory()->getQuantifiersEngine()->getTheoryEngine()->getSortInference()->getSortId(n);
+        sort_id = d_thss->getSortInference()->getSortId(n);
       }
       Trace("uf-ss-totality") << "Add totality lemma for " << n << " " << cardinality << ", sort id is " << sort_id << std::endl;
       int use_cardinality = cardinality;
@@ -1302,6 +1335,14 @@ void StrongSolverTheoryUF::SortModel::addTotalityAxiom( Node n, int cardinality,
   }
 }
 
+void StrongSolverTheoryUF::SortModel::addClique( int c, std::vector< Node >& clique ) {
+
+  if( d_clique_trie[c].add( clique ) ){
+    d_cliques[ c ].push_back( clique );
+  }
+}
+
+
 /** apply totality */
 bool StrongSolverTheoryUF::SortModel::applyTotality( int cardinality ){
   return options::ufssTotality() || cardinality<=options::ufssTotalityLimited();
@@ -1379,22 +1420,16 @@ int StrongSolverTheoryUF::SortModel::getNumRegions(){
 }
 
 void StrongSolverTheoryUF::SortModel::getRepresentatives( std::vector< Node >& reps ){
-  //if( !options::ufssColoringSat() ){
-    bool foundRegion = false;
-    for( int i=0; i<(int)d_regions_index; i++ ){
-      //should not have multiple regions at this point
-      if( foundRegion ){
-        Assert( !d_regions[i]->d_valid );
-      }
-      if( d_regions[i]->d_valid ){
-        //this is the only valid region
-        d_regions[i]->getRepresentatives( reps );
-        foundRegion = true;
-      }
+  for( int i=0; i<(int)d_regions_index; i++ ){
+    //should not have multiple regions at this point
+    //if( foundRegion ){
+    //  Assert( !d_regions[i]->d_valid );
+    //}
+    if( d_regions[i]->d_valid ){
+      //this is the only valid region
+      d_regions[i]->getRepresentatives( reps );
     }
-  //}else{
-  //  Unimplemented("Build representatives for fmf region sat is not implemented");
-  //}
+  }
 }
 
 StrongSolverTheoryUF::StrongSolverTheoryUF(context::Context* c, context::UserContext* u, OutputChannel& out, TheoryUF* th) :
@@ -1415,6 +1450,15 @@ d_rep_model_init( c )
   }else{
     d_deq_prop = NULL;
   }
+  if( options::ufssSymBreak() ){
+    d_sym_break = new SubsortSymmetryBreaker( th->getQuantifiersEngine(), c );
+  }else{
+    d_sym_break = NULL;
+  }
+}
+
+SortInference* StrongSolverTheoryUF::getSortInference() {
+  return d_th->getQuantifiersEngine()->getTheoryEngine()->getSortInference();
 }
 
 /** get default sat context */
@@ -1433,6 +1477,10 @@ void StrongSolverTheoryUF::newEqClass( Node n ){
   if( c ){
     Trace("uf-ss-solver") << "StrongSolverTheoryUF: New eq class " << n << " : " << n.getType() << std::endl;
     c->newEqClass( n );
+    if( options::ufssSymBreak() ){
+      //d_sym_break->newEqClass( n );
+      d_sym_break->queueFact( n );
+    }
   }
 }
 
@@ -1539,6 +1587,10 @@ void StrongSolverTheoryUF::check( Theory::Effort level ){
         break;
       }
     }
+    //check symmetry breaker
+    if( !d_conflict && options::ufssSymBreak() ){
+      d_sym_break->check( level );
+    }
     //disambiguate terms if necessary
     //if( !d_conflict && level==Theory::EFFORT_FULL && options::ufssColoringSat() ){
     //  Assert( d_term_amb!=NULL );
@@ -1644,6 +1696,14 @@ int StrongSolverTheoryUF::getCardinality( Node n ) {
   }
 }
 
+int StrongSolverTheoryUF::getCardinality( TypeNode tn ) {
+  std::map< TypeNode, SortModel* >::iterator it = d_rep_model.find( tn );
+  if( it!=d_rep_model.end() && it->second ){
+    return it->second->getCardinality();
+  }
+  return -1;
+}
+
 void StrongSolverTheoryUF::getRepresentatives( Node n, std::vector< Node >& reps ){
   SortModel* c = getSortModel( n );
   if( c ){
@@ -1698,6 +1758,7 @@ StrongSolverTheoryUF::Statistics::Statistics():
   d_clique_lemmas("StrongSolverTheoryUF::Clique_Lemmas", 0),
   d_split_lemmas("StrongSolverTheoryUF::Split_Lemmas", 0),
   d_disamb_term_lemmas("StrongSolverTheoryUF::Disambiguate_Term_Lemmas", 0),
+  d_sym_break_lemmas("StrongSolverTheoryUF::Symmetry_Breaking_Lemmas", 0),
   d_totality_lemmas("StrongSolverTheoryUF::Totality_Lemmas", 0),
   d_max_model_size("StrongSolverTheoryUF::Max_Model_Size", 1)
 {
@@ -1705,6 +1766,7 @@ StrongSolverTheoryUF::Statistics::Statistics():
   StatisticsRegistry::registerStat(&d_clique_lemmas);
   StatisticsRegistry::registerStat(&d_split_lemmas);
   StatisticsRegistry::registerStat(&d_disamb_term_lemmas);
+  StatisticsRegistry::registerStat(&d_sym_break_lemmas);
   StatisticsRegistry::registerStat(&d_totality_lemmas);
   StatisticsRegistry::registerStat(&d_max_model_size);
 }
@@ -1714,6 +1776,7 @@ StrongSolverTheoryUF::Statistics::~Statistics(){
   StatisticsRegistry::unregisterStat(&d_clique_lemmas);
   StatisticsRegistry::unregisterStat(&d_split_lemmas);
   StatisticsRegistry::unregisterStat(&d_disamb_term_lemmas);
+  StatisticsRegistry::unregisterStat(&d_sym_break_lemmas);
   StatisticsRegistry::unregisterStat(&d_totality_lemmas);
   StatisticsRegistry::unregisterStat(&d_max_model_size);
 }
index fa8d60b491add8ddba7937af13b8d9b0d4be1054..8e568444b99cc2b1fc01102ece4bda75611658fe 100644 (file)
 #include "util/statistics_registry.h"
 
 namespace CVC4 {
+
+class SortInference;
+
 namespace theory {
+
+class SubsortSymmetryBreaker;
+
 namespace uf {
 
 class TheoryUF;
@@ -40,7 +46,6 @@ protected:
   typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeNodeMap;
   typedef context::CDChunkList<Node> NodeList;
   typedef context::CDList<bool> BoolList;
-  typedef context::CDList<bool> IntList;
   typedef context::CDHashMap<TypeNode, bool, TypeNodeHashFunction> TypeNodeBoolMap;
 public:
   /** information for incremental conflict/clique finding for a particular sort */
@@ -201,6 +206,23 @@ public:
     void addCliqueLemma( std::vector< Node >& clique, OutputChannel* out );
     /** add totality axiom */
     void addTotalityAxiom( Node n, int cardinality, OutputChannel* out );
+  private:
+    class NodeTrie {
+      std::map< Node, NodeTrie > d_children;
+    public:
+      bool add( std::vector< Node >& n, unsigned i = 0 ){
+        Assert( i<n.size() );
+        if( i==(n.size()-1) ){
+          bool ret = d_children.find( n[i] )==d_children.end();
+          d_children[n[i]].d_children.clear();
+          return ret;
+        }else{
+          return d_children[n[i]].add( n, i+1 );
+        }
+      }
+    };
+    std::map< int, NodeTrie > d_clique_trie;
+    void addClique( int c, std::vector< Node >& clique );
   private:
     /** Are we in conflict */
     context::CDO<bool> d_conflict;
@@ -286,6 +308,8 @@ private:
   TermDisambiguator* d_term_amb;
   /** disequality propagator */
   DisequalityPropagator* d_deq_prop;
+  /** symmetry breaking techniques */
+  SubsortSymmetryBreaker* d_sym_break;
 public:
   StrongSolverTheoryUF(context::Context* c, context::UserContext* u, OutputChannel& out, TheoryUF* th);
   ~StrongSolverTheoryUF() {}
@@ -295,6 +319,10 @@ public:
   TermDisambiguator* getTermDisambiguator() { return d_term_amb; }
   /** disequality propagator */
   DisequalityPropagator* getDisequalityPropagator() { return d_deq_prop; }
+  /** symmetry breaker */
+  SubsortSymmetryBreaker* getSymmetryBreaker() { return d_sym_break; }
+  /** get sort inference module */
+  SortInference* getSortInference();
   /** get default sat context */
   context::Context* getSatContext();
   /** get default output channel */
@@ -336,8 +364,10 @@ public:
   TypeNode getCardinalityType( int i ) { return d_conf_types[i]; }
   /** get is in conflict */
   bool isConflict() { return d_conflict; }
-  /** get cardinality for sort */
+  /** get cardinality for node */
   int getCardinality( Node n );
+  /** get cardinality for type */
+  int getCardinality( TypeNode tn );
   /** get representatives */
   void getRepresentatives( Node n, std::vector< Node >& reps );
   /** minimize */
@@ -349,6 +379,7 @@ public:
     IntStat d_clique_lemmas;
     IntStat d_split_lemmas;
     IntStat d_disamb_term_lemmas;
+    IntStat d_sym_break_lemmas;
     IntStat d_totality_lemmas;
     IntStat d_max_model_size;
     Statistics();
index 13631e590a32c3dd3ed650194571bf24d0927ec0..a4c34faec6ef32b0c552f6589934f082e461d04a 100644 (file)
@@ -27,8 +27,55 @@ using namespace std;
 namespace CVC4 {
 
 
+void SortInference::UnionFind::print(const char * c){
+  for( std::map< int, int >::iterator it = d_eqc.begin(); it != d_eqc.end(); ++it ){
+    Trace(c) << "s_" << it->first << " = s_" << it->second << ", ";
+  }
+  for( unsigned i=0; i<d_deq.size(); i++ ){
+    Trace(c) << "s_" << d_deq[i].first << " != s_" << d_deq[i].second << ", ";
+  }
+  Trace(c) << std::endl;
+}
+void SortInference::UnionFind::set( UnionFind& c ) {
+  clear();
+  for( std::map< int, int >::iterator it = c.d_eqc.begin(); it != c.d_eqc.end(); ++it ){
+    d_eqc[ it->first ] = it->second;
+  }
+  d_deq.insert( d_deq.end(), c.d_deq.begin(), c.d_deq.end() );
+}
+int SortInference::UnionFind::getRepresentative( int t ){
+  std::map< int, int >::iterator it = d_eqc.find( t );
+  if( it==d_eqc.end() || it->second==t ){
+    return t;
+  }else{
+    int rt = getRepresentative( it->second );
+    d_eqc[t] = rt;
+    return rt;
+  }
+}
+void SortInference::UnionFind::setEqual( int t1, int t2 ){
+  if( t1!=t2 ){
+    int rt1 = getRepresentative( t1 );
+    int rt2 = getRepresentative( t2 );
+    if( rt1>rt2 ){
+      d_eqc[rt1] = rt2;
+    }else{
+      d_eqc[rt2] = rt1;
+    }
+  }
+}
+bool SortInference::UnionFind::isValid() {
+  for( unsigned i=0; i<d_deq.size(); i++ ){
+    if( areEqual( d_deq[i].first, d_deq[i].second ) ){
+      return false;
+    }
+  }
+  return true;
+}
+
+
 void SortInference::printSort( const char* c, int t ){
-  int rt = getRepresentative( t );
+  int rt = d_type_union_find.getRepresentative( t );
   if( d_type_types.find( rt )!=d_type_types.end() ){
     Trace(c) << d_type_types[rt];
   }else{
@@ -83,46 +130,19 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doRewrite ){
       //add lemma enforcing introduced constants to be distinct?
     }
   }
-}
-
-int SortInference::getRepresentative( int t ){
-  std::map< int, int >::iterator it = d_type_union_find.find( t );
-  if( it!=d_type_union_find.end() ){
-    if( it->second==t ){
-      return t;
-    }else{
-      int rt = getRepresentative( it->second );
-      d_type_union_find[t] = rt;
-      return rt;
-    }
-  }else{
-    return t;
-  }
+  initialSortCount = sortCount;
 }
 
 void SortInference::setEqual( int t1, int t2 ){
   if( t1!=t2 ){
-    int rt1 = getRepresentative( t1 );
-    int rt2 = getRepresentative( t2 );
+    int rt1 = d_type_union_find.getRepresentative( t1 );
+    int rt2 = d_type_union_find.getRepresentative( t2 );
     if( rt1!=rt2 ){
       Trace("sort-inference-debug") << "Set equal : ";
       printSort( "sort-inference-debug", rt1 );
       Trace("sort-inference-debug") << " ";
       printSort( "sort-inference-debug", rt2 );
       Trace("sort-inference-debug") << std::endl;
-      //check if they must be a type
-      std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
-      std::map< int, TypeNode >::iterator it2 = d_type_types.find( rt2 );
-      if( it2!=d_type_types.end() ){
-        if( it1==d_type_types.end() ){
-          //swap sides
-          int swap = rt1;
-          rt1 = rt2;
-          rt2 = swap;
-        }else{
-          Assert( rt1==rt2 );
-        }
-      }
       /*
       d_type_eq_class[rt1].insert( d_type_eq_class[rt1].end(), d_type_eq_class[rt2].begin(), d_type_eq_class[rt2].end() );
       d_type_eq_class[rt2].clear();
@@ -132,7 +152,19 @@ void SortInference::setEqual( int t1, int t2 ){
       }
       Trace("sort-inference-debug") << "}" << std::endl;
       */
-      d_type_union_find[rt2] = rt1;
+      if( rt2>rt1 ){
+        //swap
+        int swap = rt1;
+        rt1 = rt2;
+        rt2 = swap;
+      }
+      d_type_union_find.d_eqc[rt1] = rt2;
+      std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
+      if( it1!=d_type_types.end() ){
+        Assert( d_type_types.find( rt2 )==d_type_types.end() );
+        d_type_types[rt2] = it1->second;
+        d_type_types.erase( rt1 );
+      }
     }
   }
 }
@@ -155,14 +187,17 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound ){
   Trace("sort-inference-debug") << "Process " << n << std::endl;
   //add to variable bindings
   if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-    for( size_t i=0; i<n[0].getNumChildren(); i++ ){
-      //TODO: try applying sort inference to quantified variables
-      d_var_types[n][ n[0][i] ] = sortCount;
-      sortCount++;
+    if( d_var_types.find( n )!=d_var_types.end() ){
+      return getIdForType( n.getType() );
+    }else{
+      for( size_t i=0; i<n[0].getNumChildren(); i++ ){
+        //apply sort inference to quantified variables
+        d_var_types[n][ n[0][i] ] = sortCount;
+        sortCount++;
 
-      //type of the quantified variable must be the same
-      //d_var_types[n][ n[0][i] ] = getIdForType( n[0][i].getType() );
-      var_bound[ n[0][i] ] = n;
+        //type of the quantified variable must be the same
+        var_bound[ n[0][i] ] = n;
+      }
     }
   }
 
@@ -191,10 +226,10 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound ){
   int retType;
   if( n.getKind()==kind::EQUAL ){
     //we only require that the left and right hand side must be equal
-    //setEqual( child_types[0], child_types[1] );
-    int eqType = getIdForType( n[0].getType() );
-    setEqual( child_types[0], eqType );
-    setEqual( child_types[1], eqType );
+    setEqual( child_types[0], child_types[1] );
+    //int eqType = getIdForType( n[0].getType() );
+    //setEqual( child_types[0], eqType );
+    //setEqual( child_types[1], eqType );
     retType = getIdForType( n.getType() );
   }else if( n.getKind()==kind::APPLY_UF ){
     Node op = n.getOperator();
@@ -256,7 +291,7 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound ){
 
 
 TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
-  int rt = getRepresentative( t );
+  int rt = d_type_union_find.getRepresentative( t );
   if( d_type_types.find( rt )!=d_type_types.end() ){
     return d_type_types[rt];
   }else{
@@ -281,7 +316,7 @@ TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
 }
 
 TypeNode SortInference::getTypeForId( int t ){
-  int rt = getRepresentative( t );
+  int rt = d_type_union_find.getRepresentative( t );
   if( d_type_types.find( rt )!=d_type_types.end() ){
     return d_type_types[rt];
   }else{
@@ -417,15 +452,71 @@ Node SortInference::simplify( Node n, std::map< Node, Node >& var_bound ){
 }
 int SortInference::getSortId( Node n ) {
   Node op = n.getKind()==kind::APPLY_UF ? n.getOperator() : n;
-  return getRepresentative( d_op_return_types[op] );
+  if( d_op_return_types.find( op )!=d_op_return_types.end() ){
+    return d_type_union_find.getRepresentative( d_op_return_types[op] );
+  }else{
+    return 0;
+  }
 }
 
 int SortInference::getSortId( Node f, Node v ) {
-  return getRepresentative( d_var_types[f][v] );
+  if( d_var_types.find( f )!=d_var_types.end() ){
+    return d_type_union_find.getRepresentative( d_var_types[f][v] );
+  }else{
+    return 0;
+  }
 }
 
 void SortInference::setSkolemVar( Node f, Node v, Node sk ){
+  Trace("sort-inference-temp") << "Set skolem var for " << f << ", variable " << v << std::endl;
+  if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
+    std::map< Node, Node > var_bound;
+    process( f, var_bound );
+  }
   d_op_return_types[sk] = getSortId( f, v );
+  Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
+}
+
+bool SortInference::isWellSortedFormula( Node n ) {
+  if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
+    for( unsigned i=0; i<n.getNumChildren(); i++ ){
+      if( !isWellSortedFormula( n[i] ) ){
+        return false;
+      }
+    }
+    return true;
+  }else{
+    return isWellSorted( n );
+  }
+}
+
+bool SortInference::isWellSorted( Node n ) {
+  if( getSortId( n )==0 ){
+    return false;
+  }else{
+    if( n.getKind()==kind::APPLY_UF ){
+      for( unsigned i=0; i<n.getNumChildren(); i++ ){
+        int s1 = getSortId( n[i] );
+        int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
+        if( s1!=s2 ){
+          return false;
+        }
+        if( !isWellSorted( n[i] ) ){
+          return false;
+        }
+      }
+    }
+    return true;
+  }
+}
+
+void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
+  if( n.getKind()==kind::APPLY_UF ){
+    for( unsigned i=0; i<n.getNumChildren(); i++ ){
+      getSortConstraints( n[i], uf );
+      uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
+    }
+  }
 }
 
 }/* CVC4 namespace */
index 1bcb8a2088dec9a30b6a5a4eb64636efb4acbe83..53dff823fa858b0435ac648da9ea5f00beda3017 100644 (file)
@@ -30,9 +30,29 @@ class SortInference{
 private:
   //for debugging
   //std::map< int, std::vector< Node > > d_type_eq_class;
+public:
+  class UnionFind {
+  public:
+    UnionFind(){}
+    UnionFind( UnionFind& c ){
+      set( c );
+    }
+    std::map< int, int > d_eqc;
+    //pairs that must be disequal
+    std::vector< std::pair< int, int > > d_deq;
+    void print(const char * c);
+    void clear() { d_eqc.clear(); d_deq.clear(); }
+    void set( UnionFind& c );
+    int getRepresentative( int t );
+    void setEqual( int t1, int t2 );
+    void setDisequal( int t1, int t2 ){ d_deq.push_back( std::pair< int, int >( t1, t2 ) ); }
+    bool areEqual( int t1, int t2 ) { return getRepresentative( t1 )==getRepresentative( t2 ); }
+    bool isValid();
+  };
 private:
   int sortCount;
-  std::map< int, int > d_type_union_find;
+  int initialSortCount;
+  UnionFind d_type_union_find;
   std::map< int, TypeNode > d_type_types;
   std::map< TypeNode, int > d_id_for_types;
   //for apply uf operators
@@ -41,7 +61,6 @@ private:
   //for bound variables
   std::map< Node, std::map< Node, int > > d_var_types;
   //get representative
-  int getRepresentative( int t );
   void setEqual( int t1, int t2 );
   int getIdForType( TypeNode tn );
   void printSort( const char* c, int t );
@@ -61,14 +80,22 @@ private:
   //simplify
   Node simplify( Node n, std::map< Node, Node >& var_bound );
 public:
-  SortInference() : sortCount( 0 ){}
+  SortInference() : sortCount( 1 ){}
   ~SortInference(){}
 
   void simplify( std::vector< Node >& assertions, bool doRewrite = false );
+  //get sort id for term n
   int getSortId( Node n );
+  //get sort id for variable of quantified formula f
   int getSortId( Node f, Node v );
   //set that sk is the skolem variable of v for quantifier f
   void setSkolemVar( Node f, Node v, Node sk );
+public:
+  //is well sorted
+  bool isWellSortedFormula( Node n );
+  bool isWellSorted( Node n );
+  //get constraints for being well-typed according to computed sub-types
+  void getSortConstraints( Node n, SortInference::UnionFind& uf );
 };
 
 }