diff --git a/.github/workflows/agent.yml b/.github/workflows/agent.yml index b0bd7d5d..f2a7a789 100644 --- a/.github/workflows/agent.yml +++ b/.github/workflows/agent.yml @@ -61,13 +61,6 @@ jobs: with: go-version: "1.20.14" - # WARNING: This is temporary remove it after the next release of gopsutil - - name: Patch gopsutil - run: | - sudo rm -rf $(printf "%s%s" "$(go env GOMODCACHE)" "/github.com/shirou/gopsutil/v4@v4.24.9") - go mod tidy -v - curl -L https://github.com/shirou/gopsutil/pull/1722.patch | sudo patch -p1 -d $(printf "%s%s" "$(go env GOMODCACHE)" "/github.com/shirou/gopsutil/v4@v4.24.9") - - name: Build Test if: github.event_name != 'push' || !contains(github.ref, 'refs/tags/') uses: goreleaser/goreleaser-action@v6 diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 371a9dd6..ef022470 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "context" "crypto/md5" "crypto/tls" @@ -12,7 +13,6 @@ import ( "net/http" "net/url" "os" - "os/exec" "path/filepath" "runtime" "strings" @@ -45,10 +45,10 @@ import ( ) var ( - version string + version = monitor.Version // 来自于 GoReleaser 的版本号 arch string - defaultConfigPath string executablePath string + defaultConfigPath = loadDefaultConfigPath() client pb.NezhaServiceClient initialized bool dnsResolver = &net.Resolver{PreferGo: true} @@ -73,7 +73,7 @@ const ( networkTimeOut = time.Second * 5 // 普通网络超时 ) -func init() { +func setEnv() { resolver.SetDefaultScheme("passthrough") net.DefaultResolver.PreferGo = true // 使用 Go 内置的 DNS 解析器解析域名 net.DefaultResolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { @@ -102,17 +102,67 @@ func init() { utls.HelloChrome_Auto, new(utls.Config), http.DefaultTransport, nil, &headers, ) +} - // 来自于 GoReleaser 的版本号 - monitor.Version = version - +func loadDefaultConfigPath() string { var err error executablePath, err = os.Executable() if err != nil { - panic(err) + return "" + } + return filepath.Join(filepath.Dir(executablePath), "config.yml") +} + +func preRun(configPath string) error { + // init + setEnv() + + if configPath == "" { + configPath = defaultConfigPath + } + + // windows环境处理 + if runtime.GOOS == "windows" { + hostArch, err := host.KernelArch() + if err != nil { + return err + } + switch hostArch { + case "i386", "i686": + hostArch = "386" + case "x86_64": + hostArch = "amd64" + case "aarch64": + hostArch = "arm64" + } + if arch != hostArch { + return fmt.Errorf("与当前系统不匹配,当前运行 %s_%s, 需要下载 %s_%s", runtime.GOOS, arch, runtime.GOOS, hostArch) + } } - defaultConfigPath = filepath.Join(filepath.Dir(executablePath), "config.yml") + if err := agentConfig.Read(configPath); err != nil { + return fmt.Errorf("打开配置文件失败:%v", err) + } + + monitor.InitConfig(&agentConfig) + + if agentConfig.ClientSecret == "" { + return errors.New("ClientSecret 不能为空") + } + + if agentConfig.ReportDelay < 1 || agentConfig.ReportDelay > 4 { + return errors.New("report-delay 的区间为 1-4") + } + + if agentConfig.UUID == "" { + if uuid, err := uuid.GenerateUUID(); err == nil { + agentConfig.UUID = uuid + return agentConfig.Save() + } else { + return fmt.Errorf("生成 UUID 失败:%v", err) + } + } + return nil } func main() { @@ -129,9 +179,13 @@ func main() { return nil } if path := c.String("config"); path != "" { - preRun(path) + if err := preRun(path); err != nil { + return err + } } else { - preRun(defaultConfigPath) + if err := preRun(""); err != nil { + return err + } } runService("", "") return nil @@ -181,51 +235,6 @@ func main() { } } -func preRun(configPath string) { - // windows环境处理 - if runtime.GOOS == "windows" { - hostArch, err := host.KernelArch() - if err != nil { - panic(err) - } - if hostArch == "i386" { - hostArch = "386" - } - if hostArch == "i686" || hostArch == "ia64" || hostArch == "x86_64" { - hostArch = "amd64" - } - if hostArch == "aarch64" { - hostArch = "arm64" - } - if arch != hostArch { - panic(fmt.Sprintf("与当前系统不匹配,当前运行 %s_%s, 需要下载 %s_%s", runtime.GOOS, arch, runtime.GOOS, hostArch)) - } - } - - if err := agentConfig.Read(configPath); err != nil { - log.Fatalf("打开配置文件失败:%v", err) - } - - monitor.InitConfig(&agentConfig) - - if agentConfig.ClientSecret == "" { - log.Fatal("ClientSecret 不能为空") - } - - if agentConfig.ReportDelay < 1 || agentConfig.ReportDelay > 4 { - log.Fatal("report-delay 的区间为 1-4") - } - - if agentConfig.UUID == "" { - if uuid, err := uuid.GenerateUUID(); err == nil { - agentConfig.UUID = uuid - agentConfig.Save() - } else { - log.Fatalf("生成 UUID 失败:%v", err) - } - } -} - func run() { auth := model.AuthHandler{ ClientSecret: agentConfig.ClientSecret, @@ -645,8 +654,7 @@ func handleCommandTask(task *pb.Task, result *pb.TaskResult) { return } startedAt := time.Now() - var cmd *exec.Cmd - var endCh = make(chan struct{}) + endCh := make(chan struct{}) pg, err := processgroup.NewProcessExitGroup() if err != nil { // 进程组创建失败,直接退出 @@ -654,12 +662,14 @@ func handleCommandTask(task *pb.Task, result *pb.TaskResult) { return } timeout := time.NewTimer(time.Hour * 2) - if util.IsWindows() { - cmd = exec.Command("cmd", "/c", task.GetData()) // #nosec - } else { - cmd = exec.Command("sh", "-c", task.GetData()) // #nosec - } + cmd := processgroup.NewCommand(task.GetData()) + var b bytes.Buffer + cmd.Stdout = &b cmd.Env = os.Environ() + if err = cmd.Start(); err != nil { + result.Data = err.Error() + return + } pg.AddProcess(cmd) go func() { select { @@ -671,12 +681,11 @@ func handleCommandTask(task *pb.Task, result *pb.TaskResult) { timeout.Stop() } }() - output, err := cmd.Output() - if err != nil { - result.Data += fmt.Sprintf("%s\n%s", string(output), err.Error()) + if err = cmd.Wait(); err != nil { + result.Data += fmt.Sprintf("%s\n%s", b.String(), err.Error()) } else { close(endCh) - result.Data = string(output) + result.Data = b.String() result.Successful = true } pg.Dispose() diff --git a/go.mod b/go.mod index 022fa3ac..5d8b0c87 100644 --- a/go.mod +++ b/go.mod @@ -7,10 +7,10 @@ require ( github.com/UserExistsError/conpty v0.1.4 github.com/artdarek/go-unzip v1.0.0 github.com/blang/semver v3.5.1+incompatible - github.com/creack/pty v1.1.23 + github.com/creack/pty v1.1.24 github.com/dean2021/goss v0.0.0-20230129073947-df90431348f1 github.com/ebi-yade/altsvc-go v0.1.1 - github.com/ebitengine/purego v0.8.0 + github.com/ebitengine/purego v0.8.1 github.com/hashicorp/go-uuid v1.0.3 github.com/iamacarpet/go-winpty v1.0.4 github.com/jaypipes/ghw v0.12.0 @@ -20,12 +20,12 @@ require ( github.com/prometheus-community/pro-bing v0.4.1 github.com/quic-go/quic-go v0.40.1 github.com/refraction-networking/utls v1.6.3 - github.com/shirou/gopsutil/v4 v4.24.9 + github.com/shirou/gopsutil/v4 v4.24.10 github.com/spf13/viper v1.19.0 github.com/tidwall/gjson v1.18.0 github.com/urfave/cli/v2 v2.27.5 golang.org/x/net v0.29.0 - golang.org/x/sys v0.25.0 + golang.org/x/sys v0.26.0 google.golang.org/grpc v1.64.1 google.golang.org/protobuf v1.34.2 sigs.k8s.io/yaml v1.4.0 diff --git a/go.sum b/go.sum index a06680bd..25857a4e 100644 --- a/go.sum +++ b/go.sum @@ -25,8 +25,8 @@ github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBS github.com/cpuguy83/go-md2man/v2 v2.0.5 h1:ZtcqGrnekaHpVLArFSe4HK5DoKx1T0rq2DwVB0alcyc= github.com/cpuguy83/go-md2man/v2 v2.0.5/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= -github.com/creack/pty v1.1.23 h1:4M6+isWdcStXEf15G/RbrMPOQj1dZ7HPZCGwE4kOeP0= -github.com/creack/pty v1.1.23/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= +github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= +github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -34,8 +34,8 @@ github.com/dean2021/goss v0.0.0-20230129073947-df90431348f1 h1:5UiJ324LiCdOF/3w/ github.com/dean2021/goss v0.0.0-20230129073947-df90431348f1/go.mod h1:NiLueuVb3hYcdF4ta+2ezcKJh6BEjhrBz9Hts6XJ5Sc= github.com/ebi-yade/altsvc-go v0.1.1 h1:HmZDNb5ZOPlkyXhi34LnRckawFCux7yPYw+dtInIixo= github.com/ebi-yade/altsvc-go v0.1.1/go.mod h1:K/U20bLcsOVrbTeDhqRjp+e3tgNT5iAqSiQzPoU0/Q0= -github.com/ebitengine/purego v0.8.0 h1:JbqvnEzRvPpxhCJzJJ2y0RbiZ8nyjccVUrSM3q+GvvE= -github.com/ebitengine/purego v0.8.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +github.com/ebitengine/purego v0.8.1 h1:sdRKd6plj7KYW33EH5As6YKfe8m9zbN9JMrOjNVF/BE= +github.com/ebitengine/purego v0.8.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= @@ -140,8 +140,8 @@ github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6ke github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= -github.com/shirou/gopsutil/v4 v4.24.9 h1:KIV+/HaHD5ka5f570RZq+2SaeFsb/pq+fp2DGNWYoOI= -github.com/shirou/gopsutil/v4 v4.24.9/go.mod h1:3fkaHNeYsUFCGZ8+9vZVWtbyM1k2eRnlL+bWO8Bxa/Q= +github.com/shirou/gopsutil/v4 v4.24.10 h1:7VOzPtfw/5YDU+jLEoBwXwxJbQetULywoSV4RYY7HkM= +github.com/shirou/gopsutil/v4 v4.24.10/go.mod h1:s4D/wg+ag4rG0WO7AiTj2BeYCRhym0vM7DHbZRxnIT8= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= @@ -229,8 +229,8 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= -golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= diff --git a/pkg/gpu/gpu_fallback.go b/pkg/gpu/gpu_fallback.go deleted file mode 100644 index 044eafc3..00000000 --- a/pkg/gpu/gpu_fallback.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build !darwin && !linux && !windows - -package gpu - -func GetGPUModel() ([]string, error) { - return nil, nil -} - -func GetGPUStat() ([]float64, error) { - return nil, nil -} diff --git a/pkg/monitor/conn/conn_fallback.go b/pkg/monitor/conn/conn_fallback.go new file mode 100644 index 00000000..7e18cf68 --- /dev/null +++ b/pkg/monitor/conn/conn_fallback.go @@ -0,0 +1,26 @@ +//go:build !linux + +package conn + +import ( + "context" + "syscall" + + "github.com/shirou/gopsutil/v4/net" +) + +func GetState(_ context.Context) ([]uint64, error) { + var tcpConnCount, udpConnCount uint64 + + conns, _ := net.Connections("all") + for i := 0; i < len(conns); i++ { + switch conns[i].Type { + case syscall.SOCK_STREAM: + tcpConnCount++ + case syscall.SOCK_DGRAM: + udpConnCount++ + } + } + + return []uint64{tcpConnCount, udpConnCount}, nil +} diff --git a/pkg/monitor/conn/conn_linux.go b/pkg/monitor/conn/conn_linux.go new file mode 100644 index 00000000..dea8454a --- /dev/null +++ b/pkg/monitor/conn/conn_linux.go @@ -0,0 +1,50 @@ +//go:build linux + +package conn + +import ( + "context" + "syscall" + + "github.com/dean2021/goss" + "github.com/shirou/gopsutil/v4/net" +) + +func GetState(_ context.Context) ([]uint64, error) { + var tcpConnCount, udpConnCount uint64 + + tcpStat, err := goss.ConnectionsWithProtocol(goss.AF_INET, syscall.IPPROTO_TCP) + if err == nil { + tcpConnCount = uint64(len(tcpStat)) + } + + udpStat, err := goss.ConnectionsWithProtocol(goss.AF_INET, syscall.IPPROTO_UDP) + if err == nil { + udpConnCount = uint64(len(udpStat)) + } + + tcpStat6, err := goss.ConnectionsWithProtocol(goss.AF_INET6, syscall.IPPROTO_TCP) + if err == nil { + tcpConnCount += uint64(len(tcpStat6)) + } + + udpStat6, err := goss.ConnectionsWithProtocol(goss.AF_INET6, syscall.IPPROTO_UDP) + if err == nil { + udpConnCount += uint64(len(udpStat6)) + } + + if tcpConnCount < 1 && udpConnCount < 1 { + // fallback to parsing files + conns, _ := net.Connections("all") + for _, conn := range conns { + switch conn.Type { + case syscall.SOCK_STREAM: + tcpConnCount++ + case syscall.SOCK_DGRAM: + udpConnCount++ + } + } + } + + return []uint64{tcpConnCount, udpConnCount}, nil +} diff --git a/pkg/monitor/cpu/cpu.go b/pkg/monitor/cpu/cpu.go new file mode 100644 index 00000000..84355381 --- /dev/null +++ b/pkg/monitor/cpu/cpu.go @@ -0,0 +1,50 @@ +package cpu + +import ( + "context" + "fmt" + + psCpu "github.com/shirou/gopsutil/v4/cpu" +) + +type CPUHostType string + +const CPUHostKey CPUHostType = "cpu" + +func GetHost(ctx context.Context) ([]string, error) { + ci, err := psCpu.InfoWithContext(ctx) + if err != nil { + return nil, err + } + + cpuModelCount := make(map[string]int) + for _, c := range ci { + cpuModelCount[c.ModelName]++ + } + + var cpuType string + if t, ok := ctx.Value(CPUHostKey).(string); ok { + cpuType = t + } + + var ch []string + u := len(ci) > 1 + for model, count := range cpuModelCount { + if u { + ch = append(ch, fmt.Sprintf("%s %d %s Core", model, count, cpuType)) + } else { + ch = append(ch, fmt.Sprintf("%s %d %s Core", model, ci[0].Cores, cpuType)) + } + } + + return ch, nil +} + +func GetState(ctx context.Context) ([]float64, error) { + cp, err := psCpu.PercentWithContext(ctx, 0, false) + if err != nil { + return nil, err + } + + return cp, nil +} diff --git a/pkg/monitor/disk/disk.go b/pkg/monitor/disk/disk.go new file mode 100644 index 00000000..ea560fcd --- /dev/null +++ b/pkg/monitor/disk/disk.go @@ -0,0 +1,122 @@ +package disk + +import ( + "context" + "os/exec" + "runtime" + "strconv" + "strings" + + psDisk "github.com/shirou/gopsutil/v4/disk" + + "github.com/nezhahq/agent/pkg/util" +) + +type DiskKeyType string + +const DiskKey DiskKeyType = "disk" + +var expectDiskFsTypes = []string{ + "apfs", "ext4", "ext3", "ext2", "f2fs", "reiserfs", "jfs", "btrfs", + "fuseblk", "zfs", "simfs", "ntfs", "fat32", "exfat", "xfs", "fuse.rclone", +} + +func GetHost(ctx context.Context) (uint64, error) { + devices, err := getDevices(ctx) + if err != nil { + return 0, err + } + + var total uint64 + for _, mountPath := range devices { + diskUsageOf, err := psDisk.Usage(mountPath) + if err == nil { + total += diskUsageOf.Total + } + } + + // Fallback 到这个方法,仅统计根路径,适用于OpenVZ之类的. + if runtime.GOOS == "linux" && total == 0 { + cmd := exec.Command("df") + out, err := cmd.CombinedOutput() + if err == nil { + s := strings.Split(string(out), "\n") + for _, c := range s { + info := strings.Fields(c) + if len(info) == 6 { + if info[5] == "/" { + total, _ = strconv.ParseUint(info[1], 0, 64) + // 默认获取的是1K块为单位的. + total = total * 1024 + } + } + } + } + } + + return total, nil +} + +func GetState(ctx context.Context) (uint64, error) { + devices, err := getDevices(ctx) + if err != nil { + return 0, err + } + + var used uint64 + for _, mountPath := range devices { + diskUsageOf, err := psDisk.Usage(mountPath) + if err == nil { + used += diskUsageOf.Used + } + } + + // Fallback 到这个方法,仅统计根路径,适用于OpenVZ之类的. + if runtime.GOOS == "linux" && used == 0 { + cmd := exec.Command("df") + out, err := cmd.CombinedOutput() + if err == nil { + s := strings.Split(string(out), "\n") + for _, c := range s { + info := strings.Fields(c) + if len(info) == 6 { + if info[5] == "/" { + used, _ = strconv.ParseUint(info[2], 0, 64) + // 默认获取的是1K块为单位的. + used = used * 1024 + } + } + } + } + } + + return used, nil +} + +func getDevices(ctx context.Context) (map[string]string, error) { + devices := make(map[string]string) + + // 如果配置了白名单,使用白名单的列表 + if s, ok := ctx.Value(DiskKey).([]string); ok && len(s) > 0 { + for i, v := range s { + devices[strconv.Itoa(i)] = v + } + return devices, nil + } + + // 否则使用默认过滤规则 + diskList, err := psDisk.Partitions(false) + if err != nil { + return nil, err + } + + for _, d := range diskList { + fsType := strings.ToLower(d.Fstype) + // 不统计 K8s 的虚拟挂载点:https://github.com/shirou/gopsutil/issues/1007 + if devices[d.Device] == "" && util.ContainsStr(expectDiskFsTypes, fsType) && !strings.Contains(d.Mountpoint, "/var/lib/kubelet") { + devices[d.Device] = d.Mountpoint + } + } + + return devices, nil +} diff --git a/pkg/gpu/gpu_darwin.go b/pkg/monitor/gpu/gpu_darwin.go similarity index 98% rename from pkg/gpu/gpu_darwin.go rename to pkg/monitor/gpu/gpu_darwin.go index 25057e78..9e6f09d2 100644 --- a/pkg/gpu/gpu_darwin.go +++ b/pkg/monitor/gpu/gpu_darwin.go @@ -3,6 +3,7 @@ package gpu import ( + "context" "fmt" "unsafe" @@ -117,7 +118,7 @@ func init() { purego.RegisterLibFunc(&IOObjectRelease, ioKit, "IOObjectRelease") } -func GetGPUModel() ([]string, error) { +func GetHost(_ context.Context) ([]string, error) { models, err := findDevices("model") if err != nil { return nil, err @@ -125,7 +126,7 @@ func GetGPUModel() ([]string, error) { return util.RemoveDuplicate(models), nil } -func GetGPUStat() ([]float64, error) { +func GetState(_ context.Context) ([]float64, error) { usage, err := findUtilization("PerformanceStatistics", "Device Utilization %") return []float64{float64(usage)}, err } diff --git a/pkg/monitor/gpu/gpu_fallback.go b/pkg/monitor/gpu/gpu_fallback.go new file mode 100644 index 00000000..48ce1f5d --- /dev/null +++ b/pkg/monitor/gpu/gpu_fallback.go @@ -0,0 +1,13 @@ +//go:build !darwin && !linux && !windows + +package gpu + +import "context" + +func GetHost(_ context.Context) ([]string, error) { + return nil, nil +} + +func GetState(_ context.Context) ([]float64, error) { + return nil, nil +} diff --git a/pkg/gpu/gpu_linux.go b/pkg/monitor/gpu/gpu_linux.go similarity index 86% rename from pkg/gpu/gpu_linux.go rename to pkg/monitor/gpu/gpu_linux.go index 7ba0d913..27918452 100644 --- a/pkg/gpu/gpu_linux.go +++ b/pkg/monitor/gpu/gpu_linux.go @@ -3,9 +3,10 @@ package gpu import ( + "context" "errors" - "github.com/nezhahq/agent/pkg/gpu/vendor" + "github.com/nezhahq/agent/pkg/monitor/gpu/vendor" ) const ( @@ -13,14 +14,14 @@ const ( vendorNVIDIA ) -var vendorType uint8 +var vendorType = getVendor() -func init() { +func getVendor() uint8 { _, err := getNvidiaStat() if err != nil { - vendorType = vendorAMD + return vendorAMD } else { - vendorType = vendorNVIDIA + return vendorNVIDIA } } @@ -84,7 +85,7 @@ func getAMDHost() ([]string, error) { return data, nil } -func GetGPUModel() ([]string, error) { +func GetHost(_ context.Context) ([]string, error) { var gi []string var err error @@ -104,7 +105,7 @@ func GetGPUModel() ([]string, error) { return gi, nil } -func GetGPUStat() ([]float64, error) { +func GetState(_ context.Context) ([]float64, error) { var gs []float64 var err error diff --git a/pkg/gpu/gpu_windows.go b/pkg/monitor/gpu/gpu_windows.go similarity index 97% rename from pkg/gpu/gpu_windows.go rename to pkg/monitor/gpu/gpu_windows.go index 8aca70a6..1a6a850e 100644 --- a/pkg/gpu/gpu_windows.go +++ b/pkg/monitor/gpu/gpu_windows.go @@ -3,6 +3,7 @@ package gpu import ( + "context" "errors" "fmt" "time" @@ -41,7 +42,7 @@ type PDH_FMT_COUNTERVALUE_ITEM_DOUBLE struct { FmtValue PDH_FMT_COUNTERVALUE_DOUBLE } -func GetGPUModel() ([]string, error) { +func GetHost(_ context.Context) ([]string, error) { var gpuModel []string gi, err := ghw.GPU(ghw.WithDisableWarnings()) if err != nil { @@ -58,7 +59,7 @@ func GetGPUModel() ([]string, error) { return gpuModel, nil } -func GetGPUStat() ([]float64, error) { +func GetState(_ context.Context) ([]float64, error) { counter, err := newWin32PerformanceCounter("gpu_utilization", "\\GPU Engine(*engtype_3D)\\Utilization Percentage") if err != nil { return nil, err diff --git a/pkg/gpu/vendor/amd_rocm_smi.go b/pkg/monitor/gpu/vendor/amd_rocm_smi.go similarity index 100% rename from pkg/gpu/vendor/amd_rocm_smi.go rename to pkg/monitor/gpu/vendor/amd_rocm_smi.go diff --git a/pkg/gpu/vendor/nvidia_smi.go b/pkg/monitor/gpu/vendor/nvidia_smi.go similarity index 100% rename from pkg/gpu/vendor/nvidia_smi.go rename to pkg/monitor/gpu/vendor/nvidia_smi.go diff --git a/pkg/monitor/load/load.go b/pkg/monitor/load/load.go new file mode 100644 index 00000000..810985bc --- /dev/null +++ b/pkg/monitor/load/load.go @@ -0,0 +1,11 @@ +package load + +import ( + "context" + + psLoad "github.com/shirou/gopsutil/v4/load" +) + +func GetState(ctx context.Context) (*psLoad.AvgStat, error) { + return psLoad.AvgWithContext(ctx) +} diff --git a/pkg/monitor/monitor.go b/pkg/monitor/monitor.go index b20cfc6f..1790bafa 100644 --- a/pkg/monitor/monitor.go +++ b/pkg/monitor/monitor.go @@ -1,44 +1,28 @@ package monitor import ( - "fmt" - "os/exec" + "context" "runtime" - "sort" - "strconv" - "strings" "sync/atomic" - "syscall" "time" - "github.com/dean2021/goss" - "github.com/shirou/gopsutil/v4/cpu" - "github.com/shirou/gopsutil/v4/disk" "github.com/shirou/gopsutil/v4/host" - "github.com/shirou/gopsutil/v4/load" "github.com/shirou/gopsutil/v4/mem" - "github.com/shirou/gopsutil/v4/net" "github.com/shirou/gopsutil/v4/process" - "github.com/shirou/gopsutil/v4/sensors" "github.com/nezhahq/agent/model" - "github.com/nezhahq/agent/pkg/gpu" + "github.com/nezhahq/agent/pkg/monitor/conn" + "github.com/nezhahq/agent/pkg/monitor/cpu" + "github.com/nezhahq/agent/pkg/monitor/disk" + "github.com/nezhahq/agent/pkg/monitor/gpu" + "github.com/nezhahq/agent/pkg/monitor/load" + "github.com/nezhahq/agent/pkg/monitor/nic" + "github.com/nezhahq/agent/pkg/monitor/temperature" "github.com/nezhahq/agent/pkg/util" ) var ( - Version string - expectDiskFsTypes = []string{ - "apfs", "ext4", "ext3", "ext2", "f2fs", "reiserfs", "jfs", "btrfs", - "fuseblk", "zfs", "simfs", "ntfs", "fat32", "exfat", "xfs", "fuse.rclone", - } - excludeNetInterfaces = []string{ - "lo", "tun", "docker", "veth", "br-", "vmbr", "vnet", "kube", - } - sensorIgnoreList = []string{ - "PMU tcal", // the calibration sensor on arm macs, value is fixed - "noname", - } + Version string agentConfig *model.AgentConfig ) @@ -51,18 +35,25 @@ var ( // 获取设备数据的最大尝试次数 const maxDeviceDataFetchAttempts = 3 +const ( + CPU = iota + 1 + GPU + Load + Temperatures +) + // 获取主机数据的尝试次数,Key 为 Host 的属性名 -var hostDataFetchAttempts = map[string]int{ - "CPU": 0, - "GPU": 0, +var hostDataFetchAttempts = map[uint8]uint8{ + CPU: 0, + GPU: 0, } // 获取状态数据的尝试次数,Key 为 HostState 的属性名 -var statDataFetchAttempts = map[string]int{ - "CPU": 0, - "Load": 0, - "GPU": 0, - "Temperatures": 0, +var statDataFetchAttempts = map[uint8]uint8{ + CPU: 0, + GPU: 0, + Load: 0, + Temperatures: 0, } var ( @@ -95,40 +86,14 @@ func GetHost() *model.Host { ret.BootTime = hi.BootTime } - cpuModelCount := make(map[string]int) - if hostDataFetchAttempts["CPU"] < maxDeviceDataFetchAttempts { - ci, err := cpu.Info() - if err != nil { - hostDataFetchAttempts["CPU"]++ - printf("cpu.Info error: %v, attempt: %d", err, hostDataFetchAttempts["CPU"]) - } else { - hostDataFetchAttempts["CPU"] = 0 - for i := 0; i < len(ci); i++ { - cpuModelCount[ci[i].ModelName]++ - } - for model, count := range cpuModelCount { - if len(ci) > 1 { - ret.CPU = append(ret.CPU, fmt.Sprintf("%s %d %s Core", model, count, cpuType)) - } else { - ret.CPU = append(ret.CPU, fmt.Sprintf("%s %d %s Core", model, ci[0].Cores, cpuType)) - } - } - } - } + ctxCpu := context.WithValue(context.Background(), cpu.CPUHostKey, cpuType) + ret.CPU = tryHost(ctxCpu, CPU, cpu.GetHost) if agentConfig.GPU { - if hostDataFetchAttempts["GPU"] < maxDeviceDataFetchAttempts { - ret.GPU, err = gpu.GetGPUModel() - if err != nil { - hostDataFetchAttempts["GPU"]++ - printf("gpu.GetGPUModel error: %v, attempt: %d", err, hostDataFetchAttempts["GPU"]) - } else { - hostDataFetchAttempts["GPU"] = 0 - } - } + ret.GPU = tryHost(context.Background(), GPU, gpu.GetHost) } - ret.DiskTotal, _ = getDiskTotalAndUsed() + ret.DiskTotal = getDiskTotal() mv, err := mem.VirtualMemory() if err != nil { @@ -160,15 +125,9 @@ func GetHost() *model.Host { func GetState(skipConnectionCount bool, skipProcsCount bool) *model.HostState { var ret model.HostState - if statDataFetchAttempts["CPU"] < maxDeviceDataFetchAttempts { - cp, err := cpu.Percent(0, false) - if err != nil || len(cp) == 0 { - statDataFetchAttempts["CPU"]++ - printf("cpu.Percent error: %v, attempt: %d", err, statDataFetchAttempts["CPU"]) - } else { - statDataFetchAttempts["CPU"] = 0 - ret.CPU = cp[0] - } + cp := tryStat(context.Background(), CPU, cpu.GetState) + if len(cp) > 0 { + ret.CPU = cp[0] } vm, err := mem.VirtualMemory() @@ -190,20 +149,12 @@ func GetState(skipConnectionCount bool, skipProcsCount bool) *model.HostState { } } - _, ret.DiskUsed = getDiskTotalAndUsed() + ret.DiskUsed = getDiskUsed() - if statDataFetchAttempts["Load"] < maxDeviceDataFetchAttempts { - loadStat, err := load.Avg() - if err != nil { - statDataFetchAttempts["Load"]++ - printf("load.Avg error: %v, attempt: %d", err, statDataFetchAttempts["Load"]) - } else { - statDataFetchAttempts["Load"] = 0 - ret.Load1 = loadStat.Load1 - ret.Load5 = loadStat.Load5 - ret.Load15 = loadStat.Load15 - } - } + loadStat := tryStat(context.Background(), Load, load.GetState) + ret.Load1 = loadStat.Load1 + ret.Load5 = loadStat.Load5 + ret.Load15 = loadStat.Load15 var procs []int32 if !skipProcsCount { @@ -220,12 +171,15 @@ func GetState(skipConnectionCount bool, skipProcsCount bool) *model.HostState { ret.Temperatures = temperatureStat } - ret.GPU = updateGPUStat() + ret.GPU = tryStat(context.Background(), GPU, gpu.GetState) ret.NetInTransfer, ret.NetOutTransfer = netInTransfer, netOutTransfer ret.NetInSpeed, ret.NetOutSpeed = netInSpeed, netOutSpeed ret.Uptime = uint64(time.Since(cachedBootTime).Seconds()) - ret.TcpConnCount, ret.UdpConnCount = getConns(skipConnectionCount) + + if !skipConnectionCount { + ret.TcpConnCount, ret.UdpConnCount = getConns() + } return &ret } @@ -233,136 +187,52 @@ func GetState(skipConnectionCount bool, skipProcsCount bool) *model.HostState { // TrackNetworkSpeed NIC监控,统计流量与速度 func TrackNetworkSpeed() { var innerNetInTransfer, innerNetOutTransfer uint64 - nc, err := net.IOCounters(true) - if err == nil { - for _, v := range nc { - if len(agentConfig.NICAllowlist) > 0 { - if !agentConfig.NICAllowlist[v.Name] { - continue - } - } else { - if util.ContainsStr(excludeNetInterfaces, v.Name) { - continue - } - } - innerNetInTransfer += v.BytesRecv - innerNetOutTransfer += v.BytesSent - } - now := uint64(time.Now().Unix()) - diff := now - lastUpdateNetStats - if diff > 0 { - netInSpeed = (innerNetInTransfer - netInTransfer) / diff - netOutSpeed = (innerNetOutTransfer - netOutTransfer) / diff - } - netInTransfer = innerNetInTransfer - netOutTransfer = innerNetOutTransfer - lastUpdateNetStats = now + + ctx := context.WithValue(context.Background(), nic.NICKey, agentConfig.NICAllowlist) + nc, err := nic.GetState(ctx) + if err != nil { + return } -} -func getDiskTotalAndUsed() (total uint64, used uint64) { - devices := make(map[string]string) + innerNetInTransfer = nc[0] + innerNetOutTransfer = nc[1] - if len(agentConfig.HardDrivePartitionAllowlist) > 0 { - // 如果配置了白名单,使用白名单的列表 - for i, v := range agentConfig.HardDrivePartitionAllowlist { - devices[strconv.Itoa(i)] = v - } - } else { - // 否则使用默认过滤规则 - diskList, _ := disk.Partitions(false) - for _, d := range diskList { - fsType := strings.ToLower(d.Fstype) - // 不统计 K8s 的虚拟挂载点:https://github.com/shirou/gopsutil/issues/1007 - if devices[d.Device] == "" && util.ContainsStr(expectDiskFsTypes, fsType) && !strings.Contains(d.Mountpoint, "/var/lib/kubelet") { - devices[d.Device] = d.Mountpoint - } - } + now := uint64(time.Now().Unix()) + diff := now - lastUpdateNetStats + if diff > 0 { + netInSpeed = (innerNetInTransfer - netInTransfer) / diff + netOutSpeed = (innerNetOutTransfer - netOutTransfer) / diff } + netInTransfer = innerNetInTransfer + netOutTransfer = innerNetOutTransfer + lastUpdateNetStats = now +} - for _, mountPath := range devices { - diskUsageOf, err := disk.Usage(mountPath) - if err == nil { - total += diskUsageOf.Total - used += diskUsageOf.Used - } - } +func getDiskTotal() uint64 { + ctx := context.WithValue(context.Background(), disk.DiskKey, agentConfig.HardDrivePartitionAllowlist) + total, _ := disk.GetHost(ctx) - // Fallback 到这个方法,仅统计根路径,适用于OpenVZ之类的. - if runtime.GOOS == "linux" && total == 0 && used == 0 { - cmd := exec.Command("df") - out, err := cmd.CombinedOutput() - if err == nil { - s := strings.Split(string(out), "\n") - for _, c := range s { - info := strings.Fields(c) - if len(info) == 6 { - if info[5] == "/" { - total, _ = strconv.ParseUint(info[1], 0, 64) - used, _ = strconv.ParseUint(info[2], 0, 64) - // 默认获取的是1K块为单位的. - total = total * 1024 - used = used * 1024 - } - } - } - } - } + return total +} + +func getDiskUsed() uint64 { + ctx := context.WithValue(context.Background(), disk.DiskKey, agentConfig.HardDrivePartitionAllowlist) + used, _ := disk.GetState(ctx) - return + return used } -func getConns(skipConnectionCount bool) (tcpConnCount, udpConnCount uint64) { - if !skipConnectionCount { - ss_err := true - if runtime.GOOS == "linux" { - tcpStat, err_tcp := goss.ConnectionsWithProtocol(goss.AF_INET, syscall.IPPROTO_TCP) - udpStat, err_udp := goss.ConnectionsWithProtocol(goss.AF_INET, syscall.IPPROTO_UDP) - if err_tcp == nil && err_udp == nil { - ss_err = false - tcpConnCount = uint64(len(tcpStat)) - udpConnCount = uint64(len(udpStat)) - } - if strings.Contains(CachedIP, ":") { - tcpStat6, err_tcp := goss.ConnectionsWithProtocol(goss.AF_INET6, syscall.IPPROTO_TCP) - udpStat6, err_udp := goss.ConnectionsWithProtocol(goss.AF_INET6, syscall.IPPROTO_UDP) - if err_tcp == nil && err_udp == nil { - ss_err = false - tcpConnCount += uint64(len(tcpStat6)) - udpConnCount += uint64(len(udpStat6)) - } - } - } - if ss_err { - conns, _ := net.Connections("all") - for i := 0; i < len(conns); i++ { - switch conns[i].Type { - case syscall.SOCK_STREAM: - tcpConnCount++ - case syscall.SOCK_DGRAM: - udpConnCount++ - } - } - } +func getConns() (tcpConnCount, udpConnCount uint64) { + connStat, err := conn.GetState(context.Background()) + if err != nil { + return } - return tcpConnCount, udpConnCount -} -func updateGPUStat() []float64 { - if agentConfig.GPU { - if statDataFetchAttempts["GPU"] < maxDeviceDataFetchAttempts { - gs, err := gpu.GetGPUStat() - if err != nil { - statDataFetchAttempts["GPU"]++ - printf("gpustat.GetGPUStat error: %v, attempt: %d", err, statDataFetchAttempts["GPU"]) - return nil - } else { - statDataFetchAttempts["GPU"] = 0 - return gs - } - } + if len(connStat) < 2 { + return } - return nil + + return connStat[0], connStat[1] } func updateTemperatureStat() { @@ -371,30 +241,44 @@ func updateTemperatureStat() { } defer updateTempStatus.Store(0) - if statDataFetchAttempts["Temperatures"] < maxDeviceDataFetchAttempts { - temperatures, err := sensors.SensorsTemperatures() + stat := tryStat(context.Background(), Temperatures, temperature.GetState) + temperatureStat = stat +} + +type hostStateFunc[T any] func(context.Context) (T, error) + +func tryHost[T any](ctx context.Context, typ uint8, f hostStateFunc[T]) T { + var val T + + if hostDataFetchAttempts[typ] < maxDeviceDataFetchAttempts { + v, err := f(ctx) + if err != nil { + hostDataFetchAttempts[typ]++ + printf("monitor error: %v, attempt: %d", err, hostDataFetchAttempts[typ]) + return val + } else { + val = v + hostDataFetchAttempts[typ] = 0 + } + } + return val +} + +func tryStat[T any](ctx context.Context, typ uint8, f hostStateFunc[T]) T { + var val T + + if statDataFetchAttempts[typ] < maxDeviceDataFetchAttempts { + v, err := f(ctx) if err != nil { - statDataFetchAttempts["Temperatures"]++ - printf("host.SensorsTemperatures error: %v, attempt: %d", err, statDataFetchAttempts["Temperatures"]) + statDataFetchAttempts[typ]++ + printf("monitor error: %v, attempt: %d", err, statDataFetchAttempts[typ]) + return val } else { - statDataFetchAttempts["Temperatures"] = 0 - tempStat := []model.SensorTemperature{} - for _, t := range temperatures { - if t.Temperature > 0 && !util.ContainsStr(sensorIgnoreList, t.SensorKey) { - tempStat = append(tempStat, model.SensorTemperature{ - Name: t.SensorKey, - Temperature: t.Temperature, - }) - } - } - - sort.Slice(tempStat, func(i, j int) bool { - return tempStat[i].Name < tempStat[j].Name - }) - - temperatureStat = tempStat + val = v + statDataFetchAttempts[typ] = 0 } } + return val } func printf(format string, v ...interface{}) { diff --git a/pkg/monitor/nic/nic.go b/pkg/monitor/nic/nic.go new file mode 100644 index 00000000..cf350c24 --- /dev/null +++ b/pkg/monitor/nic/nic.go @@ -0,0 +1,45 @@ +package nic + +import ( + "context" + + "github.com/shirou/gopsutil/v4/net" +) + +type NICKeyType string + +const NICKey NICKeyType = "nic" + +var excludeNetInterfaces = map[string]bool{ + "lo": true, + "tun": true, + "docker": true, + "veth": true, + "br-": true, + "vmbr": true, + "vnet": true, + "kube": true, +} + +func GetState(ctx context.Context) ([]uint64, error) { + var netInTransfer, netOutTransfer uint64 + nc, err := net.IOCountersWithContext(ctx, true) + if err != nil { + return nil, err + } + + allowList := excludeNetInterfaces + if m, ok := ctx.Value(NICKey).(map[string]bool); ok && len(m) > 0 { + allowList = m + } + + for _, v := range nc { + if !allowList[v.Name] { + continue + } + netInTransfer += v.BytesRecv + netOutTransfer += v.BytesSent + } + + return []uint64{netInTransfer, netOutTransfer}, nil +} diff --git a/pkg/monitor/temperature/temperature.go b/pkg/monitor/temperature/temperature.go new file mode 100644 index 00000000..64d1f916 --- /dev/null +++ b/pkg/monitor/temperature/temperature.go @@ -0,0 +1,40 @@ +package temperature + +import ( + "context" + "fmt" + "sort" + + "github.com/shirou/gopsutil/v4/sensors" + + "github.com/nezhahq/agent/model" + "github.com/nezhahq/agent/pkg/util" +) + +var sensorIgnoreList = []string{ + "PMU tcal", // the calibration sensor on arm macs, value is fixed + "noname", +} + +func GetState(_ context.Context) ([]model.SensorTemperature, error) { + temperatures, err := sensors.SensorsTemperatures() + if err != nil { + return nil, fmt.Errorf("SensorsTemperatures: %v", err) + } + + var tempStat []model.SensorTemperature + for _, t := range temperatures { + if t.Temperature > 0 && !util.ContainsStr(sensorIgnoreList, t.SensorKey) { + tempStat = append(tempStat, model.SensorTemperature{ + Name: t.SensorKey, + Temperature: t.Temperature, + }) + } + } + + sort.Slice(tempStat, func(i, j int) bool { + return tempStat[i].Name < tempStat[j].Name + }) + + return tempStat, nil +} diff --git a/pkg/processgroup/process_group.go b/pkg/processgroup/process_group.go index 2bda9bdc..49ad65f0 100644 --- a/pkg/processgroup/process_group.go +++ b/pkg/processgroup/process_group.go @@ -1,5 +1,4 @@ //go:build !windows -// +build !windows package processgroup @@ -17,38 +16,39 @@ func NewProcessExitGroup() (ProcessExitGroup, error) { return ProcessExitGroup{}, nil } -func (g *ProcessExitGroup) killChildProcess(c *exec.Cmd) error { - pgid, err := syscall.Getpgid(c.Process.Pid) - if err != nil { - // Fall-back on error. Kill the main process only. - c.Process.Kill() - } - // Kill the whole process group. - syscall.Kill(-pgid, syscall.SIGTERM) - return c.Wait() +func NewCommand(arg string) *exec.Cmd { + cmd := exec.Command("sh", "-c", arg) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + return cmd } -func (g *ProcessExitGroup) Dispose() []error { - var errors []error - mutex := new(sync.Mutex) - wg := new(sync.WaitGroup) +func (g *ProcessExitGroup) Dispose() error { + var wg sync.WaitGroup wg.Add(len(g.cmds)) + for _, c := range g.cmds { go func(c *exec.Cmd) { defer wg.Done() - if err := g.killChildProcess(c); err != nil { - mutex.Lock() - defer mutex.Unlock() - errors = append(errors, err) - } + killChildProcess(c) }(c) } + wg.Wait() - return errors + return nil } func (g *ProcessExitGroup) AddProcess(cmd *exec.Cmd) error { - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} g.cmds = append(g.cmds, cmd) return nil } + +func killChildProcess(c *exec.Cmd) { + pgid, err := syscall.Getpgid(c.Process.Pid) + if err != nil { + // Fall-back on error. Kill the main process only. + c.Process.Kill() + } + // Kill the whole process group. + syscall.Kill(-pgid, syscall.SIGTERM) + c.Wait() +} diff --git a/pkg/processgroup/process_group_windows.go b/pkg/processgroup/process_group_windows.go index 74f57b10..fb3c4c98 100644 --- a/pkg/processgroup/process_group_windows.go +++ b/pkg/processgroup/process_group_windows.go @@ -5,26 +5,80 @@ package processgroup import ( "fmt" "os/exec" + "unsafe" + + "golang.org/x/sys/windows" ) type ProcessExitGroup struct { - cmds []*exec.Cmd + cmds []*exec.Cmd + jobHandle windows.Handle + procs []windows.Handle } -func NewProcessExitGroup() (ProcessExitGroup, error) { - return ProcessExitGroup{}, nil +func NewProcessExitGroup() (*ProcessExitGroup, error) { + job, err := windows.CreateJobObject(nil, nil) + if err != nil { + return nil, err + } + + info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{ + BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{ + LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, + }, + } + + _, err = windows.SetInformationJobObject( + job, + windows.JobObjectExtendedLimitInformation, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info))) + + return &ProcessExitGroup{jobHandle: job}, nil } -func (g *ProcessExitGroup) Dispose() error { - for _, c := range g.cmds { - if err := exec.Command("taskkill", "/F", "/T", "/PID", fmt.Sprint(c.Process.Pid)).Run(); err != nil { - return err - } +func NewCommand(args string) *exec.Cmd { + cmd := exec.Command("cmd") + cmd.SysProcAttr = &windows.SysProcAttr{ + CmdLine: fmt.Sprintf("/c %s", args), + CreationFlags: windows.CREATE_NEW_PROCESS_GROUP, } - return nil + return cmd } func (g *ProcessExitGroup) AddProcess(cmd *exec.Cmd) error { + proc, err := windows.OpenProcess(windows.PROCESS_TERMINATE|windows.PROCESS_SET_QUOTA|windows.PROCESS_SET_INFORMATION, false, uint32(cmd.Process.Pid)) + if err != nil { + return err + } + + g.procs = append(g.procs, proc) g.cmds = append(g.cmds, cmd) + + return windows.AssignProcessToJobObject(g.jobHandle, proc) +} + +func (g *ProcessExitGroup) Dispose() error { + defer func() { + windows.CloseHandle(g.jobHandle) + for _, proc := range g.procs { + windows.CloseHandle(proc) + } + }() + + if err := windows.TerminateJobObject(g.jobHandle, 1); err != nil { + // Fall-back on error. Kill the main process only. + for _, cmd := range g.cmds { + cmd.Process.Kill() + } + return err + } + + // wait for job to be terminated + status, err := windows.WaitForSingleObject(g.jobHandle, windows.INFINITE) + if status != windows.WAIT_OBJECT_0 { + return err + } + return nil } diff --git a/pkg/pty/pty.go b/pkg/pty/pty.go index 949a5321..853b8a49 100644 --- a/pkg/pty/pty.go +++ b/pkg/pty/pty.go @@ -1,5 +1,4 @@ //go:build !windows -// +build !windows package pty diff --git a/pkg/pty/pty_windows.go b/pkg/pty/pty_windows.go index bba4a44c..18979fb9 100644 --- a/pkg/pty/pty_windows.go +++ b/pkg/pty/pty_windows.go @@ -19,7 +19,7 @@ import ( "github.com/shirou/gopsutil/v4/host" ) -var isWin10 bool +var isWin10 = VersionCheck() type winPTY struct { tty *winpty.WinPTY @@ -29,10 +29,6 @@ type conPty struct { tty *conpty.ConPty } -func init() { - isWin10 = VersionCheck() -} - func VersionCheck() bool { hi, err := host.Info() if err != nil {