From 6d89f8fe9cea5618dd0c572ffa13e8729dee6f20 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Sat, 30 Nov 2024 13:27:37 -0800 Subject: [PATCH] Add binary safety check to server --- llamafile/server/client.cpp | 32 +++++++++++++++++++++----- llamafile/server/client.h | 1 + llamafile/server/utils.h | 4 ++++ llamafile/server/writev.cpp | 46 +++++++++++++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 6 deletions(-) create mode 100644 llamafile/server/writev.cpp diff --git a/llamafile/server/client.cpp b/llamafile/server/client.cpp index 7ce3be3f89..2f6e2e52fa 100644 --- a/llamafile/server/client.cpp +++ b/llamafile/server/client.cpp @@ -24,6 +24,7 @@ #include "llamafile/server/server.h" #include "llamafile/server/time.h" #include "llamafile/server/tokenbucket.h" +#include "llamafile/server/utils.h" #include "llamafile/server/worker.h" #include "llamafile/string.h" #include "llamafile/threadlocal.h" @@ -478,7 +479,7 @@ Client::send_response_chunk(const std::string_view content) // perform send system call ssize_t sent; - if ((sent = writev(fd_, iov, 3)) != bytes) { + if ((sent = safe_writev(fd_, iov, 3)) != bytes) { if (sent == -1 && errno != EAGAIN && errno != ECONNRESET) SLOG("writev failed %m"); close_connection_ = true; @@ -504,15 +505,34 @@ Client::send_response_finish() return send("0\r\n\r\n"); } -// writes raw data to socket +// writes any old data to socket +// +// unlike send() this won't fail if binary content is detected. +bool +Client::send_binary(const void* p, size_t n) +{ + ssize_t sent; + if ((sent = write(fd_, p, n)) != n) { + if (sent == -1 && errno != EAGAIN && errno != ECONNRESET) + SLOG("write failed %m"); + close_connection_ = true; + return false; + } + return true; +} + +// writes non-binary data to socket // // consider using the higher level methods like send_error(), // send_response(), send_response_start(), etc. bool Client::send(const std::string_view s) { + iovec iov[1]; ssize_t sent; - if ((sent = write(fd_, s.data(), s.size())) != s.size()) { + iov[0].iov_base = (void*)s.data(); + iov[0].iov_len = s.size(); + if ((sent = safe_writev(fd_, iov, 1)) != s.size()) { if (sent == -1 && errno != EAGAIN && errno != ECONNRESET) SLOG("write failed %m"); close_connection_ = true; @@ -521,7 +541,7 @@ Client::send(const std::string_view s) return true; } -// writes two pieces of raw data to socket in single system call +// writes two pieces of non-binary data to socket in single system call // // consider using the higher level methods like send_error(), // send_response(), send_response_start(), etc. @@ -534,7 +554,7 @@ Client::send2(const std::string_view s1, const std::string_view s2) iov[0].iov_len = s1.size(); iov[1].iov_base = (void*)s2.data(); iov[1].iov_len = s2.size(); - if ((sent = writev(fd_, iov, 2)) != s1.size() + s2.size()) { + if ((sent = safe_writev(fd_, iov, 2)) != s1.size() + s2.size()) { if (sent == -1 && errno != EAGAIN && errno != ECONNRESET) SLOG("writev failed %m"); close_connection_ = true; @@ -755,7 +775,7 @@ Client::dispatcher() close_connection_ = true; return false; } - if (!send(std::string_view(buf, chunk))) { + if (!send_binary(buf, chunk)) { close_connection_ = true; return false; } diff --git a/llamafile/server/client.h b/llamafile/server/client.h index 4e1e209084..3b5e69cf2c 100644 --- a/llamafile/server/client.h +++ b/llamafile/server/client.h @@ -85,6 +85,7 @@ struct Client bool read_content() __wur; bool send_continue() __wur; bool send(const std::string_view) __wur; + bool send_binary(const void*, size_t) __wur; void defer_cleanup(void (*)(void*), void*); bool send_error(int, const char* = nullptr); char* append_http_response_message(char*, int, const char* = nullptr); diff --git a/llamafile/server/utils.h b/llamafile/server/utils.h index 1b2ad8047f..346be6a979 100644 --- a/llamafile/server/utils.h +++ b/llamafile/server/utils.h @@ -20,6 +20,7 @@ #include <__fwd/string_view.h> #include <__fwd/vector.h> #include +#include struct llama_model; @@ -28,6 +29,9 @@ namespace server { class Atom; +ssize_t +safe_writev(int, const iovec*, int); + bool atob(std::string_view, bool); diff --git a/llamafile/server/writev.cpp b/llamafile/server/writev.cpp new file mode 100644 index 0000000000..841af57679 --- /dev/null +++ b/llamafile/server/writev.cpp @@ -0,0 +1,46 @@ +// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- +// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi +// +// Copyright 2024 Mozilla Foundation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "llamafile/server/log.h" +#include "utils.h" +#include +#include + +namespace lf { +namespace server { + +ssize_t +safe_writev(int fd, const iovec* iov, int iovcnt) +{ + for (int i = 0; i < iovcnt; ++i) { + bool has_binary = false; + size_t n = iov[i].iov_len; + unsigned char* p = (unsigned char*)iov[i].iov_base; + for (size_t j = 0; j < n; ++j) { + has_binary |= p[j] < 7; + } + if (has_binary) { + SLOG("safe_writev() detected binary server is compromised"); + errno = EINVAL; + return -1; + } + } + return writev(fd, iov, iovcnt); +} + +} // namespace server +} // namespace lf