diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index bc46ca52aa284..474ea58c8d465 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -125,37 +125,37 @@ public UTF8String substring(final int start, final int until) { } public boolean contains(final UTF8String substring) { + if (substring == null) return false; final byte[] b = substring.getBytes(); if (b.length == 0) { return true; } for (int i = 0; i <= bytes.length - b.length; i++) { - if (bytes[i] == b[0] && startsWith(substring, i)) { + if (bytes[i] == b[0] && startsWith(b, i)) { return true; } } return false; } - private boolean startsWith(final UTF8String prefix, int offset) { - byte[] b = prefix.getBytes(); - if (b.length + offset > bytes.length || offset < 0) { + private boolean startsWith(final byte[] prefix, int offset) { + if (prefix.length + offset > bytes.length || offset < 0) { return false; } int i = 0; - while (i < b.length && b[i] == bytes[i + offset]) { + while (i < prefix.length && prefix[i] == bytes[i + offset]) { i++; } - return i == b.length; + return i == prefix.length; } public boolean startsWith(final UTF8String prefix) { - return startsWith(prefix, 0); + return prefix != null && startsWith(prefix.getBytes(), 0); } public boolean endsWith(final UTF8String suffix) { - return startsWith(suffix, bytes.length - suffix.getBytes().length); + return suffix != null && startsWith(suffix.getBytes(), bytes.length - suffix.getBytes().length); } public UTF8String toUpperCase() { diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 80c179a1b5e75..f0f530418b08f 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -46,6 +46,7 @@ public void basicTest() throws UnsupportedEncodingException { @Test public void contains() { + Assert.assertFalse(UTF8String.fromString("hello").contains(null)); Assert.assertTrue(UTF8String.fromString("hello").contains(UTF8String.fromString("ello"))); Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("vello"))); Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("hellooo"))); @@ -57,6 +58,7 @@ public void contains() { @Test public void startsWith() { + Assert.assertFalse(UTF8String.fromString("hello").startsWith(null)); Assert.assertTrue(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hell"))); Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("ell"))); Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hellooo"))); @@ -68,6 +70,7 @@ public void startsWith() { @Test public void endsWith() { + Assert.assertFalse(UTF8String.fromString("hello").endsWith(null)); Assert.assertTrue(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ello"))); Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ellov"))); Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("hhhello")));