Skip to content

Commit

Permalink
lang/funcs: Experimental "defaults" function
Browse files Browse the repository at this point in the history
This is a new part of the existing module_variable_optional_attrs
experiment, because it's intended to complement the ability to declare
an input variable whose type constraint is an object type with optional
attributes. Module authors can use this to replace null values (that were
either explicitly set or implied by attribute omission) with other
non-null values of the same type.

This function is a bit more type-fussy than our functions typically are
because it's intended for use primarily with input variables that have
fully-specified type constraints, and thus it uses that type information
to help inform how the defaults data structure should be interpreted.

Other uses of this function will probably be harder today because it takes
a lot of extra annotation to build a value of a specific type if it isn't
passing through a variable type constraint. Perhaps later language
features for more general type conversion will make this more applicable,
but for now the more general form of this problem is better solved other
ways.
  • Loading branch information
apparentlymart committed Nov 2, 2020
1 parent 397c457 commit 5be8a0e
Show file tree
Hide file tree
Showing 4 changed files with 399 additions and 0 deletions.
256 changes: 256 additions & 0 deletions lang/funcs/defaults.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
package funcs

import (
"fmt"

"github.com/hashicorp/terraform/tfdiags"
"github.com/zclconf/go-cty/cty"
"github.com/zclconf/go-cty/cty/convert"
"github.com/zclconf/go-cty/cty/function"
)

// DefaultsFunc is a helper function for substituting default values in
// place of null values in a given data structure.
//
// See the documentation for function Defaults for more information.
var DefaultsFunc = function.New(&function.Spec{
Params: []function.Parameter{
{
Name: "input",
Type: cty.DynamicPseudoType,
},
{
Name: "defaults",
Type: cty.DynamicPseudoType,
},
},
Type: func(args []cty.Value) (cty.Type, error) {
// The result type is guaranteed to be the same as the input type,
// since all we're doing is replacing null values with non-null
// values of the same type.
retType := args[0].Type()
defaultsType := args[1].Type()

// This function is aimed at filling in object types or collections
// of object types where some of the attributes might be null, so
// it doesn't make sense to use a primitive type directly with it.
// (The "coalesce" function may be appropriate for such cases.)
if retType.IsPrimitiveType() {
// This error message is a bit of a fib because we can actually
// apply defaults to tuples too, but we expect that to be so
// unusual as to not be worth mentioning here, because mentioning
// it would require using some less-well-known Terraform language
// terminology in the message (tuple types, structural types).
return cty.DynamicPseudoType, function.NewArgErrorf(1, "only object types and collections of object types can have defaults applied")
}

defaultsPath := make(cty.Path, 0, 4) // some capacity so that most structures won't reallocate
if err := defaultsAssertSuitableFallback(retType, defaultsType, defaultsPath); err != nil {
errMsg := tfdiags.FormatError(err) // add attribute path prefix
return cty.DynamicPseudoType, function.NewArgErrorf(1, "%s", errMsg)
}

return retType, nil
},
Impl: func(args []cty.Value, retType cty.Type) (cty.Value, error) {
if args[0].Type().HasDynamicTypes() {
// If the types our input object aren't known yet for some reason
// then we'll defer all of our work here, because our
// interpretation of the defaults depends on the types in
// the input.
return cty.UnknownVal(retType), nil
}

v := defaultsApply(args[0], args[1])
return v, nil
},
})

func defaultsApply(input, fallback cty.Value) cty.Value {
const fallbackArgIdx = 1

wantTy := input.Type()
if !(input.IsKnown() && fallback.IsKnown()) {
return cty.UnknownVal(wantTy)
}

// For the rest of this function we're assuming that the given defaults
// will always be valid, because we expect to have caught any problems
// during the type checking phase. Any inconsistencies that reach here are
// therefore considered to be implementation bugs, and so will panic.

// Our strategy depends on the kind of type we're working with.
switch {
case wantTy.IsPrimitiveType():
// For leaf primitive values the rule is relatively simple: use the
// input if it's non-null, or fallback if input is null.
if !input.IsNull() {
return input
}
v, err := convert.Convert(fallback, wantTy)
if err != nil {
// Should not happen because we checked in defaultsAssertSuitableFallback
panic(err.Error())
}
return v

case wantTy.IsObjectType():
atys := wantTy.AttributeTypes()
ret := map[string]cty.Value{}
for attr, aty := range atys {
inputSub := input.GetAttr(attr)
fallbackSub := cty.NullVal(aty)
if fallback.Type().HasAttribute(attr) {
fallbackSub = fallback.GetAttr(attr)
}
ret[attr] = defaultsApply(inputSub, fallbackSub)
}
return cty.ObjectVal(ret)

case wantTy.IsTupleType():
l := wantTy.Length()
ret := make([]cty.Value, l)
for i := 0; i < l; i++ {
inputSub := input.Index(cty.NumberIntVal(int64(i)))
fallbackSub := fallback.Index(cty.NumberIntVal(int64(i)))
ret[i] = defaultsApply(inputSub, fallbackSub)
}
return cty.TupleVal(ret)

case wantTy.IsCollectionType():
// For collection types we apply a single fallback value to each
// element of the input collection, because in the situations this
// function is intended for we assume that the number of elements
// is the caller's decision, and so we'll just apply the same defaults
// to all of the elements.
ety := wantTy.ElementType()
switch {
case wantTy.IsMapType():
newVals := map[string]cty.Value{}

for it := input.ElementIterator(); it.Next(); {
k, v := it.Element()
newVals[k.AsString()] = defaultsApply(v, fallback)
}

if len(newVals) == 0 {
return cty.MapValEmpty(ety)
}
return cty.MapVal(newVals)
case wantTy.IsListType(), wantTy.IsSetType():
var newVals []cty.Value

for it := input.ElementIterator(); it.Next(); {
_, v := it.Element()
newV := defaultsApply(v, fallback)
newVals = append(newVals, newV)
}

if len(newVals) == 0 {
if wantTy.IsSetType() {
return cty.SetValEmpty(ety)
}
return cty.ListValEmpty(ety)
}
if wantTy.IsSetType() {
return cty.SetVal(newVals)
}
return cty.ListVal(newVals)
default:
// There are no other collection types, so this should not happen
panic(fmt.Sprintf("invalid collection type %#v", wantTy))
}
default:
// We should've caught anything else in defaultsAssertSuitableFallback,
// so this should not happen.
panic(fmt.Sprintf("invalid target type %#v", wantTy))
}
}

func defaultsAssertSuitableFallback(wantTy, fallbackTy cty.Type, fallbackPath cty.Path) error {
// If the type we want is a collection type then we need to keep peeling
// away collection type wrappers until we find the non-collection-type
// that's underneath, which is what the fallback will actually be applied
// to.
inCollection := false
for wantTy.IsCollectionType() {
wantTy = wantTy.ElementType()
inCollection = true
}

switch {
case wantTy.IsPrimitiveType():
// The fallback is valid if it's equal to or convertible to what we want.
if fallbackTy.Equals(wantTy) {
return nil
}
conversion := convert.GetConversionUnsafe(fallbackTy, wantTy)
if conversion == nil {
msg := convert.MismatchMessage(fallbackTy, wantTy)
return fallbackPath.NewErrorf("invalid default value for %s: %s", wantTy.FriendlyName(), msg)
}
return nil
case wantTy.IsObjectType():
if !fallbackTy.IsObjectType() {
if inCollection {
return fallbackPath.NewErrorf("the default value for a collection of an object type must itself be an object type, not %s", fallbackTy.FriendlyName())
}
return fallbackPath.NewErrorf("the default value for an object type must itself be an object type, not %s", fallbackTy.FriendlyName())
}
for attr, wantAty := range wantTy.AttributeTypes() {
if !fallbackTy.HasAttribute(attr) {
continue // it's always okay to not have a default value
}
fallbackSubpath := fallbackPath.GetAttr(attr)
fallbackSubTy := fallbackTy.AttributeType(attr)
err := defaultsAssertSuitableFallback(wantAty, fallbackSubTy, fallbackSubpath)
if err != nil {
return err
}
}
for attr := range fallbackTy.AttributeTypes() {
if !wantTy.HasAttribute(attr) {
fallbackSubpath := fallbackPath.GetAttr(attr)
return fallbackSubpath.NewErrorf("target type does not expect an attribute named %q", attr)
}
}
return nil
case wantTy.IsTupleType():
if !fallbackTy.IsTupleType() {
if inCollection {
return fallbackPath.NewErrorf("the default value for a collection of a tuple type must itself be a tuple type, not %s", fallbackTy.FriendlyName())
}
return fallbackPath.NewErrorf("the default value for a tuple type must itself be a tuple type, not %s", fallbackTy.FriendlyName())
}
wantEtys := wantTy.TupleElementTypes()
fallbackEtys := fallbackTy.TupleElementTypes()
if got, want := len(wantEtys), len(fallbackEtys); got != want {
return fallbackPath.NewErrorf("the default value for a tuple type of length %d must also have length %d, not %d", want, want, got)
}
for i := 0; i < len(wantEtys); i++ {
fallbackSubpath := fallbackPath.IndexInt(i)
wantSubTy := wantEtys[i]
fallbackSubTy := fallbackEtys[i]
err := defaultsAssertSuitableFallback(wantSubTy, fallbackSubTy, fallbackSubpath)
if err != nil {
return err
}
}
return nil
default:
// No other types are supported right now.
return fallbackPath.NewErrorf("cannot apply defaults to %s", wantTy.FriendlyName())
}
}

// Defaults is a helper function for substituting default values in
// place of null values in a given data structure.
//
// This is primarily intended for use with a module input variable that
// has an object type constraint (or a collection thereof) that has optional
// attributes, so that the receiver of a value that omits those attributes
// can insert non-null default values in place of the null values caused by
// omitting the attributes.
func Defaults(input, defaults cty.Value) (cty.Value, error) {
return DefaultsFunc.Call([]cty.Value{input, defaults})
}
123 changes: 123 additions & 0 deletions lang/funcs/defaults_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package funcs

import (
"fmt"
"testing"

"github.com/zclconf/go-cty/cty"
)

func TestDefaults(t *testing.T) {
tests := []struct {
Input, Defaults cty.Value
Want cty.Value
WantErr string
}{
{
Input: cty.ObjectVal(map[string]cty.Value{
"a": cty.NullVal(cty.String),
}),
Defaults: cty.ObjectVal(map[string]cty.Value{
"a": cty.StringVal("hello"),
}),
Want: cty.ObjectVal(map[string]cty.Value{
"a": cty.StringVal("hello"),
}),
},
{
Input: cty.ObjectVal(map[string]cty.Value{
"a": cty.StringVal("hey"),
}),
Defaults: cty.ObjectVal(map[string]cty.Value{
"a": cty.StringVal("hello"),
}),
Want: cty.ObjectVal(map[string]cty.Value{
"a": cty.StringVal("hey"),
}),
},
{
Input: cty.ObjectVal(map[string]cty.Value{
"a": cty.NullVal(cty.String),
}),
Defaults: cty.ObjectVal(map[string]cty.Value{
"a": cty.NullVal(cty.String),
}),
Want: cty.ObjectVal(map[string]cty.Value{
"a": cty.NullVal(cty.String),
}),
},
{
Input: cty.ObjectVal(map[string]cty.Value{
"a": cty.NullVal(cty.String),
}),
Defaults: cty.ObjectVal(map[string]cty.Value{}),
Want: cty.ObjectVal(map[string]cty.Value{
"a": cty.NullVal(cty.String),
}),
},
{
Input: cty.ObjectVal(map[string]cty.Value{}),
Defaults: cty.ObjectVal(map[string]cty.Value{
"a": cty.NullVal(cty.String),
}),
WantErr: `.a: target type does not expect an attribute named "a"`,
},

{
Input: cty.ObjectVal(map[string]cty.Value{
"a": cty.ListVal([]cty.Value{
cty.NullVal(cty.String),
}),
}),
Defaults: cty.ObjectVal(map[string]cty.Value{
"a": cty.StringVal("hello"),
}),
Want: cty.ObjectVal(map[string]cty.Value{
"a": cty.ListVal([]cty.Value{
cty.StringVal("hello"),
}),
}),
},
{
Input: cty.ObjectVal(map[string]cty.Value{
"a": cty.ListVal([]cty.Value{
cty.NullVal(cty.String),
cty.StringVal("hey"),
cty.NullVal(cty.String),
}),
}),
Defaults: cty.ObjectVal(map[string]cty.Value{
"a": cty.StringVal("hello"),
}),
Want: cty.ObjectVal(map[string]cty.Value{
"a": cty.ListVal([]cty.Value{
cty.StringVal("hello"),
cty.StringVal("hey"),
cty.StringVal("hello"),
}),
}),
},
}

for _, test := range tests {
t.Run(fmt.Sprintf("defaults(%#v, %#v)", test.Input, test.Defaults), func(t *testing.T) {
got, gotErr := Defaults(test.Input, test.Defaults)

if test.WantErr != "" {
if gotErr == nil {
t.Fatalf("unexpected success\nwant error: %s", test.WantErr)
}
if got, want := gotErr.Error(), test.WantErr; got != want {
t.Fatalf("wrong error\ngot: %s\nwant: %s", got, want)
}
return
} else if gotErr != nil {
t.Fatalf("unexpected error\ngot: %s", gotErr.Error())
}

if !test.Want.RawEquals(got) {
t.Errorf("wrong result\ngot: %#v\nwant: %#v", got, test.Want)
}
})
}
}
1 change: 1 addition & 0 deletions lang/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ func (s *Scope) Functions() map[string]function.Function {
"concat": stdlib.ConcatFunc,
"contains": stdlib.ContainsFunc,
"csvdecode": stdlib.CSVDecodeFunc,
"defaults": s.experimentalFunction(experiments.ModuleVariableOptionalAttrs, funcs.DefaultsFunc),
"dirname": funcs.DirnameFunc,
"distinct": stdlib.DistinctFunc,
"element": stdlib.ElementFunc,
Expand Down
Loading

0 comments on commit 5be8a0e

Please sign in to comment.