diff --git a/api/handlers.go b/api/handlers.go index 0dc451e..cdac0fa 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -14,6 +14,23 @@ import ( "github.com/micpst/minisearch/pkg/tokenizer" ) +type SearchRequest struct { + Query string `json:"query" binding:"required"` + Properties []string `json:"properties"` + Exact bool `json:"exact"` + Tolerance int `json:"tolerance"` + Relevance BM25Params `json:"relevance"` + Offset int `json:"offset"` + Limit int `json:"limit"` + Language tokenizer.Language `json:"lang"` +} + +type BM25Params struct { + K float64 `json:"k"` + B float64 `json:"b"` + D float64 `json:"d"` +} + type DocumentResponse struct { Id string `json:"id"` Title string `json:"title"` @@ -33,6 +50,10 @@ type SearchDocumentResponse struct { Elapsed int64 `json:"elapsed"` } +type ErrorResponse struct { + Message string `json:"message"` +} + type UploadDocumentsResponse struct { Total int `json:"total"` Success int `json:"success"` @@ -61,12 +82,11 @@ func (s *Server) uploadDocuments(c *gin.Context) { return } - params := store.InsertBatchParams[Document]{ + errs := s.db.InsertBatch(&store.InsertBatchParams[Document]{ Documents: dump.Documents, BatchSize: 10000, Language: tokenizer.Language(strings.ToLower(c.Query("lang"))), - } - errs := s.db.InsertBatch(¶ms) + }) total += len(dump.Documents) failed += len(errs) @@ -85,18 +105,24 @@ func (s *Server) createDocument(c *gin.Context) { return } - params := store.InsertParams[Document]{ + doc, err := s.db.Insert(&store.InsertParams[Document]{ Document: body, Language: tokenizer.Language(strings.ToLower(c.Query("lang"))), - } + }) - doc, err := s.db.Insert(¶ms) - if err != nil { - c.Status(http.StatusInternalServerError) - return + switch err.(type) { + case nil: + c.JSON(http.StatusCreated, DocumentResponse{ + Id: doc.Id, + Title: doc.Data.Title, + Url: doc.Data.Url, + Abstract: doc.Data.Abstract, + }) + default: + c.JSON(http.StatusBadRequest, ErrorResponse{ + Message: err.Error(), + }) } - - c.JSON(http.StatusCreated, documentFromRecord(doc)) } func (s *Server) updateDocument(c *gin.Context) { @@ -105,40 +131,56 @@ func (s *Server) updateDocument(c *gin.Context) { return } - params := store.UpdateParams[Document]{ + doc, err := s.db.Update(&store.UpdateParams[Document]{ Id: c.Param("id"), Document: body, Language: tokenizer.Language(strings.ToLower(c.Query("lang"))), - } + }) - doc, err := s.db.Update(¶ms) - if err != nil { - c.Status(http.StatusNotFound) - return + switch err.(type) { + case nil: + c.JSON(http.StatusOK, DocumentResponse{ + Id: doc.Id, + Title: doc.Data.Title, + Url: doc.Data.Url, + Abstract: doc.Data.Abstract, + }) + case *store.DocumentNotFoundError: + c.JSON(http.StatusNotFound, ErrorResponse{ + Message: err.Error(), + }) + default: + c.JSON(http.StatusBadRequest, ErrorResponse{ + Message: err.Error(), + }) } - - c.JSON(http.StatusOK, documentFromRecord(doc)) } func (s *Server) deleteDocument(c *gin.Context) { - params := store.DeleteParams[Document]{ + err := s.db.Delete(&store.DeleteParams[Document]{ Id: c.Param("id"), Language: tokenizer.Language(strings.ToLower(c.Query("lang"))), - } + }) - if err := s.db.Delete(¶ms); err != nil { - c.Status(http.StatusNotFound) - return + switch err.(type) { + case nil: + c.Status(http.StatusOK) + case *store.DocumentNotFoundError: + c.JSON(http.StatusNotFound, ErrorResponse{ + Message: err.Error(), + }) + default: + c.JSON(http.StatusBadRequest, ErrorResponse{ + Message: err.Error(), + }) } - - c.Status(http.StatusOK) } func (s *Server) searchDocuments(c *gin.Context) { - params := store.SearchParams{ + params := SearchRequest{ Properties: []string{}, Limit: 10, - Relevance: store.BM25Params{ + Relevance: BM25Params{ K: 1.2, B: 0.75, D: 0.5, @@ -149,27 +191,28 @@ func (s *Server) searchDocuments(c *gin.Context) { } start := time.Now() - result, err := s.db.Search(¶ms) - elapsed := time.Since(start) - - if err != nil { - c.Status(http.StatusBadRequest) - return - } - - c.JSON(http.StatusOK, SearchDocumentResponse{ - Count: result.Count, - Hits: *(*[]SearchDocument)(unsafe.Pointer(&result.Hits)), - Elapsed: elapsed.Microseconds(), + result, err := s.db.Search(&store.SearchParams{ + Query: params.Query, + Properties: params.Properties, + Exact: params.Exact, + Tolerance: params.Tolerance, + Relevance: store.BM25Params(params.Relevance), + Offset: params.Offset, + Limit: params.Limit, }) -} + elapsed := time.Since(start) -func documentFromRecord(d store.Record[Document]) DocumentResponse { - return DocumentResponse{ - Id: d.Id, - Title: d.Data.Title, - Url: d.Data.Url, - Abstract: d.Data.Abstract, + switch err.(type) { + case nil: + c.JSON(http.StatusOK, SearchDocumentResponse{ + Count: result.Count, + Hits: *(*[]SearchDocument)(unsafe.Pointer(&result.Hits)), + Elapsed: elapsed.Microseconds(), + }) + default: + c.JSON(http.StatusBadRequest, ErrorResponse{ + Message: err.Error(), + }) } } diff --git a/pkg/store/errors.go b/pkg/store/errors.go index 441f445..f683d01 100644 --- a/pkg/store/errors.go +++ b/pkg/store/errors.go @@ -10,10 +10,18 @@ type DocumentAlreadyExistsError struct { Id string } +type WrongSearchPropertyType struct { + Property string +} + func (e *DocumentNotFoundError) Error() string { - return fmt.Sprintf("Document with id %s not found", e.Id) + return fmt.Sprintf("Document with id '%s' not found", e.Id) } func (e *DocumentAlreadyExistsError) Error() string { - return fmt.Sprintf("Document with id %s already exists", e.Id) + return fmt.Sprintf("Document with id '%s' already exists", e.Id) +} + +func (e *WrongSearchPropertyType) Error() string { + return fmt.Sprintf("Property '%s' is not searchable", e.Property) } diff --git a/pkg/store/index.go b/pkg/store/index.go index 3df5247..6c8cae0 100644 --- a/pkg/store/index.go +++ b/pkg/store/index.go @@ -121,7 +121,7 @@ func (idx *index[K, S]) delete(params *indexParams[K, S]) { } } -func (idx *index[K, S]) find(params *findParams) map[K]float64 { +func (idx *index[K, S]) find(params *findParams) (map[K]float64, error) { idScores := make(map[K]float64) if index, ok := idx.indexes[params.property]; ok { @@ -142,9 +142,11 @@ func (idx *index[K, S]) find(params *findParams) map[K]float64 { params.relevance.D, ) } + } else { + return nil, &WrongSearchPropertyType{Property: params.property} } - return idScores + return idScores, nil } func flattenSchema(obj any, prefix ...string) map[string]any { diff --git a/pkg/store/store.go b/pkg/store/store.go index d64b774..5ceca69 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -10,13 +10,6 @@ import ( "github.com/micpst/minisearch/pkg/tokenizer" ) -const ( - AND Mode = "AND" - OR Mode = "OR" -) - -type Mode string - type Schema any type Record[S Schema] struct { @@ -47,20 +40,20 @@ type DeleteParams[S Schema] struct { } type SearchParams struct { - Query string `json:"query" binding:"required"` - Properties []string `json:"properties"` - Exact bool `json:"exact"` - Tolerance int `json:"tolerance"` - Relevance BM25Params `json:"relevance"` - Offset int `json:"offset"` - Limit int `json:"limit"` - Language tokenizer.Language `json:"lang"` + Query string + Properties []string + Exact bool + Tolerance int + Relevance BM25Params + Offset int + Limit int + Language tokenizer.Language } type BM25Params struct { - K float64 `json:"k"` - B float64 `json:"b"` - D float64 `json:"d"` + K float64 + B float64 + D float64 } type SearchResult[S Schema] struct { @@ -273,7 +266,7 @@ func (db *MemDB[S]) Search(params *SearchParams) (SearchResult[S], error) { for _, prop := range properties { for _, token := range tokens { - idScores := db.index.find(&findParams{ + idScores, err := db.index.find(&findParams{ term: token, property: prop, exact: params.Exact, @@ -281,6 +274,9 @@ func (db *MemDB[S]) Search(params *SearchParams) (SearchResult[S], error) { relevance: params.Relevance, docsCount: len(db.documents), }) + if err != nil { + return SearchResult[S]{}, err + } for id, score := range idScores { allIdScores[id] += score } diff --git a/pkg/tokenizer/errors.go b/pkg/tokenizer/errors.go index e9bc252..fd3ee3d 100644 --- a/pkg/tokenizer/errors.go +++ b/pkg/tokenizer/errors.go @@ -7,5 +7,5 @@ type LanguageNotSupportedError struct { } func (e *LanguageNotSupportedError) Error() string { - return fmt.Sprintf("Language %s is not supported", e.Language) + return fmt.Sprintf("Language '%s' is not supported", e.Language) }