diff --git a/internal/local/info.go b/internal/local/info.go index d6b3d889..6423ce1b 100644 --- a/internal/local/info.go +++ b/internal/local/info.go @@ -162,18 +162,28 @@ func (service *ProvisioningService) DisplayAMTInfo() (err error) { service.PrintOutput("Control Mode : " + string(utils.InterpretControlMode(result))) } if service.flags.AmtInfo.OpState { - result, err := cmd.GetChangeEnabled() + majorVersion, err := GetMajorVersion(dataStruct["amt"].(string)) if err != nil { log.Error(err) } - if result.IsNewInterfaceVersion() { - opStateValue := "disabled" - if result.IsAMTEnabled() { - opStateValue = "enabled" + const minimumAMTVersion = 11 + // Check if the AMT major version is greater than 11 + if majorVersion > minimumAMTVersion { + result, err := cmd.GetChangeEnabled() + if err != nil { + log.Error(err) } + if result.IsNewInterfaceVersion() { + opStateValue := "disabled" + if result.IsAMTEnabled() { + opStateValue = "enabled" + } - dataStruct["operationalState"] = opStateValue - service.PrintOutput("Operational State : " + opStateValue) + dataStruct["operationalState"] = opStateValue + service.PrintOutput("Operational State : " + opStateValue) + } + } else { + log.Debug("OpState will not work on AMT versions 11 and below.") } } if service.flags.AmtInfo.DNS { @@ -419,3 +429,16 @@ func DecodeAMT(version, SKU string) string { } return result } +func GetMajorVersion(version string) (int, error) { + amtParts := strings.Split(version, ".") + if len(amtParts) <= 1 { + return 0, fmt.Errorf("invalid AMT version format") + } + + majorVersion, err := strconv.Atoi(amtParts[0]) + if err != nil { + return 0, fmt.Errorf("invalid AMT version") + } + + return majorVersion, nil +} diff --git a/internal/local/info_test.go b/internal/local/info_test.go index e354ab42..d4c57475 100644 --- a/internal/local/info_test.go +++ b/internal/local/info_test.go @@ -183,6 +183,37 @@ func TestDecodeAMT(t *testing.T) { } } +func TestGetMajorVersion(t *testing.T) { + testCases := []struct { + version string + want int + wantErr bool + }{ + {"1.2.3", 1, false}, + {"11.8.55", 11, false}, + {"12.5.2", 12, false}, + {"16.1.25", 16, false}, + {"18.2.10", 18, false}, + {"", 0, true}, + {"abc", 0, true}, + {"1", 0, true}, + {"1.2.3.4.5", 1, false}, + } + + for _, tc := range testCases { + got, err := GetMajorVersion(tc.version) + + if (err != nil) != tc.wantErr { + t.Errorf("GetMajorVersion(%q) error = %v, wantErr %v", tc.version, err, tc.wantErr) + continue + } + + if !tc.wantErr && got != tc.want { + t.Errorf("GetMajorVersion(%q) = %v; want %v", tc.version, got, tc.want) + } + } +} + var testNetEnumerator1 = flags.NetEnumerator{ Interfaces: func() ([]net.Interface, error) { return []net.Interface{ diff --git a/internal/local/opstate.go b/internal/local/opstate.go index 7181b2e4..3b086e82 100644 --- a/internal/local/opstate.go +++ b/internal/local/opstate.go @@ -26,6 +26,10 @@ func (service *ProvisioningService) CheckAndEnableAMT(skipIPRenewal bool) (bool, resp, err := service.amtCommand.GetChangeEnabled() tlsIsEnforced := false if err != nil { + if err.Error() == "wait timeout while sending data" { + log.Debug("Operation timed out while sending data. This may occur on systems with AMT version 11 and below.") + return tlsIsEnforced, nil + } log.Error(err) return tlsIsEnforced, utils.AMTConnectionFailed } diff --git a/internal/local/opstate_test.go b/internal/local/opstate_test.go index 29a55120..e3098392 100644 --- a/internal/local/opstate_test.go +++ b/internal/local/opstate_test.go @@ -6,6 +6,7 @@ package local import ( + "errors" "rpc/internal/amt" "rpc/internal/flags" "rpc/pkg/utils" @@ -16,7 +17,7 @@ import ( ) func TestCheckAndEnableAMT(t *testing.T) { - + var errMockTimeout = errors.New("wait timeout while sending data") tests := []struct { name string skipIPRenewal bool @@ -67,6 +68,12 @@ func TestCheckAndEnableAMT(t *testing.T) { skipIPRenewal: true, renewDHCPLeaseRC: utils.WiredConfigurationFailed, }, + { + name: "expect tlsIsEnforced false when operation times out", + expectedRC: nil, + expectedTLS: false, + errChangeEnabled: errMockTimeout, + }, } for _, tc := range tests { @@ -86,6 +93,11 @@ func TestCheckAndEnableAMT(t *testing.T) { tlsForced, err := lps.CheckAndEnableAMT(tc.skipIPRenewal) assert.Equal(t, tc.expectedTLS, tlsForced) assert.Equal(t, tc.expectedRC, err) + if tc.name == "expect tlsIsEnforced false when operation times out" { + assert.False(t, tlsForced) + assert.Nil(t, err) + } + // Reset mocks mockChangeEnabledResponse = origRsp errMockChangeEnabled = origChangeEnabledErr mockEnableAMTErr = origEnableAMTErr