From 3f981d5b892a1de89f161d72b542ae96c4342b4c Mon Sep 17 00:00:00 2001 From: "Dr. Stefan Schimanski" Date: Wed, 7 Mar 2018 09:36:48 +0100 Subject: [PATCH] WIP: towards orthogonal defaulting Defaulting was implemented nested into the main validation algorithm. This makes the defaulting algorithm much more complex. This commit prototypes an externally implemented trivial looking defaulting algorithm. --- defaulter.go | 36 +++++++++++++++++ defaulter_test.go | 71 ++++++++++++++++++++++++++++++++ object_validator.go | 13 ++++-- result.go | 98 ++++++++++++++++++++++++++++++++++++++++++++- schema.go | 4 ++ slice_validator.go | 22 +++++----- 6 files changed, 229 insertions(+), 15 deletions(-) create mode 100644 defaulter.go diff --git a/defaulter.go b/defaulter.go new file mode 100644 index 0000000..a7ccf15 --- /dev/null +++ b/defaulter.go @@ -0,0 +1,36 @@ +package validate + +import "fmt" + +func ApplyDefaults(root interface{}, result *Result) { + applyDefaults(root, result) +} + +func applyDefaults(root interface{}, result *Result) { + switch obj := root.(type) { + case map[string]interface{}: + for _, val := range obj { + applyDefaults(val, result) + } + applyObjectDefaults(obj, result) + case []interface{}: + for _, val := range obj { + applyDefaults(val, result) + } + } +} + +func applyObjectDefaults(obj map[string]interface{}, result *Result) { + for fld, schemata := range result.propertySchemata[fmt.Sprintf("%p", obj)] { + if _, ok := obj[fld]; ok { + continue + } + + for _, schema := range schemata { + if schema.Default != nil { + obj[fld] = schema.Default + break + } + } + } +} diff --git a/defaulter_test.go b/defaulter_test.go index e7f37d1..bf1e91c 100644 --- a/defaulter_test.go +++ b/defaulter_test.go @@ -51,3 +51,74 @@ func TestDefaulter(t *testing.T) { assert.NoError(t, err) assert.Equal(t, expected, x) } + +func TestApplyDefaultsSimple(t *testing.T) { + schema := spec.Schema{ + SchemaProps: spec.SchemaProps{ + Properties: map[string]spec.Schema{ + "int": spec.Schema{ + SchemaProps: spec.SchemaProps{ + Default: float64(42), + }, + }, + "str": spec.Schema{ + SchemaProps: spec.SchemaProps{ + Default: "Hello", + }, + }, + }, + }, + } + validator := NewSchemaValidator(&schema, nil, "", strfmt.Default) + x := map[string]interface{}{} + t.Logf("Before: %v", x) + r := validator.Validate(x) + assert.False(t, r.HasErrors(), fmt.Sprintf("unexpected validation error: %v", r.AsError())) + + ApplyDefaults(x, r) + t.Logf("After: %v", x) + var expected interface{} + err := json.Unmarshal([]byte(`{ + "int": 42, + "str": "Hello" + }`), &expected) + assert.NoError(t, err) + assert.Equal(t, expected, x) +} + +func TestApplyDefaultsNested(t *testing.T) { + fname := filepath.Join(defaulterFixturesPath, "schema.json") + b, err := ioutil.ReadFile(fname) + assert.NoError(t, err) + var schema spec.Schema + assert.NoError(t, json.Unmarshal(b, &schema)) + + err = spec.ExpandSchema(&schema, nil, nil /*new(noopResCache)*/) + assert.NoError(t, err, fname+" should expand cleanly") + + validator := NewSchemaValidator(&schema, nil, "", strfmt.Default) + x := map[string]interface{}{ + "nested": map[string]interface{}{}, + "all": map[string]interface{}{}, + "any": map[string]interface{}{}, + "one": map[string]interface{}{}, + } + t.Logf("Before: %v", x) + r := validator.Validate(x) + assert.False(t, r.HasErrors(), fmt.Sprintf("unexpected validation error: %v", r.AsError())) + + ApplyDefaults(x, r) + t.Logf("After: %v", x) + var expected interface{} + err = json.Unmarshal([]byte(`{ + "int": 42, + "str": "Hello", + "obj": {"foo": "bar"}, + "nested": {"inner": 7}, + "all": {"foo": 42, "bar": 42}, + "any": {"foo": 42}, + "one": {"bar": 42} + }`), &expected) + assert.NoError(t, err) + assert.Equal(t, expected, x) +} diff --git a/object_validator.go b/object_validator.go index 0dd5802..525b785 100644 --- a/object_validator.go +++ b/object_validator.go @@ -101,6 +101,7 @@ func (o *objectValidator) Validate(data interface{}) *Result { o.precheck(res, val) + // check validity of field names if o.AdditionalProperties != nil && !o.AdditionalProperties.Allows { for k := range val { _, regularProperty := o.Properties[k] @@ -122,7 +123,8 @@ func (o *objectValidator) Validate(data interface{}) *Result { matched, succeededOnce, _ := o.validatePatternProperty(key, value, res) if !(regularProperty || matched || succeededOnce) { if o.AdditionalProperties != nil && o.AdditionalProperties.Schema != nil { - res.Merge(NewSchemaValidator(o.AdditionalProperties.Schema, o.Root, o.Path+"."+key, o.KnownFormats).Validate(value)) + r := NewSchemaValidator(o.AdditionalProperties.Schema, o.Root, o.Path+"."+key, o.KnownFormats).Validate(value) + res.mergeForField(data.(map[string]interface{}), key, r) } else if regularProperty && !(matched || succeededOnce) { res.AddErrors(errors.FailedAllPatternProperties(o.Path, o.In, key)) } @@ -132,7 +134,8 @@ func (o *objectValidator) Validate(data interface{}) *Result { createdFromDefaults := map[string]bool{} - for pName, pSchema := range o.Properties { + for pName := range o.Properties { + pSchema := o.Properties[pName] // one instance per iteration rName := pName if o.Path != "" { rName = o.Path + "." + pName @@ -140,9 +143,10 @@ func (o *objectValidator) Validate(data interface{}) *Result { if v, ok := val[pName]; ok { r := NewSchemaValidator(&pSchema, o.Root, rName, o.KnownFormats).Validate(v) - res.Merge(r) + res.mergeForField(data.(map[string]interface{}), pName, r) } else if pSchema.Default != nil { createdFromDefaults[pName] = true + res.addPropertySchemata(data.(map[string]interface{}), pName, &pSchema) pName := pName // shaddow def := pSchema.Default res.Defaulters = append(res.Defaulters, DefaulterFunc(func() { @@ -166,7 +170,8 @@ func (o *objectValidator) Validate(data interface{}) *Result { if !regularProperty && (matched || succeededOnce) { for _, pName := range patterns { if v, ok := o.PatternProperties[pName]; ok { - res.Merge(NewSchemaValidator(&v, o.Root, o.Path+"."+key, o.KnownFormats).Validate(value)) + r := NewSchemaValidator(&v, o.Root, o.Path+"."+key, o.KnownFormats).Validate(value) + res.mergeForField(data.(map[string]interface{}), key, r) } } } diff --git a/result.go b/result.go index 7b58e50..f212660 100644 --- a/result.go +++ b/result.go @@ -15,9 +15,11 @@ package validate import ( + "fmt" "os" "github.com/go-openapi/errors" + "github.com/go-openapi/spec" ) var ( @@ -35,11 +37,20 @@ func (f DefaulterFunc) Apply() { f() } +type SchemaPropsProperty struct { + SchemaProps *spec.SchemaProps + Property string +} + // Result represents a validation result type Result struct { Errors []error MatchCount int Defaulters []Defaulter + + objectSchemata []*spec.Schema + propertySchemata map[string]map[string][]*spec.Schema + sliceSchemata map[string][][]*spec.Schema } // Merge merges this result with the other one, preserving match counts etc @@ -47,10 +58,95 @@ func (r *Result) Merge(other *Result) *Result { if other == nil { return r } + r.mergeWithoutRootSchemata(other) + r.objectSchemata = append(r.objectSchemata, other.objectSchemata...) + return r +} + +func (r *Result) mergeForField(obj map[string]interface{}, field string, other *Result) *Result { + if other == nil { + return r + } + r.mergeWithoutRootSchemata(other) + r.addPropertySchemata(obj, field, other.objectSchemata...) + return r +} + +func (r *Result) mergeForSlice(slice []interface{}, i int, other *Result) *Result { + if other == nil { + return r + } + r.mergeWithoutRootSchemata(other) + r.addSliceSchemata(slice, i, other.objectSchemata...) + return r +} + +func (r *Result) addPropertySchemata(obj map[string]interface{}, field string, schemata ...*spec.Schema) { + if len(schemata) == 0 { + return + } + key := fmt.Sprintf("%p", obj) + if r.propertySchemata == nil { + r.propertySchemata = make(map[string]map[string][]*spec.Schema) + } + if _, ok := r.propertySchemata[key]; !ok { + r.propertySchemata[key] = make(map[string][]*spec.Schema) + } + r.propertySchemata[key][field] = append(r.propertySchemata[key][field], schemata...) +} + +func (r *Result) addSliceSchemata(slice []interface{}, i int, schemata ...*spec.Schema) { + if len(schemata) == 0 { + return + } + key := fmt.Sprintf("%p", slice) + if r.sliceSchemata == nil { + r.sliceSchemata = make(map[string][][]*spec.Schema) + } + if _, ok := r.sliceSchemata[key]; !ok { + r.sliceSchemata[key] = make([][]*spec.Schema, len(slice)) + } + if i >= len(slice) { + panic(fmt.Sprintf("index %d out of bounds for slice %#v of length %d", i, slice, len(slice))) + } + r.sliceSchemata[key][i] = append(r.sliceSchemata[key][i], schemata...) +} + +func (r *Result) mergeWithoutRootSchemata(other *Result) { + if other == nil { + return + } r.AddErrors(other.Errors...) r.MatchCount += other.MatchCount r.Defaulters = append(r.Defaulters, other.Defaulters...) - return r + + if r.propertySchemata == nil && other.propertySchemata != nil { + r.propertySchemata = make(map[string]map[string][]*spec.Schema, len(other.propertySchemata)) + } + for obj, objSchemata := range other.propertySchemata { + if _, ok := r.propertySchemata[obj]; !ok { + r.propertySchemata[obj] = make(map[string][]*spec.Schema, len(objSchemata)) + } + for fld, schemata := range objSchemata { + r.propertySchemata[obj][fld] = append(r.propertySchemata[obj][fld], schemata...) + } + } + + if r.sliceSchemata == nil && other.sliceSchemata != nil { + r.sliceSchemata = make(map[string][][]*spec.Schema, len(other.sliceSchemata)) + } + for slc, sliceSchemata := range other.sliceSchemata { + if _, ok := r.sliceSchemata[slc]; !ok { + r.sliceSchemata[slc] = make([][]*spec.Schema, len(other.sliceSchemata[slc])) + } + for i, schemata := range sliceSchemata { + if i < len(r.sliceSchemata[slc]) { + r.sliceSchemata[slc] = append(r.sliceSchemata[slc], schemata) + } else { + r.sliceSchemata[slc][i] = append(r.sliceSchemata[slc][i], schemata...) + } + } + } } // AddErrors adds errors to this validation result diff --git a/schema.go b/schema.go index f859c6d..79f00fd 100644 --- a/schema.go +++ b/schema.go @@ -86,6 +86,9 @@ func (s *SchemaValidator) Validate(data interface{}) *Result { if s == nil { return result } + if s.Schema != nil { + result.objectSchemata = []*spec.Schema{s.Schema} + } if data == nil { v := s.validators[0].Validate(data) @@ -141,6 +144,7 @@ func (s *SchemaValidator) Validate(data interface{}) *Result { result.Inc() } result.Inc() + return result } diff --git a/slice_validator.go b/slice_validator.go index 2665a0f..3600ef9 100644 --- a/slice_validator.go +++ b/slice_validator.go @@ -52,36 +52,38 @@ func (s *schemaSliceValidator) Validate(data interface{}) *Result { } val := reflect.ValueOf(data) size := val.Len() + result.sliceSchemata = map[string][][]*spec.Schema{ + fmt.Sprintf("%p", data): make([][]*spec.Schema, size), + } if s.Items != nil && s.Items.Schema != nil { validator := NewSchemaValidator(s.Items.Schema, s.Root, s.Path, s.KnownFormats) for i := 0; i < size; i++ { validator.SetPath(fmt.Sprintf("%s.%d", s.Path, i)) value := val.Index(i) - result.Merge(validator.Validate(value.Interface())) + result.mergeForSlice(data.([]interface{}), i, validator.Validate(value.Interface())) } } - itemsSize := int64(0) + itemsSize := 0 if s.Items != nil && len(s.Items.Schemas) > 0 { - itemsSize = int64(len(s.Items.Schemas)) - for i := int64(0); i < itemsSize; i++ { + itemsSize = len(s.Items.Schemas) + for i := 0; i < itemsSize; i++ { validator := NewSchemaValidator(&s.Items.Schemas[i], s.Root, fmt.Sprintf("%s.%d", s.Path, i), s.KnownFormats) - if val.Len() <= int(i) { + if val.Len() <= i { break } - result.Merge(validator.Validate(val.Index(int(i)).Interface())) + result.mergeForSlice(data.([]interface{}), int(i), validator.Validate(val.Index(i).Interface())) } - } - if s.AdditionalItems != nil && itemsSize < int64(size) { + if s.AdditionalItems != nil && itemsSize < size { if s.Items != nil && len(s.Items.Schemas) > 0 && !s.AdditionalItems.Allows { result.AddErrors(errors.New(422, "array doesn't allow for additional items")) } if s.AdditionalItems.Schema != nil { - for i := itemsSize; i < (int64(size)-itemsSize)+1; i++ { + for i := itemsSize; i < size-itemsSize+1; i++ { validator := NewSchemaValidator(s.AdditionalItems.Schema, s.Root, fmt.Sprintf("%s.%d", s.Path, i), s.KnownFormats) - result.Merge(validator.Validate(val.Index(int(i)).Interface())) + result.mergeForSlice(data.([]interface{}), int(i), validator.Validate(val.Index(int(i)).Interface())) } } }