diff --git a/aitelemetry/telemetrywrapper_test.go b/aitelemetry/telemetrywrapper_test.go index 08e63c8dc..2ccb17521 100644 --- a/aitelemetry/telemetrywrapper_test.go +++ b/aitelemetry/telemetrywrapper_test.go @@ -32,12 +32,12 @@ func TestMain(m *testing.M) { p := platform.NewExecClient(nil) if runtime.GOOS == "linux" { //nolint:errcheck // initial test setup - p.ExecuteCommand("cp metadata_test.json /tmp/azuremetadata.json") + p.ExecuteRawCommand("cp metadata_test.json /tmp/azuremetadata.json") } else { metadataFile := filepath.FromSlash(os.Getenv("TEMP")) + "\\azuremetadata.json" cmd := fmt.Sprintf("copy metadata_test.json %s", metadataFile) //nolint:errcheck // initial test setup - p.ExecuteCommand(cmd) + p.ExecuteRawCommand(cmd) } hostu, _ := url.Parse("tcp://" + hostAgentUrl) @@ -58,12 +58,12 @@ func TestMain(m *testing.M) { if runtime.GOOS == "linux" { //nolint:errcheck // test cleanup - p.ExecuteCommand("rm /tmp/azuremetadata.json") + p.ExecuteRawCommand("rm /tmp/azuremetadata.json") } else { metadataFile := filepath.FromSlash(os.Getenv("TEMP")) + "\\azuremetadata.json" cmd := fmt.Sprintf("del %s", metadataFile) //nolint:errcheck // initial test cleanup - p.ExecuteCommand(cmd) + p.ExecuteRawCommand(cmd) } log.Close() diff --git a/cns/dockerclient/dockerclient.go b/cns/dockerclient/dockerclient.go index 2e4686d16..b1a8a3f47 100644 --- a/cns/dockerclient/dockerclient.go +++ b/cns/dockerclient/dockerclient.go @@ -181,7 +181,7 @@ func (c *Client) DeleteNetwork(networkName string) error { cmd := fmt.Sprintf("iptables -t nat -D POSTROUTING -m iprange ! --dst-range 168.63.129.16 -m addrtype ! --dst-type local ! -d %v -j MASQUERADE", primaryNic.Subnet) - _, err = p.ExecuteCommand(cmd) + _, err = p.ExecuteRawCommand(cmd) if err != nil { logger.Printf("[Azure CNS] Error Removing Outbound SNAT rule %v", err) } diff --git a/ebtables/ebtables.go b/ebtables/ebtables.go index 44340701d..0f1013c51 100644 --- a/ebtables/ebtables.go +++ b/ebtables/ebtables.go @@ -133,7 +133,7 @@ func GetEbtableRules(tableName, chainName string) ([]string, error) { command := fmt.Sprintf( "ebtables -t %s -L %s --Lmac2", tableName, chainName) - out, err := p.ExecuteCommand(command) + out, err := p.ExecuteRawCommand(command) if err != nil { return nil, err } @@ -228,7 +228,7 @@ func EbTableRuleExists(tableName, chainName, matchSet string) (bool, error) { func runEbCmd(table, action, chain, rule string) error { p := platform.NewExecClient(nil) command := fmt.Sprintf("ebtables -t %s %s %s %s", table, action, chain, rule) - _, err := p.ExecuteCommand(command) + _, err := p.ExecuteRawCommand(command) return err } diff --git a/iptables/iptables.go b/iptables/iptables.go index 81fb58f3c..76dec97fe 100644 --- a/iptables/iptables.go +++ b/iptables/iptables.go @@ -109,7 +109,7 @@ func (c *Client) RunCmd(version, params string) error { cmd = fmt.Sprintf("%s -w %d %s", iptCmd, lockTimeout, params) } - if _, err := p.ExecuteCommand(cmd); err != nil { + if _, err := p.ExecuteRawCommand(cmd); err != nil { return err } diff --git a/network/network_linux.go b/network/network_linux.go index 7f688d2b2..21eaaca7e 100644 --- a/network/network_linux.go +++ b/network/network_linux.go @@ -256,7 +256,7 @@ func isGreaterOrEqaulUbuntuVersion(versionToMatch int) bool { } func (nm *networkManager) systemVersion() (string, error) { - osVersion, err := nm.plClient.ExecuteCommand("lsb_release -rs") + osVersion, err := nm.plClient.ExecuteRawCommand("lsb_release -rs") if err != nil { return osVersion, errors.Wrap(err, "error retrieving the system distribution version") } @@ -327,7 +327,7 @@ func (nm *networkManager) readDNSInfo(ifName string) (DNSInfo, error) { return dnsInfo, errors.Wrap(err, "Error generating interface name status cmd") } - out, err := nm.plClient.ExecuteCommand(cmd) + out, err := nm.plClient.ExecuteRawCommand(cmd) if err != nil { return dnsInfo, errors.Wrapf(err, "Error executing interface status with cmd %s", cmd) } @@ -434,7 +434,7 @@ func (nm *networkManager) applyDNSConfig(extIf *externalInterface, ifName string return errors.Wrap(err, "Error generating add DNS Servers cmd") } if cmd != "" { - _, err = nm.plClient.ExecuteCommand(cmd) + _, err = nm.plClient.ExecuteRawCommand(cmd) if err != nil { return errors.Wrapf(err, "Error executing add DNS Servers with cmd %s", cmd) } @@ -447,7 +447,7 @@ func (nm *networkManager) applyDNSConfig(extIf *externalInterface, ifName string return errors.Wrap(err, "Error generating add domain cmd") } - _, err = nm.plClient.ExecuteCommand(cmd) + _, err = nm.plClient.ExecuteRawCommand(cmd) if err != nil { return errors.Wrapf(err, "Error executing add Domain with cmd %s", cmd) } @@ -533,7 +533,7 @@ func (nm *networkManager) connectExternalInterface(extIf *externalInterface, nwI isSystemdResolvedActive := false if isGreaterOrEqualUbuntu17 { // Don't copy dns servers if systemd-resolved isn't available - if _, cmderr := nm.plClient.ExecuteCommand("systemctl status systemd-resolved"); cmderr == nil { + if _, cmderr := nm.plClient.ExecuteRawCommand("systemctl status systemd-resolved"); cmderr == nil { isSystemdResolvedActive = true logger.Info("Saving dns config from", zap.String("Name", hostIf.Name)) if err = nm.saveDNSConfig(extIf); err != nil { diff --git a/network/network_windows.go b/network/network_windows.go index 9b445503b..f217a6c0c 100644 --- a/network/network_windows.go +++ b/network/network_windows.go @@ -214,13 +214,13 @@ func (nm *networkManager) appIPV6RouteEntry(nwInfo *EndpointInfo) error { cmd := fmt.Sprintf(routeCmd, "delete", nwInfo.Subnets[1].Prefix.String(), ifName, ipv6DefaultHop) - if out, err = nm.plClient.ExecuteCommand(cmd); err != nil { + if out, err = nm.plClient.ExecuteRawCommand(cmd); err != nil { logger.Error("Deleting ipv6 route failed", zap.Any("out", out), zap.Error(err)) } cmd = fmt.Sprintf(routeCmd, "add", nwInfo.Subnets[1].Prefix.String(), ifName, ipv6DefaultHop) - if out, err = nm.plClient.ExecuteCommand(cmd); err != nil { + if out, err = nm.plClient.ExecuteRawCommand(cmd); err != nil { logger.Error("Adding ipv6 route failed", zap.Any("out", out), zap.Error(err)) } } diff --git a/network/networkutils/networkutils_linux.go b/network/networkutils/networkutils_linux.go index 87cc35547..9fbd90657 100644 --- a/network/networkutils/networkutils_linux.go +++ b/network/networkutils/networkutils_linux.go @@ -209,7 +209,7 @@ func (nu NetworkUtils) BlockIPAddresses(iptablesClient ipTablesClient, bridgeNam } func (nu NetworkUtils) EnableIPV4Forwarding() error { - _, err := nu.plClient.ExecuteCommand(enableIPV4ForwardCmd) + _, err := nu.plClient.ExecuteRawCommand(enableIPV4ForwardCmd) if err != nil { logger.Error("Enable ipv4 forwarding failed with", zap.Error(err)) return errors.Wrap(err, "enable ipv4 forwarding failed") @@ -220,7 +220,7 @@ func (nu NetworkUtils) EnableIPV4Forwarding() error { func (nu NetworkUtils) EnableIPV6Forwarding() error { cmd := fmt.Sprint(enableIPV6ForwardCmd) - _, err := nu.plClient.ExecuteCommand(cmd) + _, err := nu.plClient.ExecuteRawCommand(cmd) if err != nil { logger.Error("Enable ipv6 forwarding failed with", zap.Error(err)) return err @@ -233,7 +233,7 @@ func (nu NetworkUtils) EnableIPV6Forwarding() error { func (nu NetworkUtils) UpdateIPV6Setting(disable int) error { // sysctl -w net.ipv6.conf.all.disable_ipv6=0/1 cmd := fmt.Sprintf(toggleIPV6Cmd, disable) - _, err := nu.plClient.ExecuteCommand(cmd) + _, err := nu.plClient.ExecuteRawCommand(cmd) if err != nil { logger.Error("Update IPV6 Setting failed with", zap.Error(err)) } @@ -261,7 +261,7 @@ func (nu NetworkUtils) DisableRAForInterface(ifName string) error { } cmd := fmt.Sprintf(disableRACmd, ifName) - out, err := nu.plClient.ExecuteCommand(cmd) + out, err := nu.plClient.ExecuteRawCommand(cmd) if err != nil { logger.Error("Diabling ra failed with", zap.Error(err), zap.Any("out", out)) } @@ -271,7 +271,7 @@ func (nu NetworkUtils) DisableRAForInterface(ifName string) error { func (nu NetworkUtils) SetProxyArp(ifName string) error { cmd := fmt.Sprintf("echo 1 > /proc/sys/net/ipv4/conf/%v/proxy_arp", ifName) - _, err := nu.plClient.ExecuteCommand(cmd) + _, err := nu.plClient.ExecuteRawCommand(cmd) return errors.Wrapf(err, "failed to set proxy arp for interface %v", ifName) } diff --git a/network/snat/snat_linux.go b/network/snat/snat_linux.go index 8150161d7..936ce8ef8 100644 --- a/network/snat/snat_linux.go +++ b/network/snat/snat_linux.go @@ -469,7 +469,7 @@ func (client *Client) addMasqueradeRule(snatBridgeIPWithPrefix string) error { // Drop all vlan traffic on linux bridge func (client *Client) addVlanDropRule() error { - out, err := client.plClient.ExecuteCommand(l2PreroutingEntries) + out, err := client.plClient.ExecuteRawCommand(l2PreroutingEntries) if err != nil { logger.Error("Error while listing ebtable rules") return err @@ -482,7 +482,7 @@ func (client *Client) addVlanDropRule() error { } logger.Info("Adding ebtable rule to drop vlan traffic on snat bridge", zap.String("vlanDropAddRule", vlanDropAddRule)) - _, err = client.plClient.ExecuteCommand(vlanDropAddRule) + _, err = client.plClient.ExecuteRawCommand(vlanDropAddRule) return err } @@ -490,7 +490,7 @@ func (client *Client) addVlanDropRule() error { func (client *Client) EnableIPForwarding() error { // Enable ip forwading on linux vm. // sysctl -w net.ipv4.ip_forward=1 - _, err := client.plClient.ExecuteCommand(enableIPForwardCmd) + _, err := client.plClient.ExecuteRawCommand(enableIPForwardCmd) if err != nil { return errors.Wrap(err, "enable ipforwarding command failed") } diff --git a/network/transparent_endpointclient_linux.go b/network/transparent_endpointclient_linux.go index db4935ace..08135a41e 100644 --- a/network/transparent_endpointclient_linux.go +++ b/network/transparent_endpointclient_linux.go @@ -73,7 +73,7 @@ func NewTransparentEndpointClient( func (client *TransparentEndpointClient) setArpProxy(ifName string) error { cmd := fmt.Sprintf("echo 1 > /proc/sys/net/ipv4/conf/%v/proxy_arp", ifName) - _, err := client.plClient.ExecuteCommand(cmd) + _, err := client.plClient.ExecuteRawCommand(cmd) return err } diff --git a/network/transparent_vlan_endpointclient_linux.go b/network/transparent_vlan_endpointclient_linux.go index 7c6d5d286..731353c23 100644 --- a/network/transparent_vlan_endpointclient_linux.go +++ b/network/transparent_vlan_endpointclient_linux.go @@ -371,12 +371,12 @@ func (client *TransparentVlanEndpointClient) PopulateVnet(epInfo *EndpointInfo) } client.vnetMac = vnetVethIf.HardwareAddr // Disable rp filter again to allow asymmetric routing for tunneling packets - _, err = client.plClient.ExecuteCommand(DisableRPFilterCmd) + _, err = client.plClient.ExecuteRawCommand(DisableRPFilterCmd) if err != nil { return errors.Wrap(err, "transparent vlan failed to disable rp filter in vnet") } disableRPFilterVlanIfCmd := strings.Replace(DisableRPFilterCmd, "all", client.vlanIfName, 1) - _, err = client.plClient.ExecuteCommand(disableRPFilterVlanIfCmd) + _, err = client.plClient.ExecuteRawCommand(disableRPFilterVlanIfCmd) if err != nil { return errors.Wrap(err, "transparent vlan failed to disable rp filter vlan interface in vnet") } diff --git a/ovsctl/ovsctl.go b/ovsctl/ovsctl.go index 24ecd7f0a..2d99e77f7 100644 --- a/ovsctl/ovsctl.go +++ b/ovsctl/ovsctl.go @@ -65,7 +65,7 @@ func (o Ovsctl) CreateOVSBridge(bridgeName string) error { logger.Info("Creating OVS Bridge", zap.String("name", bridgeName)) ovsCreateCmd := fmt.Sprintf("ovs-vsctl add-br %s", bridgeName) - _, err := o.execcli.ExecuteCommand(ovsCreateCmd) + _, err := o.execcli.ExecuteRawCommand(ovsCreateCmd) if err != nil { logger.Error("Error while creating OVS bridge", zap.Error(err)) return newErrorOvsctl(err.Error()) @@ -78,7 +78,7 @@ func (o Ovsctl) DeleteOVSBridge(bridgeName string) error { logger.Info("Deleting OVS Bridge", zap.String("name", bridgeName)) ovsCreateCmd := fmt.Sprintf("ovs-vsctl del-br %s", bridgeName) - _, err := o.execcli.ExecuteCommand(ovsCreateCmd) + _, err := o.execcli.ExecuteRawCommand(ovsCreateCmd) if err != nil { logger.Error("Error while deleting OVS bridge", zap.Error(err)) return newErrorOvsctl(err.Error()) @@ -89,7 +89,7 @@ func (o Ovsctl) DeleteOVSBridge(bridgeName string) error { func (o Ovsctl) AddPortOnOVSBridge(hostIfName, bridgeName string, vlanID int) error { cmd := fmt.Sprintf("ovs-vsctl add-port %s %s", bridgeName, hostIfName) - _, err := o.execcli.ExecuteCommand(cmd) + _, err := o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Error while setting OVS as master to primary interface", zap.Error(err)) return newErrorOvsctl(err.Error()) @@ -100,7 +100,7 @@ func (o Ovsctl) AddPortOnOVSBridge(hostIfName, bridgeName string, vlanID int) er func (o Ovsctl) GetOVSPortNumber(interfaceName string) (string, error) { cmd := fmt.Sprintf("ovs-vsctl get Interface %s ofport", interfaceName) - ofport, err := o.execcli.ExecuteCommand(cmd) + ofport, err := o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Get ofport failed with", zap.Error(err)) return "", newErrorOvsctl(err.Error()) @@ -111,7 +111,7 @@ func (o Ovsctl) GetOVSPortNumber(interfaceName string) (string, error) { func (o Ovsctl) AddVMIpAcceptRule(bridgeName, primaryIP, mac string) error { cmd := fmt.Sprintf("ovs-ofctl add-flow %s ip,nw_dst=%s,dl_dst=%s,priority=%d,actions=normal", bridgeName, primaryIP, mac, high) - _, err := o.execcli.ExecuteCommand(cmd) + _, err := o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Adding SNAT rule failed with", zap.Error(err)) return newErrorOvsctl(err.Error()) @@ -123,7 +123,7 @@ func (o Ovsctl) AddVMIpAcceptRule(bridgeName, primaryIP, mac string) error { func (o Ovsctl) AddArpSnatRule(bridgeName, mac, macHex, ofport string) error { cmd := fmt.Sprintf(`ovs-ofctl add-flow %v table=1,priority=%d,arp,arp_op=1,actions='mod_dl_src:%s, load:0x%s->NXM_NX_ARP_SHA[],output:%s'`, bridgeName, low, mac, macHex, ofport) - _, err := o.execcli.ExecuteCommand(cmd) + _, err := o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Adding ARP SNAT rule failed with", zap.Error(err)) return newErrorOvsctl(err.Error()) @@ -149,7 +149,7 @@ func (o Ovsctl) AddIPSnatRule(bridgeName string, ip net.IP, vlanID int, port, ma cmd = fmt.Sprintf("%s,strip_vlan,%v", commonPrefix, outport) } - _, err := o.execcli.ExecuteCommand(cmd) + _, err := o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Adding IP SNAT rule failed with", zap.Error(err)) return newErrorOvsctl(err.Error()) @@ -158,7 +158,7 @@ func (o Ovsctl) AddIPSnatRule(bridgeName string, ip net.IP, vlanID int, port, ma // Drop other packets which doesn't satisfy above condition cmd = fmt.Sprintf("ovs-ofctl add-flow %v priority=%d,ip,in_port=%s,actions=drop", bridgeName, low, port) - _, err = o.execcli.ExecuteCommand(cmd) + _, err = o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Dropping vlantag packet rule failed with", zap.Error(err)) return newErrorOvsctl(err.Error()) @@ -171,7 +171,7 @@ func (o Ovsctl) AddArpDnatRule(bridgeName, port, mac string) error { // Add DNAT rule to forward ARP replies to container interfaces. cmd := fmt.Sprintf(`ovs-ofctl add-flow %s arp,arp_op=2,in_port=%s,actions='mod_dl_dst:ff:ff:ff:ff:ff:ff, load:0x%s->NXM_NX_ARP_THA[],normal'`, bridgeName, port, mac) - _, err := o.execcli.ExecuteCommand(cmd) + _, err := o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Adding DNAT rule failed with", zap.Error(err)) return newErrorOvsctl(err.Error()) @@ -191,7 +191,7 @@ func (o Ovsctl) AddFakeArpReply(bridgeName string, ip net.IP) error { move:NXM_NX_ARP_SHA[]->NXM_NX_ARP_THA[],move:NXM_OF_ARP_TPA[]->NXM_OF_ARP_SPA[], load:0x%s->NXM_NX_ARP_SHA[],load:0x%x->NXM_OF_ARP_TPA[],IN_PORT'`, bridgeName, high, defaultMacForArpResponse, macAddrHex, ipAddrInt) - _, err := o.execcli.ExecuteCommand(cmd) + _, err := o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("[ovs] Adding ARP reply rule failed with", zap.Error(err)) return newErrorOvsctl(err.Error()) @@ -207,7 +207,7 @@ func (o Ovsctl) AddArpReplyRule(bridgeName, port string, ip net.IP, mac string, logger.Info("Adding ARP reply rule to add vlan and forward packet to table 1 for port", zap.Int("vlanid", vlanid), zap.String("port", port)) cmd := fmt.Sprintf(`ovs-ofctl add-flow %s arp,arp_op=1,in_port=%s,actions='mod_vlan_vid:%v,resubmit(,1)'`, bridgeName, port, vlanid) - _, err := o.execcli.ExecuteCommand(cmd) + _, err := o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Adding ARP reply rule failed with", zap.Error(err)) return newErrorOvsctl(err.Error()) @@ -220,7 +220,7 @@ func (o Ovsctl) AddArpReplyRule(bridgeName, port string, ip net.IP, mac string, move:NXM_NX_ARP_SHA[]->NXM_NX_ARP_THA[],move:NXM_OF_ARP_SPA[]->NXM_OF_ARP_TPA[], load:0x%s->NXM_NX_ARP_SHA[],load:0x%x->NXM_OF_ARP_SPA[],strip_vlan,IN_PORT'`, bridgeName, ip.String(), vlanid, high, mac, macAddrHex, ipAddrInt) - _, err = o.execcli.ExecuteCommand(cmd) + _, err = o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Adding ARP reply rule failed with", zap.Error(err)) return newErrorOvsctl(err.Error()) @@ -241,7 +241,7 @@ func (o Ovsctl) AddMacDnatRule(bridgeName, port string, ip net.IP, mac string, v } else { cmd = fmt.Sprintf("%s,actions=mod_dl_dst:%s,strip_vlan,%s", commonPrefix, mac, containerPort) } - _, err := o.execcli.ExecuteCommand(cmd) + _, err := o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Adding MAC DNAT rule failed with", zap.Error(err)) return newErrorOvsctl(err.Error()) @@ -253,14 +253,14 @@ func (o Ovsctl) AddMacDnatRule(bridgeName, port string, ip net.IP, mac string, v func (o Ovsctl) DeleteArpReplyRule(bridgeName, port string, ip net.IP, vlanid int) { cmd := fmt.Sprintf("ovs-ofctl del-flows %s arp,arp_op=1,in_port=%s", bridgeName, port) - _, err := o.execcli.ExecuteCommand(cmd) + _, err := o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Deleting ARP reply rule failed with", zap.Error(err)) } cmd = fmt.Sprintf("ovs-ofctl del-flows %s table=1,arp,arp_tpa=%s,dl_vlan=%v,arp_op=1", bridgeName, ip.String(), vlanid) - _, err = o.execcli.ExecuteCommand(cmd) + _, err = o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Deleting ARP reply rule failed with", zap.Error(err)) } @@ -269,7 +269,7 @@ func (o Ovsctl) DeleteArpReplyRule(bridgeName, port string, ip net.IP, vlanid in func (o Ovsctl) DeleteIPSnatRule(bridgeName, port string) { cmd := fmt.Sprintf("ovs-ofctl del-flows %v ip,in_port=%s", bridgeName, port) - _, err := o.execcli.ExecuteCommand(cmd) + _, err := o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Error while deleting ovs rule", zap.String("cmd", cmd), zap.Error(err)) } @@ -286,7 +286,7 @@ func (o Ovsctl) DeleteMacDnatRule(bridgeName, port string, ip net.IP, vlanid int bridgeName, ip.String(), port) } - _, err := o.execcli.ExecuteCommand(cmd) + _, err := o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Deleting MAC DNAT rule failed with", zap.Error(err)) } @@ -295,7 +295,7 @@ func (o Ovsctl) DeleteMacDnatRule(bridgeName, port string, ip net.IP, vlanid int func (o Ovsctl) DeletePortFromOVS(bridgeName, interfaceName string) error { // Disconnect external interface from its bridge. cmd := fmt.Sprintf("ovs-vsctl del-port %s %s", bridgeName, interfaceName) - _, err := o.execcli.ExecuteCommand(cmd) + _, err := o.execcli.ExecuteRawCommand(cmd) if err != nil { logger.Error("Failed to disconnect interface", zap.String("from", interfaceName), zap.Error(err)) return newErrorOvsctl(err.Error()) diff --git a/platform/mockexec.go b/platform/mockexec.go index c0471983d..c390117a4 100644 --- a/platform/mockexec.go +++ b/platform/mockexec.go @@ -8,12 +8,14 @@ import ( type MockExecClient struct { returnError bool + setExecRawCommand execRawCommandValidator setExecCommand execCommandValidator powershellCommandResponder powershellCommandResponder } type ( - execCommandValidator func(string) (string, error) + execRawCommandValidator func(string) (string, error) + execCommandValidator func(string, ...string) (string, error) powershellCommandResponder func(string) (string, error) ) @@ -26,9 +28,9 @@ func NewMockExecClient(returnErr bool) *MockExecClient { } } -func (e *MockExecClient) ExecuteCommand(cmd string) (string, error) { - if e.setExecCommand != nil { - return e.setExecCommand(cmd) +func (e *MockExecClient) ExecuteRawCommand(cmd string) (string, error) { + if e.setExecRawCommand != nil { + return e.setExecRawCommand(cmd) } if e.returnError { @@ -38,6 +40,22 @@ func (e *MockExecClient) ExecuteCommand(cmd string) (string, error) { return "", nil } +func (e *MockExecClient) ExecuteCommand(_ context.Context, cmd string, args ...string) (string, error) { + if e.setExecCommand != nil { + return e.setExecCommand(cmd, args...) + } + + if e.returnError { + return "", ErrMockExec + } + + return "", nil +} + +func (e *MockExecClient) SetExecRawCommand(fn execRawCommandValidator) { + e.setExecRawCommand = fn +} + func (e *MockExecClient) SetExecCommand(fn execCommandValidator) { e.setExecCommand = fn } diff --git a/platform/osInterface.go b/platform/osInterface.go index a70f6203d..9cac6029c 100644 --- a/platform/osInterface.go +++ b/platform/osInterface.go @@ -18,7 +18,8 @@ type execClient struct { //nolint:revive // ExecClient make sense type ExecClient interface { - ExecuteCommand(command string) (string, error) + ExecuteRawCommand(command string) (string, error) + ExecuteCommand(ctx context.Context, command string, args ...string) (string, error) GetLastRebootTime() (time.Time, error) ClearNetworkConfiguration() (bool, error) ExecutePowershellCommand(command string) (string, error) diff --git a/platform/os_linux.go b/platform/os_linux.go index e6c219728..9e659b06a 100644 --- a/platform/os_linux.go +++ b/platform/os_linux.go @@ -13,6 +13,7 @@ import ( "time" "github.com/Azure/azure-container-networking/log" + "github.com/pkg/errors" "go.uber.org/zap" ) @@ -56,7 +57,7 @@ func GetOSInfo() string { func GetProcessSupport() error { p := NewExecClient(nil) cmd := fmt.Sprintf("ps -p %v -o comm=", os.Getpid()) - _, err := p.ExecuteCommand(cmd) + _, err := p.ExecuteRawCommand(cmd) return err } @@ -88,7 +89,8 @@ func (p *execClient) GetLastRebootTime() (time.Time, error) { return rebootTime.UTC(), nil } -func (p *execClient) ExecuteCommand(command string) (string, error) { +// Deprecated: ExecuteRawCommand is deprecated, it is recommended to use ExecuteCommand when possible +func (p *execClient) ExecuteRawCommand(command string) (string, error) { if p.logger != nil { p.logger.Info("[Azure-Utils]", zap.String("command", command)) } else { @@ -114,11 +116,38 @@ func (p *execClient) ExecuteCommand(command string) (string, error) { return out.String(), nil } +// ExecuteCommand passes its parameters to an exec.CommandContext, runs the command, and returns its output, or an error if the command fails or times out +func (p *execClient) ExecuteCommand(ctx context.Context, command string, args ...string) (string, error) { + if p.logger != nil { + p.logger.Info("[Azure-Utils]", zap.String("command", command), zap.Strings("args", args)) + } else { + log.Printf("[Azure-Utils] %s %v", command, args) + } + + var stderr bytes.Buffer + var out bytes.Buffer + + // Create a new context and add a timeout to it + derivedCtx, cancel := context.WithTimeout(ctx, p.Timeout) + defer cancel() // The cancel should be deferred so resources are cleaned up + + cmd := exec.CommandContext(derivedCtx, command, args...) + cmd.Stderr = &stderr + cmd.Stdout = &out + + err := cmd.Run() + if err != nil { + return "", errors.Wrapf(err, "%s:%s", err.Error(), stderr.String()) + } + + return out.String(), nil +} + func SetOutboundSNAT(subnet string) error { p := NewExecClient(nil) cmd := fmt.Sprintf("iptables -t nat -A POSTROUTING -m iprange ! --dst-range 168.63.129.16 -m addrtype ! --dst-type local ! -d %v -j MASQUERADE", subnet) - _, err := p.ExecuteCommand(cmd) + _, err := p.ExecuteRawCommand(cmd) if err != nil { log.Printf("SNAT Iptable rule was not set") return err @@ -132,17 +161,19 @@ func (p *execClient) ClearNetworkConfiguration() (bool, error) { return false, nil } +// not supported on linux func (p *execClient) ExecutePowershellCommand(_ string) (string, error) { return "", nil } +// not supported on linux func (p *execClient) ExecutePowershellCommandWithContext(_ context.Context, _ string) (string, error) { return "", nil } func (p *execClient) KillProcessByName(processName string) error { cmd := fmt.Sprintf("pkill -f %v", processName) - _, err := p.ExecuteCommand(cmd) + _, err := p.ExecuteRawCommand(cmd) return err } @@ -174,7 +205,7 @@ func GetProcessNameByID(pidstr string) (string, error) { p := NewExecClient(nil) pidstr = strings.Trim(pidstr, "\n") cmd := fmt.Sprintf("ps -p %s -o comm=", pidstr) - out, err := p.ExecuteCommand(cmd) + out, err := p.ExecuteRawCommand(cmd) if err != nil { log.Printf("GetProcessNameByID returned error: %v", err) return "", err @@ -188,7 +219,7 @@ func GetProcessNameByID(pidstr string) (string, error) { func PrintDependencyPackageDetails() { p := NewExecClient(nil) - out, err := p.ExecuteCommand("iptables --version") + out, err := p.ExecuteRawCommand("iptables --version") out = strings.TrimSuffix(out, "\n") log.Printf("[cni-net] iptable version:%s, err:%v", out, err) } diff --git a/platform/os_linux_test.go b/platform/os_linux_test.go index 1848f22d0..a322afabf 100644 --- a/platform/os_linux_test.go +++ b/platform/os_linux_test.go @@ -1,29 +1,62 @@ package platform import ( + "context" + "errors" + "os/exec" + "strings" "testing" "time" ) +// Command execution time is more than timeout, so ExecuteRawCommand should return error +func TestExecuteRawCommandTimeout(t *testing.T) { + const timeout = 2 * time.Second + client := NewExecClientTimeout(timeout) + + _, err := client.ExecuteRawCommand("sleep 3") + if err == nil { + t.Errorf("TestExecuteRawCommandTimeout should have returned timeout error") + } + t.Logf("%s", err.Error()) +} + +// Command execution time is less than timeout, so ExecuteRawCommand should work without error +func TestExecuteRawCommandNoTimeout(t *testing.T) { + const timeout = 2 * time.Second + client := NewExecClientTimeout(timeout) + + _, err := client.ExecuteRawCommand("sleep 1") + if err != nil { + t.Errorf("TestExecuteRawCommandNoTimeout failed with error %v", err) + } +} + +func TestExecuteCommand(t *testing.T) { + output, err := NewExecClient(nil).ExecuteCommand(context.Background(), "echo", "/B && echo two") + if err != nil { + t.Errorf("TestExecuteCommand failed with error %v", err) + } + if strings.TrimRight(output, "\n\r") != "/B && echo two" { + t.Errorf("TestExecuteCommand failed with output %s", output) + } +} + +func TestExecuteCommandError(t *testing.T) { + _, err := NewExecClient(nil).ExecuteCommand(context.Background(), "donotaddtopath") + if !errors.Is(err, exec.ErrNotFound) { + t.Errorf("TestExecuteCommand failed with error %v", err) + } +} + // Command execution time is more than timeout, so ExecuteCommand should return error func TestExecuteCommandTimeout(t *testing.T) { const timeout = 2 * time.Second client := NewExecClientTimeout(timeout) - _, err := client.ExecuteCommand("sleep 3") + _, err := client.ExecuteCommand(context.Background(), "sleep", "3") if err == nil { t.Errorf("TestExecuteCommandTimeout should have returned timeout error") } t.Logf("%s", err.Error()) } - -// Command execution time is less than timeout, so ExecuteCommand should work without error -func TestExecuteCommandNoTimeout(t *testing.T) { - const timeout = 2 * time.Second - client := NewExecClientTimeout(timeout) - - _, err := client.ExecuteCommand("sleep 1") - if err != nil { - t.Errorf("TestExecuteCommandNoTimeout failed with error %v", err) - } -} diff --git a/platform/os_windows.go b/platform/os_windows.go index 064d47969..63900c6e5 100644 --- a/platform/os_windows.go +++ b/platform/os_windows.go @@ -128,11 +128,12 @@ func (p *execClient) GetLastRebootTime() (time.Time, error) { return rebootTime.UTC(), nil } -func (p *execClient) ExecuteCommand(command string) (string, error) { +// Deprecated: ExecuteRawCommand is deprecated, it is recommended to use ExecuteCommand when possible +func (p *execClient) ExecuteRawCommand(command string) (string, error) { if p.logger != nil { - p.logger.Info("[Azure-Utils]", zap.String("ExecuteCommand", command)) + p.logger.Info("[Azure-Utils]", zap.String("ExecuteRawCommand", command)) } else { - log.Printf("[Azure-Utils] ExecuteCommand: %q", command) + log.Printf("[Azure-Utils] ExecuteRawCommand: %q", command) } var stderr, stdout bytes.Buffer @@ -141,6 +142,31 @@ func (p *execClient) ExecuteCommand(command string) (string, error) { cmd.Stderr = &stderr cmd.Stdout = &stdout + if err := cmd.Run(); err != nil { + return "", errors.Wrapf(err, "ExecuteRawCommand failed. stdout: %q, stderr: %q", stdout.String(), stderr.String()) + } + + return stdout.String(), nil +} + +// ExecuteCommand passes its parameters to an exec.CommandContext, runs the command, and returns its output, or an error if the command fails or times out +func (p *execClient) ExecuteCommand(ctx context.Context, command string, args ...string) (string, error) { + if p.logger != nil { + p.logger.Info("[Azure-Utils]", zap.String("ExecuteCommand", command), zap.Strings("args", args)) + } else { + log.Printf("[Azure-Utils] ExecuteCommand: %q %v", command, args) + } + + var stderr, stdout bytes.Buffer + + // Create a new context and add a timeout to it + derivedCtx, cancel := context.WithTimeout(ctx, p.Timeout) + defer cancel() // The cancel should be deferred so resources are cleaned up + + cmd := exec.CommandContext(derivedCtx, command, args...) + cmd.Stderr = &stderr + cmd.Stdout = &stdout + if err := cmd.Run(); err != nil { return "", errors.Wrapf(err, "ExecuteCommand failed. stdout: %q, stderr: %q", stdout.String(), stderr.String()) } @@ -169,11 +195,12 @@ func (p *execClient) ClearNetworkConfiguration() (bool, error) { func (p *execClient) KillProcessByName(processName string) error { cmd := fmt.Sprintf("taskkill /IM %v /F", processName) - _, err := p.ExecuteCommand(cmd) + _, err := p.ExecuteRawCommand(cmd) return err // nolint } // ExecutePowershellCommand executes powershell command +// Deprecated: ExecutePowershellCommand is deprecated, it is recommended to use ExecuteCommand when possible func (p *execClient) ExecutePowershellCommand(command string) (string, error) { ps, err := exec.LookPath("powershell.exe") if err != nil { @@ -201,6 +228,7 @@ func (p *execClient) ExecutePowershellCommand(command string) (string, error) { } // ExecutePowershellCommandWithContext executes powershell command wth context +// Deprecated: ExecutePowershellCommandWithContext is deprecated, it is recommended to use ExecuteCommand when possible func (p *execClient) ExecutePowershellCommandWithContext(ctx context.Context, command string) (string, error) { ps, err := exec.LookPath("powershell.exe") if err != nil { diff --git a/platform/os_windows_test.go b/platform/os_windows_test.go index 7c8644178..5cb5dacc1 100644 --- a/platform/os_windows_test.go +++ b/platform/os_windows_test.go @@ -6,6 +6,7 @@ import ( "os/exec" "strings" "testing" + "time" "github.com/Azure/azure-container-networking/platform/windows/adapter/mocks" "github.com/golang/mock/gomock" @@ -86,14 +87,14 @@ func TestUpdatePriorityVLANTagIfRequiredIfCurrentValNotEqualDesiredValAndSetRetu assert.EqualError(t, result, "error while setting Priority VLAN Tag value: test failure") } -func TestExecuteCommand(t *testing.T) { - out, err := NewExecClient(nil).ExecuteCommand("dir") +func TestExecuteRawCommand(t *testing.T) { + out, err := NewExecClient(nil).ExecuteRawCommand("dir") require.NoError(t, err) require.NotEmpty(t, out) } -func TestExecuteCommandError(t *testing.T) { - _, err := NewExecClient(nil).ExecuteCommand("dontaddtopath") +func TestExecuteRawCommandError(t *testing.T) { + _, err := NewExecClient(nil).ExecuteRawCommand("dontaddtopath") require.Error(t, err) var xErr *exec.ExitError @@ -101,6 +102,19 @@ func TestExecuteCommandError(t *testing.T) { assert.Equal(t, 1, xErr.ExitCode()) } +func TestExecuteCommand(t *testing.T) { + _, err := NewExecClient(nil).ExecuteCommand(context.Background(), "ping", "localhost") + if err != nil { + t.Errorf("TestExecuteCommand failed with error %v", err) + } +} + +func TestExecuteCommandError(t *testing.T) { + _, err := NewExecClient(nil).ExecuteCommand(context.Background(), "dontaddtopath") + require.Error(t, err) + require.ErrorIs(t, err, exec.ErrNotFound) +} + func TestSetSdnRemoteArpMacAddress_hnsNotEnabled(t *testing.T) { mockExecClient := NewMockExecClient(false) // testing skip setting SdnRemoteArpMacAddress when hns not enabled @@ -143,19 +157,28 @@ func TestFetchPnpIDMapping(t *testing.T) { return "6C-A1-00-50-E4-2D PCI\\VEN_8086&DEV_2723&SUBSYS_00808086&REV_1A\\4&328243d9&0&00E0\n80-6D-97-1E-CF-4E USB\\VID_17EF&PID_A359\\3010019E3", nil }) vfmapping, _ := FetchMacAddressPnpIDMapping(context.Background(), mockExecClient) - require.Len(t, 2, len(vfmapping)) + require.Len(t, vfmapping, 2) // Test when no adapters are found mockExecClient.SetPowershellCommandResponder(func(cmd string) (string, error) { return "", nil }) vfmapping, _ = FetchMacAddressPnpIDMapping(context.Background(), mockExecClient) - require.Empty(t, 0, len(vfmapping)) + require.Empty(t, vfmapping) // Adding carriage returns mockExecClient.SetPowershellCommandResponder(func(cmd string) (string, error) { return "6C-A1-00-50-E4-2D PCI\\VEN_8086&DEV_2723&SUBSYS_00808086&REV_1A\\4&328243d9&0&00E0\r\n\r80-6D-97-1E-CF-4E USB\\VID_17EF&PID_A359\\3010019E3", nil }) vfmapping, _ = FetchMacAddressPnpIDMapping(context.Background(), mockExecClient) - require.Len(t, 2, len(vfmapping)) + require.Len(t, vfmapping, 2) +} + +// ping -t localhost will ping indefinitely and should exceed the 5 second timeout +func TestExecuteCommandTimeout(t *testing.T) { + const timeout = 5 * time.Second + client := NewExecClientTimeout(timeout) + + _, err := client.ExecuteCommand(context.Background(), "ping", "-t", "localhost") + require.Error(t, err) } diff --git a/telemetry/telemetry_windows.go b/telemetry/telemetry_windows.go index 31a50828d..f68b211b2 100644 --- a/telemetry/telemetry_windows.go +++ b/telemetry/telemetry_windows.go @@ -40,7 +40,7 @@ func (report *CNIReport) GetSystemDetails() { func (report *CNIReport) GetOSDetails() { p := platform.NewExecClient(report.Logger) report.OSDetails = OSInfo{OSType: runtime.GOOS} - out, err := p.ExecuteCommand(versionCmd) + out, err := p.ExecuteRawCommand(versionCmd) if err == nil { report.OSDetails.OSVersion = strings.Replace(out, delimiter, "", -1) }