From 5cef06bd2beff38a911c74ec082d9789eed83421 Mon Sep 17 00:00:00 2001 From: MikolasJanota Date: Sat, 4 Sep 2021 00:59:46 +0200 Subject: [PATCH] Avoiding duplicate search in maps (#7055) This commit identifies a couple of scenarios where an operation required 2 searches into a map/hashmap and replaces them with a single search. This also makes the code shorter. Signed-off-by: Mikolas Janota --- src/context/cdhashmap.h | 185 ++++++++++++++----------- src/theory/quantifiers/instantiate.cpp | 14 +- src/theory/quantifiers/term_pools.cpp | 4 +- 3 files changed, 109 insertions(+), 94 deletions(-) diff --git a/src/context/cdhashmap.h b/src/context/cdhashmap.h index 5ea1fc5a7..16c3a09e3 100644 --- a/src/context/cdhashmap.h +++ b/src/context/cdhashmap.h @@ -97,7 +97,8 @@ namespace context { // Auxiliary class: almost the same as CDO (see cdo.h) template > -class CDOhash_map : public ContextObj { +class CDOhash_map : public ContextObj +{ friend class CDHashMap; public: @@ -130,14 +131,16 @@ class CDOhash_map : public ContextObj { ContextObj* save(ContextMemoryManager* pCMM) override { - return new(pCMM) CDOhash_map(*this); + return new (pCMM) CDOhash_map(*this); } void restore(ContextObj* data) override { CDOhash_map* p = static_cast(data); - if(d_map != NULL) { - if(p->d_map == NULL) { + if (d_map != NULL) + { + if (p->d_map == NULL) + { Assert(d_map->d_map.find(getKey()) != d_map->d_map.end() && (*d_map->d_map.find(getKey())).second == this); // no longer in map (popped beyond first level in which it was) @@ -146,16 +149,24 @@ class CDOhash_map : public ContextObj { // put it on a "trash heap" instead, for later deletion. // // FIXME multithreading - if(d_map->d_first == this) { - Debug("gc") << "remove first-elem " << this << " from map " << d_map << " with next-elem " << d_next << std::endl; - if(d_next == this) { + if (d_map->d_first == this) + { + Debug("gc") << "remove first-elem " << this << " from map " << d_map + << " with next-elem " << d_next << std::endl; + if (d_next == this) + { Assert(d_prev == this); d_map->d_first = NULL; - } else { + } + else + { d_map->d_first = d_next; } - } else { - Debug("gc") << "remove nonfirst-elem " << this << " from map " << d_map << std::endl; + } + else + { + Debug("gc") << "remove nonfirst-elem " << this << " from map " + << d_map << std::endl; } d_next->d_prev = d_prev; d_prev->d_next = d_next; @@ -163,7 +174,9 @@ class CDOhash_map : public ContextObj { Debug("gc") << "CDHashMap<> trash push_back " << this << std::endl; // this->deleteSelf(); enqueueToGarbageCollect(); - } else { + } + else + { mutable_data() = p->get(); } } @@ -194,13 +207,16 @@ class CDOhash_map : public ContextObj { bool atLevelZero = false) : ContextObj(false, context), d_value(key, data), d_map(NULL) { - if(atLevelZero) { + if (atLevelZero) + { // "Initializing" map insertion: this entry will never be // removed from the map, it's inserted at level 0 as an // "initializing" element. See // CDHashMap<>::insertAtContextLevelZero(). mutable_data() = data; - } else { + } + else + { // Normal map insertion: first makeCurrent(), then set the data // and then, later, the map. Order is important; we can't // initialize d_map in the constructor init list above, because @@ -212,11 +228,17 @@ class CDOhash_map : public ContextObj { d_map = map; CDOhash_map*& first = d_map->d_first; - if(first == NULL) { + if (first == NULL) + { first = d_next = d_prev = this; - Debug("gc") << "add first-elem " << this << " to map " << d_map << std::endl; - } else { - Debug("gc") << "add nonfirst-elem " << this << " to map " << d_map << " with first-elem " << first << "[" << first->d_prev << " " << first->d_next << std::endl; + Debug("gc") << "add first-elem " << this << " to map " << d_map + << std::endl; + } + else + { + Debug("gc") << "add nonfirst-elem " << this << " to map " << d_map + << " with first-elem " << first << "[" << first->d_prev << " " + << first->d_next << std::endl; d_prev = first->d_prev; d_next = first; d_prev->d_next = this; @@ -224,11 +246,10 @@ class CDOhash_map : public ContextObj { } } - ~CDOhash_map() { - destroy(); - } + ~CDOhash_map() { destroy(); } - void set(const Data& data) { + void set(const Data& data) + { makeCurrent(); mutable_data() = data; } @@ -239,24 +260,26 @@ class CDOhash_map : public ContextObj { const value_type& getValue() const { return d_value; } - operator Data() { - return get(); - } + operator Data() { return get(); } - const Data& operator=(const Data& data) { + const Data& operator=(const Data& data) + { set(data); return data; } - CDOhash_map* next() const { - if(d_next == d_map->d_first) { + CDOhash_map* next() const + { + if (d_next == d_map->d_first) + { return NULL; - } else { + } + else + { return d_next; } } -};/* class CDOhash_map<> */ - +}; /* class CDOhash_map<> */ /** * Generic templated class for a map which must be saved and restored @@ -264,8 +287,8 @@ class CDOhash_map : public ContextObj { * defined for the data class, and operator== for the key class. */ template -class CDHashMap : public ContextObj { - +class CDHashMap : public ContextObj +{ typedef CDOhash_map Element; typedef std::unordered_map table_type; @@ -290,11 +313,14 @@ class CDHashMap : public ContextObj { CDHashMap(const CDHashMap&) = delete; CDHashMap& operator=(const CDHashMap&) = delete; -public: + public: CDHashMap(Context* context) - : ContextObj(context), d_map(), d_first(NULL), d_context(context) {} + : ContextObj(context), d_map(), d_first(NULL), d_context(context) + { + } - ~CDHashMap() { + ~CDHashMap() + { Debug("gc") << "cdhashmap" << this << " disappearing, destroying..." << std::endl; destroy(); @@ -303,12 +329,14 @@ public: clear(); } - void clear() { + void clear() + { Debug("gc") << "clearing cdhashmap" << this << ", emptying trash" << std::endl; Debug("gc") << "done emptying trash for " << this << std::endl; - for (auto& key_element_pair : d_map) { + for (auto& key_element_pair : d_map) + { // mark it as being a destruction (short-circuit restore()) Element* element = key_element_pair.second; element->d_map = nullptr; @@ -320,43 +348,35 @@ public: // The usual operators of map - size_t size() const { - return d_map.size(); - } + size_t size() const { return d_map.size(); } - bool empty() const { - return d_map.empty(); - } + bool empty() const { return d_map.empty(); } - size_t count(const Key& k) const { - return d_map.count(k); - } + size_t count(const Key& k) const { return d_map.count(k); } // If a key is not present, a new object is created and inserted - Element& operator[](const Key& k) { - typename table_type::iterator i = d_map.find(k); - - Element* obj; - if(i == d_map.end()) {// create new object - obj = new(true) Element(d_context, this, k, Data()); - d_map.insert(std::make_pair(k, obj)); - } else { - obj = (*i).second; + Element& operator[](const Key& k) + { + const auto res = d_map.insert({k, nullptr}); + if (res.second) + { // create new object + res.first->second = new (true) Element(d_context, this, k, Data()); } - return *obj; + return *(res.first->second); } - bool insert(const Key& k, const Data& d) { - typename table_type::iterator i = d_map.find(k); - - if(i == d_map.end()) {// create new object - Element* obj = new(true) Element(d_context, this, k, d); - d_map.insert(std::make_pair(k, obj)); - return true; - } else { - (*i).second->set(d); - return false; + bool insert(const Key& k, const Data& d) + { + const auto res = d_map.insert({k, nullptr}); + if (res.second) + { // create new object + res.first->second = new (true) Element(d_context, this, k, d); } + else + { + res.first->second->set(d); + } + return res.second; } /** @@ -383,11 +403,12 @@ public: * It is an error (checked via AlwaysAssert()) to * insertAtContextLevelZero() a key that already is in the map. */ - void insertAtContextLevelZero(const Key& k, const Data& d) { + void insertAtContextLevelZero(const Key& k, const Data& d) + { AlwaysAssert(d_map.find(k) == d_map.end()); - Element* obj = new(true) Element(d_context, this, k, d, - true /* atLevelZero */); + Element* obj = + new (true) Element(d_context, this, k, d, true /* atLevelZero */); d_map.insert(std::make_pair(k, obj)); } @@ -395,7 +416,8 @@ public: using value_type = typename CDOhash_map::value_type; - class iterator { + class iterator + { const Element* d_it; public: @@ -427,30 +449,29 @@ public: } // Postfix increment is not yet supported. - };/* class CDHashMap<>::iterator */ + }; /* class CDHashMap<>::iterator */ typedef iterator const_iterator; - iterator begin() const { - return iterator(d_first); - } + iterator begin() const { return iterator(d_first); } - iterator end() const { - return iterator(NULL); - } + iterator end() const { return iterator(NULL); } - iterator find(const Key& k) const { + iterator find(const Key& k) const + { typename table_type::const_iterator i = d_map.find(k); - if(i == d_map.end()) { + if (i == d_map.end()) + { return end(); - } else { + } + else + { return iterator((*i).second); } } - -};/* class CDHashMap<> */ +}; /* class CDHashMap<> */ } // namespace context } // namespace cvc5 diff --git a/src/theory/quantifiers/instantiate.cpp b/src/theory/quantifiers/instantiate.cpp index be1633d16..0daf53d2d 100644 --- a/src/theory/quantifiers/instantiate.cpp +++ b/src/theory/quantifiers/instantiate.cpp @@ -589,19 +589,13 @@ bool Instantiate::recordInstantiationInternal(Node q, std::vector& terms) { Trace("inst-add-debug") << "Adding into context-dependent inst trie" << std::endl; - CDInstMatchTrie* imt; - std::map::iterator it = d_c_inst_match_trie.find(q); - if (it != d_c_inst_match_trie.end()) - { - imt = it->second; - } - else + const auto res = d_c_inst_match_trie.insert({q, nullptr}); + if (res.second) { - imt = new CDInstMatchTrie(d_qstate.getUserContext()); - d_c_inst_match_trie[q] = imt; + res.first->second = new CDInstMatchTrie(d_qstate.getUserContext()); } d_c_inst_match_trie_dom.insert(q); - return imt->addInstMatch(d_qstate, q, terms); + return res.first->second->addInstMatch(d_qstate, q, terms); } Trace("inst-add-debug") << "Adding into inst trie" << std::endl; return d_inst_match_trie[q].addInstMatch(d_qstate, q, terms); diff --git a/src/theory/quantifiers/term_pools.cpp b/src/theory/quantifiers/term_pools.cpp index 883161f1a..2d95c8b20 100644 --- a/src/theory/quantifiers/term_pools.cpp +++ b/src/theory/quantifiers/term_pools.cpp @@ -103,9 +103,9 @@ void TermPools::getTermsForPool(Node p, std::vector& terms) for (const Node& t : dom.d_terms) { Node r = d_qs.getRepresentative(t); - if (reps.find(r) == reps.end()) + const auto i = reps.insert(r); + if (i.second) { - reps.insert(r); dom.d_currTerms.push_back(t); } } -- 2.30.2