Merge pull request #40 from Snawoot/try_hard

Retry init
This commit is contained in:
Snawoot
2024-11-05 23:02:55 +02:00
committed by GitHub
2 changed files with 75 additions and 38 deletions

View File

@@ -70,6 +70,8 @@ windscribe-proxy -list-proxies
| cafile | String | use custom CA certificate bundle file |
| fake-sni | String | fake SNI to use to contact windscribe servers (default "com") |
| force-cold-init | - | force cold init |
| init-retries | Number | number of attempts for initialization steps, zero for unlimited retry |
| init-retry-interval | Duration | delay between initialization retries (default 5s) |
| list-locations | - | list available locations and exit |
| list-proxies | - | output proxy list and exit |
| location | String | desired proxy location. Default: best location |

111
main.go
View File

@@ -46,23 +46,25 @@ func arg_fail(msg string) {
}
type CLIArgs struct {
location string
listLocations bool
listProxies bool
bindAddress string
verbosity int
timeout time.Duration
showVersion bool
proxy string
resolver string
caFile string
clientAuthSecret string
stateFile string
username string
password string
tfacode string
fakeSNI string
forceColdInit bool
location string
listLocations bool
listProxies bool
bindAddress string
verbosity int
timeout time.Duration
showVersion bool
proxy string
resolver string
caFile string
clientAuthSecret string
stateFile string
username string
password string
tfacode string
fakeSNI string
forceColdInit bool
initRetries int
initRetryInterval time.Duration
}
func parse_args() CLIArgs {
@@ -92,6 +94,8 @@ func parse_args() CLIArgs {
flag.StringVar(&args.tfacode, "2fa", "", "2FA code for login")
flag.StringVar(&args.fakeSNI, "fake-sni", "com", "fake SNI to use to contact windscribe servers")
flag.BoolVar(&args.forceColdInit, "force-cold-init", false, "force cold init")
flag.IntVar(&args.initRetries, "init-retries", 0, "number of attempts for initialization steps, zero for unlimited retry")
flag.DurationVar(&args.initRetryInterval, "init-retry-interval", 5*time.Second, "delay between initialization retries")
flag.Parse()
if args.listLocations && args.listProxies {
arg_fail("list-locations and list-proxies flags are mutually exclusive")
@@ -189,9 +193,9 @@ func run() int {
mainLogger.Critical("Unable to construct WndClient: %v", err)
return 8
}
wndc.Mux.Lock()
wndc.State.Settings.ClientAuthSecret = args.clientAuthSecret
wndc.Mux.Unlock()
try := retryPolicy(args.initRetries, args.initRetryInterval, mainLogger)
// Try to resurrect state
state, err := maybeLoadState(args.forceColdInit, args.stateFile)
@@ -202,7 +206,7 @@ func run() int {
default:
mainLogger.Warning("Failed to load client state: %v. It is OK for a first run. Performing cold init...", err)
}
err = coldInit(wndc, args.username, args.password, args.tfacode, args.timeout)
err = coldInit(wndc, try, args.username, args.password, args.tfacode, args.timeout)
if err != nil {
mainLogger.Critical("Cold init failed: %v", err)
return 9
@@ -212,18 +216,19 @@ func run() int {
mainLogger.Error("Unable to save state file! Error: %v", err)
}
} else {
wndc.Mux.Lock()
wndc.State = *state
wndc.Mux.Unlock()
}
var serverList wndclient.ServerList
if args.listProxies || args.listLocations || args.location != "" {
ctx, cl := context.WithTimeout(context.Background(), args.timeout)
serverList, err = wndc.ServerList(ctx)
cl()
err := try("retrieve server list", func() error {
ctx, cl := context.WithTimeout(context.Background(), args.timeout)
defer cl()
l, err := wndc.ServerList(ctx)
serverList = l
return err
})
if err != nil {
mainLogger.Critical("Server list retrieve failed: %v", err)
return 12
}
}
@@ -239,14 +244,18 @@ func run() int {
var proxyHostname string
if args.location == "" {
ctx, cl := context.WithTimeout(context.Background(), args.timeout)
bestLocation, err := wndc.BestLocation(ctx)
cl()
err := try("find best location", func() error {
ctx, cl := context.WithTimeout(context.Background(), args.timeout)
defer cl()
bestLocation, err := wndc.BestLocation(ctx)
if err == nil {
proxyHostname = bestLocation.Hostname
}
return err
})
if err != nil {
mainLogger.Critical("Unable to get best location endpoint: %v", err)
return 13
}
proxyHostname = bestLocation.Hostname
} else {
proxyHostname = pickServer(serverList, args.location)
if proxyHostname == "" {
@@ -353,6 +362,7 @@ func pickServer(serverList wndclient.ServerList, location string) string {
}
var errColdInitForced = errors.New("cold init forced!")
func maybeLoadState(forceColdInit bool, filename string) (*wndclient.WndClientState, error) {
if forceColdInit {
return nil, errColdInitForced
@@ -390,20 +400,24 @@ func saveState(filename string, state *wndclient.WndClientState) error {
return err
}
func coldInit(wndc *wndclient.WndClient, username, password, tfacode string, timeout time.Duration) error {
func coldInit(wndc *wndclient.WndClient, try func(string, func() error) error, username, password, tfacode string, timeout time.Duration) error {
if username == "" || password == "" {
return errors.New(`Please provide "-username" and "-password" command line arguments!`)
}
ctx, cl := context.WithTimeout(context.Background(), timeout)
err := wndc.Session(ctx, username, password, tfacode)
cl()
err := try("init session", func() error {
ctx, cl := context.WithTimeout(context.Background(), timeout)
defer cl()
return wndc.Session(ctx, username, password, tfacode)
})
if err != nil {
return fmt.Errorf("Session call failed: %w", err)
}
ctx, cl = context.WithTimeout(context.Background(), timeout)
err = wndc.ServerCredentials(ctx)
cl()
err = try("get server credentials", func() error {
ctx, cl := context.WithTimeout(context.Background(), timeout)
defer cl()
return wndc.ServerCredentials(ctx)
})
if err != nil {
return fmt.Errorf("ServerCredentials call failed: %w", err)
}
@@ -414,3 +428,24 @@ func coldInit(wndc *wndclient.WndClient, username, password, tfacode string, tim
func main() {
os.Exit(run())
}
func retryPolicy(retries int, retryInterval time.Duration, logger *CondLogger) func(string, func() error) error {
return func(name string, f func() error) error {
var err error
for i := 1; retries <= 0 || i <= retries; i++ {
if i > 1 {
logger.Warning("Retrying action %q in %v...", name, retryInterval)
time.Sleep(retryInterval)
}
logger.Info("Attempting action %q, attempt #%d...", name, i)
err = f()
if err == nil {
logger.Info("Action %q succeeded on attempt #%d", name, i)
return nil
}
logger.Warning("Action %q failed: %v", name, err)
}
logger.Critical("All attempts for action %q have failed. Last error: %v", name, err)
return err
}
}