mem-cache: Use SatCounter for prefetchers
authorDaniel <odanrc@yahoo.com.br>
Thu, 11 Apr 2019 06:37:56 +0000 (08:37 +0200)
committerDaniel Carvalho <odanrc@yahoo.com.br>
Tue, 14 May 2019 07:55:06 +0000 (07:55 +0000)
Many prefetchers re-implement saturating counters with ints. Make
them use SatCounters instead.

Added missing operators and constructors to SatCounter for that to
be possible and their respective tests.

Change-Id: I36f10c89c27c9b3d1bf461e9ea546920f6ebb888
Signed-off-by: Daniel <odanrc@yahoo.com.br>
Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/17995
Tested-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Javier Bueno Hedo <javier.bueno@metempsy.com>
Maintainer: Jason Lowe-Power <jason@lowepower.com>

src/mem/cache/prefetch/Prefetcher.py
src/mem/cache/prefetch/indirect_memory.cc
src/mem/cache/prefetch/indirect_memory.hh
src/mem/cache/prefetch/irregular_stream_buffer.cc
src/mem/cache/prefetch/irregular_stream_buffer.hh
src/mem/cache/prefetch/signature_path.cc
src/mem/cache/prefetch/signature_path.hh
src/mem/cache/prefetch/signature_path_v2.cc
src/mem/cache/prefetch/spatio_temporal_memory_streaming.cc
src/mem/cache/prefetch/spatio_temporal_memory_streaming.hh

index aaa140887a725be0756886f0ef73eb295c157f63..b933b4953c799a8f25fb5aac1ff763de4fd4abcd 100644 (file)
@@ -156,8 +156,8 @@ class IndirectMemoryPrefetcher(QueuedPrefetcher):
     pt_table_replacement_policy = Param.BaseReplacementPolicy(LRURP(),
         "Replacement policy of the pattern table")
     max_prefetch_distance = Param.Unsigned(16, "Maximum prefetch distance")
-    max_indirect_counter_value = Param.Unsigned(8,
-        "Maximum value of the indirect counter")
+    num_indirect_counter_bits = Param.Unsigned(3,
+        "Number of bits of the indirect counter")
     ipd_table_entries = Param.MemorySize("4",
         "Number of entries of the Indirect Pattern Detector")
     ipd_table_assoc = Param.Unsigned(4,
@@ -197,7 +197,8 @@ class SignaturePathPrefetcher(QueuedPrefetcher):
     signature_table_replacement_policy = Param.BaseReplacementPolicy(LRURP(),
         "Replacement policy of the signature table")
 
-    max_counter_value = Param.UInt8(7, "Maximum pattern counter value")
+    num_counter_bits = Param.UInt8(3,
+        "Number of bits of the saturating counters")
     pattern_table_entries = Param.MemorySize("4096",
         "Number of entries of the pattern table")
     pattern_table_assoc = Param.Unsigned(1,
@@ -225,7 +226,7 @@ class SignaturePathPrefetcherV2(SignaturePathPrefetcher):
     signature_table_assoc = 1
     pattern_table_entries = "512"
     pattern_table_assoc = 1
-    max_counter_value = 15
+    num_counter_bits = 4
     prefetch_confidence_threshold = 0.25
     lookahead_confidence_threshold = 0.25
 
@@ -318,8 +319,8 @@ class IrregularStreamBufferPrefetcher(QueuedPrefetcher):
     cxx_class = "IrregularStreamBufferPrefetcher"
     cxx_header = "mem/cache/prefetch/irregular_stream_buffer.hh"
 
-    max_counter_value = Param.Unsigned(3,
-        "Maximum value of the confidence counter")
+    num_counter_bits = Param.Unsigned(2,
+        "Number of bits of the confidence counter")
     chunk_size = Param.Unsigned(256,
         "Maximum number of addresses in a temporal stream")
     degree = Param.Unsigned(4, "Number of prefetches to generate")
index d49652fa883f0d7090482a0dab278d64ed58fb7a..703105166eef4e6abf74eec9ca536927b727185f 100644 (file)
@@ -38,11 +38,11 @@ IndirectMemoryPrefetcher::IndirectMemoryPrefetcher(
     const IndirectMemoryPrefetcherParams *p) : QueuedPrefetcher(p),
     maxPrefetchDistance(p->max_prefetch_distance),
     shiftValues(p->shift_values), prefetchThreshold(p->prefetch_threshold),
-    maxIndirectCounterValue(p->max_indirect_counter_value),
     streamCounterThreshold(p->stream_counter_threshold),
     streamingDistance(p->streaming_distance),
     prefetchTable(p->pt_table_assoc, p->pt_table_entries,
-                  p->pt_table_indexing_policy, p->pt_table_replacement_policy),
+                  p->pt_table_indexing_policy, p->pt_table_replacement_policy,
+                  PrefetchTableEntry(p->num_indirect_counter_bits)),
     ipd(p->ipd_table_assoc, p->ipd_table_entries, p->ipd_table_indexing_policy,
         p->ipd_table_replacement_policy,
         IndirectPatternDetectorEntry(p->addr_array_len, shiftValues.size())),
@@ -135,9 +135,7 @@ IndirectMemoryPrefetcher::calculatePrefetch(const PrefetchInfo &pfi,
                         // Enabled entry, update the index
                         pt_entry->index = index;
                         if (!pt_entry->increasedIndirectCounter) {
-                            if (pt_entry->indirectCounter > 0) {
-                                pt_entry->indirectCounter -= 1;
-                            }
+                            pt_entry->indirectCounter--;
                         } else {
                             // Set this to false, to see if the new index
                             // has any match
@@ -146,8 +144,8 @@ IndirectMemoryPrefetcher::calculatePrefetch(const PrefetchInfo &pfi,
 
                         // If the counter is high enough, start prefetching
                         if (pt_entry->indirectCounter > prefetchThreshold) {
-                            unsigned distance = pt_entry->indirectCounter *
-                                maxPrefetchDistance / maxIndirectCounterValue;
+                            unsigned distance = maxPrefetchDistance *
+                                pt_entry->indirectCounter.calcSaturation();
                             for (int delta = 1; delta < distance; delta += 1) {
                                 Addr pf_addr = pt_entry->baseAddr +
                                     (pt_entry->index << pt_entry->shift);
@@ -237,7 +235,7 @@ IndirectMemoryPrefetcher::trackMissIndex2(Addr miss_addr)
                 pt_entry->baseAddr = ba_array[idx];
                 pt_entry->shift = shift;
                 pt_entry->enabled = true;
-                pt_entry->indirectCounter = 0;
+                pt_entry->indirectCounter.reset();
                 // Release the current IPD Entry
                 entry->reset();
                 // Do not track more misses
@@ -256,10 +254,8 @@ IndirectMemoryPrefetcher::checkAccessMatchOnActiveEntries(Addr addr)
         if (pt_entry.enabled) {
             if (addr == pt_entry.baseAddr +
                        (pt_entry.index << pt_entry.shift)) {
-                if (pt_entry.indirectCounter < maxIndirectCounterValue) {
-                    pt_entry.indirectCounter += 1;
-                    pt_entry.increasedIndirectCounter = true;
-                }
+                pt_entry.indirectCounter++;
+                pt_entry.increasedIndirectCounter = true;
             }
         }
     }
index b67cdfb0aececceb14177eae6515fbc63925ae33..f177c5c06f1057ea7495704038568417020cbcd6 100644 (file)
@@ -43,6 +43,7 @@
 
 #include <vector>
 
+#include "base/sat_counter.hh"
 #include "mem/cache/prefetch/associative_set.hh"
 #include "mem/cache/prefetch/queued.hh"
 
@@ -56,8 +57,6 @@ class IndirectMemoryPrefetcher : public QueuedPrefetcher
     const std::vector<int> shiftValues;
     /** Counter threshold to start prefetching */
     const unsigned int prefetchThreshold;
-    /** Maximum value of the confidence indirectCounter */
-    const unsigned int maxIndirectCounterValue;
     /** streamCounter value to trigger the streaming prefetcher */
     const int streamCounterThreshold;
     /** Number of prefetches generated when using the streaming prefetcher */
@@ -86,7 +85,7 @@ class IndirectMemoryPrefetcher : public QueuedPrefetcher
         /** Shift detected */
         int shift;
         /** Confidence counter of the indirect fields */
-        int indirectCounter;
+        SatCounter indirectCounter;
         /**
          * This variable is set to indicate that there has been at least one
          * match with the current index value. This information is later used
@@ -95,9 +94,11 @@ class IndirectMemoryPrefetcher : public QueuedPrefetcher
          */
         bool increasedIndirectCounter;
 
-        PrefetchTableEntry() : TaggedEntry(), address(0), secure(false),
-            streamCounter(0), enabled(false), index(0), baseAddr(0), shift(0),
-            indirectCounter(0), increasedIndirectCounter(false)
+        PrefetchTableEntry(unsigned indirect_counter_bits)
+            : TaggedEntry(), address(0), secure(false), streamCounter(0),
+              enabled(false), index(0), baseAddr(0), shift(0),
+              indirectCounter(indirect_counter_bits),
+              increasedIndirectCounter(false)
         {}
 
         void reset() override {
@@ -108,7 +109,7 @@ class IndirectMemoryPrefetcher : public QueuedPrefetcher
             index = 0;
             baseAddr = 0;
             shift = 0;
-            indirectCounter = 0;
+            indirectCounter.reset();
             increasedIndirectCounter = false;
         }
     };
index 345fe70601aecd118cfb80a2fc40345cdca9d8fc..73fa9eb42a541a2483796e98d5210095e781f4b6 100644 (file)
@@ -36,7 +36,7 @@
 
 IrregularStreamBufferPrefetcher::IrregularStreamBufferPrefetcher(
     const IrregularStreamBufferPrefetcherParams *p)
-    : QueuedPrefetcher(p), maxCounterValue(p->max_counter_value),
+    : QueuedPrefetcher(p),
         chunkSize(p->chunk_size),
         prefetchCandidatesPerEntry(p->prefetch_candidates_per_entry),
         degree(p->degree),
@@ -47,12 +47,14 @@ IrregularStreamBufferPrefetcher::IrregularStreamBufferPrefetcher(
                               p->address_map_cache_entries,
                               p->ps_address_map_cache_indexing_policy,
                               p->ps_address_map_cache_replacement_policy,
-                              AddressMappingEntry(prefetchCandidatesPerEntry)),
+                              AddressMappingEntry(prefetchCandidatesPerEntry,
+                                                  p->num_counter_bits)),
         spAddressMappingCache(p->address_map_cache_assoc,
                               p->address_map_cache_entries,
                               p->sp_address_map_cache_indexing_policy,
                               p->sp_address_map_cache_replacement_policy,
-                              AddressMappingEntry(prefetchCandidatesPerEntry)),
+                              AddressMappingEntry(prefetchCandidatesPerEntry,
+                                                  p->num_counter_bits)),
         structuralAddressCounter(0)
 {
     assert(isPowerOf2(prefetchCandidatesPerEntry));
@@ -100,30 +102,29 @@ IrregularStreamBufferPrefetcher::calculatePrefetch(const PrefetchInfo &pfi,
         if (mapping_A.counter > 0 && mapping_B.counter > 0) {
             // Entry for A and B
             if (mapping_B.address == (mapping_A.address + 1)) {
-                if (mapping_B.counter < maxCounterValue) {
-                    mapping_B.counter += 1;
-                }
+                mapping_B.counter++;
             } else {
                 if (mapping_B.counter == 1) {
-                    // counter would hit 0, reassign address
-                    mapping_B.counter = 1;
+                    // Counter would hit 0, reassign address while keeping
+                    // counter at 1
                     mapping_B.address = mapping_A.address + 1;
                     addStructuralToPhysicalEntry(mapping_B.address, is_secure,
                             correlated_addr_B);
                 } else {
-                    mapping_B.counter -= 1;
+                    mapping_B.counter--;
                 }
             }
         } else {
             if (mapping_A.counter == 0) {
                 // if A is not valid, generate a new structural address
-                mapping_A.counter = 1;
+                mapping_A.counter++;
                 mapping_A.address = structuralAddressCounter;
                 structuralAddressCounter += chunkSize;
                 addStructuralToPhysicalEntry(mapping_A.address,
                         is_secure, correlated_addr_A);
             }
-            mapping_B.counter = 1;
+            mapping_B.counter.reset();
+            mapping_B.counter++;
             mapping_B.address = mapping_A.address + 1;
             // update SP-AMC
             addStructuralToPhysicalEntry(mapping_B.address, is_secure,
@@ -203,7 +204,8 @@ IrregularStreamBufferPrefetcher::addStructuralToPhysicalEntry(
     }
     AddressMapping &mapping = sp_entry->mappings[map_index];
     mapping.address = physical_address;
-    mapping.counter = 1;
+    mapping.counter.reset();
+    mapping.counter++;
 }
 
 IrregularStreamBufferPrefetcher*
index 47038cbb7152e493f6e53d2f9cf102657b9a8809..c97fde84db48b28e68010cca6740fa9f80052616 100644 (file)
@@ -41,6 +41,7 @@
 #define __MEM_CACHE_PREFETCH_IRREGULAR_STREAM_BUFFER_HH__
 
 #include "base/callback.hh"
+#include "base/sat_counter.hh"
 #include "mem/cache/prefetch/associative_set.hh"
 #include "mem/cache/prefetch/queued.hh"
 
@@ -48,8 +49,6 @@ struct IrregularStreamBufferPrefetcherParams;
 
 class IrregularStreamBufferPrefetcher : public QueuedPrefetcher
 {
-    /** Maximum value of the confidence counters */
-    const unsigned maxCounterValue;
     /** Size in bytes of a temporal stream */
     const size_t chunkSize;
     /** Number of prefetch candidates per Physical-to-Structural entry */
@@ -71,8 +70,8 @@ class IrregularStreamBufferPrefetcher : public QueuedPrefetcher
     /** Address Mapping entry, holds an address and a confidence counter */
     struct AddressMapping {
         Addr address;
-        unsigned counter;
-        AddressMapping() : address(0), counter(0)
+        SatCounter counter;
+        AddressMapping(unsigned bits) : address(0), counter(bits)
         {}
     };
 
@@ -82,13 +81,14 @@ class IrregularStreamBufferPrefetcher : public QueuedPrefetcher
      */
     struct AddressMappingEntry : public TaggedEntry {
         std::vector<AddressMapping> mappings;
-        AddressMappingEntry(size_t num_mappings) : mappings(num_mappings)
+        AddressMappingEntry(size_t num_mappings, unsigned counter_bits)
+            : mappings(num_mappings, counter_bits)
         {}
         void reset() override
         {
             for (auto &entry : mappings) {
                 entry.address = 0;
-                entry.counter = 0;
+                entry.counter.reset();
             }
         }
     };
index 857354e656156ba44a82956e91bb73eb08c95490..febc47132bdba94eecf34f063a3becd1966f0d07 100644 (file)
@@ -31,6 +31,7 @@
 #include "mem/cache/prefetch/signature_path.hh"
 
 #include <cassert>
+#include <climits>
 
 #include "debug/HWPrefetch.hh"
 #include "mem/cache/prefetch/associative_set_impl.hh"
@@ -42,7 +43,6 @@ SignaturePathPrefetcher::SignaturePathPrefetcher(
       stridesPerPatternEntry(p->strides_per_pattern_entry),
       signatureShift(p->signature_shift),
       signatureBits(p->signature_bits),
-      maxCounterValue(p->max_counter_value),
       prefetchConfidenceThreshold(p->prefetch_confidence_threshold),
       lookaheadConfidenceThreshold(p->lookahead_confidence_threshold),
       signatureTable(p->signature_table_assoc, p->signature_table_entries,
@@ -51,7 +51,7 @@ SignaturePathPrefetcher::SignaturePathPrefetcher(
       patternTable(p->pattern_table_assoc, p->pattern_table_entries,
                    p->pattern_table_indexing_policy,
                    p->pattern_table_replacement_policy,
-                   PatternEntry(stridesPerPatternEntry))
+                   PatternEntry(stridesPerPatternEntry, p->num_counter_bits))
 {
     fatal_if(prefetchConfidenceThreshold < 0,
         "The prefetch confidence threshold must be greater than 0\n");
@@ -64,8 +64,7 @@ SignaturePathPrefetcher::SignaturePathPrefetcher(
 }
 
 SignaturePathPrefetcher::PatternStrideEntry &
-SignaturePathPrefetcher::PatternEntry::getStrideEntry(stride_t stride,
-                                                     uint8_t max_counter_value)
+SignaturePathPrefetcher::PatternEntry::getStrideEntry(stride_t stride)
 {
     PatternStrideEntry *pstride_entry = findStride(stride);
     if (pstride_entry == nullptr) {
@@ -76,18 +75,16 @@ SignaturePathPrefetcher::PatternEntry::getStrideEntry(stride_t stride,
         // If all counters have the max value, this will be the pick
         PatternStrideEntry *victim_pstride_entry = &(strideEntries[0]);
 
-        uint8_t current_counter = max_counter_value;
+        unsigned long current_counter = ULONG_MAX;
         for (auto &entry : strideEntries) {
             if (entry.counter < current_counter) {
                 victim_pstride_entry = &entry;
                 current_counter = entry.counter;
             }
-            if (entry.counter > 0) {
-                entry.counter -= 1;
-            }
+            entry.counter--;
         }
         pstride_entry = victim_pstride_entry;
-        pstride_entry->counter = 0;
+        pstride_entry->counter.reset();
         pstride_entry->stride = stride;
     }
     return *pstride_entry;
@@ -147,9 +144,7 @@ void
 SignaturePathPrefetcher::increasePatternEntryCounter(
         PatternEntry &pattern_entry, PatternStrideEntry &pstride_entry)
 {
-    if (pstride_entry.counter < maxCounterValue) {
-        pstride_entry.counter += 1;
-    }
+    pstride_entry.counter++;
 }
 
 void
@@ -158,8 +153,7 @@ SignaturePathPrefetcher::updatePatternTable(Addr signature, stride_t stride)
     assert(stride != 0);
     // The pattern table is indexed by signatures
     PatternEntry &p_entry = getPatternEntry(signature);
-    PatternStrideEntry &ps_entry = p_entry.getStrideEntry(stride,
-                                                          maxCounterValue);
+    PatternStrideEntry &ps_entry = p_entry.getStrideEntry(stride);
     increasePatternEntryCounter(p_entry, ps_entry);
 }
 
@@ -209,23 +203,21 @@ double
 SignaturePathPrefetcher::calculatePrefetchConfidence(PatternEntry const &sig,
         PatternStrideEntry const &entry) const
 {
-    return ((double) entry.counter) / maxCounterValue;
+    return entry.counter.calcSaturation();
 }
 
 double
 SignaturePathPrefetcher::calculateLookaheadConfidence(PatternEntry const &sig,
         PatternStrideEntry const &lookahead) const
 {
-    double lookahead_confidence;
-    if (lookahead.counter == maxCounterValue) {
+    double lookahead_confidence = lookahead.counter.calcSaturation();
+    if (lookahead_confidence > 0.95) {
         /**
          * maximum confidence is 0.95, guaranteeing that
          * current confidence will eventually fall beyond
          * the threshold
          */
         lookahead_confidence = 0.95;
-    } else {
-        lookahead_confidence = ((double) lookahead.counter / maxCounterValue);
     }
     return lookahead_confidence;
 }
@@ -280,7 +272,7 @@ SignaturePathPrefetcher::calculatePrefetch(const PrefetchInfo &pfi,
             patternTable.findEntry(current_signature, false);
         PatternStrideEntry const *lookahead = nullptr;
         if (current_pattern_entry != nullptr) {
-            uint8_t max_counter = 0;
+            unsigned long max_counter = 0;
             for (auto const &entry : current_pattern_entry->strideEntries) {
                 //select the entry with the maximum counter value as lookahead
                 if (max_counter < entry.counter) {
index 974c027461991e0e95fc3114aef816a8abde83ab..3bf4dd29319a6631c1380fbca19c0dea3f7c912c 100644 (file)
@@ -42,6 +42,7 @@
 #ifndef __MEM_CACHE_PREFETCH_SIGNATURE_PATH_HH__
 #define __MEM_CACHE_PREFETCH_SIGNATURE_PATH_HH__
 
+#include "base/sat_counter.hh"
 #include "mem/cache/prefetch/associative_set.hh"
 #include "mem/cache/prefetch/queued.hh"
 #include "mem/packet.hh"
@@ -62,8 +63,6 @@ class SignaturePathPrefetcher : public QueuedPrefetcher
     const uint8_t signatureShift;
     /** Size of the signature, in bits */
     const signature_t signatureBits;
-    /** Maximum pattern entries counter value */
-    const uint8_t maxCounterValue;
     /** Minimum confidence to issue a prefetch */
     const double prefetchConfidenceThreshold;
     /** Minimum confidence to keep navigating lookahead entries */
@@ -87,9 +86,9 @@ class SignaturePathPrefetcher : public QueuedPrefetcher
     {
         /** stride in a page in blkSize increments */
         stride_t stride;
-        /** counter value (max value defined by maxCounterValue) */
-        uint8_t counter;
-        PatternStrideEntry() : stride(0), counter(0)
+        /** Saturating counter */
+        SatCounter counter;
+        PatternStrideEntry(unsigned bits) : stride(0), counter(bits)
         {}
     };
     /** Pattern entry data type, a set of stride and counter entries */
@@ -98,19 +97,19 @@ class SignaturePathPrefetcher : public QueuedPrefetcher
         /** group of stides */
         std::vector<PatternStrideEntry> strideEntries;
         /** use counter, used by SPPv2 */
-        uint8_t counter;
-        PatternEntry(size_t num_strides) : strideEntries(num_strides),
-                                           counter(0)
+        SatCounter counter;
+        PatternEntry(size_t num_strides, unsigned counter_bits)
+            : strideEntries(num_strides, counter_bits), counter(counter_bits)
         {}
 
         /** Reset the entries to their initial values */
         void reset() override
         {
             for (auto &entry : strideEntries) {
-                entry.counter = 0;
+                entry.counter.reset();
                 entry.stride = 0;
             }
-            counter = 0;
+            counter.reset();
         }
 
         /**
@@ -135,13 +134,9 @@ class SignaturePathPrefetcher : public QueuedPrefetcher
          * Gets the entry with the provided stride, if there is no entry with
          * the associated stride, it replaces one of them.
          * @param stride the stride to find
-         * @param max_counter_value maximum value of the confidence counters,
-         *        it is used when no strides are found and an entry needs to be
-         *        replaced
          * @result reference to the selected entry
          */
-        PatternStrideEntry &getStrideEntry(stride_t stride,
-                                           uint8_t max_counter_value);
+        PatternStrideEntry &getStrideEntry(stride_t stride);
     };
     /** Pattern table */
     AssociativeSet<PatternEntry> patternTable;
index 571c3d12bdf49c2a7455657beac8eae6f59a3688..908e8dac6f18bf412a66f49437f916095ad806bb 100644 (file)
@@ -96,20 +96,20 @@ void
 SignaturePathPrefetcherV2::increasePatternEntryCounter(
         PatternEntry &pattern_entry, PatternStrideEntry &pstride_entry)
 {
-    if (pattern_entry.counter == maxCounterValue) {
+    if (pattern_entry.counter.isSaturated()) {
         pattern_entry.counter >>= 1;
         for (auto &entry : pattern_entry.strideEntries) {
             entry.counter >>= 1;
         }
     }
-    if (pstride_entry.counter == maxCounterValue) {
+    if (pstride_entry.counter.isSaturated()) {
         pattern_entry.counter >>= 1;
         for (auto &entry : pattern_entry.strideEntries) {
             entry.counter >>= 1;
         }
     }
-    pattern_entry.counter += 1;
-    pstride_entry.counter += 1;
+    pattern_entry.counter++;
+    pstride_entry.counter++;
 }
 
 void
index df5190a971f004fa7fa093254351f0b044990b0d..cf6144a64f9773308ee7174d242dcfb476dfef6a 100644 (file)
@@ -212,7 +212,7 @@ STeMSPrefetcher::reconstructSequence(unsigned int rmob_idx,
             patternSequenceTable.accessEntry(pst_entry);
             for (auto &seq_entry : pst_entry->sequence) {
                 if (seq_entry.counter > 1) {
-                    // 3-bit counter: high enough confidence with a
+                    // 2-bit counter: high enough confidence with a
                     // value greater than 1
                     Addr rec_addr = rmob[i].srAddress * spatialRegionSize +
                         seq_entry.offset;
index a7e25fe02610a06fa3b0aa4ad0808ca30e6b6c21..34cf5d12a121056525e4884e866b8a21a5cbaac1 100644 (file)
@@ -45,6 +45,7 @@
 
 #include <vector>
 
+#include "base/sat_counter.hh"
 #include "mem/cache/prefetch/associative_set.hh"
 #include "mem/cache/prefetch/queued.hh"
 
@@ -74,12 +75,12 @@ class STeMSPrefetcher : public QueuedPrefetcher
         /** Sequence entry data type */
         struct SequenceEntry {
             /** 2-bit confidence counter */
-            unsigned int counter;
+            SatCounter counter;
             /** Offset, in cache lines, within the spatial region */
             unsigned int offset;
             /** Intearleaving position on the global access sequence */
             unsigned int delta;
-            SequenceEntry() : counter(0), offset(0), delta(0)
+            SequenceEntry() : counter(2), offset(0), delta(0)
             {}
         };
         /** Sequence of accesses */
@@ -95,7 +96,7 @@ class STeMSPrefetcher : public QueuedPrefetcher
             pc = 0;
             seqCounter = 0;
             for (auto &seq_entry : sequence) {
-                seq_entry.counter = 0;
+                seq_entry.counter.reset();
                 seq_entry.offset = 0;
                 seq_entry.delta = 0;
             }
@@ -125,15 +126,12 @@ class STeMSPrefetcher : public QueuedPrefetcher
             for (auto &seq_entry : sequence) {
                 if (seq_entry.counter > 0) {
                     if (seq_entry.offset == offset) {
-                        //2 bit counter, saturates at 3
-                        if (seq_entry.counter < 3) {
-                            seq_entry.counter += 1;
-                        }
+                        seq_entry.counter++;
                     }
                 } else {
                     // If the counter is 0 it means that this position is not
                     // being used, and we can allocate the new offset here
-                    seq_entry.counter = 1;
+                    seq_entry.counter++;
                     seq_entry.offset = offset;
                     seq_entry.delta = seqCounter;
                     break;