Search code examples
pythonparsingluaneovimtreesitter

Neovim command to create print statement for Python function arguments


As the title suggests, I had the idea to create a command that can automatically write a print statement for all arguments to the current function I am in. Initially, I wanted to do this for Python functions - as I find myself doing this manually as a debug tool.

For example, take the Python function:

def my_func(a: int, b: str) -> None:
  ...

If my cursor is within the function definition and I run the command PythonPrintParams, I want to write the following to the first line of the function:

print(
  f"a={a}, "
  f"b={b}"
)

I want to write a Lua function that can do this using treesitter and then setup a Neovim autocommand. So far I have:

local ts_utils = require('nvim-treesitter.ts_utils')

local function list_fn_params()
  local node = ts_utils.get_node_at_cursor()

  while node and node:type() ~= 'function_definition' do
    node = node:parent()
  end

  if not node then
    print("Could not find params, cursor not inside a function.")
    return
  end

  local param_nodes = node:field('parameters')
  if param_nodes  then
    for i, param_node in ipairs(param_nodes) do
      print("  Node", i, "type:", param_node:type())
      print("  Node text:", vim.inspect(ts_utils.get_node_text(param_node)))
    end
  end
end

vim.api.nvim_create_user_command('PythonPrintParams', function()
  list_fn_params()
end, {})

This ends up printing:

  Node 1 type: parameters
  Node text: { "(", "    a: int, b: str", ")" }

This is a good start but then requires some manual parsing that I think will get messy. How can I parse deeper to get the variable and then write my print statement? Is there an alternate route with the LSP API that makes identifying args easier?

I did spend time with the py-tree-sitter docs https://github.com/tree-sitter/py-tree-sitter/tree/master.


Solution

  • Get tree and node at your cursor

      local cursor_row, cursor_col = unpack(vim.api.nvim_win_get_cursor(0))
    
      cursor_row = cursor_row - 1
    
      local parser = vim.treesitter.get_parser(0)
    
      local tree = parser:parse()[1]
    
      local root = tree:root()
    
     local node = root:named_descendant_for_range(cursor_row, cursor_col, cursor_row, cursor_col)
    

    Go through nodes find function_definition in range of your cursor, take its params and add print to first line of that function

      while node do
        if node:type() == "function_definition" then
          local start_row, _, end_row, _ = node:range()
    
          if cursor_row >= start_row and cursor_row <= end_row then
            print("Cursor is inside function")
    
            for child in node:iter_children() do
              if child:type() == "parameters" then
                local args_list = {}
    
                for param in child:iter_children() do
                 
                  if param:type() == "identifier" then
                    local arg_name = vim.treesitter.get_node_text(param, 0)
    
                    table.insert(args_list, string.format("%s={%s}", arg_name, arg_name))
                  end
                end
    
                if #args_list ~= 0 then
                  local args_string = table.concat(args_list, ", ")
    
    
                  local print_statement = string.format('   print(f"%s")', args_string)
    
                  vim.api.nvim_buf_set_lines(0, start_row + 1, start_row + 1, false, { print_statement })
                end
    
                return
              end
            end
          end
        
        node = node:parent()  
        end
    
      print("Cursor is not inside a function.") 
    end
    

    This handles only simple case,

    def add_numbers(a, b):
        return a + b
    

    To check for other type of params you can always run :InspectTree and see what type you need for example to add default params

    if param:type() == "default_parameter" then
     local arg_name = vim.treesitter.get_node_text(param:child(), 0)
    end