diff --git a/cni/network/network.go b/cni/network/network.go index 9dc206a6d..a78ef7b91 100644 --- a/cni/network/network.go +++ b/cni/network/network.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "os" + "regexp" "strconv" "time" @@ -35,6 +36,9 @@ import ( "go.uber.org/zap" ) +// matches if the string fully consists of zero or more alphanumeric, dots, dashes, parentheses, spaces, or underscores +var allowedInput = regexp.MustCompile(`^[a-zA-Z0-9._\-\(\) ]*$`) + const ( dockerNetworkOption = "com.docker.network.generic" OpModeTransparent = "transparent" @@ -408,6 +412,11 @@ func (plugin *NetPlugin) Add(args *cniSkel.CmdArgs) error { return err } + if argErr := plugin.validateArgs(args, nwCfg); argErr != nil { + err = argErr + return err + } + iptables.DisableIPTableLock = nwCfg.DisableIPTableLock plugin.setCNIReportDetails(nwCfg, CNI_ADD, "") @@ -933,6 +942,11 @@ func (plugin *NetPlugin) Get(args *cniSkel.CmdArgs) error { logger.Info("Read network configuration", zap.Any("config", nwCfg)) + if argErr := plugin.validateArgs(args, nwCfg); argErr != nil { + err = argErr + return err + } + iptables.DisableIPTableLock = nwCfg.DisableIPTableLock // Initialize values from network config. @@ -1015,6 +1029,11 @@ func (plugin *NetPlugin) Delete(args *cniSkel.CmdArgs) error { return err } + if argErr := plugin.validateArgs(args, nwCfg); argErr != nil { + err = argErr + return err + } + // Parse Pod arguments. if k8sPodName, k8sNamespace, err = plugin.getPodInfo(args.Args); err != nil { logger.Error("Failed to get POD info", zap.Error(err)) @@ -1206,6 +1225,11 @@ func (plugin *NetPlugin) Update(args *cniSkel.CmdArgs) error { return err } + if argErr := plugin.validateArgs(args, nwCfg); argErr != nil { + err = argErr + return err + } + logger.Info("Read network configuration", zap.Any("config", nwCfg)) iptables.DisableIPTableLock = nwCfg.DisableIPTableLock @@ -1468,3 +1492,14 @@ func convertCniResultToInterfaceInfo(result *cniTypesCurr.Result) network.Interf return interfaceInfo } + +func (plugin *NetPlugin) validateArgs(args *cniSkel.CmdArgs, nwCfg *cni.NetworkConfig) error { + if !allowedInput.MatchString(args.ContainerID) || !allowedInput.MatchString(args.IfName) { + return errors.New("invalid args value") + } + if !allowedInput.MatchString(nwCfg.Bridge) { + return errors.New("invalid network config value") + } + + return nil +} diff --git a/cni/network/network_test.go b/cni/network/network_test.go index 5845225ce..6d19ce673 100644 --- a/cni/network/network_test.go +++ b/cni/network/network_test.go @@ -1484,3 +1484,73 @@ func TestPluginSwiftV2MultipleAddDelete(t *testing.T) { }) } } + +func TestValidateArgs(t *testing.T) { + p, _ := cni.NewPlugin("name", "0.3.0") + plugin := &NetPlugin{ + Plugin: p, + } + + tests := []struct { + name string + args *cniSkel.CmdArgs + nwCfg *cni.NetworkConfig + wantErr bool + }{ + { + name: "Args", + args: &cniSkel.CmdArgs{ + ContainerID: "5419067fa51b3b942bdd1af1ae78ea5f9cabc67ae71c7b5ef57ba8ca1b2386ec", + IfName: "eth0", + }, + nwCfg: &cni.NetworkConfig{ + Bridge: "azure0", + }, + wantErr: false, + }, + { + name: "Args with spaces and special characters", + args: &cniSkel.CmdArgs{ + ContainerID: "test2-container", + IfName: "vEthernet (Ethernet 2)", + }, + nwCfg: &cni.NetworkConfig{ + Bridge: ".-_", + }, + wantErr: false, + }, + { + name: "Empty args", + args: &cniSkel.CmdArgs{ + ContainerID: "", + IfName: "", + }, + nwCfg: &cni.NetworkConfig{ + Bridge: "", + }, + wantErr: false, + }, + { + name: "Invalid args", + args: &cniSkel.CmdArgs{ + ContainerID: "", + IfName: "", + }, + nwCfg: &cni.NetworkConfig{ + Bridge: "\\value/\"", + }, + wantErr: true, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + err := plugin.validateArgs(tt.args, tt.nwCfg) + if tt.wantErr { + require.Error(t, err, "Expected error but did not receive one") + } else { + require.NoError(t, err, "Expected no error but received one") + } + }) + } +} diff --git a/cns/NetworkContainerContract.go b/cns/NetworkContainerContract.go index d11463a83..ae0f3bc5b 100644 --- a/cns/NetworkContainerContract.go +++ b/cns/NetworkContainerContract.go @@ -102,6 +102,7 @@ const ( ) var ErrInvalidNCID = errors.New("invalid NetworkContainerID") +var ErrInvalidIP = errors.New("invalid IP") // CreateNetworkContainerRequest specifies request to create a network container or network isolation boundary. type CreateNetworkContainerRequest struct { @@ -132,9 +133,24 @@ func (req *CreateNetworkContainerRequest) Validate() error { if _, err := uuid.Parse(strings.TrimPrefix(req.NetworkContainerid, SwiftPrefix)); err != nil { return errors.Wrapf(ErrInvalidNCID, "NetworkContainerID %s is not a valid UUID: %s", req.NetworkContainerid, err.Error()) } + if req.PrimaryInterfaceIdentifier != "" && !isValidIP(req.PrimaryInterfaceIdentifier) { + return errors.Wrapf(ErrInvalidIP, "PrimaryInterfaceIdentifier %s is not a valid ip address", req.PrimaryInterfaceIdentifier) + } + if req.IPConfiguration.GatewayIPAddress != "" && !isValidIP(req.IPConfiguration.GatewayIPAddress) { + return errors.Wrapf(ErrInvalidIP, "GatewayIPAddress %s is not a valid ip address", req.IPConfiguration.GatewayIPAddress) + } return nil } +func isValidIP(ipStr string) bool { + // if can parse (i.e. not nil), then valid ip + if ip, _, err := net.ParseCIDR(ipStr); err == nil { + return ip != nil + } + ip := net.ParseIP(ipStr) + return ip != nil +} + // CreateNetworkContainerRequest implements fmt.Stringer for logging func (req *CreateNetworkContainerRequest) String() string { return fmt.Sprintf("CreateNetworkContainerRequest"+ diff --git a/cns/NetworkContainerContract_test.go b/cns/NetworkContainerContract_test.go index 88f7e43f7..ab0197e78 100644 --- a/cns/NetworkContainerContract_test.go +++ b/cns/NetworkContainerContract_test.go @@ -157,7 +157,15 @@ func TestPostNetworkContainersRequest_Validate(t *testing.T) { NetworkContainerid: "f47ac10b-58cc-0372-8567-0e02b2c3d479", }, { - NetworkContainerid: "f47ac10b-58cc-0372-8567-0e02b2c3d478", + NetworkContainerid: "f47ac10b-58cc-0372-8567-0e02b2c3d478", + PrimaryInterfaceIdentifier: "10.240.0.4", + IPConfiguration: IPConfiguration{ + GatewayIPAddress: "10.0.0.1", + }, + }, + { + NetworkContainerid: "a47ac10b-58cc-0372-8567-0e02b2c3d478", + PrimaryInterfaceIdentifier: "10.240.0.4/24", }, }, }, @@ -191,6 +199,36 @@ func TestPostNetworkContainersRequest_Validate(t *testing.T) { }, wantErr: true, }, + { + name: "invalid", + req: PostNetworkContainersRequest{ + CreateNetworkContainerRequests: []CreateNetworkContainerRequest{ + { + NetworkContainerid: "f47ac10b-58cc-0372-8567-0e02b2c3d478", + PrimaryInterfaceIdentifier: "10.240.0.4", + IPConfiguration: IPConfiguration{ + GatewayIPAddress: "10.0.0.1;", + }, + }, + }, + }, + wantErr: true, + }, + { + name: "invalid", + req: PostNetworkContainersRequest{ + CreateNetworkContainerRequests: []CreateNetworkContainerRequest{ + { + NetworkContainerid: "f47ac10b-58cc-0372-8567-0e02b2c3d478", + PrimaryInterfaceIdentifier: "-10.240.0.4", + IPConfiguration: IPConfiguration{ + GatewayIPAddress: "10.0.0.1", + }, + }, + }, + }, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {