Jack's blog

Refactoring Python with 🌳 Tree-sitter & Jedi

I was toying around with a refactor the other day that would have taken me ages by hand as it involved 100s of files.

I wanted to rename every instance of a pytest fixture from database -> db across my entire repo (silly I know). Unfortunately this isn't something my editor of choice can magically refactor.


Here's how my test files looked before:

@pytest.fixture()
def test_a(database): ...

def test_b(database): ...

def test_c(database, x): ...

def test_d(x, database): ...

def test_e(x, database, y): ...

After the refactor, this is how they look:

@pytest.fixture()
def test_a(db): ...

def test_b(db): ...

def test_c(db, x): ...

def test_d(x, db): ...

def test_e(x, db, y): ...

After struggling to achieve what I wanted with the tools I'd typically reach for (grep + sed) I decided to try something a bit fancier.

Parsing nodes with Tree-Sitter

The first thing to do is to find all row/column of each database identifier:

from pathlib import Path

import tree_sitter_python as tspython
from tree_sitter import Language, Parser

PY_LANGUAGE = Language(tspython.language())

parser = Parser(PY_LANGUAGE)


def parse_func(node):
    for child in node.children:
        if child.type == "parameters":
            for sub_child in child.children:
                if sub_child.type == "identifier" and sub_child.text == b"database":
                    yield sub_child.start_point


def parse_file(path):
    tree = parser.parse(path.read_bytes())
    for child in tree.root_node.children:
        if child.type == "function_definition":
            yield from parse_func(child)


def process_file(path):
    for match in parse_file(path):
        print(match)

This prints the location of all the instances of def test_ functions.

Point(row=7, column=11)
Point(row=10, column=11)
Point(row=13, column=14)
Point(row=16, column=14)

Handling decorated functions

The above code doesn't include support decorated functions, for example:

@pytest.fixture()
def test_a(database): ...

The following requires a bit more effort to handle correctly:

def parse_file(path):
    tree = parser.parse(path.read_bytes())
    for child in tree.root_node.children:
        if child.type == "function_definition":
            yield from parse_func(child)
        elif child.type == "decorated_definition":
            for sub_child in child.children:
                if sub_child.type == "function_definition":
                    yield from parse_func(sub_child)

Renaming with Jedi

Now for each row/col I can use Jedi to rename the identifier:

def process_file(path):
    for match in parse_file(path):
        script = Script(code=path.read_text(), path=str(path))
        result = script.rename(line=match.row + 1, column=match.column, new_name="db")
        result.apply()

Conclusion

Ironically I ended up not merging this change, but was a fun learning exercise. I found both jedi and tree-sitter relatively easy to learn, I'll certainly be keeping them in my toolbelt for situations where grep + sed don't quite cut it.

I found myself wishing tree-sitter had a mechanism to directly manipulate the AST (maybe it does?). I found myself unable to rename/delete nodes and then write the AST back to disk. I was forced either to reach for Jedi or manually edit the source by hand (and then deal with nasty off-set re-parsing logic).

Note The astute amongst you will notice that this script does a lot of re-parsing. I could probably optimise this further, but for a quick project wide refactor I found this to be plenty fast enough.

Here's a video of it in action: