From ad0863ae8333c4dcd950153e0db8cd4565a250b3 Mon Sep 17 00:00:00 2001 From: ajreynol Date: Fri, 12 Jun 2015 14:15:14 +0200 Subject: [PATCH] Accelerate sygus solution reconstruction for constants and id functions. Minor changes to sygus type registration. Print sygus let solutions assuming fixed variable names. --- src/parser/smt2/smt2.cpp | 15 +- .../quantifiers/ce_guided_single_inv_sol.cpp | 218 ++++++++++-------- src/theory/quantifiers/term_database.cpp | 165 ++++++++++++- src/theory/quantifiers/term_database.h | 15 +- src/util/datatype.cpp | 5 + src/util/datatype.h | 2 + 6 files changed, 304 insertions(+), 116 deletions(-) diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index 47022da3e..8ed8e40a1 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -1001,17 +1001,22 @@ void Smt2::mkSygusDatatype( CVC4::Datatype& dt, std::vector& ops, void Smt2::addSygusDatatypeConstructor( CVC4::Datatype& dt, CVC4::Expr op, std::string& cname, std::vector< CVC4::Type >& cargs, CVC4::Expr& let_body, std::vector< CVC4::Expr >& let_args, unsigned let_num_input_args ) { - + Debug("parser-sygus") << "--> Add constructor " << cname << " to " << dt.getName() << std::endl; + if( !let_body.isNull() ){ + Debug("parser-sygus") << " let body = " << let_body << ", args = " << let_args.size() << "," << let_num_input_args << std::endl; + //TODO : remove arguments not occurring in body + //if this is a self identity function, ignore + if( let_args.size()==0 && let_args[0]==let_body ){ + Debug("parser-sygus") << " identity function " << cargs[0] << " to " << dt.getName() << std::endl; + //TODO + } + } std::string name = dt.getName() + "_" + cname; std::string testerId("is-"); testerId.append(name); checkDeclaration(name, CHECK_UNDECLARED, SYM_VARIABLE); checkDeclaration(testerId, CHECK_UNDECLARED, SYM_VARIABLE); CVC4::DatatypeConstructor c(name, testerId ); - Debug("parser-sygus") << "--> Add constructor " << cname << " to " << dt.getName() << std::endl; - if( !let_body.isNull() ){ - Debug("parser-sygus") << " let body = " << let_body << ", args = " << let_args.size() << "," << let_num_input_args << std::endl; - } c.setSygus( op, let_body, let_args, let_num_input_args ); for( unsigned j=0; jgetTermDatabaseSygus()->getOpArg( stn, min_t ); if( carg!=-1 ){ - Trace("csi-rcons-debug") << " Type has constant." << std::endl; + Trace("csi-rcons-debug") << " Type has operator." << std::endl; d_reconstruct[id] = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[carg].getConstructor() ) ); status = 0; }else{ @@ -754,117 +754,140 @@ int CegConjectureSingleInvSol::collectReconstructNodes( Node t, TypeNode stn, in Trace("csi-rcons-debug") << "Type for " << id << " has kind " << min_t.getKind() << ", but argument # mismatch." << std::endl; } } - } - if( status!=0 ){ - if( min_t.isConst() ){ - Node min_t_c = d_qe->getTermDatabaseSygus()->builtinToSygusConst( min_t, stn ); - if( !min_t_c.isNull() ){ - d_reconstruct[id] = min_t_c; - status = 0; - } - } if( status!=0 ){ - Trace("csi-rcons-debug") << "Try matching for " << id << "." << std::endl; - //try other options - //match against other constructors - bool success; - int c_index = 0; - do{ - success = false; - int index_found; - std::vector< Node > args; - if( d_qe->getTermDatabaseSygus()->getMatch( min_t, stn, index_found, args, karg, c_index ) ){ - success = true; + //try constant reconstruction + if( min_t.isConst() ){ + Trace("csi-rcons-debug") << "...try constant reconstruction." << std::endl; + Node min_t_c = d_qe->getTermDatabaseSygus()->builtinToSygusConst( min_t, stn ); + if( !min_t_c.isNull() ){ + Trace("csi-rcons-debug") << " constant reconstruction success for " << id << ", result = " << min_t_c << std::endl; + d_reconstruct[id] = min_t_c; status = 0; - Node cons = Node::fromExpr( dt[index_found].getConstructor() ); - Trace("csi-rcons-debug") << "Try alternative for " << id << ", matching " << dt[index_found].getName() << " with children : " << std::endl; - for( unsigned i=0; igetTermDatabaseSygus()->getNumIdFuncs( stn ); i++ ){ + unsigned ii = d_qe->getTermDatabaseSygus()->getIdFuncIndex( stn, i ); + Assert( dt[ii].getNumArgs()==1 ); + //try to directly reconstruct from single argument + std::vector< Node > tchildren; + tchildren.push_back( min_t ); + TypeNode stnc = TypeNode::fromType( ((SelectorType)dt[ii][0].getType()).getRangeType() ); + Trace("csi-rcons-debug") << "...try identity function " << dt[ii].getSygusOp() << ", child type is " << stnc << std::endl; + status = 0; + Node cons = Node::fromExpr( dt[ii].getConstructor() ); + if( !collectReconstructNodes( id, tchildren, dt[ii], d_reconstruct_op[id][cons], status ) ){ d_reconstruct_op[id].erase( cons ); status = 1; }else{ - c_index = index_found+1; + Trace("csi-rcons-debug") << " identity function success for " << id << std::endl; + break; } } - }while( success && status!=0 ); - - if( status!=0 ){ - // construct an equivalence class of terms that are equivalent to t - if( d_rep[id]==id ){ - Trace("csi-rcons-debug") << "Try rewriting for " << id << "." << std::endl; - //get equivalence class of term - std::vector< Node > equiv; - if( tn.isBoolean() ){ - Node curr = min_t; - Node new_t; - do{ - new_t = Node::null(); - if( curr.getKind()==EQUAL && ( curr[0].getType().isInteger() || curr[0].getType().isReal() ) ){ - new_t = NodeManager::currentNM()->mkNode( AND, NodeManager::currentNM()->mkNode( LEQ, curr[0], curr[1] ), - NodeManager::currentNM()->mkNode( LEQ, curr[1], curr[0] ) ); - }else if( curr.getKind()==ITE ){ - new_t = NodeManager::currentNM()->mkNode( OR, NodeManager::currentNM()->mkNode( AND, curr[0], curr[1] ), - NodeManager::currentNM()->mkNode( AND, curr[0].negate(), curr[2] ) ); - }else if( curr.getKind()==IFF ){ - new_t = NodeManager::currentNM()->mkNode( OR, NodeManager::currentNM()->mkNode( AND, curr[0], curr[1] ), - NodeManager::currentNM()->mkNode( AND, curr[0].negate(), curr[1].negate() ) ); - }else if( curr.getKind()==OR || curr.getKind()==AND ){ - new_t = TermDb::simpleNegate( curr ).negate(); - }else if( curr.getKind()==NOT ){ - new_t = TermDb::simpleNegate( curr[0] ); + if( status!=0 ){ + //try other options, such as matching against other constructors + Trace("csi-rcons-debug") << "Try matching for " << id << "." << std::endl; + bool success; + int c_index = 0; + do{ + success = false; + int index_found; + std::vector< Node > args; + if( d_qe->getTermDatabaseSygus()->getMatch( min_t, stn, index_found, args, karg, c_index ) ){ + success = true; + status = 0; + Node cons = Node::fromExpr( dt[index_found].getConstructor() ); + Trace("csi-rcons-debug") << "Try alternative for " << id << ", matching " << dt[index_found].getName() << " with children : " << std::endl; + for( unsigned i=0; imkNode( NOT, NodeManager::currentNM()->mkNode( NOT, curr ) ); + c_index = index_found+1; } - if( !new_t.isNull() ){ - if( new_t!=min_t && std::find( equiv.begin(), equiv.end(), new_t )==equiv.end() ){ - curr = new_t; - equiv.push_back( new_t ); - }else{ + } + }while( success && status!=0 ); + + if( status!=0 ){ + // construct an equivalence class of terms that are equivalent to t + if( d_rep[id]==id ){ + Trace("csi-rcons-debug") << "Try rewriting for " << id << "." << std::endl; + //get equivalence class of term + std::vector< Node > equiv; + if( tn.isBoolean() ){ + Node curr = min_t; + Node new_t; + do{ new_t = Node::null(); + if( curr.getKind()==EQUAL && ( curr[0].getType().isInteger() || curr[0].getType().isReal() ) ){ + new_t = NodeManager::currentNM()->mkNode( AND, NodeManager::currentNM()->mkNode( LEQ, curr[0], curr[1] ), + NodeManager::currentNM()->mkNode( LEQ, curr[1], curr[0] ) ); + }else if( curr.getKind()==ITE ){ + new_t = NodeManager::currentNM()->mkNode( OR, NodeManager::currentNM()->mkNode( AND, curr[0], curr[1] ), + NodeManager::currentNM()->mkNode( AND, curr[0].negate(), curr[2] ) ); + }else if( curr.getKind()==IFF ){ + new_t = NodeManager::currentNM()->mkNode( OR, NodeManager::currentNM()->mkNode( AND, curr[0], curr[1] ), + NodeManager::currentNM()->mkNode( AND, curr[0].negate(), curr[1].negate() ) ); + }else if( curr.getKind()==OR || curr.getKind()==AND ){ + new_t = TermDb::simpleNegate( curr ).negate(); + }else if( curr.getKind()==NOT ){ + new_t = TermDb::simpleNegate( curr[0] ); + }else{ + new_t = NodeManager::currentNM()->mkNode( NOT, NodeManager::currentNM()->mkNode( NOT, curr ) ); + } + if( !new_t.isNull() ){ + if( new_t!=min_t && std::find( equiv.begin(), equiv.end(), new_t )==equiv.end() ){ + curr = new_t; + equiv.push_back( new_t ); + }else{ + new_t = Node::null(); + } + } + }while( !new_t.isNull() ); + } + for( unsigned i=0; igetTermDatabaseSygus()->getArgKind( stn, i ); + if( k==AND || k==OR ){ + equiv.push_back( NodeManager::currentNM()->mkNode( k, min_t, min_t ) ); + equiv.push_back( NodeManager::currentNM()->mkNode( k, min_t, NodeManager::currentNM()->mkConst( k==AND ) ) ); } } - }while( !new_t.isNull() ); - } - for( unsigned i=0; igetTermDatabaseSygus()->getArgKind( stn, i ); - if( k==AND || k==OR ){ - equiv.push_back( NodeManager::currentNM()->mkNode( k, min_t, min_t ) ); - equiv.push_back( NodeManager::currentNM()->mkNode( k, min_t, NodeManager::currentNM()->mkConst( k==AND ) ) ); - } - } - //assign ids to terms - Trace("csi-rcons-debug") << "Term " << id << " is equivalent to " << equiv.size() << " terms : " << std::endl; - std::vector< int > equiv_ids; - for( unsigned i=0; i equiv_ids; + for( unsigned i=0; i::iterator it = pre.find( i ); if( it!=pre.end() ){ @@ -1589,14 +1592,18 @@ Node TermDbSygus::sygusToBuiltin( Node n, TypeNode tn ) { } } -Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn ) { +//rcons_depth limits the number of recursive calls when doing accelerated constant reconstruction (currently limited to 1000) +//this is hacky : depending upon order of calls, constant rcons may succeed, e.g. 1001, 999 vs. 999, 1001 +Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn, int rcons_depth ) { std::map< Node, Node >::iterator it = d_builtin_const_to_sygus[tn].find( c ); if( it==d_builtin_const_to_sygus[tn].end() ){ + Node sc; + d_builtin_const_to_sygus[tn][c] = sc; Assert( c.isConst() ); Assert( datatypes::DatatypesRewriter::isTypeDatatype( tn ) ); const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + Trace("csi-rcons-debug") << "Try to reconstruct " << c << " in " << dt.getName() << std::endl; Assert( dt.isSygus() ); - Node sc; // if we are not interested in reconstructing constants, or the grammar allows them, return a proxy if( !options::cegqiSingleInvReconstructConst() || dt.getSygusAllowConst() ){ Node k = NodeManager::currentNM()->mkSkolem( "sy", tn, "sygus proxy" ); @@ -1606,9 +1613,60 @@ Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn ) { }else{ int carg = getOpArg( tn, c ); if( carg!=-1 ){ - sc = Node::fromExpr( dt[carg].getSygusOp() ); + //sc = Node::fromExpr( dt[carg].getSygusOp() ); + sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[carg].getConstructor() ) ); }else{ - //TODO + //identity functions + for( unsigned i=0; imkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[ii].getConstructor() ), n ); + break; + } + } + if( sc.isNull() ){ + if( rcons_depth<1000 ){ + //accelerated, recursive reconstruction of constants + Kind pk = getPlusKind( TypeNode::fromType( dt.getSygusType() ) ); + if( pk!=UNDEFINED_KIND ){ + int arg = getKindArg( tn, pk ); + if( arg!=-1 ){ + Kind ck = getComparisonKind( TypeNode::fromType( dt.getSygusType() ) ); + Kind pkm = getPlusKind( TypeNode::fromType( dt.getSygusType() ), true ); + //get types + Assert( dt[arg].getNumArgs()==2 ); + TypeNode tn1 = getArgType( dt[arg], 0 ); + TypeNode tn2 = getArgType( dt[arg], 1 ); + //iterate over all positive constants, largest to smallest + int start = d_const_list[tn1].size()-1; + int end = d_const_list[tn1].size()-d_const_list_pos[tn1]; + for( int i=start; i>=end; --i ){ + Node c1 = d_const_list[tn1][i]; + //only consider if smaller than c, and + if( doCompare( c1, c, ck ) ){ + Node c2 = NodeManager::currentNM()->mkNode( pkm, c, c1 ); + c2 = Rewriter::rewrite( c2 ); + if( c2.isConst() ){ + //reconstruct constant on the other side + Node sc2 = builtinToSygusConst( c2, tn2, rcons_depth+1 ); + if( !sc2.isNull() ){ + Node sc1 = builtinToSygusConst( c1, tn1, rcons_depth ); + Assert( !sc1.isNull() ); + sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[arg].getConstructor() ), sc1, sc2 ); + break; + } + } + } + } + } + } + } + } } } d_builtin_const_to_sygus[tn][c] = sc; @@ -1798,7 +1856,7 @@ Node TermDbSygus::getTypeValue( TypeNode tn, int val ) { n = NodeManager::currentNM()->mkConst(bval); }else if( tn.isBoolean() ){ if( val==0 ){ - n = NodeManager::currentNM()->mkConst( false ); + n = d_false; } } d_type_value[tn][val] = n; @@ -1815,7 +1873,7 @@ Node TermDbSygus::getTypeMaxValue( TypeNode tn ) { if( tn.isBitVector() ){ n = bv::utils::mkOnes(tn.getConst()); }else if( tn.isBoolean() ){ - n = NodeManager::currentNM()->mkConst( true ); + n = d_true; } d_type_max_value[tn] = n; return n; @@ -1847,6 +1905,18 @@ Node TermDbSygus::getTypeValueOffset( TypeNode tn, Node val, int offset, int& st } } +struct sortConstants { + TermDbSygus * d_tds; + Kind d_comp_kind; + bool operator() (Node i, Node j) { + if( i!=j ){ + return d_tds->doCompare( i, j, d_comp_kind ); + }else{ + return false; + } + } +}; + void TermDbSygus::registerSygusType( TypeNode tn ){ if( d_register.find( tn )==d_register.end() ){ if( !datatypes::DatatypesRewriter::isTypeDatatype( tn ) ){ @@ -1858,6 +1928,11 @@ void TermDbSygus::registerSygusType( TypeNode tn ){ if( d_register[tn].isNull() ){ Trace("sygus-util") << "...not sygus." << std::endl; }else{ + //for constant reconstruction + Kind ck = getComparisonKind( TypeNode::fromType( dt.getSygusType() ) ); + Node z = getTypeValue( TypeNode::fromType( dt.getSygusType() ), 0 ); + d_const_list_pos[tn] = 0; + //iterate over constructors for( unsigned i=0; i=0 && i<(int)c.getNumArgs() ); return TypeNode::fromType( ((SelectorType)c[i].getType()).getRangeType() ); @@ -2041,7 +2152,33 @@ Node TermDbSygus::expandBuiltinTerm( Node t ){ } -void doReplace(std::string& str, const std::string& oldStr, const std::string& newStr){ +Kind TermDbSygus::getComparisonKind( TypeNode tn ) { + if( tn.isInteger() ){ + return LT; + }else if( tn.isBitVector() ){ + return BITVECTOR_ULT; + }else{ + return UNDEFINED_KIND; + } +} + +Kind TermDbSygus::getPlusKind( TypeNode tn, bool is_neg ) { + if( tn.isInteger() ){ + return is_neg ? MINUS : PLUS; + }else if( tn.isBitVector() ){ + return is_neg ? BITVECTOR_SUB : BITVECTOR_PLUS; + }else{ + return UNDEFINED_KIND; + } +} + +bool TermDbSygus::doCompare( Node a, Node b, Kind k ) { + Node com = NodeManager::currentNM()->mkNode( k, a, b ); + com = Rewriter::rewrite( com ); + return com==d_true; +} + +void doStrReplace(std::string& str, const std::string& oldStr, const std::string& newStr){ size_t pos = 0; while((pos = str.find(oldStr, pos)) != std::string::npos){ str.replace(pos, oldStr.length(), newStr); @@ -2101,12 +2238,16 @@ void TermDbSygus::printSygusTerm( std::ostream& out, Node n, std::vector< Node > std::stringstream body_out; printSygusTerm( body_out, let_body, new_lvs ); std::string body = body_out.str(); - for( unsigned i=dt[cIndex].getNumSygusLetInputArgs(); i=dt[cIndex].getNumSygusLetInputArgs() ){ + printSygusTerm( new_str, n[i], lvs ); + }else{ + new_str << Node::fromExpr( dt[cIndex].getSygusLetArg( i ) ); + } + doStrReplace( body, old_str.str().c_str(), new_str.str().c_str() ); } out << body; if( dt[cIndex].getNumSygusLetInputArgs()>0 ){ diff --git a/src/theory/quantifiers/term_database.h b/src/theory/quantifiers/term_database.h index 0bb2c3224..455287feb 100644 --- a/src/theory/quantifiers/term_database.h +++ b/src/theory/quantifiers/term_database.h @@ -382,6 +382,8 @@ private: std::map< TypeNode, std::vector< Node > > d_fv; std::map< Node, TypeNode > d_fv_stype; std::map< Node, int > d_fv_num; + Node d_true; + Node d_false; public: TNode getVar( TypeNode tn, int i ); TNode getVarInc( TypeNode tn, std::map< TypeNode, int >& var_count ); @@ -404,6 +406,9 @@ private: std::map< TypeNode, std::map< Node, int > > d_consts; std::map< TypeNode, std::map< Node, int > > d_ops; std::map< TypeNode, std::map< int, Node > > d_arg_ops; + std::map< TypeNode, std::vector< int > > d_id_funcs; + std::map< TypeNode, std::vector< Node > > d_const_list; //sorted list of constants for type + std::map< TypeNode, unsigned > d_const_list_pos; //information for builtin types std::map< TypeNode, std::map< int, Node > > d_type_value; std::map< TypeNode, Node > d_type_max_value; @@ -414,7 +419,7 @@ private: std::map< TypeNode, std::map< Node, Node > > d_sygus_to_builtin; std::map< TypeNode, std::map< Node, Node > > d_builtin_const_to_sygus; public: - TermDbSygus(){} + TermDbSygus(); bool isRegistered( TypeNode tn ); int getKindArg( TypeNode tn, Kind k ); int getConstArg( TypeNode tn, Node n ); @@ -427,6 +432,8 @@ public: Kind getArgKind( TypeNode tn, int i ); bool isKindArg( TypeNode tn, int i ); bool isConstArg( TypeNode tn, int i ); + unsigned getNumIdFuncs( TypeNode tn ); + unsigned getIdFuncIndex( TypeNode tn, unsigned i ); void registerSygusType( TypeNode tn ); /** get arg type */ TypeNode getArgType( const DatatypeConstructor& c, int i ); @@ -451,7 +458,7 @@ public: TypeNode getSygusType( Node v ); Node mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >& var_count, std::map< int, Node >& pre ); Node sygusToBuiltin( Node n, TypeNode tn ); - Node builtinToSygusConst( Node c, TypeNode tn ); + Node builtinToSygusConst( Node c, TypeNode tn, int rcons_depth = 0 ); Node getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs ); Node getNormalized( TypeNode t, Node prog, bool do_pre_norm = false, bool do_post_norm = true ); int getTermSize( Node n ); @@ -459,6 +466,10 @@ public: Node minimizeBuiltinTerm( Node n ); /** given a term, expand it into more basic components */ Node expandBuiltinTerm( Node n ); + /** get comparison kind */ + Kind getComparisonKind( TypeNode tn ); + Kind getPlusKind( TypeNode tn, bool is_neg = false ); + bool doCompare( Node a, Node b, Kind k ); /** print sygus term */ static void printSygusTerm( std::ostream& out, Node n, std::vector< Node >& lvs ); }; diff --git a/src/util/datatype.cpp b/src/util/datatype.cpp index b1ab011ef..138fb4bb0 100644 --- a/src/util/datatype.cpp +++ b/src/util/datatype.cpp @@ -707,6 +707,11 @@ unsigned DatatypeConstructor::getNumSygusLetInputArgs() const { CheckArgument(isResolved(), this, "this datatype constructor is not yet resolved"); return d_sygus_num_let_input_args; } + +bool DatatypeConstructor::isSygusIdFunc() const { + CheckArgument(isResolved(), this, "this datatype constructor is not yet resolved"); + return d_sygus_let_args.size()==1 && d_sygus_let_args[0]==d_sygus_let_body; +} Cardinality DatatypeConstructor::getCardinality() const throw(IllegalArgumentException) { CheckArgument(isResolved(), this, "this datatype constructor is not yet resolved"); diff --git a/src/util/datatype.h b/src/util/datatype.h index 1945c4390..0b8b8c61f 100644 --- a/src/util/datatype.h +++ b/src/util/datatype.h @@ -294,6 +294,8 @@ public: Expr getSygusLetArg( unsigned i ) const; /** get number of let arguments that should be printed as arguments to let */ unsigned getNumSygusLetInputArgs() const; + /** is this a sygus identity function */ + bool isSygusIdFunc() const; /** * Get the tester name for this Datatype constructor. -- 2.30.2