diff --git a/packages.go b/packages.go index 20324140c..33696c5cc 100644 --- a/packages.go +++ b/packages.go @@ -3,15 +3,16 @@ package swag import ( "fmt" "go/ast" + "go/build" goparser "go/parser" "go/token" + "golang.org/x/tools/go/loader" "os" "path/filepath" "runtime" "sort" "strings" - - "golang.org/x/tools/go/loader" + "sync" ) // PackagesDefinitions map[package import path]*PackageDefinitions. @@ -436,6 +437,52 @@ func (pkgDefs *PackagesDefinitions) findTypeSpec(pkgPath string, typeName string return nil } +// PackageCache 用于缓存已查找的包 +var packageCache = make(map[string]*build.Package) +var cacheMutex sync.RWMutex +var packageLocks sync.Map + +// FindPackage 实现 +func FindPackage(ctxt *build.Context, importPath, fromDir string, mode build.ImportMode) (*build.Package, error) { + // 检查缓存 + cacheMutex.RLock() + pkg, found := packageCache[importPath] + cacheMutex.RUnlock() + + if found { + return pkg, nil + } + + // 获取或创建锁 + lockInterface, _ := packageLocks.LoadOrStore(importPath, &sync.Mutex{}) + lock := lockInterface.(*sync.Mutex) + + lock.Lock() // 独占锁 + defer lock.Unlock() // 确保在函数结束时解锁 + + // 再次检查缓存 + cacheMutex.RLock() + pkg, found = packageCache[importPath] + cacheMutex.RUnlock() + + if found { + return pkg, nil + } + + // 查找包 + pkg, err := ctxt.Import(importPath, fromDir, mode) + if err != nil { + return nil, err + } + + // 更新缓存 + cacheMutex.Lock() + packageCache[importPath] = pkg + cacheMutex.Unlock() + + return pkg, nil +} + func (pkgDefs *PackagesDefinitions) loadExternalPackage(importPath string) error { cwd, err := os.Getwd() if err != nil { @@ -443,8 +490,9 @@ func (pkgDefs *PackagesDefinitions) loadExternalPackage(importPath string) error } conf := loader.Config{ - ParserMode: goparser.ParseComments, - Cwd: cwd, + ParserMode: goparser.ParseComments, + Cwd: cwd, + FindPackage: FindPackage, } conf.Import(importPath) @@ -526,8 +574,7 @@ func (pkgDefs *PackagesDefinitions) findPackagePathFromImports(pkg string, file } } } - - if len(pkg) == 0 || file.Name.Name == pkg { + if (len(pkg) == 0 || file.Name.Name == pkg) && pkgDefs.files[file] != nil { matchedPkgPaths = append(matchedPkgPaths, pkgDefs.files[file].PackagePath) }