From a5f907854f29e1c267ad30d1dfe85c2c47f5ac48 Mon Sep 17 00:00:00 2001 From: sinner Date: Wed, 15 Apr 2026 15:16:02 -0400 Subject: feat: add stdin support and retry logic for all search commands --- internal/client/client.go | 176 +++++++++++++++++++++------- internal/client/retry_test.go | 247 ++++++++++++++++++++++++++++++++++++++++ internal/config/config.go | 1 + internal/formatter/formatter.go | 6 +- internal/utils/tty.go | 19 ++++ internal/utils/version.go | 47 ++++++-- internal/utils/version_test.go | 24 ++++ 7 files changed, 460 insertions(+), 60 deletions(-) create mode 100644 internal/client/retry_test.go create mode 100644 internal/utils/tty.go (limited to 'internal') diff --git a/internal/client/client.go b/internal/client/client.go index 4aa6f06..4ca61e4 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -4,11 +4,16 @@ import ( "bytes" "encoding/json" "fmt" - "git.db.org.ai/dborg/internal/config" "io" + "math/rand" "net/http" "net/url" + "os" + "strconv" + "strings" "time" + + "git.db.org.ai/dborg/internal/config" ) type Client struct { @@ -20,93 +25,178 @@ func New(cfg *config.Config) (*Client, error) { if err := cfg.Validate(); err != nil { return nil, err } - - return &Client{ - config: cfg, - httpClient: &http.Client{ - Timeout: cfg.Timeout, - }, - }, nil + return newClient(cfg), nil } func NewUnauthenticated(cfg *config.Config) (*Client, error) { + return newClient(cfg), nil +} + +func newClient(cfg *config.Config) *Client { return &Client{ config: cfg, httpClient: &http.Client{ Timeout: cfg.Timeout, }, - }, nil + } +} + +func (c *Client) debugf(format string, args ...interface{}) { + if !c.config.Debug { + return + } + fmt.Fprintf(os.Stderr, "[dborg] "+format+"\n", args...) +} + +func redactKey(key string) string { + if len(key) <= 8 { + return "***" + } + return key[:4] + "..." + key[len(key)-4:] +} + +func isRetryable(statusCode int) bool { + return statusCode == http.StatusTooManyRequests || + statusCode == http.StatusRequestTimeout || + statusCode >= 500 +} + +func backoffDelay(attempt int, retryAfter string) time.Duration { + if retryAfter != "" { + if secs, err := strconv.Atoi(strings.TrimSpace(retryAfter)); err == nil && secs > 0 { + return time.Duration(secs) * time.Second + } + if t, err := http.ParseTime(retryAfter); err == nil { + if d := time.Until(t); d > 0 { + return d + } + } + } + base := time.Duration(1< 30*time.Second { + base = 30 * time.Second + } + jitter := time.Duration(rand.Int63n(int64(base) / 2)) + return base + jitter } func (c *Client) doRequest(method, path string, params url.Values, body interface{}) ([]byte, error) { fullURL := c.config.BaseURL + path - if params != nil && len(params) > 0 { + if len(params) > 0 { fullURL += "?" + params.Encode() } - var reqBody io.Reader + var bodyBytes []byte if body != nil { - jsonData, err := json.Marshal(body) + var err error + bodyBytes, err = json.Marshal(body) if err != nil { return nil, fmt.Errorf("failed to marshal request body: %w", err) } - reqBody = bytes.NewBuffer(jsonData) - } - - req, err := http.NewRequest(method, fullURL, reqBody) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("X-API-Key", c.config.APIKey) - req.Header.Set("User-Agent", c.config.UserAgent) - if body != nil { - req.Header.Set("Content-Type", "application/json") + c.debugf("→ %s %s (api_key=%s)", method, fullURL, redactKey(c.config.APIKey)) + if len(bodyBytes) > 0 { + c.debugf(" body: %s", string(bodyBytes)) } - var resp *http.Response var lastErr error for attempt := 0; attempt <= c.config.MaxRetries; attempt++ { if attempt > 0 { - time.Sleep(time.Duration(attempt) * time.Second) + delay := backoffDelay(attempt, "") + if lastErr != nil { + if ra, ok := retryAfterFromErr(lastErr); ok { + delay = backoffDelay(attempt, ra) + } + } + c.debugf(" retry %d/%d after %s (last error: %v)", attempt, c.config.MaxRetries, delay, lastErr) + time.Sleep(delay) + } + + var reqBody io.Reader + if bodyBytes != nil { + reqBody = bytes.NewReader(bodyBytes) } - resp, err = c.httpClient.Do(req) + req, err := http.NewRequest(method, fullURL, reqBody) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("X-API-Key", c.config.APIKey) + req.Header.Set("User-Agent", c.config.UserAgent) + if bodyBytes != nil { + req.Header.Set("Content-Type", "application/json") + } + + start := time.Now() + resp, err := c.httpClient.Do(req) if err != nil { lastErr = err + c.debugf("← network error after %s: %v", time.Since(start), err) continue } - defer resp.Body.Close() + respBody, readErr := io.ReadAll(resp.Body) + resp.Body.Close() + c.debugf("← %d %s (%s, %d bytes)", resp.StatusCode, http.StatusText(resp.StatusCode), time.Since(start), len(respBody)) - if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated { - return io.ReadAll(resp.Body) + if readErr != nil { + lastErr = fmt.Errorf("failed to read response body: %w", readErr) + if !isRetryable(resp.StatusCode) { + return nil, lastErr + } + continue } - bodyBytes, _ := io.ReadAll(resp.Body) - - switch resp.StatusCode { - case http.StatusForbidden: - lastErr = fmt.Errorf("access denied (403): %s - This endpoint requires premium access", string(bodyBytes)) - case http.StatusUnauthorized: - lastErr = fmt.Errorf("unauthorized (401): %s - Check your API key", string(bodyBytes)) - case http.StatusTooManyRequests: - lastErr = fmt.Errorf("rate limit exceeded (429): %s", string(bodyBytes)) - case http.StatusBadRequest: - lastErr = fmt.Errorf("bad request (400): %s", string(bodyBytes)) - default: - lastErr = fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) + if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated { + return respBody, nil } - if resp.StatusCode != http.StatusTooManyRequests && resp.StatusCode < 500 { - break + lastErr = httpError(resp.StatusCode, respBody, resp.Header.Get("Retry-After")) + + if !isRetryable(resp.StatusCode) { + return nil, lastErr } } return nil, lastErr } +type apiError struct { + status int + message string + retryAfter string +} + +func (e *apiError) Error() string { return e.message } + +func retryAfterFromErr(err error) (string, bool) { + if ae, ok := err.(*apiError); ok && ae.retryAfter != "" { + return ae.retryAfter, true + } + return "", false +} + +func httpError(status int, body []byte, retryAfter string) error { + msg := string(body) + var formatted string + switch status { + case http.StatusForbidden: + formatted = fmt.Sprintf("access denied (403): %s - This endpoint requires premium access", msg) + case http.StatusUnauthorized: + formatted = fmt.Sprintf("unauthorized (401): %s - Check your API key", msg) + case http.StatusTooManyRequests: + formatted = fmt.Sprintf("rate limit exceeded (429): %s", msg) + case http.StatusBadRequest: + formatted = fmt.Sprintf("bad request (400): %s", msg) + default: + formatted = fmt.Sprintf("API request failed with status %d: %s", status, msg) + } + return &apiError{status: status, message: formatted, retryAfter: retryAfter} +} + func (c *Client) Get(path string, params url.Values) ([]byte, error) { return c.doRequest(http.MethodGet, path, params, nil) } diff --git a/internal/client/retry_test.go b/internal/client/retry_test.go new file mode 100644 index 0000000..532ddb7 --- /dev/null +++ b/internal/client/retry_test.go @@ -0,0 +1,247 @@ +package client + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "git.db.org.ai/dborg/internal/config" +) + +func testConfig(baseURL string) *config.Config { + return &config.Config{ + APIKey: "test-key", + BaseURL: baseURL, + Timeout: 5 * time.Second, + MaxRetries: 3, + UserAgent: "dborg-test", + } +} + +func TestRedactKey(t *testing.T) { + tests := []struct { + in, want string + }{ + {"", "***"}, + {"short", "***"}, + {"12345678", "***"}, + {"abcd1234efgh5678", "abcd...5678"}, + } + for _, tt := range tests { + if got := redactKey(tt.in); got != tt.want { + t.Errorf("redactKey(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} + +func TestIsRetryable(t *testing.T) { + retryable := []int{429, 408, 500, 502, 503, 504} + notRetryable := []int{200, 201, 301, 400, 401, 403, 404} + + for _, code := range retryable { + if !isRetryable(code) { + t.Errorf("isRetryable(%d) = false, want true", code) + } + } + for _, code := range notRetryable { + if isRetryable(code) { + t.Errorf("isRetryable(%d) = true, want false", code) + } + } +} + +func TestBackoffDelay_RetryAfterSeconds(t *testing.T) { + d := backoffDelay(1, "5") + if d != 5*time.Second { + t.Errorf("backoffDelay with Retry-After=5 = %v, want 5s", d) + } +} + +func TestBackoffDelay_RetryAfterHTTPDate(t *testing.T) { + future := time.Now().Add(10 * time.Second).UTC().Format(http.TimeFormat) + d := backoffDelay(1, future) + if d < 5*time.Second || d > 11*time.Second { + t.Errorf("backoffDelay with HTTP date = %v, want ~10s", d) + } +} + +func TestBackoffDelay_Exponential(t *testing.T) { + d1 := backoffDelay(1, "") + d2 := backoffDelay(2, "") + d3 := backoffDelay(3, "") + if d1 < 2*time.Second || d1 > 3*time.Second { + t.Errorf("attempt 1 delay = %v, want 2-3s", d1) + } + if d2 < 4*time.Second || d2 > 6*time.Second { + t.Errorf("attempt 2 delay = %v, want 4-6s", d2) + } + if d3 < 8*time.Second || d3 > 12*time.Second { + t.Errorf("attempt 3 delay = %v, want 8-12s", d3) + } +} + +func TestDoRequest_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-API-Key"); got != "test-key" { + t.Errorf("X-API-Key = %q, want test-key", got) + } + fmt.Fprintln(w, `{"ok":true}`) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + body, err := c.Get("/test", nil) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if !strings.Contains(string(body), `"ok":true`) { + t.Errorf("unexpected body: %s", body) + } +} + +func TestDoRequest_RetriesOn500ThenSucceeds(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n < 3 { + http.Error(w, "boom", http.StatusInternalServerError) + return + } + fmt.Fprintln(w, `{"ok":true}`) + })) + defer srv.Close() + + cfg := testConfig(srv.URL) + cfg.MaxRetries = 5 + c, _ := New(cfg) + body, err := c.Get("/test", nil) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if atomic.LoadInt32(&calls) != 3 { + t.Errorf("calls = %d, want 3", calls) + } + if !strings.Contains(string(body), `"ok":true`) { + t.Errorf("unexpected body: %s", body) + } +} + +func TestDoRequest_NoRetryOn400(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + http.Error(w, "bad", http.StatusBadRequest) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + _, err := c.Get("/test", nil) + if err == nil { + t.Fatal("expected error") + } + if atomic.LoadInt32(&calls) != 1 { + t.Errorf("calls = %d, want 1 (no retries on 400)", calls) + } +} + +func TestDoRequest_NoRetryOn401(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + http.Error(w, "no auth", http.StatusUnauthorized) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + _, err := c.Get("/test", nil) + if err == nil { + t.Fatal("expected error") + } + if atomic.LoadInt32(&calls) != 1 { + t.Errorf("calls = %d, want 1", calls) + } +} + +func TestDoRequest_ExhaustsRetries(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + http.Error(w, "down", http.StatusBadGateway) + })) + defer srv.Close() + + cfg := testConfig(srv.URL) + cfg.MaxRetries = 2 + c, _ := New(cfg) + _, err := c.Get("/test", nil) + if err == nil { + t.Fatal("expected error after exhausting retries") + } + if atomic.LoadInt32(&calls) != 3 { + t.Errorf("calls = %d, want 3 (initial + 2 retries)", calls) + } +} + +func TestDoRequest_PostBodyResentOnRetry(t *testing.T) { + var calls int32 + var bodies []string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := io.ReadAll(r.Body) + bodies = append(bodies, string(b)) + n := atomic.AddInt32(&calls, 1) + if n < 2 { + http.Error(w, "boom", http.StatusServiceUnavailable) + return + } + fmt.Fprintln(w, `{"ok":true}`) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + payload := map[string]string{"hello": "world"} + _, err := c.Post("/test", payload) + if err != nil { + t.Fatalf("Post() error = %v", err) + } + if len(bodies) != 2 { + t.Fatalf("got %d requests, want 2", len(bodies)) + } + if bodies[0] != bodies[1] { + t.Errorf("body mismatch between retries:\n first: %q\n second: %q", bodies[0], bodies[1]) + } + if !strings.Contains(bodies[0], `"hello":"world"`) { + t.Errorf("unexpected body: %q", bodies[0]) + } +} + +func TestDoRequest_RetryAfterHeaderHonored(t *testing.T) { + var calls int32 + var firstAt, secondAt time.Time + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n == 1 { + firstAt = time.Now() + w.Header().Set("Retry-After", "1") + http.Error(w, "slow down", http.StatusTooManyRequests) + return + } + secondAt = time.Now() + fmt.Fprintln(w, `{"ok":true}`) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + _, err := c.Get("/test", nil) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + gap := secondAt.Sub(firstAt) + if gap < 900*time.Millisecond { + t.Errorf("retry gap = %v, want >= 1s (Retry-After honored)", gap) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 9df03bf..0da7541 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -36,6 +36,7 @@ func New() *Config { Timeout: 30 * time.Second, MaxRetries: 3, UserAgent: "dborg-cli/1.0", + Debug: os.Getenv("DBORG_DEBUG") != "", } } diff --git a/internal/formatter/formatter.go b/internal/formatter/formatter.go index bb36fa8..159f027 100644 --- a/internal/formatter/formatter.go +++ b/internal/formatter/formatter.go @@ -64,11 +64,7 @@ func (f *BaseFormatter) FormatJSON(data any) error { } func isTerminal() bool { - fileInfo, err := os.Stdout.Stat() - if err != nil { - return false - } - return (fileInfo.Mode() & os.ModeCharDevice) != 0 + return utils.IsTerminal() } func GetTerminalWidth() int { diff --git a/internal/utils/tty.go b/internal/utils/tty.go new file mode 100644 index 0000000..4ac367d --- /dev/null +++ b/internal/utils/tty.go @@ -0,0 +1,19 @@ +package utils + +import ( + "os" + + "golang.org/x/term" +) + +func IsTerminal() bool { + return term.IsTerminal(int(os.Stdout.Fd())) +} + +func IsStderrTerminal() bool { + return term.IsTerminal(int(os.Stderr.Fd())) +} + +func isTerminal() bool { + return IsTerminal() +} diff --git a/internal/utils/version.go b/internal/utils/version.go index 46ebb73..ab72bfb 100644 --- a/internal/utils/version.go +++ b/internal/utils/version.go @@ -4,7 +4,9 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "runtime/debug" + "strconv" "strings" "syscall" @@ -83,10 +85,21 @@ func isNewerVersion(remote, local string) bool { localParts := strings.Split(local, ".") for i := 0; i < len(remoteParts) && i < len(localParts); i++ { - if remoteParts[i] > localParts[i] { + r, errR := strconv.Atoi(remoteParts[i]) + l, errL := strconv.Atoi(localParts[i]) + if errR != nil || errL != nil { + if remoteParts[i] > localParts[i] { + return true + } + if remoteParts[i] < localParts[i] { + return false + } + continue + } + if r > l { return true } - if remoteParts[i] < localParts[i] { + if r < l { return false } } @@ -125,26 +138,36 @@ func promptAndUpdate(newVersion string) { restartSelf() } +func resolveInstalledBinary() (string, error) { + if gobin := os.Getenv("GOBIN"); gobin != "" { + return filepath.Join(gobin, "dborg"), nil + } + + gopath := os.Getenv("GOPATH") + if gopath == "" { + out, err := exec.Command("go", "env", "GOPATH").Output() + if err != nil { + return "", fmt.Errorf("could not determine GOPATH: %w", err) + } + gopath = strings.TrimSpace(string(out)) + } + + return filepath.Join(gopath, "bin", "dborg"), nil +} + func restartSelf() { - executable, err := os.Executable() + newBinary, err := resolveInstalledBinary() if err != nil { - fmt.Fprintf(os.Stderr, "Failed to get executable path: %v\n", err) + fmt.Fprintf(os.Stderr, "Failed to locate installed binary: %v\n", err) os.Exit(1) } args := os.Args[1:] - err = syscall.Exec(executable, append([]string{executable}, args...), os.Environ()) + err = syscall.Exec(newBinary, append([]string{newBinary}, args...), os.Environ()) if err != nil { fmt.Fprintf(os.Stderr, "Failed to restart: %v\n", err) os.Exit(1) } } -func isTerminal() bool { - fileInfo, err := os.Stdout.Stat() - if err != nil { - return false - } - return (fileInfo.Mode() & os.ModeCharDevice) != 0 -} diff --git a/internal/utils/version_test.go b/internal/utils/version_test.go index e3a27b1..9db181a 100644 --- a/internal/utils/version_test.go +++ b/internal/utils/version_test.go @@ -72,6 +72,30 @@ func TestIsNewerVersion(t *testing.T) { local: "v0.5.0", expected: true, }, + { + name: "double digit patch newer", + remote: "v1.0.12", + local: "v1.0.9", + expected: true, + }, + { + name: "double digit patch older", + remote: "v1.0.9", + local: "v1.0.12", + expected: false, + }, + { + name: "double digit patch same", + remote: "v1.0.12", + local: "v1.0.12", + expected: false, + }, + { + name: "double digit minor newer", + remote: "v1.10.0", + local: "v1.9.0", + expected: true, + }, } for _, tt := range tests { -- cgit v1.2.3