From 4a21e7341f880e091be0f9ccf0abd125c8e7e95b Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 11 Aug 2023 19:13:13 +0800 Subject: [PATCH] Handle special characters in JSON model dump. --- src/common/common.cc | 63 +++++++++++++++------ src/common/common.h | 43 ++++++--------- src/common/json.cc | 92 ++++++++++++------------------- src/common/numeric.h | 1 + src/learner.cc | 4 +- src/tree/tree_model.cc | 11 +++- tests/python/test_basic_models.py | 20 +++++++ 7 files changed, 129 insertions(+), 105 deletions(-) diff --git a/src/common/common.cc b/src/common/common.cc index 8f4f4b5c85ca..086f4c00d167 100644 --- a/src/common/common.cc +++ b/src/common/common.cc @@ -1,16 +1,17 @@ -/*! - * Copyright 2015-2019 by Contributors - * \file common.cc - * \brief Enable all kinds of global variables in common. +/** + * Copyright 2015-2023 by Contributors */ -#include -#include - #include "common.h" -#include "./random.h" -namespace xgboost { -namespace common { +#include // for ThreadLocalStore + +#include // for uint8_t +#include // for snprintf, size_t +#include // for string + +#include "./random.h" // for GlobalRandomEngine, GlobalRandom + +namespace xgboost::common { /*! \brief thread local entry for random. */ struct RandomThreadLocalEntry { /*! \brief the random engine instance. */ @@ -19,15 +20,43 @@ struct RandomThreadLocalEntry { using RandomThreadLocalStore = dmlc::ThreadLocalStore; -GlobalRandomEngine& GlobalRandom() { - return RandomThreadLocalStore::Get()->engine; +GlobalRandomEngine &GlobalRandom() { return RandomThreadLocalStore::Get()->engine; } + +void EscapeU8(std::string const &string, std::string *p_buffer) { + auto &buffer = *p_buffer; + for (size_t i = 0; i < string.length(); i++) { + const auto ch = string[i]; + if (ch == '\\') { + if (i < string.size() && string[i + 1] == 'u') { + buffer += "\\"; + } else { + buffer += "\\\\"; + } + } else if (ch == '"') { + buffer += "\\\""; + } else if (ch == '\b') { + buffer += "\\b"; + } else if (ch == '\f') { + buffer += "\\f"; + } else if (ch == '\n') { + buffer += "\\n"; + } else if (ch == '\r') { + buffer += "\\r"; + } else if (ch == '\t') { + buffer += "\\t"; + } else if (static_cast(ch) <= 0x1f) { + // Unit separator + char buf[8]; + snprintf(buf, sizeof buf, "\\u%04x", ch); + buffer += buf; + } else { + buffer += ch; + } + } } #if !defined(XGBOOST_USE_CUDA) -int AllVisibleGPUs() { - return 0; -} +int AllVisibleGPUs() { return 0; } #endif // !defined(XGBOOST_USE_CUDA) -} // namespace common -} // namespace xgboost +} // namespace xgboost::common diff --git a/src/common/common.h b/src/common/common.h index 35c807bef46a..bedff80b33d5 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -6,20 +6,19 @@ #ifndef XGBOOST_COMMON_COMMON_H_ #define XGBOOST_COMMON_COMMON_H_ -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include // for max +#include // for array +#include // for ceil +#include // for size_t +#include // for int32_t, int64_t +#include // for basic_istream, operator<<, istringstream +#include // for string, basic_string, getline, char_traits +#include // for make_tuple +#include // for forward, index_sequence, make_index_sequence +#include // for vector + +#include "xgboost/base.h" // for XGBOOST_DEVICE +#include "xgboost/logging.h" // for LOG, LOG_FATAL, LogMessageFatal #if defined(__CUDACC__) #include @@ -52,8 +51,7 @@ inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file, #endif // defined(__CUDACC__) } // namespace dh -namespace xgboost { -namespace common { +namespace xgboost::common { /*! * \brief Split a string by delimiter * \param s String to be split. @@ -69,19 +67,13 @@ inline std::vector Split(const std::string& s, char delim) { return ret; } +void EscapeU8(std::string const &string, std::string *p_buffer); + template XGBOOST_DEVICE T Max(T a, T b) { return a < b ? b : a; } -// simple routine to convert any data to string -template -inline std::string ToString(const T& data) { - std::ostringstream os; - os << data; - return os.str(); -} - template XGBOOST_DEVICE T1 DivRoundUp(const T1 a, const T2 b) { return static_cast(std::ceil(static_cast(a) / b)); @@ -195,6 +187,5 @@ template XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) { return indptr[group + 1] - 1; } -} // namespace common -} // namespace xgboost +} // namespace xgboost::common #endif // XGBOOST_COMMON_COMMON_H_ diff --git a/src/common/json.cc b/src/common/json.cc index c3d61b47d498..de9a89f78df8 100644 --- a/src/common/json.cc +++ b/src/common/json.cc @@ -1,23 +1,29 @@ -/*! - * Copyright (c) by Contributors 2019-2022 +/** + * Copyright 2019-2023, XGBoost Contributors */ #include "xgboost/json.h" -#include - -#include -#include -#include -#include -#include -#include - -#include "./math.h" -#include "charconv.h" -#include "xgboost/base.h" -#include "xgboost/json_io.h" -#include "xgboost/logging.h" -#include "xgboost/string_view.h" +#include // for array +#include // for isdigit +#include // for isinf, isnan +#include // for EOF +#include // for size_t, strtof +#include // for memcpy +#include // for initializer_list +#include // for distance +#include // for numeric_limits +#include // for allocator +#include // for operator<<, basic_ostream, operator&, ios, stringstream +#include // for errc + +#include "./math.h" // for CheckNAN +#include "charconv.h" // for to_chars, NumericLimits, from_chars, to_chars_result +#include "common.h" // for EscapeU8 +#include "xgboost/base.h" // for XGBOOST_EXPECT +#include "xgboost/intrusive_ptr.h" // for IntrusivePtr +#include "xgboost/json_io.h" // for JsonReader, UBJReader, UBJWriter, JsonWriter, ToBigEn... +#include "xgboost/logging.h" // for LOG, LOG_FATAL, LogMessageFatal, LogCheck_NE, CHECK +#include "xgboost/string_view.h" // for StringView, operator<< namespace xgboost { @@ -57,12 +63,12 @@ void JsonWriter::Visit(JsonObject const* obj) { } void JsonWriter::Visit(JsonNumber const* num) { - char number[NumericLimits::kToCharsSize]; - auto res = to_chars(number, number + sizeof(number), num->GetNumber()); + std::array::kToCharsSize> number; + auto res = to_chars(number.data(), number.data() + number.size(), num->GetNumber()); auto end = res.ptr; auto ori_size = stream_->size(); - stream_->resize(stream_->size() + end - number); - std::memcpy(stream_->data() + ori_size, number, end - number); + stream_->resize(stream_->size() + end - number.data()); + std::memcpy(stream_->data() + ori_size, number.data(), end - number.data()); } void JsonWriter::Visit(JsonInteger const* num) { @@ -88,43 +94,15 @@ void JsonWriter::Visit(JsonNull const* ) { } void JsonWriter::Visit(JsonString const* str) { - std::string buffer; - buffer += '"'; - auto const& string = str->GetString(); - for (size_t i = 0; i < string.length(); i++) { - const char ch = string[i]; - if (ch == '\\') { - if (i < string.size() && string[i+1] == 'u') { - buffer += "\\"; - } else { - buffer += "\\\\"; - } - } else if (ch == '"') { - buffer += "\\\""; - } else if (ch == '\b') { - buffer += "\\b"; - } else if (ch == '\f') { - buffer += "\\f"; - } else if (ch == '\n') { - buffer += "\\n"; - } else if (ch == '\r') { - buffer += "\\r"; - } else if (ch == '\t') { - buffer += "\\t"; - } else if (static_cast(ch) <= 0x1f) { - // Unit separator - char buf[8]; - snprintf(buf, sizeof buf, "\\u%04x", ch); - buffer += buf; - } else { - buffer += ch; - } - } - buffer += '"'; + std::string buffer; + buffer += '"'; + auto const& string = str->GetString(); + common::EscapeU8(string, &buffer); + buffer += '"'; - auto s = stream_->size(); - stream_->resize(s + buffer.size()); - std::memcpy(stream_->data() + s, buffer.data(), buffer.size()); + auto s = stream_->size(); + stream_->resize(s + buffer.size()); + std::memcpy(stream_->data() + s, buffer.data(), buffer.size()); } void JsonWriter::Visit(JsonBoolean const* boolean) { diff --git a/src/common/numeric.h b/src/common/numeric.h index 2da85502ad17..5b45bba8c03e 100644 --- a/src/common/numeric.h +++ b/src/common/numeric.h @@ -10,6 +10,7 @@ #include // for size_t #include // for int32_t #include // for iterator_traits +#include // for accumulate #include #include "common.h" // AssertGPUSupport diff --git a/src/learner.cc b/src/learner.cc index b2d6baff0841..81d1b795b0bc 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -797,7 +797,7 @@ class LearnerConfiguration : public Learner { bool has_nc {cfg_.find("num_class") != cfg_.cend()}; // Inject num_class into configuration. // FIXME(jiamingy): Remove the duplicated parameter in softmax - cfg_["num_class"] = common::ToString(mparam_.num_class); + cfg_["num_class"] = std::to_string(mparam_.num_class); auto& args = *p_args; args = {cfg_.cbegin(), cfg_.cend()}; // renew obj_->Configure(args); @@ -1076,7 +1076,7 @@ class LearnerIO : public LearnerConfiguration { mparam_.major_version = std::get<0>(Version::Self()); mparam_.minor_version = std::get<1>(Version::Self()); - cfg_["num_feature"] = common::ToString(mparam_.num_feature); + cfg_["num_feature"] = std::to_string(mparam_.num_feature); auto n = tparam_.__DICT__(); cfg_.insert(n.cbegin(), n.cend()); diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index f32ea701f3a5..d37be14b894d 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -398,11 +398,14 @@ class JsonGenerator : public TreeGenerator { static std::string const kIndicatorTemplate = R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID"; auto split_index = tree[nid].SplitIndex(); + auto fname = fmap_.Name(split_index); + std::string qfname; // quoted + common::EscapeU8(fname, &qfname); auto result = SuperT::Match( kIndicatorTemplate, {{"{nid}", std::to_string(nid)}, {"{depth}", std::to_string(depth)}, - {"{fname}", fmap_.Name(split_index)}, + {"{fname}", qfname}, {"{yes}", std::to_string(nyes)}, {"{no}", std::to_string(tree[nid].DefaultChild())}}); return result; @@ -430,12 +433,14 @@ class JsonGenerator : public TreeGenerator { std::string const &template_str, std::string cond, uint32_t depth) const { auto split_index = tree[nid].SplitIndex(); + auto fname = split_index < fmap_.Size() ? fmap_.Name(split_index) : std::to_string(split_index); + std::string qfname; // quoted + common::EscapeU8(fname, &qfname); std::string const result = SuperT::Match( template_str, {{"{nid}", std::to_string(nid)}, {"{depth}", std::to_string(depth)}, - {"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) : - std::to_string(split_index)}, + {"{fname}", qfname}, {"{cond}", cond}, {"{left}", std::to_string(tree[nid].LeftChild())}, {"{right}", std::to_string(tree[nid].RightChild())}, diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index 610a9236e490..f0c80124d905 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -439,6 +439,26 @@ def validate_model(parameters): 'objective': 'multi:softmax'} validate_model(parameters) + def test_special_model_dump_characters(self): + params = {"objective": "reg:squarederror", "max_depth": 3} + feature_names = ['"feature 0"', "\tfeature\n1", "feature 2"] + X, y, w = tm.make_regression(n_samples=128, n_features=3, use_cupy=False) + Xy = xgb.DMatrix(X, label=y, feature_names=feature_names) + booster = xgb.train(params, Xy, num_boost_round=3) + json_dump = booster.get_dump(dump_format="json") + assert len(json_dump) == 3 + + def validate(obj: dict) -> None: + for k, v in obj.items(): + if k == "split": + assert v in feature_names + elif isinstance(v, dict): + validate(v) + + for j_tree in json_dump: + loaded = json.loads(j_tree) + validate(loaded) + def test_categorical_model_io(self): X, y = tm.make_categorical(256, 16, 71, False) Xy = xgb.DMatrix(X, y, enable_categorical=True)