mem-garnet: Integration of HeteroGarnet
[gem5.git] / src / mem / ruby / network / garnet2.0 / NetworkInterface.hh
index 7e3083844a67e14ae69e0a6a9da3de70ee37046b..945b446d7b6191eedba432a01e0338f629e950b7 100644 (file)
@@ -1,4 +1,5 @@
 /*
+ * Copyright (c) 2020 Advanced Micro Devices, Inc.
  * Copyright (c) 2020 Inria
  * Copyright (c) 2016 Georgia Institute of Technology
  * Copyright (c) 2008 Princeton University
@@ -37,6 +38,7 @@
 
 #include "mem/ruby/common/Consumer.hh"
 #include "mem/ruby/network/garnet2.0/CommonTypes.hh"
+#include "mem/ruby/network/garnet2.0/Credit.hh"
 #include "mem/ruby/network/garnet2.0/CreditLink.hh"
 #include "mem/ruby/network/garnet2.0/GarnetNetwork.hh"
 #include "mem/ruby/network/garnet2.0/NetworkLink.hh"
@@ -67,37 +69,213 @@ class NetworkInterface : public ClockedObject, public Consumer
 
     void print(std::ostream& out) const;
     int get_vnet(int vc);
-    int get_router_id() { return m_router_id; }
     void init_net_ptr(GarnetNetwork *net_ptr) { m_net_ptr = net_ptr; }
 
     uint32_t functionalWrite(Packet *);
 
+    void scheduleFlit(flit *t_flit);
+
+    int get_router_id(int vnet)
+    {
+        OutputPort *oPort = getOutportForVnet(vnet);
+        assert(oPort);
+        return oPort->routerID();
+    }
+
+    class OutputPort
+    {
+      public:
+          OutputPort(NetworkLink *outLink, CreditLink *creditLink,
+              int routerID)
+          {
+              _vnets = outLink->mVnets;
+              _outFlitQueue = new flitBuffer();
+
+              _outNetLink = outLink;
+              _inCreditLink = creditLink;
+
+              _routerID = routerID;
+              _bitWidth = outLink->bitWidth;
+              _vcRoundRobin = 0;
+
+          }
+
+          flitBuffer *
+          outFlitQueue()
+          {
+              return _outFlitQueue;
+          }
+
+          NetworkLink *
+          outNetLink()
+          {
+              return _outNetLink;
+          }
+
+          CreditLink *
+          inCreditLink()
+          {
+              return _inCreditLink;
+          }
+
+          int
+          routerID()
+          {
+              return _routerID;
+          }
+
+          uint32_t bitWidth()
+          {
+              return _bitWidth;
+          }
+
+          bool isVnetSupported(int pVnet)
+          {
+              if (!_vnets.size()) {
+                  return true;
+              }
+
+              for (auto &it : _vnets) {
+                  if (it == pVnet) {
+                      return true;
+                  }
+              }
+              return false;
+
+          }
+
+          std::string
+          printVnets()
+          {
+              std::stringstream ss;
+              for (auto &it : _vnets) {
+                  ss << it;
+                  ss << " ";
+              }
+              return ss.str();
+          }
+
+          int vcRoundRobin()
+          {
+              return _vcRoundRobin;
+          }
+
+          void vcRoundRobin(int vc)
+          {
+              _vcRoundRobin = vc;
+          }
+
+
+      private:
+          std::vector<int> _vnets;
+          flitBuffer *_outFlitQueue;
+
+          NetworkLink *_outNetLink;
+          CreditLink *_inCreditLink;
+
+          int _vcRoundRobin; // For round robin scheduling
+
+          int _routerID;
+          uint32_t _bitWidth;
+    };
+
+    class InputPort
+    {
+      public:
+          InputPort(NetworkLink *inLink, CreditLink *creditLink)
+          {
+              _vnets = inLink->mVnets;
+              _outCreditQueue = new flitBuffer();
+
+              _inNetLink = inLink;
+              _outCreditLink = creditLink;
+              _bitWidth = inLink->bitWidth;
+          }
+
+          flitBuffer *
+          outCreditQueue()
+          {
+              return _outCreditQueue;
+          }
+
+          NetworkLink *
+          inNetLink()
+          {
+              return _inNetLink;
+          }
+
+          CreditLink *
+          outCreditLink()
+          {
+              return _outCreditLink;
+          }
+
+          bool isVnetSupported(int pVnet)
+          {
+              if (!_vnets.size()) {
+                  return true;
+              }
+
+              for (auto &it : _vnets) {
+                  if (it == pVnet) {
+                      return true;
+                  }
+              }
+              return false;
+
+          }
+
+          void sendCredit(Credit *cFlit)
+          {
+              _outCreditQueue->insert(cFlit);
+          }
+
+          uint32_t bitWidth()
+          {
+              return _bitWidth;
+          }
+
+          std::string
+          printVnets()
+          {
+              std::stringstream ss;
+              for (auto &it : _vnets) {
+                  ss << it;
+                  ss << " ";
+              }
+              return ss.str();
+          }
+
+          // Queue for stalled flits
+          std::deque<flit *> m_stall_queue;
+          bool messageEnqueuedThisCycle;
+      private:
+          std::vector<int> _vnets;
+          flitBuffer *_outCreditQueue;
+
+          NetworkLink *_inNetLink;
+          CreditLink *_outCreditLink;
+          uint32_t _bitWidth;
+    };
+
+
   private:
     GarnetNetwork *m_net_ptr;
     const NodeID m_id;
     const int m_virtual_networks, m_vc_per_vnet;
     int m_router_id; // id of my router
     std::vector<int> m_vc_allocator;
-    int m_vc_round_robin; // For round robin scheduling
-    /** Used to model link contention. */
-    flitBuffer outFlitQueue;
-    flitBuffer outCreditQueue;
+    std::vector<OutputPort *> outPorts;
+    std::vector<InputPort *> inPorts;
     int m_deadlock_threshold;
     std::vector<OutVcState> outVcState;
 
-    NetworkLink *inNetLink;
-    NetworkLink *outNetLink;
-    CreditLink *inCreditLink;
-    CreditLink *outCreditLink;
-
-    // Queue for stalled flits
-    std::deque<flit *> m_stall_queue;
     std::vector<int> m_stall_count;
 
     // Input Flit Buffers
     // The flit buffers which will serve the Consumer
     std::vector<flitBuffer>  niOutVcs;
-    std::vector<Cycles> m_ni_out_vcs_enqueue_time;
+    std::vector<Tick> m_ni_out_vcs_enqueue_time;
 
     // The Message buffers that takes messages from the protocol
     std::vector<MessageBuffer *> inNode_ptr;
@@ -106,15 +284,19 @@ class NetworkInterface : public ClockedObject, public Consumer
     // When a vc stays busy for a long time, it indicates a deadlock
     std::vector<int> vc_busy_counter;
 
-    bool checkStallQueue();
+    void checkStallQueue();
     bool flitisizeMessage(MsgPtr msg_ptr, int vnet);
     int calculateVC(int vnet);
 
+
+    void scheduleOutputPort(OutputPort *oPort);
     void scheduleOutputLink();
     void checkReschedule();
-    void sendCredit(flit *t_flit, bool is_free);
 
     void incrementStats(flit *t_flit);
+
+    InputPort *getInportForVnet(int vnet);
+    OutputPort *getOutportForVnet(int vnet);
 };
 
 #endif // __MEM_RUBY_NETWORK_GARNET2_0_NETWORKINTERFACE_HH__