summaryrefslogtreecommitdiffstats
path: root/cmd/stdin.go
blob: e9240db95abfd0bd427aed323457b0ef86f26614 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package cmd

import (
	"bufio"
	"fmt"
	"os"
	"strings"

	"git.db.org.ai/dborg/internal/utils"
	"github.com/spf13/cobra"
)

func stdinPiped() bool {
	fi, err := os.Stdin.Stat()
	if err != nil {
		return false
	}
	return (fi.Mode() & os.ModeCharDevice) == 0
}

func readStdinLines() ([]string, error) {
	var lines []string
	scanner := bufio.NewScanner(os.Stdin)
	scanner.Buffer(make([]byte, 64*1024), 1024*1024)
	for scanner.Scan() {
		line := strings.TrimSpace(scanner.Text())
		if line == "" || strings.HasPrefix(line, "#") {
			continue
		}
		lines = append(lines, line)
	}
	if err := scanner.Err(); err != nil {
		return nil, fmt.Errorf("failed to read stdin: %w", err)
	}
	return lines, nil
}

func argsOrStdin(n int) cobra.PositionalArgs {
	return func(cmd *cobra.Command, args []string) error {
		if len(args) == n {
			return nil
		}
		if len(args) == 0 && stdinPiped() {
			return nil
		}
		return fmt.Errorf("accepts %d arg(s), received %d", n, len(args))
	}
}

func resolveQueries(args []string) ([]string, error) {
	if len(args) > 0 {
		return args, nil
	}
	if !stdinPiped() {
		return nil, fmt.Errorf("no query provided")
	}
	lines, err := readStdinLines()
	if err != nil {
		return nil, err
	}
	if len(lines) == 0 {
		return nil, fmt.Errorf("no queries read from stdin")
	}
	return lines, nil
}

func forEachQuery(args []string, fn func(query string) error) error {
	queries, err := resolveQueries(args)
	if err != nil {
		return err
	}

	multi := len(queries) > 1
	var firstErr error
	for i, q := range queries {
		if multi {
			printQuerySeparator(i, q)
		}
		if err := fn(q); err != nil {
			fmt.Fprintf(os.Stderr, "Error for %q: %v\n", q, err)
			if firstErr == nil {
				firstErr = err
			}
		}
	}
	return firstErr
}

func printQuerySeparator(i int, q string) {
	if utils.IsTerminal() {
		if i > 0 {
			fmt.Println()
		}
		fmt.Printf("\033[1;36m━━━ %s ━━━\033[0m\n", q)
	} else {
		fmt.Printf("--- %s ---\n", q)
	}
}