From 06996279dfc16901f7ba5d12aeb1eae950a5b823 Mon Sep 17 00:00:00 2001 From: Jack Francis Date: Fri, 9 Nov 2018 15:36:54 -0800 Subject: [PATCH] ensure N series clusters get aks-docker-engine (#4221) --- pkg/acsengine/engine.go | 38 ------- pkg/acsengine/engine_test.go | 9 +- pkg/acsengine/template_generator.go | 4 +- pkg/api/addons.go | 2 +- pkg/api/common/helper.go | 149 +++++++++++++++++++++++++ pkg/api/common/helper_test.go | 11 ++ pkg/api/defaults.go | 30 +++-- pkg/api/defaults_test.go | 120 +++++++++++++++++++- pkg/api/types.go | 11 +- pkg/api/types_test.go | 33 +++++- pkg/api/vlabs/types.go | 3 +- pkg/api/vlabs/types_test.go | 26 +++++ test/e2e/engine/template.go | 10 -- test/e2e/kubernetes/kubernetes_test.go | 2 +- 14 files changed, 370 insertions(+), 78 deletions(-) diff --git a/pkg/acsengine/engine.go b/pkg/acsengine/engine.go index 144852310..dc1f3a0dd 100644 --- a/pkg/acsengine/engine.go +++ b/pkg/acsengine/engine.go @@ -363,44 +363,6 @@ func getDCOSDefaultRepositoryURL(orchestratorType string, orchestratorVersion st return "" } -func isNSeriesSKU(profile *api.AgentPoolProfile) bool { - /* If a new GPU sku becomes available, add a key to this map, but only if you have a confirmation - that we have an agreement with NVIDIA for this specific gpu. - */ - dm := map[string]bool{ - // K80 - "Standard_NC6": true, - "Standard_NC12": true, - "Standard_NC24": true, - "Standard_NC24r": true, - // M60 - "Standard_NV6": true, - "Standard_NV12": true, - "Standard_NV24": true, - "Standard_NV24r": true, - // P40 - "Standard_ND6s": true, - "Standard_ND12s": true, - "Standard_ND24s": true, - "Standard_ND24rs": true, - // P100 - "Standard_NC6s_v2": true, - "Standard_NC12s_v2": true, - "Standard_NC24s_v2": true, - "Standard_NC24rs_v2": true, - // V100 - "Standard_NC6s_v3": true, - "Standard_NC12s_v3": true, - "Standard_NC24s_v3": true, - "Standard_NC24rs_v3": true, - } - if _, ok := dm[profile.VMSize]; ok { - return dm[profile.VMSize] - } - - return false -} - func getDCOSCustomDataPublicIPStr(orchestratorType string, masterCount int) string { if orchestratorType == api.DCOS { var buf bytes.Buffer diff --git a/pkg/acsengine/engine_test.go b/pkg/acsengine/engine_test.go index 2525e498c..eaf774cf9 100644 --- a/pkg/acsengine/engine_test.go +++ b/pkg/acsengine/engine_test.go @@ -12,6 +12,7 @@ import ( "github.com/Azure/acs-engine/pkg/acsengine/transform" "github.com/Azure/acs-engine/pkg/api" + "github.com/Azure/acs-engine/pkg/api/common" "github.com/Azure/acs-engine/pkg/api/v20160330" "github.com/Azure/acs-engine/pkg/api/vlabs" "github.com/Azure/acs-engine/pkg/i18n" @@ -497,14 +498,14 @@ func TestIsNSeriesSKU(t *testing.T) { } for _, sku := range validSkus { - if !isNSeriesSKU(&api.AgentPoolProfile{VMSize: sku}) { - t.Fatalf("Expected isNSeriesSKU(%s) to be true", sku) + if !common.IsNvidiaEnabledSKU(sku) { + t.Fatalf("Expected common.IsNvidiaEnabledSKU(%s) to be true", sku) } } for _, sku := range invalidSkus { - if isNSeriesSKU(&api.AgentPoolProfile{VMSize: sku}) { - t.Fatalf("Expected isNSeriesSKU(%s) to be false", sku) + if common.IsNvidiaEnabledSKU(sku) { + t.Fatalf("Expected common.IsNvidiaEnabledSKU(%s) to be false", sku) } } } diff --git a/pkg/acsengine/template_generator.go b/pkg/acsengine/template_generator.go index 72cf4aa5c..30e44ec89 100644 --- a/pkg/acsengine/template_generator.go +++ b/pkg/acsengine/template_generator.go @@ -232,7 +232,7 @@ func (t *TemplateGenerator) getTemplateFuncMap(cs *api.ContainerService) templat storagetier, _ := getStorageAccountType(profile.VMSize) buf.WriteString(fmt.Sprintf(",storageprofile=managed,storagetier=%s", storagetier)) } - if isNSeriesSKU(profile) { + if common.IsNvidiaEnabledSKU(profile.VMSize) { accelerator := "nvidia" buf.WriteString(fmt.Sprintf(",accelerator=%s", accelerator)) } @@ -786,7 +786,7 @@ func (t *TemplateGenerator) getTemplateFuncMap(cs *api.ContainerService) templat return cs.Properties.IsNVIDIADevicePluginEnabled() }, "IsNSeriesSKU": func(profile *api.AgentPoolProfile) bool { - return isNSeriesSKU(profile) + return common.IsNvidiaEnabledSKU(profile.VMSize) }, "UseSinglePlacementGroup": func(profile *api.AgentPoolProfile) bool { return *profile.SinglePlacementGroup diff --git a/pkg/api/addons.go b/pkg/api/addons.go index 5623d21a2..174cfca26 100644 --- a/pkg/api/addons.go +++ b/pkg/api/addons.go @@ -155,7 +155,7 @@ func (cs *ContainerService) setAddonsConfig() { defaultNVIDIADevicePluginAddonsConfig := KubernetesAddon{ Name: NVIDIADevicePluginAddonName, - Enabled: helpers.PointerToBool(IsNSeriesSKU(cs.Properties) && common.IsKubernetesVersionGe(o.OrchestratorVersion, "1.10.0")), + Enabled: helpers.PointerToBool(cs.Properties.HasNSeriesSKU() && common.IsKubernetesVersionGe(o.OrchestratorVersion, "1.10.0")), Containers: []KubernetesContainerSpec{ { Name: NVIDIADevicePluginAddonName, diff --git a/pkg/api/common/helper.go b/pkg/api/common/helper.go index 924d76e71..ebab97de1 100644 --- a/pkg/api/common/helper.go +++ b/pkg/api/common/helper.go @@ -65,3 +65,152 @@ func ValidateDNSPrefix(dnsName string) error { } return nil } + +// IsNvidiaEnabledSKU determines if an VM SKU has nvidia driver support +func IsNvidiaEnabledSKU(vmSize string) bool { + /* If a new GPU sku becomes available, add a key to this map, but only if you have a confirmation + that we have an agreement with NVIDIA for this specific gpu. + */ + dm := map[string]bool{ + // K80 + "Standard_NC6": true, + "Standard_NC12": true, + "Standard_NC24": true, + "Standard_NC24r": true, + // M60 + "Standard_NV6": true, + "Standard_NV12": true, + "Standard_NV24": true, + "Standard_NV24r": true, + // P40 + "Standard_ND6s": true, + "Standard_ND12s": true, + "Standard_ND24s": true, + "Standard_ND24rs": true, + // P100 + "Standard_NC6s_v2": true, + "Standard_NC12s_v2": true, + "Standard_NC24s_v2": true, + "Standard_NC24rs_v2": true, + // V100 + "Standard_NC6s_v3": true, + "Standard_NC12s_v3": true, + "Standard_NC24s_v3": true, + "Standard_NC24rs_v3": true, + } + if _, ok := dm[vmSize]; ok { + return dm[vmSize] + } + + return false +} + +// GetNSeriesVMCasesForTesting returns a struct w/ VM SKUs and whether or not we expect them to be nvidia-enabled +func GetNSeriesVMCasesForTesting() []struct { + VMSKU string + Expected bool +} { + cases := []struct { + VMSKU string + Expected bool + }{ + { + "Standard_NC6", + true, + }, + { + "Standard_NC12", + true, + }, + { + "Standard_NC24", + true, + }, + { + "Standard_NC24r", + true, + }, + { + "Standard_NV6", + true, + }, + { + "Standard_NV12", + true, + }, + { + "Standard_NV24", + true, + }, + { + "Standard_NV24r", + true, + }, + { + "Standard_ND6s", + true, + }, + { + "Standard_ND12s", + true, + }, + { + "Standard_ND24s", + true, + }, + { + "Standard_ND24rs", + true, + }, + { + "Standard_NC6s_v2", + true, + }, + { + "Standard_NC12s_v2", + true, + }, + { + "Standard_NC24s_v2", + true, + }, + { + "Standard_NC24rs_v2", + true, + }, + { + "Standard_NC24rs_v2", + true, + }, + { + "Standard_NC6s_v3", + true, + }, + { + "Standard_NC12s_v3", + true, + }, + { + "Standard_NC24s_v3", + true, + }, + { + "Standard_NC24rs_v3", + true, + }, + { + "Standard_D2_v2", + false, + }, + { + "gobledygook", + false, + }, + { + "", + false, + }, + } + + return cases +} diff --git a/pkg/api/common/helper_test.go b/pkg/api/common/helper_test.go index 9fb8d6b94..909dc9a9f 100644 --- a/pkg/api/common/helper_test.go +++ b/pkg/api/common/helper_test.go @@ -56,3 +56,14 @@ func TestValidateDNSPrefix(t *testing.T) { } } } + +func TestIsNvidiaEnabledSKU(t *testing.T) { + cases := GetNSeriesVMCasesForTesting() + + for _, c := range cases { + ret := IsNvidiaEnabledSKU(c.VMSKU) + if ret != c.Expected { + t.Fatalf("expected IsNvidiaEnabledSKU(%s) to return %t, but instead got %t", c.VMSKU, c.Expected, ret) + } + } +} diff --git a/pkg/api/defaults.go b/pkg/api/defaults.go index 2186bfc53..f4d269ffc 100644 --- a/pkg/api/defaults.go +++ b/pkg/api/defaults.go @@ -438,19 +438,29 @@ func (p *Properties) setAgentProfileDefaults(isUpgrade, isScale bool) { profile.AcceleratedNetworkingEnabledWindows = helpers.PointerToBool(DefaultAcceleratedNetworkingWindowsEnabled) } - if profile.Distro == "" && profile.OSType != Windows { - if p.OrchestratorProfile.IsKubernetes() { - if profile.OSDiskSizeGB != 0 && profile.OSDiskSizeGB < VHDDiskSizeAKS { - profile.Distro = Ubuntu - } else { - if IsNSeriesSKU(p) { - profile.Distro = AKSDockerEngine + if profile.OSType != Windows { + if profile.Distro == "" { + if p.OrchestratorProfile.IsKubernetes() { + if profile.OSDiskSizeGB != 0 && profile.OSDiskSizeGB < VHDDiskSizeAKS { + profile.Distro = Ubuntu } else { - profile.Distro = AKS + if profile.IsNSeriesSKU() { + profile.Distro = AKSDockerEngine + } else { + profile.Distro = AKS + } } + } else if !p.OrchestratorProfile.IsOpenShift() { + profile.Distro = Ubuntu + } + // Ensure distro is set properly for N Series SKUs, because + // (1) At present, "aks-docker-engine" and "ubuntu" are the only working distro base for running GPU workloads on N Series SKUs + // (2) Previous versions of acs-engine had working implementations using the "aks" distro value, + // so we need to hard override it in order to produce a working cluster in upgrade/scale contexts + } else if p.OrchestratorProfile.IsKubernetes() && (isUpgrade || isScale) && profile.IsNSeriesSKU() { + if profile.Distro == AKS { + profile.Distro = AKSDockerEngine } - } else if !p.OrchestratorProfile.IsOpenShift() { - profile.Distro = Ubuntu } } diff --git a/pkg/api/defaults_test.go b/pkg/api/defaults_test.go index c185b133f..c1517a5db 100644 --- a/pkg/api/defaults_test.go +++ b/pkg/api/defaults_test.go @@ -914,26 +914,135 @@ func TestSetVMSSDefaultsAndZones(t *testing.T) { } func TestAKSDockerEngineDistro(t *testing.T) { + // N Series agent pools should always get the "aks-docker-engine" distro for default create flows + // D Series agent pools should always get the "aks" distro for default create flows mockCS := getMockBaseContainerService("1.10.9") properties := mockCS.Properties properties.OrchestratorProfile.OrchestratorType = "Kubernetes" properties.MasterProfile.Count = 1 properties.AgentPoolProfiles[0].VMSize = "Standard_NC6" + properties.AgentPoolProfiles[1].VMSize = "Standard_D2_V2" + properties.AgentPoolProfiles[2].VMSize = "Standard_NC6" + properties.AgentPoolProfiles[2].Distro = Ubuntu + properties.AgentPoolProfiles[3].VMSize = "Standard_D2_V2" + properties.AgentPoolProfiles[3].Distro = Ubuntu properties.setAgentProfileDefaults(false, false) if properties.AgentPoolProfiles[0].Distro != AKSDockerEngine { - t.Fatalf("Expected %s distro for N-series VM got %s instead", AKSDockerEngine, properties.AgentPoolProfiles[0].Distro) + t.Fatalf("Expected %s distro for N-series pool, got %s instead", AKSDockerEngine, properties.AgentPoolProfiles[0].Distro) + } + if properties.AgentPoolProfiles[1].Distro != AKS { + t.Fatalf("Expected %s distro for D-series pool, got %s instead", AKS, properties.AgentPoolProfiles[1].Distro) + } + if properties.AgentPoolProfiles[2].Distro != Ubuntu { + t.Fatalf("Expected %s distro for D-series pool, got %s instead", Ubuntu, properties.AgentPoolProfiles[2].Distro) + } + if properties.AgentPoolProfiles[3].Distro != Ubuntu { + t.Fatalf("Expected %s distro for D-series pool, got %s instead", Ubuntu, properties.AgentPoolProfiles[3].Distro) } + // N Series agent pools with small disk size should always get the "ubuntu" distro for default create flows + // D Series agent pools with small disk size should always get the "ubuntu" distro for default create flows mockCS = getMockBaseContainerService("1.10.9") properties = mockCS.Properties properties.OrchestratorProfile.OrchestratorType = "Kubernetes" properties.MasterProfile.Count = 1 - properties.AgentPoolProfiles[0].VMSize = "Standard_D2_V2" + properties.AgentPoolProfiles[0].VMSize = "Standard_NC6" + properties.AgentPoolProfiles[0].OSDiskSizeGB = VHDDiskSizeAKS - 1 + properties.AgentPoolProfiles[1].VMSize = "Standard_D2_V2" + properties.AgentPoolProfiles[1].OSDiskSizeGB = VHDDiskSizeAKS - 1 properties.setAgentProfileDefaults(false, false) - if properties.AgentPoolProfiles[0].Distro != AKS { - t.Fatalf("Expected %s distro for N-series VM got %s instead", AKS, properties.AgentPoolProfiles[0].Distro) + if properties.AgentPoolProfiles[0].Distro != Ubuntu { + t.Fatalf("Expected %s distro for N-series pool with small disk, got %s instead", Ubuntu, properties.AgentPoolProfiles[0].Distro) + } + if properties.AgentPoolProfiles[1].Distro != Ubuntu { + t.Fatalf("Expected %s distro for D-series pool with small disk, got %s instead", Ubuntu, properties.AgentPoolProfiles[1].Distro) + } + + // N Series agent pools should always get the "aks-docker-engine" distro for upgrade flows unless Ubuntu + // D Series agent pools should always get the distro they requested for upgrade flows + mockCS = getMockBaseContainerService("1.10.9") + properties = mockCS.Properties + properties.OrchestratorProfile.OrchestratorType = "Kubernetes" + properties.MasterProfile.Count = 1 + properties.AgentPoolProfiles[0].VMSize = "Standard_NC6" + properties.AgentPoolProfiles[0].Distro = AKS + properties.AgentPoolProfiles[1].VMSize = "Standard_D2_V2" + properties.AgentPoolProfiles[1].Distro = AKS + properties.AgentPoolProfiles[2].VMSize = "Standard_D2_V2" + properties.AgentPoolProfiles[2].Distro = AKSDockerEngine + properties.AgentPoolProfiles[3].VMSize = "Standard_NC6" + properties.AgentPoolProfiles[3].Distro = Ubuntu + properties.setAgentProfileDefaults(true, false) + + if properties.AgentPoolProfiles[0].Distro != AKSDockerEngine { + t.Fatalf("Expected %s distro for N-series pool, got %s instead", AKSDockerEngine, properties.AgentPoolProfiles[0].Distro) + } + if properties.AgentPoolProfiles[1].Distro != AKS { + t.Fatalf("Expected %s distro for D-series pool, got %s instead", AKS, properties.AgentPoolProfiles[1].Distro) + } + if properties.AgentPoolProfiles[2].Distro != AKSDockerEngine { + t.Fatalf("Expected %s distro for D-series pool, got %s instead", AKSDockerEngine, properties.AgentPoolProfiles[2].Distro) + } + if properties.AgentPoolProfiles[3].Distro != Ubuntu { + t.Fatalf("Expected %s distro for D-series pool, got %s instead", Ubuntu, properties.AgentPoolProfiles[3].Distro) + } + + // N Series agent pools should always get the "aks-docker-engine" distro for scale flows unless Ubuntu + // D Series agent pools should always get the distro they requested for scale flows + mockCS = getMockBaseContainerService("1.10.9") + properties = mockCS.Properties + properties.OrchestratorProfile.OrchestratorType = "Kubernetes" + properties.MasterProfile.Count = 1 + properties.AgentPoolProfiles[0].VMSize = "Standard_NC6" + properties.AgentPoolProfiles[0].Distro = AKS + properties.AgentPoolProfiles[1].VMSize = "Standard_D2_V2" + properties.AgentPoolProfiles[1].Distro = AKS + properties.AgentPoolProfiles[2].VMSize = "Standard_D2_V2" + properties.AgentPoolProfiles[2].Distro = AKSDockerEngine + properties.AgentPoolProfiles[3].VMSize = "Standard_NC6" + properties.AgentPoolProfiles[3].Distro = Ubuntu + properties.setAgentProfileDefaults(false, true) + + if properties.AgentPoolProfiles[0].Distro != AKSDockerEngine { + t.Fatalf("Expected %s distro for N-series pool, got %s instead", AKSDockerEngine, properties.AgentPoolProfiles[0].Distro) + } + if properties.AgentPoolProfiles[1].Distro != AKS { + t.Fatalf("Expected %s distro for D-series pool, got %s instead", AKS, properties.AgentPoolProfiles[1].Distro) + } + if properties.AgentPoolProfiles[2].Distro != AKSDockerEngine { + t.Fatalf("Expected %s distro for D-series pool, got %s instead", AKSDockerEngine, properties.AgentPoolProfiles[2].Distro) + } + if properties.AgentPoolProfiles[3].Distro != Ubuntu { + t.Fatalf("Expected %s distro for D-series pool, got %s instead", Ubuntu, properties.AgentPoolProfiles[3].Distro) + } + + // N Series Windows agent pools should always get no distro value + mockCS = getMockBaseContainerService("1.10.9") + properties = mockCS.Properties + properties.OrchestratorProfile.OrchestratorType = "Kubernetes" + properties.MasterProfile.Count = 1 + properties.AgentPoolProfiles[0].VMSize = "Standard_NC6" + properties.AgentPoolProfiles[0].OSType = Windows + properties.AgentPoolProfiles[1].VMSize = "Standard_NC6" + properties.setAgentProfileDefaults(false, false) + + if properties.AgentPoolProfiles[0].Distro != "" { + t.Fatalf("Expected no distro value for N-series Windows VM, got %s instead", properties.AgentPoolProfiles[0].Distro) + } + if properties.AgentPoolProfiles[1].Distro != AKSDockerEngine { + t.Fatalf("Expected %s distro for N-series pool, got %s instead", AKSDockerEngine, properties.AgentPoolProfiles[1].Distro) + } + + // Non-k8s context + mockCS = getMockBaseContainerService("1.10.9") + properties = mockCS.Properties + properties.MasterProfile.Count = 1 + properties.setAgentProfileDefaults(false, false) + + if properties.AgentPoolProfiles[0].Distro != Ubuntu { + t.Fatalf("Expected %s distro for N-series pool, got %s instead", Ubuntu, properties.AgentPoolProfiles[1].Distro) } } @@ -1179,6 +1288,9 @@ func getMockAPIProperties(orchestratorVersion string) Properties { MasterProfile: &MasterProfile{}, AgentPoolProfiles: []*AgentPoolProfile{ {}, + {}, + {}, + {}, }} } diff --git a/pkg/api/types.go b/pkg/api/types.go index 2e64f1255..8af5a7086 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -1268,8 +1268,13 @@ func (k *KubernetesConfig) IsIPMasqAgentEnabled() bool { return k.isAddonEnabled(IPMASQAgentAddonName, IPMasqAgentAddonEnabled) } -// IsNSeriesSKU returns whether or not the agent pool has Standard_N SKU VMs -func IsNSeriesSKU(p *Properties) bool { +// IsNSeriesSKU returns true if the agent pool contains an N-series (NVIDIA GPU) VM +func (a *AgentPoolProfile) IsNSeriesSKU() bool { + return common.IsNvidiaEnabledSKU(a.VMSize) +} + +// HasNSeriesSKU returns whether or not there is an N series SKU agent pool +func (p *Properties) HasNSeriesSKU() bool { for _, profile := range p.AgentPoolProfiles { if strings.Contains(profile.VMSize, "Standard_N") { return true @@ -1288,7 +1293,7 @@ func (p *Properties) IsNVIDIADevicePluginEnabled() bool { func getDefaultNVIDIADevicePluginEnabled(p *Properties) bool { o := p.OrchestratorProfile var addonEnabled bool - if IsNSeriesSKU(p) && common.IsKubernetesVersionGe(o.OrchestratorVersion, "1.10.0") { + if p.HasNSeriesSKU() && common.IsKubernetesVersionGe(o.OrchestratorVersion, "1.10.0") { addonEnabled = true } else { addonEnabled = false diff --git a/pkg/api/types_test.go b/pkg/api/types_test.go index 0d4b9dae3..74f2aaf42 100644 --- a/pkg/api/types_test.go +++ b/pkg/api/types_test.go @@ -5,6 +5,7 @@ import ( "reflect" "testing" + "github.com/Azure/acs-engine/pkg/api/common" "github.com/Azure/acs-engine/pkg/helpers" ) @@ -1214,8 +1215,8 @@ func TestIsNVIDIADevicePluginEnabled(t *testing.T) { }, } - if !IsNSeriesSKU(&p) { - t.Fatalf("IsNSeriesSKU should return true when explicitly using VM Size %s", p.AgentPoolProfiles[0].VMSize) + if !p.HasNSeriesSKU() { + t.Fatalf("HasNSeriesSKU should return true when explicitly using VM Size %s", p.AgentPoolProfiles[0].VMSize) } if p.IsNVIDIADevicePluginEnabled() { t.Fatalf("KubernetesConfig.IsNVIDIADevicePluginEnabled() should return false with N-series VMs with < k8s 1.10, instead returned %t", p.IsNVIDIADevicePluginEnabled()) @@ -1234,14 +1235,38 @@ func TestIsNVIDIADevicePluginEnabled(t *testing.T) { }, } - if IsNSeriesSKU(&p) { - t.Fatalf("IsNSeriesSKU should return false when explicitly using VM Size %s", p.AgentPoolProfiles[0].VMSize) + if p.HasNSeriesSKU() { + t.Fatalf("HasNSeriesSKU should return false when explicitly using VM Size %s", p.AgentPoolProfiles[0].VMSize) } if p.IsNVIDIADevicePluginEnabled() { t.Fatalf("KubernetesConfig.IsNVIDIADevicePluginEnabled() should return false when explicitly disabled") } } +func TestAgentPoolIsNSeriesSKU(t *testing.T) { + cases := common.GetNSeriesVMCasesForTesting() + + for _, c := range cases { + p := Properties{ + AgentPoolProfiles: []*AgentPoolProfile{ + { + Name: "agentpool", + VMSize: c.VMSKU, + Count: 1, + }, + }, + OrchestratorProfile: &OrchestratorProfile{ + OrchestratorType: Kubernetes, + OrchestratorVersion: "1.12.2", + }, + } + ret := p.AgentPoolProfiles[0].IsNSeriesSKU() + if ret != c.Expected { + t.Fatalf("expected IsNvidiaEnabledSKU(%s) to return %t, but instead got %t", c.VMSKU, c.Expected, ret) + } + } +} + func TestIsContainerMonitoringEnabled(t *testing.T) { v := "1.9.0" o := OrchestratorProfile{ diff --git a/pkg/api/vlabs/types.go b/pkg/api/vlabs/types.go index 768252554..8b7a8350f 100644 --- a/pkg/api/vlabs/types.go +++ b/pkg/api/vlabs/types.go @@ -4,6 +4,7 @@ import ( "encoding/json" "strings" + "github.com/Azure/acs-engine/pkg/api/common" "github.com/pkg/errors" ) @@ -644,7 +645,7 @@ func (a *AgentPoolProfile) IsVirtualMachineScaleSets() bool { // IsNSeriesSKU returns true if the agent pool contains an N-series (NVIDIA GPU) VM func (a *AgentPoolProfile) IsNSeriesSKU() bool { - return strings.Contains(a.VMSize, "Standard_N") + return common.IsNvidiaEnabledSKU(a.VMSize) } // IsManagedDisks returns true if the customer specified managed disks diff --git a/pkg/api/vlabs/types_test.go b/pkg/api/vlabs/types_test.go index 4b1985f8b..0f6cae7f7 100644 --- a/pkg/api/vlabs/types_test.go +++ b/pkg/api/vlabs/types_test.go @@ -3,6 +3,8 @@ package vlabs import ( "encoding/json" "testing" + + "github.com/Azure/acs-engine/pkg/api/common" ) func TestKubernetesAddon(t *testing.T) { @@ -284,3 +286,27 @@ func TestContainerServiceProperties(t *testing.T) { t.Fatalf("unexpectedly detected ContainerServiceProperties MastersAndAgentsUseAvailabilityZones returns false after unmarshal") } } + +func TestAgentPoolIsNSeriesSKU(t *testing.T) { + cases := common.GetNSeriesVMCasesForTesting() + + for _, c := range cases { + p := Properties{ + AgentPoolProfiles: []*AgentPoolProfile{ + { + Name: "agentpool", + VMSize: c.VMSKU, + Count: 1, + }, + }, + OrchestratorProfile: &OrchestratorProfile{ + OrchestratorType: Kubernetes, + OrchestratorRelease: "1.12", + }, + } + ret := p.AgentPoolProfiles[0].IsNSeriesSKU() + if ret != c.Expected { + t.Fatalf("expected IsNvidiaEnabledSKU(%s) to return %t, but instead got %t", c.VMSKU, c.Expected, ret) + } + } +} diff --git a/test/e2e/engine/template.go b/test/e2e/engine/template.go index cb01b2dd7..a75d91ddd 100644 --- a/test/e2e/engine/template.go +++ b/test/e2e/engine/template.go @@ -222,16 +222,6 @@ func (e *Engine) HasWindowsAgents() bool { return false } -// HasGPUNodes will return true if the VM SKU is GPU-enabled -func (e *Engine) HasGPUNodes() bool { - for _, ap := range e.ExpandedDefinition.Properties.AgentPoolProfiles { - if strings.Contains(ap.VMSize, "Standard_N") { - return true - } - } - return false -} - // HasAddon will return true if an addon is enabled func (e *Engine) HasAddon(name string) (bool, api.KubernetesAddon) { for _, addon := range e.ExpandedDefinition.Properties.OrchestratorProfile.KubernetesConfig.Addons { diff --git a/test/e2e/kubernetes/kubernetes_test.go b/test/e2e/kubernetes/kubernetes_test.go index 109f49f48..6c571ca93 100644 --- a/test/e2e/kubernetes/kubernetes_test.go +++ b/test/e2e/kubernetes/kubernetes_test.go @@ -690,7 +690,7 @@ var _ = Describe("Azure Container Cluster using the Kubernetes Orchestrator", fu Describe("with a GPU-enabled agent pool", func() { It("should be able to run a nvidia-gpu job", func() { - if eng.HasGPUNodes() { + if eng.ExpandedDefinition.Properties.HasNSeriesSKU() { version := common.RationalizeReleaseAndVersion( common.Kubernetes, eng.ClusterDefinition.Properties.OrchestratorProfile.OrchestratorRelease,