diff --git a/.gitignore b/.gitignore index fe1d3d5..7c740bc 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,7 @@ # vendor/ # CUSTOM -*.env \ No newline at end of file +*.env +*.json +*.crt +*.key diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..c7b72bf --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,20 @@ +# contributing to go-schwab/trader + +contributing to this project is very easy! development happens on the [dev](https://github.com/go-schwab/trader/tree/dev) branch. the goal is for the main branch to remain essentially unchanged (barring library-breaking behavior) until the next semantic release, to keep things as stable as possible. + +--- + +BEFORE DOING ANY OF THIS, YOU MUST TEST YOUR CHANGES BY RUNNING GO TEST. + +IF YOUR CODE DOESNT PASS OUR CI TESTS, YOUR PR WILL NOT BE REVIEWED KINDLY :) + +0. create a fork +1. commit your changes +2. create a pull request to the dev branch, preferably with the following description: + +``` +major | minor: +ref: [issue] #_ | [pr] v_._._ +desc: +- ... +``` diff --git a/README.md b/README.md index 70abcdd..3112581 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,17 @@ # go wrapper for schwab's trader-api + [![Go Reference](https://pkg.go.dev/badge/github.com/samjtro/schwab.svg)](https://pkg.go.dev/github.com/samjtro/schwab) [![Go Report Card](https://goreportcard.com/badge/github.com/samjtro/schwab)](https://goreportcard.com/report/github.com/samjtro/schwab) [![License](https://img.shields.io/badge/License-GPLv2-green)](LICENSE) -built by [@samjtro](https://github.com/samjtro) +built by: [@samjtro](https://github.com/samjtro) + +see: [CONTRIBUTING.md](https://github.com/go-schwab/trader/blob/main/CONTRIBUTING.md) -if you want to contribute - go for it! there is no contribution guide, just a simple golden rule: if it ain't broke, don't fix it: -**all** contributions should be tested via `go test` before submission. +--- why should you use this project? + - lightning fast - return structs are easily indexable - easy to setup, easy to use (personal preference, i know - but trust me!) @@ -17,14 +20,23 @@ why should you use this project? ### 0.0 quick start -0. go to https://developer.schwab.com, create an account, create an app, get app credentials from https://developer.schwab.com/dashboard/apps -1. create `config.env` in your project directory, formatted as such: +0. go to , create an account, create an app, get app credentials from +1. create any file with the `.env` extension in your project directory (can also have multiple, if necessary), formatted as such: + ``` APPKEY=KEY0 // App Key SECRET=KEY1 // App Secret CBURL=https://127.0.0.1 // App Callback URL ``` -2. `go get github.com/go-schwab/trader@v0.5.1` + +2. run the following command in your cwd to generate ssl certs for secure tls transmission of your bearer token: + +``` +openssl req -x509 -out localhost.crt -keyout localhost.key -newkey rsa:2048 -nodes -sha256 -subj '/CN=localhost' -extensions EXT -config <( \ +printf "[dn]\nCN=localhost\n[req]\ndistinguished_name = dn\n[EXT]\nsubjectAltName=DNS.1:localhost,IP:127.0.0.1\nkeyUsage=digitalSignature\nextendedKeyUsage=serverAuth") +``` + +3. `go get github.com/go-schwab/trader@v0.9.0` ### 0.1 agent @@ -44,7 +56,7 @@ agent := trader.Initiate() code samples: -``` +```go df, err := agent.GetPriceHistory("AAPL", "month", "1", "daily", "1", "", "") check(err) ``` @@ -62,7 +74,7 @@ return: code samples: -``` +```go quote, err := agent.GetQuote("AAPL") check(err) ``` @@ -76,14 +88,13 @@ return: {EQUITY COE NBBO true 1973757747 AAPL 199.62 164.075 EDGX 195.75 5 1717687921950 XNAS 195.74 4 1717687920970 195.87 196.5 XADF 195.745 100 195.21 195.745 -0.125 -0.06381784 -0.125 -0.06381784 0 0 0 1717687921950 Normal 14237020 1717687921574} ``` - #### 1.3.0 instruments ##### 1.3.1 simple code samples: -``` +```go simple, err := agent.SearchInstrumentSimple("AAPL") check(err) ``` @@ -101,7 +112,7 @@ return: code samples: -``` +```go fundamental, err := agent.SearchInstrumentFundamental("AAPL") check(err) ``` @@ -119,7 +130,7 @@ return: code samples: -``` +```go movers, err := agent.GetMovers("$DJI", "up", "percent") check(err) ``` @@ -139,7 +150,7 @@ return: to submit any trades in this library, one must use your encrypted account id. this as accessed by using the `agent.GetAccountNumbers()` function, which is then passed to the submission function. this is because there are use cases where you might want to change between multiple accounts while trading the same session. -``` +```go an, err := agent.GetAccountNumbers() check(err) ``` @@ -150,15 +161,15 @@ the rest of the docs assume you want to use the first element of the `[]AccountN suppose we wanted to submit a single-leg market order for the symbol "AAPL". this is as easy as: -``` +```go err = agent.SubmitSingleLegOrder(an[0].HashValue, CreateSingleLegOrder(OrderType("MARKET"), Session("NORMAL"), Duration("DAY"), Strategy("SINGLE"), Instruction("BUY"), Quantity(1.0), Instrument(SimpleOrderInstrument{ - Symbol: "AAPL", - AssetType: "EQUITY", + Symbol: "AAPL", + AssetType: "EQUITY", }))) check(err) ``` -let's break this down, although it's fairly straight forward. `CreateSingleLegOrder` returns a `SingleLegOrder`, passed to `agent.SubmitSingleLegOrder` after the hash value of your encrypted id. `CreateSingleOrder` accepts an unknown amount of parameters setting the various required elements for the order: +let's break this down, although it's fairly straight forward. `CreateSingleLegOrder` returns a `SingleLegOrder`, passed to `agent.SubmitSingleLegOrder` after the hash value of your encrypted id. `CreateSingleOrder` accepts an unknown amount of parameters setting the various elements for the order: ``` OrderType: @@ -170,12 +181,32 @@ Quantity: Instrument: ``` +the default behavior of CreateSingleLegOrder() assumes you are submitting an order with the following parameters: + +``` +OrderType: MARKET +Session: NORMAL +Duration: DAY +Strategy: SINGLE +``` + +meaning only `INSTRUCTION`, `QUANTITY` & `INSTRUMENT` are the only required directives. the above example can be simplified thusly: + +```go +err = agent.SubmitSingleLegOrder(an[0].HashValue, CreateSingleLegOrder(Instruction("BUY"), Quantity(1.0), Instrument(SimpleOrderInstrument{ + Symbol: "AAPL", + AssetType: "EQUITY", +}))) +check(err) +``` + ## WIP: DO NOT CROSS, DANGER DANGER + #### 2.2.0 accessing account data ##### 2.2.1.0 -``` +```go an, err := agent.GetAccountNumbers() check(err) fmt.Println(an) diff --git a/accounts-trading.go b/accounts-trading.go index 4f2fdd6..c9b142b 100644 --- a/accounts-trading.go +++ b/accounts-trading.go @@ -16,21 +16,21 @@ var ( endpointAccountNumbers string = accountEndpoint + "/accounts/accountNumbers" endpointAccounts string = accountEndpoint + "/accounts" endpointAccount string = accountEndpoint + "/accounts/%s" - //endpointUserPreference string = accountEndpoint + "/userPreference" + // endpointUserPreference string = accountEndpoint + "/userPreference" // Orders endpointOrders string = accountEndpoint + "/orders" endpointAccountOrders string = accountEndpoint + "/accounts/%s/orders" endpointAccountOrder string = accountEndpoint + "/accounts/%s/orders/%s" - //endpointPreviewOrder string = accountEndpoint + "/accounts/%s/previewOrder" + // endpointPreviewOrder string = accountEndpoint + "/accounts/%s/previewOrder" // Transactions - //endpointTransactions string = accountEndpoint + "/accounts/%s/transactions" + // endpointTransactions string = accountEndpoint + "/accounts/%s/transactions" endpointTransaction string = accountEndpoint + "/accounts/%s/transactions/%s" ) type Transaction struct { - ActivityID int `json:"ActivityId"` + ActivityId int Time string User User Description string @@ -315,8 +315,10 @@ type SimpleOrderInstrument struct { AssetType string // EQUITY } -type SingleLegOrderComposition func(order *SingleLegOrder) -type MultiLegSimpleOrderComposition func(order *MultiLegOrder) +type ( + SingleLegOrderComposition func(order *SingleLegOrder) + MultiLegSimpleOrderComposition func(order *MultiLegOrder) +) // Create a new Market order func CreateSingleLegOrder(opts ...SingleLegOrderComposition) *SingleLegOrder { @@ -376,52 +378,92 @@ func Instrument(instrument SimpleOrderInstrument) SingleLegOrderComposition { } } -var SingleLegOrderTemplate = ` +var OrderTemplate = ` { "orderType": "%s", "session": "%s", "duration": "%s", "orderStrategyType": "%s", "orderLegCollection": [ - { - "instruction": "%s", - "quantity": %f, - "instrument": { - "symbol": "%s", - "assetType": "%s" - } - } + %s ] } ` +var LegTemplate = ` +{ + "instruction": "%s", + "quantity": %f, + "instrument": { + "symbol": "%s", + "assetType": "%s" + } +}, +` + +var LegTemplateLast = ` +{ + "instruction": "%s", + "quantity": %f, + "instrument": { + "symbol": "%s", + "assetType": "%s" + } +}, +` + func marshalSingleLegOrder(order *SingleLegOrder) string { - return fmt.Sprintf(SingleLegOrderTemplate, order.OrderType, order.Session, order.Duration, order.Strategy, order.Instruction, order.Quantity, order.Instrument.Symbol, order.Instrument.AssetType) + return fmt.Sprintf(OrderTemplate, order.OrderType, order.Session, order.Duration, order.Strategy, fmt.Sprintf(LegTemplate, order.Instruction, order.Quantity, order.Instrument.Symbol, order.Instrument.AssetType)) +} + +func marshalMultiLegOrder(order *MultiLegOrder) string { + var legs string + // UNTESTED + for i, leg := range order.OrderLegCollection { + if i != len(order.OrderLegCollection)-1 { + legs += fmt.Sprintf(LegTemplate, leg.Instruction, leg.Quantity, leg.Instrument.Symbol, leg.Instrument.AssetType) + } else { + legs += fmt.Sprintf(LegTemplateLast, leg.Instruction, leg.Quantity, leg.Instrument.Symbol, leg.Instrument.AssetType) + } + } + return fmt.Sprintf(OrderTemplate) } // Submit a single-leg order for the specified encrypted account ID func (agent *Agent) SubmitSingleLegOrder(hashValue string, order *SingleLegOrder) error { orderJson := marshalSingleLegOrder(order) req, err := http.NewRequest("POST", fmt.Sprintf(endpointAccountOrders, hashValue), strings.NewReader(orderJson)) - check(err) + if err != nil { + return err + } req.Header.Set("Content-Type", "application/json") _, err = agent.Handler(req) - check(err) + if err != nil { + return err + } return nil } // Get a specific order by account number & order ID func (agent *Agent) GetOrder(accountNumber, orderID string) (FullOrder, error) { req, err := http.NewRequest("GET", fmt.Sprintf(endpointAccountOrder, accountNumber, orderID), nil) - check(err) + if err != nil { + return FullOrder{}, err + } resp, err := agent.Handler(req) - check(err) + if err != nil { + return FullOrder{}, err + } var order FullOrder defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + if err != nil { + return FullOrder{}, err + } err = sonic.Unmarshal(body, &order) - check(err) + if err != nil { + return FullOrder{}, err + } return order, nil } @@ -429,19 +471,27 @@ func (agent *Agent) GetOrder(accountNumber, orderID string) (FullOrder, error) { // yyyy-MM-ddTHH:mm:ss.SSSZ func (agent *Agent) GetAccountOrders(accountNumber, fromEnteredTime, toEnteredTime string) ([]FullOrder, error) { req, err := http.NewRequest("GET", fmt.Sprintf(endpointAccountOrders, accountNumber), nil) - check(err) + if err != nil { + return []FullOrder{}, err + } q := req.URL.Query() q.Add("fromEnteredTime", fromEnteredTime) q.Add("toEnteredTime", toEnteredTime) req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + if err != nil { + return []FullOrder{}, err + } var orders []FullOrder defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + if err != nil { + return []FullOrder{}, err + } err = sonic.Unmarshal(body, &orders) - check(err) + if err != nil { + return []FullOrder{}, err + } return orders, nil } @@ -450,19 +500,26 @@ func (agent *Agent) GetAccountOrders(accountNumber, fromEnteredTime, toEnteredTi // yyyy-MM-ddTHH:mm:ss.SSSZ func (agent *Agent) GetAllOrders(fromEnteredTime, toEnteredTime string) ([]FullOrder, error) { req, err := http.NewRequest("GET", endpointOrders, nil) - check(err) + if err != nil { + return []FullOrder{}, err + } q := req.URL.Query() q.Add("fromEnteredTime", fromEnteredTime) q.Add("toEnteredTime", toEnteredTime) req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + if err != nil { + return []FullOrder{}, err + } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + if err != nil { + return []FullOrder{}, err + } var orders []FullOrder - /*err = sonic.Unmarshal(body, &orders) - check(err)*/ + /* TODO: + err = sonic.Unmarshal(body, &orders) + isErrNil(err)*/ fmt.Println(body) return orders, nil } @@ -470,62 +527,94 @@ func (agent *Agent) GetAllOrders(fromEnteredTime, toEnteredTime string) ([]FullO // Get encrypted account numbers for trading func (agent *Agent) GetAccountNumbers() ([]AccountNumbers, error) { req, err := http.NewRequest("GET", endpointAccountNumbers, nil) - check(err) + if err != nil { + return []AccountNumbers{}, err + } resp, err := agent.Handler(req) - check(err) + if err != nil { + return []AccountNumbers{}, err + } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + if err != nil { + return []AccountNumbers{}, err + } var accountNumbers []AccountNumbers err = sonic.Unmarshal(body, &accountNumbers) - check(err) + if err != nil { + return []AccountNumbers{}, err + } return accountNumbers, nil } // Get all accounts associated with the user logged in func (agent *Agent) GetAccounts() ([]Account, error) { req, err := http.NewRequest("GET", endpointAccounts, nil) - check(err) + if err != nil { + return []Account{}, err + } resp, err := agent.Handler(req) - check(err) + if err != nil { + return []Account{}, err + } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + if err != nil { + return []Account{}, err + } var accounts []Account err = sonic.Unmarshal(body, &accounts) - check(err) + if err != nil { + return []Account{}, err + } return accounts, nil } // Get account by encrypted account id func (agent *Agent) GetAccount(id string) (Account, error) { req, err := http.NewRequest("GET", fmt.Sprintf(endpointAccount, id), nil) - check(err) + if err != nil { + return Account{}, err + } resp, err := agent.Handler(req) - check(err) + if err != nil { + return Account{}, err + } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + if err != nil { + return Account{}, err + } var account Account err = sonic.Unmarshal(body, &account) - check(err) + if err != nil { + return Account{}, err + } return account, nil } // Get all transactions for the user logged in -//func (agent *Agent) GetTransactions() ([]Transaction, error) {} +// func (agent *Agent) GetTransactions() ([]Transaction, error) {} // Get a transaction for a specific account id func (agent *Agent) GetTransaction(accountNumber, transactionId string) (Transaction, error) { req, err := http.NewRequest("GET", fmt.Sprintf(endpointTransaction, accountNumber, transactionId), nil) - check(err) + if err != nil { + return Transaction{}, err + } resp, err := agent.Handler(req) - check(err) + if err != nil { + return Transaction{}, err + } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + if err != nil { + return Transaction{}, err + } var transaction Transaction err = sonic.Unmarshal(body, &transaction) - check(err) + if err != nil { + return Transaction{}, err + } return transaction, nil } diff --git a/go.mod b/go.mod index a4bb5aa..e015679 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,24 @@ module github.com/go-schwab/trader -go 1.22.4 +go 1.23.2 -require github.com/joho/godotenv v1.5.1 +require ( + github.com/bytedance/sonic v1.12.2 + github.com/go-schwab/utils/oauth v0.0.0-20241103230919-01b09562dfc2 + github.com/joho/godotenv v1.5.1 + golang.org/x/oauth2 v0.23.0 +) require ( - github.com/bytedance/sonic v1.12.2 // indirect github.com/bytedance/sonic/loader v0.2.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/fatih/color v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/stretchr/testify v1.9.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect + golang.org/x/sys v0.25.0 // indirect ) diff --git a/go.sum b/go.sum index 1f1d15a..01ea035 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,5 @@ -github.com/bytedance/sonic v1.11.8 h1:Zw/j1KfiS+OYTi9lyB3bb0CFxPJVkM17k1wyDG32LRA= -github.com/bytedance/sonic v1.11.8/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic v1.12.2 h1:oaMFuRTpMHYLpCntGca65YWt5ny+wAceDERTkT2L9lg= github.com/bytedance/sonic v1.12.2/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= -github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM= github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= @@ -13,11 +10,22 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/go-schwab/utils/oauth v0.0.0-20241103230919-01b09562dfc2 h1:C7kojSdgemr586G27iYSol5L8TUHTR9TWxIwxcBHiag= +github.com/go-schwab/utils/oauth v0.0.0-20241103230919-01b09562dfc2/go.mod h1:8TuoJsd0EHSHnhDFNIQL9MTZKxYt5s0Xv45mGcUnE9U= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -26,15 +34,21 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= +golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= -rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/market-data.go b/market-data.go index c04fd33..aaf28db 100644 --- a/market-data.go +++ b/market-data.go @@ -30,164 +30,164 @@ var ( ) type Candle struct { - Time int `json:"datetime"` - Volume int `json:"volume"` - Open float64 `json:"open"` - Close float64 `json:"close"` + Time int `json:"datetime"` + Volume int + Open float64 + Close float64 Hi float64 `json:"high"` Lo float64 `json:"low"` } type Quote struct { - AssetMainType string `json:"assetMainType"` - AssetSubType string `json:"assetSubType"` - QuoteType string `json:"quoteType"` - RealTime bool `json:"realTime"` - SSID int `json:"ssid"` - Symbol string `json:"symbol"` + AssetMainType string + AssetSubType string + QuoteType string + RealTime bool + SSID int `json:"ssid"` + Symbol string Hi52 float64 `json:"52WeekHigh"` Lo52 float64 `json:"52WeekLow"` - AskMICID string `json:"askMICId"` - Ask float64 `json:"askPrice"` - AskSize int `json:"askSize"` - AskTime int `json:"askTime"` - BidMICID string `json:"bidMICId"` - Bid float64 `json:"bidPrice"` - BidSize int `json:"bidSize"` - BidTime int `json:"bidTime"` + AskMICId string + AskPrice float64 + AskSize int + AskTime int + BidMICId string + BidPrice float64 + BidSize int + BidTime int Close float64 `json:"closePrice"` - Hi float64 `json:"highPrice"` - LastMICID string `json:"lastMICId"` - LastPrice float64 `json:"lastPrice"` - LastSize int `json:"lastSize"` - Lo float64 `json:"lowPrice"` + HiPrice float64 `json:"highPrice"` + LastMICId string + LastPrice float64 + LastSize int + LoPrice float64 `json:"lowPrice"` Mark float64 `json:"mark"` - MarkChange float64 `json:"markChange"` - MarkPercentChange float64 `json:"markPercentChange"` - NetChange float64 `json:"netChange"` - NetPercentChange float64 `json:"netPercentChange"` - Open float64 `json:"open"` - PostMarketChange float64 `json:"postMarketChange"` - PostMarketPercentChange float64 `json:"postMarketPercentChange"` - QuoteTime int `json:"quoteTime"` - SecurityStatus string `json:"securityStatus"` - TotalVolume int `json:"totalVolume"` - TradeTime int `json:"tradeTime"` + MarkChange float64 + MarkPercentChange float64 + NetChange float64 + NetPercentChange float64 + Open float64 + PostMarketChange float64 + PostMarketPercentChange float64 + QuoteTime int + SecurityStatus string + TotalVolume int + TradeTime int } type SimpleInstrument struct { - Cusip string `json:"cusip"` - Symbol string `json:"symbol"` - Description string `json:"description"` - Exchange string `json:"exchange"` - AssetType string `json:"assetType"` + Cusip string + Symbol string + Description string + Exchange string + AssetType string } // Change this to reflect ordering of schwab return type FundamentalInstrument struct { - Symbol string `json:"symbol"` - Cusip string `json:"cusip"` - Description string `json:"description"` - Exchange string `json:"exchange"` - Type string `json:"assetType"` + Symbol string + Cusip string + Description string + Exchange string + AssetType string Hi52 float64 `json:"high52"` Lo52 float64 `json:"low52"` - DividendYield float64 `json:"dividendYield"` - DividendAmount float64 `json:"dividendAmount"` - DividendDate string `json:"dividendDate"` + DividendYield float64 + DividendAmount float64 + DividendDate string PE float64 `json:"peRatio"` PEG float64 `json:"pegRatio"` PB float64 `json:"pbRatio"` PR float64 `json:"prRatio"` PCF float64 `json:"pcfRatio"` - GrossMarginTTM float64 `json:"grossMarginTTM"` - NetProfitMarginTTM float64 `json:"netMarginTTM"` - OperatingMarginTTM float64 `json:"operatingMarginTTM"` - GrossMarginMRQ float64 `json:"grossMarginMRQ"` - NetProfitMarginMRQ float64 `json:"netMarginMRQ"` - OperatingMarginMRQ float64 `json:"operatingMarginMRQ"` + GrossMarginTTM float64 + NetMarginTTM float64 + OperatingMarginTTM float64 + GrossMarginMRQ float64 + NetProfitMarginMRQ float64 + OperatingMarginMRQ float64 ROE float64 `json:"returnOnEquity"` ROA float64 `json:"returnOnAssets"` ROI float64 `json:"returnOnInvestment"` - QuickRatio float64 `json:"quickRatio"` - CurrentRatio float64 `json:"currentRatio"` - InterestCoverage float64 `json:"interestCoverage"` - TotalDebtToCapital float64 `json:"totalDebtToCapital"` - LTDebtToEquity float64 `json:"ltDebtToEquity"` - TotalDebtToEquity float64 `json:"totalDebtToEquity"` - EPSTTM float64 `json:"epsTTM"` - EPSChangePercentTTM float64 `json:"epsChangePercentTTM"` - EPSChangeYear float64 `json:"epsChangeYear"` - EPSChange float64 `json:"epsChange"` - RevenueChangeYear float64 `json:"revChangeYear"` - RevenueChangeTTM float64 `json:"revChangeTTM"` - RevenueChangeIn float64 `json:"revChangeIn"` - SharesOutstanding float64 `json:"sharesOutstanding"` - MarketCapFloat float64 `json:"marketCapFloat"` - MarketCap float64 `json:"marketCap"` - BookValuePerShare float64 `json:"bookValuePerShare"` - ShortIntToFloat float64 `json:"shortIntToFloat"` - ShortIntDayToCover float64 `json:"shortIntDayToCover"` - DividendGrowthRate3Year float64 `json:"dividendGrowthRate3Year"` - DividendPayAmount float64 `json:"dividendPayAmount"` - DividendPayDate string `json:"dividendPayDate"` - Beta float64 `json:"beta"` - Vol1DayAverage float64 `json:"vol1DayAvg"` - Vol10DayAverage float64 `json:"vol10DayAvg"` - Vol3MonthAverage float64 `json:"vol3MonthAvg"` - Avg1DayVolume int `json:"avg1DayVolume"` - Avg10DaysVolume int `json:"avg10DaysVolume"` - Avg3MonthVolume int `json:"avg3MonthVolume"` - DeclarationDate string `json:"declarationDate"` - DividendFrequency int `json:"dividendFreq"` - EPS float64 `json:"eps"` - DTNVolume int `json:"dtnVolume"` - NextDividendPayDate string `json:"nextDividendPayDate"` - NextDividendDate string `json:"nextDividendDate"` - FundLeverageFactor float64 `json:"fundLeverageFactor"` + QuickRatio float64 + CurrentRatio float64 + InterestCoverage float64 + TotalDebtToCapital float64 + LtDebtToEquity float64 + TotalDebtToEquity float64 + EpsTTM float64 + EpsChangePercentTTM float64 + EpsChangeYear float64 + EpsChange float64 + RevChangeYear float64 + RevChangeTTM float64 + RevChangeIn float64 + SharesOutstanding float64 + MarketCapFloat float64 + MarketCap float64 + BookValuePerShare float64 + ShortIntToFloat float64 + ShortIntDayToCover float64 + DividendGrowthRate3Year float64 + DividendPayAmount float64 + DividendPayDate string + Beta float64 + Vol1DayAvg float64 + Vol10DayAvg float64 + Vol3MonthAvg float64 + Avg1DayVolume int + Avg10DaysVolume int + Avg3MonthVolume int + DeclarationDate string + DividendFreq int + Eps float64 + DtnVolume int + NextDividendPayDate string + NextDividendDate string + FundLeverageFactor float64 } type Screener struct { - Symbol string `json:"symbol"` - Description string `json:"description"` - Volume int `json:"volume"` - LastPrice float64 `json:"lastPrice"` - NetChange float64 `json:"netChange"` - MarketShare float64 `json:"marketShare"` - TotalVolume int `json:"totalVolume"` - Trades int `json:"trades"` - NetPercentChange float64 `json:"netPercentChange"` + Symbol string + Description string + Volume int + LastPrice float64 + NetChange float64 + MarketShare float64 + TotalVolume int + Trades int + NetPercentChange float64 } // WIP: type Underlying struct{} // WIP: type Contract struct { - TYPE string `json:""` - SYMBOL string `json:""` - STRIKE float64 `json:""` - EXCHANGE string `json:""` - EXPIRATION float64 `json:""` - DAYS2EXPIRATION float64 `json:""` - BID float64 `json:""` - ASK float64 `json:""` - LAST float64 `json:""` - MARK float64 `json:""` - BIDASK_SIZE string `json:""` - VOLATILITY float64 `json:""` - DELTA float64 `json:""` - GAMMA float64 `json:""` - THETA float64 `json:""` - VEGA float64 `json:""` - RHO float64 `json:""` - OPEN_INTEREST float64 `json:""` - TIME_VALUE float64 `json:""` - THEORETICAL_VALUE float64 `json:""` - THEORETICAL_VOLATILITY float64 `json:""` - PERCENT_CHANGE float64 `json:""` - MARK_CHANGE float64 `json:""` - MARK_PERCENT_CHANGE float64 `json:""` - INTRINSIC_VALUE float64 `json:""` - IN_THE_MONEY bool `json:""` + TYPE string + SYMBOL string + STRIKE float64 + EXCHANGE string + EXPIRATION float64 + DAYS2EXPIRATION float64 + BID float64 + ASK float64 + LAST float64 + MARK float64 + BIDASK_SIZE string + VOLATILITY float64 + DELTA float64 + GAMMA float64 + THETA float64 + VEGA float64 + RHO float64 + OPEN_INTEREST float64 + TIME_VALUE float64 + THEORETICAL_VALUE float64 + THEORETICAL_VOLATILITY float64 + PERCENT_CHANGE float64 + MARK_CHANGE float64 + MARK_PERCENT_CHANGE float64 + INTRINSIC_VALUE float64 + IN_THE_MONEY bool } // WIP: func GetQuotes(symbols string) Quote, error) {} @@ -197,19 +197,27 @@ type Contract struct { // ticker = "AAPL", etc. func (agent *Agent) GetQuote(symbol string) (Quote, error) { req, err := http.NewRequest("GET", endpointQuotes, nil) - check(err) + if err != nil { + return Quote{}, err + } q := req.URL.Query() q.Add("symbols", symbol) q.Add("fields", "quote") req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + if err != nil { + return Quote{}, err + } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + if err != nil { + return Quote{}, err + } var quote Quote err = sonic.Unmarshal([]byte(strings.Join(strings.Split(strings.Split(string(body), fmt.Sprintf("\"%s\":", symbol))[1], "\"quote\":{"), "")[:len(strings.Join(strings.Split(strings.Split(string(body), fmt.Sprintf("\"%s\":", symbol))[1], "\"quote\":{"), ""))-2]), "e) - check(err) + if err != nil { + return Quote{}, err + } return quote, err } @@ -217,19 +225,27 @@ func (agent *Agent) GetQuote(symbol string) (Quote, error) { // It takes one param: func (agent *Agent) SearchInstrumentSimple(symbols string) (SimpleInstrument, error) { req, err := http.NewRequest("GET", endpointSearchInstrument, nil) - check(err) + if err != nil { + return SimpleInstrument{}, err + } q := req.URL.Query() q.Add("symbol", symbols) q.Add("projection", "symbol-search") req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + if err != nil { + return SimpleInstrument{}, err + } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + if err != nil { + return SimpleInstrument{}, err + } var instrument SimpleInstrument err = sonic.Unmarshal([]byte(strings.Split(string(body), "[")[1][:len(strings.Split(string(body), "[")[1])-2]), &instrument) - check(err) + if err != nil { + return SimpleInstrument{}, err + } return instrument, nil } @@ -237,21 +253,29 @@ func (agent *Agent) SearchInstrumentSimple(symbols string) (SimpleInstrument, er // It takes one param: func (agent *Agent) SearchInstrumentFundamental(symbol string) (FundamentalInstrument, error) { req, err := http.NewRequest("GET", endpointSearchInstrument, nil) - check(err) + if err != nil { + return FundamentalInstrument{}, err + } q := req.URL.Query() q.Add("symbol", symbol) q.Add("projection", "fundamental") req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + if err != nil { + return FundamentalInstrument{}, err + } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + if err != nil { + return FundamentalInstrument{}, err + } var instrument FundamentalInstrument split0 := strings.Split(string(body), "[{\"fundamental\":")[1] split := strings.Split(split0, "}") err = sonic.Unmarshal([]byte(fmt.Sprintf("%s}", strings.Join(split[:2], ""))), &instrument) - check(err) + if err != nil { + return FundamentalInstrument{}, err + } return instrument, nil } @@ -274,7 +298,9 @@ func (agent *Agent) SearchInstrumentFundamental(symbol string) (FundamentalInstr // endDate = func (agent *Agent) GetPriceHistory(symbol, periodType, period, frequencyType, frequency, startDate, endDate string) ([]Candle, error) { req, err := http.NewRequest("GET", endpointPriceHistory, nil) - check(err) + if err != nil { + return []Candle{}, err + } q := req.URL.Query() q.Add("symbol", symbol) q.Add("periodType", periodType) @@ -285,13 +311,19 @@ func (agent *Agent) GetPriceHistory(symbol, periodType, period, frequencyType, f q.Add("endDate", endDate) req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + if err != nil { + return []Candle{}, err + } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + if err != nil { + return []Candle{}, err + } var candles []Candle err = sonic.Unmarshal([]byte(fmt.Sprintf("[%s]", strings.Split(strings.Split(string(body), "[")[1], "]")[0])), &candles) - check(err) + if err != nil { + return []Candle{}, err + } return candles, nil } @@ -302,23 +334,57 @@ func (agent *Agent) GetPriceHistory(symbol, periodType, period, frequencyType, f // change = "percent" or "value" func (agent *Agent) GetMovers(index, direction, change string) ([]Screener, error) { req, err := http.NewRequest("GET", fmt.Sprintf(endpointMovers, index), nil) - check(err) + if err != nil { + return []Screener{}, err + } q := req.URL.Query() q.Add("direction", direction) q.Add("change", change) req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + if err != nil { + return []Screener{}, err + } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + if err != nil { + return []Screener{}, err + } var movers []Screener stringToParse := fmt.Sprintf("[%s]", strings.Split(string(body), "[")[1][:len(strings.Split(string(body), "[")[1])-2]) err = sonic.Unmarshal([]byte(stringToParse), &movers) - check(err) + if err != nil { + return []Screener{}, err + } return movers, nil } +// get all option chains for a ticker +func (agent *Agent) GetChains(symbol string) ([]Contract, error) { + req, err := http.NewRequest("GET", endpointOptions, nil) + if err != nil { + return []Contract{}, err + } + q := req.URL.Query() + q.Add("symbol", symbol) + req.URL.RawQuery = q.Encode() + resp, err := agent.Handler(req) + if err != nil { + return []Contract{}, err + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return []Contract{}, err + } + var chain []Contract + err = sonic.Unmarshal(body, &chain) + if err != nil { + return []Contract{}, err + } + return chain, nil +} + // Single returns a []CONTRACT; containing a SINGLE option chain of your desired strike, type, etc., // it takes four parameters: // ticker = "AAPL", etc. @@ -335,25 +401,25 @@ func (agent *Agent) GetMovers(index, direction, change string) ([]Screener, erro // toDate = Only return expirations before this date. Valid ISO-8601 formats are: yyyy-MM-dd and yyyy-MM-dd'T'HH:mm:ssz. // Lets examine a sample call of Single: Single("AAPL","CALL","ALL","5","2022-07-01"). // This returns 5 AAPL CALL contracts both above and below the at the money price, with no preference as to the status of the contract ("ALL"), expiring before 2022-07-01 -func (agent *Agent) Single(ticker, contractType, strikeRange, strikeCount, toDate string) ([]Contract, error) { +func (agent *Agent) Single(symbol, contractType, strikeRange, strikeCount, toDate string) ([]Contract, error) { req, err := http.NewRequest("GET", endpointOptions, nil) - check(err) + isErrNil(err) q := req.URL.Query() - q.Add("symbol", ticker) + q.Add("symbol", symbol) q.Add("contractType", contractType) q.Add("range", strikeRange) q.Add("strikeCount", strikeCount) q.Add("toDate", toDate) req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + isErrNil(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) var chain []Contract // WIP err = sonic.Unmarshal(body, &chain) - check(err) + isErrNil(err) return chain, nil } diff --git a/utils.go b/utils.go index 5e587bd..63bf0e8 100644 --- a/utils.go +++ b/utils.go @@ -1,251 +1,131 @@ package trader import ( - "bytes" - "encoding/base64" - "encoding/json" + "context" + "crypto/tls" "errors" - "fmt" "io" + "io/fs" "log" "net/http" - "net/url" "os" - "os/exec" - "runtime" + "path/filepath" "strings" - "time" + "github.com/bytedance/sonic" + o "github.com/go-schwab/utils/oauth" "github.com/joho/godotenv" + "golang.org/x/oauth2" ) -func init() { - err := godotenv.Load("config.env") - check(err) -} - -type Agent struct { - tokens Token -} +type Agent struct{ *o.AuthorizedClient } -type Token struct { - RefreshExpiration time.Time - Refresh string - BearerExpiration time.Time - Bearer string -} - -// Helper: parse access token response -func parseAccessTokenResponse(s string) Token { - token := Token{ - RefreshExpiration: time.Now().Add(time.Hour * 168), - BearerExpiration: time.Now().Add(time.Minute * 30), - } - for _, x := range strings.Split(s, ",") { - for i1, x1 := range strings.Split(x, ":") { - if trimOneFirstOneLast(x1) == "refresh_token" { - token.Refresh = trimOneFirstOneLast(strings.Split(x, ":")[i1+1]) - } else if trimOneFirstOneLast(x1) == "access_token" { - token.Bearer = trimOneFirstOneLast(strings.Split(x, ":")[i1+1]) - } - } - } - return token -} - -// Read in tokens from ~/.trade/bar.json -func readDB() Token { - var tokens Token - body, err := os.ReadFile(fmt.Sprintf("%s/.trade/bar.json", homeDir())) - check(err) - err = json.Unmarshal(body, &tokens) - check(err) - return tokens -} - -// Credit: https://go.dev/play/p/C2sZRYC15XN -func getStringInBetween(str string, start string, end string) (result string) { - s := strings.Index(str, start) - if s == -1 { - return - } - s += len(start) - e := strings.Index(str[s:], end) - if e == -1 { - return - } - return str[s : s+e] -} +var ( + APPKEY string + SECRET string + CBURL string +) -// Credit: https://gist.github.com/hyg/9c4afcd91fe24316cbf0 -func openBrowser(url string) { - var err error - switch runtime.GOOS { - case "linux": - err = exec.Command("xdg-open", url).Start() - case "windows": - err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() - case "darwin": - err = exec.Command("open", url).Start() - default: - log.Fatalf("Unsupported platform.") - } - check(err) +func init() { + err := godotenv.Load(findAllEnvFiles()...) + isErrNil(err) + APPKEY = os.Getenv("APPKEY") + SECRET = os.Getenv("SECRET") + CBURL = os.Getenv("CBURL") } -// Generic error checking, will be implementing more robust error/exception handling >v0.9.0 -func check(err error) { +// is the err nil? +func isErrNil(err error) { if err != nil { - log.Fatalf("[ERR] %s", err.Error()) - } -} - -// trim one FIRST character in the string -func trimOneFirst(s string) string { - if len(s) < 1 { - return "" + log.Fatalf("[fatal] %s", err.Error()) } - return s[1:] } -// trim one LAST character in the string -func trimOneLast(s string) string { - if len(s) < 1 { - return "" - } - return s[:len(s)-1] -} - -// trim one FIRST & one LAST character in the string -func trimOneFirstOneLast(s string) string { - if len(s) < 1 { - return "" - } - return s[1 : len(s)-1] -} - -// trim two FIRST & one LAST character in the string -func trimTwoFirstOneLast(s string) string { - if len(s) < 1 { - return "" - } - return s[2 : len(s)-1] -} - -// trim one FIRST & two LAST character in the string -func trimOneFirstTwoLast(s string) string { - if len(s) < 1 { - return "" - } - return s[1 : len(s)-2] -} - -// trim one FIRST & three LAST character in the string -func trimOneFirstThreeLast(s string) string { - if len(s) < 1 { - return "" +// find all env files +func findAllEnvFiles() []string { + var files []string + err := filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + split := strings.Split(d.Name(), ".") + if len(split) > 1 { + if split[1] == "env" { + files = append(files, d.Name()) + } + } + return err + }) + isErrNil(err) + return files +} + +// Read in tokens from .json +func readDB() Agent { + var tok *oauth2.Token + body, err := os.ReadFile(".json") + isErrNil(err) + err = sonic.Unmarshal(body, &tok) + isErrNil(err) + conf := &oauth2.Config{ + ClientID: APPKEY, // Schwab App Key + ClientSecret: SECRET, // Schwab App Secret + Endpoint: oauth2.Endpoint{ + AuthURL: "https://api.schwabapi.com/v1/oauth/authorize", + TokenURL: "https://api.schwabapi.com/v1/oauth/token", + }, + } + tr := &http.Transport{ + TLSClientConfig: &tls.Config{}, + } + sslcli := &http.Client{Transport: tr} + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, sslcli) + return Agent{ + &o.AuthorizedClient{ + conf.Client(ctx, tok), + tok, + }, } - return s[1 : len(s)-3] -} - -// wrapper for os.UserHomeDir() -func homeDir() string { - dir, err := os.UserHomeDir() - check(err) - return dir } -// Initiate the Schwab oAuth process to retrieve bearer/refresh tokens func Initiate() *Agent { - agent := Agent{} - if _, err := os.Stat(fmt.Sprintf("%s/.trade", homeDir())); errors.Is(err, os.ErrNotExist) { - err := os.Mkdir(fmt.Sprintf("%s/.trade", homeDir()), os.ModePerm) - check(err) - // oAuth Leg 1 - Authorization Code - openBrowser(fmt.Sprintf("https://api.schwabapi.com/v1/oauth/authorize?client_id=%s&redirect_uri=%s", os.Getenv("APPKEY"), os.Getenv("CBURL"))) - fmt.Printf("Log into your Schwab brokerage account. Copy Error404 URL and paste it here: ") - var urlInput string - fmt.Scanln(&urlInput) - authCodeEncoded := getStringInBetween(urlInput, "?code=", "&session=") - authCode, err := url.QueryUnescape(authCodeEncoded) - check(err) - // oAuth Leg 2 - Refresh, Bearer Tokens - authStringLegTwo := fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", os.Getenv("APPKEY"), os.Getenv("SECRET"))))) - client := http.Client{} - payload := fmt.Sprintf("grant_type=authorization_code&code=%s&redirect_uri=%s", string(authCode), os.Getenv("CBURL")) - req, err := http.NewRequest("POST", "https://api.schwabapi.com/v1/oauth/token", bytes.NewBuffer([]byte(payload))) - check(err) - req.Header = http.Header{ - "Authorization": {authStringLegTwo}, - "Content-Type": {"application/x-www-form-urlencoded"}, - } - res, err := client.Do(req) - check(err) - defer res.Body.Close() - bodyBytes, err := io.ReadAll(res.Body) - check(err) - agent.tokens = parseAccessTokenResponse(string(bodyBytes)) - tokensJson, err := json.Marshal(agent.tokens) - check(err) - err = os.WriteFile(fmt.Sprintf("%s/.trade/bar.json", homeDir()), tokensJson, 0777) - check(err) + var agent Agent + // TODO: test this block, this is to attempt to resolve the error described in #67 + if _, err := os.Stat(".json"); errors.Is(err, os.ErrNotExist) { + agent = Agent{o.Initiate(APPKEY, SECRET)} + bytes, err := sonic.Marshal(agent.Token) + isErrNil(err) + err = os.WriteFile(".json", bytes, 0777) + isErrNil(err) } else { - agent.tokens = readDB() - if agent.tokens.Bearer == "" { - err := os.RemoveAll(fmt.Sprintf("%s/.trade", homeDir())) - check(err) - log.Fatalf("[err] please reinitiate, something went wrong\n") - } + agent = readDB() } return &agent } -// Use refresh token to generate a new bearer token for authentication -func (agent *Agent) refresh() { - oldTokens := readDB() - authStringRefresh := fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", os.Getenv("APPKEY"), os.Getenv("SECRET"))))) - client := http.Client{} - req, err := http.NewRequest("POST", "https://api.schwabapi.com/v1/oauth/token", bytes.NewBuffer([]byte(fmt.Sprintf("grant_type=refresh_token&refresh_token=%s", oldTokens.Refresh)))) - check(err) - req.Header = http.Header{ - "Authorization": {authStringRefresh}, - "Content-Type": {"application/x-www-form-urlencoded"}, - } - res, err := client.Do(req) - check(err) - defer res.Body.Close() - bodyBytes, err := io.ReadAll(res.Body) - check(err) - agent.tokens = parseAccessTokenResponse(string(bodyBytes)) -} - // Handler is the general purpose request function for the td-ameritrade api, all functions will be routed through this handler function, which does all of the API calling work // It performs a GET request after adding the apikey found in the config.env file in the same directory as the program calling the function, // then returns the body of the GET request's return. // It takes one parameter: // req = a request of type *http.Request func (agent *Agent) Handler(req *http.Request) (*http.Response, error) { - if (&Agent{}) == agent { - log.Fatal("[ERR] empty agent - call 'Agent.Initiate' before making any API function calls.") - } - if !time.Now().Before(agent.tokens.BearerExpiration) { - agent.refresh() + var err error + if agent.Token.AccessToken == "" { + log.Fatal("[fatal] no access token found, please reinitiate with 'Initiate'") + // TODO: auto reinitiate? } - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", agent.tokens.Bearer)) - client := http.Client{} - resp, err := client.Do(req) + resp, err := agent.Do(req) if err != nil { return resp, err } - if resp.StatusCode == 401 { - err := os.Remove(fmt.Sprintf("%s/.trade", homeDir())) - check(err) - } - if resp.StatusCode < 200 || resp.StatusCode > 300 { + switch true { + case resp.StatusCode == 401: + log.Fatal("[fatal] invalid token - please reinitiate with 'Initiate'") + case resp.StatusCode < 200, resp.StatusCode > 300: defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) - log.Fatalf("[ERR] %d, %s", resp.StatusCode, body) + isErrNil(err) + log.Println("[err] ", string(body)) } return resp, nil }