summaryrefslogtreecommitdiffstats
path: root/internal
diff options
context:
space:
mode:
authorsinner <[email protected]>2026-04-15 15:16:02 -0400
committersinner <[email protected]>2026-04-15 15:16:02 -0400
commita5f907854f29e1c267ad30d1dfe85c2c47f5ac48 (patch)
treebc8685c3b22e6d5d47702ba0607c694f938ba7fd /internal
parent8a1cf20dd5014ebe15ced77344902b79dcfa2e66 (diff)
downloaddborg-master.tar.gz
dborg-master.zip
feat: add stdin support and retry logic for all search commandsHEADv1.1.1v0.1.14master
Diffstat (limited to 'internal')
-rw-r--r--internal/client/client.go176
-rw-r--r--internal/client/retry_test.go247
-rw-r--r--internal/config/config.go1
-rw-r--r--internal/formatter/formatter.go6
-rw-r--r--internal/utils/tty.go19
-rw-r--r--internal/utils/version.go47
-rw-r--r--internal/utils/version_test.go24
7 files changed, 460 insertions, 60 deletions
diff --git a/internal/client/client.go b/internal/client/client.go
index 4aa6f06..4ca61e4 100644
--- a/internal/client/client.go
+++ b/internal/client/client.go
@@ -4,11 +4,16 @@ import (
"bytes"
"encoding/json"
"fmt"
- "git.db.org.ai/dborg/internal/config"
"io"
+ "math/rand"
"net/http"
"net/url"
+ "os"
+ "strconv"
+ "strings"
"time"
+
+ "git.db.org.ai/dborg/internal/config"
)
type Client struct {
@@ -20,93 +25,178 @@ func New(cfg *config.Config) (*Client, error) {
if err := cfg.Validate(); err != nil {
return nil, err
}
-
- return &Client{
- config: cfg,
- httpClient: &http.Client{
- Timeout: cfg.Timeout,
- },
- }, nil
+ return newClient(cfg), nil
}
func NewUnauthenticated(cfg *config.Config) (*Client, error) {
+ return newClient(cfg), nil
+}
+
+func newClient(cfg *config.Config) *Client {
return &Client{
config: cfg,
httpClient: &http.Client{
Timeout: cfg.Timeout,
},
- }, nil
+ }
+}
+
+func (c *Client) debugf(format string, args ...interface{}) {
+ if !c.config.Debug {
+ return
+ }
+ fmt.Fprintf(os.Stderr, "[dborg] "+format+"\n", args...)
+}
+
+func redactKey(key string) string {
+ if len(key) <= 8 {
+ return "***"
+ }
+ return key[:4] + "..." + key[len(key)-4:]
+}
+
+func isRetryable(statusCode int) bool {
+ return statusCode == http.StatusTooManyRequests ||
+ statusCode == http.StatusRequestTimeout ||
+ statusCode >= 500
+}
+
+func backoffDelay(attempt int, retryAfter string) time.Duration {
+ if retryAfter != "" {
+ if secs, err := strconv.Atoi(strings.TrimSpace(retryAfter)); err == nil && secs > 0 {
+ return time.Duration(secs) * time.Second
+ }
+ if t, err := http.ParseTime(retryAfter); err == nil {
+ if d := time.Until(t); d > 0 {
+ return d
+ }
+ }
+ }
+ base := time.Duration(1<<attempt) * time.Second
+ if base > 30*time.Second {
+ base = 30 * time.Second
+ }
+ jitter := time.Duration(rand.Int63n(int64(base) / 2))
+ return base + jitter
}
func (c *Client) doRequest(method, path string, params url.Values, body interface{}) ([]byte, error) {
fullURL := c.config.BaseURL + path
- if params != nil && len(params) > 0 {
+ if len(params) > 0 {
fullURL += "?" + params.Encode()
}
- var reqBody io.Reader
+ var bodyBytes []byte
if body != nil {
- jsonData, err := json.Marshal(body)
+ var err error
+ bodyBytes, err = json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
- reqBody = bytes.NewBuffer(jsonData)
- }
-
- req, err := http.NewRequest(method, fullURL, reqBody)
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
}
- req.Header.Set("X-API-Key", c.config.APIKey)
- req.Header.Set("User-Agent", c.config.UserAgent)
- if body != nil {
- req.Header.Set("Content-Type", "application/json")
+ c.debugf("→ %s %s (api_key=%s)", method, fullURL, redactKey(c.config.APIKey))
+ if len(bodyBytes) > 0 {
+ c.debugf(" body: %s", string(bodyBytes))
}
- var resp *http.Response
var lastErr error
for attempt := 0; attempt <= c.config.MaxRetries; attempt++ {
if attempt > 0 {
- time.Sleep(time.Duration(attempt) * time.Second)
+ delay := backoffDelay(attempt, "")
+ if lastErr != nil {
+ if ra, ok := retryAfterFromErr(lastErr); ok {
+ delay = backoffDelay(attempt, ra)
+ }
+ }
+ c.debugf(" retry %d/%d after %s (last error: %v)", attempt, c.config.MaxRetries, delay, lastErr)
+ time.Sleep(delay)
}
- resp, err = c.httpClient.Do(req)
+ var reqBody io.Reader
+ if bodyBytes != nil {
+ reqBody = bytes.NewReader(bodyBytes)
+ }
+
+ req, err := http.NewRequest(method, fullURL, reqBody)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ req.Header.Set("X-API-Key", c.config.APIKey)
+ req.Header.Set("User-Agent", c.config.UserAgent)
+ if bodyBytes != nil {
+ req.Header.Set("Content-Type", "application/json")
+ }
+
+ start := time.Now()
+ resp, err := c.httpClient.Do(req)
if err != nil {
lastErr = err
+ c.debugf("← network error after %s: %v", time.Since(start), err)
continue
}
- defer resp.Body.Close()
+ respBody, readErr := io.ReadAll(resp.Body)
+ resp.Body.Close()
+ c.debugf("← %d %s (%s, %d bytes)", resp.StatusCode, http.StatusText(resp.StatusCode), time.Since(start), len(respBody))
- if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated {
- return io.ReadAll(resp.Body)
+ if readErr != nil {
+ lastErr = fmt.Errorf("failed to read response body: %w", readErr)
+ if !isRetryable(resp.StatusCode) {
+ return nil, lastErr
+ }
+ continue
}
- bodyBytes, _ := io.ReadAll(resp.Body)
-
- switch resp.StatusCode {
- case http.StatusForbidden:
- lastErr = fmt.Errorf("access denied (403): %s - This endpoint requires premium access", string(bodyBytes))
- case http.StatusUnauthorized:
- lastErr = fmt.Errorf("unauthorized (401): %s - Check your API key", string(bodyBytes))
- case http.StatusTooManyRequests:
- lastErr = fmt.Errorf("rate limit exceeded (429): %s", string(bodyBytes))
- case http.StatusBadRequest:
- lastErr = fmt.Errorf("bad request (400): %s", string(bodyBytes))
- default:
- lastErr = fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
+ if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated {
+ return respBody, nil
}
- if resp.StatusCode != http.StatusTooManyRequests && resp.StatusCode < 500 {
- break
+ lastErr = httpError(resp.StatusCode, respBody, resp.Header.Get("Retry-After"))
+
+ if !isRetryable(resp.StatusCode) {
+ return nil, lastErr
}
}
return nil, lastErr
}
+type apiError struct {
+ status int
+ message string
+ retryAfter string
+}
+
+func (e *apiError) Error() string { return e.message }
+
+func retryAfterFromErr(err error) (string, bool) {
+ if ae, ok := err.(*apiError); ok && ae.retryAfter != "" {
+ return ae.retryAfter, true
+ }
+ return "", false
+}
+
+func httpError(status int, body []byte, retryAfter string) error {
+ msg := string(body)
+ var formatted string
+ switch status {
+ case http.StatusForbidden:
+ formatted = fmt.Sprintf("access denied (403): %s - This endpoint requires premium access", msg)
+ case http.StatusUnauthorized:
+ formatted = fmt.Sprintf("unauthorized (401): %s - Check your API key", msg)
+ case http.StatusTooManyRequests:
+ formatted = fmt.Sprintf("rate limit exceeded (429): %s", msg)
+ case http.StatusBadRequest:
+ formatted = fmt.Sprintf("bad request (400): %s", msg)
+ default:
+ formatted = fmt.Sprintf("API request failed with status %d: %s", status, msg)
+ }
+ return &apiError{status: status, message: formatted, retryAfter: retryAfter}
+}
+
func (c *Client) Get(path string, params url.Values) ([]byte, error) {
return c.doRequest(http.MethodGet, path, params, nil)
}
diff --git a/internal/client/retry_test.go b/internal/client/retry_test.go
new file mode 100644
index 0000000..532ddb7
--- /dev/null
+++ b/internal/client/retry_test.go
@@ -0,0 +1,247 @@
+package client
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "git.db.org.ai/dborg/internal/config"
+)
+
+func testConfig(baseURL string) *config.Config {
+ return &config.Config{
+ APIKey: "test-key",
+ BaseURL: baseURL,
+ Timeout: 5 * time.Second,
+ MaxRetries: 3,
+ UserAgent: "dborg-test",
+ }
+}
+
+func TestRedactKey(t *testing.T) {
+ tests := []struct {
+ in, want string
+ }{
+ {"", "***"},
+ {"short", "***"},
+ {"12345678", "***"},
+ {"abcd1234efgh5678", "abcd...5678"},
+ }
+ for _, tt := range tests {
+ if got := redactKey(tt.in); got != tt.want {
+ t.Errorf("redactKey(%q) = %q, want %q", tt.in, got, tt.want)
+ }
+ }
+}
+
+func TestIsRetryable(t *testing.T) {
+ retryable := []int{429, 408, 500, 502, 503, 504}
+ notRetryable := []int{200, 201, 301, 400, 401, 403, 404}
+
+ for _, code := range retryable {
+ if !isRetryable(code) {
+ t.Errorf("isRetryable(%d) = false, want true", code)
+ }
+ }
+ for _, code := range notRetryable {
+ if isRetryable(code) {
+ t.Errorf("isRetryable(%d) = true, want false", code)
+ }
+ }
+}
+
+func TestBackoffDelay_RetryAfterSeconds(t *testing.T) {
+ d := backoffDelay(1, "5")
+ if d != 5*time.Second {
+ t.Errorf("backoffDelay with Retry-After=5 = %v, want 5s", d)
+ }
+}
+
+func TestBackoffDelay_RetryAfterHTTPDate(t *testing.T) {
+ future := time.Now().Add(10 * time.Second).UTC().Format(http.TimeFormat)
+ d := backoffDelay(1, future)
+ if d < 5*time.Second || d > 11*time.Second {
+ t.Errorf("backoffDelay with HTTP date = %v, want ~10s", d)
+ }
+}
+
+func TestBackoffDelay_Exponential(t *testing.T) {
+ d1 := backoffDelay(1, "")
+ d2 := backoffDelay(2, "")
+ d3 := backoffDelay(3, "")
+ if d1 < 2*time.Second || d1 > 3*time.Second {
+ t.Errorf("attempt 1 delay = %v, want 2-3s", d1)
+ }
+ if d2 < 4*time.Second || d2 > 6*time.Second {
+ t.Errorf("attempt 2 delay = %v, want 4-6s", d2)
+ }
+ if d3 < 8*time.Second || d3 > 12*time.Second {
+ t.Errorf("attempt 3 delay = %v, want 8-12s", d3)
+ }
+}
+
+func TestDoRequest_Success(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if got := r.Header.Get("X-API-Key"); got != "test-key" {
+ t.Errorf("X-API-Key = %q, want test-key", got)
+ }
+ fmt.Fprintln(w, `{"ok":true}`)
+ }))
+ defer srv.Close()
+
+ c, _ := New(testConfig(srv.URL))
+ body, err := c.Get("/test", nil)
+ if err != nil {
+ t.Fatalf("Get() error = %v", err)
+ }
+ if !strings.Contains(string(body), `"ok":true`) {
+ t.Errorf("unexpected body: %s", body)
+ }
+}
+
+func TestDoRequest_RetriesOn500ThenSucceeds(t *testing.T) {
+ var calls int32
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ n := atomic.AddInt32(&calls, 1)
+ if n < 3 {
+ http.Error(w, "boom", http.StatusInternalServerError)
+ return
+ }
+ fmt.Fprintln(w, `{"ok":true}`)
+ }))
+ defer srv.Close()
+
+ cfg := testConfig(srv.URL)
+ cfg.MaxRetries = 5
+ c, _ := New(cfg)
+ body, err := c.Get("/test", nil)
+ if err != nil {
+ t.Fatalf("Get() error = %v", err)
+ }
+ if atomic.LoadInt32(&calls) != 3 {
+ t.Errorf("calls = %d, want 3", calls)
+ }
+ if !strings.Contains(string(body), `"ok":true`) {
+ t.Errorf("unexpected body: %s", body)
+ }
+}
+
+func TestDoRequest_NoRetryOn400(t *testing.T) {
+ var calls int32
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ atomic.AddInt32(&calls, 1)
+ http.Error(w, "bad", http.StatusBadRequest)
+ }))
+ defer srv.Close()
+
+ c, _ := New(testConfig(srv.URL))
+ _, err := c.Get("/test", nil)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if atomic.LoadInt32(&calls) != 1 {
+ t.Errorf("calls = %d, want 1 (no retries on 400)", calls)
+ }
+}
+
+func TestDoRequest_NoRetryOn401(t *testing.T) {
+ var calls int32
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ atomic.AddInt32(&calls, 1)
+ http.Error(w, "no auth", http.StatusUnauthorized)
+ }))
+ defer srv.Close()
+
+ c, _ := New(testConfig(srv.URL))
+ _, err := c.Get("/test", nil)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if atomic.LoadInt32(&calls) != 1 {
+ t.Errorf("calls = %d, want 1", calls)
+ }
+}
+
+func TestDoRequest_ExhaustsRetries(t *testing.T) {
+ var calls int32
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ atomic.AddInt32(&calls, 1)
+ http.Error(w, "down", http.StatusBadGateway)
+ }))
+ defer srv.Close()
+
+ cfg := testConfig(srv.URL)
+ cfg.MaxRetries = 2
+ c, _ := New(cfg)
+ _, err := c.Get("/test", nil)
+ if err == nil {
+ t.Fatal("expected error after exhausting retries")
+ }
+ if atomic.LoadInt32(&calls) != 3 {
+ t.Errorf("calls = %d, want 3 (initial + 2 retries)", calls)
+ }
+}
+
+func TestDoRequest_PostBodyResentOnRetry(t *testing.T) {
+ var calls int32
+ var bodies []string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ b, _ := io.ReadAll(r.Body)
+ bodies = append(bodies, string(b))
+ n := atomic.AddInt32(&calls, 1)
+ if n < 2 {
+ http.Error(w, "boom", http.StatusServiceUnavailable)
+ return
+ }
+ fmt.Fprintln(w, `{"ok":true}`)
+ }))
+ defer srv.Close()
+
+ c, _ := New(testConfig(srv.URL))
+ payload := map[string]string{"hello": "world"}
+ _, err := c.Post("/test", payload)
+ if err != nil {
+ t.Fatalf("Post() error = %v", err)
+ }
+ if len(bodies) != 2 {
+ t.Fatalf("got %d requests, want 2", len(bodies))
+ }
+ if bodies[0] != bodies[1] {
+ t.Errorf("body mismatch between retries:\n first: %q\n second: %q", bodies[0], bodies[1])
+ }
+ if !strings.Contains(bodies[0], `"hello":"world"`) {
+ t.Errorf("unexpected body: %q", bodies[0])
+ }
+}
+
+func TestDoRequest_RetryAfterHeaderHonored(t *testing.T) {
+ var calls int32
+ var firstAt, secondAt time.Time
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ n := atomic.AddInt32(&calls, 1)
+ if n == 1 {
+ firstAt = time.Now()
+ w.Header().Set("Retry-After", "1")
+ http.Error(w, "slow down", http.StatusTooManyRequests)
+ return
+ }
+ secondAt = time.Now()
+ fmt.Fprintln(w, `{"ok":true}`)
+ }))
+ defer srv.Close()
+
+ c, _ := New(testConfig(srv.URL))
+ _, err := c.Get("/test", nil)
+ if err != nil {
+ t.Fatalf("Get() error = %v", err)
+ }
+ gap := secondAt.Sub(firstAt)
+ if gap < 900*time.Millisecond {
+ t.Errorf("retry gap = %v, want >= 1s (Retry-After honored)", gap)
+ }
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index 9df03bf..0da7541 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -36,6 +36,7 @@ func New() *Config {
Timeout: 30 * time.Second,
MaxRetries: 3,
UserAgent: "dborg-cli/1.0",
+ Debug: os.Getenv("DBORG_DEBUG") != "",
}
}
diff --git a/internal/formatter/formatter.go b/internal/formatter/formatter.go
index bb36fa8..159f027 100644
--- a/internal/formatter/formatter.go
+++ b/internal/formatter/formatter.go
@@ -64,11 +64,7 @@ func (f *BaseFormatter) FormatJSON(data any) error {
}
func isTerminal() bool {
- fileInfo, err := os.Stdout.Stat()
- if err != nil {
- return false
- }
- return (fileInfo.Mode() & os.ModeCharDevice) != 0
+ return utils.IsTerminal()
}
func GetTerminalWidth() int {
diff --git a/internal/utils/tty.go b/internal/utils/tty.go
new file mode 100644
index 0000000..4ac367d
--- /dev/null
+++ b/internal/utils/tty.go
@@ -0,0 +1,19 @@
+package utils
+
+import (
+ "os"
+
+ "golang.org/x/term"
+)
+
+func IsTerminal() bool {
+ return term.IsTerminal(int(os.Stdout.Fd()))
+}
+
+func IsStderrTerminal() bool {
+ return term.IsTerminal(int(os.Stderr.Fd()))
+}
+
+func isTerminal() bool {
+ return IsTerminal()
+}
diff --git a/internal/utils/version.go b/internal/utils/version.go
index 46ebb73..ab72bfb 100644
--- a/internal/utils/version.go
+++ b/internal/utils/version.go
@@ -4,7 +4,9 @@ import (
"fmt"
"os"
"os/exec"
+ "path/filepath"
"runtime/debug"
+ "strconv"
"strings"
"syscall"
@@ -83,10 +85,21 @@ func isNewerVersion(remote, local string) bool {
localParts := strings.Split(local, ".")
for i := 0; i < len(remoteParts) && i < len(localParts); i++ {
- if remoteParts[i] > localParts[i] {
+ r, errR := strconv.Atoi(remoteParts[i])
+ l, errL := strconv.Atoi(localParts[i])
+ if errR != nil || errL != nil {
+ if remoteParts[i] > localParts[i] {
+ return true
+ }
+ if remoteParts[i] < localParts[i] {
+ return false
+ }
+ continue
+ }
+ if r > l {
return true
}
- if remoteParts[i] < localParts[i] {
+ if r < l {
return false
}
}
@@ -125,26 +138,36 @@ func promptAndUpdate(newVersion string) {
restartSelf()
}
+func resolveInstalledBinary() (string, error) {
+ if gobin := os.Getenv("GOBIN"); gobin != "" {
+ return filepath.Join(gobin, "dborg"), nil
+ }
+
+ gopath := os.Getenv("GOPATH")
+ if gopath == "" {
+ out, err := exec.Command("go", "env", "GOPATH").Output()
+ if err != nil {
+ return "", fmt.Errorf("could not determine GOPATH: %w", err)
+ }
+ gopath = strings.TrimSpace(string(out))
+ }
+
+ return filepath.Join(gopath, "bin", "dborg"), nil
+}
+
func restartSelf() {
- executable, err := os.Executable()
+ newBinary, err := resolveInstalledBinary()
if err != nil {
- fmt.Fprintf(os.Stderr, "Failed to get executable path: %v\n", err)
+ fmt.Fprintf(os.Stderr, "Failed to locate installed binary: %v\n", err)
os.Exit(1)
}
args := os.Args[1:]
- err = syscall.Exec(executable, append([]string{executable}, args...), os.Environ())
+ err = syscall.Exec(newBinary, append([]string{newBinary}, args...), os.Environ())
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to restart: %v\n", err)
os.Exit(1)
}
}
-func isTerminal() bool {
- fileInfo, err := os.Stdout.Stat()
- if err != nil {
- return false
- }
- return (fileInfo.Mode() & os.ModeCharDevice) != 0
-}
diff --git a/internal/utils/version_test.go b/internal/utils/version_test.go
index e3a27b1..9db181a 100644
--- a/internal/utils/version_test.go
+++ b/internal/utils/version_test.go
@@ -72,6 +72,30 @@ func TestIsNewerVersion(t *testing.T) {
local: "v0.5.0",
expected: true,
},
+ {
+ name: "double digit patch newer",
+ remote: "v1.0.12",
+ local: "v1.0.9",
+ expected: true,
+ },
+ {
+ name: "double digit patch older",
+ remote: "v1.0.9",
+ local: "v1.0.12",
+ expected: false,
+ },
+ {
+ name: "double digit patch same",
+ remote: "v1.0.12",
+ local: "v1.0.12",
+ expected: false,
+ },
+ {
+ name: "double digit minor newer",
+ remote: "v1.10.0",
+ local: "v1.9.0",
+ expected: true,
+ },
}
for _, tt := range tests {