Skip to content

Commit

Permalink
feat: enhance history handling and add timestamps
Browse files Browse the repository at this point in the history
  • Loading branch information
kardolus committed Nov 5, 2024
1 parent fb090c4 commit 64951ac
Show file tree
Hide file tree
Showing 11 changed files with 325 additions and 124 deletions.
75 changes: 54 additions & 21 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"github.com/kardolus/chatgpt-cli/utils"
"strings"
"time"
"unicode/utf8"

"github.com/kardolus/chatgpt-cli/history"
Expand All @@ -23,14 +24,26 @@ const (
gptPrefix = "gpt"
)

type Timer interface {
Now() time.Time
}

type RealTime struct {
}

func (r *RealTime) Now() time.Time {
return time.Now()
}

type Client struct {
Config types.Config
History []types.Message
History []types.History
caller http.Caller
historyStore history.HistoryStore
timer Timer
}

func New(callerFactory http.CallerFactory, hs history.HistoryStore, cfg types.Config, interactiveMode bool) *Client {
func New(callerFactory http.CallerFactory, hs history.HistoryStore, t Timer, cfg types.Config, interactiveMode bool) *Client {
caller := callerFactory(cfg)

if interactiveMode && cfg.AutoCreateNewThread {
Expand All @@ -43,6 +56,7 @@ func New(callerFactory http.CallerFactory, hs history.HistoryStore, cfg types.Co
Config: cfg,
caller: caller,
historyStore: hs,
timer: t,
}
}

Expand Down Expand Up @@ -107,8 +121,8 @@ func (c *Client) ListModels() ([]string, error) {
// characters.
func (c *Client) ProvideContext(context string) {
c.initHistory()
messages := createMessagesFromString(context)
c.History = append(c.History, messages...)
historyEntries := c.createHistoryEntriesFromString(context)
c.History = append(c.History, historyEntries...)
}

// Query sends a query to the API, returning the response as a string along with the token usage.
Expand Down Expand Up @@ -182,8 +196,14 @@ func (c *Client) Stream(input string) error {
}

func (c *Client) createBody(stream bool) ([]byte, error) {
var messages []types.Message

for _, item := range c.History {
messages = append(messages, item.Message)
}

body := types.CompletionsRequest{
Messages: c.History,
Messages: messages,
Model: c.Config.Model,
MaxTokens: c.Config.MaxTokens,
Temperature: c.Config.Temperature,
Expand All @@ -207,8 +227,11 @@ func (c *Client) initHistory() {
}

if len(c.History) == 0 {
c.History = []types.Message{{
Role: SystemRole,
c.History = []types.History{{
Message: types.Message{
Role: SystemRole,
},
Timestamp: c.timer.Now(),
}}
}

Expand All @@ -221,7 +244,10 @@ func (c *Client) addQuery(query string) {
Content: query,
}

c.History = append(c.History, message)
c.History = append(c.History, types.History{
Message: message,
Timestamp: c.timer.Now(),
})
c.truncateHistory()
}

Expand Down Expand Up @@ -270,9 +296,12 @@ func (c *Client) truncateHistory() {
}

func (c *Client) updateHistory(response string) {
c.History = append(c.History, types.Message{
Role: AssistantRole,
Content: response,
c.History = append(c.History, types.History{
Message: types.Message{
Role: AssistantRole,
Content: response,
},
Timestamp: c.timer.Now(),
})

if !c.Config.OmitHistory {
Expand All @@ -286,13 +315,13 @@ func calculateEffectiveContextWindow(window int, bufferPercentage int) int {
return effectiveContextWindow
}

func countTokens(messages []types.Message) (int, []int) {
func countTokens(entries []types.History) (int, []int) {
var result int
var rolling []int

for _, message := range messages {
for _, entry := range entries {
charCount, wordCount := 0, 0
words := strings.Fields(message.Content)
words := strings.Fields(entry.Content)
wordCount += len(words)

for _, word := range words {
Expand All @@ -309,9 +338,10 @@ func countTokens(messages []types.Message) (int, []int) {
return result, rolling
}

func createMessagesFromString(input string) []types.Message {
func (c *Client) createHistoryEntriesFromString(input string) []types.History {
var result []types.History

words := strings.Fields(input)
var messages []types.Message

for i := 0; i < len(words); i += 100 {
end := i + 100
Expand All @@ -321,14 +351,17 @@ func createMessagesFromString(input string) []types.Message {

content := strings.Join(words[i:end], " ")

message := types.Message{
Role: UserRole,
Content: content,
item := types.History{
Message: types.Message{
Role: UserRole,
Content: content,
},
Timestamp: c.timer.Now(),
}
messages = append(messages, message)
result = append(result, item)
}

return messages
return result
}

func (c *Client) printRequestDebugInfo(endpoint string, body []byte) {
Expand Down
Loading

0 comments on commit 64951ac

Please sign in to comment.