From c6a6fbe9fdf3b7d586f83d50522ce2b91b3f2ba9 Mon Sep 17 00:00:00 2001 From: Gabe Black Date: Wed, 10 May 2017 00:35:43 -0700 Subject: [PATCH] base: Make the VNC server more resilient. If the client does something bad, don't kill the whole simulation, just complain, drop the client and keep going. Change-Id: I824f2d121e2fe03cdf4323a25c192b68e0370acc Reviewed-on: https://gem5-review.googlesource.com/3200 Reviewed-by: Andreas Sandberg Maintainer: Andreas Sandberg --- src/base/vnc/vncserver.cc | 145 +++++++++++++++++++++----------------- src/base/vnc/vncserver.hh | 18 ++--- 2 files changed, 89 insertions(+), 74 deletions(-) diff --git a/src/base/vnc/vncserver.cc b/src/base/vnc/vncserver.cc index 216fa2fb4..9cf38dc2d 100644 --- a/src/base/vnc/vncserver.cc +++ b/src/base/vnc/vncserver.cc @@ -198,7 +198,10 @@ VncServer::accept() panic("%s: cannot accept a connection if not listening!", name()); int fd = listener.accept(true); - fatal_if(fd < 0, "%s: failed to accept VNC connection!", name()); + if (fd < 0) { + warn("%s: failed to accept VNC connection!", name()); + return; + } if (dataFd != -1) { char message[] = "vnc server already attached!\n"; @@ -210,7 +213,7 @@ VncServer::accept() dataFd = fd; // Send our version number to the client - write((uint8_t*)vncVersion(), strlen(vncVersion())); + write((uint8_t *)vncVersion(), strlen(vncVersion())); // read the client response dataEvent = new DataEvent(this, dataFd, POLLIN); @@ -224,7 +227,6 @@ void VncServer::data() { // We have new data, see if we can handle it - size_t len; DPRINTF(VNC, "Vnc client message recieved\n"); switch (curState) { @@ -237,8 +239,8 @@ VncServer::data() case WaitForClientInit: // Don't care about shared, just need to read it out of the socket uint8_t shared; - len = read(&shared); - assert(len == 1); + if (!read(&shared)) + return; // Send our idea of the frame buffer sendServerInit(); @@ -246,12 +248,8 @@ VncServer::data() break; case NormalPhase: uint8_t message_type; - len = read(&message_type); - if (!len) { - detach(); + if (!read(&message_type)) return; - } - assert(len == 1); switch (message_type) { case ClientSetPixelFormat: @@ -273,8 +271,9 @@ VncServer::data() recvCutText(); break; default: - panic("Unimplemented message type recv from client: %d\n", - message_type); + warn("Unimplemented message type recv from client: %d\n", + message_type); + detach(); break; } break; @@ -285,7 +284,7 @@ VncServer::data() // read from socket -size_t +bool VncServer::read(uint8_t *buf, size_t len) { if (dataFd < 0) @@ -297,59 +296,58 @@ VncServer::read(uint8_t *buf, size_t len) } while (ret == -1 && errno == EINTR); - if (ret <= 0){ - DPRINTF(VNC, "Read failed.\n"); + if (ret != len) { + DPRINTF(VNC, "Read failed %d.\n", ret); detach(); - return 0; + return false; } - return ret; + return true; } -size_t +bool VncServer::read1(uint8_t *buf, size_t len) { - size_t read_len M5_VAR_USED; - read_len = read(buf + 1, len - 1); - assert(read_len == len - 1); - return read_len; + return read(buf + 1, len - 1); } template -size_t +bool VncServer::read(T* val) { - return read((uint8_t*)val, sizeof(T)); + return read((uint8_t *)val, sizeof(T)); } // write to socket -size_t +bool VncServer::write(const uint8_t *buf, size_t len) { if (dataFd < 0) panic("Vnc client not properly attached.\n"); - ssize_t ret; - ret = atomic_write(dataFd, buf, len); + ssize_t ret = atomic_write(dataFd, buf, len); - if (ret < len) + if (ret != len) { + DPRINTF(VNC, "Write failed.\n"); detach(); + return false; + } - return ret; + return true; } template -size_t +bool VncServer::write(T* val) { - return write((uint8_t*)val, sizeof(T)); + return write((uint8_t *)val, sizeof(T)); } -size_t +bool VncServer::write(const char* str) { - return write((uint8_t*)str, strlen(str)); + return write((uint8_t *)str, strlen(str)); } // detach a vnc client @@ -377,7 +375,8 @@ void VncServer::sendError(const char* error_msg) { uint32_t len = strlen(error_msg); - write(&len); + if (!write(&len)) + return; write(error_msg); } @@ -392,8 +391,10 @@ VncServer::checkProtocolVersion() // Null terminate the message so it's easier to work with version_string[12] = 0; - len = read((uint8_t*)version_string, 12); - assert(len == 12); + if (!read((uint8_t *)version_string, sizeof(version_string) - 1)) { + warn("Failed to read protocol version."); + return; + } uint32_t major, minor; @@ -402,6 +403,7 @@ VncServer::checkProtocolVersion() warn(" Malformed protocol version %s\n", version_string); sendError("Malformed protocol version\n"); detach(); + return; } DPRINTF(VNC, "Client request protocol version %d.%d\n", major, minor); @@ -412,16 +414,18 @@ VncServer::checkProtocolVersion() uint8_t err = AuthInvalid; write(&err); detach(); + return; } // Auth is different based on version number if (minor < 7) { uint32_t sec_type = htobe((uint32_t)AuthNone); - write(&sec_type); + if (!write(&sec_type)) + return; } else { uint8_t sec_cnt = 1; uint8_t sec_type = htobe((uint8_t)AuthNone); - write(&sec_cnt); - write(&sec_type); + if (!write(&sec_cnt) || !write(&sec_type)) + return; } // Wait for client to respond @@ -434,9 +438,8 @@ VncServer::checkSecurity() assert(curState == WaitForSecurityResponse); uint8_t security_type; - size_t len M5_VAR_USED = read(&security_type); - - assert(len == 1); + if (!read(&security_type)) + return; if (security_type != AuthNone) { warn("Unknown VNC security type\n"); @@ -446,7 +449,8 @@ VncServer::checkSecurity() DPRINTF(VNC, "Sending security auth OK\n"); uint32_t success = htobe(VncOK); - write(&success); + if (!write(&success)) + return; curState = WaitForClientInit; } @@ -475,7 +479,8 @@ VncServer::sendServerInit() msg.namelen = htobe(msg.namelen); memcpy(msg.name, "M5", 2); - write(&msg); + if (!write(&msg)) + return; curState = NormalPhase; } @@ -485,7 +490,8 @@ VncServer::setPixelFormat() DPRINTF(VNC, "Received pixel format from client message\n"); PixelFormatMessage pfm; - read1((uint8_t*)&pfm, sizeof(PixelFormatMessage)); + if (!read1((uint8_t *)&pfm, sizeof(PixelFormatMessage))) + return; DPRINTF(VNC, " -- bpp = %d; depth = %d; be = %d\n", pfm.px.bpp, pfm.px.depth, pfm.px.bigendian); @@ -504,8 +510,10 @@ VncServer::setPixelFormat() betoh(pfm.px.bluemax) != pixelFormat.bluemax || betoh(pfm.px.redshift) != pixelFormat.redshift || betoh(pfm.px.greenshift) != pixelFormat.greenshift || - betoh(pfm.px.blueshift) != pixelFormat.blueshift) - fatal("VNC client doesn't support true color raw encoding\n"); + betoh(pfm.px.blueshift) != pixelFormat.blueshift) { + warn("VNC client doesn't support true color raw encoding\n"); + detach(); + } } void @@ -514,7 +522,8 @@ VncServer::setEncodings() DPRINTF(VNC, "Received supported encodings from client\n"); PixelEncodingsMessage pem; - read1((uint8_t*)&pem, sizeof(PixelEncodingsMessage)); + if (!read1((uint8_t *)&pem, sizeof(PixelEncodingsMessage))) + return; pem.num_encodings = betoh(pem.num_encodings); @@ -523,9 +532,8 @@ VncServer::setEncodings() for (int x = 0; x < pem.num_encodings; x++) { int32_t encoding; - size_t len M5_VAR_USED; - len = read(&encoding); - assert(len == sizeof(encoding)); + if (!read(&encoding)) + return; DPRINTF(VNC, " -- supports %d\n", betoh(encoding)); switch (betoh(encoding)) { @@ -538,8 +546,10 @@ VncServer::setEncodings() } } - if (!supportsRawEnc) - fatal("VNC clients must always support raw encoding\n"); + if (!supportsRawEnc) { + warn("VNC clients must always support raw encoding\n"); + detach(); + } } void @@ -548,7 +558,8 @@ VncServer::requestFbUpdate() DPRINTF(VNC, "Received frame buffer update request from client\n"); FrameBufferUpdateReq fbr; - read1((uint8_t*)&fbr, sizeof(FrameBufferUpdateReq)); + if (!read1((uint8_t *)&fbr, sizeof(FrameBufferUpdateReq))) + return; fbr.x = betoh(fbr.x); fbr.y = betoh(fbr.y); @@ -566,7 +577,8 @@ VncServer::recvKeyboardInput() { DPRINTF(VNC, "Received keyboard input from client\n"); KeyEventMessage kem; - read1((uint8_t*)&kem, sizeof(KeyEventMessage)); + if (!read1((uint8_t *)&kem, sizeof(KeyEventMessage))) + return; kem.key = betoh(kem.key); DPRINTF(VNC, " -- received key code %d (%s)\n", kem.key, kem.down_flag ? @@ -582,7 +594,8 @@ VncServer::recvPointerInput() DPRINTF(VNC, "Received pointer input from client\n"); PointerEventMessage pem; - read1((uint8_t*)&pem, sizeof(PointerEventMessage));; + if (!read1((uint8_t *)&pem, sizeof(PointerEventMessage))) + return; pem.x = betoh(pem.x); pem.y = betoh(pem.y); @@ -599,18 +612,18 @@ VncServer::recvCutText() DPRINTF(VNC, "Received client copy buffer message\n"); ClientCutTextMessage cct; - read1((uint8_t*)&cct, sizeof(ClientCutTextMessage)); + if (!read1((uint8_t *)&cct, sizeof(ClientCutTextMessage))) + return; char str[1025]; size_t data_len = betoh(cct.length); DPRINTF(VNC, "String length %d\n", data_len); while (data_len > 0) { - size_t len; size_t bytes_to_read = data_len > 1024 ? 1024 : data_len; - len = read((uint8_t*)&str, bytes_to_read); + if (!read((uint8_t *)&str, bytes_to_read)) + return; str[bytes_to_read] = 0; - assert(len >= data_len); - data_len -= len; + data_len -= bytes_to_read; DPRINTF(VNC, "Buffer: %s\n", str); } @@ -651,8 +664,8 @@ VncServer::sendFrameBufferUpdate() fbr.encoding = htobe(fbr.encoding); // send headers to client - write(&fbu); - write(&fbr); + if (!write(&fbu) || !write(&fbr)) + return; assert(fb); @@ -665,7 +678,8 @@ VncServer::sendFrameBufferUpdate() raw_pixel += pixelConverter.length; } - write(line_buffer.data(), line_buffer.size()); + if (!write(line_buffer.data(), line_buffer.size())) + return; } } @@ -695,7 +709,8 @@ VncServer::sendFrameBufferResized() fbr.encoding = htobe(fbr.encoding); // send headers to client - write(&fbu); + if (!write(&fbu)) + return; write(&fbr); // No actual data is sent in this message diff --git a/src/base/vnc/vncserver.hh b/src/base/vnc/vncserver.hh index a52850323..99f4b5fe1 100644 --- a/src/base/vnc/vncserver.hh +++ b/src/base/vnc/vncserver.hh @@ -216,9 +216,9 @@ class VncServer : public VncInput /** Read some data from the client * @param buf the data to read * @param len the amount of data to read - * @return length read + * @return whether the read was successful */ - size_t read(uint8_t *buf, size_t len); + bool read(uint8_t *buf, size_t len); /** Read len -1 bytes from the client into the buffer provided + 1 * assert that we read enough bytes. This function exists to handle @@ -226,35 +226,35 @@ class VncServer : public VncInput * the first byte which describes which one we're reading * @param buf the address of the buffer to add one to and read data into * @param len the amount of data + 1 to read - * @return length read + * @return whether the read was successful. */ - size_t read1(uint8_t *buf, size_t len); + bool read1(uint8_t *buf, size_t len); /** Templated version of the read function above to * read simple data to the client * @param val data to recv from the client */ - template size_t read(T* val); + template bool read(T* val); /** Write a buffer to the client. * @param buf buffer to send * @param len length of the buffer - * @return number of bytes sent + * @return whether the write was successful */ - size_t write(const uint8_t *buf, size_t len); + bool write(const uint8_t *buf, size_t len); /** Templated version of the write function above to * write simple data to the client * @param val data to send to the client */ - template size_t write(T* val); + template bool write(T* val); /** Send a string to the client * @param str string to transmit */ - size_t write(const char* str); + bool write(const char* str); /** Check the client's protocol verion for compatibility and send * the security types we support -- 2.30.2