Skip to content

Commit

Permalink
Add Duration type
Browse files Browse the repository at this point in the history
  • Loading branch information
mabeyj committed May 26, 2017
1 parent c056b81 commit b881be9
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 0 deletions.
128 changes: 128 additions & 0 deletions duration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package saml

import (
"fmt"
"regexp"
"strconv"
"strings"
"time"
)

// Duration is a time.Duration that uses the xsd:duration format for text
// marshalling and unmarshalling.
type Duration time.Duration

// MarshalText implements the encoding.TextMarshaler interface.
func (d Duration) MarshalText() ([]byte, error) {
if d == 0 {
return nil, nil
}

out := "PT"
if d < 0 {
d *= -1
out = "-" + out
}

h := time.Duration(d) / time.Hour
m := time.Duration(d) % time.Hour / time.Minute
s := time.Duration(d) % time.Minute / time.Second
ns := time.Duration(d) % time.Second
if h > 0 {
out += fmt.Sprintf("%dH", h)
}
if m > 0 {
out += fmt.Sprintf("%dM", m)
}
if s > 0 || ns > 0 {
out += fmt.Sprintf("%d", s)
if ns > 0 {
out += strings.TrimRight(fmt.Sprintf(".%09d", ns), "0")
}
out += "S"
}

return []byte(out), nil
}

const (
day = 24 * time.Hour
month = 30 * day // Assumed to be 30 days.
year = 12 * month // Assumed to be non-leap year.
)

var (
durationRegexp = regexp.MustCompile(`^(-?)P(?:(\d+)Y)?(?:(\d+)M)?(?:(\d+)D)?(?:T(.+))?$`)
durationTimeRegexp = regexp.MustCompile(`^(?:(\d+)H)?(?:(\d+)M)?(?:(\d+(?:\.\d+)?)S)?$`)
)

// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (d *Duration) UnmarshalText(text []byte) error {
if text == nil {
*d = 0
return nil
}

var (
out time.Duration
sign time.Duration = 1
)
match := durationRegexp.FindStringSubmatch(string(text))
if match == nil || strings.Join(match[2:6], "") == "" {
return fmt.Errorf("invalid duration (%s)", text)
}
if match[1] == "-" {
sign = -1
}
if match[2] != "" {
y, err := strconv.Atoi(match[2])
if err != nil {
return fmt.Errorf("invalid duration years (%s): %s", text, err)
}
out += time.Duration(y) * year
}
if match[3] != "" {
m, err := strconv.Atoi(match[3])
if err != nil {
return fmt.Errorf("invalid duration months (%s): %s", text, err)
}
out += time.Duration(m) * month
}
if match[4] != "" {
d, err := strconv.Atoi(match[4])
if err != nil {
return fmt.Errorf("invalid duration days (%s): %s", text, err)
}
out += time.Duration(d) * day
}
if match[5] != "" {
match := durationTimeRegexp.FindStringSubmatch(match[5])
if match == nil {
return fmt.Errorf("invalid duration (%s)", text)
}
if match[1] != "" {
h, err := strconv.Atoi(match[1])
if err != nil {
return fmt.Errorf("invalid duration hours (%s): %s", text, err)
}
out += time.Duration(h) * time.Hour
}
if match[2] != "" {
m, err := strconv.Atoi(match[2])
if err != nil {
return fmt.Errorf("invalid duration minutes (%s): %s", text, err)
}
out += time.Duration(m) * time.Minute
}
if match[3] != "" {
s, err := strconv.ParseFloat(match[3], 64)
if err != nil {
return fmt.Errorf("invalid duration seconds (%s): %s", text, err)
}
out += time.Duration(s * float64(time.Second))
}
}

*d = Duration(sign * out)
return nil
}
78 changes: 78 additions & 0 deletions duration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package saml

import (
"errors"
"time"

. "gopkg.in/check.v1"
)

var _ = Suite(&DurationTest{})

type DurationTest struct{}

var durationMarshalTests = []struct {
in time.Duration
out []byte
}{
{0, nil},
{time.Nanosecond, []byte("PT0.000000001S")},
{time.Millisecond, []byte("PT0.001S")},
{time.Second, []byte("PT1S")},
{time.Minute, []byte("PT1M")},
{time.Hour, []byte("PT1H")},
{-time.Hour, []byte("-PT1H")},
{2*time.Hour + 3*time.Minute + 4*time.Second + 5*time.Nanosecond, []byte("PT2H3M4.000000005S")},
}

func (t DurationTest) TestMarshalText(c *C) {
for _, tc := range durationMarshalTests {
got, err := Duration(tc.in).MarshalText()
c.Assert(err, IsNil)
c.Assert(got, DeepEquals, tc.out)
}
}

var durationUnmarshalTests = []struct {
in []byte
out time.Duration
err error
}{
{nil, 0, nil},
{[]byte("PT0.0000000001S"), 0, nil},
{[]byte("PT0.000000001S"), time.Nanosecond, nil},
{[]byte("PT0.001S"), time.Millisecond, nil},
{[]byte("PT1S"), time.Second, nil},
{[]byte("PT1M"), time.Minute, nil},
{[]byte("PT1H"), time.Hour, nil},
{[]byte("-PT1H"), -time.Hour, nil},
{[]byte("P1D"), day, nil},
{[]byte("P1M"), month, nil},
{[]byte("P1Y"), year, nil},
{[]byte("P2Y3M4DT5H6M7.000000008S"), 2*year + 3*month + 4*day + 5*time.Hour + 6*time.Minute + 7*time.Second + 8*time.Nanosecond, nil},
{[]byte("P0Y0M0DT0H0M0S"), 0, nil},
{[]byte("P0001Y"), year, nil},
{[]byte(""), 0, errors.New("invalid duration ()")},
{[]byte("12345"), 0, errors.New("invalid duration (12345)")},
{[]byte("P1D1M1Y"), 0, errors.New("invalid duration (P1D1M1Y)")},
{[]byte("P1H1M1S"), 0, errors.New("invalid duration (P1H1M1S)")},
{[]byte("PT1S1M1H"), 0, errors.New("invalid duration (PT1S1M1H)")},
{[]byte(" P1Y "), 0, errors.New("invalid duration ( P1Y )")},
{[]byte("P"), 0, errors.New("invalid duration (P)")},
{[]byte("-P"), 0, errors.New("invalid duration (-P)")},
{[]byte("PT"), 0, errors.New("invalid duration (PT)")},
{[]byte("P1YMD"), 0, errors.New("invalid duration (P1YMD)")},
{[]byte("P1YT"), 0, errors.New("invalid duration (P1YT)")},
{[]byte("P-1Y"), 0, errors.New("invalid duration (P-1Y)")},
{[]byte("P1.5Y"), 0, errors.New("invalid duration (P1.5Y)")},
{[]byte("PT1.S"), 0, errors.New("invalid duration (PT1.S)")},
}

func (t DurationTest) TestUnmarshalText(c *C) {
for _, tc := range durationUnmarshalTests {
var d Duration
err := d.UnmarshalText(tc.in)
c.Assert(err, DeepEquals, tc.err)
c.Assert(d, Equals, Duration(tc.out))
}
}

0 comments on commit b881be9

Please sign in to comment.