From 4748af3ee298ce5aae36a8ab8cad4426d1398c17 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 15 Aug 2018 13:02:46 -0500 Subject: [PATCH] Make sort inference a preprocessing pass (#2309) --- src/Makefile.am | 2 + src/preprocessing/passes/sort_infer.cpp | 85 +++++ src/preprocessing/passes/sort_infer.h | 58 +++ src/smt/smt_engine.cpp | 18 +- src/theory/sort_inference.cpp | 385 ++++++++++---------- src/theory/sort_inference.h | 73 +++- test/regress/Makefile.tests | 1 + test/regress/regress0/fmf/sort-inf-int.smt2 | 13 + test/regress/regress1/fmf/ALG008-1.smt2 | 1 + 9 files changed, 420 insertions(+), 216 deletions(-) create mode 100644 src/preprocessing/passes/sort_infer.cpp create mode 100644 src/preprocessing/passes/sort_infer.h create mode 100644 test/regress/regress0/fmf/sort-inf-int.smt2 diff --git a/src/Makefile.am b/src/Makefile.am index c2a620f57..43aa70174 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -85,6 +85,8 @@ libcvc4_la_SOURCES = \ preprocessing/passes/rewrite.h \ preprocessing/passes/sep_skolem_emp.cpp \ preprocessing/passes/sep_skolem_emp.h \ + preprocessing/passes/sort_infer.cpp \ + preprocessing/passes/sort_infer.h \ preprocessing/passes/static_learning.cpp \ preprocessing/passes/static_learning.h \ preprocessing/passes/symmetry_breaker.cpp \ diff --git a/src/preprocessing/passes/sort_infer.cpp b/src/preprocessing/passes/sort_infer.cpp new file mode 100644 index 000000000..e2b0bfb59 --- /dev/null +++ b/src/preprocessing/passes/sort_infer.cpp @@ -0,0 +1,85 @@ +/********************* */ +/*! \file sort_infer.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Sort inference preprocessing pass + **/ + +#include "preprocessing/passes/sort_infer.h" + +#include "options/smt_options.h" +#include "options/uf_options.h" +#include "theory/rewriter.h" + +using namespace std; + +namespace CVC4 { +namespace preprocessing { +namespace passes { + +SortInferencePass::SortInferencePass(PreprocessingPassContext* preprocContext, + SortInference* si) + : PreprocessingPass(preprocContext, "sort-inference"), d_si(si) +{ +} + +PreprocessingPassResult SortInferencePass::applyInternal( + AssertionPipeline* assertionsToPreprocess) +{ + if (options::sortInference()) + { + d_si->initialize(assertionsToPreprocess->ref()); + std::map model_replace_f; + std::map > visited; + for (unsigned i = 0, size = assertionsToPreprocess->size(); i < size; i++) + { + Node prev = (*assertionsToPreprocess)[i]; + Node next = d_si->simplify(prev, model_replace_f, visited); + if (next != prev) + { + next = theory::Rewriter::rewrite(next); + assertionsToPreprocess->replace(i, next); + Trace("sort-infer-preprocess") + << "*** Preprocess SortInferencePass " << prev << endl; + Trace("sort-infer-preprocess") + << " ...got " << (*assertionsToPreprocess)[i] << endl; + } + } + std::vector newAsserts; + d_si->getNewAssertions(newAsserts); + for (const Node& na : newAsserts) + { + Node nar = theory::Rewriter::rewrite(na); + Trace("sort-infer-preprocess") + << "*** Preprocess SortInferencePass : new constraint " << nar + << endl; + assertionsToPreprocess->push_back(nar); + } + // indicate correspondence between the functions + // TODO (#2308): move this to a better place + SmtEngine* smt = smt::currentSmtEngine(); + for (const std::pair& mrf : model_replace_f) + { + smt->setPrintFuncInModel(mrf.first.toExpr(), false); + smt->setPrintFuncInModel(mrf.second.toExpr(), true); + } + } + // only need to compute monotonicity on the resulting formula if we are + // using this option + if (options::ufssFairnessMonotone()) + { + d_si->computeMonotonicity(assertionsToPreprocess->ref()); + } + return PreprocessingPassResult::NO_CONFLICT; +} + +} // namespace passes +} // namespace preprocessing +} // namespace CVC4 diff --git a/src/preprocessing/passes/sort_infer.h b/src/preprocessing/passes/sort_infer.h new file mode 100644 index 000000000..e56d7ab60 --- /dev/null +++ b/src/preprocessing/passes/sort_infer.h @@ -0,0 +1,58 @@ +/********************* */ +/*! \file sort_infer.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Sort inference preprocessing pass + **/ + +#ifndef __CVC4__PREPROCESSING__PASSES__SORT_INFERENCE_PASS_H_ +#define __CVC4__PREPROCESSING__PASSES__SORT_INFERENCE_PASS_H_ + +#include +#include +#include +#include "expr/node.h" + +#include "preprocessing/preprocessing_pass.h" +#include "preprocessing/preprocessing_pass_context.h" +#include "theory/sort_inference.h" + +namespace CVC4 { +namespace preprocessing { +namespace passes { + +/** SortInferencePass + * + * This preprocessing pass runs sort inference techniques on the input formula. + * For details on these techniques, see theory/sort_inference.h. + */ +class SortInferencePass : public PreprocessingPass +{ + public: + SortInferencePass(PreprocessingPassContext* preprocContext, + SortInference* si); + + protected: + PreprocessingPassResult applyInternal( + AssertionPipeline* assertionsToPreprocess) override; + + private: + /** + * Pointer to the sort inference module. This should be the sort inference + * belonging to the theory engine of the current SMT engine. + */ + SortInference* d_si; +}; + +} // namespace passes +} // namespace preprocessing +} // namespace CVC4 + +#endif /* __CVC4__PREPROCESSING__PASSES__SORT_INFERENCE_PASS_H_ */ diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 38f6a2d5e..cc6f09801 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -80,6 +80,7 @@ #include "preprocessing/passes/real_to_int.h" #include "preprocessing/passes/rewrite.h" #include "preprocessing/passes/sep_skolem_emp.h" +#include "preprocessing/passes/sort_infer.h" #include "preprocessing/passes/static_learning.h" #include "preprocessing/passes/symmetry_breaker.h" #include "preprocessing/passes/symmetry_detect.h" @@ -2735,15 +2736,18 @@ void SmtEnginePrivate::finishInit() new RealToInt(d_preprocessingPassContext.get())); std::unique_ptr rewrite( new Rewrite(d_preprocessingPassContext.get())); + std::unique_ptr sortInfer( + new SortInferencePass(d_preprocessingPassContext.get(), + d_smt.d_theoryEngine->getSortInference())); std::unique_ptr staticLearning( new StaticLearning(d_preprocessingPassContext.get())); std::unique_ptr sbProc( new SymBreakerPass(d_preprocessingPassContext.get())); std::unique_ptr srrProc( new SynthRewRulesPass(d_preprocessingPassContext.get())); - std::unique_ptr sepSkolemEmp( + std::unique_ptr sepSkolemEmp( new SepSkolemEmp(d_preprocessingPassContext.get())); - d_preprocessingPassRegistry.registerPass("apply-substs", + d_preprocessingPassRegistry.registerPass("apply-substs", std::move(applySubsts)); d_preprocessingPassRegistry.registerPass("bool-to-bv", std::move(boolToBv)); d_preprocessingPassRegistry.registerPass("bv-abstraction", @@ -2761,6 +2765,8 @@ void SmtEnginePrivate::finishInit() d_preprocessingPassRegistry.registerPass("rewrite", std::move(rewrite)); d_preprocessingPassRegistry.registerPass("sep-skolem-emp", std::move(sepSkolemEmp)); + d_preprocessingPassRegistry.registerPass("sort-inference", + std::move(sortInfer)); d_preprocessingPassRegistry.registerPass("static-learning", std::move(staticLearning)); d_preprocessingPassRegistry.registerPass("sym-break", std::move(sbProc)); @@ -4332,13 +4338,7 @@ void SmtEnginePrivate::processAssertions() { } if( options::sortInference() || options::ufssFairnessMonotone() ){ - //sort inference technique - SortInference * si = d_smt.d_theoryEngine->getSortInference(); - si->simplify( d_assertions.ref(), options::sortInference(), options::ufssFairnessMonotone() ); - for( std::map< Node, Node >::iterator it = si->d_model_replace_f.begin(); it != si->d_model_replace_f.end(); ++it ){ - d_smt.setPrintFuncInModel( it->first.toExpr(), false ); - d_smt.setPrintFuncInModel( it->second.toExpr(), true ); - } + d_preprocessingPassRegistry.getPass("sort-inference")->apply(&d_assertions); } if( options::pbRewrites() ){ diff --git a/src/theory/sort_inference.cpp b/src/theory/sort_inference.cpp index 96e1e3a38..b6e8f7553 100644 --- a/src/theory/sort_inference.cpp +++ b/src/theory/sort_inference.cpp @@ -102,7 +102,7 @@ void SortInference::reset() { d_non_monotonic_sorts.clear(); d_type_sub_sorts.clear(); //reset info - sortCount = 1; + d_sortCount = 1; d_type_union_find.clear(); d_type_types.clear(); d_id_for_types.clear(); @@ -114,203 +114,191 @@ void SortInference::reset() { d_const_map.clear(); } -void SortInference::simplify( std::vector< Node >& assertions, bool doSortInference, bool doMonotonicyInference ){ - if( doSortInference ){ - Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl; - NodeManager* nm = NodeManager::currentNM(); - //process all assertions - std::map< Node, int > visited; - for( unsigned i=0; i var_bound; - process( assertions[i], var_bound, visited ); - } - Trace("sort-inference-proc") << "...done" << std::endl; - for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){ - Trace("sort-inference") << it->first << " : "; - TypeNode retTn = it->first.getType(); - if( !d_op_arg_types[ it->first ].empty() ){ - Trace("sort-inference") << "( "; - for( size_t i=0; ifirst ].size(); i++ ){ - recordSubsort( retTn[i], d_op_arg_types[ it->first ][i] ); - printSort( "sort-inference", d_op_arg_types[ it->first ][i] ); - Trace("sort-inference") << " "; - } - Trace("sort-inference") << ") -> "; - retTn = retTn[(int)retTn.getNumChildren()-1]; +void SortInference::initialize(const std::vector& assertions) +{ + Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl; + // process all assertions + std::map visited; + for (const Node& a : assertions) + { + Trace("sort-inference-debug") << "Process " << a << std::endl; + std::map var_bound; + process(a, var_bound, visited); + } + Trace("sort-inference-proc") << "...done" << std::endl; + for (const std::pair& rt : d_op_return_types) + { + Trace("sort-inference") << rt.first << " : "; + TypeNode retTn = rt.first.getType(); + if (!d_op_arg_types[rt.first].empty()) + { + Trace("sort-inference") << "( "; + for (size_t i = 0; i < d_op_arg_types[rt.first].size(); i++) + { + recordSubsort(retTn[i], d_op_arg_types[rt.first][i]); + printSort("sort-inference", d_op_arg_types[rt.first][i]); + Trace("sort-inference") << " "; } - recordSubsort( retTn, it->second ); - printSort( "sort-inference", it->second ); - Trace("sort-inference") << std::endl; + Trace("sort-inference") << ") -> "; + retTn = retTn[(int)retTn.getNumChildren() - 1]; } - for( std::map< Node, std::map< Node, int > >::iterator it = d_var_types.begin(); it != d_var_types.end(); ++it ){ - Trace("sort-inference") << "Quantified formula : " << it->first << " : " << std::endl; - for( unsigned i=0; ifirst[0].getNumChildren(); i++ ){ - recordSubsort( it->first[0][i].getType(), it->second[it->first[0][i]] ); - printSort( "sort-inference", it->second[it->first[0][i]] ); - Trace("sort-inference") << std::endl; - } + recordSubsort(retTn, rt.second); + printSort("sort-inference", rt.second); + Trace("sort-inference") << std::endl; + } + for (std::pair >& vt : d_var_types) + { + Trace("sort-inference") + << "Quantified formula : " << vt.first << " : " << std::endl; + for (const Node& v : vt.first[0]) + { + recordSubsort(v.getType(), vt.second[v]); + printSort("sort-inference", vt.second[v]); Trace("sort-inference") << std::endl; } + Trace("sort-inference") << std::endl; + } - bool rewritten = false; - // determine monotonicity of sorts - Trace("sort-inference-proc") << "Calculating monotonicty for subsorts..." - << std::endl; - std::map > visitedm; - for (const Node& a : assertions) + // determine monotonicity of sorts + Trace("sort-inference-proc") + << "Calculating monotonicty for subsorts..." << std::endl; + std::map > visitedm; + for (const Node& a : assertions) + { + Trace("sort-inference-debug") + << "Process monotonicity for " << a << std::endl; + std::map var_bound; + processMonotonic(a, true, true, var_bound, visitedm); + } + Trace("sort-inference-proc") << "...done" << std::endl; + + Trace("sort-inference") << "We have " << d_sub_sorts.size() + << " sub-sorts : " << std::endl; + for (unsigned i = 0, size = d_sub_sorts.size(); i < size; i++) + { + printSort("sort-inference", d_sub_sorts[i]); + if (d_type_types.find(d_sub_sorts[i]) != d_type_types.end()) { - Trace("sort-inference-debug") << "Process monotonicity for " << a - << std::endl; - std::map var_bound; - processMonotonic(a, true, true, var_bound, visitedm); + Trace("sort-inference") << " is interpreted." << std::endl; } - Trace("sort-inference-proc") << "...done" << std::endl; - - Trace("sort-inference") << "We have " << d_sub_sorts.size() - << " sub-sorts : " << std::endl; - for (unsigned i = 0, size = d_sub_sorts.size(); i < size; i++) + else if (d_non_monotonic_sorts.find(d_sub_sorts[i]) + == d_non_monotonic_sorts.end()) { - printSort("sort-inference", d_sub_sorts[i]); - if (d_type_types.find(d_sub_sorts[i]) != d_type_types.end()) - { - Trace("sort-inference") << " is interpreted." << std::endl; - } - else if (d_non_monotonic_sorts.find(d_sub_sorts[i]) - == d_non_monotonic_sorts.end()) - { - Trace("sort-inference") << " is monotonic." << std::endl; - } - else - { - Trace("sort-inference") << " is not monotonic." << std::endl; - } + Trace("sort-inference") << " is monotonic." << std::endl; } + else + { + Trace("sort-inference") << " is not monotonic." << std::endl; + } + } +} - // simplify all assertions by introducing new symbols wherever necessary - Trace("sort-inference-proc") << "Perform simplification..." << std::endl; - std::map > visited2; - for (unsigned i = 0, size = assertions.size(); i < size; i++) +Node SortInference::simplify(Node n, + std::map& model_replace_f, + std::map >& visited) +{ + Trace("sort-inference-debug") << "Simplify " << n << std::endl; + std::map var_bound; + TypeNode tnn; + Node ret = simplifyNode(n, var_bound, tnn, model_replace_f, visited); + ret = theory::Rewriter::rewrite(ret); + return ret; +} + +void SortInference::getNewAssertions(std::vector& new_asserts) +{ + NodeManager* nm = NodeManager::currentNM(); + // now, ensure constants are distinct + for (const std::pair >& cm : d_const_map) + { + std::vector consts; + for (const std::pair& c : cm.second) { - Node prev = assertions[i]; - std::map var_bound; - Trace("sort-inference-debug") << "Simplify " << prev << std::endl; - TypeNode tnn; - Node curr = simplifyNode(assertions[i], var_bound, tnn, visited2); - Trace("sort-inference-debug") << "Done." << std::endl; - if (curr != assertions[i]) - { - Trace("sort-inference-debug") << "Rewrite " << curr << std::endl; - curr = theory::Rewriter::rewrite(curr); - rewritten = true; - Trace("sort-inference-rewrite") << assertions << std::endl; - Trace("sort-inference-rewrite") << " --> " << curr << std::endl; - PROOF(ProofManager::currentPM()->addDependence(curr, assertions[i]);); - assertions[i] = curr; - } + Assert(c.first.isConst()); + consts.push_back(c.second); } - Trace("sort-inference-proc") << "...done" << std::endl; - // now, ensure constants are distinct - for (std::map >::iterator it = - d_const_map.begin(); - it != d_const_map.end(); - ++it) + // add lemma enforcing introduced constants to be distinct + if (consts.size() > 1) { - std::vector consts; - for (std::map::iterator it2 = it->second.begin(); - it2 != it->second.end(); - ++it2) - { - Assert(it2->first.isConst()); - consts.push_back(it2->second); - } - // add lemma enforcing introduced constants to be distinct - if (consts.size() > 1) - { - Node distinct_const = nm->mkNode(kind::DISTINCT, consts); - Trace("sort-inference-rewrite") - << "Add the constant distinctness lemma: " << std::endl; - Trace("sort-inference-rewrite") << " " << distinct_const << std::endl; - assertions.push_back(distinct_const); - rewritten = true; - } + Node distinct_const = nm->mkNode(kind::DISTINCT, consts); + Trace("sort-inference-rewrite") + << "Add the constant distinctness lemma: " << std::endl; + Trace("sort-inference-rewrite") << " " << distinct_const << std::endl; + new_asserts.push_back(distinct_const); } + } + + // enforce constraints based on monotonicity + Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl; - // enforce constraints based on monotonicity - Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl; - for (std::map >::iterator it = - d_type_sub_sorts.begin(); - it != d_type_sub_sorts.end(); - ++it) + for (const std::pair >& tss : + d_type_sub_sorts) + { + int nmonSort = -1; + unsigned nsorts = tss.second.size(); + for (unsigned i = 0; i < nsorts; i++) { - int nmonSort = -1; - unsigned nsorts = it->second.size(); - for (unsigned i = 0; i < nsorts; i++) + if (d_non_monotonic_sorts.find(tss.second[i]) + != d_non_monotonic_sorts.end()) { - if (d_non_monotonic_sorts.find(it->second[i]) - != d_non_monotonic_sorts.end()) - { - nmonSort = it->second[i]; - break; - } + nmonSort = tss.second[i]; + break; } - if (nmonSort != -1) + } + if (nmonSort != -1) + { + std::vector injections; + TypeNode base_tn = getOrCreateTypeForId(nmonSort, tss.first); + for (unsigned i = 0; i < nsorts; i++) { - std::vector injections; - TypeNode base_tn = getOrCreateTypeForId(nmonSort, it->first); - for (unsigned i = 0; i < nsorts; i++) - { - if (it->second[i] != nmonSort) - { - TypeNode new_tn = getOrCreateTypeForId(it->second[i], it->first); - // make injection to nmonSort - Node a1 = mkInjection(new_tn, base_tn); - injections.push_back(a1); - if (d_non_monotonic_sorts.find(it->second[i]) - != d_non_monotonic_sorts.end()) - { - // also must make injection from nmonSort to this - Node a2 = mkInjection(base_tn, new_tn); - injections.push_back(a2); - } - } - } - if (Trace.isOn("sort-inference-rewrite")) + if (tss.second[i] != nmonSort) { - Trace("sort-inference-rewrite") << "Add the following injections for " << it->first << " to ensure consistency wrt non-monotonic sorts : " << std::endl; - for (const Node& i : injections) + TypeNode new_tn = getOrCreateTypeForId(tss.second[i], tss.first); + // make injection to nmonSort + Node a1 = mkInjection(new_tn, base_tn); + injections.push_back(a1); + if (d_non_monotonic_sorts.find(tss.second[i]) + != d_non_monotonic_sorts.end()) { - Trace("sort-inference-rewrite") << " " << i << std::endl; + // also must make injection from nmonSort to this + Node a2 = mkInjection(base_tn, new_tn); + injections.push_back(a2); } } - assertions.insert( - assertions.end(), injections.begin(), injections.end()); - if (!injections.empty()) + } + if (Trace.isOn("sort-inference-rewrite")) + { + Trace("sort-inference-rewrite") + << "Add the following injections for " << tss.first + << " to ensure consistency wrt non-monotonic sorts : " << std::endl; + for (const Node& i : injections) { - rewritten = true; + Trace("sort-inference-rewrite") << " " << i << std::endl; } } + new_asserts.insert( + new_asserts.end(), injections.begin(), injections.end()); } - Trace("sort-inference-proc") << "...done" << std::endl; - // no sub-sort information is stored - reset(); - Trace("sort-inference-debug") - << "Finished sort inference, rewritten = " << rewritten << std::endl; - - initialSortCount = sortCount; } - if( doMonotonicyInference ){ - std::map > visitedmt; - Trace("sort-inference-proc") << "Calculating monotonicty for types..." << std::endl; - for (const Node& a : assertions) - { - Trace("sort-inference-debug") << "Process type monotonicity for " << a - << std::endl; - std::map< Node, Node > var_bound; - processMonotonic(a, true, true, var_bound, visitedmt, true); - } - Trace("sort-inference-proc") << "...done" << std::endl; + Trace("sort-inference-proc") << "...done" << std::endl; + // no sub-sort information is stored + reset(); + Trace("sort-inference-debug") << "Finished sort inference" << std::endl; +} + +void SortInference::computeMonotonicity(const std::vector& assertions) +{ + std::map > visitedmt; + Trace("sort-inference-proc") + << "Calculating monotonicty for types..." << std::endl; + for (const Node& a : assertions) + { + Trace("sort-inference-debug") + << "Process type monotonicity for " << a << std::endl; + std::map var_bound; + processMonotonic(a, true, true, var_bound, visitedmt, true); } + Trace("sort-inference-proc") << "...done" << std::endl; } void SortInference::setEqual( int t1, int t2 ){ @@ -357,10 +345,10 @@ int SortInference::getIdForType( TypeNode tn ){ //register the return type std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn ); if( it==d_id_for_types.end() ){ - int sc = sortCount; - d_type_types[ sortCount ] = tn; - d_id_for_types[ tn ] = sortCount; - sortCount++; + int sc = d_sortCount; + d_type_types[d_sortCount] = tn; + d_id_for_types[tn] = d_sortCount; + d_sortCount++; return sc; }else{ return it->second; @@ -381,8 +369,8 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map< }else{ for( size_t i=0; i& var_bound, std::map< d_op_return_types[op] = getIdForType( n.getType() ); }else{ //assign arbitrary sort for return type - d_op_return_types[op] = sortCount; - sortCount++; + d_op_return_types[op] = d_sortCount; + d_sortCount++; } - //d_type_eq_class[sortCount].push_back( op ); - //assign arbitrary sort for argument types + // d_type_eq_class[d_sortCount].push_back( op ); + // assign arbitrary sort for argument types for( size_t i=0; i& var_bound, std::map< Trace("sort-inference-debug") << n << " is a variable." << std::endl; if( d_op_return_types.find( n )==d_op_return_types.end() ){ //assign arbitrary sort - d_op_return_types[n] = sortCount; - sortCount++; - //d_type_eq_class[sortCount].push_back( n ); + d_op_return_types[n] = d_sortCount; + d_sortCount++; + // d_type_eq_class[d_sortCount].push_back( n ); } retType = d_op_return_types[n]; }else if( n.isConst() ){ Trace("sort-inference-debug") << n << " is a constant." << std::endl; //can be any type we want - retType = sortCount; - sortCount++; + retType = d_sortCount; + d_sortCount++; }else{ Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl; //it is an interpreted term @@ -556,8 +544,13 @@ TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){ return d_type_types[rt]; }else{ TypeNode retType; - //see if we can assign pref - if( !pref.isNull() && d_id_for_types.find( pref )==d_id_for_types.end() ){ + // See if we can assign pref. This is an optimization for reusing an + // uninterpreted sort as the first subsort, so that fewer symbols needed + // to be rewritten in the sort-inferred signature. Notice we only assign + // pref here if it is an uninterpreted sort. + if (!pref.isNull() && d_id_for_types.find(pref) == d_id_for_types.end() + && pref.isSort()) + { retType = pref; }else{ //must create new type @@ -606,7 +599,13 @@ Node SortInference::getNewSymbol( Node old, TypeNode tn ){ } } -Node SortInference::simplifyNode( Node n, std::map< Node, Node >& var_bound, TypeNode tnn, std::map< Node, std::map< TypeNode, Node > >& visited ){ +Node SortInference::simplifyNode( + Node n, + std::map& var_bound, + TypeNode tnn, + std::map& model_replace_f, + std::map >& visited) +{ std::map< TypeNode, Node >::iterator itv = visited[n].find( tnn ); if( itv!=visited[n].end() ){ return itv->second; @@ -654,7 +653,11 @@ Node SortInference::simplifyNode( Node n, std::map< Node, Node >& var_bound, Typ tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() ); Assert( !tnnc.isNull() ); } - Node nc = simplifyNode( n[i], var_bound, tnnc, use_new_visited ? new_visited : visited ); + Node nc = simplifyNode(n[i], + var_bound, + tnnc, + model_replace_f, + use_new_visited ? new_visited : visited); Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl; children.push_back( nc ); childChanged = childChanged || nc!=n[i]; @@ -701,7 +704,7 @@ Node SortInference::simplifyNode( Node n, std::map< Node, Node >& var_bound, Typ TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType ); d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" ); Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl; - d_model_replace_f[op] = d_symbol_map[op]; + model_replace_f[op] = d_symbol_map[op]; }else{ d_symbol_map[op] = op; } diff --git a/src/theory/sort_inference.h b/src/theory/sort_inference.h index 6daf6157a..b93d5531c 100644 --- a/src/theory/sort_inference.h +++ b/src/theory/sort_inference.h @@ -26,6 +26,14 @@ namespace CVC4 { +/** sort inference + * + * This class implements sort inference techniques, which rewrites a + * formula F into an equisatisfiable formula F', where the symbols g in F are + * replaced by others g', possibly of different types. For details, see e.g.: + * "Sort it out with Monotonicity" Claessen 2011 + * "Non-Cyclic Sorts for First-Order Satisfiability" Korovin 2013. + */ class SortInference { private: //all subsorts @@ -52,9 +60,10 @@ public: bool areEqual( int t1, int t2 ) { return getRepresentative( t1 )==getRepresentative( t2 ); } bool isValid(); }; -private: - int sortCount; - int initialSortCount; + + private: + /** the id count for all subsorts we have allocated */ + int d_sortCount; UnionFind d_type_union_find; std::map< int, TypeNode > d_type_types; std::map< TypeNode, int > d_id_for_types; @@ -70,8 +79,8 @@ private: void printSort( const char* c, int t ); //process int process( Node n, std::map< Node, Node >& var_bound, std::map< Node, int >& visited ); -//for monotonicity inference -private: + // for monotonicity inference + private: void processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode = false ); //for rewriting @@ -85,17 +94,56 @@ private: TypeNode getTypeForId( int t ); Node getNewSymbol( Node old, TypeNode tn ); //simplify - Node simplifyNode( Node n, std::map< Node, Node >& var_bound, TypeNode tnn, std::map< Node, std::map< TypeNode, Node > >& visited ); + Node simplifyNode(Node n, + std::map& var_bound, + TypeNode tnn, + std::map& model_replace_f, + std::map >& visited); //make injection Node mkInjection( TypeNode tn1, TypeNode tn2 ); //reset void reset(); public: - SortInference() : sortCount(1), initialSortCount() {} + SortInference() : d_sortCount(1) {} ~SortInference(){} - void simplify( std::vector< Node >& assertions, bool doSortInference, bool doMonotonicyInference ); + /** initialize + * + * This initializes this class. The input formula is indicated by assertions. + */ + void initialize(const std::vector& assertions); + /** simplify + * + * This returns the simplified form of formula n, based on the information + * computed during initialization. The argument model_replace_f stores the + * mapping between functions and their analog in the sort-inferred signature. + * The argument visited is a cache of the internal results of simplifying + * previous nodes with this class. + * + * Must call initialize() before this function. + */ + Node simplify(Node n, + std::map& model_replace_f, + std::map >& visited); + /** get new constraints + * + * This adds constraints to new_asserts that ensure the following. + * Let F be the conjunction of assertions from the input. Let F' be the + * conjunction of the simplified form of each conjunct in F. Let C be the + * conjunction of formulas adding to new_asserts. Then, F and F' ^ C are + * equisatisfiable. + */ + void getNewAssertions(std::vector& new_asserts); + /** compute monotonicity + * + * This computes whether sorts are monotonic (see e.g. Claessen 2011). If + * this function is called, then calls to isMonotonic() can subsequently be + * used to query whether sorts are monotonic. + */ + void computeMonotonicity(const std::vector& assertions); + /** return true if tn was inferred to be monotonic */ + bool isMonotonic(TypeNode tn); //get sort id for term n int getSortId( Node n ); //get sort id for variable of quantified formula f @@ -108,16 +156,9 @@ public: bool isWellSorted( Node n ); //get constraints for being well-typed according to computed sub-types void getSortConstraints( Node n, SortInference::UnionFind& uf ); -public: - //list of all functions and the uninterpreted symbols they were replaced with - std::map< Node, Node > d_model_replace_f; - private: // store monotonicity for original sorts as well - std::map< TypeNode, bool > d_non_monotonic_sorts_orig; -public: - //is monotonic - bool isMonotonic( TypeNode tn ); + std::map d_non_monotonic_sorts_orig; }; } diff --git a/test/regress/Makefile.tests b/test/regress/Makefile.tests index 543fbd158..cd79fe050 100644 --- a/test/regress/Makefile.tests +++ b/test/regress/Makefile.tests @@ -433,6 +433,7 @@ REG0_TESTS = \ regress0/fmf/quant_real_univ.cvc \ regress0/fmf/sat-logic.smt2 \ regress0/fmf/sc_bad_model_1221.smt2 \ + regress0/fmf/sort-inf-int.smt2 \ regress0/fmf/syn002-si-real-int.smt2 \ regress0/fmf/tail_rec.smt2 \ regress0/fp/simple.smt2 \ diff --git a/test/regress/regress0/fmf/sort-inf-int.smt2 b/test/regress/regress0/fmf/sort-inf-int.smt2 new file mode 100644 index 000000000..e4a8978d4 --- /dev/null +++ b/test/regress/regress0/fmf/sort-inf-int.smt2 @@ -0,0 +1,13 @@ +; COMMAND-LINE: --finite-model-find --sort-inference --no-check-models +; EXPECT: sat +(set-logic UFLIRA) +(set-info :status sat) +(declare-fun f (Int) Int) +(declare-fun g (Int) Int) +(declare-fun h (Int) Int) +(assert (forall ((x Int)) (or (= (f x) (h x)) (= (f x) (g x))))) +(assert (not (= (f 3) (h 3)))) +(assert (not (= (f 5) (g 5)))) +(assert (= (f 4) (g 8))) + +(check-sat) diff --git a/test/regress/regress1/fmf/ALG008-1.smt2 b/test/regress/regress1/fmf/ALG008-1.smt2 index 2c3bab80d..5bf36a715 100644 --- a/test/regress/regress1/fmf/ALG008-1.smt2 +++ b/test/regress/regress1/fmf/ALG008-1.smt2 @@ -1,4 +1,5 @@ ; COMMAND-LINE: --finite-model-find +; COMMAND-LINE: --finite-model-find --sort-inference --no-check-models ; EXPECT: sat ;%-------------------------------------------------------------------------- ;% File : ALG008-1 : TPTP v5.4.0. Released v2.2.0. -- 2.30.2