From 77e463b998f29f1a93404133f9deeac61bd83b32 Mon Sep 17 00:00:00 2001 From: divyam234 <47589864+divyam234@users.noreply.github.com> Date: Wed, 1 Jan 2025 18:21:45 +0530 Subject: [PATCH] refactor: simplify condition checks and improve upload handling --- .goreleaser.yml | 4 ++ Makefile | 2 +- internal/chizap/chizap.go | 60 -------------------------- internal/middleware/middleware.go | 10 ++++- internal/reader/reader.go | 2 +- pkg/mapper/mapper.go | 17 ++++++++ pkg/services/auth.go | 2 +- pkg/services/common.go | 2 +- pkg/services/file.go | 68 ++++++++++++++++++++---------- pkg/services/file_query_builder.go | 25 +++++------ pkg/services/upload.go | 19 +++++---- pkg/services/user.go | 4 +- 12 files changed, 105 insertions(+), 110 deletions(-) diff --git a/.goreleaser.yml b/.goreleaser.yml index 184d682e..60307e11 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -2,6 +2,10 @@ version: 2 project_name: teldrive env: - GO111MODULE=on + +before: + hooks: + - go generate ./... builds: - env: diff --git a/Makefile b/Makefile index 6e679091..c5c9a438 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ GOARCH ?= $(shell go env GOARCH) VERSION:= $(GIT_TAG) BINARY_EXTENSION := -.PHONY: all build run clean frontend backend run sync-ui retag patch-version minor-version generate +.PHONY: all build run clean frontend backend run sync-ui retag patch-version minor-version gen all: build diff --git a/internal/chizap/chizap.go b/internal/chizap/chizap.go index ab4aa17e..8763e821 100644 --- a/internal/chizap/chizap.go +++ b/internal/chizap/chizap.go @@ -3,13 +3,8 @@ package chizap import ( "context" - "net" "net/http" - "net/http/httputil" - "os" "regexp" - "runtime/debug" - "strings" "time" "github.com/go-chi/chi/v5/middleware" @@ -115,58 +110,3 @@ func ChizapWithConfig(logger ZapLogger, conf *Config) func(next http.Handler) ht }) } } - -func defaultHandleRecovery(w http.ResponseWriter, r *http.Request, err interface{}) { - w.WriteHeader(http.StatusInternalServerError) -} - -func RecoveryWithZap(logger ZapLogger, stack bool) func(next http.Handler) http.Handler { - return CustomRecoveryWithZap(logger, stack, defaultHandleRecovery) -} - -func CustomRecoveryWithZap(logger ZapLogger, stack bool, recovery func(w http.ResponseWriter, r *http.Request, err interface{})) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer func() { - if err := recover(); err != nil { - var brokenPipe bool - if ne, ok := err.(*net.OpError); ok { - if se, ok := ne.Err.(*os.SyscallError); ok { - if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || - strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { - brokenPipe = true - } - } - } - - httpRequest, _ := httputil.DumpRequest(r, false) - if brokenPipe { - logger.Error(r.URL.Path, - zap.Any("error", err), - zap.String("request", string(httpRequest)), - ) - http.Error(w, "connection broken", http.StatusInternalServerError) - return - } - - if stack { - logger.Error("[Recovery from panic]", - zap.Time("time", time.Now()), - zap.Any("error", err), - zap.String("request", string(httpRequest)), - zap.String("stack", string(debug.Stack())), - ) - } else { - logger.Error("[Recovery from panic]", - zap.Time("time", time.Now()), - zap.Any("error", err), - zap.String("request", string(httpRequest)), - ) - } - recovery(w, r, err) - } - }() - next.ServeHTTP(w, r) - }) - } -} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 3cb2b4c1..eda62d30 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -28,13 +28,21 @@ func SPAHandler(filesystem fs.FS) http.HandlerFunc { logging.DefaultLogger().Fatal(err.Error()) } return func(w http.ResponseWriter, r *http.Request) { - f, err := spaFS.Open(strings.TrimPrefix(path.Clean(r.URL.Path), "/")) + filePath := strings.TrimPrefix(path.Clean(r.URL.Path), "/") + f, err := spaFS.Open(filePath) if err == nil { defer f.Close() } if os.IsNotExist(err) { r.URL.Path = "/" + filePath = "index.html" } + if filePath == "index.html" { + w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") + } else { + w.Header().Set("Cache-Control", "public, max-age=31536000, immutable") + } + http.FileServer(http.FS(spaFS)).ServeHTTP(w, r) } } diff --git a/internal/reader/reader.go b/internal/reader/reader.go index d1c96369..abe2b6e4 100644 --- a/internal/reader/reader.go +++ b/internal/reader/reader.go @@ -140,7 +140,7 @@ func (r *LinearReader) getPartReader() (io.ReadCloser, error) { reader io.ReadCloser err error ) - if r.file.Encrypted.IsSet() && r.file.Encrypted.Value { + if r.file.Encrypted.Value { salt := r.parts[r.ranges[r.pos].PartNo].Salt cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt) reader, err = cipher.DecryptDataSeek(r.ctx, diff --git a/pkg/mapper/mapper.go b/pkg/mapper/mapper.go index 189a268f..f8df2e08 100644 --- a/pkg/mapper/mapper.go +++ b/pkg/mapper/mapper.go @@ -29,3 +29,20 @@ func ToFileOut(file models.File, extended bool) *api.File { } return res } + +func ToUploadOut(parts []models.Upload) []api.UploadPart { + res := []api.UploadPart{} + for _, part := range parts { + res = append(res, api.UploadPart{ + Name: part.Name, + PartId: part.PartId, + ChannelId: part.ChannelID, + PartNo: part.PartNo, + Size: part.Size, + Encrypted: part.Encrypted, + Salt: api.NewOptString(part.Salt), + }) + + } + return res +} diff --git a/pkg/services/auth.go b/pkg/services/auth.go index 9ac12ce4..418c8c28 100644 --- a/pkg/services/auth.go +++ b/pkg/services/auth.go @@ -134,7 +134,7 @@ func (a *apiService) AuthLogout(ctx context.Context) (*api.AuthLogoutNoContent, } func (a *apiService) AuthSession(ctx context.Context, params api.AuthSessionParams) (api.AuthSessionRes, error) { - if !params.AccessToken.IsSet() { + if params.AccessToken.Value == "" { return &api.AuthSessionNoContent{}, nil } claims, err := auth.VerifyUser(a.db, a.cache, a.cnf.JWT.Secret, params.AccessToken.Value) diff --git a/pkg/services/common.go b/pkg/services/common.go index c8b0f738..1118ac22 100644 --- a/pkg/services/common.go +++ b/pkg/services/common.go @@ -49,7 +49,7 @@ func getParts(ctx context.Context, client *telegram.Client, cache cache.Cacher, Size: document.Size, Salt: file.Parts[i].Salt.Value, } - if file.Encrypted.IsSet() && file.Encrypted.Value { + if file.Encrypted.Value { part.DecryptedSize, _ = crypt.DecryptedSize(document.Size) } parts = append(parts, part) diff --git a/pkg/services/file.go b/pkg/services/file.go index aa987d3e..f71b2e4a 100644 --- a/pkg/services/file.go +++ b/pkg/services/file.go @@ -196,7 +196,7 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap dbFile.Name = req.NewName.Or(file.Name) dbFile.Size = utils.Ptr(file.Size.Value) dbFile.Type = string(file.Type) - dbFile.MimeType = file.MimeType.Value + dbFile.MimeType = file.MimeType.Or(defaultContentType) dbFile.Parts = datatypes.NewJSONSlice(newIds) dbFile.UserID = userId dbFile.Status = "active" @@ -226,7 +226,7 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi channelId int64 ) - if fileIn.Path.IsSet() { + if fileIn.Path.Value != "" { path = strings.TrimSpace(fileIn.Path.Value) path = strings.ReplaceAll(path, "//", "/") if path != "/" { @@ -234,7 +234,7 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi } } - if path != "" && !fileIn.ParentId.IsSet() { + if path != "" && fileIn.ParentId.Value == "" { parent, err = a.getFileFromPath(path, userId) if err != nil { return nil, &apiError{err: err, code: 404} @@ -243,7 +243,7 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi String: parent.Id, Valid: true, } - } else if fileIn.ParentId.IsSet() { + } else if fileIn.ParentId.Value != "" { fileDB.ParentID = sql.NullString{ String: fileIn.ParentId.Value, Valid: true, @@ -257,7 +257,7 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi fileDB.MimeType = "drive/folder" fileDB.Parts = nil } else if fileIn.Type == "file" { - if !fileIn.ChannelId.IsSet() { + if fileIn.ChannelId.Value == 0 { channelId, err = getDefaultChannel(a.db, a.cache, userId) if err != nil { return nil, &apiError{err: err} @@ -266,10 +266,18 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi channelId = fileIn.ChannelId.Value } fileDB.ChannelID = &channelId - fileDB.MimeType = fileIn.MimeType.Or("application/octet-stream") + fileDB.MimeType = fileIn.MimeType.Value fileDB.Category = string(category.GetCategory(fileIn.Name)) if len(fileIn.Parts) > 0 { - fileDB.Parts = datatypes.NewJSONSlice(fileIn.Parts) + parts := []api.Part{} + for _, part := range fileIn.Parts { + p := api.Part{ID: part.ID} + if part.Salt.Value != "" { + p.Salt = part.Salt + } + parts = append(parts, p) + } + fileDB.Parts = datatypes.NewJSONSlice(parts) } fileDB.Size = utils.Ptr(fileIn.Size.Or(0)) } @@ -277,7 +285,7 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi fileDB.Type = string(fileIn.Type) fileDB.UserID = userId fileDB.Status = "active" - fileDB.Encrypted = fileIn.Encrypted.Or(false) + fileDB.Encrypted = fileIn.Encrypted.Value if err := a.db.Create(&fileDB).Error; err != nil { if database.IsKeyConflictErr(err) { return nil, &apiError{err: errors.New("file already exists"), code: 409} @@ -292,7 +300,7 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre var fileShare models.FileShare - if req.Password.IsSet() { + if req.Password.Value != "" { bytes, err := bcrypt.GenerateFromPassword([]byte(req.Password.Value), bcrypt.MinCost) if err != nil { return &apiError{err: err} @@ -315,14 +323,14 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error { userId, _ := auth.GetUser(ctx) - if !req.Source.IsSet() && len(req.Ids) == 0 { + if req.Source.Value == "" && len(req.Ids) == 0 { return &apiError{err: errors.New("source or ids is required"), code: 409} } - if req.Source.IsSet() && len(req.Ids) == 0 { + if req.Source.Value != "" && len(req.Ids) == 0 { if err := a.db.Exec("call teldrive.delete_folder_recursive($1 , $2)", req.Source.Value, userId).Error; err != nil { return &apiError{err: err} } - } else if !req.Source.IsSet() && len(req.Ids) > 0 { + } else if req.Source.Value == "" && len(req.Ids) > 0 { if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil { return &apiError{err: err} } @@ -351,7 +359,7 @@ func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreat var fileShareUpdate models.FileShare - if req.Password.IsSet() { + if req.Password.Value != "" { bytes, err := bcrypt.GenerateFromPassword([]byte(req.Password.Value), bcrypt.MinCost) if err != nil { return &apiError{err: err} @@ -409,15 +417,15 @@ func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error { func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error { userId, _ := auth.GetUser(ctx) - if !req.Source.IsSet() && len(req.Ids) == 0 { + if req.Source.Value == "" && len(req.Ids) == 0 { return &apiError{err: errors.New("source or ids is required"), code: 409} } - if !req.Source.IsSet() && len(req.Ids) > 0 { + if req.Source.Value != "" && len(req.Ids) > 0 { if err := a.db.Exec("select * from teldrive.move_items($1 , $2 , $3)", req.Ids, req.Destination, userId).Error; err != nil { return &apiError{err: err} } } - if req.Source.IsSet() && len(req.Ids) == 0 { + if req.Source.Value == "" && len(req.Ids) == 0 { if err := a.db.Exec("select * from teldrive.move_directory(? , ? , ?)", req.Source.Value, req.Destination, userId).Error; err != nil { return &apiError{err: err} @@ -466,13 +474,21 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param chain *gorm.DB ) updateDb := models.File{} - if req.Name.IsSet() { + if req.Name.Value != "" { updateDb.Name = req.Name.Value } if len(req.Parts) > 0 { - updateDb.Parts = datatypes.NewJSONSlice(req.Parts) + parts := []api.Part{} + for _, part := range req.Parts { + p := api.Part{ID: part.ID} + if part.Salt.Value != "" { + p.Salt = part.Salt + } + parts = append(parts, p) + } + updateDb.Parts = datatypes.NewJSONSlice(parts) } - if req.Size.IsSet() { + if req.Size.Value != 0 { updateDb.Size = utils.Ptr(req.Size.Value) } if req.UpdatedAt.IsSet() { @@ -503,7 +519,7 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd Size: utils.Ptr(req.Size), } - if !req.ChannelId.IsSet() { + if req.ChannelId.Value == 0 { channelId, err := getDefaultChannel(a.db, a.cache, userId) if err != nil { return &apiError{err: err} @@ -513,7 +529,15 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd updatePayload.ChannelID = &req.ChannelId.Value } if len(req.Parts) > 0 { - updatePayload.Parts = datatypes.NewJSONSlice(req.Parts) + parts := []api.Part{} + for _, part := range req.Parts { + p := api.Part{ID: part.ID} + if part.Salt.Value != "" { + p.Salt = part.Salt + } + parts = append(parts, p) + } + updatePayload.Parts = datatypes.NewJSONSlice(parts) } err := a.db.Transaction(func(tx *gorm.DB) error { if err := tx.Where("id = ?", params.ID).First(&file).Error; err != nil { @@ -522,7 +546,7 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd if err := tx.Model(models.File{}).Where("id = ?", params.ID).Updates(updatePayload).Error; err != nil { return err } - if req.UploadId.IsSet() { + if req.UploadId.Value != "" { if err := tx.Where("upload_id = ?", req.UploadId.Value).Delete(&models.Upload{}).Error; err != nil { return err } diff --git a/pkg/services/file_query_builder.go b/pkg/services/file_query_builder.go index e4499e7b..d2320d2a 100644 --- a/pkg/services/file_query_builder.go +++ b/pkg/services/file_query_builder.go @@ -63,24 +63,24 @@ func (afb *fileQueryBuilder) execute(filesQuery *api.FilesListParams, userId int } func (afb *fileQueryBuilder) applyListFilters(query *gorm.DB, filesQuery *api.FilesListParams, userId int64) *gorm.DB { - if filesQuery.Path.IsSet() && !filesQuery.ParentId.IsSet() { + if filesQuery.Path.Value != "" && filesQuery.ParentId.Value == "" { query = query.Where("parent_id in (SELECT id FROM teldrive.get_file_from_path(?, ?, ?))", filesQuery.Path.Value, userId, true) } - if filesQuery.ParentId.IsSet() { + if filesQuery.ParentId.Value != "" { query = query.Where("parent_id = ?", filesQuery.ParentId.Value) } return query } func (afb *fileQueryBuilder) applyFindFilters(query *gorm.DB, filesQuery *api.FilesListParams, userId int64) *gorm.DB { - if filesQuery.DeepSearch.IsSet() && filesQuery.DeepSearch.Value && filesQuery.Query.IsSet() && filesQuery.Path.IsSet() { + if filesQuery.DeepSearch.Value && filesQuery.Query.Value != "" && filesQuery.Path.Value != "" { query = query.Where("files.id in (select id from subdirs)") } - if filesQuery.UpdatedAt.IsSet() { + if filesQuery.UpdatedAt.Value != "" { query, _ = afb.applyDateFilters(query, filesQuery.UpdatedAt.Value) } - if filesQuery.Query.IsSet() { + if filesQuery.Query.Value != "" { query = afb.applySearchQuery(query, filesQuery) } @@ -92,23 +92,24 @@ func (afb *fileQueryBuilder) applyFindFilters(query *gorm.DB, filesQuery *api.Fi } func (afb *fileQueryBuilder) applyFileSpecificFilters(query *gorm.DB, filesQuery *api.FilesListParams, userId int64) *gorm.DB { - if filesQuery.Name.IsSet() { + if filesQuery.Name.Value != "" { query = query.Where("name = ?", filesQuery.Name.Value) } - if filesQuery.ParentId.IsSet() { + if filesQuery.ParentId.Value != "" { query = query.Where("parent_id = ?", filesQuery.ParentId.Value) } - if !filesQuery.ParentId.IsSet() && filesQuery.Path.IsSet() && !filesQuery.Query.IsSet() { - query = query.Where("parent_id in (SELECT id FROM teldrive.get_file_from_path(?, ?, ?))", filesQuery.Path.Value, userId, true) + if filesQuery.ParentId.Value == "" && filesQuery.Path.Value != "" && filesQuery.Query.Value == "" { + query = query.Where("parent_id in (SELECT id FROM teldrive.get_file_from_path(?, ?, ?))", + filesQuery.Path.Value, userId, true) } - if filesQuery.Type.IsSet() { + if filesQuery.Type.Value != "" { query = query.Where("type = ?", filesQuery.Type.Value) } - if filesQuery.Shared.IsSet() && filesQuery.Shared.Value { + if filesQuery.Shared.Value { query = query.Where("id in (SELECT file_id FROM teldrive.file_shares where user_id = ?)", userId) } @@ -201,7 +202,7 @@ func (afb *fileQueryBuilder) buildFileQuery(query *gorm.DB, filesQuery *api.File } func (afb *fileQueryBuilder) buildSubqueryCTE(query *gorm.DB, filesQuery *api.FilesListParams, userId int64) *gorm.DB { - if filesQuery.DeepSearch.IsSet() && filesQuery.DeepSearch.Value && filesQuery.Query.IsSet() && filesQuery.Path.IsSet() { + if filesQuery.DeepSearch.Value && filesQuery.Query.Value != "" && filesQuery.Path.Value != "" { return afb.db.Clauses(exclause.With{Recursive: true, CTEs: []exclause.CTE{{Name: "subdirs", Subquery: exclause.Subquery{DB: afb.db.Model(&models.File{}).Select("id", "parent_id"). Where("id in (SELECT id FROM teldrive.get_file_from_path(?, ?, ?))", filesQuery.Path.Value, userId, true). diff --git a/pkg/services/upload.go b/pkg/services/upload.go index 683717ca..64e23b53 100644 --- a/pkg/services/upload.go +++ b/pkg/services/upload.go @@ -24,6 +24,7 @@ import ( "github.com/gotd/td/telegram/message" "github.com/gotd/td/telegram/uploader" "github.com/gotd/td/tg" + "github.com/tgdrive/teldrive/pkg/mapper" "github.com/tgdrive/teldrive/pkg/models" ) @@ -37,13 +38,13 @@ func (a *apiService) UploadsDelete(ctx context.Context, params api.UploadsDelete } func (a *apiService) UploadsPartsById(ctx context.Context, params api.UploadsPartsByIdParams) ([]api.UploadPart, error) { - parts := []api.UploadPart{} + parts := []models.Upload{} if err := a.db.Model(&models.Upload{}).Order("part_no").Where("upload_id = ?", params.ID). Where("created_at < ?", time.Now().UTC().Add(a.cnf.TG.Uploads.Retention)). Find(&parts).Error; err != nil { return nil, &apiError{err: err} } - return parts, nil + return mapper.ToUploadOut(parts), nil } func (a *apiService) UploadsStats(ctx context.Context, params api.UploadsStatsParams) ([]api.UploadStats, error) { @@ -74,7 +75,7 @@ func (a *apiService) UploadsStats(ctx context.Context, params api.UploadsStatsPa return stats, nil } -func (a *apiService) UploadsUpload(ctx context.Context, req api.UploadsUploadReq, params api.UploadsUploadParams) (*api.UploadPart, error) { +func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadReqWithContentType, params api.UploadsUploadParams) (*api.UploadPart, error) { var ( channelId int64 err error @@ -85,17 +86,17 @@ func (a *apiService) UploadsUpload(ctx context.Context, req api.UploadsUploadReq out api.UploadPart ) - if !params.Encrypted.IsSet() && a.cnf.TG.Uploads.EncryptionKey == "" { + if params.Encrypted.Value && a.cnf.TG.Uploads.EncryptionKey == "" { return nil, &apiError{err: errors.New("encryption is not enabled"), code: 400} } userId, session := auth.GetUser(ctx) - fileStream := req.Data + fileStream := req.Content.Data fileSize := params.ContentLength - if !params.ChannelId.IsSet() { + if params.ChannelId.Value == 0 { channelId, err = getDefaultChannel(a.db, a.cache, userId) if err != nil { return nil, err @@ -158,7 +159,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req api.UploadsUploadReq var salt string - if params.Encrypted.IsSet() { + if params.Encrypted.Value { //gen random Salt salt, _ = generateRandomSalt() cipher, err := crypt.NewCipher(a.cnf.TG.Uploads.EncryptionKey, salt) @@ -166,7 +167,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req api.UploadsUploadReq return err } fileSize = crypt.EncryptedSize(fileSize) - fileStream, err = cipher.EncryptData(req.Data) + fileStream, err = cipher.EncryptData(fileStream) if err != nil { return err } @@ -219,7 +220,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req api.UploadsUploadReq Size: fileSize, PartNo: int(params.PartNo), UserId: userId, - Encrypted: params.Encrypted.IsSet(), + Encrypted: params.Encrypted.Value, Salt: salt, } diff --git a/pkg/services/user.go b/pkg/services/user.go index ce7c78ed..cab4edb6 100644 --- a/pkg/services/user.go +++ b/pkg/services/user.go @@ -246,10 +246,10 @@ func (a *apiService) UsersUpdateChannel(ctx context.Context, req *api.ChannelUpd channel := &models.Channel{UserID: userId, Selected: true} - if req.ChannelId.IsSet() { + if req.ChannelId.Value != 0 { channel.ChannelID = req.ChannelId.Value } - if req.ChannelName.IsSet() { + if req.ChannelName.Value != "" { channel.ChannelName = req.ChannelName.Value }