Skip to content

Commit

Permalink
Fixed printing of nan floats/doubles in Python.
Browse files Browse the repository at this point in the history
The second assert in _upb_EncodeRoundTripFloat is raised if val is a nan. This fix just returns the output of first spnprintf.

I am not sure how changes to this repo are made so feel free to ignore this CL.

To test this, you could
1. Define a proto with a float field
message Test {
  float val = 1;
}
2. In a python script, import the library and then set the val to nan and try to print it.
proto = Test(val=float('nan'))
print(proto)

This will cause a coredump due to assertion error:

assert.h assertion failed at third_party/upb/upb/lex/round_trip.c:46 in void _upb_EncodeRoundTripFloat(float, char *, size_t): strtof(buf, NULL) == val

Added the corresponding change to double too

PiperOrigin-RevId: 637127851
  • Loading branch information
protobuf-github-bot authored and copybara-github committed May 25, 2024
1 parent ee98ba2 commit f651080
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/google/protobuf/internal/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,23 @@ def testFloatPrinting(self, message_module):
message.optional_float = 2.0
self.assertEqual(str(message), 'optional_float: 2.0\n')

def testFloatNanPrinting(self, message_module):
message = message_module.TestAllTypes()
message.optional_float = float('nan')
self.assertEqual(str(message), 'optional_float: nan\n')

def testHighPrecisionFloatPrinting(self, message_module):
msg = message_module.TestAllTypes()
msg.optional_float = 0.12345678912345678
old_float = msg.optional_float
msg.ParseFromString(msg.SerializeToString())
self.assertEqual(old_float, msg.optional_float)

def testDoubleNanPrinting(self, message_module):
message = message_module.TestAllTypes()
message.optional_double = float('nan')
self.assertEqual(str(message), 'optional_double: nan\n')

def testHighPrecisionDoublePrinting(self, message_module):
msg = message_module.TestAllTypes()
msg.optional_double = 0.12345678912345678
Expand Down
10 changes: 10 additions & 0 deletions upb/lex/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ cc_test(
],
)

cc_test(
name = "round_trip_test",
srcs = ["round_trip_test.cc"],
deps = [
":lex",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
],
)

# begin:github_only
filegroup(
name = "source_files",
Expand Down
10 changes: 10 additions & 0 deletions upb/lex/round_trip.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "upb/lex/round_trip.h"

#include <float.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>

// Must be last.
Expand All @@ -28,6 +30,10 @@ static void upb_FixLocale(char* p) {

void _upb_EncodeRoundTripDouble(double val, char* buf, size_t size) {
assert(size >= kUpb_RoundTripBufferSize);
if (isnan(val)) {
snprintf(buf, size, "%s", "nan");
return;
}
snprintf(buf, size, "%.*g", DBL_DIG, val);
if (strtod(buf, NULL) != val) {
snprintf(buf, size, "%.*g", DBL_DIG + 2, val);
Expand All @@ -38,6 +44,10 @@ void _upb_EncodeRoundTripDouble(double val, char* buf, size_t size) {

void _upb_EncodeRoundTripFloat(float val, char* buf, size_t size) {
assert(size >= kUpb_RoundTripBufferSize);
if (isnan(val)) {
snprintf(buf, size, "%s", "nan");
return;
}
snprintf(buf, size, "%.*g", FLT_DIG, val);
if (strtof(buf, NULL) != val) {
snprintf(buf, size, "%.*g", FLT_DIG + 3, val);
Expand Down
35 changes: 35 additions & 0 deletions upb/lex/round_trip_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include "upb/lex/round_trip.h"

#include <math.h>

#include <gtest/gtest.h>

namespace {

TEST(RoundTripTest, Double) {
char buf[32];

_upb_EncodeRoundTripDouble(0.123456789, buf, sizeof(buf));
EXPECT_STREQ(buf, "0.123456789");

_upb_EncodeRoundTripDouble(0.0, buf, sizeof(buf));
EXPECT_STREQ(buf, "0");

_upb_EncodeRoundTripDouble(nan(""), buf, sizeof(buf));
EXPECT_STREQ(buf, "nan");
}

TEST(RoundTripTest, Float) {
char buf[32];

_upb_EncodeRoundTripFloat(0.123456, buf, sizeof(buf));
EXPECT_STREQ(buf, "0.123456");

_upb_EncodeRoundTripFloat(0.0, buf, sizeof(buf));
EXPECT_STREQ(buf, "0");

_upb_EncodeRoundTripFloat(nan(""), buf, sizeof(buf));
EXPECT_STREQ(buf, "nan");
}

} // namespace

0 comments on commit f651080

Please sign in to comment.