Fun with trees in Python 🌳
I recently discovered this neat little trick in Python: using a recursive defaultdict
to easily build a tree data-structure.
>>> from collections import defaultdict
>>> def tree(): return defaultdict(tree)
This allows you to very quickly build up a tree:
>>> t = tree()
>>> t[1][2][3]
>>> t[4][5]
>>> t[6][7]
Which if you try to print in your terminal will look pretty horrible. So to make it readable you can use another little trick:
>>> import json
>>> pprint(json.loads(json.dumps(t)), width=1)
{'1': {'2': {'3': {}}},
'4': {'5': {}},
'6': {'7': {}}}
Little helpers
It's pretty straightforward to then write a bunch of little utility functions for working with this structure.
Like getting all the nodes in the tree:
>>> def nodes(tree):
... n = []
... for k, v in tree.items():
... n.append(k)
... n.extend(nodes(v))
... return n
...
>>> nodes(t)
[1, 2, 3, 4, 5, 6, 7]
Or just the leaf nodes:
>>> def leaves(tree):
... n = []
... for k, v in tree.items():
... if not v:
... n.append(k)
... n.extend(leaves(v))
... return n
...
>>> leaves(t)
[3, 5, 7]
Or even breath first traversal
>>> from collections import deque
>>> def bfs(tree):
... n = []
... q = deque(tree.items())
... while q:
... k, v = q.popleft()
... n.append(k)
... if v:
... q.extend(v.items())
... return n
...
>>> bfs(t)
[1, 4, 6, 2, 5, 7, 3]
Conclusion
All of this has been super relevant for me recently because I've been doing a lot stuff with tree structures in this years Advent Of Code.
My previous implementation takes the more classical class based approach:
@dataclass
class Tree:
node: Point
children: list[Self] = field(default_factory=list)
@property
def nodes(self: Self) -> Iterator[Point]:
yield self.node
for child in self.children:
yield from child.nodes