def ls_op(self, op1, op2, carry):
op1 = getsig(op1)
if isinstance(op2, Const) or isinstance(op2, Signal):
+ scalar = True
shape = op1.shape()
pa = PartitionedScalarShift(shape[0], self.partpoints)
else:
+ scalar = False
op2 = getsig(op2)
shape = op1.shape()
pa = PartitionedDynamicShift(shape[0], self.partpoints)
setattr(self.m.submodules, self.get_modname('ls'), pa)
comb = self.m.d.comb
- if isinstance(op2, Const) or isinstance(op2, Signal):
- comb += pa.a.eq(op1)
- comb += pa.b.eq(op2)
- else:
+ if scalar:
comb += pa.data.eq(op1)
comb += pa.shifter.eq(op2)
+ else:
+ comb += pa.a.eq(op1)
+ comb += pa.b.eq(op2)
# XXX TODO: carry-in, carry-out
#comb += pa.carry_in.eq(carry)
return (pa.output, 0)
self.partpoints = partpoints
self.a = PartitionedSignal(partpoints, width)
self.b = PartitionedSignal(partpoints, width)
+ self.bsig = Signal(width)
self.add_output = Signal(width)
self.ls_output = Signal(width) # left shift
+ self.ls_scal_output = Signal(width) # left shift
self.sub_output = Signal(width)
self.eq_output = Signal(len(partpoints)+1)
self.gt_output = Signal(len(partpoints)+1)
comb += self.ls_output.eq(self.a << self.b)
ppts = self.partpoints
comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
+ # scalar left shift
+ comb += self.bsig.eq(self.b.sig)
+ comb += self.ls_scal_output.eq(self.a << self.bsig)
return m
def async_process():
+ def test_ls_scal_fn(carry_in, a, b, mask):
+ # reduce range of b
+ bits = count_bits(mask)
+ newb = b & ((bits-1))
+ print ("%x %x %x bits %d trunc %x" % \
+ (a, b, mask, bits, newb))
+ b = newb
+ # TODO: carry
+ carry_in = 0
+ lsb = mask & ~(mask-1) if carry_in else 0
+ sum = ((a & mask) << b)
+ result = mask & sum
+ carry = (sum & mask) != sum
+ carry = 0
+ print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
+ return result, carry
+
def test_ls_fn(carry_in, a, b, mask):
# reduce range of b
bits = count_bits(mask)
result = mask & sum
carry = (sum & mask) != sum
carry = 0
- print("result", hex(a), hex(b), hex(sum), hex(mask), hex(result))
+ print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
return result, carry
def test_add_fn(carry_in, a, b, mask):
self.assertEqual(carry_result, c_outval, msg)
for (test_fn, mod_attr) in (
+ (test_ls_scal_fn, "ls_scal"),
(test_ls_fn, "ls"),
(test_add_fn, "add"),
(test_sub_fn, "sub"),