diff --git a/db/chat_test.go b/db/chat_test.go index a09e02e27..9fa718d6d 100644 --- a/db/chat_test.go +++ b/db/chat_test.go @@ -1,8 +1,11 @@ package db import ( - "github.com/stretchr/testify/assert" + "strings" "testing" + "time" + + "github.com/stretchr/testify/assert" ) func TestGetChatsForWorkspace(t *testing.T) { @@ -110,3 +113,431 @@ func TestGetChatsForWorkspace(t *testing.T) { }) } } + +func TestGetChatMessagesForChatID(t *testing.T) { + InitTestDB() + currentTime := time.Now() + + tests := []struct { + name string + setup func() string + expected []ChatMessage + expectError bool + }{ + { + name: "Successfully geting messages for chat", + setup: func() string { + chatID := "chat123" + messages := []ChatMessage{ + { + ID: "msg1", + ChatID: chatID, + Message: "Hello", + Role: UserRole, + Timestamp: currentTime, + Status: SentStatus, + Source: UserSource, + }, + { + ID: "msg2", + ChatID: chatID, + Message: "Hi there", + Role: AssistantRole, + Timestamp: currentTime.Add(time.Minute), + Status: SentStatus, + Source: AgentSource, + }, + } + for _, msg := range messages { + TestDB.db.Create(&msg) + } + return chatID + }, + expected: []ChatMessage{ + { + ID: "msg1", + ChatID: "chat123", + Message: "Hello", + Role: UserRole, + Status: SentStatus, + Source: UserSource, + }, + { + ID: "msg2", + ChatID: "chat123", + Message: "Hi there", + Role: AssistantRole, + Status: SentStatus, + Source: AgentSource, + }, + }, + expectError: false, + }, + { + name: "Empty chat ID", + setup: func() string { + return "" + }, + expected: []ChatMessage{}, + expectError: false, + }, + { + name: "Non-existent chat ID", + setup: func() string { + return "nonexistent123" + }, + expected: []ChatMessage{}, + expectError: false, + }, + { + name: "Chat with special characters in messages", + setup: func() string { + chatID := "chat456" + messages := []ChatMessage{ + { + ID: "msg3", + ChatID: chatID, + Message: "Hello !@#$%^&*()", + Role: UserRole, + Timestamp: currentTime, + Status: SentStatus, + Source: UserSource, + }, + } + for _, msg := range messages { + TestDB.db.Create(&msg) + } + return chatID + }, + expected: []ChatMessage{ + { + ID: "msg3", + ChatID: "chat456", + Message: "Hello !@#$%^&*()", + Role: UserRole, + Status: SentStatus, + Source: UserSource, + }, + }, + expectError: false, + }, + { + name: "Chat with the Unicode messages", + setup: func() string { + chatID := "chat789" + messages := []ChatMessage{ + { + ID: "msg4", + ChatID: chatID, + Message: "你好 👋 Привет", + Role: UserRole, + Timestamp: currentTime, + Status: SentStatus, + Source: UserSource, + }, + } + for _, msg := range messages { + TestDB.db.Create(&msg) + } + return chatID + }, + expected: []ChatMessage{ + { + ID: "msg4", + ChatID: "chat789", + Message: "你好 👋 Привет", + Role: UserRole, + Status: SentStatus, + Source: UserSource, + }, + }, + expectError: false, + }, + { + name: "Chat with large message", + setup: func() string { + chatID := "chat101112" + messages := []ChatMessage{ + { + ID: "msg5", + ChatID: chatID, + Message: strings.Repeat("a", 1000), + Role: UserRole, + Timestamp: currentTime, + Status: SentStatus, + Source: UserSource, + }, + } + for _, msg := range messages { + TestDB.db.Create(&msg) + } + return chatID + }, + expected: []ChatMessage{ + { + ID: "msg5", + ChatID: "chat101112", + Message: strings.Repeat("a", 1000), + Role: UserRole, + Status: SentStatus, + Source: UserSource, + }, + }, + expectError: false, + }, + { + name: "SQL injection attempt in chat ID", + setup: func() string { + return "chat123'; DROP TABLE chat_messages; --" + }, + expected: []ChatMessage{}, + expectError: false, + }, + { + name: "Messages ordered by timestamp", + setup: func() string { + chatID := "chatOrdered" + messages := []ChatMessage{ + { + ID: "msg6", + ChatID: chatID, + Message: "Second", + Role: UserRole, + Timestamp: currentTime.Add(time.Minute), + Status: SentStatus, + Source: UserSource, + }, + { + ID: "msg7", + ChatID: chatID, + Message: "First", + Role: UserRole, + Timestamp: currentTime, + Status: SentStatus, + Source: UserSource, + }, + } + for _, msg := range messages { + TestDB.db.Create(&msg) + } + return chatID + }, + expected: []ChatMessage{ + { + ID: "msg7", + ChatID: "chatOrdered", + Message: "First", + Role: UserRole, + Status: SentStatus, + Source: UserSource, + }, + { + ID: "msg6", + ChatID: "chatOrdered", + Message: "Second", + Role: UserRole, + Status: SentStatus, + Source: UserSource, + }, + }, + expectError: false, + }, + { + name: "Messages with different statuses", + setup: func() string { + chatID := "chatStatus" + messages := []ChatMessage{ + { + ID: "msg8", + ChatID: chatID, + Message: "Sending", + Role: UserRole, + Timestamp: currentTime, + Status: SendingStatus, + Source: UserSource, + }, + { + ID: "msg9", + ChatID: chatID, + Message: "Error", + Role: AssistantRole, + Timestamp: currentTime.Add(time.Minute), + Status: ErrorStatus, + Source: AgentSource, + }, + } + for _, msg := range messages { + TestDB.db.Create(&msg) + } + return chatID + }, + expected: []ChatMessage{ + { + ID: "msg8", + ChatID: "chatStatus", + Message: "Sending", + Role: UserRole, + Status: SendingStatus, + Source: UserSource, + }, + { + ID: "msg9", + ChatID: "chatStatus", + Message: "Error", + Role: AssistantRole, + Status: ErrorStatus, + Source: AgentSource, + }, + }, + expectError: false, + }, + { + name: "Valid Chat ID with Messages", + setup: func() string { + chatID := "valid-chat-123" + messages := []ChatMessage{ + { + ID: "valid-msg-1", + ChatID: chatID, + Message: "Test message 1", + Role: UserRole, + Timestamp: currentTime, + Status: SentStatus, + Source: UserSource, + }, + } + for _, msg := range messages { + TestDB.db.Create(&msg) + } + return chatID + }, + expected: []ChatMessage{ + { + ID: "valid-msg-1", + ChatID: "valid-chat-123", + Message: "Test message 1", + Role: UserRole, + Status: SentStatus, + Source: UserSource, + }, + }, + expectError: false, + }, + { + name: "Chat ID with Maximum Length", + setup: func() string { + chatID := strings.Repeat("a", 255) + messages := []ChatMessage{ + { + ID: "max-length-msg-1", + ChatID: chatID, + Message: "Max length chat ID", + Role: UserRole, + Timestamp: currentTime, + Status: SentStatus, + Source: UserSource, + }, + } + for _, msg := range messages { + TestDB.db.Create(&msg) + } + return chatID + }, + expected: []ChatMessage{ + { + ID: "max-length-msg-1", + ChatID: strings.Repeat("a", 255), + Message: "Max length chat ID", + Role: UserRole, + Status: SentStatus, + Source: UserSource, + }, + }, + expectError: false, + }, + { + name: "Case Sensitivity", + setup: func() string { + chatID := "UPPERCASE-CHAT-ID" + messages := []ChatMessage{ + { + ID: "case-sensitive-msg-1", + ChatID: chatID, + Message: "Case sensitive test", + Role: UserRole, + Timestamp: currentTime, + Status: SentStatus, + Source: UserSource, + }, + } + for _, msg := range messages { + TestDB.db.Create(&msg) + } + return "uppercase-chat-id" + }, + expected: []ChatMessage{}, + expectError: false, + }, + { + name: "Chat ID with Special Characters", + setup: func() string { + chatID := "chat!@#$%^&*()_+" + messages := []ChatMessage{ + { + ID: "special-char-msg-1", + ChatID: chatID, + Message: "Special characters in chat ID", + Role: UserRole, + Timestamp: currentTime, + Status: SentStatus, + Source: UserSource, + }, + } + for _, msg := range messages { + TestDB.db.Create(&msg) + } + return chatID + }, + expected: []ChatMessage{ + { + ID: "special-char-msg-1", + ChatID: "chat!@#$%^&*()_+", + Message: "Special characters in chat ID", + Role: UserRole, + Status: SentStatus, + Source: UserSource, + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + TestDB.db.Exec("DELETE FROM chat_messages") + + chatID := tt.setup() + messages, err := TestDB.GetChatMessagesForChatID(chatID) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, len(tt.expected), len(messages)) + + for i, expectedMsg := range tt.expected { + assert.Equal(t, expectedMsg.ID, messages[i].ID) + assert.Equal(t, expectedMsg.ChatID, messages[i].ChatID) + assert.Equal(t, expectedMsg.Message, messages[i].Message) + assert.Equal(t, expectedMsg.Role, messages[i].Role) + assert.Equal(t, expectedMsg.Status, messages[i].Status) + assert.Equal(t, expectedMsg.Source, messages[i].Source) + if len(expectedMsg.ContextTags) > 0 { + assert.Equal(t, expectedMsg.ContextTags, messages[i].ContextTags) + } + } + } + }) + } +}