diff --git a/main.go b/main.go index d3e4004..2379a6b 100644 --- a/main.go +++ b/main.go @@ -64,6 +64,8 @@ type CLIArgs struct { maxPause time.Duration backoffInitial time.Duration backoffDeadline time.Duration + initRetries int + initRetryInterval time.Duration hideSNI bool userAgent string } @@ -84,6 +86,8 @@ func parse_args() CLIArgs { flag.DurationVar(&args.rotate, "rotate", 1*time.Hour, "rotate user ID once per given period") flag.DurationVar(&args.backoffInitial, "backoff-initial", 3*time.Second, "initial average backoff delay for zgettunnels (randomized by +/-50%)") flag.DurationVar(&args.backoffDeadline, "backoff-deadline", 5*time.Minute, "total duration of zgettunnels method attempts") + flag.IntVar(&args.initRetries, "init-retries", 0, "number of attempts for initialization steps, zero for unlimited retry") + flag.DurationVar(&args.initRetryInterval, "init-retry-interval", 5*time.Second, "delay between initialization retries") flag.StringVar(&args.proxy_type, "proxy-type", "direct", "proxy type: direct or lum") // or skip but not mentioned // skip would be used something like this: `./bin/hola-proxy -proxy-type skip -force-port-field 24232 -country ua.peer` for debugging flag.StringVar(&args.resolver, "resolver", "https://cloudflare-dns.com/dns-query", @@ -180,28 +184,32 @@ func run() int { SetUserAgent(args.userAgent) + try := retryPolicy(args.initRetries, args.initRetryInterval, mainLogger) + if args.list_countries { - return print_countries(args.timeout) + return print_countries(try, args.timeout) } + mainLogger.Info("hola-proxy client version %s is starting...", version) if args.extVer == "" { - ctx, cl := context.WithTimeout(context.Background(), args.timeout) - defer cl() - extVer, err := GetExtVer(ctx, nil, HolaExtStoreID, dialer) + err := try("get latest version of browser extension", func() error { + ctx, cl := context.WithTimeout(context.Background(), args.timeout) + defer cl() + extVer, err := GetExtVer(ctx, nil, HolaExtStoreID, dialer) + args.extVer = extVer + return err + }) if err != nil { mainLogger.Critical("Can't detect latest API version. Try to specify -ext-ver parameter. Error: %v", err) return 8 } - args.extVer = extVer mainLogger.Warning("Detected latest extension version: %q. Pass -ext-ver parameter to skip resolve and speedup startup", args.extVer) - cl() } if args.list_proxies { - return print_proxies(mainLogger, args.extVer, args.country, args.proxy_type, args.limit, args.timeout, + return print_proxies(try, mainLogger, args.extVer, args.country, args.proxy_type, args.limit, args.timeout, args.backoffInitial, args.backoffDeadline) } - mainLogger.Info("hola-proxy client version %s is starting...", version) mainLogger.Info("Constructing fallback DNS upstream...") resolver, err := NewResolver(args.resolver, args.timeout) if err != nil { @@ -209,11 +217,16 @@ func run() int { return 6 } - mainLogger.Info("Initializing configuration provider...") - auth, tunnels, err := CredService(args.rotate, args.timeout, args.extVer, args.country, - args.proxy_type, credLogger, args.backoffInitial, args.backoffDeadline) + var ( + auth AuthProvider + tunnels *ZGetTunnelsResponse + ) + err = try("run credentials service", func() error { + auth, tunnels, err = CredService(args.rotate, args.timeout, args.extVer, args.country, + args.proxy_type, credLogger, args.backoffInitial, args.backoffDeadline) + return err + }) if err != nil { - mainLogger.Critical("Unable to instantiate credential service: %v", err) return 4 } endpoint, err := get_endpoint(tunnels, args.proxy_type, args.use_trial, args.force_port_field) @@ -236,3 +249,24 @@ func run() int { func main() { os.Exit(run()) } + +func retryPolicy(retries int, retryInterval time.Duration, logger *CondLogger) func(string, func() error) error { + return func(name string, f func() error) error { + var err error + for i := 1; retries <= 0 || i <= retries; i++ { + if i > 1 { + logger.Warning("Retrying action %q in %v...", name, retryInterval) + time.Sleep(retryInterval) + } + logger.Info("Attempting action %q, attempt #%d...", name, i) + err = f() + if err == nil { + logger.Info("Action %q succeeded on attempt #%d", name, i) + return nil + } + logger.Warning("Action %q failed: %v", name, err) + } + logger.Critical("All attempts for action %q have failed. Last error: %v", name, err) + return err + } +} diff --git a/utils.go b/utils.go index 1b05508..8cae1ca 100644 --- a/utils.go +++ b/utils.go @@ -106,27 +106,33 @@ func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer return } -func print_countries(timeout time.Duration) int { +func print_countries(try func(string, func() error) error, timeout time.Duration) int { var ( countries CountryList err error + tx_res bool + tx_err error ) - tx_res, tx_err := EnsureTransaction(context.Background(), timeout, func(ctx context.Context, client *http.Client) bool { - ctx1, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - countries, err = VPNCountries(ctx1, client) - if err != nil { - fmt.Fprintf(os.Stderr, "Transaction error: %v. Retrying with the fallback mechanism...\n", err) - return false + err = try("list VPN countries", func() error { + tx_res, tx_err = EnsureTransaction(context.Background(), timeout, func(ctx context.Context, client *http.Client) bool { + ctx1, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + countries, err = VPNCountries(ctx1, client) + if err != nil { + fmt.Fprintf(os.Stderr, "Transaction error: %v. Retrying with the fallback mechanism...\n", err) + return false + } + return true + }) + if tx_err != nil { + return fmt.Errorf("transaction recovery mechanism failure: %v", err) } - return true + if !tx_res { + return errors.New("all fallback proxies failed.") + } + return nil }) - if tx_err != nil { - fmt.Fprintf(os.Stderr, "Transaction recovery mechanism failure: %v.\n", tx_err) - return 4 - } - if !tx_res { - fmt.Fprintf(os.Stderr, "All attempts failed.") + if err != nil { return 3 } for _, code := range countries { @@ -135,28 +141,34 @@ func print_countries(timeout time.Duration) int { return 0 } -func print_proxies(logger *CondLogger, extVer, country string, proxy_type string, +func print_proxies(try func(string, func() error) error, logger *CondLogger, extVer, country string, proxy_type string, limit uint, timeout time.Duration, backoffInitial time.Duration, backoffDeadline time.Duration, ) int { var ( tunnels *ZGetTunnelsResponse user_uuid string err error + tx_res bool + tx_err error ) - tx_res, tx_err := EnsureTransaction(context.Background(), timeout, func(ctx context.Context, client *http.Client) bool { - tunnels, user_uuid, err = Tunnels(ctx, logger, client, extVer, country, proxy_type, limit, timeout, backoffInitial, backoffDeadline) - if err != nil { - fmt.Fprintf(os.Stderr, "Transaction error: %v. Retrying with the fallback mechanism...\n", err) - return false + err = try("list proxies", func() error { + tx_res, tx_err = EnsureTransaction(context.Background(), timeout, func(ctx context.Context, client *http.Client) bool { + tunnels, user_uuid, err = Tunnels(ctx, logger, client, extVer, country, proxy_type, limit, timeout, backoffInitial, backoffDeadline) + if err != nil { + fmt.Fprintf(os.Stderr, "Transaction error: %v. Retrying with the fallback mechanism...\n", err) + return false + } + return true + }) + if tx_err != nil { + return fmt.Errorf("transaction recovery mechanism failure: %v", err) } - return true + if !tx_res { + return errors.New("all fallback proxies failed.") + } + return nil }) - if tx_err != nil { - fmt.Fprintf(os.Stderr, "Transaction recovery mechanism failure: %v.\n", tx_err) - return 4 - } - if !tx_res { - fmt.Fprintf(os.Stderr, "All attempts failed.") + if err != nil { return 3 } wr := csv.NewWriter(os.Stdout)