Skip to content

Commit

Permalink
ipn/proton: store multiple wg configs per cc
Browse files Browse the repository at this point in the history
  • Loading branch information
ignoramous committed Dec 15, 2024
1 parent 610de80 commit 62f6f52
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 61 deletions.
2 changes: 1 addition & 1 deletion intra/backend/ipn_proxies.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type Rpn interface {
// RegisterAmnezia registers a new Amnezia installation.
RegisterAmnezia(publicKeyBase64 string) (json []byte, err error)
// RegisterProton registers a new Proton installation.
RegisterProton(existingStateJson []byte, serversFile string) (json []byte, err error)
RegisterProton(existingStateJson []byte) (json []byte, err error)
// TestWarp connects to some Warp IPs and returns reachable ones.
TestWarp() (ips string, errs error)
// TestAmnezia connects to the Amnezia gateway and returns its IP if reachable.
Expand Down
7 changes: 4 additions & 3 deletions intra/ipn/proxies.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,14 +846,15 @@ func (px *proxifier) RegisterAmnezia(pub string) ([]byte, error) {
}

// RegisterProton implements x.Rpn.
func (px *proxifier) RegisterProton(existingStateJson []byte, serversFile string) (stateJson []byte, err error) {
func (px *proxifier) RegisterProton(existingStateJson []byte) (stateJson []byte, err error) {
const nostore = ""
var id *warp.ProtonWgConfig // may be nil

redo := len(existingStateJson) > 0
if redo {
id, err = px.extc.MakeProtonWgFrom(px.ctx, existingStateJson, serversFile)
id, err = px.extc.MakeProtonWgFrom(px.ctx, existingStateJson, nostore)
} else {
id, err = px.extc.MakeProtonWg(px.ctx, serversFile)
id, err = px.extc.MakeProtonWg(px.ctx, nostore)
}
px.lastProtonErr = err // may be nil
if err != nil {
Expand Down
167 changes: 110 additions & 57 deletions intra/ipn/warp/proton.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ var (
errProtonCredsMismatch = errors.New("proton: creds mismatch")
)

const maxProtonLogicalsRefreshThreshold = 72 * time.Hour
const (
maxProtonLogicalsRefreshThreshold = 72 * time.Hour
maxPerRegionWgConfs = 6
maxRegisterCertTries = 3
)

var protonLogicalsUpdateTime = time.Time{}

Expand Down Expand Up @@ -310,19 +314,20 @@ type ProtonServerResponse struct {
// "Load": 63
// }
type ProtonLogicals struct {
Name string `json:"Name"`
EntryCountry string `json:"EntryCountry"`
ExitCountry string `json:"ExitCountry"`
Domain string `json:"Domain"`
Tier int `json:"Tier"`
Features int `json:"Features"`
Region string `json:"Region"`
City string `json:"City"`
Score float64
HostCountry string `json:"HostCountry"`
Organization string `json:"OrganizationID"`
VPNGatewayID string `json:"VPNGatewayID"`
ID string `json:"ID"`
Name string `json:"Name"`
EntryCountry string `json:"EntryCountry"`
ExitCountry string `json:"ExitCountry"`
Domain string `json:"Domain"`
Tier int `json:"Tier"`
Features int `json:"Features"`
Region string `json:"Region"`
City string `json:"City"`
Score float64 `json:"Score"`
HostCountry string `json:"HostCountry"`
Organization string `json:"OrganizationID"`
VPNGatewayID string `json:"VPNGatewayID"`
ID string `json:"ID"`
Load int `json:"Load"`
Location ProtonServerLocation
Status int `json:"Status"`
Servers []ProtonServer
Expand Down Expand Up @@ -350,6 +355,8 @@ type ProtonServerLocation struct {
// "ServicesDownReason": null
// }
type ProtonServer struct {
Name string `json:"Name"`
Load int `json:"Load"`
EntryIP string `json:"EntryIP"`
ExitIP string `json:"ExitIP"`
Domain string `json:"Domain"`
Expand Down Expand Up @@ -390,19 +397,22 @@ type ProtonWgConfig struct {
UID string `json:"UID"`
SessionAccessToken string `json:"SessionAccessToken"`
SessionRefreshToken string `json:"SessionRefreshToken"`
UserID string `json:"UserID"`
CredsAccessToken string `json:"UserAccessToken"`
CredsRefreshToken string `json:"UserRefreshToken"`
CertSerialNumber string `json:"CertSerialNumber"`
CertExpTime int `json:"CertExpTime"`
CertRefreshTime int `json:"CertRefreshTime"`

UserID string `json:"UserID"`
CredsAccessToken string `json:"UserAccessToken"`
CredsRefreshToken string `json:"UserRefreshToken"`

CertSerialNumber string `json:"CertSerialNumber"`
CertExpTime int `json:"CertExpTime"`
CertRefreshTime int `json:"CertRefreshTime"`

CreateTimestamp int64 `json:"CreateTimestamp"`
RegionalWgConfs []*RegionalWgConf `json:"RegionalWgConfs"`
}

type RegionalWgConf struct {
CC string `json:"CC"`
Load int `json:"Load"`
Name string `json:"Name"`

ClientAddr4 string `json:"ClientAddr4"`
ClientAddr6 string `json:"ClientAddr6"`
Expand Down Expand Up @@ -509,30 +519,12 @@ func newProtonGw(ctx context.Context, k ProtonKey, logicals []ProtonLogicals, h2
publicKeyPem = publicKeyPem[6:16]
}

m := make(map[string][]ProtonServer, 0)
skips := 0
tot := 0
for _, x := range logicals {
// github.com/ProtonVPN/android-app/blob/b9c6e59de40/app/src/main/java/com/protonvpn/android/utils/ServerManager.kt#L251
// skip premium or restricted or offline servers
if x.securecore() || x.gateway() || x.offline() {
skips++
continue
}
if c, ok := m[x.EntryCountry]; ok {
m[x.EntryCountry] = append(c, x.Servers...)
} else {
m[x.EntryCountry] = x.Servers
}
tot += len(x.Servers)
}
log.I("proton: new gw for %s: sz: l(%d) => [cc(%d) => svcs(%d) / skip: %d]",
publicKeyPem, len(logicals), len(m), tot, skips)
m := protonServersByCountry(logicals)

a := &protongw{
http: h2,
key: k,
servers: m,
servers: m, // may be empty
sched: core.NewScheduler(ctx),
sess: struct {
uid string
Expand All @@ -547,6 +539,8 @@ func newProtonGw(ctx context.Context, k ProtonKey, logicals []ProtonLogicals, h2
config: nil,
}

log.I("proton: gw: new: %s / %d", publicKeyPem, len(m))

return a, nil
}

Expand Down Expand Up @@ -599,25 +593,30 @@ func (a *protongw) newConf() error {
rwgConfs := make([]*RegionalWgConf, 0, len(a.servers))
for cc, ss := range a.servers {
wc := new(RegionalWgConf)
wc.CC = cc
wc.ClientAddr4 = protonClientAddr4
wc.ClientPrivKey = clientPrivKey
wc.ClientPubKey = clientPubKey
wc.ClientDNS4 = protonDNSAddr4

n := 0
for _, s := range ss {
if n > maxPerRegionWgConfs {
break
}
// github.com/ProtonVPN/android-app/blob/b9c6e59de40/app/src/main/java/com/protonvpn/android/utils/ServerManager.kt#L251
if s.online() && s.wg() {
n++
wc.Name = s.Name
wc.ServerPubKey = s.X25519PublicKey
wc.ServerIPPort4 = fmt.Sprintf("%s:%d", s.EntryIP, protonPrimaryTcpPort)
wc.ServerDomainPort = fmt.Sprintf("%s:%d", s.Domain, protonPrimaryTcpPort)
wc.AllowedIPs = protonAllowedIPs

rwgConfs = append(rwgConfs, wc)

log.VV("proton: genconf: %s: %s@%s; peer: %s@%s",
cc, wc.ClientAddr4, wc.ClientPubKey[:6], wc.ServerIPPort4, wc.ServerPubKey[:6])

break
log.VV("proton: genconf: %s n:%d, l:%d; x: %s@%s; p: %s@%s",
s.Name, n, s.Load, wc.ClientAddr4, wc.ClientPubKey[:6], wc.ServerIPPort4, wc.ServerPubKey[:6])
}
}
}
Expand All @@ -643,13 +642,14 @@ func (a *protongw) newConf() error {
// wg info
pc.RegionalWgConfs = rwgConfs

pc.CreateTimestamp = time.Now().Unix()

a.config = pc

return nil // success
}

func (a *protongw) registerCert() error {
const maxRegisterCertTries = 3
tries := 0

retryAfterRefresh:
Expand Down Expand Up @@ -1041,6 +1041,20 @@ func (a *protongw) reg() error {
return a.newConf()
}

func (a *protongw) refreshServers() error {
const nofile = ""

oldEnough := time.Since(protonLogicalsUpdateTime) > maxProtonLogicalsRefreshThreshold
missingConfig := a.config == nil || len(a.config.RegionalWgConfs) <= 0
if oldEnough || missingConfig {
t := protonLogicalsUpdateTime.Format(time.RFC1123)
log.I("proton: refresh servers; old(%s)? %t / missing? %t", oldEnough, t, missingConfig)
a.servers = protonServersByCountry(protonServersFrom(nofile, a.http))
}

return nil
}

func (a *protongw) rereg() error {
if len(a.sess.uid) <= 0 {
log.W("proton: re-reg: no session; initiating reg")
Expand Down Expand Up @@ -1102,8 +1116,7 @@ func (w *Client) MakeProtonWgFrom(ctx context.Context, fromConfigJson []byte, al
return nil, err
}

svcs := protonServersFrom(allServersFilePath, &w.h2)

svcs := protonServersPrebuilt() // refreshed if needed later
a, err := newProtonGw(ctx, k, svcs, &w.h2)
if err != nil {
return nil, err
Expand All @@ -1119,6 +1132,11 @@ func (w *Client) MakeProtonWgFrom(ctx context.Context, fromConfigJson []byte, al
return nil, err
}

err = a.refreshServers()
if err != nil {
return nil, err
}

return a.config, nil
}

Expand All @@ -1138,16 +1156,52 @@ func (a *protongw) load(conf *ProtonWgConfig) error {
a.cert.ExpirationTime = conf.CertExpTime
a.cert.RefreshTime = conf.CertRefreshTime

protonLogicalsUpdateTime = time.Unix(conf.CreateTimestamp, 0)

return nil
}

// go.dev/play/p/9kapzPiG72r
func protonServersFrom(allServersFilePath string, c *http.Client) []ProtonLogicals {
var prebuilts, all ProtonServerResponse
func protonServersByCountry(logicals []ProtonLogicals) map[string][]ProtonServer {
m := make(map[string][]ProtonServer, 0)
skips := 0
tot := 0
for _, x := range logicals {
// github.com/ProtonVPN/android-app/blob/b9c6e59de40/app/src/main/java/com/protonvpn/android/utils/ServerManager.kt#L251
// skip premium or restricted or offline servers
if x.securecore() || x.gateway() || x.offline() {
skips++
continue
}
for _, s := range x.Servers {
s.Load = x.Load
s.Name = x.Name
}
if c, ok := m[x.EntryCountry]; ok {
m[x.EntryCountry] = append(c, x.Servers...)
} else {
m[x.EntryCountry] = x.Servers
}
tot += len(x.Servers)
}
log.I("proton: servers: sz: l(%d) => [cc(%d) => svcs(%d) / skip: %d]",
len(logicals), len(m), tot, skips)
return m
}

func protonServersPrebuilt() []ProtonLogicals {
var prebuilts []ProtonLogicals
err := json.Unmarshal(prebuiltProtonServersJson, &prebuilts)
if err != nil {
log.E("proton: servers: %d unmarshal: %v", len(prebuiltProtonServersJson), err)
}
return prebuilts
}

// go.dev/play/p/9kapzPiG72r
func protonServersFrom(allServersFilePath string, c *http.Client) []ProtonLogicals {
var all ProtonServerResponse

prebuilts := protonServersPrebuilt()

if len(allServersFilePath) > 0 {
fp := filepath.Clean(allServersFilePath)
Expand Down Expand Up @@ -1224,16 +1278,15 @@ func protonServersFrom(allServersFilePath string, c *http.Client) []ProtonLogica
_, err = f.Write(b)
if err != nil {
log.E("proton: servers: write %s, err: %v", fp, err)
} else {
protonLogicalsUpdateTime = time.Now()
}
} // else: written
}
}
} // else: no-store
protonLogicalsUpdateTime = time.Now()
}
}
}
}
}
}
return append(all.R, prebuilts.R...)
} // else: contains remote servers
return append(all.R, prebuilts...)
}

0 comments on commit 62f6f52

Please sign in to comment.