diff --git a/README.md b/README.md index 3f77fca..1cfaf6d 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,8 @@ windscribe-proxy -list-proxies | cafile | String | use custom CA certificate bundle file | | fake-sni | String | fake SNI to use to contact windscribe servers (default "com") | | force-cold-init | - | force cold init | +| init-retries | Number | number of attempts for initialization steps, zero for unlimited retry | +| init-retry-interval | Duration | delay between initialization retries (default 5s) | | list-locations | - | list available locations and exit | | list-proxies | - | output proxy list and exit | | location | String | desired proxy location. Default: best location | diff --git a/main.go b/main.go index ab57307..8f16ac6 100644 --- a/main.go +++ b/main.go @@ -46,23 +46,25 @@ func arg_fail(msg string) { } type CLIArgs struct { - location string - listLocations bool - listProxies bool - bindAddress string - verbosity int - timeout time.Duration - showVersion bool - proxy string - resolver string - caFile string - clientAuthSecret string - stateFile string - username string - password string - tfacode string - fakeSNI string - forceColdInit bool + location string + listLocations bool + listProxies bool + bindAddress string + verbosity int + timeout time.Duration + showVersion bool + proxy string + resolver string + caFile string + clientAuthSecret string + stateFile string + username string + password string + tfacode string + fakeSNI string + forceColdInit bool + initRetries int + initRetryInterval time.Duration } func parse_args() CLIArgs { @@ -92,6 +94,8 @@ func parse_args() CLIArgs { flag.StringVar(&args.tfacode, "2fa", "", "2FA code for login") flag.StringVar(&args.fakeSNI, "fake-sni", "com", "fake SNI to use to contact windscribe servers") flag.BoolVar(&args.forceColdInit, "force-cold-init", false, "force cold init") + flag.IntVar(&args.initRetries, "init-retries", 0, "number of attempts for initialization steps, zero for unlimited retry") + flag.DurationVar(&args.initRetryInterval, "init-retry-interval", 5*time.Second, "delay between initialization retries") flag.Parse() if args.listLocations && args.listProxies { arg_fail("list-locations and list-proxies flags are mutually exclusive") @@ -189,9 +193,9 @@ func run() int { mainLogger.Critical("Unable to construct WndClient: %v", err) return 8 } - wndc.Mux.Lock() wndc.State.Settings.ClientAuthSecret = args.clientAuthSecret - wndc.Mux.Unlock() + + try := retryPolicy(args.initRetries, args.initRetryInterval, mainLogger) // Try to resurrect state state, err := maybeLoadState(args.forceColdInit, args.stateFile) @@ -202,7 +206,7 @@ func run() int { default: mainLogger.Warning("Failed to load client state: %v. It is OK for a first run. Performing cold init...", err) } - err = coldInit(wndc, args.username, args.password, args.tfacode, args.timeout) + err = coldInit(wndc, try, args.username, args.password, args.tfacode, args.timeout) if err != nil { mainLogger.Critical("Cold init failed: %v", err) return 9 @@ -212,18 +216,19 @@ func run() int { mainLogger.Error("Unable to save state file! Error: %v", err) } } else { - wndc.Mux.Lock() wndc.State = *state - wndc.Mux.Unlock() } var serverList wndclient.ServerList if args.listProxies || args.listLocations || args.location != "" { - ctx, cl := context.WithTimeout(context.Background(), args.timeout) - serverList, err = wndc.ServerList(ctx) - cl() + err := try("retrieve server list", func() error { + ctx, cl := context.WithTimeout(context.Background(), args.timeout) + defer cl() + l, err := wndc.ServerList(ctx) + serverList = l + return err + }) if err != nil { - mainLogger.Critical("Server list retrieve failed: %v", err) return 12 } } @@ -239,14 +244,18 @@ func run() int { var proxyHostname string if args.location == "" { - ctx, cl := context.WithTimeout(context.Background(), args.timeout) - bestLocation, err := wndc.BestLocation(ctx) - cl() + err := try("find best location", func() error { + ctx, cl := context.WithTimeout(context.Background(), args.timeout) + defer cl() + bestLocation, err := wndc.BestLocation(ctx) + if err == nil { + proxyHostname = bestLocation.Hostname + } + return err + }) if err != nil { - mainLogger.Critical("Unable to get best location endpoint: %v", err) return 13 } - proxyHostname = bestLocation.Hostname } else { proxyHostname = pickServer(serverList, args.location) if proxyHostname == "" { @@ -353,6 +362,7 @@ func pickServer(serverList wndclient.ServerList, location string) string { } var errColdInitForced = errors.New("cold init forced!") + func maybeLoadState(forceColdInit bool, filename string) (*wndclient.WndClientState, error) { if forceColdInit { return nil, errColdInitForced @@ -390,20 +400,24 @@ func saveState(filename string, state *wndclient.WndClientState) error { return err } -func coldInit(wndc *wndclient.WndClient, username, password, tfacode string, timeout time.Duration) error { +func coldInit(wndc *wndclient.WndClient, try func(string, func() error) error, username, password, tfacode string, timeout time.Duration) error { if username == "" || password == "" { return errors.New(`Please provide "-username" and "-password" command line arguments!`) } - ctx, cl := context.WithTimeout(context.Background(), timeout) - err := wndc.Session(ctx, username, password, tfacode) - cl() + err := try("init session", func() error { + ctx, cl := context.WithTimeout(context.Background(), timeout) + defer cl() + return wndc.Session(ctx, username, password, tfacode) + }) if err != nil { return fmt.Errorf("Session call failed: %w", err) } - ctx, cl = context.WithTimeout(context.Background(), timeout) - err = wndc.ServerCredentials(ctx) - cl() + err = try("get server credentials", func() error { + ctx, cl := context.WithTimeout(context.Background(), timeout) + defer cl() + return wndc.ServerCredentials(ctx) + }) if err != nil { return fmt.Errorf("ServerCredentials call failed: %w", err) } @@ -414,3 +428,24 @@ func coldInit(wndc *wndclient.WndClient, username, password, tfacode string, tim func main() { os.Exit(run()) } + +func retryPolicy(retries int, retryInterval time.Duration, logger *CondLogger) func(string, func() error) error { + return func(name string, f func() error) error { + var err error + for i := 1; retries <= 0 || i <= retries; i++ { + if i > 1 { + logger.Warning("Retrying action %q in %v...", name, retryInterval) + time.Sleep(retryInterval) + } + logger.Info("Attempting action %q, attempt #%d...", name, i) + err = f() + if err == nil { + logger.Info("Action %q succeeded on attempt #%d", name, i) + return nil + } + logger.Warning("Action %q failed: %v", name, err) + } + logger.Critical("All attempts for action %q have failed. Last error: %v", name, err) + return err + } +}