Skip to content

Commit

Permalink
Merge pull request #108 from projectdiscovery/issue-103-add-callback-var
Browse files Browse the repository at this point in the history
Add Callback Vars to flagset (#103)
  • Loading branch information
Mzack9999 authored Feb 28, 2023
2 parents 24e2560 + 2fc46cd commit 783b6c4
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 3 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ The following types are supported by the goflags library. The `<name>P` suffix m
| VarP | Custom value with long short name implementing flag.Value interface |
| EnumVar | Enum value with long name |
| EnumVarP | Enum value with long short name |
| CallbackVar | Callback function as value with long name |
| CallbackVarP | Callback function as value with long short name |


### String Slice Options

Expand Down Expand Up @@ -95,6 +98,14 @@ func main() {
flagSet.BoolVar(&opt.silent, "silent", true, "show silent output")
flagSet.StringSliceVarP(&opt.inputs, "inputs", "i", nil, "list of inputs (file,comma-separated)", goflags.FileCommaSeparatedStringSliceOptions)

update := func(tool string ) func() {
return func() {
fmt.Printf("%v updated successfully!", tool)
}
}
flagSet.CallbackVarP(update("tool_1"), "update", "up", "update tool_1")


// Group example
flagSet.CreateGroup("config", "Configuration",
flagSet.StringVar(&opt.config, "config", "", "file to read config from"),
Expand Down
64 changes: 64 additions & 0 deletions callback_var.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package goflags

import (
"fmt"
"strconv"
)

// CallBackFunc
type CallBackFunc func()

// callBackVar
type callBackVar struct {
Value CallBackFunc
}

// Set
func (c *callBackVar) Set(s string) error {
v, err := strconv.ParseBool(s)
if err != nil {
return fmt.Errorf("failed to parse callback flag")
}
if v {
// if flag found execute callback
c.Value()
}
return nil
}

// IsBoolFlag
func (c *callBackVar) IsBoolFlag() bool {
return true
}

// String
func (c *callBackVar) String() string {
return "false"
}

// CallbackVar adds a Callback flag with a longname
func (flagSet *FlagSet) CallbackVar(callback CallBackFunc, long string, usage string) *FlagData {
return flagSet.CallbackVarP(callback, long, "", usage)
}

// CallbackVarP adds a Callback flag with a shortname and longname
func (flagSet *FlagSet) CallbackVarP(callback CallBackFunc, long, short string, usage string) *FlagData {
if callback == nil {
panic(fmt.Errorf("callback cannot be nil for flag -%v", long))
}
flagData := &FlagData{
usage: usage,
long: long,
defaultValue: strconv.FormatBool(false),
field: &callBackVar{Value: callback},
skipMarshal: true,
}
if short != "" {
flagData.short = short
flagSet.CommandLine.Var(flagData.field, short, usage)
flagSet.flagKeys.Set(short, flagData)
}
flagSet.CommandLine.Var(flagData.field, long, usage)
flagSet.flagKeys.Set(long, flagData)
return flagData
}
64 changes: 64 additions & 0 deletions callback_var_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package goflags

import (
"bytes"
"fmt"
"io"
"os"
"os/exec"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSuccessfulCallback(t *testing.T) {
toolName := "tool_1"
want := `updated successfully!`
got := &bytes.Buffer{}

flagSet := NewFlagSet()
flagSet.CreateGroup("Update", "Update",
flagSet.CallbackVar(updateCallbackFunc(toolName, got), "update", fmt.Sprintf("update %v to the latest released version", toolName)),
flagSet.CallbackVarP(func() {}, "disable-update-check", "duc", "disable automatic update check"),
)
os.Args = []string{
os.Args[0],
"-update",
}
err := flagSet.Parse()
assert.Nil(t, err)
assert.Equal(t, want, got.String())
tearDown(t.Name())
}

func TestFailCallback(t *testing.T) {
toolName := "tool_1"
got := &bytes.Buffer{}

if os.Getenv("IS_SUB_PROCESS") == "1" {
flagSet := NewFlagSet()
flagSet.CommandLine.SetOutput(got)
flagSet.CreateGroup("Update", "Update",
flagSet.CallbackVar(nil, "update", fmt.Sprintf("update %v to the latest released version", toolName)),
)
os.Args = []string{
os.Args[0],
"-update",
}
_ = flagSet.Parse()
}
cmd := exec.Command(os.Args[0], "-test.run=TestFailCallback")
cmd.Env = append(os.Environ(), "IS_SUB_PROCESS=1")
err := cmd.Run()
if e, ok := err.(*exec.ExitError); ok && !e.Success() {
return
}
t.Fatalf("process ran with err %v, want exit error", err)
tearDown(t.Name())
}

func updateCallbackFunc(toolName string, cliOutput io.Writer) func() {
return func() {
fmt.Fprintf(cliOutput, "updated successfully!")
}
}
6 changes: 3 additions & 3 deletions enum_var_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const (
Type2
)

func TestEnumVarPositive(t *testing.T) {
func TestSuccessfulEnumVar(t *testing.T) {
flagSet := NewFlagSet()
flagSet.EnumVar(&enumString, "enum", Type1, "enum", AllowdTypes{"type1": Type1, "type2": Type2})
os.Args = []string{
Expand All @@ -29,7 +29,7 @@ func TestEnumVarPositive(t *testing.T) {
tearDown(t.Name())
}

func TestEnumVarNegative(t *testing.T) {
func TestFailEnumVar(t *testing.T) {
if os.Getenv("IS_SUB_PROCESS") == "1" {
flagSet := NewFlagSet()

Expand All @@ -41,7 +41,7 @@ func TestEnumVarNegative(t *testing.T) {
_ = flagSet.Parse()
return
}
cmd := exec.Command(os.Args[0], "-test.run=TestEnumVarNegative")
cmd := exec.Command(os.Args[0], "-test.run=TestFailEnumVar")
cmd.Env = append(os.Environ(), "IS_SUB_PROCESS=1")
err := cmd.Run()
if e, ok := err.(*exec.ExitError); ok && !e.Success() {
Expand Down
7 changes: 7 additions & 0 deletions examples/basic/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"fmt"
"log"

"github.com/projectdiscovery/goflags"
Expand All @@ -15,6 +16,11 @@ type Options struct {

func main() {
testOptions := &Options{}
CheckUpdate := func() {
fmt.Println("checking if new version is available")
fmt.Println("updating tool....")
}

flagSet := goflags.NewFlagSet()
flagSet.CreateGroup("info", "Info",
flagSet.StringVarP(&testOptions.name, "name", "n", "", "name of the user"),
Expand All @@ -23,6 +29,7 @@ func main() {
flagSet.CreateGroup("additional", "Additional",
flagSet.StringVarP(&testOptions.Phone, "phone", "ph", "", "phone of the user"),
flagSet.StringSliceVarP(&testOptions.Address, "address", "add", nil, "address of the user", goflags.StringSliceOptions),
flagSet.CallbackVarP(CheckUpdate, "update", "ut", "update this tool to latest version"),
)

if err := flagSet.Parse(); err != nil {
Expand Down
7 changes: 7 additions & 0 deletions goflags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ func TestUsageOrder(t *testing.T) {
"two": EnumVariable(2),
}).Group("Enum")

flagSet.SetGroup("Update", "Update")
flagSet.CallbackVar(func() {}, "update", "update tool_1 to the latest released version").Group("Update")
flagSet.CallbackVarP(func() {}, "disable-update-check", "duc", "disable automatic update check").Group("Update")

output := &bytes.Buffer{}
flagSet.CommandLine.SetOutput(output)

Expand Down Expand Up @@ -155,6 +159,9 @@ BOOLEAN:
-bwdv, -bool-with-default-value2 Bool with default value example #2 (default true)
ENUM:
-en, -enum-with-default-value value Enum with default value(zero/one/two) (default zero)
UPDATE:
-update update tool_1 to the latest released version
-duc, -disable-update-check disable automatic update check
`
assert.Equal(t, expected, actual)

Expand Down

0 comments on commit 783b6c4

Please sign in to comment.