diff --git a/internal/patch/patch.go b/internal/patch/patch.go index e540784..b19711b 100644 --- a/internal/patch/patch.go +++ b/internal/patch/patch.go @@ -1,6 +1,7 @@ package patch import ( + "bytes" "encoding/json" "fmt" "strings" @@ -41,7 +42,11 @@ func NewValidator(patchReq string, s schema.Schema, extensions ...schema.Schema) Path string Value interface{} } - if err := json.Unmarshal([]byte(patchReq), &operation); err != nil { + + // Decode a number into a json.Number instead of floag64 + d := json.NewDecoder(bytes.NewBufferString(patchReq)) + d.UseNumber() + if err := d.Decode(&operation); err != nil { return OperationValidator{}, err } diff --git a/internal/patch/patch_test.go b/internal/patch/patch_test.go index 991f20b..e181077 100644 --- a/internal/patch/patch_test.go +++ b/internal/patch/patch_test.go @@ -2,9 +2,10 @@ package patch import ( "fmt" + "testing" + "github.com/elimity-com/scim/schema" "github.com/scim2/filter-parser/v2" - "testing" ) func TestNewPathValidator(t *testing.T) { @@ -31,6 +32,94 @@ func TestNewPathValidator(t *testing.T) { t.Error("expected JSON error, got none") } }) + t.Run("Valid integer", func(t *testing.T) { + ops := []string{ + `{"op":"add","path":"attr2","value":1234}`, + `{"op":"add","path":"attr2","value":"1234"}`, + } + for _, op := range ops { + validator, err := NewValidator(op, patchSchema) + if err != nil { + t.Errorf("unexpected error, got %v", err) + return + } + v, err := validator.Validate() + if err != nil { + t.Errorf("unexpected error, got %v", err) + return + } + n, ok := v.(int64) + if !ok { + t.Errorf("unexpected type, got %T", v) + return + } + if n != 1234 { + t.Errorf("unexpected integer, got %d", n) + return + } + } + }) + + t.Run("Valid float64", func(t *testing.T) { + ops := []string{ + `{"op":"add","path":"attr3","value":12.34}`, + `{"op":"add","path":"attr3","value":"12.34"}`, + } + for _, op := range ops { + validator, err := NewValidator(op, patchSchema) + if err != nil { + t.Errorf("unexpected error, got %v", err) + return + } + v, err := validator.Validate() + if err != nil { + t.Errorf("unexpected error, got %v", err) + return + } + n, ok := v.(float64) + if !ok { + t.Errorf("unexpected type, got %T", v) + return + } + if n != 12.34 { + t.Errorf("unexpected integer, got %f", n) + return + } + } + }) + + t.Run("Valid Booleans", func(t *testing.T) { + tests := []struct { + op string + expected bool + }{ + {`{"op":"add","path":"attr4","value":true}`, true}, + {`{"op":"add","path":"attr4","value":"True"}`, true}, + {`{"op":"add","path":"attr4","value":false}`, false}, + {`{"op":"add","path":"attr4","value":"False"}`, false}, + } + for _, tc := range tests { + validator, err := NewValidator(tc.op, patchSchema) + if err != nil { + t.Errorf("unexpected error, got %v", err) + return + } + v, err := validator.Validate() + if err != nil { + t.Errorf("unexpected error, got %v", err) + return + } + b, ok := v.(bool) + if !ok { + t.Errorf("unexpected type, got %T", v) + return + } + if b != tc.expected { + t.Errorf("unexpected integer, got %v", b) + return + } + } + }) } func TestOperationValidator_getRefAttribute(t *testing.T) { diff --git a/internal/patch/remove_test.go b/internal/patch/remove_test.go index a3e9816..6f95800 100644 --- a/internal/patch/remove_test.go +++ b/internal/patch/remove_test.go @@ -31,18 +31,6 @@ func Example_removeComplexMultiValuedAttributeValue() { // } -// The following example shows how remove a single member from a group. -func Example_removeSingleMember() { - operation := `{ - "op": "remove", - "path": "members[value eq \"0001\"]" -}` - validator, _ := NewValidator(operation, schema.CoreGroupSchema()) - fmt.Println(validator.Validate()) - // Output: - // -} - // The following example shows how remove a single group from a user. func Example_removeSingleGroup() { operation := `{ @@ -59,6 +47,18 @@ func Example_removeSingleGroup() { // [map[]] } +// The following example shows how remove a single member from a group. +func Example_removeSingleMember() { + operation := `{ + "op": "remove", + "path": "members[value eq \"0001\"]" +}` + validator, _ := NewValidator(operation, schema.CoreGroupSchema()) + fmt.Println(validator.Validate()) + // Output: + // +} + // The following example shows how to replace all of the members of a group with a different members list. func Example_replaceAllMembers() { operations := []string{`{ diff --git a/internal/patch/schema_test.go b/internal/patch/schema_test.go index 711864e..c7c6c6a 100644 --- a/internal/patch/schema_test.go +++ b/internal/patch/schema_test.go @@ -14,6 +14,17 @@ var ( schema.SimpleCoreAttribute(schema.SimpleStringParams(schema.StringParams{ Name: "attr1", })), + schema.SimpleCoreAttribute(schema.SimpleNumberParams(schema.NumberParams{ + Name: "attr2", + Type: schema.AttributeTypeInteger(), + })), + schema.SimpleCoreAttribute(schema.SimpleNumberParams(schema.NumberParams{ + Name: "attr3", + Type: schema.AttributeTypeDecimal(), + })), + schema.SimpleCoreAttribute(schema.SimpleBooleanParams(schema.BooleanParams{ + Name: "attr4", + })), schema.SimpleCoreAttribute(schema.SimpleStringParams(schema.StringParams{ Name: "multiValued", MultiValued: true, diff --git a/schema/core.go b/schema/core.go index 18399de..1f65b7c 100644 --- a/schema/core.go +++ b/schema/core.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "regexp" + "strconv" "strings" datetime "github.com/di-wu/xsd-datetime" @@ -179,12 +180,19 @@ func (a CoreAttribute) ValidateSingular(attribute interface{}) (interface{}, *er return bin, nil case attributeDataTypeBoolean: - b, ok := attribute.(bool) - if !ok { + switch b := attribute.(type) { + case bool: + return b, nil + case string: + bb, err := strconv.ParseBool(b) + if err != nil { + return nil, &errors.ScimErrorInvalidValue + } + + return bb, nil + default: return nil, &errors.ScimErrorInvalidValue } - - return b, nil case attributeDataTypeComplex: obj, ok := attribute.(map[string]interface{}) if !ok { @@ -237,6 +245,13 @@ func (a CoreAttribute) ValidateSingular(attribute interface{}) (interface{}, *er return f, nil case float64: return n, nil + case string: + f, err := strconv.ParseFloat(n, 64) + if err != nil { + return nil, &errors.ScimErrorInvalidValue + } + + return f, nil default: return nil, &errors.ScimErrorInvalidValue } @@ -251,6 +266,13 @@ func (a CoreAttribute) ValidateSingular(attribute interface{}) (interface{}, *er return i, nil case int, int8, int16, int32, int64: return n, nil + case string: + i, err := strconv.ParseInt(n, 10, 64) + if err != nil { + return nil, &errors.ScimErrorInvalidValue + } + + return i, nil default: return nil, &errors.ScimErrorInvalidValue } diff --git a/schema/schema_test.go b/schema/schema_test.go index 91bd9ed..3497551 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -229,7 +229,7 @@ func TestValidationInvalid(t *testing.T) { "booleans": []interface{}{ true, }, - "decimal": "1.1", + "decimal": "1,000", }, { // invalid type integer (json.Number) "required": "present",