Skip to content

Commit

Permalink
refactor: simplify condition checks and improve upload handling
Browse files Browse the repository at this point in the history
  • Loading branch information
divyam234 committed Jan 2, 2025
1 parent a296b36 commit 77e463b
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 110 deletions.
4 changes: 4 additions & 0 deletions .goreleaser.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ version: 2
project_name: teldrive
env:
- GO111MODULE=on

before:
hooks:
- go generate ./...

builds:
- env:
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ GOARCH ?= $(shell go env GOARCH)
VERSION:= $(GIT_TAG)
BINARY_EXTENSION :=

.PHONY: all build run clean frontend backend run sync-ui retag patch-version minor-version generate
.PHONY: all build run clean frontend backend run sync-ui retag patch-version minor-version gen

all: build

Expand Down
60 changes: 0 additions & 60 deletions internal/chizap/chizap.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,8 @@ package chizap

import (
"context"
"net"
"net/http"
"net/http/httputil"
"os"
"regexp"
"runtime/debug"
"strings"
"time"

"github.com/go-chi/chi/v5/middleware"
Expand Down Expand Up @@ -115,58 +110,3 @@ func ChizapWithConfig(logger ZapLogger, conf *Config) func(next http.Handler) ht
})
}
}

func defaultHandleRecovery(w http.ResponseWriter, r *http.Request, err interface{}) {
w.WriteHeader(http.StatusInternalServerError)
}

func RecoveryWithZap(logger ZapLogger, stack bool) func(next http.Handler) http.Handler {
return CustomRecoveryWithZap(logger, stack, defaultHandleRecovery)
}

func CustomRecoveryWithZap(logger ZapLogger, stack bool, recovery func(w http.ResponseWriter, r *http.Request, err interface{})) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
var brokenPipe bool
if ne, ok := err.(*net.OpError); ok {
if se, ok := ne.Err.(*os.SyscallError); ok {
if strings.Contains(strings.ToLower(se.Error()), "broken pipe") ||
strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
brokenPipe = true
}
}
}

httpRequest, _ := httputil.DumpRequest(r, false)
if brokenPipe {
logger.Error(r.URL.Path,
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
http.Error(w, "connection broken", http.StatusInternalServerError)
return
}

if stack {
logger.Error("[Recovery from panic]",
zap.Time("time", time.Now()),
zap.Any("error", err),
zap.String("request", string(httpRequest)),
zap.String("stack", string(debug.Stack())),
)
} else {
logger.Error("[Recovery from panic]",
zap.Time("time", time.Now()),
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
}
recovery(w, r, err)
}
}()
next.ServeHTTP(w, r)
})
}
}
10 changes: 9 additions & 1 deletion internal/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,21 @@ func SPAHandler(filesystem fs.FS) http.HandlerFunc {
logging.DefaultLogger().Fatal(err.Error())
}
return func(w http.ResponseWriter, r *http.Request) {
f, err := spaFS.Open(strings.TrimPrefix(path.Clean(r.URL.Path), "/"))
filePath := strings.TrimPrefix(path.Clean(r.URL.Path), "/")
f, err := spaFS.Open(filePath)
if err == nil {
defer f.Close()
}
if os.IsNotExist(err) {
r.URL.Path = "/"
filePath = "index.html"
}
if filePath == "index.html" {
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
} else {
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
}

http.FileServer(http.FS(spaFS)).ServeHTTP(w, r)
}
}
2 changes: 1 addition & 1 deletion internal/reader/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (r *LinearReader) getPartReader() (io.ReadCloser, error) {
reader io.ReadCloser
err error
)
if r.file.Encrypted.IsSet() && r.file.Encrypted.Value {
if r.file.Encrypted.Value {
salt := r.parts[r.ranges[r.pos].PartNo].Salt
cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt)
reader, err = cipher.DecryptDataSeek(r.ctx,
Expand Down
17 changes: 17 additions & 0 deletions pkg/mapper/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,20 @@ func ToFileOut(file models.File, extended bool) *api.File {
}
return res
}

func ToUploadOut(parts []models.Upload) []api.UploadPart {
res := []api.UploadPart{}
for _, part := range parts {
res = append(res, api.UploadPart{
Name: part.Name,
PartId: part.PartId,
ChannelId: part.ChannelID,
PartNo: part.PartNo,
Size: part.Size,
Encrypted: part.Encrypted,
Salt: api.NewOptString(part.Salt),
})

}
return res
}
2 changes: 1 addition & 1 deletion pkg/services/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (a *apiService) AuthLogout(ctx context.Context) (*api.AuthLogoutNoContent,
}

func (a *apiService) AuthSession(ctx context.Context, params api.AuthSessionParams) (api.AuthSessionRes, error) {
if !params.AccessToken.IsSet() {
if params.AccessToken.Value == "" {
return &api.AuthSessionNoContent{}, nil
}
claims, err := auth.VerifyUser(a.db, a.cache, a.cnf.JWT.Secret, params.AccessToken.Value)
Expand Down
2 changes: 1 addition & 1 deletion pkg/services/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func getParts(ctx context.Context, client *telegram.Client, cache cache.Cacher,
Size: document.Size,
Salt: file.Parts[i].Salt.Value,
}
if file.Encrypted.IsSet() && file.Encrypted.Value {
if file.Encrypted.Value {
part.DecryptedSize, _ = crypt.DecryptedSize(document.Size)
}
parts = append(parts, part)
Expand Down
68 changes: 46 additions & 22 deletions pkg/services/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
dbFile.Name = req.NewName.Or(file.Name)
dbFile.Size = utils.Ptr(file.Size.Value)
dbFile.Type = string(file.Type)
dbFile.MimeType = file.MimeType.Value
dbFile.MimeType = file.MimeType.Or(defaultContentType)
dbFile.Parts = datatypes.NewJSONSlice(newIds)
dbFile.UserID = userId
dbFile.Status = "active"
Expand Down Expand Up @@ -226,15 +226,15 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
channelId int64
)

if fileIn.Path.IsSet() {
if fileIn.Path.Value != "" {
path = strings.TrimSpace(fileIn.Path.Value)
path = strings.ReplaceAll(path, "//", "/")
if path != "/" {
path = strings.TrimSuffix(path, "/")
}
}

if path != "" && !fileIn.ParentId.IsSet() {
if path != "" && fileIn.ParentId.Value == "" {
parent, err = a.getFileFromPath(path, userId)
if err != nil {
return nil, &apiError{err: err, code: 404}
Expand All @@ -243,7 +243,7 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
String: parent.Id,
Valid: true,
}
} else if fileIn.ParentId.IsSet() {
} else if fileIn.ParentId.Value != "" {
fileDB.ParentID = sql.NullString{
String: fileIn.ParentId.Value,
Valid: true,
Expand All @@ -257,7 +257,7 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
fileDB.MimeType = "drive/folder"
fileDB.Parts = nil
} else if fileIn.Type == "file" {
if !fileIn.ChannelId.IsSet() {
if fileIn.ChannelId.Value == 0 {
channelId, err = getDefaultChannel(a.db, a.cache, userId)
if err != nil {
return nil, &apiError{err: err}
Expand All @@ -266,18 +266,26 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
channelId = fileIn.ChannelId.Value
}
fileDB.ChannelID = &channelId
fileDB.MimeType = fileIn.MimeType.Or("application/octet-stream")
fileDB.MimeType = fileIn.MimeType.Value
fileDB.Category = string(category.GetCategory(fileIn.Name))
if len(fileIn.Parts) > 0 {
fileDB.Parts = datatypes.NewJSONSlice(fileIn.Parts)
parts := []api.Part{}
for _, part := range fileIn.Parts {
p := api.Part{ID: part.ID}
if part.Salt.Value != "" {
p.Salt = part.Salt
}
parts = append(parts, p)
}
fileDB.Parts = datatypes.NewJSONSlice(parts)
}
fileDB.Size = utils.Ptr(fileIn.Size.Or(0))
}
fileDB.Name = fileIn.Name
fileDB.Type = string(fileIn.Type)
fileDB.UserID = userId
fileDB.Status = "active"
fileDB.Encrypted = fileIn.Encrypted.Or(false)
fileDB.Encrypted = fileIn.Encrypted.Value
if err := a.db.Create(&fileDB).Error; err != nil {
if database.IsKeyConflictErr(err) {
return nil, &apiError{err: errors.New("file already exists"), code: 409}
Expand All @@ -292,7 +300,7 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre

var fileShare models.FileShare

if req.Password.IsSet() {
if req.Password.Value != "" {
bytes, err := bcrypt.GenerateFromPassword([]byte(req.Password.Value), bcrypt.MinCost)
if err != nil {
return &apiError{err: err}
Expand All @@ -315,14 +323,14 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre

func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error {
userId, _ := auth.GetUser(ctx)
if !req.Source.IsSet() && len(req.Ids) == 0 {
if req.Source.Value == "" && len(req.Ids) == 0 {
return &apiError{err: errors.New("source or ids is required"), code: 409}
}
if req.Source.IsSet() && len(req.Ids) == 0 {
if req.Source.Value != "" && len(req.Ids) == 0 {
if err := a.db.Exec("call teldrive.delete_folder_recursive($1 , $2)", req.Source.Value, userId).Error; err != nil {
return &apiError{err: err}
}
} else if !req.Source.IsSet() && len(req.Ids) > 0 {
} else if req.Source.Value == "" && len(req.Ids) > 0 {
if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil {
return &apiError{err: err}
}
Expand Down Expand Up @@ -351,7 +359,7 @@ func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreat

var fileShareUpdate models.FileShare

if req.Password.IsSet() {
if req.Password.Value != "" {
bytes, err := bcrypt.GenerateFromPassword([]byte(req.Password.Value), bcrypt.MinCost)
if err != nil {
return &apiError{err: err}
Expand Down Expand Up @@ -409,15 +417,15 @@ func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error {

func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error {
userId, _ := auth.GetUser(ctx)
if !req.Source.IsSet() && len(req.Ids) == 0 {
if req.Source.Value == "" && len(req.Ids) == 0 {
return &apiError{err: errors.New("source or ids is required"), code: 409}
}
if !req.Source.IsSet() && len(req.Ids) > 0 {
if req.Source.Value != "" && len(req.Ids) > 0 {
if err := a.db.Exec("select * from teldrive.move_items($1 , $2 , $3)", req.Ids, req.Destination, userId).Error; err != nil {
return &apiError{err: err}
}
}
if req.Source.IsSet() && len(req.Ids) == 0 {
if req.Source.Value == "" && len(req.Ids) == 0 {
if err := a.db.Exec("select * from teldrive.move_directory(? , ? , ?)", req.Source.Value,
req.Destination, userId).Error; err != nil {
return &apiError{err: err}
Expand Down Expand Up @@ -466,13 +474,21 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param
chain *gorm.DB
)
updateDb := models.File{}
if req.Name.IsSet() {
if req.Name.Value != "" {
updateDb.Name = req.Name.Value
}
if len(req.Parts) > 0 {
updateDb.Parts = datatypes.NewJSONSlice(req.Parts)
parts := []api.Part{}
for _, part := range req.Parts {
p := api.Part{ID: part.ID}
if part.Salt.Value != "" {
p.Salt = part.Salt
}
parts = append(parts, p)
}
updateDb.Parts = datatypes.NewJSONSlice(parts)
}
if req.Size.IsSet() {
if req.Size.Value != 0 {
updateDb.Size = utils.Ptr(req.Size.Value)
}
if req.UpdatedAt.IsSet() {
Expand Down Expand Up @@ -503,7 +519,7 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
Size: utils.Ptr(req.Size),
}

if !req.ChannelId.IsSet() {
if req.ChannelId.Value == 0 {
channelId, err := getDefaultChannel(a.db, a.cache, userId)
if err != nil {
return &apiError{err: err}
Expand All @@ -513,7 +529,15 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
updatePayload.ChannelID = &req.ChannelId.Value
}
if len(req.Parts) > 0 {
updatePayload.Parts = datatypes.NewJSONSlice(req.Parts)
parts := []api.Part{}
for _, part := range req.Parts {
p := api.Part{ID: part.ID}
if part.Salt.Value != "" {
p.Salt = part.Salt
}
parts = append(parts, p)
}
updatePayload.Parts = datatypes.NewJSONSlice(parts)
}
err := a.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("id = ?", params.ID).First(&file).Error; err != nil {
Expand All @@ -522,7 +546,7 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
if err := tx.Model(models.File{}).Where("id = ?", params.ID).Updates(updatePayload).Error; err != nil {
return err
}
if req.UploadId.IsSet() {
if req.UploadId.Value != "" {
if err := tx.Where("upload_id = ?", req.UploadId.Value).Delete(&models.Upload{}).Error; err != nil {
return err
}
Expand Down
Loading

0 comments on commit 77e463b

Please sign in to comment.