# License: BSD
 
 from migen import *
-from migen.fhdl.specials import Special
+from migen.fhdl.specials import Special, Tristate
 
 # Differential Input/Output ------------------------------------------------------------------------
 
 
 class SDRIO(Special):
     def __init__(self, i, o, clk=ClockSignal()):
-        assert len(i) == len(o)
+        assert len(i) == len(o) == 1
         Special.__init__(self)
-        print(o)
         self.i            = wrap(i)
         self.o            = wrap(o)
         self.clk          = wrap(clk)
 class SDRInput(SDRIO):  pass
 class SDROutput(SDRIO): pass
 
+# SDR Tristate -------------------------------------------------------------------------------------
+
+class InferedSDRTristate(Module):
+    def __init__(self, io, o, oe, i, clk, clk_domain):
+        if clk_domain is None:
+            raise NotImplementedError("Attempted to use an SDRTristate but no clk_domain specified.")
+        _o  = Signal()
+        _oe = Signal()
+        _i  = Signal()
+        self.specials += SDROutput(o, _o)
+        self.specials += SDRInput(_i, i)
+        self.submodules += InferedSDRIO(oe, _oe, clk, clk_domain)
+        self.specials += Tristate(io, _o, _oe, _i)
+
+class SDRTristate(Special):
+    def __init__(self, io, o, oe, i, clk=ClockSignal()):
+        assert len(i) == len(o) == len(oe)
+        Special.__init__(self)
+        self.io           = wrap(io)
+        self.o            = wrap(o)
+        self.oe           = wrap(oe)
+        self.i            = wrap(i)
+        self.clk          = wrap(clk)
+        self.clk_domain   = None if not hasattr(clk, "cd") else clk.cd
+
+    def iter_expressions(self):
+        yield self, "io",  SPECIAL_INOUT
+        yield self, "o",   SPECIAL_INPUT
+        yield self, "oe",  SPECIAL_INPUT
+        yield self, "i",   SPECIAL_OUTPUT
+        yield self, "clk", SPECIAL_INPUT
+
+    @staticmethod
+    def lower(dr):
+        return InferedSDRTristate(dr.io, dr.o, dr.oe, dr.i, dr.clk, dr.clk_domain)
+
 # DDR Input/Output ---------------------------------------------------------------------------------
 
 class DDRInput(Special):