class NormalForm {
public:
+ /**
+ * Constructs a set of the form:
+ * (union (singleton c1) ... (union (singleton c_{n-1}) (singleton c_n))))
+ * from the set { c1 ... cn }, also handles empty set case, which is why
+ * setType is passed to this method.
+ */
template <bool ref_count>
static Node elementsToSet(const std::set<NodeTemplate<ref_count> >& elements,
TypeNode setType)
Node cur = nm->mkNode(kind::SINGLETON, *it);
while (++it != elements.end())
{
- cur = nm->mkNode(kind::UNION, cur, nm->mkNode(kind::SINGLETON, *it));
+ cur = nm->mkNode(kind::UNION, nm->mkNode(kind::SINGLETON, *it), cur);
}
return cur;
}
}
+ /**
+ * Returns true if n is considered a to be a (canonical) constant set value.
+ * A canonical set value is one whose AST is:
+ * (union (singleton c1) ... (union (singleton c_{n-1}) (singleton c_n))))
+ * where c1 ... cn are constants and the node identifier of these constants
+ * are such that:
+ * c1 > ... > cn.
+ * Also handles the corner cases of empty set and singleton set.
+ */
static bool checkNormalConstant(TNode n) {
Debug("sets-checknormal") << "[sets-checknormal] checkNormal " << n << " :"
<< std::endl;
} else if (n.getKind() == kind::SINGLETON) {
return n[0].isConst();
} else if (n.getKind() == kind::UNION) {
- // assuming (union ... (union {SmallestNodeID} {BiggerNodeId}) ...
- // {BiggestNodeId})
-
- // store BiggestNodeId in prvs
- if (n[1].getKind() != kind::SINGLETON) return false;
- if (!n[1][0].isConst()) return false;
- Debug("sets-checknormal")
- << "[sets-checknormal] frst element = " << n[1][0] << " "
- << n[1][0].getId() << std::endl;
- TNode prvs = n[1][0];
- n = n[0];
+ // assuming (union {SmallestNodeID} ... (union {BiggerNodeId} ...
+ Node orig = n;
+ TNode prvs;
// check intermediate nodes
- while (n.getKind() == kind::UNION) {
- if (n[1].getKind() != kind::SINGLETON) return false;
- if (!n[1].isConst()) return false;
+ while (n.getKind() == kind::UNION)
+ {
+ if (n[0].getKind() != kind::SINGLETON || !n[0][0].isConst())
+ {
+ // not a constant
+ Trace("sets-isconst") << "sets::isConst: " << orig << " not due to "
+ << n[0] << std::endl;
+ return false;
+ }
Debug("sets-checknormal")
- << "[sets-checknormal] element = " << n[1][0] << " "
- << n[1][0].getId() << std::endl;
- if (n[1][0] >= prvs) return false;
- prvs = n[1][0];
- n = n[0];
+ << "[sets-checknormal] element = " << n[0][0] << " "
+ << n[0][0].getId() << std::endl;
+ if (!prvs.isNull() && n[0][0] >= prvs)
+ {
+ Trace("sets-isconst")
+ << "sets::isConst: " << orig << " not due to compare " << n[0][0]
+ << std::endl;
+ return false;
+ }
+ prvs = n[0][0];
+ n = n[1];
}
// check SmallestNodeID is smallest
- if (n.getKind() != kind::SINGLETON) return false;
- if (!n[0].isConst()) return false;
+ if (n.getKind() != kind::SINGLETON || !n[0].isConst())
+ {
+ Trace("sets-isconst") << "sets::isConst: " << orig
+ << " not due to final " << n << std::endl;
+ return false;
+ }
Debug("sets-checknormal")
<< "[sets-checknormal] lst element = " << n[0] << " "
<< n[0].getId() << std::endl;
- if (n[0] >= prvs) return false;
-
- // we made it
- return true;
-
- } else {
- return false;
+ // compare last ID
+ if (n[0] < prvs)
+ {
+ return true;
+ }
+ Trace("sets-isconst")
+ << "sets::isConst: " << orig << " not due to compare final " << n[0]
+ << std::endl;
}
+ return false;
}
+ /**
+ * Converts a set term to a std::set of its elements. This expects a set of
+ * the form:
+ * (union (singleton c1) ... (union (singleton c_{n-1}) (singleton c_n))))
+ * Also handles the corner cases of empty set and singleton set.
+ */
static std::set<Node> getElementsFromNormalConstant(TNode n) {
Assert(n.isConst());
std::set<Node> ret;
return ret;
}
while (n.getKind() == kind::UNION) {
- Assert(n[1].getKind() == kind::SINGLETON);
- ret.insert(ret.begin(), n[1][0]);
- n = n[0];
+ Assert(n[0].getKind() == kind::SINGLETON);
+ ret.insert(ret.begin(), n[0][0]);
+ n = n[1];
}
Assert(n.getKind() == kind::SINGLETON);
ret.insert(n[0]);
return ret;
}
-
- //AJR
-
- static void getElementsFromBop( Kind k, Node n, std::vector< Node >& els ){
- if( n.getKind()==k ){
- for( unsigned i=0; i<n.getNumChildren(); i++ ){
- getElementsFromBop( k, n[i], els );
- }
- }else{
- if( std::find( els.begin(), els.end(), n )==els.end() ){
- els.push_back( n );
- }
- }
- }
static Node mkBop( Kind k, std::vector< Node >& els, TypeNode tn, unsigned index = 0 ){
if( index>=els.size() ){
return NodeManager::currentNM()->mkConst(EmptySet(tn));
namespace theory {
namespace sets {
-bool checkConstantMembership(TNode elementTerm, TNode setTerm)
+bool TheorySetsRewriter::checkConstantMembership(TNode elementTerm, TNode setTerm)
{
if(setTerm.getKind() == kind::EMPTYSET) {
return false;
}
Assert(setTerm.getKind() == kind::UNION
- && setTerm[1].getKind() == kind::SINGLETON)
+ && setTerm[0].getKind() == kind::SINGLETON)
<< "kind was " << setTerm.getKind() << ", term: " << setTerm;
- return
- elementTerm == setTerm[1][0] ||
- checkConstantMembership(elementTerm, setTerm[0]);
+ return elementTerm == setTerm[0][0]
+ || checkConstantMembership(elementTerm, setTerm[1]);
}
// static
Trace("sets-postrewrite") << "Process: " << node << std::endl;
if(node.isConst()) {
+ Trace("sets-rewrite-nf")
+ << "Sets::rewrite: no rewrite (constant) " << node << std::endl;
// Dare you touch the const and mangle it to something else.
return RewriteResponse(REWRITE_DONE, node);
}
Assert(newNode.isConst());
Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
return RewriteResponse(REWRITE_DONE, newNode);
- } else {
- std::vector< Node > els;
- NormalForm::getElementsFromBop( kind::INTERSECTION, node, els );
- std::sort( els.begin(), els.end() );
- Node rew = NormalForm::mkBop( kind::INTERSECTION, els, node.getType() );
- if( rew!=node ){
- Trace("sets-rewrite") << "Sets::rewrite " << node << " -> " << rew << std::endl;
- }
- return RewriteResponse(REWRITE_DONE, rew);
}
- /*
- } else if (node[0] > node[1]) {
+ else if (node[0] > node[1])
+ {
Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
- Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
return RewriteResponse(REWRITE_DONE, newNode);
}
- */
+ // we don't merge non-constant intersections
break;
}//kind::INTERSECION
std::inserter(newSet, newSet.begin()));
Node newNode = NormalForm::elementsToSet(newSet, node.getType());
Assert(newNode.isConst());
- Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
+ Trace("sets-rewrite")
+ << "Sets::rewrite: UNION_CONSTANT_MERGE: " << newNode << std::endl;
return RewriteResponse(REWRITE_DONE, newNode);
- } else {
- std::vector< Node > els;
- NormalForm::getElementsFromBop( kind::UNION, node, els );
- std::sort( els.begin(), els.end() );
- Node rew = NormalForm::mkBop( kind::UNION, els, node.getType() );
- if( rew!=node ){
- Trace("sets-rewrite") << "Sets::rewrite " << node << " -> " << rew << std::endl;
- }
- Trace("sets-rewrite") << "...no rewrite." << std::endl;
- return RewriteResponse(REWRITE_DONE, rew);
}
+ else if (node[0] > node[1])
+ {
+ Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
+ return RewriteResponse(REWRITE_DONE, newNode);
+ }
+ // we don't merge non-constant unions
break;
}//kind::UNION
case kind::COMPLEMENT: {
// static
RewriteResponse TheorySetsRewriter::preRewrite(TNode node) {
NodeManager* nm = NodeManager::currentNM();
-
- if(node.getKind() == kind::EQUAL) {
-
+ Kind k = node.getKind();
+ if (k == kind::EQUAL)
+ {
if(node[0] == node[1]) {
return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
}
-
- }//kind::EQUAL
- else if(node.getKind() == kind::INSERT) {
-
+ }
+ else if (k == kind::INSERT)
+ {
Node insertedElements = nm->mkNode(kind::SINGLETON, node[0]);
size_t setNodeIndex = node.getNumChildren()-1;
for(size_t i = 1; i < setNodeIndex; ++i) {
nm->mkNode(kind::UNION,
insertedElements,
node[setNodeIndex]));
-
- }//kind::INSERT
- else if(node.getKind() == kind::SUBSET) {
-
+ }
+ else if (k == kind::SUBSET)
+ {
// rewrite (A subset-or-equal B) as (A union B = B)
return RewriteResponse(REWRITE_AGAIN,
nm->mkNode(kind::EQUAL,
nm->mkNode(kind::UNION, node[0], node[1]),
node[1]) );
-
- }//kind::SUBSET
+ }
+ // could have an efficient normalizer for union here
return RewriteResponse(REWRITE_DONE, node);
}
delete d_em;
}
+ void addAndCheckUnique(Node n, std::vector<Node>& elems)
+ {
+ TS_ASSERT(n.isConst());
+ TS_ASSERT(std::find(elems.begin(), elems.end(), n) == elems.end());
+ elems.push_back(n);
+ }
+
void testSetOfBooleans()
{
TypeNode boolType = d_nm->booleanType();
SetEnumerator setEnumerator(d_nm->mkSetType(boolType));
TS_ASSERT(!setEnumerator.isFinished());
+ std::vector<Node> elems;
+
Node actual0 = *setEnumerator;
- Node expected0 =
- d_nm->mkConst(EmptySet(d_nm->mkSetType(boolType)));
- TS_ASSERT_EQUALS(expected0, actual0);
+ addAndCheckUnique(actual0, elems);
TS_ASSERT(!setEnumerator.isFinished());
Node actual1 = *++setEnumerator;
- Node expected1 = d_nm->mkNode(Kind::SINGLETON, d_nm->mkConst(false));
- TS_ASSERT_EQUALS(expected1, actual1);
+ addAndCheckUnique(actual1, elems);
TS_ASSERT(!setEnumerator.isFinished());
Node actual2 = *++setEnumerator;
- Node expected2 = d_nm->mkNode(Kind::SINGLETON, d_nm->mkConst(true));
- TS_ASSERT_EQUALS(expected2, actual2);
+ addAndCheckUnique(actual2, elems);
TS_ASSERT(!setEnumerator.isFinished());
Node actual3 = Rewriter::rewrite(*++setEnumerator);
- Node expected3 =
- Rewriter::rewrite(d_nm->mkNode(Kind::UNION, expected1, expected2));
- TS_ASSERT_EQUALS(expected3, actual3);
+ addAndCheckUnique(actual3, elems);
TS_ASSERT(!setEnumerator.isFinished());
TS_ASSERT_THROWS(*++setEnumerator, NoMoreValuesException&);
TS_ASSERT_EQUALS(expected0, actual0);
TS_ASSERT(!setEnumerator.isFinished());
- Node actual1 = *++setEnumerator;
- Node expected1 = d_nm->mkNode(
- Kind::SINGLETON, d_nm->mkConst(UninterpretedConstant(sort, 0)));
- TS_ASSERT_EQUALS(expected1, actual1);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual2 = *++setEnumerator;
- Node expected2 = d_nm->mkNode(
- Kind::SINGLETON, d_nm->mkConst(UninterpretedConstant(sort, 1)));
- TS_ASSERT_EQUALS(expected2, actual2);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual3 = *++setEnumerator;
- Node expected3 = d_nm->mkNode(Kind::UNION, expected1, expected2);
- TS_ASSERT_EQUALS(expected3, actual3);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual4 = *++setEnumerator;
- Node expected4 = d_nm->mkNode(
- Kind::SINGLETON, d_nm->mkConst(UninterpretedConstant(sort, 2)));
- TS_ASSERT_EQUALS(expected4, actual4);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual5 = *++setEnumerator;
- Node expected5 = d_nm->mkNode(Kind::UNION, expected1, expected4);
- TS_ASSERT_EQUALS(expected5, actual5);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual6 = *++setEnumerator;
- Node expected6 = d_nm->mkNode(Kind::UNION, expected2, expected4);
- TS_ASSERT_EQUALS(expected6, actual6);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual7 = *++setEnumerator;
- Node expected7 = d_nm->mkNode(Kind::UNION, expected3, expected4);
- TS_ASSERT_EQUALS(expected7, actual7);
- TS_ASSERT(!setEnumerator.isFinished());
+ std::vector<Node> elems;
+ for (unsigned i = 0; i < 7; i++)
+ {
+ Node actual = *setEnumerator;
+ addAndCheckUnique(actual, elems);
+ TS_ASSERT(!setEnumerator.isFinished());
+ ++setEnumerator;
+ }
}
void testSetOfFiniteDatatype()
Node blue = d_nm->mkNode(APPLY_CONSTRUCTOR, dtcons[2]->getConstructor());
- Node actual0 = *setEnumerator;
- Node expected0 =
- d_nm->mkConst(EmptySet(d_nm->mkSetType(datatype)));
- TS_ASSERT_EQUALS(expected0, actual0);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual1 = *++setEnumerator;
- Node expected1 = d_nm->mkNode(Kind::SINGLETON, red);
- TS_ASSERT_EQUALS(expected1, actual1);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual2 = *++setEnumerator;
- Node expected2 = d_nm->mkNode(Kind::SINGLETON, green);
- TS_ASSERT_EQUALS(expected2, actual2);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual3 = *++setEnumerator;
- Node expected3 = d_nm->mkNode(Kind::UNION, expected1, expected2);
- TS_ASSERT_EQUALS(expected3, actual3);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual4 = *++setEnumerator;
- Node expected4 = d_nm->mkNode(Kind::SINGLETON, blue);
- TS_ASSERT_EQUALS(expected4, actual4);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual5 = *++setEnumerator;
- Node expected5 = d_nm->mkNode(Kind::UNION, expected1, expected4);
- TS_ASSERT_EQUALS(expected5, actual5);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual6 = *++setEnumerator;
- Node expected6 = d_nm->mkNode(Kind::UNION, expected2, expected4);
- TS_ASSERT_EQUALS(expected6, actual6);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual7 = *++setEnumerator;
- Node expected7 = d_nm->mkNode(Kind::UNION, expected3, expected4);
- TS_ASSERT_EQUALS(expected7, actual7);
- TS_ASSERT(!setEnumerator.isFinished());
+ std::vector<Node> elems;
+ for (unsigned i = 0; i < 8; i++)
+ {
+ Node actual = *setEnumerator;
+ addAndCheckUnique(actual, elems);
+ TS_ASSERT(!setEnumerator.isFinished());
+ ++setEnumerator;
+ }
TS_ASSERT_THROWS(*++setEnumerator, NoMoreValuesException&);
TS_ASSERT(setEnumerator.isFinished());
{
TypeNode bitVector2 = d_nm->mkBitVectorType(2);
SetEnumerator setEnumerator(d_nm->mkSetType(bitVector2));
- Node zero = d_nm->mkConst(BitVector(2u, 0u));
- Node one = d_nm->mkConst(BitVector(2u, 1u));
- Node two = d_nm->mkConst(BitVector(2u, 2u));
- Node three = d_nm->mkConst(BitVector(2u, 3u));
- Node four = d_nm->mkConst(BitVector(2u, 4u));
-
- Node actual0 = *setEnumerator;
- Node expected0 =
- d_nm->mkConst(EmptySet(d_nm->mkSetType(bitVector2)));
- TS_ASSERT_EQUALS(expected0, actual0);
- TS_ASSERT(!setEnumerator.isFinished());
- Node actual1 = *++setEnumerator;
- Node expected1 = d_nm->mkNode(Kind::SINGLETON, zero);
- TS_ASSERT_EQUALS(expected1, actual1);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual2 = *++setEnumerator;
- Node expected2 = d_nm->mkNode(Kind::SINGLETON, one);
- TS_ASSERT_EQUALS(expected2, actual2);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual3 = *++setEnumerator;
- Node expected3 = d_nm->mkNode(Kind::UNION, expected1, expected2);
- TS_ASSERT_EQUALS(expected3, actual3);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual4 = *++setEnumerator;
- Node expected4 = d_nm->mkNode(Kind::SINGLETON, two);
- TS_ASSERT_EQUALS(expected4, actual4);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual5 = *++setEnumerator;
- Node expected5 = d_nm->mkNode(Kind::UNION, expected1, expected4);
- TS_ASSERT_EQUALS(expected5, actual5);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual6 = *++setEnumerator;
- Node expected6 = d_nm->mkNode(Kind::UNION, expected2, expected4);
- TS_ASSERT_EQUALS(expected6, actual6);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual7 = *++setEnumerator;
- Node expected7 = d_nm->mkNode(Kind::UNION, expected3, expected4);
- TS_ASSERT_EQUALS(expected7, actual7);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual8 = *++setEnumerator;
- Node expected8 = d_nm->mkNode(Kind::SINGLETON, three);
- TS_ASSERT_EQUALS(expected8, actual8);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual9 = *++setEnumerator;
- Node expected9 = d_nm->mkNode(Kind::UNION, expected1, expected8);
- TS_ASSERT_EQUALS(expected9, actual9);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual10 = *++setEnumerator;
- Node expected10 = d_nm->mkNode(Kind::UNION, expected2, expected8);
- TS_ASSERT_EQUALS(expected10, actual10);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual11 = *++setEnumerator;
- Node expected11 = d_nm->mkNode(Kind::UNION, expected3, expected8);
- TS_ASSERT_EQUALS(expected11, actual11);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual12 = *++setEnumerator;
- Node expected12 = d_nm->mkNode(Kind::UNION, expected4, expected8);
- TS_ASSERT_EQUALS(expected12, actual12);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual13 = *++setEnumerator;
- Node expected13 = d_nm->mkNode(Kind::UNION, expected5, expected8);
- TS_ASSERT_EQUALS(expected13, actual13);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual14 = *++setEnumerator;
- Node expected14 = d_nm->mkNode(Kind::UNION, expected6, expected8);
- TS_ASSERT_EQUALS(expected14, actual14);
- TS_ASSERT(!setEnumerator.isFinished());
-
- Node actual15 = *++setEnumerator;
- Node expected15 = d_nm->mkNode(Kind::UNION, expected7, expected8);
- TS_ASSERT_EQUALS(expected15, actual15);
- TS_ASSERT(!setEnumerator.isFinished());
+ std::vector<Node> elems;
+ for (unsigned i = 0; i < 16; i++)
+ {
+ Node actual = *setEnumerator;
+ addAndCheckUnique(actual, elems);
+ TS_ASSERT(!setEnumerator.isFinished());
+ ++setEnumerator;
+ }
TS_ASSERT_THROWS(*++setEnumerator, NoMoreValuesException&);
TS_ASSERT(setEnumerator.isFinished());