Jack's blog

Fun with trees in Python 🌳

I recently discovered this neat little trick in Python: using a recursive defaultdict to easily build a tree data-structure.

>>> from collections import defaultdict
>>> def tree(): return defaultdict(tree)

This allows you to very quickly build up a tree:

>>> t = tree()
>>> t[1][2][3]
>>> t[4][5]
>>> t[6][7]

Which if you try to print in your terminal will look pretty horrible. So to make it readable you can use another little trick:

>>> import json
>>> pprint(json.loads(json.dumps(t)), width=1)
{'1': {'2': {'3': {}}},
 '4': {'5': {}},
 '6': {'7': {}}}

Little helpers

It's pretty straightforward to then write a bunch of little utility functions for working with this structure.


Like getting all the nodes in the tree:

>>> def nodes(tree):
...     n = []
...     for k, v in tree.items():
...         n.append(k)
...         n.extend(nodes(v))
...     return n
...
>>> nodes(t)
[1, 2, 3, 4, 5, 6, 7]

Or just the leaf nodes:

>>> def leaves(tree):
...     n = []
...     for k, v in tree.items():
...         if not v:
...             n.append(k)
...         n.extend(leaves(v))
...     return n
...
>>> leaves(t)
[3, 5, 7]

Or even breath first traversal

>>> from collections import deque
>>> def bfs(tree):
...     n = []
...     q = deque(tree.items())
...     while q:
...         k, v = q.popleft()
...         n.append(k)
...         if v:
...             q.extend(v.items())
...     return n
...
>>> bfs(t)
[1, 4, 6, 2, 5, 7, 3]

Conclusion

All of this has been super relevant for me recently because I've been doing a lot stuff with tree structures in this years Advent Of Code.

My previous implementation takes the more classical class based approach:

@dataclass
class Tree:
    node: Point
    children: list[Self] = field(default_factory=list)

    @property
    def nodes(self: Self) -> Iterator[Point]:
        yield self.node
        for child in self.children:
            yield from child.nodes