From c1cbeff6b3c6fb3f4d928a0e96c9ac9e07a047b3 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Wed, 8 Jan 2025 15:25:42 +0700 Subject: [PATCH] Config: Fix config version downgrade --- cmd/config/config.go | 114 +++++++++++++------------------ config/config.go | 6 +- config/config_encryption.go | 20 +++--- config/config_encryption_test.go | 18 ++--- config/versions/v3.go | 1 - config/versions/versions.go | 57 ++++++++++++---- config/versions/versions_test.go | 20 +++--- 7 files changed, 123 insertions(+), 113 deletions(-) diff --git a/cmd/config/config.go b/cmd/config/config.go index 0f3e29d482a..13d53568c5e 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -1,7 +1,7 @@ package main import ( - "errors" + "context" "flag" "fmt" "os" @@ -11,9 +11,10 @@ import ( "github.com/buger/jsonparser" "github.com/thrasher-corp/gocryptotrader/common/file" "github.com/thrasher-corp/gocryptotrader/config" + "github.com/thrasher-corp/gocryptotrader/config/versions" ) -var commands = []string{"upgrade", "encrypt", "decrypt"} +var commands = []string{"upgrade", "downgrade", "encrypt", "decrypt"} func main() { fmt.Println("GoCryptoTrader: config-helper tool") @@ -22,6 +23,7 @@ func main() { var in, out, keyStr string var inplace bool + var version int fs := flag.NewFlagSet("config", flag.ExitOnError) fs.Usage = func() { usage(fs) } @@ -29,6 +31,7 @@ func main() { fs.StringVar(&out, "out", "[in].out", "The config output file") fs.BoolVar(&inplace, "edit", false, "Edit; Save result to the original file") fs.StringVar(&keyStr, "key", "", "The key to use for AES encryption") + fs.IntVar(&version, "version", 0, "The version to downgrade to") cmd, args := parseCommand(os.Args[1:]) if cmd == "" { @@ -46,83 +49,59 @@ func main() { out = in + ".out" } - key := []byte(keyStr) var err error - switch cmd { - case "upgrade": - err = upgradeFile(in, out, key) - case "decrypt": - err = encryptWrapper(in, out, key, false, decryptFile) - case "encrypt": - err = encryptWrapper(in, out, key, true, encryptFile) - } + key := []byte(keyStr) + data := readFile(in) + isEncrypted := config.IsEncrypted(data) - if err != nil { - fatal(err.Error()) + if cmd == "encrypt" && isEncrypted { + fatal("Error: File is already encrypted") } - fmt.Println("Success! File written to " + out) -} - -func upgradeFile(in, out string, key []byte) error { - c := &config.Config{ - EncryptionKeyProvider: func(_ bool) ([]byte, error) { - if len(key) != 0 { - return key, nil - } - return config.PromptForConfigKey(false) - }, + if len(key) == 0 && (isEncrypted || cmd == "encrypt") { + if key, err = config.PromptForConfigKey(cmd == "encrypt"); err != nil { + fatal(err.Error()) + } } - if err := c.ReadConfigFromFile(in, true); err != nil { - return err + if config.IsEncrypted(data) { + if data, err = config.DecryptConfigData(data, key); err != nil { + fatal(err.Error()) + } } - return c.SaveConfigToFile(out) -} - -type encryptFunc func(string, []byte) ([]byte, error) - -func encryptWrapper(in, out string, key []byte, confirmKey bool, fn encryptFunc) error { - if len(key) == 0 { - var err error - if key, err = config.PromptForConfigKey(confirmKey); err != nil { - return err + switch cmd { + case "decrypt": + if data, err = jsonparser.Set(data, []byte("-1"), "encryptConfig"); err != nil { + fatal("Unable to decrypt config data; Error: " + err.Error()) + } + case "downgrade", "upgrade": + if version == 0 { + if cmd == "downgrade" { + fmt.Fprintln(os.Stderr, "Error: downgrade requires a version") + usage(fs) + os.Exit(3) + } + version = -1 + } + if data, err = versions.Manager.Deploy(context.Background(), data, version); err != nil { + fatal("Unable to " + cmd + " config; Error: " + err.Error()) + } + if !isEncrypted { + break + } + fallthrough + case "encrypt": + if data, err = config.EncryptConfigData(data, key); err != nil { + fatal("Unable to encrypt config data; Error: " + err.Error()) } } - outData, err := fn(in, key) - if err != nil { - return err - } - if err := file.Write(out, outData); err != nil { - return fmt.Errorf("unable to write output file %s; Error: %w", out, err) - } - return nil -} -func encryptFile(in string, key []byte) ([]byte, error) { - if config.IsFileEncrypted(in) { - return nil, errors.New("file is already encrypted") - } - outData, err := config.EncryptConfigFile(readFile(in), key) - if err != nil { - return nil, fmt.Errorf("unable to encrypt config data. Error: %w", err) + if err := file.Write(out, data); err != nil { + fatal("Unable to write output file `" + out + "`; Error: " + err.Error()) } - return outData, nil -} -func decryptFile(in string, key []byte) ([]byte, error) { - if !config.IsFileEncrypted(in) { - return nil, errors.New("file is already decrypted") - } - outData, err := config.DecryptConfigFile(readFile(in), key) - if err != nil { - return nil, fmt.Errorf("unable to decrypt config data. Error: %w", err) - } - if outData, err = jsonparser.Set(outData, []byte("-1"), "encryptConfig"); err != nil { - return nil, fmt.Errorf("unable to decrypt config data. Error: %w", err) - } - return outData, nil + fmt.Println("Success! File written to " + out) } func readFile(in string) []byte { @@ -152,7 +131,7 @@ func parseCommand(a []string) (cmd string, args []string) { switch len(cmds) { case 0: fmt.Fprintln(os.Stderr, "No command provided") - case 1: // + case 1: return cmds[0], rem default: fmt.Fprintln(os.Stderr, "Too many commands provided: "+strings.Join(cmds, ", ")) @@ -171,6 +150,7 @@ The commands are: encrypt encrypt infile and write to outfile decrypt decrypt infile and write to outfile upgrade upgrade the version of a decrypted config file + downgrade downgrade the version of a decrypted config file to a specific version The arguments are:`) fs.PrintDefaults() diff --git a/config/config.go b/config/config.go index 5e14f884553..10ba840fb30 100644 --- a/config/config.go +++ b/config/config.go @@ -1502,7 +1502,7 @@ func (c *Config) readConfig(d io.Reader) error { } } - if j, err = versions.Manager.Deploy(context.Background(), j); err != nil { + if j, err = versions.Manager.Deploy(context.Background(), j, -1); err != nil { return err } @@ -1536,7 +1536,7 @@ func (c *Config) decryptConfig(j []byte) ([]byte, error) { log.Errorf(log.ConfigMgr, "PromptForConfigKey err: %s", err) continue } - d, err := c.decryptConfigData(j, key) + d, err := c.DecryptConfigData(j, key) if err != nil { log.Errorln(log.ConfigMgr, "Could not decrypt and deserialise data with given key. Invalid password?", err) continue @@ -1593,7 +1593,7 @@ func (c *Config) Save(writerProvider func() (io.Writer, error)) error { } c.sessionDK, c.storedSalt = sessionDK, storedSalt } - payload, err = c.encryptConfigFile(payload) + payload, err = c.encryptConfigData(payload) if err != nil { return err } diff --git a/config/config_encryption.go b/config/config_encryption.go index b438ceef334..cd2ff8a1d96 100644 --- a/config/config_encryption.go +++ b/config/config_encryption.go @@ -95,8 +95,8 @@ func getSensitiveInput(prompt string) (resp []byte, err error) { return bytes.TrimRight(resp, "\r\n"), err } -// EncryptConfigFile encrypts json config data with a key -func EncryptConfigFile(configData, key []byte) ([]byte, error) { +// EncryptConfigData encrypts json config data with a key +func EncryptConfigData(configData, key []byte) ([]byte, error) { sessionDK, salt, err := makeNewSessionDK(key) if err != nil { return nil, err @@ -105,12 +105,12 @@ func EncryptConfigFile(configData, key []byte) ([]byte, error) { sessionDK: sessionDK, storedSalt: salt, } - return c.encryptConfigFile(configData) + return c.encryptConfigData(configData) } -// encryptConfigFile encrypts json config data with a key +// encryptConfigData encrypts json config data with a key // The EncryptConfig field is set to config enabled (1) -func (c *Config) encryptConfigFile(configData []byte) ([]byte, error) { +func (c *Config) encryptConfigData(configData []byte) ([]byte, error) { configData, err := jsonparser.Set(configData, []byte("1"), "encryptConfig") if err != nil { return nil, fmt.Errorf("%w: %w", ErrSettingEncryptConfig, err) @@ -135,13 +135,13 @@ func (c *Config) encryptConfigFile(configData []byte) ([]byte, error) { return appendedFile, nil } -// DecryptConfigFile decrypts config data with a key -func DecryptConfigFile(d, key []byte) ([]byte, error) { - return (&Config{}).decryptConfigData(d, key) +// DecryptConfigData decrypts config data with a key +func DecryptConfigData(d, key []byte) ([]byte, error) { + return (&Config{}).DecryptConfigData(d, key) } -// decryptConfigData decrypts config data with a key -func (c *Config) decryptConfigData(d, key []byte) ([]byte, error) { +// DecryptConfigData decrypts config data with a key +func (c *Config) DecryptConfigData(d, key []byte) ([]byte, error) { if !bytes.HasPrefix(d, encryptionPrefix) { return d, errNoPrefix } diff --git a/config/config_encryption_test.go b/config/config_encryption_test.go index 3fad053f7f8..5bb20c0b054 100644 --- a/config/config_encryption_test.go +++ b/config/config_encryption_test.go @@ -59,16 +59,16 @@ func TestPromptForConfigKey(t *testing.T) { func TestEncryptConfigFile(t *testing.T) { t.Parallel() - _, err := EncryptConfigFile([]byte("test"), nil) + _, err := EncryptConfigData([]byte("test"), nil) require.ErrorIs(t, err, errKeyIsEmpty) c := &Config{ sessionDK: []byte("a"), } - _, err = c.encryptConfigFile([]byte(`test`)) + _, err = c.encryptConfigData([]byte(`test`)) require.ErrorIs(t, err, ErrSettingEncryptConfig) - _, err = c.encryptConfigFile([]byte(`{"test":1}`)) + _, err = c.encryptConfigData([]byte(`{"test":1}`)) require.Error(t, err) require.IsType(t, aes.KeySizeError(1), err) @@ -79,26 +79,26 @@ func TestEncryptConfigFile(t *testing.T) { sessionDK: sessDk, storedSalt: salt, } - _, err = c.encryptConfigFile([]byte(`{"test":1}`)) + _, err = c.encryptConfigData([]byte(`{"test":1}`)) require.NoError(t, err) } func TestDecryptConfigFile(t *testing.T) { t.Parallel() - e, err := EncryptConfigFile([]byte(`{"test":1}`), []byte("key")) + e, err := EncryptConfigData([]byte(`{"test":1}`), []byte("key")) require.NoError(t, err) - d, err := DecryptConfigFile(e, []byte("key")) + d, err := DecryptConfigData(e, []byte("key")) require.NoError(t, err) assert.Equal(t, `{"test":1,"encryptConfig":1}`, string(d), "encryptConfig should be set to 1 after first encryption") - _, err = DecryptConfigFile(e, nil) + _, err = DecryptConfigData(e, nil) require.ErrorIs(t, err, errKeyIsEmpty) - _, err = DecryptConfigFile([]byte("test"), nil) + _, err = DecryptConfigData([]byte("test"), nil) require.ErrorIs(t, err, errNoPrefix) - _, err = DecryptConfigFile(encryptionPrefix, []byte("AAAAAAAAAAAAAAAA")) + _, err = DecryptConfigData(encryptionPrefix, []byte("AAAAAAAAAAAAAAAA")) require.ErrorIs(t, err, errAESBlockSize) } diff --git a/config/versions/v3.go b/config/versions/v3.go index d4700ef3cde..4f7f6d436f1 100644 --- a/config/versions/v3.go +++ b/config/versions/v3.go @@ -62,7 +62,6 @@ func (v *Version3) UpgradeExchange(_ context.Context, e []byte) ([]byte, error) // DowngradeExchange moves AssetEnabled assets into AssetType field func (v *Version3) DowngradeExchange(_ context.Context, e []byte) ([]byte, error) { assetTypes := []string{} - assetEnabledFn := func(asset []byte, v []byte, _ jsonparser.ValueType, _ int) error { if b, err := jsonparser.GetBoolean(v, "assetEnabled"); err == nil { if b { diff --git a/config/versions/versions.go b/config/versions/versions.go index 13db0d7fcf6..38958de40a3 100644 --- a/config/versions/versions.go +++ b/config/versions/versions.go @@ -13,9 +13,11 @@ package versions import ( "bytes" "context" + "encoding/json" "errors" "fmt" "log" + "os" "slices" "strconv" "sync" @@ -55,16 +57,24 @@ type manager struct { var Manager = &manager{} // Deploy upgrades or downgrades the config between versions -func (m *manager) Deploy(ctx context.Context, j []byte) ([]byte, error) { +// version param -1 defaults to the latest version +// Prints an error an exits if the config file version or version param is not registered +func (m *manager) Deploy(ctx context.Context, j []byte, version int) ([]byte, error) { if err := m.checkVersions(); err != nil { return j, err } - target, err := m.latest() + latest, err := m.latest() if err != nil { return j, err } + target := latest + + if version != -1 { + target = int(version) + } + m.m.RLock() defer m.m.RUnlock() @@ -77,47 +87,59 @@ func (m *manager) Deploy(ctx context.Context, j []byte) ([]byte, error) { return j, fmt.Errorf("%w `version`: %w", common.ErrGettingField, err) case target == current: return j, nil + case latest < current: + errVersionNotRegistered(current, latest, "Version in config file") + case target > latest: + errVersionNotRegistered(target, latest, "Target downgrade version") } for current != target { - next := current + 1 - action := "upgrade" + patchVersion := current + 1 + action := "upgrade to" configMethod := ConfigVersion.UpgradeConfig exchMethod := ExchangeVersion.UpgradeExchange if target < current { - next = current - 1 - action = "downgrade" + patchVersion = current + action = "downgrade from" configMethod = ConfigVersion.DowngradeConfig exchMethod = ExchangeVersion.DowngradeExchange } - log.Printf("Running %s to config version %v\n", action, next) + log.Printf("Running %s config version %v\n", action, patchVersion) - patch := m.versions[next] + patch := m.versions[patchVersion] if cPatch, ok := patch.(ConfigVersion); ok { if j, err = configMethod(cPatch, ctx, j); err != nil { - return j, fmt.Errorf("%w %s to %v: %w", errApplyingVersion, action, next, err) + return j, fmt.Errorf("%w %s %v: %w", errApplyingVersion, action, patchVersion, err) } } if ePatch, ok := patch.(ExchangeVersion); ok { if j, err = exchangeDeploy(ctx, ePatch, exchMethod, j); err != nil { - return j, fmt.Errorf("%w %s to %v: %w", errApplyingVersion, action, next, err) + return j, fmt.Errorf("%w %s %v: %w", errApplyingVersion, action, patchVersion, err) } } - current = next + current = patchVersion + if target < current { + current = patchVersion - 1 + } if j, err = jsonparser.Set(j, []byte(strconv.Itoa(current)), "version"); err != nil { - return j, fmt.Errorf("%w `version` during %s to %v: %w", common.ErrSettingField, action, next, err) + return j, fmt.Errorf("%w `version` during %s %v: %w", common.ErrSettingField, action, patchVersion, err) } } + var out bytes.Buffer + if err = json.Indent(&out, j, "", " "); err != nil { + return j, fmt.Errorf("error formatting json: %w", err) + } + log.Println("Version management finished") - return j, nil + return out.Bytes(), nil } func exchangeDeploy(ctx context.Context, patch ExchangeVersion, method func(ExchangeVersion, context.Context, []byte) ([]byte, error), j []byte) ([]byte, error) { @@ -196,3 +218,12 @@ func (m *manager) checkVersions() error { } return nil } + +func errVersionNotRegistered(current, latest int, msg string) { + fmt.Fprintf(os.Stderr, ` +%s '%d' is higher than latest available version '%d' +Switch back to the version of GoCryptoTrader containing config version '%d' and run: +$ cmd/config downgrade %d +`, msg, current, latest, current, latest) + os.Exit(1) +} diff --git a/config/versions/versions_test.go b/config/versions/versions_test.go index 7c7473a8774..e5eb4bd3445 100644 --- a/config/versions/versions_test.go +++ b/config/versions/versions_test.go @@ -13,45 +13,45 @@ import ( func TestDeploy(t *testing.T) { t.Parallel() m := manager{} - _, err := m.Deploy(context.Background(), []byte(``)) + _, err := m.Deploy(context.Background(), []byte(``), -1) assert.ErrorIs(t, err, errNoVersions) m.registerVersion(1, &TestVersion1{}) - _, err = m.Deploy(context.Background(), []byte(``)) + _, err = m.Deploy(context.Background(), []byte(``), -1) require.ErrorIs(t, err, errVersionIncompatible) m = manager{} m.registerVersion(0, &Version0{}) - _, err = m.Deploy(context.Background(), []byte(`not an object`)) + _, err = m.Deploy(context.Background(), []byte(`not an object`), -1) require.ErrorIs(t, err, jsonparser.KeyPathNotFoundError, "Must throw the correct error trying to add version to bad json") require.ErrorIs(t, err, common.ErrSettingField, "Must throw the correct error trying to add version to bad json") require.ErrorContains(t, err, "version", "Must throw the correct error trying to add version to bad json") - _, err = m.Deploy(context.Background(), []byte(`{"version":"not an int"}`)) + _, err = m.Deploy(context.Background(), []byte(`{"version":"not an int"}`), -1) require.ErrorIs(t, err, common.ErrGettingField, "Must throw the correct error trying to get version from bad json") in := []byte(`{"version":0,"exchanges":[{"name":"Juan"}]}`) - j, err := m.Deploy(context.Background(), in) + j, err := m.Deploy(context.Background(), in, -1) require.NoError(t, err) require.Equal(t, string(in), string(j)) m.registerVersion(1, &Version1{}) - j, err = m.Deploy(context.Background(), in) + j, err = m.Deploy(context.Background(), in, -1) require.NoError(t, err) require.Contains(t, string(j), `"version":1`) m.versions = m.versions[:1] - j, err = m.Deploy(context.Background(), j) + j, err = m.Deploy(context.Background(), j, -1) require.NoError(t, err) require.Contains(t, string(j), `"version":0`) m.versions = append(m.versions, &TestVersion2{ConfigErr: true, ExchErr: false}) // Bit hacky, but this will actually work - _, err = m.Deploy(context.Background(), j) + _, err = m.Deploy(context.Background(), j, -1) require.ErrorIs(t, err, errUpgrade) m.versions[1] = &TestVersion2{ConfigErr: false, ExchErr: true} - _, err = m.Deploy(context.Background(), in) + _, err = m.Deploy(context.Background(), in, -1) require.Implements(t, (*ExchangeVersion)(nil), m.versions[1]) require.ErrorIs(t, err, errUpgrade) } @@ -61,7 +61,7 @@ func TestDeploy(t *testing.T) { func TestExchangeDeploy(t *testing.T) { t.Parallel() m := manager{} - _, err := m.Deploy(context.Background(), []byte(``)) + _, err := m.Deploy(context.Background(), []byte(``), -1) assert.ErrorIs(t, err, errNoVersions) v := &TestVersion2{}