Skip to content

Commit

Permalink
Add --delete-thread flag for thread management
Browse files Browse the repository at this point in the history
  • Loading branch information
kardolus committed Apr 13, 2024
1 parent 321727f commit dcdba7b
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 59 deletions.
14 changes: 14 additions & 0 deletions client/configmocks_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 0 additions & 14 deletions client/historymocks_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 12 additions & 7 deletions cmd/chatgpt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func main() {
rootCmd.PersistentFlags().BoolVarP(&listThreads, "list-threads", "", false, "List available threads")
rootCmd.PersistentFlags().StringVar(&modelName, "set-model", "", "Set a new default GPT model by specifying the model name")
rootCmd.PersistentFlags().StringVar(&threadName, "set-thread", "", "Set a new active thread by specifying the thread name")
rootCmd.PersistentFlags().StringVar(&threadName, "delete-thread", "", "Delete the specified thread")
rootCmd.PersistentFlags().StringVar(&shell, "set-completions", "", "Generate autocompletion script for your current shell")
rootCmd.PersistentFlags().IntVar(&maxTokens, "set-max-tokens", 0, "Set a new default max token size by specifying the max tokens")
rootCmd.PersistentFlags().IntVar(&contextWindow, "set-context-window", 0, "Set a new default context window size")
Expand Down Expand Up @@ -120,6 +121,16 @@ func run(cmd *cobra.Command, args []string) error {
return nil
}

if cmd.Flag("delete-thread").Changed {
cm := configmanager.New(config.New())

if err := cm.DeleteThread(threadName); err != nil {
return err
}
fmt.Printf("Successfully deleted thead %s\n", threadName)
return nil
}

if listThreads {
cm := configmanager.New(config.New())

Expand All @@ -135,15 +146,9 @@ func run(cmd *cobra.Command, args []string) error {
}

if clearHistory {
historyHandler, err := history.New()
if err != nil {
return err
}

cm := configmanager.New(config.New())
historyHandler.SetThread(cm.Config.Thread)

if err := historyHandler.Delete(); err != nil {
if err := cm.DeleteThread(cm.Config.Thread); err != nil {
return err
}

Expand Down
10 changes: 10 additions & 0 deletions config/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const (
)

type ConfigStore interface {
Delete(string) error
List() ([]string, error)
Read() (types.Config, error)
ReadDefaults() types.Config
Expand Down Expand Up @@ -63,6 +64,15 @@ func (f *FileIO) WithHistoryPath(historyPath string) *FileIO {
return f
}

func (f *FileIO) Delete(name string) error {
path := filepath.Join(f.historyFilePath, name+".json")

if _, err := os.Stat(path); err == nil {
return os.Remove(path)
}
return nil
}

func (f *FileIO) List() ([]string, error) {
var result []string

Expand Down
6 changes: 6 additions & 0 deletions configmanager/configmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ func (c *ConfigManager) APIKeyEnvVarName() string {
return strings.ToUpper(c.Config.Name) + "_" + "API_KEY"
}

// DeleteThread removes the specified thread from the configuration store.
// This operation is idempotent; non-existent threads do not cause errors.
func (c *ConfigManager) DeleteThread(thread string) error {
return c.configStore.Delete(thread)
}

// ListThreads retrieves a list of all threads stored in the configuration.
// It marks the current thread with an asterisk (*) and returns the list sorted alphabetically.
// If an error occurs while retrieving the threads from the config store, it returns the error.
Expand Down
47 changes: 40 additions & 7 deletions configmanager/configmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
PresencePenalty: 5.5,
}

mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).AnyTimes()
mockConfigStore.EXPECT().Read().Return(userConfig, nil).AnyTimes()
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(userConfig, nil).Times(1)

subject := configmanager.New(mockConfigStore).WithEnvironment()

Expand Down Expand Up @@ -166,8 +166,8 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
os.Setenv(envPrefix+"FREQUENCY_PENALTY", "4.4")
os.Setenv(envPrefix+"PRESENCE_PENALTY", "5.5")

mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).AnyTimes()
mockConfigStore.EXPECT().Read().Return(types.Config{}, errors.New("config error")).AnyTimes()
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{}, errors.New("config error")).Times(1)

subject := configmanager.New(mockConfigStore).WithEnvironment()

Expand Down Expand Up @@ -226,8 +226,8 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
PresencePenalty: 4.5,
}

mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).AnyTimes()
mockConfigStore.EXPECT().Read().Return(userConfig, nil).AnyTimes()
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(userConfig, nil).Times(1)

subject := configmanager.New(mockConfigStore).WithEnvironment()

Expand Down Expand Up @@ -281,13 +281,46 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
})
})

when("DeleteThread()", func() {
var subject *configmanager.ConfigManager

threadName := "non-active-thread"

it.Before(func() {
userConfig := types.Config{Thread: threadName}

mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(userConfig, nil).Times(1)

subject = configmanager.New(mockConfigStore).WithEnvironment()
})

it("propagates the error from the config store", func() {
expectedMsg := "expected-error"

mockConfigStore.EXPECT().Delete(threadName).Return(errors.New(expectedMsg)).Times(1)

err := subject.DeleteThread(threadName)

Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(expectedMsg))
})
it("completes successfully the config store throws no error", func() {
mockConfigStore.EXPECT().Delete(threadName).Return(nil).Times(1)

err := subject.DeleteThread(threadName)

Expect(err).NotTo(HaveOccurred())
})
})

when("ListThreads()", func() {
activeThread := "active-thread"

it.Before(func() {
userConfig := types.Config{Thread: activeThread}

mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).AnyTimes()
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(userConfig, nil).Times(1)
})

Expand Down
14 changes: 14 additions & 0 deletions configmanager/configmocks_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 0 additions & 8 deletions history/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ const (
)

type HistoryStore interface {
Delete() error
Read() ([]types.Message, error)
Write([]types.Message) error
SetThread(thread string)
Expand Down Expand Up @@ -63,13 +62,6 @@ func (f *FileIO) WithDirectory(historyDir string) *FileIO {
return f
}

func (f *FileIO) Delete() error {
if _, err := os.Stat(f.getPath()); err == nil {
return os.Remove(f.getPath())
}
return nil
}

func (f *FileIO) Read() ([]types.Message, error) {
return parseFile(f.getPath())
}
Expand Down
78 changes: 55 additions & 23 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
RegisterTestingT(t)
})

when("Read, Write and Delete History", func() {
when("Read and Write History", func() {
const threadName = "default-thread"

var (
Expand Down Expand Up @@ -93,17 +93,9 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
Expect(err).NotTo(HaveOccurred())
Expect(readMessages).To(Equal(messages))
})

it("deletes the file", func() {
err = fileIO.Delete()
Expect(err).NotTo(HaveOccurred())

_, err = os.Stat(threadName + ".json")
Expect(os.IsNotExist(err)).To(BeTrue())
})
})

when("Read, Write, List Config", func() {
when("Read, Write, List, Delete Config", func() {
var (
tmpDir string
tmpFile *os.File
Expand Down Expand Up @@ -236,6 +228,26 @@ max_tokens: 100
Expect(result[2]).To(Equal("thread3.json"))
})

it("deletes the thread", func() {
files := []string{"thread1.json", "thread2.json", "thread3.json"}

for _, file := range files {
file, err := os.Create(filepath.Join(historyDir, file))
Expect(err).NotTo(HaveOccurred())

Expect(file.Close()).To(Succeed())
}

err = configIO.Delete("thread2")
Expect(err).NotTo(HaveOccurred())

_, err = os.Stat(filepath.Join(historyDir, "thread2.json"))
Expect(os.IsNotExist(err)).To(BeTrue())

_, err = os.Stat(filepath.Join(historyDir, "thread3.json"))
Expect(os.IsNotExist(err)).To(BeFalse())
})

// Since we don't have a Delete method in the config, we will test if we can overwrite the configuration.
it("overwrites the existing config", func() {
newConfig := types.Config{
Expand Down Expand Up @@ -333,19 +345,6 @@ max_tokens: 100
Eventually(session).Should(gexec.Exit(exitSuccess))
})

it("should require a hidden folder for the --clear-history flag", func() {
Expect(os.Unsetenv(apiKeyEnvVar)).To(Succeed())

command := exec.Command(binaryPath, "--clear-history")
session, err := gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred())

Eventually(session).Should(gexec.Exit(exitFailure))

output := string(session.Out.Contents())
Expect(output).To(ContainSubstring(".chatgpt-cli: no such file or directory"))
})

it("should require a hidden folder for the --list-threads flag", func() {
command := exec.Command(binaryPath, "--list-threads")
session, err := gexec.Start(command, io.Discard, io.Discard)
Expand Down Expand Up @@ -631,6 +630,39 @@ max_tokens: 100
Expect(output).To(ContainSubstring("- thread3"))
})

it("should delete the expected thread using the --delete-threads flag", func() {
historyDir := path.Join(filePath, "history")
Expect(os.Mkdir(historyDir, 0755)).To(Succeed())

files := []string{"thread1.json", "thread2.json", "thread3.json", "default.json"}

os.Mkdir(historyDir, 7555)

for _, file := range files {
file, err := os.Create(filepath.Join(historyDir, file))
Expect(err).NotTo(HaveOccurred())

Expect(file.Close()).To(Succeed())
}

runCommand("--delete-thread", "thread2")

output := runCommand("--list-threads")

Expect(output).To(ContainSubstring("* default (current)"))
Expect(output).To(ContainSubstring("- thread1"))
Expect(output).NotTo(ContainSubstring("- thread2"))
Expect(output).To(ContainSubstring("- thread3"))
})

it("should not throw an error when a non-existent thread is deleted using the --delete-threads flag", func() {
command := exec.Command(binaryPath, "--delete-thread", "does-not-exist")
session, err := gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred())

Eventually(session).Should(gexec.Exit(exitSuccess))
})

when("configurable flags are set", func() {
it.Before(func() {
configFile = path.Join(filePath, "config.yaml")
Expand Down

0 comments on commit dcdba7b

Please sign in to comment.