base: Make the VNC server more resilient.
authorGabe Black <gabeblack@google.com>
Wed, 10 May 2017 07:35:43 +0000 (00:35 -0700)
committerGabe Black <gabeblack@google.com>
Fri, 12 May 2017 09:43:20 +0000 (09:43 +0000)
If the client does something bad, don't kill the whole simulation, just
complain, drop the client and keep going.

Change-Id: I824f2d121e2fe03cdf4323a25c192b68e0370acc
Reviewed-on: https://gem5-review.googlesource.com/3200
Reviewed-by: Andreas Sandberg <andreas.sandberg@arm.com>
Maintainer: Andreas Sandberg <andreas.sandberg@arm.com>

src/base/vnc/vncserver.cc
src/base/vnc/vncserver.hh

index 216fa2fb4421a1d8c72b8521eef04ce4ba7ded1d..9cf38dc2db65a97faf39d26849a22b5e5801174a 100644 (file)
@@ -198,7 +198,10 @@ VncServer::accept()
         panic("%s: cannot accept a connection if not listening!", name());
 
     int fd = listener.accept(true);
-    fatal_if(fd < 0, "%s: failed to accept VNC connection!", name());
+    if (fd < 0) {
+        warn("%s: failed to accept VNC connection!", name());
+        return;
+    }
 
     if (dataFd != -1) {
         char message[] = "vnc server already attached!\n";
@@ -210,7 +213,7 @@ VncServer::accept()
     dataFd = fd;
 
     // Send our version number to the client
-    write((uint8_t*)vncVersion(), strlen(vncVersion()));
+    write((uint8_t *)vncVersion(), strlen(vncVersion()));
 
     // read the client response
     dataEvent = new DataEvent(this, dataFd, POLLIN);
@@ -224,7 +227,6 @@ void
 VncServer::data()
 {
     // We have new data, see if we can handle it
-    size_t len;
     DPRINTF(VNC, "Vnc client message recieved\n");
 
     switch (curState) {
@@ -237,8 +239,8 @@ VncServer::data()
       case WaitForClientInit:
         // Don't care about shared, just need to read it out of the socket
         uint8_t shared;
-        len = read(&shared);
-        assert(len == 1);
+        if (!read(&shared))
+            return;
 
         // Send our idea of the frame buffer
         sendServerInit();
@@ -246,12 +248,8 @@ VncServer::data()
         break;
       case NormalPhase:
         uint8_t message_type;
-        len = read(&message_type);
-        if (!len) {
-            detach();
+        if (!read(&message_type))
             return;
-        }
-        assert(len == 1);
 
         switch (message_type) {
           case ClientSetPixelFormat:
@@ -273,8 +271,9 @@ VncServer::data()
             recvCutText();
             break;
           default:
-            panic("Unimplemented message type recv from client: %d\n",
-                  message_type);
+            warn("Unimplemented message type recv from client: %d\n",
+                 message_type);
+            detach();
             break;
         }
         break;
@@ -285,7 +284,7 @@ VncServer::data()
 
 
 // read from socket
-size_t
+bool
 VncServer::read(uint8_t *buf, size_t len)
 {
     if (dataFd < 0)
@@ -297,59 +296,58 @@ VncServer::read(uint8_t *buf, size_t len)
     } while (ret == -1 && errno == EINTR);
 
 
-    if (ret <= 0){
-        DPRINTF(VNC, "Read failed.\n");
+    if (ret != len) {
+        DPRINTF(VNC, "Read failed %d.\n", ret);
         detach();
-        return 0;
+        return false;
     }
 
-    return ret;
+    return true;
 }
 
-size_t
+bool
 VncServer::read1(uint8_t *buf, size_t len)
 {
-    size_t read_len M5_VAR_USED;
-    read_len = read(buf + 1, len - 1);
-    assert(read_len == len - 1);
-    return read_len;
+    return read(buf + 1, len - 1);
 }
 
 
 template<typename T>
-size_t
+bool
 VncServer::read(T* val)
 {
-    return read((uint8_t*)val, sizeof(T));
+    return read((uint8_t *)val, sizeof(T));
 }
 
 // write to socket
-size_t
+bool
 VncServer::write(const uint8_t *buf, size_t len)
 {
     if (dataFd < 0)
         panic("Vnc client not properly attached.\n");
 
-    ssize_t ret;
-    ret = atomic_write(dataFd, buf, len);
+    ssize_t ret = atomic_write(dataFd, buf, len);
 
-    if (ret < len)
+    if (ret != len) {
+        DPRINTF(VNC, "Write failed.\n");
         detach();
+        return false;
+    }
 
-    return ret;
+    return true;
 }
 
 template<typename T>
-size_t
+bool
 VncServer::write(T* val)
 {
-    return write((uint8_t*)val, sizeof(T));
+    return write((uint8_t *)val, sizeof(T));
 }
 
-size_t
+bool
 VncServer::write(const char* str)
 {
-    return write((uint8_t*)str, strlen(str));
+    return write((uint8_t *)str, strlen(str));
 }
 
 // detach a vnc client
@@ -377,7 +375,8 @@ void
 VncServer::sendError(const char* error_msg)
 {
    uint32_t len = strlen(error_msg);
-   write(&len);
+   if (!write(&len))
+       return;
    write(error_msg);
 }
 
@@ -392,8 +391,10 @@ VncServer::checkProtocolVersion()
     // Null terminate the message so it's easier to work with
     version_string[12] = 0;
 
-    len = read((uint8_t*)version_string, 12);
-    assert(len == 12);
+    if (!read((uint8_t *)version_string, sizeof(version_string) - 1)) {
+        warn("Failed to read protocol version.");
+        return;
+    }
 
     uint32_t major, minor;
 
@@ -402,6 +403,7 @@ VncServer::checkProtocolVersion()
         warn(" Malformed protocol version %s\n", version_string);
         sendError("Malformed protocol version\n");
         detach();
+        return;
     }
 
     DPRINTF(VNC, "Client request protocol version %d.%d\n", major, minor);
@@ -412,16 +414,18 @@ VncServer::checkProtocolVersion()
         uint8_t err = AuthInvalid;
         write(&err);
         detach();
+        return;
     }
     // Auth is different based on version number
     if (minor < 7) {
         uint32_t sec_type = htobe((uint32_t)AuthNone);
-        write(&sec_type);
+        if (!write(&sec_type))
+            return;
     } else {
         uint8_t sec_cnt = 1;
         uint8_t sec_type = htobe((uint8_t)AuthNone);
-        write(&sec_cnt);
-        write(&sec_type);
+        if (!write(&sec_cnt) || !write(&sec_type))
+            return;
     }
 
     // Wait for client to respond
@@ -434,9 +438,8 @@ VncServer::checkSecurity()
     assert(curState == WaitForSecurityResponse);
 
     uint8_t security_type;
-    size_t len M5_VAR_USED = read(&security_type);
-
-    assert(len == 1);
+    if (!read(&security_type))
+        return;
 
     if (security_type != AuthNone) {
         warn("Unknown VNC security type\n");
@@ -446,7 +449,8 @@ VncServer::checkSecurity()
     DPRINTF(VNC, "Sending security auth OK\n");
 
     uint32_t success = htobe(VncOK);
-    write(&success);
+    if (!write(&success))
+        return;
     curState = WaitForClientInit;
 }
 
@@ -475,7 +479,8 @@ VncServer::sendServerInit()
     msg.namelen = htobe(msg.namelen);
     memcpy(msg.name, "M5", 2);
 
-    write(&msg);
+    if (!write(&msg))
+        return;
     curState = NormalPhase;
 }
 
@@ -485,7 +490,8 @@ VncServer::setPixelFormat()
     DPRINTF(VNC, "Received pixel format from client message\n");
 
     PixelFormatMessage pfm;
-    read1((uint8_t*)&pfm, sizeof(PixelFormatMessage));
+    if (!read1((uint8_t *)&pfm, sizeof(PixelFormatMessage)))
+        return;
 
     DPRINTF(VNC, " -- bpp = %d; depth = %d; be = %d\n", pfm.px.bpp,
             pfm.px.depth, pfm.px.bigendian);
@@ -504,8 +510,10 @@ VncServer::setPixelFormat()
         betoh(pfm.px.bluemax) != pixelFormat.bluemax ||
         betoh(pfm.px.redshift) != pixelFormat.redshift ||
         betoh(pfm.px.greenshift) != pixelFormat.greenshift ||
-        betoh(pfm.px.blueshift) != pixelFormat.blueshift)
-        fatal("VNC client doesn't support true color raw encoding\n");
+        betoh(pfm.px.blueshift) != pixelFormat.blueshift) {
+        warn("VNC client doesn't support true color raw encoding\n");
+        detach();
+    }
 }
 
 void
@@ -514,7 +522,8 @@ VncServer::setEncodings()
     DPRINTF(VNC, "Received supported encodings from client\n");
 
     PixelEncodingsMessage pem;
-    read1((uint8_t*)&pem, sizeof(PixelEncodingsMessage));
+    if (!read1((uint8_t *)&pem, sizeof(PixelEncodingsMessage)))
+        return;
 
     pem.num_encodings = betoh(pem.num_encodings);
 
@@ -523,9 +532,8 @@ VncServer::setEncodings()
 
     for (int x = 0; x < pem.num_encodings; x++) {
         int32_t encoding;
-        size_t len M5_VAR_USED;
-        len = read(&encoding);
-        assert(len == sizeof(encoding));
+        if (!read(&encoding))
+            return;
         DPRINTF(VNC, " -- supports %d\n", betoh(encoding));
 
         switch (betoh(encoding)) {
@@ -538,8 +546,10 @@ VncServer::setEncodings()
         }
     }
 
-    if (!supportsRawEnc)
-        fatal("VNC clients must always support raw encoding\n");
+    if (!supportsRawEnc) {
+        warn("VNC clients must always support raw encoding\n");
+        detach();
+    }
 }
 
 void
@@ -548,7 +558,8 @@ VncServer::requestFbUpdate()
     DPRINTF(VNC, "Received frame buffer update request from client\n");
 
     FrameBufferUpdateReq fbr;
-    read1((uint8_t*)&fbr, sizeof(FrameBufferUpdateReq));
+    if (!read1((uint8_t *)&fbr, sizeof(FrameBufferUpdateReq)))
+        return;
 
     fbr.x = betoh(fbr.x);
     fbr.y = betoh(fbr.y);
@@ -566,7 +577,8 @@ VncServer::recvKeyboardInput()
 {
     DPRINTF(VNC, "Received keyboard input from client\n");
     KeyEventMessage kem;
-    read1((uint8_t*)&kem, sizeof(KeyEventMessage));
+    if (!read1((uint8_t *)&kem, sizeof(KeyEventMessage)))
+        return;
 
     kem.key = betoh(kem.key);
     DPRINTF(VNC, " -- received key code %d (%s)\n", kem.key, kem.down_flag ?
@@ -582,7 +594,8 @@ VncServer::recvPointerInput()
     DPRINTF(VNC, "Received pointer input from client\n");
     PointerEventMessage pem;
 
-    read1((uint8_t*)&pem, sizeof(PointerEventMessage));;
+    if (!read1((uint8_t *)&pem, sizeof(PointerEventMessage)))
+        return;
 
     pem.x = betoh(pem.x);
     pem.y = betoh(pem.y);
@@ -599,18 +612,18 @@ VncServer::recvCutText()
     DPRINTF(VNC, "Received client copy buffer message\n");
 
     ClientCutTextMessage cct;
-    read1((uint8_t*)&cct, sizeof(ClientCutTextMessage));
+    if (!read1((uint8_t *)&cct, sizeof(ClientCutTextMessage)))
+        return;
 
     char str[1025];
     size_t data_len = betoh(cct.length);
     DPRINTF(VNC, "String length %d\n", data_len);
     while (data_len > 0) {
-        size_t len;
         size_t bytes_to_read = data_len > 1024 ? 1024 : data_len;
-        len = read((uint8_t*)&str, bytes_to_read);
+        if (!read((uint8_t *)&str, bytes_to_read))
+            return;
         str[bytes_to_read] = 0;
-        assert(len >= data_len);
-        data_len -= len;
+        data_len -= bytes_to_read;
         DPRINTF(VNC, "Buffer: %s\n", str);
     }
 
@@ -651,8 +664,8 @@ VncServer::sendFrameBufferUpdate()
     fbr.encoding = htobe(fbr.encoding);
 
     // send headers to client
-    write(&fbu);
-    write(&fbr);
+    if (!write(&fbu) || !write(&fbr))
+        return;
 
     assert(fb);
 
@@ -665,7 +678,8 @@ VncServer::sendFrameBufferUpdate()
             raw_pixel += pixelConverter.length;
         }
 
-        write(line_buffer.data(), line_buffer.size());
+        if (!write(line_buffer.data(), line_buffer.size()))
+            return;
     }
 }
 
@@ -695,7 +709,8 @@ VncServer::sendFrameBufferResized()
     fbr.encoding = htobe(fbr.encoding);
 
     // send headers to client
-    write(&fbu);
+    if (!write(&fbu))
+        return;
     write(&fbr);
 
     // No actual data is sent in this message
index a52850323b7c5d9b34461857b6e18497155529d6..99f4b5fe114633cd585235e6688961e10d24f2b1 100644 (file)
@@ -216,9 +216,9 @@ class VncServer : public VncInput
     /** Read some data from the client
      * @param buf the data to read
      * @param len the amount of data to read
-     * @return length read
+     * @return whether the read was successful
      */
-    size_t read(uint8_t *buf, size_t len);
+    bool read(uint8_t *buf, size_t len);
 
     /** Read len -1 bytes from the client into the buffer provided + 1
      * assert that we read enough bytes. This function exists to handle
@@ -226,35 +226,35 @@ class VncServer : public VncInput
      * the first byte which describes which one we're reading
      * @param buf the address of the buffer to add one to and read data into
      * @param len the amount of data  + 1 to read
-     * @return length read
+     * @return whether the read was successful.
      */
-    size_t read1(uint8_t *buf, size_t len);
+    bool read1(uint8_t *buf, size_t len);
 
 
     /** Templated version of the read function above to
      * read simple data to the client
      * @param val data to recv from the client
      */
-    template <typename T> size_t read(T* val);
+    template <typename T> bool read(T* val);
 
 
     /** Write a buffer to the client.
      * @param buf buffer to send
      * @param len length of the buffer
-     * @return number of bytes sent
+     * @return whether the write was successful
      */
-    size_t write(const uint8_t *buf, size_t len);
+    bool write(const uint8_t *buf, size_t len);
 
     /** Templated version of the write function above to
      * write simple data to the client
      * @param val data to send to the client
      */
-    template <typename T> size_t write(T* val);
+    template <typename T> bool write(T* val);
 
     /** Send a string to the client
      * @param str string to transmit
      */
-    size_t write(const char* str);
+    bool write(const char* str);
 
     /** Check the client's protocol verion for compatibility and send
      * the security types we support