cli/opts/gpus.go

112 строки
2.3 KiB
Go

package opts
import (
"encoding/csv"
"fmt"
"strconv"
"strings"
"github.com/docker/docker/api/types/container"
"github.com/pkg/errors"
)
// GpuOpts is a Value type for parsing mounts
type GpuOpts struct {
values []container.DeviceRequest
}
func parseCount(s string) (int, error) {
if s == "all" {
return -1, nil
}
i, err := strconv.Atoi(s)
return i, errors.Wrap(err, "count must be an integer")
}
// Set a new mount value
//
//nolint:gocyclo
func (o *GpuOpts) Set(value string) error {
csvReader := csv.NewReader(strings.NewReader(value))
fields, err := csvReader.Read()
if err != nil {
return err
}
req := container.DeviceRequest{}
seen := map[string]struct{}{}
// Set writable as the default
for _, field := range fields {
key, val, withValue := strings.Cut(field, "=")
if _, ok := seen[key]; ok {
return fmt.Errorf("gpu request key '%s' can be specified only once", key)
}
seen[key] = struct{}{}
if !withValue {
seen["count"] = struct{}{}
req.Count, err = parseCount(key)
if err != nil {
return err
}
continue
}
switch key {
case "driver":
req.Driver = val
case "count":
req.Count, err = parseCount(val)
if err != nil {
return err
}
case "device":
req.DeviceIDs = strings.Split(val, ",")
case "capabilities":
req.Capabilities = [][]string{append(strings.Split(val, ","), "gpu")}
case "options":
r := csv.NewReader(strings.NewReader(val))
optFields, err := r.Read()
if err != nil {
return errors.Wrap(err, "failed to read gpu options")
}
req.Options = ConvertKVStringsToMap(optFields)
default:
return fmt.Errorf("unexpected key '%s' in '%s'", key, field)
}
}
if _, ok := seen["count"]; !ok && req.DeviceIDs == nil {
req.Count = 1
}
if req.Options == nil {
req.Options = make(map[string]string)
}
if req.Capabilities == nil {
req.Capabilities = [][]string{{"gpu"}}
}
o.values = append(o.values, req)
return nil
}
// Type returns the type of this option
func (o *GpuOpts) Type() string {
return "gpu-request"
}
// String returns a string repr of this option
func (o *GpuOpts) String() string {
gpus := []string{}
for _, gpu := range o.values {
gpus = append(gpus, fmt.Sprintf("%v", gpu))
}
return strings.Join(gpus, ", ")
}
// Value returns the mounts
func (o *GpuOpts) Value() []container.DeviceRequest {
return o.values
}