Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify TTL/MaxTTL logic in SSH CA paths #3507

Merged
merged 1 commit into from
Oct 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 22 additions & 45 deletions builtin/logical/ssh/path_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func pathRoles(b *backend) *framework.Path {
`,
},
"ttl": &framework.FieldSchema{
Type: framework.TypeString,
Type: framework.TypeDurationSecond,
Description: `
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
The lease duration if no specific lease duration is
Expand All @@ -184,7 +184,7 @@ func pathRoles(b *backend) *framework.Path {
the value of max_ttl.`,
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeString,
Type: framework.TypeDurationSecond,
Description: `
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
The maximum allowed lease duration
Expand Down Expand Up @@ -433,9 +433,9 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
}

func (b *backend) createCARole(allowedUsers, defaultUser string, data *framework.FieldData) (*sshRole, *logical.Response) {
ttl := time.Duration(data.Get("ttl").(int)) * time.Second
maxTTL := time.Duration(data.Get("max_ttl").(int)) * time.Second
role := &sshRole{
MaxTTL: data.Get("max_ttl").(string),
TTL: data.Get("ttl").(string),
AllowedCriticalOptions: data.Get("allowed_critical_options").(string),
AllowedExtensions: data.Get("allowed_extensions").(string),
AllowUserCertificates: data.Get("allow_user_certificates").(bool),
Expand All @@ -457,44 +457,12 @@ func (b *backend) createCARole(allowedUsers, defaultUser string, data *framework
defaultCriticalOptions := convertMapToStringValue(data.Get("default_critical_options").(map[string]interface{}))
defaultExtensions := convertMapToStringValue(data.Get("default_extensions").(map[string]interface{}))

var maxTTL time.Duration
maxSystemTTL := b.System().MaxLeaseTTL()
if len(role.MaxTTL) == 0 {
maxTTL = maxSystemTTL
} else {
var err error
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf(
"Invalid max ttl: %s", err))
}
}
if maxTTL > maxSystemTTL {
return nil, logical.ErrorResponse("Requested max TTL is higher than backend maximum")
if ttl != 0 && maxTTL != 0 && ttl > maxTTL {
return nil, logical.ErrorResponse(
`"ttl" value must be less than "max_ttl" when both are specified`)
}

ttl := b.System().DefaultLeaseTTL()
if len(role.TTL) != 0 {
var err error
ttl, err = parseutil.ParseDurationSecond(role.TTL)
if err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf(
"Invalid ttl: %s", err))
}
}
if ttl > maxTTL {
// If they are using the system default, cap it to the role max;
// if it was specified on the command line, make it an error
if len(role.TTL) == 0 {
ttl = maxTTL
} else {
return nil, logical.ErrorResponse(
`"ttl" value must be less than "max_ttl" and/or backend default max lease TTL value`,
)
}
}

// Persist clamped TTLs
// Persist TTLs
role.TTL = ttl.String()
role.MaxTTL = maxTTL.String()
role.DefaultCriticalOptions = defaultCriticalOptions
Expand Down Expand Up @@ -551,13 +519,22 @@ func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*l
},
}, nil
} else if role.KeyType == KeyTypeCA {
ttl, err := parseutil.ParseDurationSecond(role.TTL)
if err != nil {
return nil, err
}
maxTTL, err := parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return nil, err
}

return &logical.Response{
Data: map[string]interface{}{
"allowed_users": role.AllowedUsers,
"allowed_domains": role.AllowedDomains,
"default_user": role.DefaultUser,
"max_ttl": role.MaxTTL,
"ttl": role.TTL,
"allowed_users": role.AllowedUsers,
"allowed_domains": role.AllowedDomains,
"default_user": role.DefaultUser,
"ttl": int64(ttl.Seconds()),
"max_ttl": int64(maxTTL.Seconds()),
"allowed_critical_options": role.AllowedCriticalOptions,
"allowed_extensions": role.AllowedExtensions,
"allow_user_certificates": role.AllowUserCertificates,
Expand Down
38 changes: 16 additions & 22 deletions builtin/logical/ssh/path_sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func pathSign(b *backend) *framework.Path {
Description: `The desired role with configuration for this request.`,
},
"ttl": &framework.FieldSchema{
Type: framework.TypeString,
Type: framework.TypeDurationSecond,
Description: `The requested Time To Live for the SSH certificate;
sets the expiration date. If not specified
the role default, backend default, or system
Expand Down Expand Up @@ -345,40 +345,34 @@ func (b *backend) calculateExtensions(data *framework.FieldData, role *sshRole)
}

func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) {

var ttl, maxTTL time.Duration
var ttlField string
ttlFieldInt, ok := data.GetOk("ttl")
if !ok {
ttlField = role.TTL
} else {
ttlField = ttlFieldInt.(string)
}
var err error

if len(ttlField) == 0 {
ttl = b.System().DefaultLeaseTTL()
ttlRaw, specifiedTTL := data.GetOk("ttl")
if specifiedTTL {
ttl = time.Duration(ttlRaw.(int)) * time.Second
} else {
var err error
ttl, err = parseutil.ParseDurationSecond(ttlField)
ttl, err = parseutil.ParseDurationSecond(role.TTL)
if err != nil {
return 0, fmt.Errorf("invalid requested ttl: %s", err)
return 0, err
}
}
if ttl == 0 {
ttl = b.System().DefaultLeaseTTL()
}

if len(role.MaxTTL) == 0 {
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return 0, err
}
if maxTTL == 0 {
maxTTL = b.System().MaxLeaseTTL()
} else {
var err error
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return 0, fmt.Errorf("invalid requested max ttl: %s", err)
}
}

if ttl > maxTTL {
// Don't error if they were using system defaults, only error if
// they specifically chose a bad TTL
if len(ttlField) == 0 {
if !specifiedTTL {
ttl = maxTTL
} else {
return 0, fmt.Errorf("ttl is larger than maximum allowed (%d)", maxTTL/time.Second)
Expand Down