fhdl/verilog: make signal behave as integers in arithmetic (MyHDL style)
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Thu, 29 Nov 2012 21:59:54 +0000 (22:59 +0100)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Thu, 29 Nov 2012 21:59:54 +0000 (22:59 +0100)
See http://jandecaluwe.com/hdldesign/counting.html

migen/fhdl/verilog.py

index b334762933c5ff121704a0f41a11430a2200ba5f..49c31e9eb477a683a9e62915d49956896863aae8 100644 (file)
@@ -20,14 +20,14 @@ def _printsig(ns, s):
 def _printintbool(node):
        if isinstance(node, bool):
                if node:
-                       return "1'd1"
+                       return "1'd1", False
                else:
-                       return "1'd0"
+                       return "1'd0", False
        elif isinstance(node, int):
                if node >= 0:
-                       return str(bits_for(node)) + "'d" + str(node)
+                       return str(bits_for(node)) + "'d" + str(node), False
                else:
-                       return "-" + str(bits_for(node)) + "'sd" + str(-node)
+                       return "-" + str(bits_for(node)) + "'sd" + str(-node), True
        else:
                raise TypeError
 
@@ -35,16 +35,32 @@ def _printexpr(ns, node):
        if isinstance(node, (int, bool)):
                return _printintbool(node)
        elif isinstance(node, Signal):
-               return ns.get_name(node)
+               return ns.get_name(node), node.signed
        elif isinstance(node, _Operator):
                arity = len(node.operands)
+               r1, s1 = _printexpr(ns, node.operands[0])
                if arity == 1:
-                       r = node.op + _printexpr(ns, node.operands[0])
+                       if node.op == "-":
+                               if s1:
+                                       r = node.op + r1
+                               else:
+                                       r = "-$signed({1'd0, " + r1 + "})"
+                               s = True
+                       else:
+                               r = node.op + r1
+                               s = s1
                elif arity == 2:
-                       r = _printexpr(ns, node.operands[0]) + " " + node.op + " " + _printexpr(ns, node.operands[1])
+                       r2, s2 = _printexpr(ns, node.operands[1])
+                       if node.op in ["+", "-", "*", "&", "^", "|"]:
+                               if s2 and not s1:
+                                       r1 = "$signed({1'd0, " + r1 + "})"
+                               if s1 and not s2:
+                                       r2 = "$signed({1'd0, " + r2 + "})"
+                       r = r1 + " " + node.op + " " + r2
+                       s = s1 or s2
                else:
                        raise TypeError
-               return "(" + r + ")"
+               return "(" + r + ")", s
        elif isinstance(node, _Slice):
                # Verilog does not like us slicing non-array signals...
                if isinstance(node.value, Signal) \
@@ -56,13 +72,13 @@ def _printexpr(ns, node):
                        sr = "[" + str(node.start) + "]"
                else:
                        sr = "[" + str(node.stop-1) + ":" + str(node.start) + "]"
-               return _printexpr(ns, node.value) + sr
+               r, s = _printexpr(ns, node.value)
+               return r + sr, s
        elif isinstance(node, Cat):
-               l = list(map(partial(_printexpr, ns), node.l))
-               l.reverse()
-               return "{" + ", ".join(l) + "}"
+               l = [_printexpr(ns, v)[0] for v in reversed(node.l)]
+               return "{" + ", ".join(l) + "}", False
        elif isinstance(node, Replicate):
-               return "{" + str(node.n) + "{" + _printexpr(ns, node.v) + "}}"
+               return "{" + str(node.n) + "{" + _printexpr(ns, node.v) + "}}", False
        else:
                raise TypeError
 
@@ -80,11 +96,11 @@ def _printnode(ns, at, level, node):
                        assignment = " = "
                else:
                        assignment = " <= "
-               return "\t"*level + _printexpr(ns, node.l) + assignment + _printexpr(ns, node.r) + ";\n"
+               return "\t"*level + _printexpr(ns, node.l)[0] + assignment + _printexpr(ns, node.r)[0] + ";\n"
        elif isinstance(node, list):
                return "".join(list(map(partial(_printnode, ns, at, level), node)))
        elif isinstance(node, If):
-               r = "\t"*level + "if (" + _printexpr(ns, node.cond) + ") begin\n"
+               r = "\t"*level + "if (" + _printexpr(ns, node.cond)[0] + ") begin\n"
                r += _printnode(ns, at, level + 1, node.t)
                if node.f:
                        r += "\t"*level + "end else begin\n"
@@ -93,10 +109,10 @@ def _printnode(ns, at, level, node):
                return r
        elif isinstance(node, Case):
                if node.cases:
-                       r = "\t"*level + "case (" + _printexpr(ns, node.test) + ")\n"
+                       r = "\t"*level + "case (" + _printexpr(ns, node.test)[0] + ")\n"
                        css = sorted([(k, v) for (k, v) in node.cases.items() if k != "default"], key=itemgetter(0))
                        for choice, statements in css:
-                               r += "\t"*(level + 1) + _printexpr(ns, choice) + ": begin\n"
+                               r += "\t"*(level + 1) + _printexpr(ns, choice)[0] + ": begin\n"
                                r += _printnode(ns, at, level + 2, statements)
                                r += "\t"*(level + 1) + "end\n"
                        if "default" in node.cases:
@@ -176,7 +192,7 @@ def _printcomb(f, ns, display_run):
                                if display_run:
                                        r += "\t$display(\"Running comb block #" + str(n) + "\");\n"
                                for t in g[0]:
-                                       r += "\t" + ns.get_name(t) + " <= " + _printexpr(ns, t.reset) + ";\n"
+                                       r += "\t" + ns.get_name(t) + " <= " + _printexpr(ns, t.reset)[0] + ";\n"
                                r += _printnode(ns, _AT_NONBLOCKING, 1, g[1])
                                r += syn_off
                                r += "\t" + ns.get_name(dummy_d) + " <= " + ns.get_name(dummy_s) + ";\n"
@@ -223,7 +239,7 @@ def _printinstances(f, ns, clock_domains):
                for p in x.items:
                        if isinstance(p, Instance._IO):
                                name_inst = p.name
-                               name_design = _printexpr(ns, p.expr)
+                               name_design = _printexpr(ns, p.expr)[0]
                        elif isinstance(p, Instance.ClockPort):
                                name_inst = p.name_inst
                                name_design = ns.get_name(clock_domains[p.domain].clk)
@@ -259,7 +275,7 @@ def _printinit(f, ios, ns):
        if signals:
                r += "initial begin\n"
                for s in sorted(signals, key=lambda x: x.huid):
-                       r += "\t" + ns.get_name(s) + " <= " + _printexpr(ns, s.reset) + ";\n"
+                       r += "\t" + ns.get_name(s) + " <= " + _printexpr(ns, s.reset)[0] + ";\n"
                r += "end\n\n"
        return r