From 62f6f52f94aa55eb8d49aa04aee0e8f27422aa51 Mon Sep 17 00:00:00 2001 From: Murtaza Aliakbar Date: Sun, 15 Dec 2024 19:23:36 +0530 Subject: [PATCH] ipn/proton: store multiple wg configs per cc --- intra/backend/ipn_proxies.go | 2 +- intra/ipn/proxies.go | 7 +- intra/ipn/warp/proton.go | 167 +++++++++++++++++++++++------------ 3 files changed, 115 insertions(+), 61 deletions(-) diff --git a/intra/backend/ipn_proxies.go b/intra/backend/ipn_proxies.go index 7391edd8..235340a0 100644 --- a/intra/backend/ipn_proxies.go +++ b/intra/backend/ipn_proxies.go @@ -63,7 +63,7 @@ type Rpn interface { // RegisterAmnezia registers a new Amnezia installation. RegisterAmnezia(publicKeyBase64 string) (json []byte, err error) // RegisterProton registers a new Proton installation. - RegisterProton(existingStateJson []byte, serversFile string) (json []byte, err error) + RegisterProton(existingStateJson []byte) (json []byte, err error) // TestWarp connects to some Warp IPs and returns reachable ones. TestWarp() (ips string, errs error) // TestAmnezia connects to the Amnezia gateway and returns its IP if reachable. diff --git a/intra/ipn/proxies.go b/intra/ipn/proxies.go index 79ef5008..6d8b9af0 100644 --- a/intra/ipn/proxies.go +++ b/intra/ipn/proxies.go @@ -846,14 +846,15 @@ func (px *proxifier) RegisterAmnezia(pub string) ([]byte, error) { } // RegisterProton implements x.Rpn. -func (px *proxifier) RegisterProton(existingStateJson []byte, serversFile string) (stateJson []byte, err error) { +func (px *proxifier) RegisterProton(existingStateJson []byte) (stateJson []byte, err error) { + const nostore = "" var id *warp.ProtonWgConfig // may be nil redo := len(existingStateJson) > 0 if redo { - id, err = px.extc.MakeProtonWgFrom(px.ctx, existingStateJson, serversFile) + id, err = px.extc.MakeProtonWgFrom(px.ctx, existingStateJson, nostore) } else { - id, err = px.extc.MakeProtonWg(px.ctx, serversFile) + id, err = px.extc.MakeProtonWg(px.ctx, nostore) } px.lastProtonErr = err // may be nil if err != nil { diff --git a/intra/ipn/warp/proton.go b/intra/ipn/warp/proton.go index 909a1bbe..1da7cb7a 100644 --- a/intra/ipn/warp/proton.go +++ b/intra/ipn/warp/proton.go @@ -66,7 +66,11 @@ var ( errProtonCredsMismatch = errors.New("proton: creds mismatch") ) -const maxProtonLogicalsRefreshThreshold = 72 * time.Hour +const ( + maxProtonLogicalsRefreshThreshold = 72 * time.Hour + maxPerRegionWgConfs = 6 + maxRegisterCertTries = 3 +) var protonLogicalsUpdateTime = time.Time{} @@ -310,19 +314,20 @@ type ProtonServerResponse struct { // "Load": 63 // } type ProtonLogicals struct { - Name string `json:"Name"` - EntryCountry string `json:"EntryCountry"` - ExitCountry string `json:"ExitCountry"` - Domain string `json:"Domain"` - Tier int `json:"Tier"` - Features int `json:"Features"` - Region string `json:"Region"` - City string `json:"City"` - Score float64 - HostCountry string `json:"HostCountry"` - Organization string `json:"OrganizationID"` - VPNGatewayID string `json:"VPNGatewayID"` - ID string `json:"ID"` + Name string `json:"Name"` + EntryCountry string `json:"EntryCountry"` + ExitCountry string `json:"ExitCountry"` + Domain string `json:"Domain"` + Tier int `json:"Tier"` + Features int `json:"Features"` + Region string `json:"Region"` + City string `json:"City"` + Score float64 `json:"Score"` + HostCountry string `json:"HostCountry"` + Organization string `json:"OrganizationID"` + VPNGatewayID string `json:"VPNGatewayID"` + ID string `json:"ID"` + Load int `json:"Load"` Location ProtonServerLocation Status int `json:"Status"` Servers []ProtonServer @@ -350,6 +355,8 @@ type ProtonServerLocation struct { // "ServicesDownReason": null // } type ProtonServer struct { + Name string `json:"Name"` + Load int `json:"Load"` EntryIP string `json:"EntryIP"` ExitIP string `json:"ExitIP"` Domain string `json:"Domain"` @@ -390,19 +397,22 @@ type ProtonWgConfig struct { UID string `json:"UID"` SessionAccessToken string `json:"SessionAccessToken"` SessionRefreshToken string `json:"SessionRefreshToken"` - UserID string `json:"UserID"` - CredsAccessToken string `json:"UserAccessToken"` - CredsRefreshToken string `json:"UserRefreshToken"` - CertSerialNumber string `json:"CertSerialNumber"` - CertExpTime int `json:"CertExpTime"` - CertRefreshTime int `json:"CertRefreshTime"` + UserID string `json:"UserID"` + CredsAccessToken string `json:"UserAccessToken"` + CredsRefreshToken string `json:"UserRefreshToken"` + + CertSerialNumber string `json:"CertSerialNumber"` + CertExpTime int `json:"CertExpTime"` + CertRefreshTime int `json:"CertRefreshTime"` + + CreateTimestamp int64 `json:"CreateTimestamp"` RegionalWgConfs []*RegionalWgConf `json:"RegionalWgConfs"` } type RegionalWgConf struct { CC string `json:"CC"` - Load int `json:"Load"` + Name string `json:"Name"` ClientAddr4 string `json:"ClientAddr4"` ClientAddr6 string `json:"ClientAddr6"` @@ -509,30 +519,12 @@ func newProtonGw(ctx context.Context, k ProtonKey, logicals []ProtonLogicals, h2 publicKeyPem = publicKeyPem[6:16] } - m := make(map[string][]ProtonServer, 0) - skips := 0 - tot := 0 - for _, x := range logicals { - // github.com/ProtonVPN/android-app/blob/b9c6e59de40/app/src/main/java/com/protonvpn/android/utils/ServerManager.kt#L251 - // skip premium or restricted or offline servers - if x.securecore() || x.gateway() || x.offline() { - skips++ - continue - } - if c, ok := m[x.EntryCountry]; ok { - m[x.EntryCountry] = append(c, x.Servers...) - } else { - m[x.EntryCountry] = x.Servers - } - tot += len(x.Servers) - } - log.I("proton: new gw for %s: sz: l(%d) => [cc(%d) => svcs(%d) / skip: %d]", - publicKeyPem, len(logicals), len(m), tot, skips) + m := protonServersByCountry(logicals) a := &protongw{ http: h2, key: k, - servers: m, + servers: m, // may be empty sched: core.NewScheduler(ctx), sess: struct { uid string @@ -547,6 +539,8 @@ func newProtonGw(ctx context.Context, k ProtonKey, logicals []ProtonLogicals, h2 config: nil, } + log.I("proton: gw: new: %s / %d", publicKeyPem, len(m)) + return a, nil } @@ -599,14 +593,21 @@ func (a *protongw) newConf() error { rwgConfs := make([]*RegionalWgConf, 0, len(a.servers)) for cc, ss := range a.servers { wc := new(RegionalWgConf) + wc.CC = cc wc.ClientAddr4 = protonClientAddr4 wc.ClientPrivKey = clientPrivKey wc.ClientPubKey = clientPubKey wc.ClientDNS4 = protonDNSAddr4 + n := 0 for _, s := range ss { + if n > maxPerRegionWgConfs { + break + } // github.com/ProtonVPN/android-app/blob/b9c6e59de40/app/src/main/java/com/protonvpn/android/utils/ServerManager.kt#L251 if s.online() && s.wg() { + n++ + wc.Name = s.Name wc.ServerPubKey = s.X25519PublicKey wc.ServerIPPort4 = fmt.Sprintf("%s:%d", s.EntryIP, protonPrimaryTcpPort) wc.ServerDomainPort = fmt.Sprintf("%s:%d", s.Domain, protonPrimaryTcpPort) @@ -614,10 +615,8 @@ func (a *protongw) newConf() error { rwgConfs = append(rwgConfs, wc) - log.VV("proton: genconf: %s: %s@%s; peer: %s@%s", - cc, wc.ClientAddr4, wc.ClientPubKey[:6], wc.ServerIPPort4, wc.ServerPubKey[:6]) - - break + log.VV("proton: genconf: %s n:%d, l:%d; x: %s@%s; p: %s@%s", + s.Name, n, s.Load, wc.ClientAddr4, wc.ClientPubKey[:6], wc.ServerIPPort4, wc.ServerPubKey[:6]) } } } @@ -643,13 +642,14 @@ func (a *protongw) newConf() error { // wg info pc.RegionalWgConfs = rwgConfs + pc.CreateTimestamp = time.Now().Unix() + a.config = pc return nil // success } func (a *protongw) registerCert() error { - const maxRegisterCertTries = 3 tries := 0 retryAfterRefresh: @@ -1041,6 +1041,20 @@ func (a *protongw) reg() error { return a.newConf() } +func (a *protongw) refreshServers() error { + const nofile = "" + + oldEnough := time.Since(protonLogicalsUpdateTime) > maxProtonLogicalsRefreshThreshold + missingConfig := a.config == nil || len(a.config.RegionalWgConfs) <= 0 + if oldEnough || missingConfig { + t := protonLogicalsUpdateTime.Format(time.RFC1123) + log.I("proton: refresh servers; old(%s)? %t / missing? %t", oldEnough, t, missingConfig) + a.servers = protonServersByCountry(protonServersFrom(nofile, a.http)) + } + + return nil +} + func (a *protongw) rereg() error { if len(a.sess.uid) <= 0 { log.W("proton: re-reg: no session; initiating reg") @@ -1102,8 +1116,7 @@ func (w *Client) MakeProtonWgFrom(ctx context.Context, fromConfigJson []byte, al return nil, err } - svcs := protonServersFrom(allServersFilePath, &w.h2) - + svcs := protonServersPrebuilt() // refreshed if needed later a, err := newProtonGw(ctx, k, svcs, &w.h2) if err != nil { return nil, err @@ -1119,6 +1132,11 @@ func (w *Client) MakeProtonWgFrom(ctx context.Context, fromConfigJson []byte, al return nil, err } + err = a.refreshServers() + if err != nil { + return nil, err + } + return a.config, nil } @@ -1138,16 +1156,52 @@ func (a *protongw) load(conf *ProtonWgConfig) error { a.cert.ExpirationTime = conf.CertExpTime a.cert.RefreshTime = conf.CertRefreshTime + protonLogicalsUpdateTime = time.Unix(conf.CreateTimestamp, 0) + return nil } -// go.dev/play/p/9kapzPiG72r -func protonServersFrom(allServersFilePath string, c *http.Client) []ProtonLogicals { - var prebuilts, all ProtonServerResponse +func protonServersByCountry(logicals []ProtonLogicals) map[string][]ProtonServer { + m := make(map[string][]ProtonServer, 0) + skips := 0 + tot := 0 + for _, x := range logicals { + // github.com/ProtonVPN/android-app/blob/b9c6e59de40/app/src/main/java/com/protonvpn/android/utils/ServerManager.kt#L251 + // skip premium or restricted or offline servers + if x.securecore() || x.gateway() || x.offline() { + skips++ + continue + } + for _, s := range x.Servers { + s.Load = x.Load + s.Name = x.Name + } + if c, ok := m[x.EntryCountry]; ok { + m[x.EntryCountry] = append(c, x.Servers...) + } else { + m[x.EntryCountry] = x.Servers + } + tot += len(x.Servers) + } + log.I("proton: servers: sz: l(%d) => [cc(%d) => svcs(%d) / skip: %d]", + len(logicals), len(m), tot, skips) + return m +} + +func protonServersPrebuilt() []ProtonLogicals { + var prebuilts []ProtonLogicals err := json.Unmarshal(prebuiltProtonServersJson, &prebuilts) if err != nil { log.E("proton: servers: %d unmarshal: %v", len(prebuiltProtonServersJson), err) } + return prebuilts +} + +// go.dev/play/p/9kapzPiG72r +func protonServersFrom(allServersFilePath string, c *http.Client) []ProtonLogicals { + var all ProtonServerResponse + + prebuilts := protonServersPrebuilt() if len(allServersFilePath) > 0 { fp := filepath.Clean(allServersFilePath) @@ -1224,16 +1278,15 @@ func protonServersFrom(allServersFilePath string, c *http.Client) []ProtonLogica _, err = f.Write(b) if err != nil { log.E("proton: servers: write %s, err: %v", fp, err) - } else { - protonLogicalsUpdateTime = time.Now() - } + } // else: written } - } + } // else: no-store + protonLogicalsUpdateTime = time.Now() } } } } } - } - return append(all.R, prebuilts.R...) + } // else: contains remote servers + return append(all.R, prebuilts...) }