Skip to content

Commit

Permalink
make it thread-safe and simplify ctor (#3326)
Browse files Browse the repository at this point in the history
  • Loading branch information
wgtmac authored Mar 7, 2025
1 parent a2047e9 commit a7d27e4
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 2 deletions.
6 changes: 6 additions & 0 deletions lang/c++/impl/Compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,12 @@ static LogicalType makeLogicalType(const Entity &e, const Object &m) {
t = LogicalType::DURATION;
else if (typeField == "uuid")
t = LogicalType::UUID;
else {
auto custom = CustomLogicalTypeRegistry::instance().create(typeField, e.toString());
if (custom != nullptr) {
return LogicalType(std::move(custom));
}
}
return LogicalType(t);
}

Expand Down
35 changes: 34 additions & 1 deletion lang/c++/impl/LogicalType.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
namespace avro {

LogicalType::LogicalType(Type type)
: type_(type), precision_(0), scale_(0) {}
: type_(type), precision_(0), scale_(0), custom_(nullptr) {
if (type == CUSTOM) {
throw Exception("Logical type CUSTOM must be initialized with a custom logical type");
}
}

LogicalType::LogicalType(std::shared_ptr<CustomLogicalType> custom)
: type_(CUSTOM), precision_(0), scale_(0), custom_(std::move(custom)) {}

LogicalType::Type LogicalType::type() const {
return type_;
Expand Down Expand Up @@ -92,7 +99,33 @@ void LogicalType::printJson(std::ostream &os) const {
case UUID:
os << R"("logicalType": "uuid")";
break;
case CUSTOM:
custom_->printJson(os);
break;
}
}

void CustomLogicalType::printJson(std::ostream &os) const {
os << R"("logicalType": ")" << name_ << "\"";
}

CustomLogicalTypeRegistry &CustomLogicalTypeRegistry::instance() {
static CustomLogicalTypeRegistry instance;
return instance;
}

void CustomLogicalTypeRegistry::registerType(const std::string &name, Factory factory) {
std::lock_guard<std::mutex> lock(mutex_);
registry_[name] = factory;
}

std::shared_ptr<CustomLogicalType> CustomLogicalTypeRegistry::create(const std::string &name, const std::string &json) const {
std::lock_guard<std::mutex> lock(mutex_);
auto it = registry_.find(name);
if (it == registry_.end()) {
return nullptr;
}
return it->second(json);
}

} // namespace avro
5 changes: 5 additions & 0 deletions lang/c++/impl/Node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ void Node::setLogicalType(LogicalType logicalType) {
"STRING type");
}
break;
case LogicalType::CUSTOM:
if (logicalType.customLogicalType() == nullptr) {
throw Exception("CUSTOM logical type is not set");
}
break;
}

logicalType_ = logicalType;
Expand Down
12 changes: 12 additions & 0 deletions lang/c++/impl/NodeImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,18 @@ static void printName(std::ostream &os, const Name &n, size_t depth) {
os << indent(depth) << R"("name": ")" << n.simpleName() << "\",\n";
}

static void printLogicalType(std::ostream &os, const LogicalType &logicalType, size_t depth) {
if (logicalType.type() != LogicalType::NONE) {
os << indent(depth);
logicalType.printJson(os);
os << ",\n";
}
}

void NodeRecord::printJson(std::ostream &os, size_t depth) const {
os << "{\n";
os << indent(++depth) << "\"type\": \"record\",\n";
printLogicalType(os, logicalType(), depth);
const Name &name = nameAttribute_.get();
printName(os, name, depth);

Expand Down Expand Up @@ -524,6 +533,7 @@ void NodeMap::printDefaultToJson(const GenericDatum &g, std::ostream &os,
void NodeEnum::printJson(std::ostream &os, size_t depth) const {
os << "{\n";
os << indent(++depth) << "\"type\": \"enum\",\n";
printLogicalType(os, logicalType(), depth);
if (!getDoc().empty()) {
os << indent(depth) << R"("doc": ")"
<< escape(getDoc()) << "\",\n";
Expand All @@ -550,6 +560,7 @@ void NodeEnum::printJson(std::ostream &os, size_t depth) const {
void NodeArray::printJson(std::ostream &os, size_t depth) const {
os << "{\n";
os << indent(depth + 1) << "\"type\": \"array\",\n";
printLogicalType(os, logicalType(), depth + 1);
if (!getDoc().empty()) {
os << indent(depth + 1) << R"("doc": ")"
<< escape(getDoc()) << "\",\n";
Expand All @@ -566,6 +577,7 @@ void NodeArray::printJson(std::ostream &os, size_t depth) const {
void NodeMap::printJson(std::ostream &os, size_t depth) const {
os << "{\n";
os << indent(depth + 1) << "\"type\": \"map\",\n";
printLogicalType(os, logicalType(), depth + 1);
if (!getDoc().empty()) {
os << indent(depth + 1) << R"("doc": ")"
<< escape(getDoc()) << "\",\n";
Expand Down
51 changes: 50 additions & 1 deletion lang/c++/include/avro/LogicalType.hh
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@
#ifndef avro_LogicalType_hh__
#define avro_LogicalType_hh__

#include <functional>
#include <iostream>
#include <memory>
#include <mutex>
#include <unordered_map>

#include "Config.hh"

namespace avro {

class CustomLogicalType;

class AVRO_DECL LogicalType {
public:
enum Type {
Expand All @@ -41,10 +47,12 @@ public:
LOCAL_TIMESTAMP_MICROS,
LOCAL_TIMESTAMP_NANOS,
DURATION,
UUID
UUID,
CUSTOM // for registered custom logical types
};

explicit LogicalType(Type type);
explicit LogicalType(std::shared_ptr<CustomLogicalType> custom);

Type type() const;

Expand All @@ -57,12 +65,53 @@ public:
void setScale(int32_t scale);
int32_t scale() const { return scale_; }

const std::shared_ptr<CustomLogicalType> &customLogicalType() const {
return custom_;
}

void printJson(std::ostream &os) const;

private:
Type type_;
int32_t precision_;
int32_t scale_;
std::shared_ptr<CustomLogicalType> custom_;
};

class AVRO_DECL CustomLogicalType {
public:
CustomLogicalType(const std::string &name) : name_(name) {}

virtual ~CustomLogicalType() = default;

const std::string &name() const { return name_; }

virtual void printJson(std::ostream &os) const;

private:
std::string name_;
};

// Registry for custom logical types.
// This class is thread-safe.
class AVRO_DECL CustomLogicalTypeRegistry {
public:
static CustomLogicalTypeRegistry &instance();

using Factory = std::function<std::shared_ptr<CustomLogicalType>(const std::string &json)>;

// Register a custom logical type and its factory function.
void registerType(const std::string &name, Factory factory);

// Create a custom logical type from a JSON string.
// Returns nullptr if the name is not registered.
std::shared_ptr<CustomLogicalType> create(const std::string &name, const std::string &json) const;

private:
CustomLogicalTypeRegistry() = default;

std::unordered_map<std::string, Factory> registry_;
mutable std::mutex mutex_;
};

} // namespace avro
Expand Down
38 changes: 38 additions & 0 deletions lang/c++/test/SchemaTests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,43 @@ static void testMalformedLogicalTypes(const char *schema) {
BOOST_CHECK(datum.logicalType().type() == LogicalType::NONE);
}

static void testCustomLogicalType() {
// Declare a custom logical type.
struct MapLogicalType : public CustomLogicalType {
MapLogicalType() : CustomLogicalType("map") {}
};

// Register the custom logical type with the registry.
CustomLogicalTypeRegistry::instance().registerType("map", [](const std::string &) {
return std::make_shared<MapLogicalType>();
});

auto verifyCustomLogicalType = [](const ValidSchema &schema) {
auto logicalType = schema.root()->logicalType();
BOOST_CHECK_EQUAL(logicalType.type(), LogicalType::CUSTOM);
BOOST_CHECK_EQUAL(logicalType.customLogicalType()->name(), "map");
};

const std::string schema =
R"({ "type": "array",
"logicalType": "map",
"items": {
"type": "record",
"name": "k12_v13",
"fields": [
{ "name": "key", "type": "int", "field-id": 12 },
{ "name": "value", "type": "string", "field-id": 13 }
]
}
})";
auto compiledSchema = compileJsonSchemaFromString(schema);
verifyCustomLogicalType(compiledSchema);

auto json = compiledSchema.toJson();
auto parsedSchema = compileJsonSchemaFromString(json);
verifyCustomLogicalType(parsedSchema);
}

} // namespace schema
} // namespace avro

Expand All @@ -681,5 +718,6 @@ init_unit_test_suite(int /*argc*/, char * /*argv*/[]) {
ADD_PARAM_TEST(ts, avro::schema::testMalformedLogicalTypes,
avro::schema::malformedLogicalTypes);
ts->add(BOOST_TEST_CASE(&avro::schema::testCompactSchemas));
ts->add(BOOST_TEST_CASE(&avro::schema::testCustomLogicalType));
return ts;
}

0 comments on commit a7d27e4

Please sign in to comment.