Refactoring Python with 🌳 Tree-sitter & Jedi
I was toying around with a refactor the other day that would have taken me ages by hand as it involved 100s of files.
I wanted to rename every instance of a pytest
fixture from database
-> db
across my entire repo (silly I know). Unfortunately this isn't something my editor of choice can magically refactor. I could have installed Pycharm, but where's the fun in that
Here's how my test files looked before:
@pytest.fixture()
def test_a(database): ...
def test_b(database): ...
def test_c(database, x): ...
def test_d(x, database): ...
def test_e(x, database, y): ...
After the refactor, this is how they look:
@pytest.fixture()
def test_a(db): ...
def test_b(db): ...
def test_c(db, x): ...
def test_d(x, db): ...
def test_e(x, db, y): ...
After struggling to achieve what I wanted with the tools I'd typically reach for (grep
+ sed
) I decided to try something a bit fancier.
Parsing nodes with Tree-Sitter
The first thing to do is to find all row/column of each database
identifier:
from pathlib import Path
import tree_sitter_python as tspython
from tree_sitter import Language, Parser
PY_LANGUAGE = Language(tspython.language())
parser = Parser(PY_LANGUAGE)
def parse_func(node):
for child in node.children:
if child.type == "parameters":
for sub_child in child.children:
if sub_child.type == "identifier" and sub_child.text == b"database":
yield sub_child.start_point
def parse_file(path):
tree = parser.parse(path.read_bytes())
for child in tree.root_node.children:
if child.type == "function_definition":
yield from parse_func(child)
def process_file(path):
for match in parse_file(path):
print(match)
This prints the location of all the instances of def test_
functions.
Point(row=7, column=11)
Point(row=10, column=11)
Point(row=13, column=14)
Point(row=16, column=14)
Handling decorated functions
The above code doesn't include support decorated functions, for example:
@pytest.fixture()
def test_a(database): ...
Decorators requires a bit more effort to handle correctly:
def parse_file(path):
tree = parser.parse(path.read_bytes())
for child in tree.root_node.children:
if child.type == "function_definition":
yield from parse_func(child)
elif child.type == "decorated_definition":
for sub_child in child.children:
if sub_child.type == "function_definition":
yield from parse_func(sub_child)
Renaming with Jedi
Now for each row/col I can use Jedi to rename the identifier:
def process_file(path):
for match in parse_file(path):
script = Script(code=path.read_text(), path=str(path))
result = script.rename(line=match.row + 1, column=match.column, new_name="db")
result.apply()
Conclusion
Ironically I ended up not merging this change, but was a fun learning exercise. I found both jedi
and tree-sitter
relatively easy to learn, I'll certainly be keeping them in my toolbelt for situations where grep
+ sed
don't quite cut it.
I do wish tree-sitter
had a mechanism to directly manipulate the AST. I was unable to simply rename/delete nodes and then write the AST back to disk. Instead I had to use Jedi
or manually edit the source (and then deal with nasty off-set re-parsing logic).
Note The astute amongst you will notice that this script does a lot of re-parsing. I could probably optimise this further, but for a quick project wide refactor I found this to be plenty fast enough.
Here's a video of it in action: