diff options
| author | sinner <[email protected]> | 2026-04-15 15:16:02 -0400 |
|---|---|---|
| committer | sinner <[email protected]> | 2026-04-15 15:16:02 -0400 |
| commit | a5f907854f29e1c267ad30d1dfe85c2c47f5ac48 (patch) | |
| tree | bc8685c3b22e6d5d47702ba0607c694f938ba7fd | |
| parent | 8a1cf20dd5014ebe15ced77344902b79dcfa2e66 (diff) | |
| download | dborg-1.1.1.tar.gz dborg-1.1.1.zip | |
| -rw-r--r-- | cmd/breachforum.go | 24 | ||||
| -rw-r--r-- | cmd/bssid.go | 30 | ||||
| -rw-r--r-- | cmd/crawl.go | 16 | ||||
| -rw-r--r-- | cmd/dns.go | 77 | ||||
| -rw-r--r-- | cmd/email.go | 24 | ||||
| -rw-r--r-- | cmd/files.go | 36 | ||||
| -rw-r--r-- | cmd/moon.go | 68 | ||||
| -rw-r--r-- | cmd/root.go | 2 | ||||
| -rw-r--r-- | cmd/sl.go | 13 | ||||
| -rw-r--r-- | cmd/stdin.go | 98 | ||||
| -rw-r--r-- | cmd/telegram.go | 24 | ||||
| -rw-r--r-- | cmd/username.go | 45 | ||||
| -rw-r--r-- | internal/client/client.go | 176 | ||||
| -rw-r--r-- | internal/client/retry_test.go | 247 | ||||
| -rw-r--r-- | internal/config/config.go | 1 | ||||
| -rw-r--r-- | internal/formatter/formatter.go | 6 | ||||
| -rw-r--r-- | internal/utils/tty.go | 19 | ||||
| -rw-r--r-- | internal/utils/version.go | 47 | ||||
| -rw-r--r-- | internal/utils/version_test.go | 24 |
19 files changed, 743 insertions, 234 deletions
diff --git a/cmd/breachforum.go b/cmd/breachforum.go index 63254c1..229292a 100644 --- a/cmd/breachforum.go +++ b/cmd/breachforum.go @@ -11,7 +11,7 @@ var breachforumCmd = &cobra.Command{ Aliases: []string{"brf"}, Short: "Search BreachForum data", Long: `Search breachdetect index for BreachForum messages and detections`, - Args: cobra.ExactArgs(1), + Args: argsOrStdin(1), RunE: runBreachForumSearch, } @@ -26,15 +26,17 @@ func runBreachForumSearch(cmd *cobra.Command, args []string) error { return err } - params := &models.BreachForumSearchParams{ - Search: args[0], - } - params.MaxHits, _ = cmd.Flags().GetInt("max_hits") - - response, err := c.SearchBreachForum(params) - if err != nil { - return err - } + maxHits, _ := cmd.Flags().GetInt("max_hits") - return formatter.FormatBreachForumResults(response, IsJSONOutput()) + return forEachQuery(args, func(query string) error { + params := &models.BreachForumSearchParams{ + Search: query, + MaxHits: maxHits, + } + response, err := c.SearchBreachForum(params) + if err != nil { + return err + } + return formatter.FormatBreachForumResults(response, IsJSONOutput()) + }) } diff --git a/cmd/bssid.go b/cmd/bssid.go index 0cf751f..116c63a 100644 --- a/cmd/bssid.go +++ b/cmd/bssid.go @@ -11,7 +11,7 @@ var bssidCmd = &cobra.Command{ Aliases: []string{"bs"}, Short: "Lookup WiFi access point location by BSSID", Long: `Lookup geographic location of a WiFi access point by its BSSID (MAC address) using Apple's location services`, - Args: cobra.ExactArgs(1), + Args: argsOrStdin(1), RunE: runBSSIDLookup, } @@ -28,17 +28,21 @@ func runBSSIDLookup(cmd *cobra.Command, args []string) error { return err } - params := &models.BSSIDParams{ - BSSID: args[0], - } - params.All, _ = cmd.Flags().GetBool("all") - params.Google, _ = cmd.Flags().GetBool("google") - params.OSM, _ = cmd.Flags().GetBool("osm") - - response, err := c.LookupBSSID(params) - if err != nil { - return err - } + all, _ := cmd.Flags().GetBool("all") + google, _ := cmd.Flags().GetBool("google") + osm, _ := cmd.Flags().GetBool("osm") - return formatter.FormatBSSIDResults(*response, IsJSONOutput()) + return forEachQuery(args, func(bssid string) error { + params := &models.BSSIDParams{ + BSSID: bssid, + All: all, + Google: google, + OSM: osm, + } + response, err := c.LookupBSSID(params) + if err != nil { + return err + } + return formatter.FormatBSSIDResults(*response, IsJSONOutput()) + }) } diff --git a/cmd/crawl.go b/cmd/crawl.go index 2c9b719..5c5f78f 100644 --- a/cmd/crawl.go +++ b/cmd/crawl.go @@ -11,7 +11,7 @@ var crawlCmd = &cobra.Command{ Aliases: []string{"cw"}, Short: "Crawl domain", Long: `Resolves a domain using httpx and crawls it using katana. Returns discovered links as plain text, one per line, streamed in real-time. Supports both http:// and https:// URLs.`, - Args: cobra.ExactArgs(1), + Args: argsOrStdin(1), RunE: runCrawl, } @@ -28,14 +28,10 @@ func runCrawl(cmd *cobra.Command, args []string) error { return err } - err = c.CrawlDomain(args[0], subdomains, func(line string) error { - fmt.Println(line) - return nil + return forEachQuery(args, func(domain string) error { + return c.CrawlDomain(domain, subdomains, func(line string) error { + fmt.Println(line) + return nil + }) }) - - if err != nil { - return err - } - - return nil } @@ -20,7 +20,7 @@ var dnsTLDCmd = &cobra.Command{ Aliases: []string{"t"}, Short: "Check NXDOMAIN for custom term against all TLDs", Long: "Streams NDJSON results checking each TLD. For NXDOMAIN domains, returns status. For existing domains, runs httpx to get page title and tech stack.", - Args: cobra.ExactArgs(1), + Args: argsOrStdin(1), RunE: runDNSTLDCheck, } @@ -29,12 +29,11 @@ var dnsSiteCmd = &cobra.Command{ Aliases: []string{"s"}, Short: "Check if a website URL has been reused", Long: "Checks if a website URL has been reused across different domains", - Args: cobra.ExactArgs(1), + Args: argsOrStdin(1), RunE: runDNSSiteCheck, } func runDNSTLDCheck(cmd *cobra.Command, args []string) error { - term := args[0] showOnly, _ := cmd.Flags().GetString("show-only") c, err := newUnauthenticatedClient() @@ -42,58 +41,54 @@ func runDNSTLDCheck(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to create client: %w", err) } - params := &models.DNSTLDParams{ - Term: term, - ShowOnly: showOnly, - } - - fmt.Printf("Checking TLDs for term: %s\n\n", term) - - err = c.CheckDNSTLDStream(params, func(result json.RawMessage) error { - var domainResult models.DomainResult - if err := json.Unmarshal(result, &domainResult); err != nil { - return fmt.Errorf("failed to parse result: %w", err) + return forEachQuery(args, func(term string) error { + params := &models.DNSTLDParams{ + Term: term, + ShowOnly: showOnly, } - output, err := formatter.FormatDNSResults(&domainResult, IsJSONOutput()) + fmt.Printf("Checking TLDs for term: %s\n\n", term) + + err := c.CheckDNSTLDStream(params, func(result json.RawMessage) error { + var domainResult models.DomainResult + if err := json.Unmarshal(result, &domainResult); err != nil { + return fmt.Errorf("failed to parse result: %w", err) + } + output, err := formatter.FormatDNSResults(&domainResult, IsJSONOutput()) + if err != nil { + return err + } + printOutput(output) + return nil + }) if err != nil { - return err + return fmt.Errorf("TLD check failed: %w", err) } - printOutput(output) return nil }) - - if err != nil { - return fmt.Errorf("TLD check failed: %w", err) - } - - return nil } func runDNSSiteCheck(cmd *cobra.Command, args []string) error { - siteURL := args[0] - c, err := newClient() if err != nil { return fmt.Errorf("failed to create client: %w", err) } - response, err := c.CheckDNSSite(siteURL) - if err != nil { - return fmt.Errorf("site check failed: %w", err) - } - - if err := checkError(response.Error); err != nil { - return err - } - - output, err := formatter.FormatDNSSite(response, IsJSONOutput()) - if err != nil { - return err - } - - printOutput(output) - return nil + return forEachQuery(args, func(siteURL string) error { + response, err := c.CheckDNSSite(siteURL) + if err != nil { + return fmt.Errorf("site check failed: %w", err) + } + if err := checkError(response.Error); err != nil { + return err + } + output, err := formatter.FormatDNSSite(response, IsJSONOutput()) + if err != nil { + return err + } + printOutput(output) + return nil + }) } func init() { diff --git a/cmd/email.go b/cmd/email.go index 85cc686..871c272 100644 --- a/cmd/email.go +++ b/cmd/email.go @@ -17,7 +17,7 @@ var verifyEmailCmd = &cobra.Command{ Aliases: []string{"v"}, Short: "Verify email address", Long: `Performs comprehensive email verification including format validation, MX records check, SMTP verification, and disposable/webmail detection`, - Args: cobra.ExactArgs(1), + Args: argsOrStdin(1), RunE: runVerifyEmail, } @@ -27,21 +27,19 @@ func init() { } func runVerifyEmail(cmd *cobra.Command, args []string) error { - email := args[0] - c, err := newClient() if err != nil { return err } - response, err := c.VerifyEmail(email) - if err != nil { - return err - } - - if err := checkError(response.Error); err != nil { - return err - } - - return formatter.FormatEmailResults(response, IsJSONOutput()) + return forEachQuery(args, func(email string) error { + response, err := c.VerifyEmail(email) + if err != nil { + return err + } + if err := checkError(response.Error); err != nil { + return err + } + return formatter.FormatEmailResults(response, IsJSONOutput()) + }) } diff --git a/cmd/files.go b/cmd/files.go index aee7d14..6b608ca 100644 --- a/cmd/files.go +++ b/cmd/files.go @@ -11,7 +11,7 @@ var filesCmd = &cobra.Command{ Aliases: []string{"f"}, Short: "Search open directory files", Long: `Search for files in open directories using various filters (free OSINT endpoint)`, - Args: cobra.ExactArgs(1), + Args: argsOrStdin(1), RunE: runFilesSearch, } @@ -30,19 +30,25 @@ func runFilesSearch(cmd *cobra.Command, args []string) error { return err } - params := &models.OpenDirectorySearchParams{ - URL: args[0], - } - params.Filename, _ = cmd.Flags().GetString("filename") - params.Extension, _ = cmd.Flags().GetString("extension") - params.Exclude, _ = cmd.Flags().GetString("exclude") - params.Size, _ = cmd.Flags().GetInt("size") - params.From, _ = cmd.Flags().GetInt("from") - - response, err := c.SearchOpenDirectoryFiles(params) - if err != nil { - return err - } + filename, _ := cmd.Flags().GetString("filename") + extension, _ := cmd.Flags().GetString("extension") + exclude, _ := cmd.Flags().GetString("exclude") + size, _ := cmd.Flags().GetInt("size") + from, _ := cmd.Flags().GetInt("from") - return formatter.FormatFilesResults(*response, IsJSONOutput()) + return forEachQuery(args, func(url string) error { + params := &models.OpenDirectorySearchParams{ + URL: url, + Filename: filename, + Extension: extension, + Exclude: exclude, + Size: size, + From: from, + } + response, err := c.SearchOpenDirectoryFiles(params) + if err != nil { + return err + } + return formatter.FormatFilesResults(*response, IsJSONOutput()) + }) } diff --git a/cmd/moon.go b/cmd/moon.go index 931f64c..f749435 100644 --- a/cmd/moon.go +++ b/cmd/moon.go @@ -13,7 +13,7 @@ var moonCmd = &cobra.Command{ Aliases: []string{"mn"}, Short: "Search moon logs", Long: `Search moon logs with various filters. Requires admin API key.`, - Args: cobra.ExactArgs(1), + Args: argsOrStdin(1), RunE: runMoon, } @@ -36,40 +36,46 @@ func runMoon(cmd *cobra.Command, args []string) error { return err } - params := &models.MoonParams{ - Query: args[0], - } - params.Filename, _ = cmd.Flags().GetString("filename") - params.MaxHits, _ = cmd.Flags().GetInt("max_hits") + filename, _ := cmd.Flags().GetString("filename") + maxHits, _ := cmd.Flags().GetInt("max_hits") sortBy, _ := cmd.Flags().GetString("sort_by") if sortBy != "" && sortBy != "ingest_timestamp" && sortBy != "date_posted" { return fmt.Errorf("invalid sort_by value: must be 'ingest_timestamp' or 'date_posted'") } - params.SortBy = sortBy - params.IngestStartDate, _ = cmd.Flags().GetString("ingest_start_date") - params.IngestEndDate, _ = cmd.Flags().GetString("ingest_end_date") - params.PostedStartDate, _ = cmd.Flags().GetString("posted_start_date") - params.PostedEndDate, _ = cmd.Flags().GetString("posted_end_date") - params.Format, _ = cmd.Flags().GetString("format") + ingestStart, _ := cmd.Flags().GetString("ingest_start_date") + ingestEnd, _ := cmd.Flags().GetString("ingest_end_date") + postedStart, _ := cmd.Flags().GetString("posted_start_date") + postedEnd, _ := cmd.Flags().GetString("posted_end_date") + format, _ := cmd.Flags().GetString("format") - response, err := c.SearchMoonLogs(params) - if err != nil { - return err - } - - if err := checkError(response.Error); err != nil { - return err - } - - if params.Format != "json" { - fmt.Println(response.Message) + return forEachQuery(args, func(query string) error { + params := &models.MoonParams{ + Query: query, + Filename: filename, + MaxHits: maxHits, + SortBy: sortBy, + IngestStartDate: ingestStart, + IngestEndDate: ingestEnd, + PostedStartDate: postedStart, + PostedEndDate: postedEnd, + Format: format, + } + response, err := c.SearchMoonLogs(params) + if err != nil { + return err + } + if err := checkError(response.Error); err != nil { + return err + } + if params.Format != "json" { + fmt.Println(response.Message) + return nil + } + output, err := formatter.FormatMoonResults(response, IsJSONOutput()) + if err != nil { + return err + } + printOutput(output) return nil - } - - output, err := formatter.FormatMoonResults(response, IsJSONOutput()) - if err != nil { - return err - } - printOutput(output) - return nil + }) } diff --git a/cmd/root.go b/cmd/root.go index 292e96d..a5d407f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -23,6 +23,8 @@ DB.org.ai CLI client`, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { return utils.CheckForUpdates(cmd) }, + SilenceUsage: true, + SilenceErrors: true, } func Execute() { @@ -3,6 +3,7 @@ package cmd import ( "fmt" + "git.db.org.ai/dborg/internal/client" "git.db.org.ai/dborg/internal/formatter" "git.db.org.ai/dborg/internal/models" "github.com/spf13/cobra" @@ -11,8 +12,8 @@ import ( var slCmd = &cobra.Command{ Use: "sl [query]", Short: "Search stealer logs", - Long: `Search stealer logs with various filters`, - Args: cobra.ExactArgs(1), + Long: `Search stealer logs with various filters. Accepts a query arg or newline-delimited queries on stdin.`, + Args: argsOrStdin(1), RunE: runSLSearch, } @@ -34,8 +35,14 @@ func runSLSearch(cmd *cobra.Command, args []string) error { return err } + return forEachQuery(args, func(query string) error { + return runSLSearchOne(cmd, c, query) + }) +} + +func runSLSearchOne(cmd *cobra.Command, c *client.Client, query string) error { params := &models.SLParams{ - Query: args[0], + Query: query, } params.Filename, _ = cmd.Flags().GetString("filename") params.MaxHits, _ = cmd.Flags().GetInt("max_hits") diff --git a/cmd/stdin.go b/cmd/stdin.go new file mode 100644 index 0000000..e9240db --- /dev/null +++ b/cmd/stdin.go @@ -0,0 +1,98 @@ +package cmd + +import ( + "bufio" + "fmt" + "os" + "strings" + + "git.db.org.ai/dborg/internal/utils" + "github.com/spf13/cobra" +) + +func stdinPiped() bool { + fi, err := os.Stdin.Stat() + if err != nil { + return false + } + return (fi.Mode() & os.ModeCharDevice) == 0 +} + +func readStdinLines() ([]string, error) { + var lines []string + scanner := bufio.NewScanner(os.Stdin) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + lines = append(lines, line) + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("failed to read stdin: %w", err) + } + return lines, nil +} + +func argsOrStdin(n int) cobra.PositionalArgs { + return func(cmd *cobra.Command, args []string) error { + if len(args) == n { + return nil + } + if len(args) == 0 && stdinPiped() { + return nil + } + return fmt.Errorf("accepts %d arg(s), received %d", n, len(args)) + } +} + +func resolveQueries(args []string) ([]string, error) { + if len(args) > 0 { + return args, nil + } + if !stdinPiped() { + return nil, fmt.Errorf("no query provided") + } + lines, err := readStdinLines() + if err != nil { + return nil, err + } + if len(lines) == 0 { + return nil, fmt.Errorf("no queries read from stdin") + } + return lines, nil +} + +func forEachQuery(args []string, fn func(query string) error) error { + queries, err := resolveQueries(args) + if err != nil { + return err + } + + multi := len(queries) > 1 + var firstErr error + for i, q := range queries { + if multi { + printQuerySeparator(i, q) + } + if err := fn(q); err != nil { + fmt.Fprintf(os.Stderr, "Error for %q: %v\n", q, err) + if firstErr == nil { + firstErr = err + } + } + } + return firstErr +} + +func printQuerySeparator(i int, q string) { + if utils.IsTerminal() { + if i > 0 { + fmt.Println() + } + fmt.Printf("\033[1;36m━━━ %s ━━━\033[0m\n", q) + } else { + fmt.Printf("--- %s ---\n", q) + } +} diff --git a/cmd/telegram.go b/cmd/telegram.go index 98092a6..23da10a 100644 --- a/cmd/telegram.go +++ b/cmd/telegram.go @@ -17,7 +17,7 @@ var phoneCmd = &cobra.Command{ Aliases: []string{"p"}, Short: "Get phone number for Telegram user", Long: `Retrieves the phone number associated with a Telegram username (with @ prefix) or user ID`, - Args: cobra.ExactArgs(1), + Args: argsOrStdin(1), RunE: runTelegramPhone, } @@ -27,21 +27,19 @@ func init() { } func runTelegramPhone(cmd *cobra.Command, args []string) error { - identifier := args[0] - c, err := newClient() if err != nil { return err } - response, err := c.GetTelegramPhone(identifier) - if err != nil { - return err - } - - if err := checkError(response.Error); err != nil { - return err - } - - return formatter.FormatTelegramResults(response, IsJSONOutput()) + return forEachQuery(args, func(identifier string) error { + response, err := c.GetTelegramPhone(identifier) + if err != nil { + return err + } + if err := checkError(response.Error); err != nil { + return err + } + return formatter.FormatTelegramResults(response, IsJSONOutput()) + }) } diff --git a/cmd/username.go b/cmd/username.go index ed51d2e..60da503 100644 --- a/cmd/username.go +++ b/cmd/username.go @@ -14,7 +14,7 @@ var usernameCmd = &cobra.Command{ Aliases: []string{"un"}, Short: "Check username availability across websites", Long: `Check username availability across hundreds of websites using WhatsMyName dataset`, - Args: cobra.ExactArgs(1), + Args: argsOrStdin(1), RunE: runUsernameCheck, } @@ -31,30 +31,27 @@ func runUsernameCheck(cmd *cobra.Command, args []string) error { return err } - params := &models.USRSXParams{ - Username: args[0], - } - params.Sites, _ = cmd.Flags().GetStringSlice("sites") - params.Fuzzy, _ = cmd.Flags().GetBool("fuzzy") - params.MaxTasks, _ = cmd.Flags().GetInt("max_tasks") - - err = c.CheckUsernameStream(params, func(result json.RawMessage) error { - if IsJSONOutput() { - fmt.Println(string(result)) - return nil - } + sites, _ := cmd.Flags().GetStringSlice("sites") + fuzzy, _ := cmd.Flags().GetBool("fuzzy") + maxTasks, _ := cmd.Flags().GetInt("max_tasks") - var siteResult models.SiteResult - if err := json.Unmarshal(result, &siteResult); err != nil { - return err + return forEachQuery(args, func(username string) error { + params := &models.USRSXParams{ + Username: username, + Sites: sites, + Fuzzy: fuzzy, + MaxTasks: maxTasks, } - - return formatter.FormatUsernameSiteResult(&siteResult) + return c.CheckUsernameStream(params, func(result json.RawMessage) error { + if IsJSONOutput() { + fmt.Println(string(result)) + return nil + } + var siteResult models.SiteResult + if err := json.Unmarshal(result, &siteResult); err != nil { + return err + } + return formatter.FormatUsernameSiteResult(&siteResult) + }) }) - - if err != nil { - return err - } - - return nil } diff --git a/internal/client/client.go b/internal/client/client.go index 4aa6f06..4ca61e4 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -4,11 +4,16 @@ import ( "bytes" "encoding/json" "fmt" - "git.db.org.ai/dborg/internal/config" "io" + "math/rand" "net/http" "net/url" + "os" + "strconv" + "strings" "time" + + "git.db.org.ai/dborg/internal/config" ) type Client struct { @@ -20,93 +25,178 @@ func New(cfg *config.Config) (*Client, error) { if err := cfg.Validate(); err != nil { return nil, err } - - return &Client{ - config: cfg, - httpClient: &http.Client{ - Timeout: cfg.Timeout, - }, - }, nil + return newClient(cfg), nil } func NewUnauthenticated(cfg *config.Config) (*Client, error) { + return newClient(cfg), nil +} + +func newClient(cfg *config.Config) *Client { return &Client{ config: cfg, httpClient: &http.Client{ Timeout: cfg.Timeout, }, - }, nil + } +} + +func (c *Client) debugf(format string, args ...interface{}) { + if !c.config.Debug { + return + } + fmt.Fprintf(os.Stderr, "[dborg] "+format+"\n", args...) +} + +func redactKey(key string) string { + if len(key) <= 8 { + return "***" + } + return key[:4] + "..." + key[len(key)-4:] +} + +func isRetryable(statusCode int) bool { + return statusCode == http.StatusTooManyRequests || + statusCode == http.StatusRequestTimeout || + statusCode >= 500 +} + +func backoffDelay(attempt int, retryAfter string) time.Duration { + if retryAfter != "" { + if secs, err := strconv.Atoi(strings.TrimSpace(retryAfter)); err == nil && secs > 0 { + return time.Duration(secs) * time.Second + } + if t, err := http.ParseTime(retryAfter); err == nil { + if d := time.Until(t); d > 0 { + return d + } + } + } + base := time.Duration(1<<attempt) * time.Second + if base > 30*time.Second { + base = 30 * time.Second + } + jitter := time.Duration(rand.Int63n(int64(base) / 2)) + return base + jitter } func (c *Client) doRequest(method, path string, params url.Values, body interface{}) ([]byte, error) { fullURL := c.config.BaseURL + path - if params != nil && len(params) > 0 { + if len(params) > 0 { fullURL += "?" + params.Encode() } - var reqBody io.Reader + var bodyBytes []byte if body != nil { - jsonData, err := json.Marshal(body) + var err error + bodyBytes, err = json.Marshal(body) if err != nil { return nil, fmt.Errorf("failed to marshal request body: %w", err) } - reqBody = bytes.NewBuffer(jsonData) - } - - req, err := http.NewRequest(method, fullURL, reqBody) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("X-API-Key", c.config.APIKey) - req.Header.Set("User-Agent", c.config.UserAgent) - if body != nil { - req.Header.Set("Content-Type", "application/json") + c.debugf("→ %s %s (api_key=%s)", method, fullURL, redactKey(c.config.APIKey)) + if len(bodyBytes) > 0 { + c.debugf(" body: %s", string(bodyBytes)) } - var resp *http.Response var lastErr error for attempt := 0; attempt <= c.config.MaxRetries; attempt++ { if attempt > 0 { - time.Sleep(time.Duration(attempt) * time.Second) + delay := backoffDelay(attempt, "") + if lastErr != nil { + if ra, ok := retryAfterFromErr(lastErr); ok { + delay = backoffDelay(attempt, ra) + } + } + c.debugf(" retry %d/%d after %s (last error: %v)", attempt, c.config.MaxRetries, delay, lastErr) + time.Sleep(delay) } - resp, err = c.httpClient.Do(req) + var reqBody io.Reader + if bodyBytes != nil { + reqBody = bytes.NewReader(bodyBytes) + } + + req, err := http.NewRequest(method, fullURL, reqBody) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("X-API-Key", c.config.APIKey) + req.Header.Set("User-Agent", c.config.UserAgent) + if bodyBytes != nil { + req.Header.Set("Content-Type", "application/json") + } + + start := time.Now() + resp, err := c.httpClient.Do(req) if err != nil { lastErr = err + c.debugf("← network error after %s: %v", time.Since(start), err) continue } - defer resp.Body.Close() + respBody, readErr := io.ReadAll(resp.Body) + resp.Body.Close() + c.debugf("← %d %s (%s, %d bytes)", resp.StatusCode, http.StatusText(resp.StatusCode), time.Since(start), len(respBody)) - if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated { - return io.ReadAll(resp.Body) + if readErr != nil { + lastErr = fmt.Errorf("failed to read response body: %w", readErr) + if !isRetryable(resp.StatusCode) { + return nil, lastErr + } + continue } - bodyBytes, _ := io.ReadAll(resp.Body) - - switch resp.StatusCode { - case http.StatusForbidden: - lastErr = fmt.Errorf("access denied (403): %s - This endpoint requires premium access", string(bodyBytes)) - case http.StatusUnauthorized: - lastErr = fmt.Errorf("unauthorized (401): %s - Check your API key", string(bodyBytes)) - case http.StatusTooManyRequests: - lastErr = fmt.Errorf("rate limit exceeded (429): %s", string(bodyBytes)) - case http.StatusBadRequest: - lastErr = fmt.Errorf("bad request (400): %s", string(bodyBytes)) - default: - lastErr = fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) + if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated { + return respBody, nil } - if resp.StatusCode != http.StatusTooManyRequests && resp.StatusCode < 500 { - break + lastErr = httpError(resp.StatusCode, respBody, resp.Header.Get("Retry-After")) + + if !isRetryable(resp.StatusCode) { + return nil, lastErr } } return nil, lastErr } +type apiError struct { + status int + message string + retryAfter string +} + +func (e *apiError) Error() string { return e.message } + +func retryAfterFromErr(err error) (string, bool) { + if ae, ok := err.(*apiError); ok && ae.retryAfter != "" { + return ae.retryAfter, true + } + return "", false +} + +func httpError(status int, body []byte, retryAfter string) error { + msg := string(body) + var formatted string + switch status { + case http.StatusForbidden: + formatted = fmt.Sprintf("access denied (403): %s - This endpoint requires premium access", msg) + case http.StatusUnauthorized: + formatted = fmt.Sprintf("unauthorized (401): %s - Check your API key", msg) + case http.StatusTooManyRequests: + formatted = fmt.Sprintf("rate limit exceeded (429): %s", msg) + case http.StatusBadRequest: + formatted = fmt.Sprintf("bad request (400): %s", msg) + default: + formatted = fmt.Sprintf("API request failed with status %d: %s", status, msg) + } + return &apiError{status: status, message: formatted, retryAfter: retryAfter} +} + func (c *Client) Get(path string, params url.Values) ([]byte, error) { return c.doRequest(http.MethodGet, path, params, nil) } diff --git a/internal/client/retry_test.go b/internal/client/retry_test.go new file mode 100644 index 0000000..532ddb7 --- /dev/null +++ b/internal/client/retry_test.go @@ -0,0 +1,247 @@ +package client + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "git.db.org.ai/dborg/internal/config" +) + +func testConfig(baseURL string) *config.Config { + return &config.Config{ + APIKey: "test-key", + BaseURL: baseURL, + Timeout: 5 * time.Second, + MaxRetries: 3, + UserAgent: "dborg-test", + } +} + +func TestRedactKey(t *testing.T) { + tests := []struct { + in, want string + }{ + {"", "***"}, + {"short", "***"}, + {"12345678", "***"}, + {"abcd1234efgh5678", "abcd...5678"}, + } + for _, tt := range tests { + if got := redactKey(tt.in); got != tt.want { + t.Errorf("redactKey(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} + +func TestIsRetryable(t *testing.T) { + retryable := []int{429, 408, 500, 502, 503, 504} + notRetryable := []int{200, 201, 301, 400, 401, 403, 404} + + for _, code := range retryable { + if !isRetryable(code) { + t.Errorf("isRetryable(%d) = false, want true", code) + } + } + for _, code := range notRetryable { + if isRetryable(code) { + t.Errorf("isRetryable(%d) = true, want false", code) + } + } +} + +func TestBackoffDelay_RetryAfterSeconds(t *testing.T) { + d := backoffDelay(1, "5") + if d != 5*time.Second { + t.Errorf("backoffDelay with Retry-After=5 = %v, want 5s", d) + } +} + +func TestBackoffDelay_RetryAfterHTTPDate(t *testing.T) { + future := time.Now().Add(10 * time.Second).UTC().Format(http.TimeFormat) + d := backoffDelay(1, future) + if d < 5*time.Second || d > 11*time.Second { + t.Errorf("backoffDelay with HTTP date = %v, want ~10s", d) + } +} + +func TestBackoffDelay_Exponential(t *testing.T) { + d1 := backoffDelay(1, "") + d2 := backoffDelay(2, "") + d3 := backoffDelay(3, "") + if d1 < 2*time.Second || d1 > 3*time.Second { + t.Errorf("attempt 1 delay = %v, want 2-3s", d1) + } + if d2 < 4*time.Second || d2 > 6*time.Second { + t.Errorf("attempt 2 delay = %v, want 4-6s", d2) + } + if d3 < 8*time.Second || d3 > 12*time.Second { + t.Errorf("attempt 3 delay = %v, want 8-12s", d3) + } +} + +func TestDoRequest_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-API-Key"); got != "test-key" { + t.Errorf("X-API-Key = %q, want test-key", got) + } + fmt.Fprintln(w, `{"ok":true}`) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + body, err := c.Get("/test", nil) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if !strings.Contains(string(body), `"ok":true`) { + t.Errorf("unexpected body: %s", body) + } +} + +func TestDoRequest_RetriesOn500ThenSucceeds(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n < 3 { + http.Error(w, "boom", http.StatusInternalServerError) + return + } + fmt.Fprintln(w, `{"ok":true}`) + })) + defer srv.Close() + + cfg := testConfig(srv.URL) + cfg.MaxRetries = 5 + c, _ := New(cfg) + body, err := c.Get("/test", nil) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if atomic.LoadInt32(&calls) != 3 { + t.Errorf("calls = %d, want 3", calls) + } + if !strings.Contains(string(body), `"ok":true`) { + t.Errorf("unexpected body: %s", body) + } +} + +func TestDoRequest_NoRetryOn400(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + http.Error(w, "bad", http.StatusBadRequest) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + _, err := c.Get("/test", nil) + if err == nil { + t.Fatal("expected error") + } + if atomic.LoadInt32(&calls) != 1 { + t.Errorf("calls = %d, want 1 (no retries on 400)", calls) + } +} + +func TestDoRequest_NoRetryOn401(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + http.Error(w, "no auth", http.StatusUnauthorized) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + _, err := c.Get("/test", nil) + if err == nil { + t.Fatal("expected error") + } + if atomic.LoadInt32(&calls) != 1 { + t.Errorf("calls = %d, want 1", calls) + } +} + +func TestDoRequest_ExhaustsRetries(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + http.Error(w, "down", http.StatusBadGateway) + })) + defer srv.Close() + + cfg := testConfig(srv.URL) + cfg.MaxRetries = 2 + c, _ := New(cfg) + _, err := c.Get("/test", nil) + if err == nil { + t.Fatal("expected error after exhausting retries") + } + if atomic.LoadInt32(&calls) != 3 { + t.Errorf("calls = %d, want 3 (initial + 2 retries)", calls) + } +} + +func TestDoRequest_PostBodyResentOnRetry(t *testing.T) { + var calls int32 + var bodies []string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := io.ReadAll(r.Body) + bodies = append(bodies, string(b)) + n := atomic.AddInt32(&calls, 1) + if n < 2 { + http.Error(w, "boom", http.StatusServiceUnavailable) + return + } + fmt.Fprintln(w, `{"ok":true}`) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + payload := map[string]string{"hello": "world"} + _, err := c.Post("/test", payload) + if err != nil { + t.Fatalf("Post() error = %v", err) + } + if len(bodies) != 2 { + t.Fatalf("got %d requests, want 2", len(bodies)) + } + if bodies[0] != bodies[1] { + t.Errorf("body mismatch between retries:\n first: %q\n second: %q", bodies[0], bodies[1]) + } + if !strings.Contains(bodies[0], `"hello":"world"`) { + t.Errorf("unexpected body: %q", bodies[0]) + } +} + +func TestDoRequest_RetryAfterHeaderHonored(t *testing.T) { + var calls int32 + var firstAt, secondAt time.Time + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n == 1 { + firstAt = time.Now() + w.Header().Set("Retry-After", "1") + http.Error(w, "slow down", http.StatusTooManyRequests) + return + } + secondAt = time.Now() + fmt.Fprintln(w, `{"ok":true}`) + })) + defer srv.Close() + + c, _ := New(testConfig(srv.URL)) + _, err := c.Get("/test", nil) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + gap := secondAt.Sub(firstAt) + if gap < 900*time.Millisecond { + t.Errorf("retry gap = %v, want >= 1s (Retry-After honored)", gap) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 9df03bf..0da7541 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -36,6 +36,7 @@ func New() *Config { Timeout: 30 * time.Second, MaxRetries: 3, UserAgent: "dborg-cli/1.0", + Debug: os.Getenv("DBORG_DEBUG") != "", } } diff --git a/internal/formatter/formatter.go b/internal/formatter/formatter.go index bb36fa8..159f027 100644 --- a/internal/formatter/formatter.go +++ b/internal/formatter/formatter.go @@ -64,11 +64,7 @@ func (f *BaseFormatter) FormatJSON(data any) error { } func isTerminal() bool { - fileInfo, err := os.Stdout.Stat() - if err != nil { - return false - } - return (fileInfo.Mode() & os.ModeCharDevice) != 0 + return utils.IsTerminal() } func GetTerminalWidth() int { diff --git a/internal/utils/tty.go b/internal/utils/tty.go new file mode 100644 index 0000000..4ac367d --- /dev/null +++ b/internal/utils/tty.go @@ -0,0 +1,19 @@ +package utils + +import ( + "os" + + "golang.org/x/term" +) + +func IsTerminal() bool { + return term.IsTerminal(int(os.Stdout.Fd())) +} + +func IsStderrTerminal() bool { + return term.IsTerminal(int(os.Stderr.Fd())) +} + +func isTerminal() bool { + return IsTerminal() +} diff --git a/internal/utils/version.go b/internal/utils/version.go index 46ebb73..ab72bfb 100644 --- a/internal/utils/version.go +++ b/internal/utils/version.go @@ -4,7 +4,9 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "runtime/debug" + "strconv" "strings" "syscall" @@ -83,10 +85,21 @@ func isNewerVersion(remote, local string) bool { localParts := strings.Split(local, ".") for i := 0; i < len(remoteParts) && i < len(localParts); i++ { - if remoteParts[i] > localParts[i] { + r, errR := strconv.Atoi(remoteParts[i]) + l, errL := strconv.Atoi(localParts[i]) + if errR != nil || errL != nil { + if remoteParts[i] > localParts[i] { + return true + } + if remoteParts[i] < localParts[i] { + return false + } + continue + } + if r > l { return true } - if remoteParts[i] < localParts[i] { + if r < l { return false } } @@ -125,26 +138,36 @@ func promptAndUpdate(newVersion string) { restartSelf() } +func resolveInstalledBinary() (string, error) { + if gobin := os.Getenv("GOBIN"); gobin != "" { + return filepath.Join(gobin, "dborg"), nil + } + + gopath := os.Getenv("GOPATH") + if gopath == "" { + out, err := exec.Command("go", "env", "GOPATH").Output() + if err != nil { + return "", fmt.Errorf("could not determine GOPATH: %w", err) + } + gopath = strings.TrimSpace(string(out)) + } + + return filepath.Join(gopath, "bin", "dborg"), nil +} + func restartSelf() { - executable, err := os.Executable() + newBinary, err := resolveInstalledBinary() if err != nil { - fmt.Fprintf(os.Stderr, "Failed to get executable path: %v\n", err) + fmt.Fprintf(os.Stderr, "Failed to locate installed binary: %v\n", err) os.Exit(1) } args := os.Args[1:] - err = syscall.Exec(executable, append([]string{executable}, args...), os.Environ()) + err = syscall.Exec(newBinary, append([]string{newBinary}, args...), os.Environ()) if err != nil { fmt.Fprintf(os.Stderr, "Failed to restart: %v\n", err) os.Exit(1) } } -func isTerminal() bool { - fileInfo, err := os.Stdout.Stat() - if err != nil { - return false - } - return (fileInfo.Mode() & os.ModeCharDevice) != 0 -} diff --git a/internal/utils/version_test.go b/internal/utils/version_test.go index e3a27b1..9db181a 100644 --- a/internal/utils/version_test.go +++ b/internal/utils/version_test.go @@ -72,6 +72,30 @@ func TestIsNewerVersion(t *testing.T) { local: "v0.5.0", expected: true, }, + { + name: "double digit patch newer", + remote: "v1.0.12", + local: "v1.0.9", + expected: true, + }, + { + name: "double digit patch older", + remote: "v1.0.9", + local: "v1.0.12", + expected: false, + }, + { + name: "double digit patch same", + remote: "v1.0.12", + local: "v1.0.12", + expected: false, + }, + { + name: "double digit minor newer", + remote: "v1.10.0", + local: "v1.9.0", + expected: true, + }, } for _, tt := range tests { |
