improved bitpattern (proc_mux) performance
authorClifford Wolf <clifford@clifford.at>
Wed, 31 Dec 2014 12:15:35 +0000 (13:15 +0100)
committerClifford Wolf <clifford@clifford.at>
Wed, 31 Dec 2014 12:15:35 +0000 (13:15 +0100)
kernel/bitpattern.h

index 7416a488d79a360c44aac1f61b6fe19a67dec9f8..00bbc3bfb50fdc7d6dbd58411f15bb8bdedf57c9 100644 (file)
@@ -28,14 +28,34 @@ YOSYS_NAMESPACE_BEGIN
 struct BitPatternPool
 {
        int width;
-       typedef std::vector<RTLIL::State> bits_t;
+       struct bits_t {
+               std::vector<RTLIL::State> bitdata;
+               unsigned int cached_hash;
+               bits_t(int width = 0) : bitdata(width), cached_hash(0) { }
+               RTLIL::State &operator[](int index) {
+                       return bitdata[index];
+               }
+               const RTLIL::State &operator[](int index) const {
+                       return bitdata[index];
+               }
+               bool operator==(const bits_t &other) const {
+                       if (hash() != other.hash())
+                               return false;
+                       return bitdata == other.bitdata;
+               }
+               unsigned int hash() const {
+                       if (!cached_hash)
+                               ((bits_t*)this)->cached_hash = hash_ops<std::vector<RTLIL::State>>::hash(bitdata);
+                       return cached_hash;
+               }
+       };
        pool<bits_t> database;
 
        BitPatternPool(RTLIL::SigSpec sig)
        {
                width = sig.size();
                if (width > 0) {
-                       std::vector<RTLIL::State> pattern(width);
+                       bits_t pattern(width);
                        for (int i = 0; i < width; i++) {
                                if (sig[i].wire == NULL && sig[i].data <= RTLIL::State::S1)
                                        pattern[i] = sig[i].data;
@@ -50,7 +70,7 @@ struct BitPatternPool
        {
                this->width = width;
                if (width > 0) {
-                       std::vector<RTLIL::State> pattern(width);
+                       bits_t pattern(width);
                        for (int i = 0; i < width; i++)
                                pattern[i] = RTLIL::State::Sa;
                        database.insert(pattern);
@@ -59,8 +79,9 @@ struct BitPatternPool
 
        bits_t sig2bits(RTLIL::SigSpec sig)
        {
-               bits_t bits = sig.as_const().bits;
-               for (auto &b : bits)
+               bits_t bits;
+               bits.bitdata = sig.as_const().bits;
+               for (auto &b : bits.bitdata)
                        if (b > RTLIL::State::S1)
                                b = RTLIL::State::Sa;
                return bits;
@@ -68,8 +89,8 @@ struct BitPatternPool
 
        bool match(bits_t a, bits_t b)
        {
-               log_assert(int(a.size()) == width);
-               log_assert(int(b.size()) == width);
+               log_assert(int(a.bitdata.size()) == width);
+               log_assert(int(b.bitdata.size()) == width);
                for (int i = 0; i < width; i++)
                        if (a[i] <= RTLIL::State::S1 && b[i] <= RTLIL::State::S1 && a[i] != b[i])
                                return false;
@@ -103,21 +124,21 @@ struct BitPatternPool
        {
                bool status = false;
                bits_t bits = sig2bits(sig);
-               std::vector<bits_t> pattern_list;
-               for (auto &it : database)
-                       if (match(it, bits))
-                               pattern_list.push_back(it);
-               for (auto pattern : pattern_list) {
-                       database.erase(pattern);
-                       for (int i = 0; i < width; i++) {
-                               if (pattern[i] != RTLIL::State::Sa || bits[i] == RTLIL::State::Sa)
-                                       continue;
-                               bits_t new_pattern = pattern;
-                               new_pattern[i] = bits[i] == RTLIL::State::S1 ? RTLIL::State::S0 : RTLIL::State::S1;
-                               database.insert(new_pattern);
-                       }
-                       status = true;
-               }
+               for (auto it = database.begin(); it != database.end();)
+                       if (match(*it, bits)) {
+                               for (int i = 0; i < width; i++) {
+                                       if ((*it)[i] != RTLIL::State::Sa || bits[i] == RTLIL::State::Sa)
+                                               continue;
+                                       bits_t new_pattern;
+                                       new_pattern.bitdata = it->bitdata;
+                                       new_pattern[i] = bits[i] == RTLIL::State::S1 ? RTLIL::State::S0 : RTLIL::State::S1;
+                                       database.insert(new_pattern);
+                               }
+                               it = database.erase(it);
+                               status = true;
+                               continue;
+                       } else
+                               ++it;
                return status;
        }