diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index 083ee17..8ce99ec 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -286,56 +286,68 @@ def handle(self): def read_bytes(self, num): return self.rfile.read(num) - def read_next_message(self): - try: - b1, b2 = self.read_bytes(2) - except SocketError as e: # to be replaced with ConnectionResetError for py3 - if e.errno == errno.ECONNRESET: - logger.info("Client closed connection.") - self.keep_alive = 0 - return - b1, b2 = 0, 0 - except ValueError as e: - b1, b2 = 0, 0 - - fin = b1 & FIN - opcode = b1 & OPCODE - masked = b2 & MASKED - payload_length = b2 & PAYLOAD_LEN - + def is_opcode_valid(self, opcode): + if opcode == OPCODE_TEXT: + return True, self.server._message_received_ + if opcode == OPCODE_PING: + return True, self.server._ping_received_ + if opcode == OPCODE_PONG: + return True, self.server._pong_received_ + return False, None + + def log_info_about_bad_opcode(self, opcode): if opcode == OPCODE_CLOSE_CONN: logger.info("Client asked to close connection.") self.keep_alive = 0 - return - if not masked: - logger.warning("Client must always be masked.") - self.keep_alive = 0 - return + return True if opcode == OPCODE_CONTINUATION: logger.warning("Continuation frames are not supported.") - return - elif opcode == OPCODE_BINARY: + return True + if opcode == OPCODE_BINARY: logger.warning("Binary frames are not supported.") - return - elif opcode == OPCODE_TEXT: - opcode_handler = self.server._message_received_ - elif opcode == OPCODE_PING: - opcode_handler = self.server._ping_received_ - elif opcode == OPCODE_PONG: - opcode_handler = self.server._pong_received_ - else: - logger.warning("Unknown opcode %#x." % opcode) + return True + logger.warning("Unknown opcode %#x." % opcode) + return False + + def choose_opcode_handler(self, opcode, masked): + if not masked: + logger.warning("Client must always be masked.") self.keep_alive = 0 - return + return None + ok, handler = self.is_opcode_valid(opcode) + if ok: + return handler + ok = self.log_info_about_bad_opcode(opcode) + if not ok: + self.keep_alive = 0 + return None + def compute_payload_length(self, secondByte): + payload_length = secondByte & PAYLOAD_LEN if payload_length == 126: - payload_length = struct.unpack(">H", self.rfile.read(2))[0] - elif payload_length == 127: - payload_length = struct.unpack(">Q", self.rfile.read(8))[0] + return struct.unpack(">H", self.rfile.read(2))[0] + if payload_length == 127: + return struct.unpack(">Q", self.rfile.read(8))[0] + return payload_length + + + def read_next_message(self): + b1, b2 = 0, 0 + try: + b1, b2 = self.read_bytes(2) + except SocketError as e: # to be replaced with ConnectionResetError for py3 + if e.errno == errno.ECONNRESET: + logger.info("Client closed connection.") + self.keep_alive = 0 + return + + opcode_handler = self.choose_opcode_handler(b1 & OPCODE, b2 & MASKED) + if opcode_handler is None: + return masks = self.read_bytes(4) message_bytes = bytearray() - for message_byte in self.read_bytes(payload_length): + for message_byte in self.read_bytes(self.compute_payload_length(b2)): message_byte ^= masks[len(message_bytes) % 4] message_bytes.append(message_byte) opcode_handler(self, message_bytes.decode('utf8'))