From: whitequark Date: Mon, 8 Jul 2019 09:23:33 +0000 (+0000) Subject: hdl.ast: use keyword-only arguments as appropriate. X-Git-Tag: locally_working~98 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=dac62754;p=nmigen.git hdl.ast: use keyword-only arguments as appropriate. As a motivation/related refactor, make sure each AST node exposes src_loc_at in the constructor. --- diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index 622c0fd..9a88fa3 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -41,7 +41,7 @@ class Value(metaclass=ABCMeta): else: raise TypeError("Object '{!r}' is not an nMigen value".format(obj)) - def __init__(self, src_loc_at=0): + def __init__(self, *, src_loc_at=0): super().__init__() self.src_loc = tracer.get_src_loc(1 + src_loc_at) @@ -242,6 +242,7 @@ class Const(Value): return value def __init__(self, value, shape=None): + # We deliberately do not call Value.__init__ here. self.value = int(value) if shape is None: shape = bits_for(self.value), self.value < 0 @@ -270,8 +271,8 @@ C = Const # shorthand class AnyValue(Value, DUID): - def __init__(self, shape): - super().__init__(src_loc_at=0) + def __init__(self, shape, *, src_loc_at=0): + super().__init__(src_loc_at=src_loc_at) if isinstance(shape, int): shape = shape, False self.nbits, self.signed = shape @@ -300,7 +301,7 @@ class AnySeq(AnyValue): @final class Operator(Value): - def __init__(self, op, operands, src_loc_at=0): + def __init__(self, op, operands, *, src_loc_at=0): super().__init__(src_loc_at=1 + src_loc_at) self.op = op self.operands = [Value.wrap(o) for o in operands] @@ -395,7 +396,7 @@ def Mux(sel, val1, val0): @final class Slice(Value): - def __init__(self, value, start, end): + def __init__(self, value, start, end, *, src_loc_at=0): if not isinstance(start, int): raise TypeError("Slice start must be an integer, not '{!r}'".format(start)) if not isinstance(end, int): @@ -413,7 +414,7 @@ class Slice(Value): if start > end: raise IndexError("Slice start {} must be less than slice end {}".format(start, end)) - super().__init__() + super().__init__(src_loc_at=src_loc_at) self.value = Value.wrap(value) self.start = start self.end = end @@ -433,11 +434,11 @@ class Slice(Value): @final class Part(Value): - def __init__(self, value, offset, width): + def __init__(self, value, offset, width, *, src_loc_at=0): if not isinstance(width, int) or width < 0: raise TypeError("Part width must be a non-negative integer, not '{!r}'".format(width)) - super().__init__() + super().__init__(src_loc_at=src_loc_at) self.value = value self.offset = Value.wrap(offset) self.width = width @@ -480,8 +481,8 @@ class Cat(Value): Value, inout Resulting ``Value`` obtained by concatentation. """ - def __init__(self, *args): - super().__init__() + def __init__(self, *args, src_loc_at=0): + super().__init__(src_loc_at=src_loc_at) self.parts = [Value.wrap(v) for v in flatten(args)] def shape(self): @@ -525,12 +526,12 @@ class Repl(Value): Repl, out Replicated value. """ - def __init__(self, value, count): + def __init__(self, value, count, *, src_loc_at=0): if not isinstance(count, int) or count < 0: raise TypeError("Replication count must be a non-negative integer, not '{!r}'" .format(count)) - super().__init__() + super().__init__(src_loc_at=src_loc_at) self.value = Value.wrap(value) self.count = count @@ -592,7 +593,7 @@ class Signal(Value, DUID): attrs : dict """ - def __init__(self, shape=None, name=None, reset=0, reset_less=False, min=None, max=None, + def __init__(self, shape=None, name=None, *, reset=0, reset_less=False, min=None, max=None, attrs=None, decoder=None, src_loc_at=0): super().__init__(src_loc_at=src_loc_at) @@ -641,7 +642,7 @@ class Signal(Value, DUID): self.decoder = decoder @classmethod - def like(cls, other, name=None, name_suffix=None, src_loc_at=0, **kwargs): + def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs): """Create Signal based on another. Parameters @@ -688,8 +689,8 @@ class ClockSignal(Value): domain : str Clock domain to obtain a clock signal for. Defaults to ``"sync"``. """ - def __init__(self, domain="sync"): - super().__init__() + def __init__(self, domain="sync", *, src_loc_at=0): + super().__init__(src_loc_at=src_loc_at) if not isinstance(domain, str): raise TypeError("Clock domain name must be a string, not '{!r}'".format(domain)) self.domain = domain @@ -722,8 +723,8 @@ class ResetSignal(Value): allow_reset_less : bool If the clock domain is reset-less, act as a constant ``0`` instead of reporting an error. """ - def __init__(self, domain="sync", allow_reset_less=False): - super().__init__() + def __init__(self, domain="sync", allow_reset_less=False, *, src_loc_at=0): + super().__init__(src_loc_at=src_loc_at) if not isinstance(domain, str): raise TypeError("Clock domain name must be a string, not '{!r}'".format(domain)) self.domain = domain @@ -832,8 +833,8 @@ class Array(MutableSequence): @final class ArrayProxy(Value): - def __init__(self, elems, index): - super().__init__(src_loc_at=1) + def __init__(self, elems, index, *, src_loc_at=0): + super().__init__(src_loc_at=1 + src_loc_at) self.elems = elems self.index = Value.wrap(index) @@ -885,7 +886,7 @@ class UserValue(Value): * Indexing or iterating through individual bits; * Adding an assignment to the value to a ``Module`` using ``m.d. +=``. """ - def __init__(self, src_loc_at=1): + def __init__(self, *, src_loc_at=0): super().__init__(src_loc_at=1 + src_loc_at) self.__lowered = None @@ -917,8 +918,8 @@ class Sample(Value): of the ``domain`` clock back. If that moment is before the beginning of time, it is equal to the value of the expression calculated as if each signal had its reset value. """ - def __init__(self, expr, clocks, domain): - super().__init__(src_loc_at=1) + def __init__(self, expr, clocks, domain, *, src_loc_at=0): + super().__init__(src_loc_at=1 + src_loc_at) self.value = Value.wrap(expr) self.clocks = int(clocks) self.domain = domain @@ -962,6 +963,9 @@ class _StatementList(list): class Statement: + def __init__(self, *, src_loc_at=0): + self.src_loc = tracer.get_src_loc(1 + src_loc_at) + @staticmethod def wrap(obj): if isinstance(obj, Iterable): @@ -975,9 +979,8 @@ class Statement: @final class Assign(Statement): - def __init__(self, lhs, rhs, src_loc_at=0): - self.src_loc = tracer.get_src_loc(src_loc_at) - + def __init__(self, lhs, rhs, *, src_loc_at=0): + super().__init__(src_loc_at=src_loc_at) self.lhs = Value.wrap(lhs) self.rhs = Value.wrap(rhs) @@ -992,17 +995,14 @@ class Assign(Statement): class Property(Statement): - def __init__(self, test, _check=None, _en=None): - self.src_loc = tracer.get_src_loc() - - self.test = Value.wrap(test) - + def __init__(self, test, *, _check=None, _en=None, src_loc_at=0): + super().__init__(src_loc_at=src_loc_at) + self.test = Value.wrap(test) self._check = _check + self._en = _en if self._check is None: self._check = Signal(reset_less=True, name="${}$check".format(self._kind)) self._check.src_loc = self.src_loc - - self._en = _en if _en is None: self._en = Signal(reset_less=True, name="${}$en".format(self._kind)) self._en.src_loc = self.src_loc @@ -1029,9 +1029,8 @@ class Assume(Property): # @final class Switch(Statement): - def __init__(self, test, cases, src_loc_at=0): - self.src_loc = tracer.get_src_loc(src_loc_at) - + def __init__(self, test, cases, *, src_loc_at=0): + super().__init__(src_loc_at=src_loc_at) self.test = Value.wrap(test) self.cases = OrderedDict() for keys, stmts in cases.items(): @@ -1081,7 +1080,8 @@ class Switch(Statement): @final class Delay(Statement): - def __init__(self, interval=None): + def __init__(self, interval=None, *, src_loc_at=0): + super().__init__(src_loc_at=src_loc_at) self.interval = None if interval is None else float(interval) def _rhs_signals(self): @@ -1096,7 +1096,8 @@ class Delay(Statement): @final class Tick(Statement): - def __init__(self, domain="sync"): + def __init__(self, domain="sync", *, src_loc_at=0): + super().__init__(src_loc_at=src_loc_at) self.domain = str(domain) def _rhs_signals(self): @@ -1108,6 +1109,9 @@ class Tick(Statement): @final class Passive(Statement): + def __init__(self, *, src_loc_at=0): + super().__init__(src_loc_at=src_loc_at) + def _rhs_signals(self): return ValueSet() diff --git a/nmigen/hdl/xfrm.py b/nmigen/hdl/xfrm.py index 93e5d0a..298a449 100644 --- a/nmigen/hdl/xfrm.py +++ b/nmigen/hdl/xfrm.py @@ -207,7 +207,7 @@ class StatementVisitor(metaclass=ABCMeta): new_stmt = self.on_statements(stmt) else: new_stmt = self.on_unknown_statement(stmt) - if hasattr(stmt, "src_loc") and hasattr(new_stmt, "src_loc"): + if isinstance(new_stmt, Statement): new_stmt.src_loc = stmt.src_loc return new_stmt