# TODO: detect if the 2nd operand is a Const, a Signal or a
# PartitionedSignal. if it's a Const or a Signal, a global shift
# can occur. if it's a PartitionedSignal, that's much more interesting.
- def ls_op(self, op1, op2, carry):
+ def ls_op(self, op1, op2, carry, shr_flag=0):
op1 = getsig(op1)
if isinstance(op2, Const) or isinstance(op2, Signal):
scalar = True
if scalar:
comb += pa.data.eq(op1)
comb += pa.shifter.eq(op2)
+ comb += pa.shift_right.eq(shr_flag)
else:
comb += pa.a.eq(op1)
comb += pa.b.eq(op2)
+ comb += pa.shift_right.eq(shr_flag)
# XXX TODO: carry-in, carry-out
#comb += pa.carry_in.eq(carry)
return (pa.output, 0)
return Operator("<<", [other, self])
def __rshift__(self, other):
- raise NotImplementedError
- return Operator(">>", [self, other])
+ z = Const(0, len(self.partpoints)+1)
+ result, _ = self.ls_op(self, other, carry=z, shr_flag=1) # TODO, carry
+ return result
def __rrshift__(self, other):
raise NotImplementedError
self.add_output = Signal(width)
self.ls_output = Signal(width) # left shift
self.ls_scal_output = Signal(width) # left shift
+ self.rs_output = Signal(width) # left shift
+ self.rs_scal_output = Signal(width) # left shift
self.sub_output = Signal(width)
self.eq_output = Signal(len(partpoints)+1)
self.gt_output = Signal(len(partpoints)+1)
sync += self.neg_output.eq(-self.a)
# left shift
sync += self.ls_output.eq(self.a << self.b)
+ sync += self.rs_output.eq(self.a >> self.b)
ppts = self.partpoints
sync += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
# scalar left shift
comb += self.bsig.eq(self.b.sig)
sync += self.ls_scal_output.eq(self.a << self.bsig)
+ sync += self.rs_scal_output.eq(self.a >> self.bsig)
return m
self.add_output = Signal(width)
self.ls_output = Signal(width) # left shift
self.ls_scal_output = Signal(width) # left shift
+ self.rs_output = Signal(width) # left shift
+ self.rs_scal_output = Signal(width) # left shift
self.sub_output = Signal(width)
self.eq_output = Signal(len(partpoints)+1)
self.gt_output = Signal(len(partpoints)+1)
comb += self.neg_output.eq(-self.a)
# left shift
comb += self.ls_output.eq(self.a << self.b)
+ comb += self.rs_output.eq(self.a >> self.b)
ppts = self.partpoints
comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
# scalar left shift
comb += self.bsig.eq(self.b.sig)
comb += self.ls_scal_output.eq(self.a << self.bsig)
+ comb += self.rs_scal_output.eq(self.a >> self.bsig)
return m
print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
return result, carry
+ def test_rs_scal_fn(carry_in, a, b, mask):
+ # reduce range of b
+ bits = count_bits(mask)
+ newb = b & ((bits-1))
+ print ("%x %x %x bits %d trunc %x" % \
+ (a, b, mask, bits, newb))
+ b = newb
+ # TODO: carry
+ carry_in = 0
+ lsb = mask & ~(mask-1) if carry_in else 0
+ sum = ((a & mask) >> b)
+ result = mask & sum
+ carry = (sum & mask) != sum
+ carry = 0
+ print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
+ return result, carry
+
def test_ls_fn(carry_in, a, b, mask):
# reduce range of b
bits = count_bits(mask)
print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
return result, carry
+ def test_rs_fn(carry_in, a, b, mask):
+ # reduce range of b
+ bits = count_bits(mask)
+ fz = first_zero(mask)
+ newb = b & ((bits-1)<<fz)
+ print ("%x %x %x bits %d zero %d trunc %x" % \
+ (a, b, mask, bits, fz, newb))
+ b = newb
+ # TODO: carry
+ carry_in = 0
+ lsb = mask & ~(mask-1) if carry_in else 0
+ b = (b & mask)
+ b = b >>fz
+ sum = ((a & mask) >> b)
+ result = mask & sum
+ carry = (sum & mask) != sum
+ carry = 0
+ print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
+ return result, carry
+
def test_add_fn(carry_in, a, b, mask):
lsb = mask & ~(mask-1) if carry_in else 0
sum = (a & mask) + (b & mask) + lsb
for (test_fn, mod_attr) in (
(test_ls_scal_fn, "ls_scal"),
(test_ls_fn, "ls"),
+ (test_rs_scal_fn, "rs_scal"),
+ (test_rs_fn, "rs"),
(test_add_fn, "add"),
(test_sub_fn, "sub"),
(test_neg_fn, "neg"),