# Copyright (C) Jonathan P Dawson 2013
# 2013-12-12
-from nmigen import Signal, Cat, Const, Mux, Module
+from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
from math import log
from operator import or_
from functools import reduce
+from singlepipe import PrevControl, NextControl
+from pipeline import ObjectProxy
+
+
class MultiShiftR:
def __init__(self, width):
return res
-class FPNumBase:
+class FPNumBase: #(Elaboratable):
""" Floating-point Base Number Class
"""
def __init__(self, width, m_extra=True):
self.s = Signal(reset_less=True) # Sign bit
self.mzero = Const(0, (m_width, False))
+ m_msb = 1<<(self.m_width-2)
+ self.msb1 = Const(m_msb, (m_width, False))
self.m1s = Const(-1, (m_width, False))
self.P128 = Const(e_max, (e_width, True))
self.P127 = Const(e_max-1, (e_width, True))
self.is_overflowed = Signal(reset_less=True)
self.is_denormalised = Signal(reset_less=True)
self.exp_128 = Signal(reset_less=True)
+ self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
+ self.exp_lt_n126 = Signal(reset_less=True)
+ self.exp_gt_n126 = Signal(reset_less=True)
self.exp_gt127 = Signal(reset_less=True)
self.exp_n127 = Signal(reset_less=True)
self.exp_n126 = Signal(reset_less=True)
m.d.comb += self.is_overflowed.eq(self._is_overflowed())
m.d.comb += self.is_denormalised.eq(self._is_denormalised())
m.d.comb += self.exp_128.eq(self.e == self.P128)
+ m.d.comb += self.exp_sub_n126.eq(self.e - self.N126)
+ m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
+ m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
m.d.comb += self.exp_gt127.eq(self.e > self.P127)
m.d.comb += self.exp_n127.eq(self.e == self.N127)
m.d.comb += self.exp_n126.eq(self.e == self.N126)
def _is_denormalised(self):
return (self.exp_n126) & (self.m_msbzero)
- def copy(self, inp):
+ def __iter__(self):
+ yield self.s
+ yield self.e
+ yield self.m
+
+ def eq(self, inp):
return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
def zero(self, s):
return self.create(s, self.N127, 0)
+ def create2(self, s, e, m):
+ """ creates a value from sign / exponent / mantissa
+
+ bias is added here, to the exponent
+ """
+ e = e + self.P127 # exp (add on bias)
+ return Cat(m[0:self.e_start],
+ e[0:self.e_end-self.e_start],
+ s)
+
+ def nan2(self, s):
+ return self.create2(s, self.P128, self.msb1)
+
+ def inf2(self, s):
+ return self.create2(s, self.P128, self.mzero)
+
+ def zero2(self, s):
+ return self.create2(s, self.N127, self.mzero)
+
+
+class MultiShiftRMerge(Elaboratable):
+ """ shifts down (right) and merges lower bits into m[0].
+ m[0] is the "sticky" bit, basically
+ """
+ def __init__(self, width, s_max=None):
+ if s_max is None:
+ s_max = int(log(width) / log(2))
+ self.smax = s_max
+ self.m = Signal(width, reset_less=True)
+ self.inp = Signal(width, reset_less=True)
+ self.diff = Signal(s_max, reset_less=True)
+ self.width = width
+
+ def elaborate(self, platform):
+ m = Module()
+
+ rs = Signal(self.width, reset_less=True)
+ m_mask = Signal(self.width, reset_less=True)
+ smask = Signal(self.width, reset_less=True)
+ stickybit = Signal(reset_less=True)
+ maxslen = Signal(self.smax, reset_less=True)
+ maxsleni = Signal(self.smax, reset_less=True)
+
+ sm = MultiShift(self.width-1)
+ m0s = Const(0, self.width-1)
+ mw = Const(self.width-1, len(self.diff))
+ m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
+ maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
+ ]
+
+ m.d.comb += [
+ # shift mantissa by maxslen, mask by inverse
+ rs.eq(sm.rshift(self.inp[1:], maxslen)),
+ m_mask.eq(sm.rshift(~m0s, maxsleni)),
+ smask.eq(self.inp[1:] & m_mask),
+ # sticky bit combines all mask (and mantissa low bit)
+ stickybit.eq(smask.bool() | self.inp[0]),
+ # mantissa result contains m[0] already.
+ self.m.eq(Cat(stickybit, rs))
+ ]
+ return m
-class FPNumShift(FPNumBase):
+
+class FPNumShift(FPNumBase, Elaboratable):
""" Floating-point Number Class for shifting
"""
def __init__(self, mainm, op, inv, width, m_extra=True):
return m
- def shift_down(self):
+ def shift_down(self, inp):
""" shifts a mantissa down by one. exponent is increased to compensate
accuracy is lost as a result in the mantissa however there are 3
guard bits (the latter of which is the "sticky" bit)
"""
- return [self.e.eq(self.e + 1),
- self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
+ return [self.e.eq(inp.e + 1),
+ self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
]
def shift_down_multi(self, diff):
self.m.eq(sm.lshift(self.m, maxslen))
]
-class FPNumIn(FPNumBase):
+
+class FPNumDecode(FPNumBase):
""" Floating-point Number Class
Contains signals for an incoming copy of the value, decoded into
"""
def __init__(self, op, width, m_extra=True):
FPNumBase.__init__(self, width, m_extra)
- self.latch_in = Signal()
self.op = op
def elaborate(self, platform):
m = FPNumBase.elaborate(self, platform)
- #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
- #with m.If(self.latch_in):
- # m.d.sync += self.decode(self.v)
+ m.d.comb += self.decode(self.v)
return m
self.s.eq(v[-1]), # sign
]
- def shift_down(self):
+class FPNumIn(FPNumBase):
+ """ Floating-point Number Class
+
+ Contains signals for an incoming copy of the value, decoded into
+ sign / exponent / mantissa.
+ Also contains encoding functions, creation and recognition of
+ zero, NaN and inf (all signed)
+
+ Four extra bits are included in the mantissa: the top bit
+ (m[-1]) is effectively a carry-overflow. The other three are
+ guard (m[2]), round (m[1]), and sticky (m[0])
+ """
+ def __init__(self, op, width, m_extra=True):
+ FPNumBase.__init__(self, width, m_extra)
+ self.latch_in = Signal()
+ self.op = op
+
+ def decode2(self, m):
+ """ decodes a latched value into sign / exponent / mantissa
+
+ bias is subtracted here, from the exponent. exponent
+ is extended to 10 bits so that subtract 127 is done on
+ a 10-bit number
+ """
+ v = self.v
+ args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+ #print ("decode", self.e_end)
+ res = ObjectProxy(m, pipemode=False)
+ res.m = Cat(*args) # mantissa
+ res.e = v[self.e_start:self.e_end] - self.P127 # exp
+ res.s = v[-1] # sign
+ return res
+
+ def decode(self, v):
+ """ decodes a latched value into sign / exponent / mantissa
+
+ bias is subtracted here, from the exponent. exponent
+ is extended to 10 bits so that subtract 127 is done on
+ a 10-bit number
+ """
+ args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+ #print ("decode", self.e_end)
+ return [self.m.eq(Cat(*args)), # mantissa
+ self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
+ self.s.eq(v[-1]), # sign
+ ]
+
+ def shift_down(self, inp):
""" shifts a mantissa down by one. exponent is increased to compensate
accuracy is lost as a result in the mantissa however there are 3
guard bits (the latter of which is the "sticky" bit)
"""
- return [self.e.eq(self.e + 1),
- self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
+ return [self.e.eq(inp.e + 1),
+ self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
]
- def shift_down_multi(self, diff):
+ def shift_down_multi(self, diff, inp=None):
""" shifts a mantissa down. exponent is increased to compensate
accuracy is lost as a result in the mantissa however there are 3
inverted and used as a mask to get the LSBs of the mantissa.
those are then |'d into the sticky bit.
"""
+ if inp is None:
+ inp = self
sm = MultiShift(self.width)
mw = Const(self.m_width-1, len(diff))
maxslen = Mux(diff > mw, mw, diff)
- rs = sm.rshift(self.m[1:], maxslen)
+ rs = sm.rshift(inp.m[1:], maxslen)
maxsleni = mw - maxslen
m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
- stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
- return [self.e.eq(self.e + diff),
- self.m.eq(Cat(stickybits, rs))
+ #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
+ stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
+ return [self.e.eq(inp.e + diff),
+ self.m.eq(Cat(stickybit, rs))
]
def shift_up_multi(self, diff):
self.m.eq(sm.lshift(self.m, maxslen))
]
-class FPOp:
- def __init__(self, width):
- self.width = width
+class Trigger(Elaboratable):
+ def __init__(self):
- self.v = Signal(width)
self.stb = Signal(reset=0)
self.ack = Signal()
+ self.trigger = Signal(reset_less=True)
+
+ def elaborate(self, platform):
+ m = Module()
+ m.d.comb += self.trigger.eq(self.stb & self.ack)
+ return m
+
+ def eq(self, inp):
+ return [self.stb.eq(inp.stb),
+ self.ack.eq(inp.ack)
+ ]
+
+ def ports(self):
+ return [self.stb, self.ack]
+
+
+class FPOpIn(PrevControl):
+ def __init__(self, width):
+ PrevControl.__init__(self)
+ self.width = width
+ self.v = Signal(width)
+ self.i_data = self.v
def chain_inv(self, in_op, extra=None):
stb = in_op.stb
in_op.ack.eq(self.ack), # send ACK
]
- def ports(self):
- return [self.v, self.stb, self.ack]
+class FPOpOut(NextControl):
+ def __init__(self, width):
+ NextControl.__init__(self)
+ self.width = width
+ self.v = Signal(width)
+ self.o_data = self.v
+
+ def chain_inv(self, in_op, extra=None):
+ stb = in_op.stb
+ if extra is not None:
+ stb = stb & extra
+ return [self.v.eq(in_op.v), # receive value
+ self.stb.eq(stb), # receive STB
+ in_op.ack.eq(~self.ack), # send ACK
+ ]
-class Overflow:
+ def chain_from(self, in_op, extra=None):
+ stb = in_op.stb
+ if extra is not None:
+ stb = stb & extra
+ return [self.v.eq(in_op.v), # receive value
+ self.stb.eq(stb), # receive STB
+ in_op.ack.eq(self.ack), # send ACK
+ ]
+
+
+class Overflow(Elaboratable):
def __init__(self):
self.guard = Signal(reset_less=True) # tot[2]
self.round_bit = Signal(reset_less=True) # tot[1]
self.roundz = Signal(reset_less=True)
- def copy(self, inp):
+ def __iter__(self):
+ yield self.guard
+ yield self.round_bit
+ yield self.sticky
+ yield self.m0
+
+ def eq(self, inp):
return [self.guard.eq(inp.guard),
self.round_bit.eq(inp.round_bit),
self.sticky.eq(inp.sticky),
when both stb and ack are 1.
acknowledgement is sent by setting ack to ZERO.
"""
- with m.If((op.ack) & (op.stb)):
+ res = v.decode2(m)
+ ack = Signal()
+ with m.If((op.ready_o) & (op.valid_i_test)):
m.next = next_state
- m.d.sync += [
- # op is latched in from FPNumIn class on same ack/stb
- v.decode(op.v),
- op.ack.eq(0)
- ]
+ # op is latched in from FPNumIn class on same ack/stb
+ m.d.comb += ack.eq(0)
with m.Else():
- m.d.sync += op.ack.eq(1)
+ m.d.comb += ack.eq(1)
+ return [res, ack]
def denormalise(self, m, a):
""" denormalises a number. this is probably the wrong name for
both cases *effectively multiply the number stored by 2*,
which has to be taken into account when extracting the result.
"""
- with m.If(a.e == a.N127):
+ with m.If(a.exp_n127):
m.d.sync += a.e.eq(a.N126) # limit a exponent
with m.Else():
m.d.sync += a.m[-1].eq(1) # set top mantissa bit
with m.Else():
m.next = next_state
- def roundz(self, m, z, out_z, roundz):
+ def roundz(self, m, z, roundz):
""" performs rounding on the output. TODO: different kinds of rounding
"""
- m.d.comb += out_z.copy(z) # copies input to output first
with m.If(roundz):
- m.d.comb += out_z.m.eq(z.m + 1) # mantissa rounds up
+ m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
with m.If(z.m == z.m1s): # all 1s
- m.d.comb += out_z.e.eq(z.e + 1) # exponent rounds up
+ m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
def corrections(self, m, z, next_state):
""" denormalisation and sign-bug corrections
m.d.sync += [
out_z.v.eq(z.v)
]
- with m.If(out_z.stb & out_z.ack):
- m.d.sync += out_z.stb.eq(0)
+ with m.If(out_z.o_valid & out_z.i_ready_test):
+ m.d.sync += out_z.o_valid.eq(0)
m.next = next_state
with m.Else():
- m.d.sync += out_z.stb.eq(1)
+ m.d.sync += out_z.o_valid.eq(1)
+
+
+class FPState(FPBase):
+ def __init__(self, state_from):
+ self.state_from = state_from
+
+ def set_inputs(self, inputs):
+ self.inputs = inputs
+ for k,v in inputs.items():
+ setattr(self, k, v)
+
+ def set_outputs(self, outputs):
+ self.outputs = outputs
+ for k,v in outputs.items():
+ setattr(self, k, v)
+
+
+class FPID:
+ def __init__(self, id_wid):
+ self.id_wid = id_wid
+ if self.id_wid:
+ self.in_mid = Signal(id_wid, reset_less=True)
+ self.out_mid = Signal(id_wid, reset_less=True)
+ else:
+ self.in_mid = None
+ self.out_mid = None
+
+ def idsync(self, m):
+ if self.id_wid is not None:
+ m.d.sync += self.out_mid.eq(self.in_mid)