Ruby: Improve Change PerfectSwitch's wakeup function
authorNilay Vaish <nilay@cs.wisc.edu>
Mon, 14 Feb 2011 22:14:54 +0000 (16:14 -0600)
committerNilay Vaish <nilay@cs.wisc.edu>
Mon, 14 Feb 2011 22:14:54 +0000 (16:14 -0600)
Currently the wakeup function for the PerfectSwitch contains three loops -

loop on number of virtual networks
  loop on number of incoming links
    loop till all messages for this (link, network) have been routed

With an 8 processor mesh network and Hammer protocol, about 11-12% of the
was observed to have been spent in this function, which is the highest
amongst all the functions. It was found that the innermost loop is executed
about 45 times per invocation of the wakeup function, when each invocation
of the wakeup function processes just about one message.

The patch tries to do away with the redundant executions of the innermost
loop. Counters have been added for each virtual network that record the
number of messages that need to be routed for that virtual network. The
inner loops are only executed when the number of messages for that particular
virtual network > 0. This does away with almost 80% of the executions of the
innermost loop. The function now consumes about 5-6% of the total execution
time.

src/mem/ruby/buffers/MessageBuffer.cc
src/mem/ruby/buffers/MessageBuffer.hh
src/mem/ruby/common/Consumer.hh
src/mem/ruby/network/simple/PerfectSwitch.cc
src/mem/ruby/network/simple/PerfectSwitch.hh
src/mem/ruby/slicc_interface/Message.hh
src/mem/ruby/slicc_interface/NetworkMessage.hh

index f6b79c5803913baf0b07ecfbeec785dffa10a580..2255950053006cd10b1a61d78632d33f080ed025 100644 (file)
@@ -58,6 +58,8 @@ MessageBuffer::MessageBuffer(const string &name)
     m_name = name;
 
     m_stall_msg_map.clear();
+    m_input_link_id = 0;
+    m_vnet_id = 0;
 }
 
 int
@@ -228,6 +230,7 @@ MessageBuffer::enqueue(MsgPtr message, Time delta)
     // Schedule the wakeup
     if (m_consumer_ptr != NULL) {
         g_eventQueue_ptr->scheduleEventAbsolute(m_consumer_ptr, arrival_time);
+        m_consumer_ptr->storeEventInfo(m_vnet_id);
     } else {
         panic("No consumer: %s name: %s\n", *this, m_name);
     }
index 62cc656701989f5f03bb8850ef2abb0d30ad7872..88df5b788e23d14543a23e5d8ead596856092b92 100644 (file)
@@ -142,6 +142,9 @@ class MessageBuffer
     void printStats(std::ostream& out);
     void clearStats() { m_not_avail_count = 0; m_msg_counter = 0; }
 
+    void setIncomingLink(int link_id) { m_input_link_id = link_id; }
+    void setVnet(int net) { m_vnet_id = net; }
+
   private:
     //added by SS
     int m_recycle_latency;
@@ -184,6 +187,9 @@ class MessageBuffer
     bool m_ordering_set;
     bool m_randomization;
     Time m_last_arrival_time;
+
+    int m_input_link_id;
+    int m_vnet_id;
 };
 
 inline std::ostream&
index c1f8bc42e6354be81e8690461b5ccc71703f6c7d..a119abb3909e0e665f6a447e460d0f020eeb5687 100644 (file)
@@ -67,6 +67,7 @@ class Consumer
 
     virtual void wakeup() = 0;
     virtual void print(std::ostream& out) const = 0;
+    virtual void storeEventInfo(int info) {}
 
     const Time&
     getLastScheduledWakeup() const
index 7229c724f0c6c85dd8415206e2f99dc1a689be81..5c461c63fe7be542c1c50a6097d9d311bf01d08d 100644 (file)
@@ -54,6 +54,11 @@ PerfectSwitch::PerfectSwitch(SwitchID sid, SimpleNetwork* network_ptr)
     m_round_robin_start = 0;
     m_network_ptr = network_ptr;
     m_wakeups_wo_switch = 0;
+
+    for(int i = 0;i < m_virtual_networks;++i)
+    {
+        m_pending_message_count.push_back(0);
+    }
 }
 
 void
@@ -62,12 +67,15 @@ PerfectSwitch::addInPort(const vector<MessageBuffer*>& in)
     assert(in.size() == m_virtual_networks);
     NodeID port = m_in.size();
     m_in.push_back(in);
+
     for (int j = 0; j < m_virtual_networks; j++) {
         m_in[port][j]->setConsumer(this);
         string desc = csprintf("[Queue from port %s %s %s to PerfectSwitch]",
             NodeIDToString(m_switch_id), NodeIDToString(port),
             NodeIDToString(j));
         m_in[port][j]->setDescription(desc);
+        m_in[port][j]->setIncomingLink(port);
+        m_in[port][j]->setVnet(j);
     }
 }
 
@@ -154,160 +162,169 @@ PerfectSwitch::wakeup()
             m_round_robin_start = 0;
         }
 
-        // for all input ports, use round robin scheduling
-        for (int counter = 0; counter < m_in.size(); counter++) {
-            // Round robin scheduling
-            incoming++;
-            if (incoming >= m_in.size()) {
-                incoming = 0;
-            }
+        if(m_pending_message_count[vnet] > 0) {
+            // for all input ports, use round robin scheduling
+            for (int counter = 0; counter < m_in.size(); counter++) {
+                // Round robin scheduling
+                incoming++;
+                if (incoming >= m_in.size()) {
+                    incoming = 0;
+                }
 
-            // temporary vectors to store the routing results
-            vector<LinkID> output_links;
-            vector<NetDest> output_link_destinations;
-
-            // Is there a message waiting?
-            while (m_in[incoming][vnet]->isReady()) {
-                DPRINTF(RubyNetwork, "incoming: %d\n", incoming);
-
-                // Peek at message
-                msg_ptr = m_in[incoming][vnet]->peekMsgPtr();
-                net_msg_ptr = safe_cast<NetworkMessage*>(msg_ptr.get());
-                DPRINTF(RubyNetwork, "Message: %s\n", (*net_msg_ptr));
-
-                output_links.clear();
-                output_link_destinations.clear();
-                NetDest msg_dsts =
-                    net_msg_ptr->getInternalDestination();
-
-                // Unfortunately, the token-protocol sends some
-                // zero-destination messages, so this assert isn't valid
-                // assert(msg_dsts.count() > 0);
-
-                assert(m_link_order.size() == m_routing_table.size());
-                assert(m_link_order.size() == m_out.size());
-
-                if (m_network_ptr->getAdaptiveRouting()) {
-                    if (m_network_ptr->isVNetOrdered(vnet)) {
-                        // Don't adaptively route
-                        for (int out = 0; out < m_out.size(); out++) {
-                            m_link_order[out].m_link = out;
-                            m_link_order[out].m_value = 0;
-                        }
-                    } else {
-                        // Find how clogged each link is
-                        for (int out = 0; out < m_out.size(); out++) {
-                            int out_queue_length = 0;
-                            for (int v = 0; v < m_virtual_networks; v++) {
-                                out_queue_length += m_out[out][v]->getSize();
+                // temporary vectors to store the routing results
+                vector<LinkID> output_links;
+                vector<NetDest> output_link_destinations;
+
+                // Is there a message waiting?
+                while (m_in[incoming][vnet]->isReady()) {
+                    DPRINTF(RubyNetwork, "incoming: %d\n", incoming);
+
+                    // Peek at message
+                    msg_ptr = m_in[incoming][vnet]->peekMsgPtr();
+                    net_msg_ptr = safe_cast<NetworkMessage*>(msg_ptr.get());
+                    DPRINTF(RubyNetwork, "Message: %s\n", (*net_msg_ptr));
+
+                    output_links.clear();
+                    output_link_destinations.clear();
+                    NetDest msg_dsts =
+                        net_msg_ptr->getInternalDestination();
+
+                    // Unfortunately, the token-protocol sends some
+                    // zero-destination messages, so this assert isn't valid
+                    // assert(msg_dsts.count() > 0);
+
+                    assert(m_link_order.size() == m_routing_table.size());
+                    assert(m_link_order.size() == m_out.size());
+
+                    if (m_network_ptr->getAdaptiveRouting()) {
+                        if (m_network_ptr->isVNetOrdered(vnet)) {
+                            // Don't adaptively route
+                            for (int out = 0; out < m_out.size(); out++) {
+                                m_link_order[out].m_link = out;
+                                m_link_order[out].m_value = 0;
+                            }
+                        } else {
+                            // Find how clogged each link is
+                            for (int out = 0; out < m_out.size(); out++) {
+                                int out_queue_length = 0;
+                                for (int v = 0; v < m_virtual_networks; v++) {
+                                    out_queue_length += m_out[out][v]->getSize();
+                                }
+                                int value =
+                                    (out_queue_length << 8) | (random() & 0xff);
+                                m_link_order[out].m_link = out;
+                                m_link_order[out].m_value = value;
                             }
-                            int value =
-                                (out_queue_length << 8) | (random() & 0xff);
-                            m_link_order[out].m_link = out;
-                            m_link_order[out].m_value = value;
+
+                            // Look at the most empty link first
+                            sort(m_link_order.begin(), m_link_order.end());
                         }
+                    }
 
-                        // Look at the most empty link first
-                        sort(m_link_order.begin(), m_link_order.end());
+                    for (int i = 0; i < m_routing_table.size(); i++) {
+                        // pick the next link to look at
+                        int link = m_link_order[i].m_link;
+                        NetDest dst = m_routing_table[link];
+                        DPRINTF(RubyNetwork, "dst: %s\n", dst);
+
+                        if (!msg_dsts.intersectionIsNotEmpty(dst))
+                            continue;
+
+                        // Remember what link we're using
+                        output_links.push_back(link);
+
+                        // Need to remember which destinations need this
+                        // message in another vector.  This Set is the
+                        // intersection of the routing_table entry and the
+                        // current destination set.  The intersection must
+                        // not be empty, since we are inside "if"
+                        output_link_destinations.push_back(msg_dsts.AND(dst));
+
+                        // Next, we update the msg_destination not to
+                        // include those nodes that were already handled
+                        // by this link
+                        msg_dsts.removeNetDest(dst);
                     }
-                }
 
-                for (int i = 0; i < m_routing_table.size(); i++) {
-                    // pick the next link to look at
-                    int link = m_link_order[i].m_link;
-                    NetDest dst = m_routing_table[link];
-                    DPRINTF(RubyNetwork, "dst: %s\n", dst);
-
-                    if (!msg_dsts.intersectionIsNotEmpty(dst))
-                        continue;
-
-                    // Remember what link we're using
-                    output_links.push_back(link);
-
-                    // Need to remember which destinations need this
-                    // message in another vector.  This Set is the
-                    // intersection of the routing_table entry and the
-                    // current destination set.  The intersection must
-                    // not be empty, since we are inside "if"
-                    output_link_destinations.push_back(msg_dsts.AND(dst));
-
-                    // Next, we update the msg_destination not to
-                    // include those nodes that were already handled
-                    // by this link
-                    msg_dsts.removeNetDest(dst);
-                }
+                    assert(msg_dsts.count() == 0);
+                    //assert(output_links.size() > 0);
+
+                    // Check for resources - for all outgoing queues
+                    bool enough = true;
+                    for (int i = 0; i < output_links.size(); i++) {
+                        int outgoing = output_links[i];
+                        if (!m_out[outgoing][vnet]->areNSlotsAvailable(1))
+                            enough = false;
+                        DPRINTF(RubyNetwork, "Checking if node is blocked\n"
+                                "outgoing: %d, vnet: %d, enough: %d\n",
+                                outgoing, vnet, enough);
+                    }
 
-                assert(msg_dsts.count() == 0);
-                //assert(output_links.size() > 0);
-
-                // Check for resources - for all outgoing queues
-                bool enough = true;
-                for (int i = 0; i < output_links.size(); i++) {
-                    int outgoing = output_links[i];
-                    if (!m_out[outgoing][vnet]->areNSlotsAvailable(1))
-                        enough = false;
-                    DPRINTF(RubyNetwork, "Checking if node is blocked\n"
-                            "outgoing: %d, vnet: %d, enough: %d\n",
-                            outgoing, vnet, enough);
-                }
+                    // There were not enough resources
+                    if (!enough) {
+                        g_eventQueue_ptr->scheduleEvent(this, 1);
+                        DPRINTF(RubyNetwork, "Can't deliver message since a node "
+                                "is blocked\n"
+                                "Message: %s\n", (*net_msg_ptr));
+                        break; // go to next incoming port
+                    }
 
-                // There were not enough resources
-                if (!enough) {
-                    g_eventQueue_ptr->scheduleEvent(this, 1);
-                    DPRINTF(RubyNetwork, "Can't deliver message since a node "
-                            "is blocked\n"
-                            "Message: %s\n", (*net_msg_ptr));
-                    break; // go to next incoming port
-                }
+                    MsgPtr unmodified_msg_ptr;
 
-                MsgPtr unmodified_msg_ptr;
+                    if (output_links.size() > 1) {
+                        // If we are sending this message down more than
+                        // one link (size>1), we need to make a copy of
+                        // the message so each branch can have a different
+                        // internal destination we need to create an
+                        // unmodified MsgPtr because the MessageBuffer
+                        // enqueue func will modify the message
 
-                if (output_links.size() > 1) {
-                    // If we are sending this message down more than
-                    // one link (size>1), we need to make a copy of
-                    // the message so each branch can have a different
-                    // internal destination we need to create an
-                    // unmodified MsgPtr because the MessageBuffer
-                    // enqueue func will modify the message
+                        // This magic line creates a private copy of the
+                        // message
+                        unmodified_msg_ptr = msg_ptr->clone();
+                    }
 
-                    // This magic line creates a private copy of the
-                    // message
-                    unmodified_msg_ptr = msg_ptr->clone();
-                }
+                    // Enqueue it - for all outgoing queues
+                    for (int i=0; i<output_links.size(); i++) {
+                        int outgoing = output_links[i];
 
-                // Enqueue it - for all outgoing queues
-                for (int i=0; i<output_links.size(); i++) {
-                    int outgoing = output_links[i];
+                        if (i > 0) {
+                            // create a private copy of the unmodified
+                            // message
+                            msg_ptr = unmodified_msg_ptr->clone();
+                        }
 
-                    if (i > 0) {
-                        // create a private copy of the unmodified
-                        // message
-                        msg_ptr = unmodified_msg_ptr->clone();
-                    }
+                        // Change the internal destination set of the
+                        // message so it knows which destinations this
+                        // link is responsible for.
+                        net_msg_ptr = safe_cast<NetworkMessage*>(msg_ptr.get());
+                        net_msg_ptr->getInternalDestination() =
+                            output_link_destinations[i];
 
-                    // Change the internal destination set of the
-                    // message so it knows which destinations this
-                    // link is responsible for.
-                    net_msg_ptr = safe_cast<NetworkMessage*>(msg_ptr.get());
-                    net_msg_ptr->getInternalDestination() =
-                        output_link_destinations[i];
+                        // Enqeue msg
+                        DPRINTF(RubyNetwork, "Switch: %d enqueuing net msg from "
+                                "inport[%d][%d] to outport [%d][%d] time: %lld.\n",
+                                m_switch_id, incoming, vnet, outgoing, vnet,
+                                g_eventQueue_ptr->getTime());
 
-                    // Enqeue msg
-                    DPRINTF(RubyNetwork, "Switch: %d enqueuing net msg from "
-                            "inport[%d][%d] to outport [%d][%d] time: %lld.\n",
-                            m_switch_id, incoming, vnet, outgoing, vnet,
-                            g_eventQueue_ptr->getTime());
+                        m_out[outgoing][vnet]->enqueue(msg_ptr);
+                    }
 
-                    m_out[outgoing][vnet]->enqueue(msg_ptr);
+                    // Dequeue msg
+                    m_in[incoming][vnet]->pop();
+                    m_pending_message_count[vnet]--;
                 }
-
-                // Dequeue msg
-                m_in[incoming][vnet]->pop();
             }
         }
     }
 }
 
+void
+PerfectSwitch::storeEventInfo(int info)
+{
+    m_pending_message_count[info]++;
+}
+
 void
 PerfectSwitch::printStats(std::ostream& out) const
 {
index a7e577df01502b46ff003251c64eb611957ddd49..cd0219fd9f1af4d3ec314461af191189fa5558a0 100644 (file)
@@ -69,6 +69,7 @@ class PerfectSwitch : public Consumer
     int getOutLinks() const { return m_out.size(); }
 
     void wakeup();
+    void storeEventInfo(int info);
 
     void printStats(std::ostream& out) const;
     void clearStats();
@@ -92,6 +93,7 @@ class PerfectSwitch : public Consumer
     int m_round_robin_start;
     int m_wakeups_wo_switch;
     SimpleNetwork* m_network_ptr;
+    std::vector<int> m_pending_message_count;
 };
 
 inline std::ostream&
index ff94fdd409e7392ec2c0cba26e6bb4761b90bc09..7fcfabe9ca0c81e7a75d9b867c4f2d7d31335d78 100644 (file)
@@ -57,6 +57,8 @@ class Message : public RefCounted
 
     virtual Message* clone() const = 0;
     virtual void print(std::ostream& out) const = 0;
+    virtual void setIncomingLink(int) {}
+    virtual void setVnet(int) {}
 
     void setDelayedCycles(const int& cycles) { m_DelayedCycles = cycles; }
     const int& getDelayedCycles() const {return m_DelayedCycles;}
index 082481e054cd5dfdb35c3be9a56c2531e18ba0b4..a8f9c625b3b993507fa3d6388275355941fa409b 100644 (file)
@@ -82,9 +82,16 @@ class NetworkMessage : public Message
 
     virtual void print(std::ostream& out) const = 0;
 
+    int getIncomingLink() const { return incoming_link; }
+    void setIncomingLink(int link) { incoming_link = link; }
+    int getVnet() const { return vnet; }
+    void setVnet(int net) { vnet = net; }
+
   private:
     NetDest m_internal_dest;
     bool m_internal_dest_valid;
+    int incoming_link;
+    int vnet;
 };
 
 inline std::ostream&