mem-cache: Add a masked pattern to compressors
[gem5.git] / src / mem / cache / compressors / dictionary_compressor.hh
index 87e69ccc890441fcca3009d1ac189f14eeeca053..3c828f040f64ee7616e543ad733c3f34ec2ee275 100644 (file)
 #include "base/types.hh"
 #include "mem/cache/compressors/base.hh"
 
-struct DictionaryCompressorParams;
+struct BaseDictionaryCompressorParams;
 
-class DictionaryCompressor : public BaseCacheCompressor
+class BaseDictionaryCompressor : public BaseCacheCompressor
 {
   protected:
+    /** Dictionary size. */
+    const std::size_t dictionarySize;
+
+    /** Number of valid entries in the dictionary. */
+    std::size_t numEntries;
+
+    /**
+     * @defgroup CompressionStats Compression specific statistics.
+     * @{
+     */
+
+    /** Number of data entries that were compressed to each pattern. */
+    Stats::Vector patternStats;
+
+    /**
+     * @}
+     */
+
+    /**
+     * Trick function to get the number of patterns.
+     *
+     * @return The number of defined patterns.
+     */
+    virtual uint64_t getNumPatterns() const = 0;
+
+    /**
+     * Get meta-name assigned to the given pattern.
+     *
+     * @param number The number of the pattern.
+     * @return The meta-name of the pattern.
+     */
+    virtual std::string getName(int number) const = 0;
+
+  public:
+    typedef BaseDictionaryCompressorParams Params;
+    BaseDictionaryCompressor(const Params *p);
+    ~BaseDictionaryCompressor() = default;
+
+    void regStats() override;
+};
+
+/**
+ * A template version of the dictionary compressor that allows to choose the
+ * dictionary size.
+ *
+ * @tparam The type of a dictionary entry (e.g., uint16_t, uint32_t, etc).
+ */
+template <class T>
+class DictionaryCompressor : public BaseDictionaryCompressor
+{
+  protected:
+    /** Convenience typedef for a dictionary entry. */
+    typedef std::array<uint8_t, sizeof(T)> DictionaryEntry;
+
     /**
      * Compression data for the dictionary compressor. It consists of a vector
      * of patterns.
@@ -68,6 +122,9 @@ class DictionaryCompressor : public BaseCacheCompressor
 
     // Forward declaration of a pattern
     class Pattern;
+    class UncompressedPattern;
+    template <T mask>
+    class MaskedPattern;
 
     /**
      * Create a factory to determine if input matches a pattern. The if else
@@ -78,8 +135,8 @@ class DictionaryCompressor : public BaseCacheCompressor
     struct Factory
     {
         static std::unique_ptr<Pattern> getPattern(
-            const std::array<uint8_t, 4>& bytes,
-            const std::array<uint8_t, 4>& dict_bytes, const int match_location)
+            const DictionaryEntry& bytes, const DictionaryEntry& dict_bytes,
+            const int match_location)
         {
             // If match this pattern, instantiate it. If a negative match
             // location is used, the patterns that use the dictionary bytes
@@ -96,65 +153,37 @@ class DictionaryCompressor : public BaseCacheCompressor
         }
     };
 
-    /** Specialization to end the recursion. */
+    /**
+     * Specialization to end the recursion. This must be called when all
+     * other patterns failed, and there is no choice but to leave data
+     * uncompressed. As such, this pattern must inherit from the uncompressed
+     * pattern.
+     */
     template <class Head>
     struct Factory<Head>
     {
-        static std::unique_ptr<Pattern> getPattern(
-            const std::array<uint8_t, 4>& bytes,
-            const std::array<uint8_t, 4>& dict_bytes, const int match_location)
+        static_assert(std::is_base_of<UncompressedPattern, Head>::value,
+            "The last pattern must always be derived from the uncompressed "
+            "pattern.");
+
+        static std::unique_ptr<Pattern>
+        getPattern(const DictionaryEntry& bytes,
+            const DictionaryEntry& dict_bytes, const int match_location)
         {
-            // Instantiate last pattern. Should be the XXXX pattern.
             return std::unique_ptr<Pattern>(new Head(bytes, match_location));
         }
     };
 
     /** The dictionary. */
-    std::vector<std::array<uint8_t, 4>> dictionary;
-
-    /** Dictionary size. */
-    const std::size_t dictionarySize;
-
-    /** Number of valid entries in the dictionary. */
-    std::size_t numEntries;
-
-    /**
-     * @defgroup CompressionStats Compression specific statistics.
-     * @{
-     */
-
-    /**
-     * Number of data entries that were compressed to each pattern.
-     */
-    Stats::Vector patternStats;
-
-    /**
-     * @}
-     */
-
-    /**
-     * Trick function to get the number of patterns.
-     *
-     * @return The number of defined patterns.
-     */
-    virtual uint64_t getNumPatterns() const = 0;
-
-    /**
-     * Get meta-name assigned to the given pattern.
-     *
-     * @param number The number of the pattern.
-     * @return The meta-name of the pattern.
-     */
-    virtual std::string getName(int number) const = 0;
+    std::vector<DictionaryEntry> dictionary;
 
     /**
      * Since the factory cannot be instantiated here, classes that inherit
      * from this base class have to implement the call to their factory's
      * getPattern.
      */
-    virtual std::unique_ptr<Pattern> getPattern(
-        const std::array<uint8_t, 4>& bytes,
-        const std::array<uint8_t, 4>& dict_bytes,
+    virtual std::unique_ptr<Pattern>
+    getPattern(const DictionaryEntry& bytes, const DictionaryEntry& dict_bytes,
         const int match_location) const = 0;
 
     /**
@@ -163,15 +192,15 @@ class DictionaryCompressor : public BaseCacheCompressor
      * @param data Data to be compressed.
      * @return The pattern this data matches.
      */
-    std::unique_ptr<Pattern> compressWord(const uint32_t data);
+    std::unique_ptr<Pattern> compressValue(const T data);
 
     /**
-     * Decompress a word.
+     * Decompress a pattern into a value that fits in a dictionary entry.
      *
      * @param pattern The pattern to be decompressed.
      * @return The decompressed word.
      */
-    uint32_t decompressWord(const Pattern* pattern);
+    T decompressValue(const Pattern* pattern);
 
     /** Clear all dictionary entries. */
     void resetDictionary();
@@ -181,7 +210,7 @@ class DictionaryCompressor : public BaseCacheCompressor
      *
      * @param data The new entry.
      */
-    virtual void addToDictionary(std::array<uint8_t, 4> data) = 0;
+    virtual void addToDictionary(const DictionaryEntry data) = 0;
 
     /**
      * Apply compression.
@@ -200,18 +229,26 @@ class DictionaryCompressor : public BaseCacheCompressor
      */
     void decompress(const CompressionData* comp_data, uint64_t* data) override;
 
-  public:
-    /** Convenience typedef. */
-    typedef DictionaryCompressorParams Params;
-
-    /** Default constructor. */
-    DictionaryCompressor(const Params *p);
+    /**
+     * Turn a value into a dictionary entry.
+     *
+     * @param value The value to turn.
+     * @return A dictionary entry containing the value.
+     */
+    static DictionaryEntry toDictionaryEntry(T value);
 
-    /** Default destructor. */
-    ~DictionaryCompressor() {};
+    /**
+     * Turn a dictionary entry into a value.
+     *
+     * @param The dictionary entry to turn.
+     * @return The value that the dictionary entry contained.
+     */
+    static T fromDictionaryEntry(const DictionaryEntry& entry);
 
-    /** Register local statistics. */
-    void regStats() override;
+  public:
+    typedef BaseDictionaryCompressorParams Params;
+    DictionaryCompressor(const Params *p);
+    ~DictionaryCompressor() = default;
 };
 
 /**
@@ -220,7 +257,8 @@ class DictionaryCompressor : public BaseCacheCompressor
  * decompress(). Then the new pattern must be added to the PatternFactory
  * declaration in crescent order of size (in the DictionaryCompressor class).
  */
-class DictionaryCompressor::Pattern
+template <class T>
+class DictionaryCompressor<T>::Pattern
 {
   protected:
     /** Pattern enum number. */
@@ -322,25 +360,124 @@ class DictionaryCompressor::Pattern
      * @param dict_bytes The bytes in the corresponding matching entry.
      * @return The decompressed pattern.
      */
-    virtual std::array<uint8_t, 4> decompress(
-        const std::array<uint8_t, 4> dict_bytes) const = 0;
+    virtual DictionaryEntry decompress(
+        const DictionaryEntry dict_bytes) const = 0;
 };
 
-class DictionaryCompressor::CompData : public CompressionData
+template <class T>
+class DictionaryCompressor<T>::CompData : public CompressionData
 {
   public:
     /** The patterns matched in the original line. */
     std::vector<std::unique_ptr<Pattern>> entries;
 
+    CompData();
+    ~CompData() = default;
+
     /**
-     * Default constructor.
+     * Add a pattern entry to the list of patterns.
      *
-     * @param dictionary_size Number of entries in the dictionary.
+     * @param entry The new pattern entry.
      */
-    CompData(const std::size_t dictionary_size);
+    virtual void addEntry(std::unique_ptr<Pattern>);
+};
 
-    /** Default destructor. */
-    ~CompData();
+/**
+ * A pattern containing the original uncompressed data. This should be the
+ * worst case of every pattern factory, where if all other patterns fail,
+ * an instance of this pattern is created.
+ */
+template <class T>
+class DictionaryCompressor<T>::UncompressedPattern
+    : public DictionaryCompressor<T>::Pattern
+{
+  private:
+    /** A copy of the original data. */
+    const DictionaryEntry data;
+
+  public:
+    UncompressedPattern(const int number,
+        const uint64_t code,
+        const uint64_t metadata_length,
+        const int match_location,
+        const DictionaryEntry bytes)
+      : DictionaryCompressor<T>::Pattern(number, code, metadata_length,
+            sizeof(T), match_location, true),
+        data(bytes)
+    {
+    }
+
+    static bool
+    isPattern(const DictionaryEntry& bytes, const DictionaryEntry& dict_bytes,
+        const int match_location)
+    {
+        // An entry can always be uncompressed
+        return true;
+    }
+
+    DictionaryEntry
+    decompress(const DictionaryEntry dict_bytes) const override
+    {
+        return data;
+    }
+};
+
+/**
+ * A pattern that compares masked values against dictionary entries. If
+ * the masked dictionary entry matches perfectly the masked value to be
+ * compressed, there is a pattern match.
+ *
+ * For example, if the mask is 0xFF00 (that is, this pattern matches the MSB),
+ * the value (V) 0xFF20 is being compressed, and the dictionary contains
+ * the value (D) 0xFF03, this is a match (V & mask == 0xFF00 == D & mask),
+ * and 0x0020 is added to the list of unmatched bits.
+ *
+ * @tparam mask A mask containing the bits that must match.
+ */
+template <class T>
+template <T mask>
+class DictionaryCompressor<T>::MaskedPattern
+    : public DictionaryCompressor<T>::Pattern
+{
+  private:
+    static_assert(mask != 0, "The pattern's value mask must not be zero. Use "
+        "the uncompressed pattern instead.");
+
+    /** A copy of the bits that do not belong to the mask. */
+    const T bits;
+
+  public:
+    MaskedPattern(const int number,
+        const uint64_t code,
+        const uint64_t metadata_length,
+        const int match_location,
+        const DictionaryEntry bytes,
+        const bool allocate = true)
+      : DictionaryCompressor<T>::Pattern(number, code, metadata_length,
+            popCount(~mask) / 8, match_location, allocate),
+        bits(DictionaryCompressor<T>::fromDictionaryEntry(bytes) & ~mask)
+    {
+    }
+
+    static bool
+    isPattern(const DictionaryEntry& bytes, const DictionaryEntry& dict_bytes,
+        const int match_location)
+    {
+        const T masked_bytes =
+            DictionaryCompressor<T>::fromDictionaryEntry(bytes) & mask;
+        const T masked_dict_bytes =
+            DictionaryCompressor<T>::fromDictionaryEntry(dict_bytes) & mask;
+        return (match_location >= 0) && (masked_bytes == masked_dict_bytes);
+    }
+
+    DictionaryEntry
+    decompress(const DictionaryEntry dict_bytes) const override
+    {
+        const T masked_dict_bytes =
+            DictionaryCompressor<T>::fromDictionaryEntry(dict_bytes) & mask;
+        return DictionaryCompressor<T>::toDictionaryEntry(
+            bits | masked_dict_bytes);
+    }
 };
 
 #endif //__MEM_CACHE_COMPRESSORS_DICTIONARY_COMPRESSOR_HH__