diff --git a/pkg/format/format.go b/pkg/format/format.go index 557495a2..e1ff0617 100644 --- a/pkg/format/format.go +++ b/pkg/format/format.go @@ -8,9 +8,9 @@ import ( // ParseBool returns true for "1","true","yes" or false for "0","false","no" or defaultValue for any other value func ParseBool(value string, defaultValue bool) (parsedValue bool, ok bool) { switch strings.ToLower(value) { - case "true", "1", "yes": + case "true", "1", "yes", "y": return true, true - case "false", "0", "no": + case "false", "0", "no", "n": return false, true default: return defaultValue, false diff --git a/pkg/format/format_colorize.go b/pkg/format/format_colorize.go index 601b1580..3516d9f2 100644 --- a/pkg/format/format_colorize.go +++ b/pkg/format/format_colorize.go @@ -29,6 +29,9 @@ var ColorizeError = ColorizeFalse // ColorizeContainer colorizes the input string as "Container" var ColorizeContainer = ColorizeDesc +// ColorizeLink colorizes the input string as "Link" +var ColorizeLink = color.New(color.FgHiBlue).SprintFunc() + // ColorizeValue colorizes the input string according to what type appears to be func ColorizeValue(value string, isEnum bool) string { if isEnum { diff --git a/pkg/generators/router.go b/pkg/generators/router.go index 84a68e32..092256b4 100644 --- a/pkg/generators/router.go +++ b/pkg/generators/router.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/containrrr/shoutrrr/pkg/generators/basic" "github.com/containrrr/shoutrrr/pkg/generators/xouath2" + "github.com/containrrr/shoutrrr/pkg/services/telegram" t "github.com/containrrr/shoutrrr/pkg/types" "strings" ) @@ -11,6 +12,7 @@ import ( var generatorMap = map[string]func() t.Generator{ "basic": func() t.Generator { return &basic.Generator{} }, "oauth2": func() t.Generator { return &xouath2.Generator{} }, + "telegram": func() t.Generator { return &telegram.Generator{} }, } // NewGenerator creates an instance of the generator that corresponds to the provided identifier diff --git a/pkg/services/telegram/telegram.go b/pkg/services/telegram/telegram.go index ffab9637..45f0364c 100644 --- a/pkg/services/telegram/telegram.go +++ b/pkg/services/telegram/telegram.go @@ -1,12 +1,8 @@ package telegram import ( - "bytes" - "encoding/json" "errors" - "fmt" "github.com/containrrr/shoutrrr/pkg/format" - "net/http" "net/url" "github.com/containrrr/shoutrrr/pkg/services/standard" @@ -14,7 +10,7 @@ import ( ) const ( - apiBase = "https://api.telegram.org/bot" + apiFormat = "https://api.telegram.org/bot%s/%s" maxlength = 4096 ) @@ -28,7 +24,7 @@ type Service struct { // Send notification to Telegram func (service *Service) Send(message string, params *types.Params) error { if len(message) > maxlength { - return errors.New("message exceeds the max length") + return errors.New("Message exceeds the max length") } config := *service.config @@ -55,8 +51,8 @@ func (service *Service) Initialize(configURL *url.URL, logger types.StdLogger) e } func (service *Service) sendMessageForChatIDs(message string, config *Config) error { - for _, channel := range service.config.Channels { - if err := sendMessageToAPI(message, channel, config); err != nil { + for _, chat := range service.config.Chats { + if err := sendMessageToAPI(message, chat, config); err != nil { return err } } @@ -68,19 +64,9 @@ func (service *Service) GetConfig() *Config { return service.config } -func sendMessageToAPI(message string, channel string, config *Config) error { - postURL := fmt.Sprintf("%s%s/sendMessage", apiBase, config.Token) - - payload := createSendMessagePayload(message, channel, config) - - jsonData, err := json.Marshal(payload) - if err != nil { - return err - } - - res, err := http.Post(postURL, "application/jsonData", bytes.NewBuffer(jsonData)) - if err == nil && res.StatusCode != http.StatusOK { - return fmt.Errorf("failed to send notification to \"%s\", response status code %s", channel, res.Status) - } +func sendMessageToAPI(message string, chat string, config *Config) error { + client := &Client{token: config.Token} + payload := createSendMessagePayload(message, chat, config) + _, err := client.SendMessage(&payload) return err } diff --git a/pkg/services/telegram/telegram_client.go b/pkg/services/telegram/telegram_client.go new file mode 100644 index 00000000..08624089 --- /dev/null +++ b/pkg/services/telegram/telegram_client.go @@ -0,0 +1,69 @@ +package telegram + +import ( + "encoding/json" + "fmt" + "github.com/containrrr/shoutrrr/pkg/util/jsonclient" +) + +// Client for Telegram API +type Client struct { + token string +} + +func (c *Client) apiURL(endpoint string) string { + return fmt.Sprintf(apiFormat, c.token, endpoint) +} + +// GetBotInfo returns the bot User info +func (c *Client) GetBotInfo() (*User, error) { + response := &userResponse{} + err := jsonclient.Get(c.apiURL("getMe"), response) + + if !response.OK { + return nil, GetErrorResponse(jsonclient.ErrorBody(err)) + } + + return &response.Result, nil +} + +// GetUpdates retrieves the latest updates +func (c *Client) GetUpdates(offset int, limit int, timeout int, allowedUpdates []string) ([]Update, error) { + + request := &updatesRequest{ + Offset: offset, + Limit: limit, + Timeout: timeout, + AllowedUpdates: allowedUpdates, + } + response := &updatesResponse{} + err := jsonclient.Post(c.apiURL("getUpdates"), request, response) + + if !response.OK { + return nil, GetErrorResponse(jsonclient.ErrorBody(err)) + } + + return response.Result, nil +} + +// SendMessage sends the specified Message +func (c *Client) SendMessage(message *SendMessagePayload) (*Message, error) { + + response := &messageResponse{} + err := jsonclient.Post(c.apiURL("sendMessage"), message, response) + + if !response.OK { + return nil, GetErrorResponse(jsonclient.ErrorBody(err)) + } + + return response.Result, nil +} + +// GetErrorResponse retrieves the error message from a failed request +func GetErrorResponse(body string) error { + response := &errorResponse{} + if err := json.Unmarshal([]byte(body), response); err == nil { + return response + } + return nil +} diff --git a/pkg/services/telegram/telegram_config.go b/pkg/services/telegram/telegram_config.go index f249a753..002dc8a6 100644 --- a/pkg/services/telegram/telegram_config.go +++ b/pkg/services/telegram/telegram_config.go @@ -13,16 +13,16 @@ import ( type Config struct { Token string `url:"user"` Preview bool `key:"preview" default:"Yes" desc:"If disabled, no web page preview will be displayed for URLs"` - Notification bool `key:"notification" default:"Yes" desc:"If disabled, sends message silently"` - ParseMode parseMode `key:"parsemode" default:"None" desc:"How the text message should be parsed"` - Channels []string `key:"channels"` + Notification bool `key:"notification" default:"Yes" desc:"If disabled, sends Message silently"` + ParseMode parseMode `key:"parsemode" default:"None" desc:"How the text Message should be parsed"` + Chats []string `key:"chats,channels"` Title string `key:"title" default:"" desc:"Notification title, optionally set by the sender"` } // Enums returns the fields that should use a corresponding EnumFormatter to Print/Parse their values func (config *Config) Enums() map[string]types.EnumFormatter { return map[string]types.EnumFormatter{ - "ParseMode": parseModes.Enum, + "ParseMode": ParseModes.Enum, } } @@ -67,7 +67,7 @@ func (config *Config) setURL(resolver types.ConfigQueryResolver, url *url.URL) e } } - if len(config.Channels) < 1 { + if len(config.Chats) < 1 { return errors.New("no channels defined in config URL") } diff --git a/pkg/services/telegram/telegram_generator.go b/pkg/services/telegram/telegram_generator.go new file mode 100644 index 00000000..2b3e945f --- /dev/null +++ b/pkg/services/telegram/telegram_generator.go @@ -0,0 +1,149 @@ +package telegram + +import ( + f "github.com/containrrr/shoutrrr/pkg/format" + "github.com/containrrr/shoutrrr/pkg/types" + "github.com/containrrr/shoutrrr/pkg/util/generator" + "os/signal" + "syscall" + + "fmt" + "os" + "strconv" +) + +// Generator is the telegram-specific URL generator +type Generator struct { + ud *generator.UserDialog + client *Client + chats []string + chatNames []string + chatTypes []string + done bool + owner *User + statusMessage int64 + botName string +} + +// Generate a telegram Shoutrrr configuration from a user dialog +func (g *Generator) Generate(_ types.Service, props map[string]string, _ []string) (types.ServiceConfig, error) { + var config Config + + g.ud = generator.NewUserDialog(os.Stdin, os.Stdout, props) + ud := g.ud + + ud.Writeln("To start we need your bot token. If you haven't created a bot yet, you can use this link:") + ud.Writeln(" %v", f.ColorizeLink("https://t.me/botfather?start")) + ud.Writeln("") + + token := ud.QueryString("Enter your bot token:", generator.ValidateFormat(IsTokenValid), "token") + + ud.Writeln("Fetching bot info...") + // ud.Writeln("Session token: %v", g.sessionToken) + + g.client = &Client{token: token} + botInfo, err := g.client.GetBotInfo() + if err != nil { + return &Config{}, err + } + + g.botName = botInfo.Username + ud.Writeln("") + ud.Writeln("Okay! %v will listen for any messages in PMs and group chats it is invited to.", + f.ColorizeString("@", g.botName, ":")) + + g.done = false + lastUpdate := 0 + + signals := make(chan os.Signal, 1) + + // Subscribe to system signals + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + + for !g.done { + + ud.Writeln("Waiting for messages to arrive...") + + updates, err := g.client.GetUpdates(lastUpdate, 10, 120, nil) + if err != nil { + panic(err) + } + + for _, update := range updates { + lastUpdate = update.UpdateID + 1 + + message := update.Message + if update.ChannelPost != nil { + message = update.ChannelPost + } + + if message != nil { + chat := message.Chat + + source := message.Chat.Username + if message.From != nil { + source = message.From.Username + } + ud.Writeln("Got Message '%v' from @%v in %v chat %v", + f.ColorizeString(message.Text), + f.ColorizeProp(source), + f.ColorizeEnum(chat.Type), + f.ColorizeNumber(chat.ID)) + ud.Writeln(g.addChat(chat)) + } else { + ud.Writeln("Got unknown Update. Ignored!") + } + } + + ud.Writeln("") + + g.done = !ud.QueryBool(fmt.Sprintf("Got %v chat ID(s) so far. Want to add some more?", + f.ColorizeNumber(len(g.chats))), "") + } + + ud.Writeln("") + ud.Writeln("Cleaning up the bot session...") + + // Notify API that we got the updates + if _, err = g.client.GetUpdates(lastUpdate, 0, 0, nil); err != nil { + g.ud.Writeln("Failed to mark last updates as received: %v", f.ColorizeError(err)) + } + + if len(g.chats) < 1 { + return nil, fmt.Errorf("no chats were selected") + } + + ud.Writeln("Selected chats:") + + for i, id := range g.chats { + name := g.chatNames[i] + chatType := g.chatTypes[i] + ud.Writeln(" %v (%v) %v", f.ColorizeNumber(id), f.ColorizeEnum(chatType), f.ColorizeString(name)) + } + + ud.Writeln("") + + config = Config{ + Notification: true, + Token: token, + Chats: g.chats, + } + + return &config, nil +} + +func (g *Generator) addChat(chat *chat) (result string) { + id := strconv.FormatInt(chat.ID, 10) + name := chat.Name() + + for _, c := range g.chats { + if c == id { + return fmt.Sprintf("chat %v is already selected!", f.ColorizeString(name)) + } + } + g.chats = append(g.chats, id) + g.chatNames = append(g.chatNames, name) + g.chatTypes = append(g.chatTypes, chat.Type) + + return fmt.Sprintf("Added new chat %v!", f.ColorizeString(name)) +} diff --git a/pkg/services/telegram/telegram_internal_test.go b/pkg/services/telegram/telegram_internal_test.go index e8864b21..bdbe652d 100644 --- a/pkg/services/telegram/telegram_internal_test.go +++ b/pkg/services/telegram/telegram_internal_test.go @@ -60,11 +60,11 @@ func getPayloadFromURL(testURL string, message string, logger *log.Logger) (Send return SendMessagePayload{}, err } - if len(telegram.config.Channels) < 1 { + if len(telegram.config.Chats) < 1 { return SendMessagePayload{}, errors.New("no channels were supplied") } - return createSendMessagePayload(message, telegram.config.Channels[0], telegram.config), nil + return createSendMessagePayload(message, telegram.config.Chats[0], telegram.config), nil } diff --git a/pkg/services/telegram/telegram_json.go b/pkg/services/telegram/telegram_json.go index c8b18139..a44930ed 100644 --- a/pkg/services/telegram/telegram_json.go +++ b/pkg/services/telegram/telegram_json.go @@ -2,11 +2,28 @@ package telegram // SendMessagePayload is the notification payload for the telegram notification service type SendMessagePayload struct { - Text string `json:"text"` - ID string `json:"chat_id"` - ParseMode string `json:"parse_mode,omitempty"` - DisablePreview bool `json:"disable_web_page_preview"` - DisableNotification bool `json:"disable_notification"` + Text string `json:"text"` + ID string `json:"chat_id"` + ParseMode string `json:"parse_mode,omitempty"` + DisablePreview bool `json:"disable_web_page_preview"` + DisableNotification bool `json:"disable_notification"` + ReplyMarkup *replyMarkup `json:"reply_markup,omitempty"` + Entities []entity `json:"entities,omitempty"` + ReplyTo int64 `json:"reply_to_message_id"` + MessageID int64 `json:"message_id,omitempty"` +} + +// Message represents one chat message +type Message struct { + MessageID int64 `json:"message_id"` + Text string `json:"text"` + From *User `json:"from"` + Chat *chat `json:"chat"` +} + +type messageResponse struct { + OK bool `json:"ok"` + Result *Message `json:"result"` } func createSendMessagePayload(message string, channel string, config *Config) SendMessagePayload { @@ -17,9 +34,141 @@ func createSendMessagePayload(message string, channel string, config *Config) Se DisablePreview: !config.Preview, } - if config.ParseMode != parseModes.None { + if config.ParseMode != ParseModes.None { payload.ParseMode = config.ParseMode.String() } return payload } + +type errorResponse struct { + OK bool `json:"ok"` + ErrorCode int `json:"error_code"` + Description string `json:"description"` +} + +func (e *errorResponse) Error() string { + return e.Description +} + +type userResponse struct { + OK bool `json:"ok"` + Result User `json:"result"` +} + +// User contains information about a telegram user or bot +type User struct { + // Unique identifier for this User or bot + ID int64 `json:"id"` + // True, if this User is a bot + IsBot bool `json:"is_bot"` + // User's or bot's first name + FirstName string `json:"first_name"` + // Optional. User's or bot's last name + LastName string `json:"last_name"` + // Optional. User's or bot's username + Username string `json:"username"` + // Optional. IETF language tag of the User's language + LanguageCode string `json:"language_code"` + // Optional. True, if the bot can be invited to groups. Returned only in getMe. + CanJoinGroups bool `json:"can_join_groups"` + // Optional. True, if privacy mode is disabled for the bot. Returned only in getMe. + CanReadAllGroupMessages bool `json:"can_read_all_group_messages"` + // Optional. True, if the bot supports inline queries. Returned only in getMe. + SupportsInlineQueries bool `json:"supports_inline_queries"` +} + +type updatesRequest struct { + Offset int `json:"offset"` + Limit int `json:"limit"` + Timeout int `json:"timeout"` + AllowedUpdates []string `json:"allowed_updates"` +} + +type updatesResponse struct { + OK bool `json:"ok"` + Result []Update `json:"result"` +} + +type inlineQuery struct { + // Unique identifier for this query + ID string `json:"id"` + // Sender + From User `json:"from"` + // Text of the query (up to 256 characters) + Query string `json:"query"` + // Offset of the results to be returned, can be controlled by the bot + Offset string `json:"offset"` +} + +type chosenInlineResult struct{} + +// Update contains state changes since the previous Update +type Update struct { + // The Update's unique identifier. Update identifiers start from a certain positive number and increase sequentially. This ID becomes especially handy if you're using Webhooks, since it allows you to ignore repeated updates or to restore the correct Update sequence, should they get out of order. If there are no new updates for at least a week, then identifier of the next Update will be chosen randomly instead of sequentially. + UpdateID int `json:"update_id"` + // Optional. New incoming Message of any kind — text, photo, sticker, etc. + Message *Message `json:"Message"` + // Optional. New version of a Message that is known to the bot and was edited + EditedMessage *Message `json:"edited_message"` + // Optional. New incoming channel post of any kind — text, photo, sticker, etc. + ChannelPost *Message `json:"channel_post"` + // Optional. New version of a channel post that is known to the bot and was edited + EditedChannelPost *Message `json:"edited_channel_post"` + // Optional. New incoming inline query + InlineQuery *inlineQuery `json:"inline_query"` + //// Optional. The result of an inline query that was chosen by a User and sent to their chat partner. Please see our documentation on the feedback collecting for details on how to enable these updates for your bot. + ChosenInlineResult *chosenInlineResult `json:"chosen_inline_result"` + //// Optional. New incoming callback query + CallbackQuery *callbackQuery `json:"callback_query"` + //// Optional. New incoming shipping query. Only for invoices with flexible price + //ShippingQuery ShippingQuery `json:"shipping_query"` + //// Optional. New incoming pre-checkout query. Contains full information about checkout + //PreCheckoutQuery PreCheckoutQuery `json:"pre_checkout_query"` + /* + // Optional. New poll state. Bots receive only updates about stopped polls and polls, which are sent by the bot + Poll Poll `json:"poll"` + // Optional. A User changed their answer in a non-anonymous poll. Bots receive new votes only in polls that were sent by the bot itself. + Poll_answer PollAnswer `json:"poll_answer"` + */ +} + +type chat struct { + ID int64 `json:"id"` + Type string `json:"type"` + Title string `json:"title"` + Username string `json:"username"` +} + +func (c *chat) Name() string { + if c.Type == "private" || c.Type == "channel" { + return "@" + c.Username + } + return c.Title +} + +type inlineKey struct { + Text string `json:"text"` + URL string `json:"url"` + LoginURL string `json:"login_url"` + CallbackData string `json:"callback_data"` + SwitchInlineQuery string `json:"switch_inline_query"` + SwitchInlineQueryCurrent string `json:"switch_inline_query_current_chat"` +} + +type replyMarkup struct { + InlineKeyboard [][]inlineKey `json:"inline_keyboard,omitempty"` +} + +type entity struct { + Type string `json:"type"` + Offset int `json:"offset"` + Length int `json:"length"` +} + +type callbackQuery struct { + ID string `json:"id"` + From *User `json:"from"` + Message *Message `json:"Message"` + Data string `json:"data"` +} diff --git a/pkg/services/telegram/telegram_parsemode.go b/pkg/services/telegram/telegram_parsemode.go index fa85dadc..eb69c717 100644 --- a/pkg/services/telegram/telegram_parsemode.go +++ b/pkg/services/telegram/telegram_parsemode.go @@ -15,7 +15,8 @@ type parseModeVals struct { Enum types.EnumFormatter } -var parseModes = &parseModeVals{ +// ParseModes is an enum helper for parseMode +var ParseModes = &parseModeVals{ None: 0, Markdown: 1, HTML: 2, @@ -30,5 +31,5 @@ var parseModes = &parseModeVals{ } func (pm parseMode) String() string { - return parseModes.Enum.Print(int(pm)) + return ParseModes.Enum.Print(int(pm)) } diff --git a/pkg/services/telegram/telegram_test.go b/pkg/services/telegram/telegram_test.go index 74eb4f99..bff1fd8e 100644 --- a/pkg/services/telegram/telegram_test.go +++ b/pkg/services/telegram/telegram_test.go @@ -43,16 +43,16 @@ var _ = Describe("the telegram service", func() { serviceURL, _ := url.Parse(envTelegramURL) err := telegram.Initialize(serviceURL, logger) Expect(err).NotTo(HaveOccurred()) - err = telegram.Send("This is an integration test message", nil) + err = telegram.Send("This is an integration test Message", nil) Expect(err).NotTo(HaveOccurred()) }) - When("given a message that exceeds the max length", func() { + When("given a Message that exceeds the max length", func() { It("should generate an error", func() { if envTelegramURL == "" { return } hundredChars := "this string is exactly (to the letter) a hundred characters long which will make the send func error" - serviceURL, _ := url.Parse("telegram://12345:mock-token/channel-1") + serviceURL, _ := url.Parse("telegram://12345:mock-token/?chats=channel-1") builder := strings.Builder{} for i := 0; i < 42; i++ { builder.WriteString(hundredChars) @@ -69,8 +69,8 @@ var _ = Describe("the telegram service", func() { return } It("should generate a 401", func() { - serviceURL, _ := url.Parse("telegram://000000000:AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA@telegram/?channels=channel-id") - message := "this is a perfectly valid message" + serviceURL, _ := url.Parse("telegram://000000000:AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA@telegram/?chats=channel-id") + message := "this is a perfectly valid Message" err := telegram.Initialize(serviceURL, logger) Expect(err).NotTo(HaveOccurred()) @@ -99,7 +99,7 @@ var _ = Describe("the telegram service", func() { var err error BeforeEach(func() { - serviceURL, _ := url.Parse("telegram://12345:mock-token@telegram/?channels=channel-1,channel-2,channel-3") + serviceURL, _ := url.Parse("telegram://12345:mock-token@telegram/?chats=channel-1,channel-2,channel-3") err = telegram.Initialize(serviceURL, logger) config = telegram.GetConfig() }) @@ -113,9 +113,9 @@ var _ = Describe("the telegram service", func() { Expect(err).NotTo(HaveOccurred()) Expect(config.Token).To(Equal("12345:mock-token")) }) - It("should add every subsequent argument as a channel id", func() { + It("should add every chats query field as a chat ID", func() { Expect(err).NotTo(HaveOccurred()) - Expect(config.Channels).To(Equal([]string{ + Expect(config.Chats).To(Equal([]string{ "channel-1", "channel-2", "channel-3", @@ -134,7 +134,7 @@ var _ = Describe("the telegram service", func() { httpmock.DeactivateAndReset() }) It("should not report an error if the server accepts the payload", func() { - serviceURL, _ := url.Parse("telegram://12345:mock-token@telegram/?channels=channel-1,channel-2,channel-3") + serviceURL, _ := url.Parse("telegram://12345:mock-token@telegram/?chats=channel-1,channel-2,channel-3") err = telegram.Initialize(serviceURL, logger) Expect(err).NotTo(HaveOccurred()) @@ -148,10 +148,10 @@ var _ = Describe("the telegram service", func() { It("should implement basic service API methods correctly", func() { testutils.TestConfigGetInvalidQueryValue(&Config{}) - testutils.TestConfigSetInvalidQueryValue(&Config{}, "telegram://12345:mock-token@telegram/?channels=channel-1&foo=bar") + testutils.TestConfigSetInvalidQueryValue(&Config{}, "telegram://12345:mock-token@telegram/?chats=channel-1&foo=bar") testutils.TestConfigGetEnumsCount(&Config{}, 1) - testutils.TestConfigGetFieldsCount(&Config{}, 5) + testutils.TestConfigGetFieldsCount(&Config{}, 6) }) }) @@ -162,7 +162,7 @@ func expectErrorAndEmptyObject(telegram *Service, rawURL string, logger *log.Log config := telegram.GetConfig() fmt.Printf("Token: \"%+v\" \"%s\" \n", config.Token, config.Token) Expect(config.Token).To(BeEmpty()) - Expect(len(config.Channels)).To(BeZero()) + Expect(len(config.Chats)).To(BeZero()) } func setupResponder(endpoint string, token string, code int, body string) { diff --git a/pkg/util/generator/generator_common.go b/pkg/util/generator/generator_common.go new file mode 100644 index 00000000..9bc84434 --- /dev/null +++ b/pkg/util/generator/generator_common.go @@ -0,0 +1,183 @@ +package generator + +import ( + "bufio" + "errors" + "fmt" + f "github.com/containrrr/shoutrrr/pkg/format" + "github.com/fatih/color" + "io" + re "regexp" + "strconv" +) + +var errInvalidFormat = errors.New("invalid format") + +// ValidateFormat is a validation wrapper turning false bool results into errors +func ValidateFormat(validator func(string) bool) func(string) error { + return func(answer string) error { + if validator(answer) { + return nil + } + return errInvalidFormat + } +} + +var errRequired = errors.New("field is required") + +// Required is a validator that checks whether the input contains any characters +func Required(answer string) error { + if answer == "" { + return errRequired + } + return nil +} + +// UserDialog is an abstraction for question/answer based user interaction +type UserDialog struct { + reader io.Reader + writer io.Writer + scanner *bufio.Scanner + props map[string]string +} + +// NewUserDialog initializes a UserDialog with safe defaults +func NewUserDialog(reader io.Reader, writer io.Writer, props map[string]string) *UserDialog { + if props == nil { + props = map[string]string{} + } + return &UserDialog{ + reader: reader, + writer: writer, + scanner: bufio.NewScanner(reader), + props: props, + } +} + +// Write message to user +func (ud *UserDialog) Write(message string, v ...interface{}) { + if _, err := fmt.Fprintf(ud.writer, message, v...); err != nil { + fmt.Printf("failed to write to output: %v", err) + } +} + +// Writeln writes a message to the user that completes a line +func (ud *UserDialog) Writeln(format string, v ...interface{}) { + ud.Write(format+"\n", v...) +} + +// Query writes the prompt to the user and returns the regex groups if it matches the validator pattern +func (ud *UserDialog) Query(prompt string, validator *re.Regexp, key string) (groups []string) { + ud.QueryString(prompt, ValidateFormat(func(answer string) bool { + groups = validator.FindStringSubmatch(answer) + return groups != nil + }), key) + + return groups +} + +// QueryAll is a version of Query that can return multiple matches +func (ud *UserDialog) QueryAll(prompt string, validator *re.Regexp, key string, maxMatches int) (matches [][]string) { + ud.QueryString(prompt, ValidateFormat(func(answer string) bool { + matches = validator.FindAllStringSubmatch(answer, maxMatches) + return matches != nil + }), key) + + return matches +} + +// QueryString writes the prompt to the user and returns the answer if it passes the validator function +func (ud *UserDialog) QueryString(prompt string, validator func(string) error, key string) string { + + if validator == nil { + validator = func(string) error { + return nil + } + } + + answer, foundProp := ud.props[key] + if foundProp { + err := validator(answer) + colAnswer := f.ColorizeValue(answer, false) + colKey := f.ColorizeProp(key) + if err == nil { + ud.Writeln("Using prop value %v for %v", colAnswer, colKey) + return answer + } + ud.Writeln("Supplied prop value %v is not valid for %v: %v", colAnswer, colKey, err) + } + + for { + ud.Write("%v ", prompt) + color.Set(color.FgHiWhite) + if !ud.scanner.Scan() { + if err := ud.scanner.Err(); err != nil { + ud.Writeln(err.Error()) + continue + } + + // Input closed, so let's just return an empty string + return "" + } + answer = ud.scanner.Text() + color.Unset() + + if err := validator(answer); err != nil { + ud.Writeln("%v", err) + continue + } + return answer + } +} + +// QueryStringPattern is a version of QueryString taking a regular expression pattern as the validator +func (ud *UserDialog) QueryStringPattern(prompt string, validator *re.Regexp, key string) (answer string) { + + if validator == nil { + panic("validator cannot be nil") + } + + return ud.QueryString(prompt, func(s string) error { + if validator.MatchString(s) { + return nil + } + return errInvalidFormat + }, key) +} + +// QueryInt writes the prompt to the user and returns the answer if it can be parsed as an integer +func (ud *UserDialog) QueryInt(prompt string, key string, bitSize int) (value int64) { + validator := re.MustCompile(`^((0x|#)([0-9a-fA-F]+))|(-?[0-9]+)$`) + ud.QueryString(prompt, func(answer string) error { + groups := validator.FindStringSubmatch(answer) + if len(groups) < 1 { + return errors.New("not a number") + } + number := groups[0] + base := 0 + if groups[2] == "#" { + // Explicitly treat #ffa080 as hexadecimal + base = 16 + number = groups[3] + } + + var err error + value, err = strconv.ParseInt(number, base, bitSize) + + return err + }, key) + return value +} + +// QueryBool writes the prompt to the user and returns the answer if it can be parsed as a boolean +func (ud *UserDialog) QueryBool(prompt string, key string) (value bool) { + ud.QueryString(prompt, func(answer string) error { + parsed, ok := f.ParseBool(answer, false) + if ok { + value = parsed + return nil + } + return fmt.Errorf("answer using %v or %v", f.ColorizeTrue("yes"), f.ColorizeFalse("no")) + }, key) + return value +} diff --git a/pkg/util/generator/generator_test.go b/pkg/util/generator/generator_test.go new file mode 100644 index 00000000..645fb665 --- /dev/null +++ b/pkg/util/generator/generator_test.go @@ -0,0 +1,169 @@ +package generator_test + +import ( + "fmt" + "github.com/containrrr/shoutrrr/pkg/util/generator" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" + re "regexp" + "strings" + "testing" +) + +func TestGenerator(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Generator Suite") +} + +var ( + client *generator.UserDialog + userOut *gbytes.Buffer + userIn *gbytes.Buffer +) + +func mockTyped(a ...interface{}) { + _, _ = fmt.Fprint(userOut, a...) + _, _ = fmt.Fprint(userOut, "\n") +} + +func dumpBuffers() { + for _, line := range strings.Split(string(userIn.Contents()), "\n") { + println(">", line) + } + for _, line := range strings.Split(string(userOut.Contents()), "\n") { + println("<", line) + } +} + +var _ = Describe("GeneratorCommon", func() { + Describe("attach to the data stream", func() { + + BeforeEach(func() { + userOut = gbytes.NewBuffer() + userIn = gbytes.NewBuffer() + client = generator.NewUserDialog(userOut, userIn, map[string]string{"propKey": "propVal"}) + }) + + It("reprompt upon invalid answers", func() { + defer dumpBuffers() + answer := make(chan string) + go func() { + answer <- client.QueryString("name:", generator.Required, "") + }() + + mockTyped("") + mockTyped("Normal Human Name") + + Eventually(userIn).Should(gbytes.Say(`name: `)) + + Eventually(userIn).Should(gbytes.Say(`field is required`)) + Eventually(userIn).Should(gbytes.Say(`name: `)) + Eventually(answer).Should(Receive(Equal("Normal Human Name"))) + }) + + It("should accept any input when validator is nil", func() { + defer dumpBuffers() + answer := make(chan string) + go func() { + answer <- client.QueryString("name:", nil, "") + }() + mockTyped("") + Eventually(answer).Should(Receive(BeEmpty())) + }) + + It("should use predefined prop value if key is present", func() { + defer dumpBuffers() + answer := make(chan string) + go func() { + answer <- client.QueryString("name:", generator.Required, "propKey") + }() + Eventually(answer).Should(Receive(Equal("propVal"))) + }) + + It("Query", func() { + defer dumpBuffers() + answer := make(chan []string) + query := "pick foo or bar:" + go func() { + answer <- client.Query(query, re.MustCompile("(foo|bar)"), "") + }() + + mockTyped("") + mockTyped("foo") + + Eventually(userIn).Should(gbytes.Say(query)) + Eventually(userIn).Should(gbytes.Say(`invalid format`)) + Eventually(userIn).Should(gbytes.Say(query)) + Eventually(answer).Should(Receive(ContainElement("foo"))) + }) + + It("QueryAll", func() { + defer dumpBuffers() + answer := make(chan [][]string) + query := "pick foo or bar:" + go func() { + answer <- client.QueryAll(query, re.MustCompile(`foo(ba[rz])`), "", -1) + }() + + mockTyped("foobar foobaz") + + Eventually(userIn).Should(gbytes.Say(query)) + var matches [][]string + Eventually(answer).Should(Receive(&matches)) + Expect(matches).To(ContainElement([]string{"foobar", "bar"})) + Expect(matches).To(ContainElement([]string{"foobaz", "baz"})) + }) + + It("QueryStringPattern", func() { + defer dumpBuffers() + answer := make(chan string) + query := "type of bar:" + go func() { + answer <- client.QueryStringPattern(query, re.MustCompile(".*bar"), "") + }() + + mockTyped("foo") + mockTyped("foobar") + + Eventually(userIn).Should(gbytes.Say(query)) + Eventually(userIn).Should(gbytes.Say(`invalid format`)) + Eventually(userIn).Should(gbytes.Say(query)) + Eventually(answer).Should(Receive(Equal("foobar"))) + }) + + It("QueryInt", func() { + defer dumpBuffers() + answer := make(chan int64) + query := "number:" + go func() { + answer <- client.QueryInt(query, "", 64) + }() + + mockTyped("x") + mockTyped("0x20") + + Eventually(userIn).Should(gbytes.Say(query)) + Eventually(userIn).Should(gbytes.Say(`not a number`)) + Eventually(userIn).Should(gbytes.Say(query)) + Eventually(answer).Should(Receive(Equal(int64(32)))) + }) + + It("QueryBool", func() { + defer dumpBuffers() + answer := make(chan bool) + query := "cool?" + go func() { + answer <- client.QueryBool(query, "") + }() + + mockTyped("maybe") + mockTyped("y") + + Eventually(userIn).Should(gbytes.Say(query)) + Eventually(userIn).Should(gbytes.Say(`answer using yes or no`)) + Eventually(userIn).Should(gbytes.Say(query)) + Eventually(answer).Should(Receive(BeTrue())) + }) + }) +}) diff --git a/pkg/util/jsonclient/error.go b/pkg/util/jsonclient/error.go new file mode 100644 index 00000000..7317f8ae --- /dev/null +++ b/pkg/util/jsonclient/error.go @@ -0,0 +1,29 @@ +package jsonclient + +import "fmt" + +// Error contains additional http/JSON details +type Error struct { + StatusCode int + Body string + err error +} + +func (je Error) Error() string { + return je.String() +} + +func (je Error) String() string { + if je.err == nil { + return fmt.Sprintf("unknown error (HTTP %v)", je.StatusCode) + } + return je.err.Error() +} + +// ErrorBody returns the request body from an Error +func ErrorBody(e error) string { + if jsonError, ok := e.(Error); ok { + return jsonError.Body + } + return "" +} diff --git a/pkg/util/jsonclient/jsonclient.go b/pkg/util/jsonclient/jsonclient.go new file mode 100644 index 00000000..a6396d0f --- /dev/null +++ b/pkg/util/jsonclient/jsonclient.go @@ -0,0 +1,86 @@ +package jsonclient + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" +) + +// ContentType is the default mime type for JSON +const ContentType = "application/json" + +// DefaultClient is the singleton instance of jsonclient using http.DefaultClient +var DefaultClient = &Client{HTTPClient: http.DefaultClient} + +// Get fetches url using GET and unmarshals into the passed response using DefaultClient +func Get(url string, response interface{}) error { + return DefaultClient.Get(url, response) +} + +// Post sends request as JSON and unmarshals the response JSON into the supplied struct using DefaultClient +func Post(url string, request interface{}, response interface{}) error { + return DefaultClient.Post(url, request, response) +} + +// Client is a JSON wrapper around http.Client +type Client struct { + HTTPClient *http.Client +} + +// Get fetches url using GET and unmarshals into the passed response +func (c *Client) Get(url string, response interface{}) error { + res, err := c.HTTPClient.Get(url) + if err != nil { + return err + } + + return parseResponse(res, response) +} + +// Post sends request as JSON and unmarshals the response JSON into the supplied struct +func (c *Client) Post(url string, request interface{}, response interface{}) error { + + var err error + var body []byte + + body, err = json.Marshal(request) + if err != nil { + return fmt.Errorf("error creating payload: %v", err) + } + + var res *http.Response + res, err = c.HTTPClient.Post(url, ContentType, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("error sending payload: %v", err) + } + + return parseResponse(res, response) +} + +func parseResponse(res *http.Response, response interface{}) error { + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + + if res.StatusCode >= 400 { + err = fmt.Errorf("got HTTP %v", res.Status) + } + + if err == nil { + err = json.Unmarshal(body, response) + } + + if err != nil { + if body == nil { + body = []byte{} + } + return Error{ + StatusCode: res.StatusCode, + Body: string(body), + err: err, + } + } + + return nil +} diff --git a/pkg/util/jsonclient/jsonclient_test.go b/pkg/util/jsonclient/jsonclient_test.go new file mode 100644 index 00000000..23dbf8d9 --- /dev/null +++ b/pkg/util/jsonclient/jsonclient_test.go @@ -0,0 +1,138 @@ +package jsonclient_test + +import ( + "errors" + "github.com/containrrr/shoutrrr/pkg/util/jsonclient" + "github.com/onsi/gomega/ghttp" + "net/http" + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestJSONClient(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "JSONClient Suite") +} + +var _ = Describe("JSONClient", func() { + var server *ghttp.Server + + BeforeEach(func() { + server = ghttp.NewServer() + }) + + When("the server returns an invalid JSON response", func() { + It("should return an error", func() { + server.AppendHandlers(ghttp.RespondWith(http.StatusOK, "invalid json")) + res := &mockResponse{} + err := jsonclient.Get(server.URL(), &res) + Expect(server.ReceivedRequests()).Should(HaveLen(1)) + Expect(err).To(MatchError("invalid character 'i' looking for beginning of value")) + Expect(res.Status).To(BeEmpty()) + }) + }) + + When("the server returns an empty response", func() { + It("should return an error", func() { + server.AppendHandlers(ghttp.RespondWith(http.StatusOK, nil)) + res := &mockResponse{} + err := jsonclient.Get(server.URL(), &res) + Expect(server.ReceivedRequests()).Should(HaveLen(1)) + Expect(err).To(MatchError("unexpected end of JSON input")) + Expect(res.Status).To(BeEmpty()) + }) + }) + + It("should deserialize GET response", func() { + server.AppendHandlers(ghttp.RespondWithJSONEncoded(http.StatusOK, mockResponse{Status: "OK"})) + res := &mockResponse{} + err := jsonclient.Get(server.URL(), &res) + Expect(server.ReceivedRequests()).Should(HaveLen(1)) + Expect(err).ToNot(HaveOccurred()) + Expect(res.Status).To(Equal("OK")) + }) + + Describe("POST", func() { + It("should de-/serialize request and response", func() { + + req := &mockRequest{Number: 5} + res := &mockResponse{} + + server.AppendHandlers(ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", "/"), + ghttp.VerifyJSONRepresenting(&req), + ghttp.RespondWithJSONEncoded(http.StatusOK, &mockResponse{Status: "That's Numberwang!"})), + ) + + err := jsonclient.Post(server.URL(), &req, &res) + Expect(server.ReceivedRequests()).Should(HaveLen(1)) + Expect(err).ToNot(HaveOccurred()) + Expect(res.Status).To(Equal("That's Numberwang!")) + }) + + It("should return error on error status responses", func() { + server.AppendHandlers(ghttp.RespondWith(404, "Not found!")) + err := jsonclient.Post(server.URL(), &mockRequest{}, &mockResponse{}) + Expect(server.ReceivedRequests()).Should(HaveLen(1)) + Expect(err).To(MatchError("got HTTP 404 Not Found")) + }) + + It("should return error on invalid request", func() { + server.AppendHandlers(ghttp.VerifyRequest("POST", "/")) + err := jsonclient.Post(server.URL(), func() {}, &mockResponse{}) + Expect(server.ReceivedRequests()).Should(HaveLen(0)) + Expect(err).To(MatchError("error creating payload: json: unsupported type: func()")) + }) + + It("should return error on invalid response type", func() { + res := &mockResponse{Status: "cool skirt"} + server.AppendHandlers(ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", "/"), + ghttp.RespondWithJSONEncoded(http.StatusOK, res)), + ) + + err := jsonclient.Post(server.URL(), nil, &[]bool{}) + Expect(server.ReceivedRequests()).Should(HaveLen(1)) + Expect(err).To(MatchError("json: cannot unmarshal object into Go value of type []bool")) + Expect(jsonclient.ErrorBody(err)).To(MatchJSON(`{"Status":"cool skirt"}`)) + }) + }) + + AfterEach(func() { + //shut down the server between tests + server.Close() + }) +}) + +var _ = Describe("Error", func() { + When("no internal error has been set", func() { + It("should return a generic message with status code", func() { + errorWithNoError := jsonclient.Error{StatusCode: http.StatusEarlyHints} + Expect(errorWithNoError.String()).To(Equal("unknown error (HTTP 103)")) + }) + }) + Describe("ErrorBody", func() { + When("passed a non-json error", func() { + It("should return an empty string", func() { + Expect(jsonclient.ErrorBody(errors.New("unrelated error"))).To(BeEmpty()) + }) + }) + When("passed a jsonclient.Error", func() { + It("should return the request body from that error", func() { + errorBody := `{"error": "bad user"}` + jsonError := jsonclient.Error{Body: errorBody} + Expect(jsonclient.ErrorBody(jsonError)).To(MatchJSON(errorBody)) + }) + }) + }) +}) + +type mockResponse struct { + Status string +} + +type mockRequest struct { + Number int +}