Skip to content

Commit

Permalink
Merge pull request #256 from tigergraph/GML-1809-feedback-analysis-ac…
Browse files Browse the repository at this point in the history
…cess-control

Gml 1809 feedback analysis access control
  • Loading branch information
luzhoutg authored Aug 2, 2024
2 parents fdd9560 + 0ca2da8 commit 9b920dc
Show file tree
Hide file tree
Showing 15 changed files with 536 additions and 142 deletions.
2 changes: 1 addition & 1 deletion chat-history/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ clean:

run: clean test build
clear
CONFIG="config.json" DEV=true ./chat-history
CONFIG_FILES="chat_config.json,db_config.json" DEV=true ./chat-history


79 changes: 55 additions & 24 deletions chat-history/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,83 @@ package config

import (
"encoding/json"
"fmt"
"os"
)

type LLMConfig struct {
ModelName string `json:"model_name"`
}

type DbConfig struct {
Port string `json:"apiPort"`
DbPath string `json:"dbPath"`
DbLogPath string `json:"dbLogPath"`
LogPath string `json:"logPath"`
// DbHostname string `json:"hostname"`
// Username string `json:"username"`
// Password string `json:"password"`
type ChatDbConfig struct {
Port string `json:"apiPort"`
DbPath string `json:"dbPath"`
DbLogPath string `json:"dbLogPath"`
LogPath string `json:"logPath"`
ConversationAccessRoles []string `json:"conversationAccessRoles"`
}

type TgDbConfig struct {
Hostname string `json:"hostname"`
Username string `json:"username"`
Password string `json:"password"`
GsPort string `json:"gsPort"`
TgCloud bool `json:"tgCloud"`
// GetToken string `json:"getToken"`
// DefaultTimeout string `json:"default_timeout"`
// DefaultMemThreshold string `json:"default_mem_threshold"`
// DefaultThreadLimit string `json:"default_thread_limit"`
}

type Config struct {
DbConfig
ChatDbConfig
TgDbConfig
// LLMConfig
}

func LoadConfig(path string) (Config, error) {
var b []byte
if _, err := os.Stat(path); os.IsNotExist(err) {
// file doesn't exist read from env
cfg := os.Getenv("CONFIG")
if cfg == "" {
fmt.Println("CONFIG path is not found nor is the CONFIG json env variable defined")
os.Exit(1)
func LoadConfig(paths map[string]string) (Config, error) {
var config Config

// Load database config
if dbConfigPath, ok := paths["chatdb"]; ok {
dbConfig, err := loadChatDbConfig(dbConfigPath)
if err != nil {
return Config{}, err
}
b = []byte(cfg)
} else {
b, err = os.ReadFile(path)
config.ChatDbConfig = dbConfig
}

// Load TigerGraph config
if tgConfigPath, ok := paths["tgdb"]; ok {
tgConfig, err := loadTgDbConfig(tgConfigPath)
if err != nil {
return Config{}, err
}
config.TgDbConfig = tgConfig
}

var cfg Config
json.Unmarshal(b, &cfg)
return config, nil
}

return cfg, nil
func loadChatDbConfig(path string) (ChatDbConfig, error) {
var dbConfig ChatDbConfig
b, err := os.ReadFile(path)
if err != nil {
return ChatDbConfig{}, err
}
if err := json.Unmarshal(b, &dbConfig); err != nil {
return ChatDbConfig{}, err
}
return dbConfig, nil
}

func loadTgDbConfig(path string) (TgDbConfig, error) {
var tgConfig TgDbConfig
b, err := os.ReadFile(path)
if err != nil {
return TgDbConfig{}, err
}
if err := json.Unmarshal(b, &tgConfig); err != nil {
return TgDbConfig{}, err
}
return tgConfig, nil
}
68 changes: 43 additions & 25 deletions chat-history/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,59 @@ import (
)

func TestLoadConfig(t *testing.T) {
pth := setup(t)
cfg, err := LoadConfig(pth)
chatConfigPath, tgConfigPath := setup(t)

cfg, err := LoadConfig(map[string]string{
"chatdb": chatConfigPath,
"tgdb": tgConfigPath,
})
if err != nil {
t.Fatal(err)
}

if cfg.Port != "8000" ||
cfg.DbPath != "chats.db" ||
cfg.DbLogPath != "db.log" ||
cfg.LogPath != "requestLogs.jsonl" {
t.Fatalf("config is wrong, %v", cfg)
if cfg.ChatDbConfig.Port != "8002" ||
cfg.ChatDbConfig.DbPath != "chats.db" ||
cfg.ChatDbConfig.DbLogPath != "db.log" ||
cfg.ChatDbConfig.LogPath != "requestLogs.jsonl" {
t.Fatalf("config is wrong, %v", cfg.ChatDbConfig)
}

if cfg.TgDbConfig.Hostname != "https://tg-0cdef603-3760-41c3-af6f-41e95afc40de.us-east-1.i.tgcloud.io" ||
cfg.TgDbConfig.GsPort != "14240" ||
cfg.TgDbConfig.TgCloud != true {
t.Fatalf("TigerGraph config is wrong, %v", cfg.TgDbConfig)
}
}

func setup(t *testing.T) string {
func setup(t *testing.T) (string, string) {
tmp := t.TempDir()
pth := fmt.Sprintf("%s/%s", tmp, "config.json")
dat := `

chatConfigPath := fmt.Sprintf("%s/%s", tmp, "chat_config.json")
chatConfigData := `
{
"apiPort":"8000",
"hostname": "http://localhost:14240",
"dbPath": "chats.db",
"dbLogPath": "db.log",
"logPath": "requestLogs.jsonl",
"username": "tigergraph",
"password": "tigergraph",
"getToken": false,
"default_timeout": 300,
"default_mem_threshold": 5000,
"default_thread_limit": 8
"apiPort":"8002",
"dbPath": "chats.db",
"dbLogPath": "db.log",
"logPath": "requestLogs.jsonl",
"conversationAccessRoles": ["superuser", "globaldesigner"]
}`
err := os.WriteFile(pth, []byte(dat), 0644)
if err != nil {
t.Fatal("error setting up config.json")

if err := os.WriteFile(chatConfigPath, []byte(chatConfigData), 0644); err != nil {
t.Fatal("error setting up chat_config.json")
}
return pth

tgConfigPath := fmt.Sprintf("%s/%s", tmp, "db_config.json")
tgConfigData := `
{
"hostname": "https://tg-0cdef603-3760-41c3-af6f-41e95afc40de.us-east-1.i.tgcloud.io",
"gsPort": "14240",
"username": "supportai",
"password": "supportai",
"tgCloud": true
}`
if err := os.WriteFile(tgConfigPath, []byte(tgConfigData), 0644); err != nil {
t.Fatal("error setting up tg_config.json")
}

return chatConfigPath, tgConfigPath
}
31 changes: 28 additions & 3 deletions chat-history/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,28 @@ func UpdateConversationById(message structs.Message) (*structs.Conversation, err
return &convo, nil
}

// GetAllMessages retrieves all messages from the database
func GetAllMessages() ([]structs.Message, error) {
var messages []structs.Message

// Use GORM to query all messages
if err := db.Find(&messages).Error; err != nil {
return nil, err
}

return messages, nil
}

func populateDB() {
mu.Lock()
defer mu.Unlock()

// init convos
conv1 := uuid.MustParse("601529eb-4927-4e24-b285-bd6b9519a951")
conv2 := uuid.MustParse("601529eb-4927-4e24-b285-bd6b9519a952")
db.Create(&structs.Conversation{UserId: "sam_pull", ConversationId: conv1, Name: "conv1"})
db.Create(&structs.Conversation{UserId: "sam_pull", ConversationId: uuid.New(), Name: "conv2"})
db.Create(&structs.Conversation{UserId: "Miss_Take", ConversationId: uuid.New(), Name: "conv3"})
db.Create(&structs.Conversation{UserId: "Miss_Take", ConversationId: conv2, Name: "conv2"})
// db.Create(&structs.Conversation{UserId: "Miss_Take", ConversationId: uuid.New(), Name: "conv3"})

// add message to convos
message := structs.Message{
Expand All @@ -152,8 +165,8 @@ func populateDB() {
Feedback: structs.NoFeedback,
Comment: "",
}

db.Create(&message)

m2 := structs.Message{
ConversationId: conv1,
MessageId: uuid.New(),
Expand All @@ -165,4 +178,16 @@ func populateDB() {
Comment: "",
}
db.Create(&m2)

m3 := structs.Message{
ConversationId: conv2,
MessageId: uuid.New(),
ParentId: &message.MessageId,
ModelName: "GPT-4o",
Content: "How many transactions?",
Role: structs.SystemRole,
Feedback: structs.NoFeedback,
Comment: "",
}
db.Create(&m3)
}
23 changes: 23 additions & 0 deletions chat-history/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,29 @@ func TestParallelWrites(t *testing.T) {
}
}

func TestGetAllMessages(t *testing.T) {
setupTest(t, true)

messages, err := GetAllMessages()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

// Ensure that messages are returned
if len(messages) == 0 {
t.Fatalf("Expected some messages, got none")
}

// Validate the structure of the messages
for _, m := range messages {
if uuid.Validate(m.ConversationId.String()) != nil ||
uuid.Validate(m.MessageId.String()) != nil ||
(m.Role != "system" && m.Role != "user") {
t.Fatalf("Invaid message structure: %v", m)
}
}
}

/*
helper functions
*/
Expand Down
4 changes: 2 additions & 2 deletions chat-history/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ module chat-history
go 1.22.3

require (
github.com/go-chi/chi/v5 v5.0.12
github.com/go-chi/httplog/v2 v2.0.11
github.com/google/uuid v1.6.0
gorm.io/driver/sqlite v1.5.5
gorm.io/gorm v1.25.10
)

require (
github.com/go-chi/httplog/v2 v2.0.11 // indirect
github.com/go-chi/chi/v5 v5.0.12 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
Expand Down
2 changes: 2 additions & 0 deletions chat-history/go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/GenericP3rson/TigerGo v0.0.4 h1:xI7d/cLJ6sRP4fzanInakARE0XGk1YAmvn5KrH1fwFU=
github.com/GenericP3rson/TigerGo v0.0.4/go.mod h1:PGpAFO9vNA7l34WSGYCtWb/eqVKHuIq1xqvizBlNhRM=
github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s=
github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/httplog/v2 v2.0.11 h1:eu6kYksMEJzBcOP+ba/iYudc0m5rv4VvBAzroJMkaY4=
Expand Down
17 changes: 12 additions & 5 deletions chat-history/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@ import (
)

func main() {
configPath:= os.Getenv("CONFIG")
config, err := config.LoadConfig(configPath)
configPath := os.Getenv("CONFIG_FILES")
// Split the paths into a slice
configPaths := strings.Split(configPath, ",")

cfg, err := config.LoadConfig(map[string]string{
"chatdb": configPaths[0],
"tgdb": configPaths[1],
})
if err != nil {
panic(err)
}
db.InitDB(config.DbPath, config.DbLogPath)
db.InitDB(cfg.ChatDbConfig.DbPath, cfg.ChatDbConfig.DbLogPath)

// make router
router := http.NewServeMux()
Expand All @@ -30,14 +36,15 @@ func main() {
router.HandleFunc("GET /user/{userId}", routes.GetUserConversations)
router.HandleFunc("GET /conversation/{conversationId}", routes.GetConversation)
router.HandleFunc("POST /conversation", routes.UpdateConversation)
router.HandleFunc("GET /get_feedback", routes.GetFeedback(cfg.TgDbConfig.Hostname, cfg.TgDbConfig.GsPort, cfg.ChatDbConfig.ConversationAccessRoles, cfg.TgDbConfig.TgCloud))

// create server with middleware
dev := strings.ToLower(os.Getenv("DEV")) == "true"
var port string
if dev {
port = fmt.Sprintf("localhost:%s", config.Port)
port = fmt.Sprintf("localhost:%s", cfg.ChatDbConfig.Port)
} else {
port = fmt.Sprintf(":%s", config.Port)
port = fmt.Sprintf(":%s", cfg.ChatDbConfig.Port)
}

handler := middleware.ChainMiddleware(router,
Expand Down
Loading

0 comments on commit 9b920dc

Please sign in to comment.