add WIP text_tree_graph.py
[openpower-isa.git] / src / openpower / util / text_tree_graph.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 """ Draw Textual Tree from list of operations
5
6 https://bugs.libre-soc.org/show_bug.cgi?id=697
7 """
8
9
10 from dataclasses import dataclass, field
11 from enum import Enum
12 from re import T
13 from typing import Iterable
14 from cached_property import cached_property
15
16
17 class Op:
18 """Generic N-in M-out operation."""
19
20 def __init__(self, outs, ins):
21 self.outs = tuple(map(int, outs))
22 self.ins = tuple(map(int, ins))
23
24 @property
25 def name(self):
26 return self.__class__.__name__
27
28 def __str__(self):
29 outs = repr(self.outs) if len(self.outs) != 1 else repr(self.outs[0])
30 ins = repr(self.ins) if len(self.ins) != 1 else repr(self.ins[0])
31 return f"{self.name} {outs} <= {ins}"
32
33 def __repr__(self):
34 return f"{self.name}({self.outs!r}, {self.ins!r})"
35
36
37 @dataclass(frozen=True, unsafe_hash=True)
38 class _SSAReg:
39 reg: int
40 counter: int
41
42
43 @dataclass
44 class _RegState:
45 ssa_reg: _SSAReg
46 written_by: "_Cell | None"
47
48 @property
49 def tree_depth(self):
50 if self.written_by is None:
51 return 0
52 return self.written_by.tree_depth
53
54
55 @dataclass
56 class _Cell:
57 op: "Op | None"
58 outs: "tuple[_SSAReg, ...]"
59 ins: "tuple[_SSAReg, ...]"
60 tree_depth: int
61
62 @cached_property
63 def __op_text(self):
64 # only cache if op is set, otherwise debuggers could cache the empty
65 # value prematurely, causing the code to fail when debugged
66 assert self.op is not None
67 return str(self.op)
68
69 @property
70 def text(self):
71 if self.op is None:
72 return ""
73 return self.__op_text
74
75 @property
76 def io_coords_count(self):
77 return max(len(self.outs), len(self.ins))
78
79 @property
80 def cell_part_text_width(self):
81 """return the terminal width used by text"""
82 # Python doesn't have the right function needed to implement this,
83 # the correct function is something like:
84 # https://docs.rs/unicode-width/0.1.9/unicode_width/trait.UnicodeWidthStr.html#tymethod.width
85 # so, we just return something kinda sorta ok, sorry non-ascii people
86 text_width = len(self.text)
87 io_text_width = self.io_coords_count
88 return max(text_width, io_text_width)
89
90 @property
91 def cell_part_text_height(self):
92 return 1
93
94 @property
95 def grid_x(self):
96 if len(self.outs):
97 return self.outs[0].reg
98 if len(self.ins):
99 return self.ins[0].reg
100 return 0
101
102 @property
103 def grid_y(self):
104 return self.tree_depth
105
106 @property
107 def grid_pos(self):
108 return self.grid_x, self.grid_y
109
110 def clear_after_size_bump(self):
111 # TODO: add clearing locals with routing info
112 pass
113
114
115 class _RestartWithBiggerChannel(Exception):
116 pass
117
118
119 class _CheckIfFitsResult(Enum):
120 FITS = True
121 KEEP_LOOKING = False
122 CANCEL = "cancel"
123
124
125 @dataclass
126 class _RoutingChannelBase:
127 cell_coord: int
128 used: "set[tuple[_Coord, _Coord]]" = field(default_factory=set, init=False)
129 size: int = 0
130
131 def clear_after_size_bump(self):
132 self.used.clear()
133
134 def _allocate_segment(self, coord_range, flip_coords, check_if_fits=None):
135 """allocate a segment of a horizontal line extending to every x
136 in coord_range.
137 returns the allocated y _Coord.
138 flip_coords: bool
139 true if we're allocating a vertical line segment rather than
140 horizontal. exchanges x and y.
141 """
142 subcell_coord = 0
143 coord_range = list(coord_range)
144 while True:
145 y = _Coord(cell_coord=self.cell_coord, in_routing_channel=True,
146 subcell_coord=subcell_coord)
147 fits = True
148 for x in coord_range:
149 assert isinstance(x, _Coord)
150 pos = (y, x) if flip_coords else (x, y)
151 if pos in self.used:
152 fits = False
153 break
154 if fits and check_if_fits is not None:
155 result = check_if_fits(y)
156 if result == _CheckIfFitsResult.CANCEL:
157 return None
158 elif result == _CheckIfFitsResult.KEEP_LOOKING:
159 fits = False
160 else:
161 assert result == _CheckIfFitsResult.FITS
162 if not fits:
163 subcell_coord += 1
164 if subcell_coord >= self.size:
165 self.size += 1
166 raise _RestartWithBiggerChannel
167 continue
168 for x in coord_range:
169 assert isinstance(x, _Coord)
170 pos = (y, x) if flip_coords else (x, y)
171 self.used.add(pos)
172 return y
173
174
175 @dataclass
176 class _HorizontalRoutingChannel(_RoutingChannelBase):
177 def alloc_h_seg(self, x_range, check_if_fits=None):
178 """allocate a segment of a horizontal line extending to every x
179 in x_range.
180 returns the allocated y _Coord.
181 """
182 return self._allocate_segment(x_range, flip_coords=False,
183 check_if_fits=check_if_fits)
184
185
186 @dataclass
187 class _VerticalRoutingChannel(_RoutingChannelBase):
188 def alloc_v_seg(self, y_range, check_if_fits=None):
189 """allocate a segment of a vertical line extending to every y
190 in y_range.
191 returns the allocated x _Coord.
192 """
193 return self._allocate_segment(y_range, flip_coords=True,
194 check_if_fits=check_if_fits)
195
196
197 @dataclass(frozen=True, unsafe_hash=True)
198 class _Coord:
199 r"""
200 Coordinates:
201 cell_x
202 /---------------------------^-------------------------\
203 | |
204 rx=0 rx=1 rx=2 rx=3 ox=0 ox=1 ox=2 ox=3
205 /- +--------------------------+--------------------------+
206 ry=0 | | Routing Channel -------- | Routing Channel -------- | ry=0
207 | | horizontal coord: | | horizontal coord: | |
208 ry=1 | | cell_coord=cell_x ------ | cell_coord=cell_x ------ | ry=1
209 | | in_routing_channel=True | in_routing_channel=False |
210 ry=2 | | subcell_coord=rx ------- | subcell_coord=ox ------- | ry=2
211 | | vertical coord: | | vertical coord: | |
212 ry=3 | | cell_coord=cell_y ------ | cell_coord=cell_y ------ | ry=3
213 | | in_routing_channel=True | in_routing_channel=True |
214 ry=4 | | subcell_coord=ry ------- | subcell_coord=ry ------- | ry=4
215 | | | | | | | | | | | |
216 | +--------------------------+--------------------------+
217 cell_y < | Routing Channel | | horizontal coord: | |
218 | | horizontal coord: | | cell_coord=cell_x | |
219 | | cell_coord=cell_x | | in_routing_channel=False |
220 | | in_routing_channel=True | subcell_coord=ox | |
221 | | subcell_coord=rx | | vertical coord: | |
222 | | vertical coord: | | cell_coord=cell_y | |
223 | | cell_coord=cell_y | | in_routing_channel=False |
224 | | in_routing_channel=False | subcell_coord=oy | |
225 | | subcell_coord=oy | | | | | | |
226 | | | | | | | V V V V |
227 oy=0 | | | | | | | +-In0--In1--In2--In3--+ | oy=0
228 | | | | | | | | Op | |
229 oy=1 | | | | | | | +-Out0-Out1-Out2------+ | oy=1
230 | | | | | | | | | | |
231 \- +--------------------------+--------------------------+
232 rx=0 rx=1 rx=2 rx=3 ox=0 ox=1 ox=2 ox=3
233 """
234 cell_coord: int
235 in_routing_channel: bool
236 subcell_coord: int
237
238 def __lt__(self, other):
239 if not isinstance(other, _Coord):
240 return NotImplemented
241 if self.cell_coord < other.cell_coord:
242 return True
243 if self.cell_coord > other.cell_coord:
244 return False
245 if self.in_routing_channel and not other.in_routing_channel:
246 return True
247 if not self.in_routing_channel and other.in_routing_channel:
248 return False
249 return self.subcell_coord < other.subcell_coord
250
251 def __le__(self, other):
252 if not isinstance(other, _Coord):
253 return NotImplemented
254 return not other.__lt__(self)
255
256 def __gt__(self, other):
257 if not isinstance(other, _Coord):
258 return NotImplemented
259 return other.__lt__(self)
260
261 def __ge__(self, other):
262 if not isinstance(other, _Coord):
263 return NotImplemented
264 return not self.__lt__(other)
265
266
267 @dataclass
268 class _Route:
269 """route, made of horizontal and vertical lines,
270 from some op's output to another op's input.
271 """
272 coords: "list[_Coord]" = field(default_factory=list)
273 """alternating x and y coords for the route, starting with y"""
274
275 def __len__(self):
276 """number of points in this route"""
277 return max(len(self.coords) - 1, 0)
278
279 def __getitem__(self, index):
280 assert isinstance(index, int)
281 if index < 0:
282 index += len(self)
283 assert 0 <= index < len(self)
284 c0 = self.coords[index]
285 c1 = self.coords[index + 1]
286 if index % 2 != 0:
287 return c0, c1
288 return c1, c0
289
290 def __iter__(self):
291 for i in range(len(self)):
292 yield self[i]
293
294 @property
295 def start_pos(self):
296 return self[0]
297
298 @property
299 def end_pos(self):
300 return self[-1]
301
302 def __str__(self):
303 return f"Route{{{' -> '.join(map(repr, self))}}}"
304
305
306 @dataclass
307 class _GridRow:
308 cells: "list[_Cell | None]"
309 routing_channel: _HorizontalRoutingChannel
310 cell_part_text_height: int = 0
311 text_y_start: "int | None" = None
312
313 @property
314 def text_height(self):
315 return self.cell_part_text_height + self.routing_channel.size
316
317 def __init__(self, cell_y, x_size):
318 assert isinstance(x_size, int)
319 self.cells = [None] * x_size
320 self.routing_channel = _HorizontalRoutingChannel(cell_coord=cell_y)
321
322 def clear_after_size_bump(self):
323 for cell in self.cells:
324 if cell is not None:
325 cell.clear_after_size_bump()
326 self.routing_channel.clear_after_size_bump()
327 self.cell_part_text_height = 0
328 self.text_y_start = None
329
330
331 @dataclass
332 class _GridCol:
333 routing_channel: _VerticalRoutingChannel
334 cell_part_text_width: int = 0
335 io_coords_count: int = 0
336 text_x_start: "int | None" = None
337
338 @property
339 def text_width(self):
340 return self.cell_part_text_width + self.routing_channel.size
341
342 def __init__(self, cell_x):
343 self.routing_channel = _VerticalRoutingChannel(cell_coord=cell_x)
344
345 def clear_after_size_bump(self):
346 self.routing_channel.clear_after_size_bump()
347 self.cell_part_text_width = 0
348 self.text_x_start = None
349
350
351 @dataclass
352 class _Grid:
353 cols: "list[_GridCol]"
354 rows: "list[_GridRow]"
355 x_coords: "list[_Coord]"
356 x_coords_indexes: "dict[_Coord, int]"
357 y_coords: "list[_Coord]"
358 y_coords_indexes: "dict[_Coord, int]"
359
360 def __init__(self, x_size, y_size):
361 self.cols = [_GridCol(cell_x) for cell_x in range(x_size)]
362 self.rows = [_GridRow(cell_y, x_size) for cell_y in range(y_size)]
363 self.x_coords = []
364 self.x_coords_indexes = {}
365 self.y_coords = []
366 self.y_coords_indexes = {}
367
368 def clear_after_size_bump(self):
369 for col in self.cols:
370 col.clear_after_size_bump()
371 for row in self.rows:
372 row.clear_after_size_bump()
373
374 def calc_positions_and_sizes(self):
375 self.x_coords = []
376 self.y_coords = []
377 text_y = 0
378 for cell_y, row in enumerate(self.rows):
379 row.text_y_start = text_y
380 for cell_x, cell in enumerate(row.cells):
381 if cell is None:
382 continue
383 col = self.cols[cell_x]
384 col.cell_part_text_width = max(col.cell_part_text_width,
385 cell.cell_part_text_width)
386 row.cell_part_text_height = max(row.cell_part_text_height,
387 cell.cell_part_text_height)
388 col.io_coords_count = max(col.io_coords_count,
389 cell.io_coords_count)
390 for subcell_coord in range(row.routing_channel.size):
391 self.y_coords.append(_Coord(cell_coord=cell_y,
392 in_routing_channel=True,
393 subcell_coord=subcell_coord))
394 self.y_coords.append(_Coord(cell_coord=cell_y,
395 in_routing_channel=False,
396 subcell_coord=0))
397 self.y_coords.append(_Coord(cell_coord=cell_y,
398 in_routing_channel=False,
399 subcell_coord=1))
400 text_y += row.text_height
401 text_x = 0
402 for cell_x, col in enumerate(self.cols):
403 col.text_x_start = text_x
404 for subcell_coord in range(col.routing_channel.size):
405 self.x_coords.append(_Coord(cell_coord=cell_x,
406 in_routing_channel=True,
407 subcell_coord=subcell_coord))
408 for subcell_coord in range(col.io_coords_count):
409 self.x_coords.append(_Coord(cell_coord=cell_x,
410 in_routing_channel=False,
411 subcell_coord=subcell_coord))
412 text_x += col.text_width
413 assert self.x_coords == sorted(self.x_coords), \
414 "mismatch with _Coord comparison"
415 assert self.y_coords == sorted(self.y_coords), \
416 "mismatch with _Coord comparison"
417 self.x_coords_indexes = {x: i for i, x in enumerate(self.x_coords)}
418 self.y_coords_indexes = {y: i for i, y in enumerate(self.y_coords)}
419
420 def text_x(self, x_coord):
421 assert isinstance(x_coord, _Coord)
422 col = self.cols[x_coord.cell_coord]
423 assert col.text_x_start is not None
424 if x_coord.in_routing_channel:
425 return col.text_x_start + x_coord.subcell_coord
426 else:
427 return (col.text_x_start + col.routing_channel.size
428 + x_coord.subcell_coord)
429
430 def text_y(self, y_coord):
431 assert isinstance(y_coord, _Coord)
432 row = self.rows[y_coord.cell_coord]
433 assert row.text_y_start is not None
434 if y_coord.in_routing_channel:
435 return row.text_y_start + y_coord.subcell_coord
436 else:
437 return (row.text_y_start + row.routing_channel.size
438 + y_coord.subcell_coord)
439
440 def __getitem__(self, pos):
441 x, y = pos
442 assert isinstance(x, int)
443 assert isinstance(y, int)
444 return self.rows[y].cells[x]
445
446 def __setitem__(self, pos, value):
447 assert value is None or isinstance(value, _Cell)
448 x, y = pos
449 assert isinstance(x, int)
450 assert isinstance(y, int)
451 self.rows[y].cells[x] = value
452
453 def range_x_coord(self, first_x, last_x):
454 """return all x `_Coord`s in first_x to last_x inclusive"""
455 assert isinstance(first_x, _Coord)
456 assert isinstance(last_x, _Coord)
457 first = self.x_coords_indexes[first_x]
458 last = self.x_coords_indexes[last_x]
459 if first < last:
460 return self.x_coords[first:last + 1]
461 return self.x_coords[last:first + 1]
462
463 def range_y_coord(self, first_y, last_y):
464 """return all y `_Coord`s in first_y to last_y inclusive"""
465 assert isinstance(first_y, _Coord)
466 assert isinstance(last_y, _Coord)
467 first = self.y_coords_indexes[first_y]
468 last = self.y_coords_indexes[last_y]
469 if first < last:
470 return self.y_coords[first:last + 1]
471 return self.y_coords[last:first + 1]
472
473 def alloc_h_seg(self, src_x, dest_x, cell_y, check_if_fits=None):
474 assert isinstance(src_x, _Coord)
475 assert isinstance(dest_x, _Coord)
476 assert isinstance(cell_y, int)
477 horiz_rc = self.rows[cell_y].routing_channel
478 r = self.range_x_coord(src_x, dest_x)
479 return horiz_rc.alloc_h_seg(r, check_if_fits=check_if_fits)
480
481 def alloc_v_seg(self, cell_x, src_y, dest_y, check_if_fits=None):
482 assert isinstance(cell_x, int)
483 assert isinstance(src_y, _Coord)
484 assert isinstance(dest_y, _Coord)
485 vert_rc = self.cols[cell_x].routing_channel
486 r = self.range_y_coord(src_y, dest_y)
487 return vert_rc.alloc_v_seg(r, check_if_fits=check_if_fits)
488
489 def allocate_route(self, dest_op_input_index, dest_cell_pos,
490 src_op_output_index, src_cell_pos):
491 assert isinstance(dest_op_input_index, int)
492 dest_cell_x, dest_cell_y = dest_cell_pos
493 assert isinstance(dest_cell_x, int)
494 assert isinstance(dest_cell_y, int)
495 assert isinstance(src_op_output_index, int)
496 src_cell_x, src_cell_y = src_cell_pos
497 assert isinstance(src_cell_x, int)
498 assert isinstance(src_cell_y, int)
499 assert dest_cell_y > src_cell_y, "bad route passed in"
500 src_x = _Coord(cell_coord=src_cell_x,
501 in_routing_channel=False,
502 subcell_coord=src_op_output_index)
503 src_y = _Coord(cell_coord=src_cell_y,
504 in_routing_channel=False,
505 subcell_coord=1)
506 dest_x = _Coord(cell_coord=dest_cell_x,
507 in_routing_channel=False,
508 subcell_coord=dest_op_input_index)
509 dest_y = _Coord(cell_coord=dest_cell_y,
510 in_routing_channel=False,
511 subcell_coord=0)
512 if dest_cell_y == src_cell_y + 1:
513 # no intervening cells vertically
514 if src_x == dest_x:
515 # straight line from src to dest
516 return _Route([src_y, src_x, dest_y])
517 rc_y = self.alloc_h_seg(src_x, dest_x, dest_cell_y)
518 assert rc_y is not None
519 return _Route([
520 # start
521 src_y, src_x,
522 # go to routing channel
523 rc_y,
524 # go horizontally to dest x
525 dest_x,
526 # go vertically to dest y
527 dest_y,
528 ])
529 else:
530 def check_if_fits(y):
531 raise NotImplementedError
532 raise NotImplementedError
533 todo_x = ... # FIXME finish
534 src_horiz_rc_y = self.alloc_h_seg(src_x, todo_x, dest_cell_y,
535 check_if_fits=check_if_fits)
536 raise NotImplementedError
537
538
539 @dataclass
540 class _Regs:
541 __regs: "list[_RegState]" = field(default_factory=list)
542
543 def get(self, reg):
544 assert isinstance(reg, int) and reg >= 0
545 for i in range(len(self.__regs), reg + 1):
546 self.__regs.append(_RegState(_SSAReg(i, 0), None))
547 return self.__regs[reg]
548
549 def __len__(self):
550 return len(self.__regs)
551
552
553 def render_tree(program, indent_str=""):
554 """draw a tree of operations. returns a string with the rendered tree.
555 program: Iterable[Op]
556 """
557 # build ops_graph
558 ops_graph: "dict[_SSAReg, _Cell]"
559 ops_graph = {}
560 regs = _Regs()
561 cells: "list[_Cell]" = []
562 for op in program:
563 assert isinstance(op, Op)
564 ins = tuple(regs.get(reg).ssa_reg for reg in op.ins)
565 tree_depth = max(regs.get(reg).tree_depth for reg in op.ins) + 1
566 outs = tuple(regs.get(reg).ssa_reg for reg in op.outs)
567 assert len(set(outs)) == len(outs), \
568 f"duplicate output registers on the same instruction: {op}"
569 cell = _Cell(
570 op=op, outs=outs, ins=ins, tree_depth=tree_depth)
571 for out in op.outs:
572 out_reg = regs.get(out)
573 out_reg.ssa_reg = out = _SSAReg(out, out_reg.ssa_reg.counter + 1)
574 ops_graph[out] = out_reg.written_by = cell
575 cells.append(cell)
576
577 # generate output grid
578 grid = _Grid(x_size=len(regs),
579 y_size=max(i.grid_y for i in ops_graph.values()) + 1)
580 for cell in cells:
581 grid[cell.grid_pos] = cell
582 raise NotImplementedError
583
584
585 def print_tree(program, indent_str=""):
586 """draw a tree of operations. prints the tree to stdout.
587 program: Iterable[Op]
588 """
589 print(render_tree(program, indent_str=indent_str))