#include "theory/quantifiers/fun_def_process.h"
#include "theory/rewriter.h"
+#include "theory/quantifiers/term_database.h"
+#include "theory/quantifiers/quant_util.h"
using namespace CVC4;
using namespace std;
void FunDefFmf::simplify( std::vector< Node >& assertions, bool doRewrite ) {
+ std::vector< int > fd_assertions;
+ //first pass : find defined functions, transform quantifiers
+ for( unsigned i=0; i<assertions.size(); i++ ){
+ if( assertions[i].getKind()==FORALL ){
+ if( quantifiers::TermDb::isFunDef( assertions[i] ) ){
+ Assert( assertions[i][1].getKind()==EQUAL || assertions[i][1].getKind()==IFF );
+ Node n = assertions[i][1][0];
+ Assert( n.getKind()==APPLY_UF );
+ Node f = n.getOperator();
+
+ //create a sort S that represents the inputs of the function
+ std::stringstream ss;
+ ss << "I_" << f;
+ TypeNode iType = NodeManager::currentNM()->mkSort( ss.str() );
+ d_sorts[f] = iType;
+
+ //create functions f1...fn mapping from this sort to concrete elements
+ for( unsigned j=0; j<n.getNumChildren(); j++ ){
+ TypeNode typ = NodeManager::currentNM()->mkFunctionType( iType, n[j].getType() );
+ std::stringstream ss;
+ ss << f << "_arg_" << j;
+ d_input_arg_inj[f].push_back( NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during fun def fmf" ) );
+ }
+
+ //construct new quantifier forall S. F[f1(S)/x1....fn(S)/xn]
+ std::vector< Node > children;
+ Node bv = NodeManager::currentNM()->mkBoundVar("?i", iType );
+ Node bvl = NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, bv );
+ std::vector< Node > subs;
+ std::vector< Node > vars;
+ for( unsigned j=0; j<n.getNumChildren(); j++ ){
+ vars.push_back( n[j] );
+ subs.push_back( NodeManager::currentNM()->mkNode( APPLY_UF, d_input_arg_inj[f][j], bv ) );
+ }
+ Node bd = assertions[i][1].substitute( vars.begin(), vars.end(), subs.begin(), subs.end() );
+
+ Trace("fmf-fun-def") << "FMF fun def: rewrite " << assertions[i] << " to ";
+ assertions[i] = NodeManager::currentNM()->mkNode( FORALL, bvl, bd );
+ Trace("fmf-fun-def") << assertions[i] << std::endl;
+ fd_assertions.push_back( i );
+ }
+ }
+ }
+ //second pass : rewrite assertions
+ for( unsigned i=0; i<assertions.size(); i++ ){
+ bool is_fd = std::find( fd_assertions.begin(), fd_assertions.end(), i )!=fd_assertions.end();
+ std::vector< Node > constraints;
+ Node n = simplify( assertions[i], true, true, constraints, is_fd );
+ Assert( constraints.empty() );
+ if( n!=assertions[i] ){
+ n = Rewriter::rewrite( n );
+ Trace("fmf-fun-def-rewrite") << "FMF fun def : rewrite " << assertions[i] << " to " << n << std::endl;
+ assertions[i] = n;
+ }
+ }
+}
+
+Node FunDefFmf::simplify( Node n, bool pol, bool hasPol, std::vector< Node >& constraints, bool is_fun_def ) {
+ Trace("fmf-fun-def-debug") << "Simplify " << n << " " << pol << " " << hasPol << " " << is_fun_def << std::endl;
+ if( n.getKind()==FORALL ){
+ Node c = simplify( n[1], pol, hasPol, constraints, is_fun_def );
+ if( c!=n[1] ){
+ return NodeManager::currentNM()->mkNode( FORALL, n[0], c );
+ }else{
+ return n;
+ }
+ }else if( n.getType().isBoolean() && n.getKind()!=APPLY_UF ){
+ std::vector< Node > children;
+ bool childChanged = false;
+ for( unsigned i=0; i<n.getNumChildren(); i++ ){
+ Node c = n[i];
+ //do not process LHS of definition
+ if( !is_fun_def || i!=0 ){
+ bool newHasPol;
+ bool newPol;
+ QuantPhaseReq::getPolarity( n, i, hasPol, pol, newHasPol, newPol );
+ //get child constraints
+ std::vector< Node > cconstraints;
+ c = simplify( n[i], newPol, newHasPol, cconstraints );
+ constraints.insert( constraints.end(), cconstraints.begin(), cconstraints.end() );
+ }
+ children.push_back( c );
+ childChanged = c!=n[i] || childChanged;
+ }
+ if( !constraints.empty() || childChanged ){
+ std::vector< Node > c;
+ if( childChanged ){
+ c.push_back( NodeManager::currentNM()->mkNode( n.getKind(), children ) );
+ }else{
+ c.push_back( n );
+ }
+ if( hasPol ){
+ //conjoin with current
+ for( unsigned i=0; i<constraints.size(); i++ ){
+ if( pol ){
+ c.push_back( constraints[i] );
+ }else{
+ c.push_back( constraints[i].negate() );
+ }
+ }
+ constraints.clear();
+ }else{
+ //must add at higher level
+ }
+ return c.size()==1 ? c[0] : NodeManager::currentNM()->mkNode( AND, c );
+ }
+ }else{
+ //simplify term
+ simplifyTerm( n, constraints );
+ }
+ return n;
+}
+void FunDefFmf::simplifyTerm( Node n, std::vector< Node >& constraints ) {
+ Trace("fmf-fun-def-debug") << "Simplify term " << n << std::endl;
+ if( n.getKind()==ITE ){
+ simplifyTerm( n[0], constraints );
+ std::vector< Node > ccons1;
+ std::vector< Node > ccons2;
+ simplifyTerm( n[1], ccons1 );
+ simplifyTerm( n[2], ccons2 );
+ if( !ccons1.empty() || !ccons2.empty() ){
+ Node n1 = ccons1.empty() ? NodeManager::currentNM()->mkConst( true ) : ( ccons1.size()==1 ? ccons1[0] : NodeManager::currentNM()->mkNode( AND, ccons1 ) );
+ Node n2 = ccons2.empty() ? NodeManager::currentNM()->mkConst( true ) : ( ccons2.size()==1 ? ccons2[0] : NodeManager::currentNM()->mkNode( AND, ccons2 ) );
+ constraints.push_back( NodeManager::currentNM()->mkNode( ITE, n[0], n1, n2 ) );
+ }
+ }else{
+ if( n.getKind()==APPLY_UF ){
+ //check if f is defined, if so, we must enforce domain constraints for this f-application
+ Node f = n.getOperator();
+ std::map< Node, TypeNode >::iterator it = d_sorts.find( f );
+ if( it!=d_sorts.end() ){
+ //create existential
+ Node z = NodeManager::currentNM()->mkBoundVar("?z", it->second );
+ Node bvl = NodeManager::currentNM()->mkNode( BOUND_VAR_LIST, z );
+ std::vector< Node > children;
+ for( unsigned j=0; j<n.getNumChildren(); j++ ){
+ Node uz = NodeManager::currentNM()->mkNode( APPLY_UF, d_input_arg_inj[f][j], z );
+ if( !n[j].getType().isBoolean() ){
+ children.push_back( uz.eqNode( n[j] ) );
+ }else{
+ children.push_back( uz.iffNode( n[j] ) );
+ }
+ }
+ Node bd = children.size()==1 ? children[0] : NodeManager::currentNM()->mkNode( AND, children );
+ bd = bd.negate();
+ Node ex = NodeManager::currentNM()->mkNode( FORALL, bvl, bd );
+ ex = ex.negate();
+ constraints.push_back( ex );
+ Trace("fmf-fun-def-debug") << "---> add constraint " << ex << std::endl;
+ }
+ }
+ for( unsigned i=0; i<n.getNumChildren(); i++ ){
+ simplifyTerm( n[i], constraints );
+ }
+ }
}