Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
kiyoon committed Jun 8, 2024
1 parent a5011c8 commit fe01a19
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 29 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,14 @@ NOTE: This is work-in-progress and not yet ready for public. There isn't much cu
},
opts = {
-- Example 1:
-- Default behaviour for `tqdm` is `from tqdm.auto import tqdm`.
-- If you want to change it to `import tqdm`, you can set `import = {"tqdm"}` and `import_from = {tqdm = nil}` here.
-- If you want to change it to `from tqdm import tqdm`, you can set `import_from = {tqdm = "tqdm"}` here.
-- Default behaviour for `tqdm` is `from tqdm.auto import tqdm`.
-- If you want to change it to `import tqdm`, you can set `import = {"tqdm"}` and `import_from = {tqdm = nil}` here.
-- If you want to change it to `from tqdm import tqdm`, you can set `import_from = {tqdm = "tqdm"}` here.

-- Example 2:
-- Default behaviour for `logger` is `import logging`, ``, `logger = logging.getLogger(__name__)`.
-- If you want to change it to `import my_custom_logger`, ``, `logger = my_custom_logger.get_logger()`,
-- you can set `statement_after_imports = {logger = {"import my_custom_logger", "", "logger = my_custom_logger.get_logger()"}}` here.
extend_lookup_table = {
import = {
-- "tqdm",
Expand All @@ -116,6 +121,9 @@ NOTE: This is work-in-progress and not yet ready for public. There isn't much cu
-- tqdm = nil,
-- tqdm = "tqdm",
},
statement_after_imports = {
-- logger = { "import my_custom_logger", "", "logger = my_custom_logger.get_logger()" },
},
},
},
},
Expand Down
6 changes: 6 additions & 0 deletions lua/python_import.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ function M.setup(opts)
vim.tbl_deep_extend("force", {}, lookup_table.default_import_as, config.opts.extend_lookup_table.import_as)
lookup_table.import_from =
vim.tbl_deep_extend("force", {}, lookup_table.default_import_from, config.opts.extend_lookup_table.import_from)
lookup_table.statement_after_imports = vim.tbl_deep_extend(
"force",
{},
lookup_table.default_statement_after_imports,
config.opts.extend_lookup_table.statement_after_imports
)

-- make lookup table for faster lookup
for _, v in ipairs(lookup_table.import) do
Expand Down
36 changes: 17 additions & 19 deletions lua/python_import/api.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ local utils = require "python_import.utils"
M = {}

---@param bufnr integer
---@param statement string
---@param word string
---@param ts_node TSNode?
---@return string[]?
local function get_import(bufnr, statement, ts_node)
local function get_import(bufnr, word, ts_node)
bufnr = bufnr or vim.api.nvim_get_current_buf()
if statement == nil then
if word == nil then
return nil
end

Expand Down Expand Up @@ -55,31 +55,31 @@ local function get_import(bufnr, statement, ts_node)
end
end

if statement == "logger" then
return { "import logging", "", "logger = logging.getLogger(__name__)" }
if lookup_table.statement_after_imports[word] ~= nil then
return lookup_table.statement_after_imports[word]
end

-- extend from .. import *
if utils.get_cached_first_party_modules(bufnr) ~= nil then
local first_module = utils.get_cached_first_party_modules(bufnr)[1]
-- if statement ends with _DIR, import from the first module (from project import PROJECT_DIR)
if statement:match "_DIR$" then
return { "from " .. first_module .. " import " .. statement }
elseif statement == "setup_logging" then
if word:match "_DIR$" then
return { "from " .. first_module .. " import " .. word }
elseif word == "setup_logging" then
return { "from " .. first_module .. ".utils.log import setup_logging" }
end
end

if lookup_table.is_import[statement] then
return { "import " .. statement }
if lookup_table.is_import[word] then
return { "import " .. word }
end

if lookup_table.import_as[statement] ~= nil then
return { "import " .. lookup_table.import_as[statement] .. " as " .. statement }
if lookup_table.import_as[word] ~= nil then
return { "import " .. lookup_table.import_as[word] .. " as " .. word }
end

if lookup_table.import_from[statement] ~= nil then
return { "from " .. lookup_table.import_from[statement] .. " import " .. statement }
if lookup_table.import_from[word] ~= nil then
return { "from " .. lookup_table.import_from[word] .. " import " .. word }
end

-- Can't find from pre-defined tables.
Expand All @@ -92,10 +92,8 @@ local function get_import(bufnr, statement, ts_node)
if requirements_installed then
local project_root = vim.fs.root(0, { ".git", "pyproject.toml" })
if project_root ~= nil then
local find_import_outputs = vim.api.nvim_exec(
[[w !python-import count ']] .. project_root .. [[' ']] .. statement .. [[']],
{ output = true }
)
local find_import_outputs =
vim.api.nvim_exec([[w !python-import count ']] .. project_root .. [[' ']] .. word .. [[']], { output = true })

if find_import_outputs ~= nil then
-- strip
Expand Down Expand Up @@ -133,7 +131,7 @@ local function get_import(bufnr, statement, ts_node)
end
end

return { "import " .. statement }
return { "import " .. word }
end

---@param bufnr integer
Expand Down
3 changes: 3 additions & 0 deletions lua/python_import/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ M.default_opts = {
-- tqdm = "tqdm.auto",
-- nn = "torch",
},
statement_after_imports = {
-- logger = { "import my_custom_logger", "", "logger = my_custom_logger.get_logger()" },
},
},
}

Expand Down
15 changes: 10 additions & 5 deletions lua/python_import/lookup_table.lua
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ M.default_import_from = {
WebDriver = "selenium.webdriver.remote.webdriver",
}

M.default_statement_after_imports = {
logger = { "import logging", "", "logger = logging.getLogger(__name__)" },
}

M.python_keywords = {
"False",
"None",
Expand Down Expand Up @@ -404,16 +408,17 @@ M.python_builtins = {
"_",
}

M.import = {}
M.is_import = {}
M.import = {} ---@type string[]
M.is_import = {} ---@type table<string, boolean>
-- for _, v in ipairs(M.import) do
-- M.is_import[v] = true
-- end

M.import_as = {}
M.import_from = {}
M.import_as = {} ---@type table<string, string>
M.import_from = {} ---@type table<string, string>
M.statement_after_imports = {} ---@type table<string, table<string>>

M.ban_from_import = {}
M.ban_from_import = {} ---@type table<string, boolean>
for _, v in ipairs(M.python_keywords) do
M.ban_from_import[v] = true
end
Expand Down
4 changes: 2 additions & 2 deletions lua/python_import/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ end
---Find src/module_name in git root
---@param bufnr integer?
---@return string[]?
function M.find_python_first_party_modules(bufnr)
local function find_python_first_party_modules(bufnr)
bufnr = bufnr or vim.api.nvim_get_current_buf()
-- local git_root = vim.fn.systemlist "git rev-parse --show-toplevel"
local git_root = vim.fs.root(bufnr, { ".git", "pyproject.toml" })
Expand Down Expand Up @@ -146,7 +146,7 @@ function M.get_cached_first_party_modules(bufnr)
bufnr = bufnr or vim.api.nvim_get_current_buf()

if buf_to_first_party_modules[bufnr] == nil then
buf_to_first_party_modules[bufnr] = M.find_python_first_party_modules(bufnr)
buf_to_first_party_modules[bufnr] = find_python_first_party_modules(bufnr)
end

return buf_to_first_party_modules[bufnr]
Expand Down

0 comments on commit fe01a19

Please sign in to comment.