treeviz: improve layout of unbalanced trees
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Wed, 7 Aug 2013 16:32:02 +0000 (18:32 +0200)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Wed, 7 Aug 2013 16:32:02 +0000 (18:32 +0200)
migen/graph/treeviz.py

index 4e4c7750b88e22d39c1beb01b6da020c89c358f0..9e5a44fd7a342918061ece4fcb56fda96ecb61b7 100644 (file)
@@ -1,9 +1,11 @@
 import cairo
 import math
 
-def _cairo_draw_node(ctx, radius, color, outer_color, s):
+def _cairo_draw_node(ctx, dx, radius, color, outer_color, s):
        ctx.save()
 
+       ctx.translate(dx, 0)
+
        ctx.set_line_width(0.0)
        gradient_color = cairo.RadialGradient(0, 0, 0, 0, 0, radius)
        gradient_color.add_color_stop_rgb(0, *color)
@@ -40,36 +42,44 @@ class RenderNode:
                self.radius = radius
                self.pitch = self.radius*3
 
-       def get_extents(self):
+       def get_dimensions(self):
                if self.children:
-                       cw, ch = zip(*[c.get_extents() for c in self.children])
-                       w = max(cw)*len(self.children)
-                       h = self.pitch + max(ch)
+                       cws, chs, cdxs = zip(*[c.get_dimensions() for c in self.children])
+                       w = sum(cws)
+                       h = self.pitch + max(chs)
+                       dx = cws[0]/4 - cws[-1]/4
                else:
                        w = h = self.pitch
-               return w, h
+                       dx = 0
+               return w, h, dx
 
        def render(self, ctx):
-               _cairo_draw_node(ctx, self.radius, self.color, self.outer_color, self.label)
                if self.children:
-                       cpitch = max([c.get_extents()[0] for c in self.children])
-                       first_child_x = -(cpitch*(len(self.children) - 1))/2
+                       cws, chs, cdxs = zip(*[c.get_dimensions() for c in self.children])
+                       first_child_x = -sum(cws)/2
 
                        ctx.save()
                        ctx.translate(first_child_x, self.pitch)
-                       for c in self.children:
+                       for c, w in zip(self.children, cws):
+                               ctx.translate(w/2, 0)
                                c.render(ctx)
-                               ctx.translate(cpitch, 0)
+                               ctx.translate(w/2, 0)
                        ctx.restore()
 
+                       dx = cws[0]/4 - cws[-1]/4
+
                        current_x = first_child_x
-                       for c in self.children:
+                       for c, w, cdx in zip(self.children, cws, cdxs):
                                current_y = self.pitch - c.radius
-                               _cairo_draw_connection(ctx, 0, self.radius, self.outer_color, current_x, current_y, c.outer_color)
-                               current_x += cpitch
+                               current_x += w/2
+                               _cairo_draw_connection(ctx, dx, self.radius, self.outer_color, current_x+cdx, current_y, c.outer_color)
+                               current_x += w/2
+               else:
+                       dx = 0
+               _cairo_draw_node(ctx, dx, self.radius, self.color, self.outer_color, self.label)
 
        def to_svg(self, name):
-               w, h = self.get_extents()
+               w, h, dx = self.get_dimensions()
                surface = cairo.SVGSurface(name, w, h)
                ctx = cairo.Context(surface)
                ctx.translate(w/2, self.pitch/2)