Remove Constant
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Wed, 28 Nov 2012 22:18:43 +0000 (23:18 +0100)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Wed, 28 Nov 2012 22:18:43 +0000 (23:18 +0100)
16 files changed:
migen/actorlib/structuring.py
migen/bank/csrgen.py
migen/bus/asmibus.py
migen/bus/dfi.py
migen/bus/wishbone.py
migen/corelogic/divider.py
migen/corelogic/fsm.py
migen/corelogic/misc.py
migen/corelogic/record.py
migen/corelogic/roundrobin.py
migen/fhdl/structure.py
migen/fhdl/tools.py
migen/fhdl/verilog.py
migen/fhdl/visit.py
migen/pytholite/expr.py
migen/pytholite/reg.py

index 9b798ee5e6a5fd035de48e342eed838e82464fe4..3402e2f08af9276772f66bbfb5431bb2b13e7049 100644 (file)
@@ -50,7 +50,7 @@ class Unpack(Actor):
                                )
                        )
                ]
-               cases = [(Constant(i, BV(muxbits)) if i else Default(),
+               cases = [(i if i else Default(),
                        Cat(*self.token("source").flatten()).eq(Cat(*self.token("sink").subrecord("chunk{0}".format(i)).flatten())))
                        for i in range(self.n)]
                comb.append(Case(mux, *cases))
@@ -69,7 +69,7 @@ class Pack(Actor):
                
                load_part = Signal()
                strobe_all = Signal()
-               cases = [(Constant(i, BV(demuxbits)),
+               cases = [(i,
                        Cat(*self.token("source").subrecord("chunk{0}".format(i)).flatten()).eq(*self.token("sink").flatten()))
                        for i in range(self.n)]
                comb = [
index b49ea56a84949fa33f445a9addd1893e95e5e64b..ae942a486fbe0c7b1826f8a54f49fb15ef07d568 100644 (file)
@@ -15,7 +15,7 @@ class Bank:
                sync = []
                
                sel = Signal()
-               comb.append(sel.eq(self.interface.adr[9:] == Constant(self.address, BV(5))))
+               comb.append(sel.eq(self.interface.adr[9:] == self.address))
                
                desc_exp = expand_description(self.description, csr.data_width)
                nbits = bits_for(len(desc_exp)-1)
@@ -27,9 +27,9 @@ class Bank:
                                comb.append(reg.r.eq(self.interface.dat_w[:reg.size]))
                                comb.append(reg.re.eq(sel & \
                                        self.interface.we & \
-                                       (self.interface.adr[:nbits] == Constant(i, BV(nbits)))))
+                                       (self.interface.adr[:nbits] == i)))
                        elif isinstance(reg, RegisterFields):
-                               bwra = [Constant(i, BV(nbits))]
+                               bwra = [i]
                                offset = 0
                                for field in reg.fields:
                                        if field.access_bus == WRITE_ONLY or field.access_bus == READ_WRITE:
@@ -51,7 +51,7 @@ class Bank:
                brcases = []
                for i, reg in enumerate(desc_exp):
                        if isinstance(reg, RegisterRaw):
-                               brcases.append([Constant(i, BV(nbits)), self.interface.dat_r.eq(reg.w)])
+                               brcases.append([i, self.interface.dat_r.eq(reg.w)])
                        elif isinstance(reg, RegisterFields):
                                brs = []
                                reg_readable = False
@@ -60,9 +60,9 @@ class Bank:
                                                brs.append(field.storage)
                                                reg_readable = True
                                        else:
-                                               brs.append(Constant(0, BV(field.size)))
+                                               brs.append(Replicate(0, field.size))
                                if reg_readable:
-                                       brcases.append([Constant(i, BV(nbits)), self.interface.dat_r.eq(Cat(*brs))])
+                                       brcases.append([i, self.interface.dat_r.eq(Cat(*brs))])
                        else:
                                raise TypeError
                if brcases:
index ff2080c0bc952fbb36d1cef5e010850b94fba0e6..1b47a86bd1ff3225efefbda9a3674255c9f5002f 100644 (file)
@@ -83,7 +83,7 @@ class Port:
                if not self.finalized:
                        raise FinalizeError
                return self.call \
-                       & (self.tag_call == Constant(self.base + slotn, BV(self.tagbits)))
+                       & (self.tag_call == (self.base + slotn))
                
        def get_fragment(self):
                if not self.finalized:
index f965ecdc94b9824bd6e20b3189b0482ad7a09dae..d1c2cf7202f66834895a80d46f1888f4722766dc 100644 (file)
@@ -25,10 +25,10 @@ class Interface:
                self.pdesc = phase_description(a, ba, d)
                self.phases = [SimpleInterface(self.pdesc) for i in range(nphases)]
                for p in self.phases:
-                       p.cas_n.reset = Constant(1)
-                       p.cs_n.reset = Constant(1)
-                       p.ras_n.reset = Constant(1)
-                       p.we_n.reset = Constant(1)
+                       p.cas_n.reset = 1
+                       p.cs_n.reset = 1
+                       p.ras_n.reset = 1
+                       p.we_n.reset = 1
        
        # Returns pairs (DFI-mandated signal name, Migen signal object)
        def get_standard_names(self, m2s=True, s2m=True):
index e55da5de142dd7ad60e66fd656890c38c0733bf8..1fc699ac246d1028edc65301702c6d87cf2b6aa5 100644 (file)
@@ -47,7 +47,7 @@ class Arbiter:
                        for i, m in enumerate(self.masters):
                                dest = getattr(m, name)
                                if name == "ack" or name == "err":
-                                       comb.append(dest.eq(source & (self.rr.grant == Constant(i, self.rr.grant.bv))))
+                                       comb.append(dest.eq(source & (self.rr.grant == i)))
                                else:
                                        comb.append(dest.eq(source))
                
@@ -59,27 +59,15 @@ class Arbiter:
 
 class Decoder:
        # slaves is a list of pairs:
-       # 0) structure.Constant defining address (always decoded on the upper bits)
-       #    Slaves can have differing numbers of address bits, but addresses 
-       #    must not conflict.
-       # 1) wishbone.Slave reference
-       # Addresses are decoded from bit 31-offset and downwards.
+       # 0) function that takes the address signal and returns a FHDL expression
+       #    that evaluates to 1 when the slave is selected and 0 otherwise.
+       # 1) wishbone.Slave reference.
        # register adds flip-flops after the address comparators. Improves timing,
        # but breaks Wishbone combinatorial feedback.
-       def __init__(self, master, slaves, offset=0, register=False):
+       def __init__(self, master, slaves, register=False):
                self.master = master
                self.slaves = slaves
-               self.offset = offset
                self.register = register
-               
-               addresses = [slave[0] for slave in self.slaves]
-               maxbits = max([bits_for(addr) for addr in addresses])
-               def mkconst(x):
-                       if isinstance(x, int):
-                               return Constant(x, BV(maxbits))
-                       else:
-                               return x
-               self.addresses = list(map(mkconst, addresses))
 
        def get_fragment(self):
                comb = []
@@ -90,9 +78,8 @@ class Decoder:
                slave_sel_r = Signal(BV(ns))
                
                # decode slave addresses
-               hi = len(self.master.adr) - self.offset
-               comb += [slave_sel[i].eq(self.master.adr[hi-len(addr):hi] == addr)
-                       for i, addr in enumerate(self.addresses)]
+               comb += [slave_sel[i].eq(fun(self.master.adr))
+                       for i, (fun, bus) in enumerate(self.slaves)]
                if self.register:
                        sync.append(slave_sel_r.eq(slave_sel))
                else:
@@ -120,11 +107,10 @@ class Decoder:
                return Fragment(comb, sync)
 
 class InterconnectShared:
-       def __init__(self, masters, slaves, offset=0, register=False):
+       def __init__(self, masters, slaves, register=False):
                self._shared = Interface()
                self._arbiter = Arbiter(masters, self._shared)
-               self._decoder = Decoder(self._shared, slaves, offset, register)
-               self.addresses = self._decoder.addresses
+               self._decoder = Decoder(self._shared, slaves, register)
        
        def get_fragment(self):
                return self._arbiter.get_fragment() + self._decoder.get_fragment()
index 50b4af18f7d30cc17f326a45a7886d0e47c2c2fd..edac22e835aeeb8e5110759085479e372b5b9231 100644 (file)
@@ -22,7 +22,7 @@ class Divider:
                comb = [
                        self.quotient_o.eq(qr[:w]),
                        self.remainder_o.eq(qr[w:]),
-                       self.ready_o.eq(counter == Constant(0, counter.bv)),
+                       self.ready_o.eq(counter == 0),
                        diff.eq(self.remainder_o - divisor_r)
                ]
                sync = [
@@ -36,7 +36,7 @@ class Divider:
                                        ).Else(
                                                qr.eq(Cat(1, qr[:w-1], diff[:w]))
                                        ),
-                                       counter.eq(counter - Constant(1, counter.bv))
+                                       counter.eq(counter - 1)
                        )
                ]
                return Fragment(comb, sync)
index 715dcb5f1297f239adeaebc48c48ce87f513f3d6..dbd4c02a0c33335a03d2b252b9d210c8e3396650 100644 (file)
@@ -8,16 +8,16 @@ class FSM:
                self._state = Signal(self._state_bv)
                self._next_state = Signal(self._state_bv)
                for n, state in enumerate(states):
-                       setattr(self, state, Constant(n, self._state_bv))
+                       setattr(self, state, n)
                self.actions = [[] for i in range(len(states))]
                
                for name, target, delay in delayed_enters:
                        target_state = getattr(self, target)
                        if delay:
                                name_state = len(self.actions)
-                               setattr(self, name, Constant(name_state, self._state_bv))
+                               setattr(self, name, name_state)
                                for i in range(delay-1):
-                                       self.actions.append([self.next_state(Constant(name_state+i+1, self._state_bv))])
+                                       self.actions.append([self.next_state(name_state+i+1)])
                                self.actions.append([self.next_state(target_state)])
                        else:
                                # alias
@@ -30,11 +30,10 @@ class FSM:
                return self._next_state.eq(state)
        
        def act(self, state, *statements):
-               self.actions[state.n] += statements
+               self.actions[state] += statements
        
        def get_fragment(self):
-               cases = [[Constant(s, self._state_bv)] + a
-                       for s, a in enumerate(self.actions) if a]
+               cases = [[s] + a for s, a in enumerate(self.actions) if a]
                comb = [
                        self._next_state.eq(self._state),
                        Case(self._state, *cases)
index fd71983f8aa422bf2097d1d66015ac9ce85d3062..a4da7dcf2c55078f7af924e25676b746d5536f0a 100644 (file)
@@ -49,7 +49,7 @@ def chooser(signal, shift, output, n=None, reverse=False):
                        s = n - i - 1
                else:
                        s = i
-               cases.append([Constant(i, shift.bv), output.eq(signal[s*w:(s+1)*w])])
+               cases.append([i, output.eq(signal[s*w:(s+1)*w])])
        cases[n-1][0] = Default()
        return Case(shift, *cases)
 
@@ -57,25 +57,25 @@ def timeline(trigger, events):
        lastevent = max([e[0] for e in events])
        counter = Signal(BV(bits_for(lastevent)))
        
-       counterlogic = If(counter != Constant(0, counter.bv),
-               counter.eq(counter + Constant(1, counter.bv))
+       counterlogic = If(counter != 0,
+               counter.eq(counter + 1)
        ).Elif(trigger,
-               counter.eq(Constant(1, counter.bv))
+               counter.eq(1)
        )
        # insert counter reset if it doesn't naturally overflow
        # (test if lastevent+1 is a power of 2)
        if (lastevent & (lastevent + 1)) != 0:
                counterlogic = If(counter == lastevent,
-                       counter.eq(Constant(0, counter.bv))
+                       counter.eq(0)
                ).Else(
                        counterlogic
                )
        
        def get_cond(e):
                if e[0] == 0:
-                       return trigger & (counter == Constant(0, counter.bv))
+                       return trigger & (counter == 0)
                else:
-                       return counter == Constant(e[0], counter.bv)
+                       return counter == e[0]
        sync = [If(get_cond(e), *e[1]) for e in events]
        sync.append(counterlogic)
        return sync
index 7dc6b0d65eb9eeee780cc46eac1ff49058b1f148..d8fd82bdd709a926e20caf60e89eb0da4d0bcff5 100644 (file)
@@ -1,4 +1,5 @@
 from migen.fhdl.structure import *
+from migen.fhdl.tools import value_bv
 
 class Record:
        def __init__(self, layout, name=""):
@@ -76,7 +77,7 @@ class Record:
                        if align:
                                pad_size = alignment - (offset % alignment)
                                if pad_size < alignment:
-                                       l.append(Constant(0, BV(pad_size)))
+                                       l.append(Replicate(0, pad_size))
                                        offset += pad_size
                        
                        e = self.__dict__[key]
@@ -87,7 +88,7 @@ class Record:
                        else:
                                raise TypeError
                        for x in added:
-                               offset += len(x)
+                               offset += value_bv(x).width
                        l += added
                if return_offset:
                        return (l, offset)
index aa310204e2641c82377f7b6f0d62ed940066794d..a0d7c73482372df70ae35297a2fad558dd9ac953 100644 (file)
@@ -21,7 +21,7 @@ class RoundRobin:
                                        t = j % self.n
                                        switch = [
                                                If(self.request[t],
-                                                       self.grant.eq(Constant(t, BV(self.bn)))
+                                                       self.grant.eq(t)
                                                ).Else(
                                                        *switch
                                                )
@@ -30,7 +30,7 @@ class RoundRobin:
                                        case = [If(~self.request[i], *switch)]
                                else:
                                        case = switch
-                               cases.append([Constant(i, BV(self.bn))] + case)
+                               cases.append([i] + case)
                        statement = Case(self.grant, *cases)
                        if self.switch_policy == SP_CE:
                                statement = If(self.ce, statement)
index 443ddea8eeacade14c8a67191a99eb320d3d1284..16faebc809a8b0231fd967121c4f25d8e390fefa 100644 (file)
@@ -15,8 +15,6 @@ def log2_int(n, need_pow2=True):
        return r
 
 def bits_for(n, require_sign_bit=False):
-       if isinstance(n, Constant):
-               return len(n)
        if n > 0:
                r = log2_int(n + 1, False)
        else:
@@ -126,7 +124,7 @@ class _Operator(Value):
        def __init__(self, op, operands):
                super().__init__()
                self.op = op
-               self.operands = list(map(_cst, operands))
+               self.operands = operands
 
 class _Slice(Value):
        def __init__(self, value, start, stop):
@@ -138,49 +136,21 @@ class _Slice(Value):
 class Cat(Value):
        def __init__(self, *args):
                super().__init__()
-               self.l = list(map(_cst, args))
+               self.l = args
 
 class Replicate(Value):
        def __init__(self, v, n):
                super().__init__()
-               self.v = _cst(v)
+               self.v = v
                self.n = n
 
-class Constant(Value):
-       def __init__(self, n, bv=None):
-               super().__init__()
-               self.bv = bv or BV(bits_for(n), n < 0)
-               self.n = n
-       
-       def __len__(self):
-               return self.bv.width
-       
-       def __repr__(self):
-               return str(self.bv) + str(self.n)
-       
-       def __eq__(self, other):
-               return self.bv == other.bv and self.n == other.n
-       
-       def __hash__(self):
-               return super().__hash__()
-
-
-def binc(x, signed=False):
-       return Constant(int(x, 2), BV(len(x), signed))
-
-def _cst(x):
-       if isinstance(x, int):
-               return Constant(x)
-       else:
-               return x
-
 class Signal(Value):
        def __init__(self, bv=BV(), name=None, variable=False, reset=0, name_override=None):
                super().__init__()
                assert(isinstance(bv, BV))
                self.bv = bv
                self.variable = variable
-               self.reset = Constant(reset, bv)
+               self.reset = reset
                self.name_override = name_override
                self.backtrace = tracer.trace_back(name)
 
@@ -195,7 +165,7 @@ class Signal(Value):
 class _Assign:
        def __init__(self, l, r):
                self.l = l
-               self.r = _cst(r)
+               self.r = r
 
 class If:
        def __init__(self, cond, *t):
@@ -274,8 +244,6 @@ class Instance(HUID):
                        self.name = name
                        if isinstance(expr, BV):
                                self.expr = Signal(expr, name)
-                       elif isinstance(expr, int):
-                               self.expr = Constant(expr)
                        else:
                                self.expr = expr
        class Input(_IO):
index 0824e46690ef7e0fa1ecbf19f37af9d36f7b6737..b26aa441512e8c0b98c3fd721cb7439b53440eb9 100644 (file)
@@ -105,8 +105,10 @@ def insert_reset(rst, sl):
        return If(rst, *resetcode).Else(*sl)
 
 def value_bv(v):
-       if isinstance(v, Constant):
-               return v.bv
+       if isinstance(v, bool):
+               return BV(1, False)
+       elif isinstance(v, int):
+               return BV(bits_for(v), v < 0)
        elif isinstance(v, Signal):
                return v.bv
        elif isinstance(v, _Operator):
@@ -152,7 +154,7 @@ class _ArrayLowerer(NodeTransformer):
                        cases = []
                        for n, choice in enumerate(node.l.choices):
                                assign = self.visit_Assign(_Assign(choice, node.r))
-                               cases.append([Constant(n), assign])
+                               cases.append([n, assign])
                        cases[-1][0] = Default()
                        return Case(k, *cases)
                else:
@@ -160,7 +162,7 @@ class _ArrayLowerer(NodeTransformer):
        
        def visit_ArrayProxy(self, node):
                array_muxed = Signal(value_bv(node))
-               cases = [[Constant(n), _Assign(array_muxed, self.visit(choice))]
+               cases = [[n, _Assign(array_muxed, self.visit(choice))]
                        for n, choice in enumerate(node.choices)]
                cases[-1][0] = Default()
                self.comb.append(Case(self.visit(node.key), *cases))
index 3d497c2f5cbb0e119b21d1157d0d29a1a71e1dde..ec878c0cad9ec423d0478c79249bffd2395b3e5f 100644 (file)
@@ -17,12 +17,23 @@ def _printsig(ns, s):
        n += ns.get_name(s)
        return n
 
-def _printexpr(ns, node):
-       if isinstance(node, Constant):
-               if node.n >= 0:
-                       return str(node.bv) + str(node.n)
+def _printintbool(node):
+       if isinstance(node, bool):
+               if node:
+                       return "1'd1"
+               else:
+                       return "1'd0"
+       elif isinstance(node, int):
+               if node >= 0:
+                       return str(bits_for(node)) + "'d" + str(node)
                else:
-                       return "-" + str(node.bv) + str(-node.n)
+                       return "-" + str(bits_for(node)) + "'sd" + str(-node)
+       else:
+               raise TypeError
+
+def _printexpr(ns, node):
+       if isinstance(node, (int, bool)):
+               return _printintbool(node)
        elif isinstance(node, Signal):
                return ns.get_name(node)
        elif isinstance(node, _Operator):
@@ -146,7 +157,7 @@ def _printcomb(f, ns, display_run):
                dummy_s = Signal(name_override="dummy_s")
                r += syn_off
                r += "reg " + _printsig(ns, dummy_s) + ";\n"
-               r += "initial " + ns.get_name(dummy_s) + " <= 1'b0;\n"
+               r += "initial " + ns.get_name(dummy_s) + " <= 1'd0;\n"
                r += syn_on
                
                groups = group_by_targets(f.comb)
@@ -164,7 +175,7 @@ def _printcomb(f, ns, display_run):
                                if display_run:
                                        r += "\t$display(\"Running comb block #" + str(n) + "\");\n"
                                for t in g[0]:
-                                       r += "\t" + ns.get_name(t) + " <= " + str(t.reset) + ";\n"
+                                       r += "\t" + ns.get_name(t) + " <= " + _printexpr(ns, t.reset) + ";\n"
                                r += _printnode(ns, _AT_NONBLOCKING, 1, g[1])
                                r += syn_off
                                r += "\t" + ns.get_name(dummy_d) + " <= " + ns.get_name(dummy_s) + ";\n"
@@ -194,7 +205,9 @@ def _printinstances(f, ns, clock_domains):
                                        r += ",\n"
                                firstp = False
                                r += "\t." + p.name + "("
-                               if isinstance(p.value, int) or isinstance(p.value, float) or isinstance(p.value, Constant):
+                               if isinstance(p.value, (int, bool)):
+                                       r += _printintbool(p.value)
+                               elif isinstance(p.value, float):
                                        r += str(p.value)
                                elif isinstance(p.value, str):
                                        r += "\"" + p.value + "\""
index 595f40052ecedfcb148d1d5ef6a2d1cee3dfd179..44abc77de29e72fb7138b7051d0f6c5721fc3ddb 100644 (file)
@@ -5,8 +5,8 @@ from migen.fhdl.structure import _Operator, _Slice, _Assign, _ArrayProxy
 
 class NodeVisitor:
        def visit(self, node):
-               if isinstance(node, Constant):
-                       self.visit_Constant(node)
+               if isinstance(node, (int, bool)):
+                       self.visit_constant(node)
                elif isinstance(node, Signal):
                        self.visit_Signal(node)
                elif isinstance(node, _Operator):
@@ -34,7 +34,7 @@ class NodeVisitor:
                elif node is not None:
                        self.visit_unknown(node)
        
-       def visit_Constant(self, node):
+       def visit_constant(self, node):
                pass
        
        def visit_Signal(self, node):
@@ -90,15 +90,14 @@ class NodeVisitor:
                pass
 
 # Default methods always copy the node, except for:
-# - Constants
 # - Signals
 # - Unknown objects
 # - All fragment fields except comb and sync
 # In those cases, the original node is returned unchanged.
 class NodeTransformer:
        def visit(self, node):
-               if isinstance(node, Constant):
-                       return self.visit_Constant(node)
+               if isinstance(node, (int, bool)):
+                       return self.visit_constant(node)
                elif isinstance(node, Signal):
                        return self.visit_Signal(node)
                elif isinstance(node, _Operator):
@@ -128,7 +127,7 @@ class NodeTransformer:
                else:
                        return None
        
-       def visit_Constant(self, node):
+       def visit_constant(self, node):
                return node
        
        def visit_Signal(self, node):
index 3dba08a7943a1e1a87cf437da1b91ca11cef2921..17f587f6dedb88de8d0d8472f11d1bfb16e977dc 100644 (file)
@@ -95,18 +95,16 @@ class ExprCompiler:
        
        def visit_expr_name(self, node):
                if node.id == "True":
-                       return Constant(1)
+                       return 1
                if node.id == "False":
-                       return Constant(0)
+                       return 0
                r = self.symdict[node.id]
                if isinstance(r, ImplRegister):
                        r = r.storage
-               if isinstance(r, int):
-                       r = Constant(r)
                return r
        
        def visit_expr_num(self, node):
-               return Constant(node.n)
+               return node.n
        
        def visit_expr_attribute(self, node):
                raise NotImplementedError
index 610f2330786f25f540441ea9725f09ab4d02bd1b..543ac064a33ce9e16e414e4331f48c7ec63c4a0d 100644 (file)
@@ -46,7 +46,6 @@ class ImplRegister:
                        raise FinalizeError
                # do nothing when sel == 0
                items = sorted(self.source_encoding.items(), key=itemgetter(1))
-               cases = [(Constant(v, self.sel.bv),
-                       self.storage.eq(k)) for k, v in items]
+               cases = [(v, self.storage.eq(k)) for k, v in items]
                sync = [Case(self.sel, *cases)]
                return Fragment(sync=sync)