Skip to content

Commit

Permalink
[SPARK-10617] [SQL] Fixed AddMonths leap year calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrovner committed Oct 5, 2015
1 parent d323e5e commit 4e64409
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ object DateTimeUtils {
val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0
val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay

val currentDayInMonth = if (daysToMonthEnd == 0 || dayOfMonth >= lastDayOfMonth) {
val currentDayInMonth = if (dayOfMonth >= lastDayOfMonth) {
// last day of the month
lastDayOfMonth
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("add_months") {
checkEvaluation(AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(1)),
DateTimeUtils.fromJavaDate(Date.valueOf("2015-02-28")))
checkEvaluation(AddMonths(Literal(Date.valueOf("2015-02-28")), Literal(12)),
DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-28")))
checkEvaluation(AddMonths(Literal(Date.valueOf("2016-03-30")), Literal(-1)),
DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29")))
checkEvaluation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,11 @@ class DateTimeUtilsSuite extends SparkFunSuite {
c1.set(1997, 1, 28, 10, 30, 0)
val days1 = millisToDays(c1.getTimeInMillis)
val c2 = Calendar.getInstance()
c2.set(2000, 1, 29)
c2.set(2000, 1, 28)
assert(dateAddMonths(days1, 36) === millisToDays(c2.getTimeInMillis))
c2.set(1996, 0, 31)
c2.set(1996, 0, 28)
assert(dateAddMonths(days1, -13) === millisToDays(c2.getTimeInMillis))

}

test("timestamp add months") {
Expand All @@ -383,7 +384,7 @@ class DateTimeUtilsSuite extends SparkFunSuite {
c1.set(Calendar.MILLISECOND, 0)
val ts1 = c1.getTimeInMillis * 1000L
val c2 = Calendar.getInstance()
c2.set(2000, 1, 29, 10, 30, 0)
c2.set(2000, 1, 28, 10, 30, 0)
c2.set(Calendar.MILLISECOND, 123)
val ts2 = c2.getTimeInMillis * 1000L
assert(timestampAddInterval(ts1, 36, 123000) === ts2)
Expand Down

0 comments on commit 4e64409

Please sign in to comment.