Skip to content

Commit

Permalink
Merge pull request #145 from maropu/NaNComparisonIssue
Browse files Browse the repository at this point in the history
Fix the bug that handles NaN comparisons incorrectly
  • Loading branch information
oontvoo authored Apr 25, 2021
2 parents 8f08f2e + 29aaa06 commit f16db69
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ class EvaluatorTest extends CommonsCompilerTestSuite {
if (comp == "!=") { return lhs != rhs; }
if (comp == "<") { return lhs < rhs; }
if (comp == "<=") { return lhs <= rhs; }
if (comp == ">") { return lhs < rhs; }
if (comp == ">=") { return lhs <= rhs; }
if (comp == ">") { return lhs > rhs; }
if (comp == ">=") { return lhs >= rhs; }
throw new RuntimeException("Unsupported comparison");
}
public static boolean
Expand All @@ -379,14 +379,14 @@ class EvaluatorTest extends CommonsCompilerTestSuite {
if (comp == "!=") { return lhs != rhs; }
if (comp == "<") { return lhs < rhs; }
if (comp == "<=") { return lhs <= rhs; }
if (comp == ">") { return lhs < rhs; }
if (comp == ">=") { return lhs <= rhs; }
if (comp == ">") { return lhs > rhs; }
if (comp == ">=") { return lhs >= rhs; }
throw new RuntimeException("Unsupported comparison");
}

@Test public void
testHandlingNaN() throws Exception {
String prog = (
String prog1 = (
""
+ "package test;\n"
+ "public class Test {\n"
Expand All @@ -395,64 +395,93 @@ class EvaluatorTest extends CommonsCompilerTestSuite {
+ " if (comp == \"!=\") { return lhs != rhs; }\n"
+ " if (comp == \"<\" ) { return lhs < rhs; }\n"
+ " if (comp == \"<=\") { return lhs <= rhs; }\n"
+ " if (comp == \">\" ) { return lhs < rhs; }\n"
+ " if (comp == \">=\") { return lhs <= rhs; }\n"
+ " if (comp == \">\" ) { return lhs > rhs; }\n"
+ " if (comp == \">=\") { return lhs >= rhs; }\n"
+ " throw new RuntimeException(\"Unsupported comparison\");\n"
+ " }\n"
+ " public static boolean compare(float lhs, float rhs, String comp) {\n"
+ " if (comp == \"==\") { return lhs == rhs; }\n"
+ " if (comp == \"!=\") { return lhs != rhs; }\n"
+ " if (comp == \"<\" ) { return lhs < rhs; }\n"
+ " if (comp == \"<=\") { return lhs <= rhs; }\n"
+ " if (comp == \">\" ) { return lhs < rhs; }\n"
+ " if (comp == \">=\") { return lhs <= rhs; }\n"
+ " if (comp == \">\" ) { return lhs > rhs; }\n"
+ " if (comp == \">=\") { return lhs >= rhs; }\n"
+ " throw new RuntimeException(\"Unsupported comparison\");\n"
+ " }\n"
+ "}\n"
);

String prog2 = (
""
+ "package test;\n"
+ "public class Test {\n"
+ " public static boolean compare(double lhs, double rhs, String comp) {\n"
+ " if (comp == \"==\") { if (lhs == rhs) { return true; } else { return false; } }\n"
+ " if (comp == \"!=\") { if (lhs != rhs) { return true; } else { return false; } }\n"
+ " if (comp == \"<\") { if (lhs < rhs) { return true; } else { return false; } }\n"
+ " if (comp == \"<=\") { if (lhs <= rhs) { return true; } else { return false; } }\n"
+ " if (comp == \">\") { if (lhs > rhs) { return true; } else { return false; } }\n"
+ " if (comp == \">=\") { if (lhs >= rhs) { return true; } else { return false; } }\n"
+ " throw new RuntimeException(\"Unsupported comparison\");\n"
+ " }\n"
+ " public static boolean compare(float lhs, float rhs, String comp) {\n"
+ " if (comp == \"==\") { if (lhs == rhs) { return true; } else { return false; } }\n"
+ " if (comp == \"!=\") { if (lhs != rhs) { return true; } else { return false; } }\n"
+ " if (comp == \"<\") { if (lhs < rhs) { return true; } else { return false; } }\n"
+ " if (comp == \"<=\") { if (lhs <= rhs) { return true; } else { return false; } }\n"
+ " if (comp == \">\") { if (lhs > rhs) { return true; } else { return false; } }\n"
+ " if (comp == \">=\") { if (lhs >= rhs) { return true; } else { return false; } }\n"
+ " throw new RuntimeException(\"Unsupported comparison\");\n"
+ " }\n"
+ "}\n"
);
ISimpleCompiler sc = this.compilerFactory.newSimpleCompiler();
sc.cook(prog);

final Class<?> c = sc.getClassLoader().loadClass("test.Test");
final Method dm = c.getMethod("compare", new Class[] { double.class, double.class, String.class });
final Method fm = c.getMethod("compare", new Class[] { float.class, float.class, String.class });
final Double[][] argss = new Double[][] {
{ new Double(Double.NaN), new Double(Double.NaN) },
{ new Double(Double.NaN), new Double(1.0) },
{ new Double(1.0), new Double(Double.NaN) },
{ new Double(1.0), new Double(2.0) },
{ new Double(2.0), new Double(1.0) },
{ new Double(1.0), new Double(1.0) },
};
String[] operators = new String[] { "==", "!=", "<", "<=", ">", ">=" };
for (String operator : operators) {
for (Double[] args : argss) {
String msg = "\"" + args[0] + " " + operator + " " + args[1] + "\"";
{
boolean exp = EvaluatorTest.compare(
args[0].doubleValue(),
args[1].doubleValue(),
operator
);
Object[] objs = new Object[] { args[0], args[1], operator };
Object actual = dm.invoke(null, objs);
Assert.assertEquals(msg, exp, actual);
}

{
msg = "float: " + msg;
boolean exp = EvaluatorTest.compare(
args[0].floatValue(),
args[1].floatValue(),
operator
);
Object[] objs = new Object[] {
new Float(args[0].floatValue()),
new Float(args[1].floatValue()),
operator
};
Object actual = fm.invoke(null, objs);
Assert.assertEquals(msg, exp, actual);
String[] progs = new String[] { prog1, prog2 };
for (String prog : progs) {
ISimpleCompiler sc = this.compilerFactory.newSimpleCompiler();
sc.cook(prog);

final Class<?> c = sc.getClassLoader().loadClass("test.Test");
final Method dm = c.getMethod("compare", new Class[] { double.class, double.class, String.class });
final Method fm = c.getMethod("compare", new Class[] { float.class, float.class, String.class });
final Double[][] argss = new Double[][] {
{ new Double(Double.NaN), new Double(Double.NaN) },
{ new Double(Double.NaN), new Double(1.0) },
{ new Double(1.0), new Double(Double.NaN) },
{ new Double(1.0), new Double(2.0) },
{ new Double(2.0), new Double(1.0) },
{ new Double(1.0), new Double(1.0) },
};
String[] operators = new String[] { "==", "!=", "<", "<=", ">", ">=" };
for (String operator : operators) {
for (Double[] args : argss) {
String msg = "\"" + args[0] + " " + operator + " " + args[1] + "\"";
{
boolean exp = EvaluatorTest.compare(
args[0].doubleValue(),
args[1].doubleValue(),
operator
);
Object[] objs = new Object[] { args[0], args[1], operator };
Object actual = dm.invoke(null, objs);
Assert.assertEquals(msg, exp, actual);
}

{
msg = "float: " + msg;
boolean exp = EvaluatorTest.compare(
args[0].floatValue(),
args[1].floatValue(),
operator
);
Object[] objs = new Object[] {
new Float(args[0].floatValue()),
new Float(args[1].floatValue()),
operator
};
Object actual = fm.invoke(null, objs);
Assert.assertEquals(msg, exp, actual);
}
}
}
}
Expand Down
17 changes: 9 additions & 8 deletions janino/src/main/java/org/codehaus/janino/UnitCompiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -4308,8 +4308,6 @@ private int getTargetVersion() {
);
if (opIdx == Integer.MIN_VALUE) break COMPARISON;

if (orientation == UnitCompiler.JUMP_IF_FALSE) opIdx ^= 1;

// Comparison with "null".
{
boolean lhsIsNull = this.getConstantValue(bo.lhs) == null;
Expand Down Expand Up @@ -4351,6 +4349,8 @@ private int getTargetVersion() {
this.consT(bo, (Object) null);
}

if (orientation == UnitCompiler.JUMP_IF_FALSE) opIdx ^= 1;

switch (opIdx) {

case EQ:
Expand Down Expand Up @@ -4389,7 +4389,7 @@ private int getTargetVersion() {
this.compileGetValue(bo.rhs);
this.numericPromotion(bo.rhs, this.convertToPrimitiveNumericType(bo.rhs, rhsType), promotedType);

this.ifNumeric(bo, opIdx, dst);
this.ifNumeric(bo, opIdx, dst, orientation);
return;
}

Expand Down Expand Up @@ -4418,7 +4418,7 @@ private int getTargetVersion() {
this.unboxingConversion(bo, icl.TYPE_java_lang_Boolean, IClass.BOOLEAN);
}

this.if_icmpxx(bo, opIdx, dst);
this.if_icmpxx(bo, orientation == UnitCompiler.JUMP_IF_FALSE ? opIdx ^ 1 : opIdx, dst);
return;
}

Expand All @@ -4435,7 +4435,7 @@ private int getTargetVersion() {

this.compileGetValue(bo.rhs);

this.if_acmpxx(bo, opIdx, dst);
this.if_acmpxx(bo, orientation == UnitCompiler.JUMP_IF_FALSE ? opIdx ^ 1 : opIdx, dst);
return;
}

Expand Down Expand Up @@ -11746,23 +11746,24 @@ interface Compilable { void compile() throws CompileException; }

/**
* @param opIdx One of {@link #EQ}, {@link #NE}, {@link #LT}, {@link #GE}, {@link #GT} or {@link #LE}
* @param orientation {@link #JUMP_IF_TRUE} or {@link #JUMP_IF_FALSE}
*/
private void
ifNumeric(Locatable locatable, int opIdx, Offset dst) {
ifNumeric(Locatable locatable, int opIdx, Offset dst, boolean orientation) {
assert opIdx >= UnitCompiler.EQ && opIdx <= UnitCompiler.LE;

VerificationTypeInfo topOperand = this.getCodeContext().peekOperand();

if (topOperand == StackMapTableAttribute.INTEGER_VARIABLE_INFO) {
this.if_icmpxx(locatable, opIdx, dst);
this.if_icmpxx(locatable, orientation == UnitCompiler.JUMP_IF_FALSE ? opIdx ^ 1 : opIdx, dst);
} else
if (
topOperand == StackMapTableAttribute.LONG_VARIABLE_INFO
|| topOperand == StackMapTableAttribute.FLOAT_VARIABLE_INFO
|| topOperand == StackMapTableAttribute.DOUBLE_VARIABLE_INFO
) {
this.cmp(locatable, opIdx);
this.ifxx(locatable, opIdx, dst);
this.ifxx(locatable, orientation == UnitCompiler.JUMP_IF_FALSE ? opIdx ^ 1 : opIdx, dst);
} else
{
throw new InternalCompilerException("Unexpected computational type \"" + topOperand + "\"");
Expand Down

0 comments on commit f16db69

Please sign in to comment.