Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
kiyoon committed Jun 8, 2024
1 parent e679a98 commit a5011c8
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 153 deletions.
161 changes: 8 additions & 153 deletions lua/python_import/api.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,155 +6,10 @@ end
local lookup_table = require "python_import.lookup_table"
local ts_utils = require "python_import.ts_utils"
local health = require "python_import.health"
local utils = require "python_import.utils"

M = {}

---Return line after the first comments and docstring.
---It iterates e.g. 50 first lines and obtains treesitter nodes to check the syntax (string or comment)
---@param max_lines integer?
---@return integer?
local function find_line_after_module_docstring(max_lines)
max_lines = max_lines or 50
local bufnr = vim.fn.bufnr()
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, max_lines, false)
for i, line in ipairs(lines) do
local node = vim.treesitter.get_node { pos = { i - 1, 0 } }
-- if node == nil or node:type() == "module" then
-- local stripped = line:match "^%s*(.*)%s*$"
-- if stripped == "" then
-- return i
-- end
-- elseif node:type() == "import_statement" or node:type() == "import_from_statement" then
-- return i
if
node ~= nil
and node:type() ~= "comment"
and node:type() ~= "string"
and node:type() ~= "string_start"
and node:type() ~= "string_content"
and node:type() ~= "string_end"
then
return i
end
end
return nil
end

---Find the first import statement in a python file.
---@param max_lines integer?
local function find_line_first_import(max_lines)
max_lines = max_lines or 50
local bufnr = vim.fn.bufnr()
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, max_lines, false)
for i, line in ipairs(lines) do
local node = vim.treesitter.get_node { pos = { i - 1, 0 } }
if node ~= nil and (node:type() == "import_statement" or node:type() == "import_from_statement") then
-- additional check whether the node is top-level.
-- if not, it's probably an import inside a function
if node:parent():type() == "module" then
return i
end
end
end
return nil
end

---Find the last import statement in a python file.
---@param max_lines integer?
local function find_line_last_import(max_lines)
max_lines = max_lines or 50
local bufnr = vim.fn.bufnr()
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, max_lines, false)
-- iterate backwards
for i = #lines, 1, -1 do
local node = vim.treesitter.get_node { pos = { i - 1, 0 } }
if node ~= nil and (node:type() == "import_statement" or node:type() == "import_from_statement") then
-- additional check whether the node is top-level.
-- if not, it's probably an import inside a function
if node:parent():type() == "module" then
return i
end
end
end
return nil
end

---Find src/module_name in git root
---@param bufnr integer?
---@return string[]?
local function find_python_first_party_modules(bufnr)
bufnr = bufnr or vim.api.nvim_get_current_buf()
-- local git_root = vim.fn.systemlist "git rev-parse --show-toplevel"
local git_root = vim.fs.root(bufnr, { ".git", "pyproject.toml" })
if git_root == nil then
return nil
end

local src_dir = git_root .. "/src"
if vim.fn.isdirectory(src_dir) == 0 then
return nil
end

local modules = {}
local function find_modules(dir)
local files = vim.fn.readdir(dir)
for _, file in ipairs(files) do
local path = dir .. "/" .. file
local stat = vim.loop.fs_stat(path)
if stat.type == "directory" then
-- no egg-info
if file:match "%.egg%-info$" == nil then
modules[#modules + 1] = file
end
end
end
end
find_modules(src_dir)

if #modules == 0 then
return nil
end

return modules
end

---Get current word in a buffer
---It is aware of the insert mode (move column by -1 if the mode is insert).
---@return string
local function get_current_word()
local line = vim.fn.getline "."
local col = vim.fn.col "."
local mode = vim.fn.mode "."
if mode == "i" then
-- insert mode has cursor one char to the right
col = col - 1
end
local finish = line:find("[^a-zA-Z0-9_]", col)
-- look forward
while finish == col do
col = col + 1
finish = line:find("[^a-zA-Z0-9_]", col)
end

if finish == nil then
finish = #line + 1
end
local start = vim.fn.match(line:sub(1, col), [[\k*$]])
return line:sub(start + 1, finish - 1)
end

local buf_to_first_party_modules = {}

local function get_cached_first_party_modules(bufnr)
bufnr = bufnr or vim.api.nvim_get_current_buf()

if buf_to_first_party_modules[bufnr] == nil then
buf_to_first_party_modules[bufnr] = find_python_first_party_modules(bufnr)
end

return buf_to_first_party_modules[bufnr]
end

---@param bufnr integer
---@param statement string
---@param ts_node TSNode?
Expand Down Expand Up @@ -205,8 +60,8 @@ local function get_import(bufnr, statement, ts_node)
end

-- extend from .. import *
if get_cached_first_party_modules(bufnr) ~= nil then
local first_module = get_cached_first_party_modules(bufnr)[1]
if utils.get_cached_first_party_modules(bufnr) ~= nil then
local first_module = utils.get_cached_first_party_modules(bufnr)[1]
-- if statement ends with _DIR, import from the first module (from project import PROJECT_DIR)
if statement:match "_DIR$" then
return { "from " .. first_module .. " import " .. statement }
Expand Down Expand Up @@ -298,10 +153,10 @@ local function add_import(bufnr, module, ts_node)

local import_statements = nil
-- prefer to add after last import
local line_number = find_line_last_import()
local line_number = utils.find_line_last_import()
if line_number == nil then
-- if no import, add to first empty line
line_number = find_line_after_module_docstring()
line_number = utils.find_line_after_module_docstring()
if line_number == nil then
line_number = 1
end
Expand Down Expand Up @@ -331,7 +186,7 @@ local function add_import_current_word(winnr)
winnr = winnr or vim.api.nvim_get_current_win()
local bufnr = vim.api.nvim_win_get_buf(winnr)

local module = get_current_word()
local module = utils.get_current_word()
local node = ts_utils.get_node_at_cursor(winnr)
-- local module = vim.fn.expand "<cword>"
return add_import(bufnr, module, node)
Expand Down Expand Up @@ -397,10 +252,10 @@ end
M.add_rich_traceback = function()
local statements = { "import rich.traceback", "", "rich.traceback.install(show_locals=True)", "" }

local line_number = find_line_first_import() ---@type integer | nil
local line_number = utils.find_line_first_import() ---@type integer | nil

if line_number == nil then
line_number = find_line_after_module_docstring()
line_number = utils.find_line_after_module_docstring()
if line_number == nil then
line_number = 1
end
Expand Down
155 changes: 155 additions & 0 deletions lua/python_import/utils.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
M = {}

---Return line after the first comments and docstring.
---It iterates e.g. 50 first lines and obtains treesitter nodes to check the syntax (string or comment)
---@param bufnr integer?
---@param max_lines integer?
---@return integer?
function M.find_line_after_module_docstring(bufnr, max_lines)
bufnr = bufnr or vim.api.nvim_get_current_buf()
max_lines = max_lines or 50

local lines = vim.api.nvim_buf_get_lines(bufnr, 0, max_lines, false)
for i, line in ipairs(lines) do
local node = vim.treesitter.get_node { pos = { i - 1, 0 } }
-- if node == nil or node:type() == "module" then
-- local stripped = line:match "^%s*(.*)%s*$"
-- if stripped == "" then
-- return i
-- end
-- elseif node:type() == "import_statement" or node:type() == "import_from_statement" then
-- return i
if
node ~= nil
and node:type() ~= "comment"
and node:type() ~= "string"
and node:type() ~= "string_start"
and node:type() ~= "string_content"
and node:type() ~= "string_end"
then
return i
end
end
return nil
end

---Find the first import statement in a python file.
---@param bufnr integer?
---@param max_lines integer?
function M.find_line_first_import(bufnr, max_lines)
bufnr = bufnr or vim.api.nvim_get_current_buf()
max_lines = max_lines or 50

local lines = vim.api.nvim_buf_get_lines(bufnr, 0, max_lines, false)
for i, line in ipairs(lines) do
local node = vim.treesitter.get_node { pos = { i - 1, 0 } }
if node ~= nil and (node:type() == "import_statement" or node:type() == "import_from_statement") then
-- additional check whether the node is top-level.
-- if not, it's probably an import inside a function
if node:parent():type() == "module" then
return i
end
end
end
return nil
end

---Find the last import statement in a python file.
---@param bufnr integer?
---@param max_lines integer?
function M.find_line_last_import(bufnr, max_lines)
bufnr = bufnr or vim.api.nvim_get_current_buf()
max_lines = max_lines or 50

local lines = vim.api.nvim_buf_get_lines(bufnr, 0, max_lines, false)
-- iterate backwards
for i = #lines, 1, -1 do
local node = vim.treesitter.get_node { pos = { i - 1, 0 } }
if node ~= nil and (node:type() == "import_statement" or node:type() == "import_from_statement") then
-- additional check whether the node is top-level.
-- if not, it's probably an import inside a function
if node:parent():type() == "module" then
return i
end
end
end
return nil
end

---Find src/module_name in git root
---@param bufnr integer?
---@return string[]?
function M.find_python_first_party_modules(bufnr)
bufnr = bufnr or vim.api.nvim_get_current_buf()
-- local git_root = vim.fn.systemlist "git rev-parse --show-toplevel"
local git_root = vim.fs.root(bufnr, { ".git", "pyproject.toml" })
if git_root == nil then
return nil
end

local src_dir = git_root .. "/src"
if vim.fn.isdirectory(src_dir) == 0 then
return nil
end

local modules = {}
local function find_modules(dir)
local files = vim.fn.readdir(dir)
for _, file in ipairs(files) do
local path = dir .. "/" .. file
local stat = vim.loop.fs_stat(path)
if stat.type == "directory" then
-- no egg-info
if file:match "%.egg%-info$" == nil then
modules[#modules + 1] = file
end
end
end
end
find_modules(src_dir)

if #modules == 0 then
return nil
end

return modules
end

---Get current word in a buffer
---It is aware of the insert mode (move column by -1 if the mode is insert).
---@return string
function M.get_current_word()
local line = vim.fn.getline "."
local col = vim.fn.col "."
local mode = vim.fn.mode "."
if mode == "i" then
-- insert mode has cursor one char to the right
col = col - 1
end
local finish = line:find("[^a-zA-Z0-9_]", col)
-- look forward
while finish == col do
col = col + 1
finish = line:find("[^a-zA-Z0-9_]", col)
end

if finish == nil then
finish = #line + 1
end
local start = vim.fn.match(line:sub(1, col), [[\k*$]])
return line:sub(start + 1, finish - 1)
end

local buf_to_first_party_modules = {}

function M.get_cached_first_party_modules(bufnr)
bufnr = bufnr or vim.api.nvim_get_current_buf()

if buf_to_first_party_modules[bufnr] == nil then
buf_to_first_party_modules[bufnr] = M.find_python_first_party_modules(bufnr)
end

return buf_to_first_party_modules[bufnr]
end

return M

0 comments on commit a5011c8

Please sign in to comment.