diff options
| author | sinner <[email protected]> | 2026-04-15 15:16:02 -0400 |
|---|---|---|
| committer | sinner <[email protected]> | 2026-04-15 15:16:02 -0400 |
| commit | a5f907854f29e1c267ad30d1dfe85c2c47f5ac48 (patch) | |
| tree | bc8685c3b22e6d5d47702ba0607c694f938ba7fd /internal/client | |
| parent | 8a1cf20dd5014ebe15ced77344902b79dcfa2e66 (diff) | |
| download | dborg-master.tar.gz dborg-master.zip | |
Diffstat (limited to 'internal/client')
| -rw-r--r-- | internal/client/client.go | 176 | ||||
| -rw-r--r-- | internal/client/retry_test.go | 247 |
2 files changed, 380 insertions, 43 deletions
diff --git a/internal/client/client.go b/internal/client/client.go index 4aa6f06..4ca61e4 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -4,11 +4,16 @@ import ( "bytes" "encoding/json" "fmt" - "git.db.org.ai/dborg/internal/config" "io" + "math/rand" "net/http" "net/url" + "os" + "strconv" + "strings" "time" + + "git.db.org.ai/dborg/internal/config" ) type Client struct { @@ -20,93 +25,178 @@ func New(cfg *config.Config) (*Client, error) { if err := cfg.Validate(); err != nil { return nil, err } - - return &Client{ - config: cfg, - httpClient: &http.Client{ - Timeout: cfg.Timeout, - }, - }, nil + return newClient(cfg), nil } func NewUnauthenticated(cfg *config.Config) (*Client, error) { + return newClient(cfg), nil +} + +func newClient(cfg *config.Config) *Client { return &Client{ config: cfg, httpClient: &http.Client{ Timeout: cfg.Timeout, }, - }, nil + } +} + +func (c *Client) debugf(format string, args ...interface{}) { + if !c.config.Debug { + return + } + fmt.Fprintf(os.Stderr, "[dborg] "+format+"\n", args...) +} + +func redactKey(key string) string { + if len(key) <= 8 { + return "***" + } + return key[:4] + "..." + key[len(key)-4:] +} + +func isRetryable(statusCode int) bool { + return statusCode == http.StatusTooManyRequests || + statusCode == http.StatusRequestTimeout || + statusCode >= 500 +} + +func backoffDelay(attempt int, retryAfter string) time.Duration { + if retryAfter != "" { + if secs, err := strconv.Atoi(strings.TrimSpace(retryAfter)); err == nil && secs > 0 { + return time.Duration(secs) * time.Second + } + if t, err := http.ParseTime(retryAfter); err == nil { + if d := time.Until(t); d > 0 { + return d + } + } + } + base := time.Duration(1<<attempt) * time.Second + if base > 30*time.Second { + base = 30 * time.Second + } + jitter := time.Duration(rand.Int63n(int64(base) / 2)) + return base + jitter } func (c *Client) doRequest(method, path string, params url.Values, body interface{}) ([]byte, error) { fullURL := c.config.BaseURL + path - if params != nil && len(params) > 0 { + if len(params) > 0 { fullURL += "?" + params.Encode() } - var reqBody io.Reader + var bodyBytes []byte if body != nil { - jsonData, err := json.Marshal(body) + var err error + bodyBytes, err = json.Marshal(body) if err != nil { return nil, fmt.Errorf("failed to marshal request body: %w", err) } - reqBody = bytes.NewBuffer(jsonData) - } - - req, err := http.NewRequest(method, fullURL, reqBody) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("X-API-Key", c.config.APIKey) - req.Header.Set("User-Agent", c.config.UserAgent) - if body != nil { - req.Header.Set("Content-Type", "application/json") + c.debugf("→ %s %s (api_key=%s)", method, fullURL, redactKey(c.config.APIKey)) + if len(bodyBytes) > 0 { + c.debugf(" body: %s", string(bodyBytes)) } - var resp *http.Response var lastErr error for attempt := 0; attempt <= c.config.MaxRetries; attempt++ { if attempt > 0 { - time.Sleep(time.Duration(attempt) * time.Second) + delay := backoffDelay(attempt, "") + if lastErr != nil { + if ra, ok := retryAfterFromErr(lastErr); ok { + delay = backoffDelay(attempt, ra) + } + } + c.debugf(" retry %d/%d after %s (last error: %v)", attempt, c.config.MaxRetries, delay, lastErr) + time.Sleep(delay) } - resp, err = c.httpClient.Do(req) + var reqBody io.Reader + if bodyBytes != nil { + reqBody = bytes.NewReader(bodyBytes) + } + + req, err := http.NewRequest(method, fullURL, reqBody) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("X-API-Key", c.config.APIKey) + req.Header.Set("User-Agent", c.config.UserAgent) + if bodyBytes != nil { + req.Header.Set("Content-Type", "application/json") + } + + start := time.Now() + resp, err := c.httpClient.Do(req) if err != nil { lastErr = err + c.debugf("← network error after %s: %v", time.Since(start), err) continue } - defer resp.Body.Close() + respBody, readErr := io.ReadAll(resp.Body) + resp.Body.Close() + c.debugf("← %d %s (%s, %d bytes)", resp.StatusCode, http.StatusText(resp.StatusCode), time.Since(start), len(respBody)) - if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated { - return io.ReadAll(resp.Body) + if readErr != nil { + lastErr = fmt.Errorf("failed to read response body: %w", readErr) + if !isRetryable(resp.StatusCode) { + return nil, lastErr + } + continue } - bodyBytes, _ := io.ReadAll(resp.Body) - - switch resp.StatusCode { - case http.StatusForbidden: - lastErr = fmt.Errorf("access denied (403): %s - This endpoint requires premium access", string(bodyBytes)) - case http.StatusUnauthorized: - lastErr = fmt.Errorf("unauthorized (401): %s - Check your API key", string(bodyBytes)) - case http.StatusTooManyRequests: - lastErr = fmt.Errorf("rate limit exceeded (429): %s", string(bodyBytes)) - case http.StatusBadRequest: - lastErr = fmt.Errorf("bad request (400): %s", string(bodyBytes)) - default: - lastErr = fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) + if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated { + return respBody, nil } - if resp.StatusCode != http.StatusTooManyRequests && resp.StatusCode < 500 { - break + lastErr = httpError(resp.StatusCode, respBody, resp.Header.Get("Retry-After")) + + if !isRetryable(resp.StatusCode) { + return nil, lastErr } } return nil, lastErr } +type apiError struct { + status int + message string + retryAfter string +} + +func (e *apiError) Error() string { return e.message } + +func retryAfterFromErr(err error) (string, bool) { + if ae, ok := err.(*apiError); ok && ae.retryAfter != "" { + return ae.retryAfter, true + } + return "", false +} + +func httpError(status int, body []byte, retryAfter string) error { + msg := string(body) + var formatted string + switch status { + case http.StatusForbidden: + formatted = fmt.Sprintf("access denied (403): %s - This endpoint requires premium access", msg) + case http.StatusUnauthorized: + formatted = fmt.Sprintf("unauthorized (401): %s - Check your API key", msg) + case http.StatusTooManyRequests: + formatted = fmt.Sprintf("rate limit exceeded (429): %s", msg) + case http.StatusBadRequest: + formatted = fmt.Sprintf("bad request (400): %s", msg) + default: + formatted = fmt.Sprintf("API request failed with status %d: %s", status, msg) + } + return &apiError{status: status, message: formatted, retryAfter: retryAfter} +} + func (c *Client) Get(path string, params url.Values) ([]byte, error) { return c.doRequest(http.MethodGet, path, params, nil) } diff --git a/internal/client/retry_test.go b/internal/client/retry_test.go new file mode 100644 index 0000000..532ddb7 --- /dev/null +++ b/internal/client/retry_test.go @@ -0,0 +1,247 @@ +package client + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "git.db.org.ai/dborg/internal/config" +) + +func testConfig(baseURL string) *config.Config { + return &config.Config{ + APIKey: "test-key", + BaseURL: baseURL, + Timeout: 5 * time.Second, + MaxRetries: 3, + UserAgent: "dborg-test", + } +} + +func TestRedactKey(t *testing.T) { + tests := []struct { + in, want string + }{ + {"", "***"}, + {"short", "***"}, + {"12345678", "***"}, + {"abcd1234efgh5678", "abcd...5678"}, + } + for _, tt := range tests { + if got := redactKey(tt.in); got != tt.want { + t.Errorf("redactKey(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} + +func TestIsRetryable(t *testing.T) { + retryable := []int{429, 408, 500, 502, 503, 504} + notRetryable := []int{200, 201, 301, 400, 401, 403, 404} + + for _, code := range retryable { + if !isRetryable(code) { + t.Errorf("isRetryable(%d) = false, want true", code) + } + } + for _, code := range notRetryable { + if isRetryable(code) { + t.Errorf("isRetryable(%d) = true, want false", code) + } + } +} + +func TestBackoffDelay_RetryAfterSeconds(t *testing.T) { + d := backoffDelay(1, "5") + if d != 5*time.Second { + t.Errorf("backoffDelay with Retry-After=5 = %v, want 5s", d) + } +} + +func TestBackoffDelay_RetryAfterHTTPDate(t *testing.T) { + future := time.Now().Add(10 * time.Second).UTC().Format(http.TimeFormat) + d := backoffDelay(1, future) + if d < 5*time.Second || d > 11*time.Second { + t.Errorf("backoffDelay with HTTP date = %v, want ~10s", d) + } +} + +func TestBackoffDelay_Exponential(t *testing.T) { + d1 := backoffDelay(1, "") + d2 := backoffDelay(2, "") + d3 := backoffDelay(3, "") + if d1 < 2*time.Second || d1 > 3*time.Second { + t.Errorf("attempt 1 delay = %v, want 2-3s", d1) + } + if d2 < 4*time.Second || d2 > 6*time.Second { + t.Errorf("attempt 2 delay = %v, want 4-6s", d2) + } + if d3 < 8*time.Second || d3 > 12*time.Second { + t.Errorf("attempt 3 delay = %v, want 8-12s", d3) + } +} + +func TestDoRequest_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-API-Key"); got != "test-key" { + t.Errorf("X-API-Key = %q, want test-key", got) + } + fmt.Fprintln(w, `{"ok":true}`) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + body, err := c.Get("/test", nil) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if !strings.Contains(string(body), `"ok":true`) { + t.Errorf("unexpected body: %s", body) + } +} + +func TestDoRequest_RetriesOn500ThenSucceeds(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n < 3 { + http.Error(w, "boom", http.StatusInternalServerError) + return + } + fmt.Fprintln(w, `{"ok":true}`) + })) + defer srv.Close() + + cfg := testConfig(srv.URL) + cfg.MaxRetries = 5 + c, _ := New(cfg) + body, err := c.Get("/test", nil) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if atomic.LoadInt32(&calls) != 3 { + t.Errorf("calls = %d, want 3", calls) + } + if !strings.Contains(string(body), `"ok":true`) { + t.Errorf("unexpected body: %s", body) + } +} + +func TestDoRequest_NoRetryOn400(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + http.Error(w, "bad", http.StatusBadRequest) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + _, err := c.Get("/test", nil) + if err == nil { + t.Fatal("expected error") + } + if atomic.LoadInt32(&calls) != 1 { + t.Errorf("calls = %d, want 1 (no retries on 400)", calls) + } +} + +func TestDoRequest_NoRetryOn401(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + http.Error(w, "no auth", http.StatusUnauthorized) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + _, err := c.Get("/test", nil) + if err == nil { + t.Fatal("expected error") + } + if atomic.LoadInt32(&calls) != 1 { + t.Errorf("calls = %d, want 1", calls) + } +} + +func TestDoRequest_ExhaustsRetries(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + http.Error(w, "down", http.StatusBadGateway) + })) + defer srv.Close() + + cfg := testConfig(srv.URL) + cfg.MaxRetries = 2 + c, _ := New(cfg) + _, err := c.Get("/test", nil) + if err == nil { + t.Fatal("expected error after exhausting retries") + } + if atomic.LoadInt32(&calls) != 3 { + t.Errorf("calls = %d, want 3 (initial + 2 retries)", calls) + } +} + +func TestDoRequest_PostBodyResentOnRetry(t *testing.T) { + var calls int32 + var bodies []string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := io.ReadAll(r.Body) + bodies = append(bodies, string(b)) + n := atomic.AddInt32(&calls, 1) + if n < 2 { + http.Error(w, "boom", http.StatusServiceUnavailable) + return + } + fmt.Fprintln(w, `{"ok":true}`) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + payload := map[string]string{"hello": "world"} + _, err := c.Post("/test", payload) + if err != nil { + t.Fatalf("Post() error = %v", err) + } + if len(bodies) != 2 { + t.Fatalf("got %d requests, want 2", len(bodies)) + } + if bodies[0] != bodies[1] { + t.Errorf("body mismatch between retries:\n first: %q\n second: %q", bodies[0], bodies[1]) + } + if !strings.Contains(bodies[0], `"hello":"world"`) { + t.Errorf("unexpected body: %q", bodies[0]) + } +} + +func TestDoRequest_RetryAfterHeaderHonored(t *testing.T) { + var calls int32 + var firstAt, secondAt time.Time + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n == 1 { + firstAt = time.Now() + w.Header().Set("Retry-After", "1") + http.Error(w, "slow down", http.StatusTooManyRequests) + return + } + secondAt = time.Now() + fmt.Fprintln(w, `{"ok":true}`) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + _, err := c.Get("/test", nil) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + gap := secondAt.Sub(firstAt) + if gap < 900*time.Millisecond { + t.Errorf("retry gap = %v, want >= 1s (Retry-After honored)", gap) + } +} |
