from nmutil.latch import SRLatch, latchregister
 from nmutil.byterev import byte_reverse
+from nmutil.extend import exts
 
 from soc.experiment.compalu_multi import go_record, CompUnitRecord
 from soc.experiment.l0_cache import PortInterface
         p_st_go = Signal(reset_less=True)
         sync += p_st_go.eq(self.st.go_i)
 
+        # decode bits of operand (latched)
+        oper_r = CompLDSTOpSubset(name="oper_r")  # Dest register
+        comb += op_is_st.eq(oper_r.insn_type == MicrOp.OP_STORE)  # ST
+        comb += op_is_ld.eq(oper_r.insn_type == MicrOp.OP_LOAD)  # LD
+        op_is_update = oper_r.ldst_mode == LDSTMode.update           # UPDATE
+        op_is_cix = oper_r.ldst_mode == LDSTMode.cix           # cache-inhibit
+        comb += self.load_mem_o.eq(op_is_ld & self.go_ad_i)
+        comb += self.stwd_mem_o.eq(op_is_st & self.go_st_i)
+        comb += self.ld_o.eq(op_is_ld)
+        comb += self.st_o.eq(op_is_st)
+
         ##########################
         # FSM implemented through sequence of latches.  approximately this:
         # - opc_l       : opcode
 
         # dest operand latch
         comb += wri_l.s.eq(issue_i)
-        sync += wri_l.r.eq(reset_w | Repl(self.done_o, self.n_dst))
+        sync += wri_l.r.eq(reset_w | Repl(self.done_o |
+                                          (self.pi.busy_o & op_is_update),
+                                          self.n_dst))
 
         # update-mode operand latch (EA written to reg 2)
         sync += upd_l.s.eq(reset_i)
         comb += rst_l.r.eq(issue_i)
 
         # create a latch/register for the operand
-        oper_r = CompLDSTOpSubset(name="oper_r")  # Dest register
         with m.If(self.issue_i):
             sync += oper_r.eq(self.oper_i)
         with m.If(self.done_o):
         comb += alu_o.eq(src1_or_z + src2_or_imm)  # actual EA
         m.d.sync += alu_ok.eq(alu_valid)             # keep ack in sync with EA
 
-        # decode bits of operand (latched)
-        comb += op_is_st.eq(oper_r.insn_type == MicrOp.OP_STORE)  # ST
-        comb += op_is_ld.eq(oper_r.insn_type == MicrOp.OP_LOAD)  # LD
-        op_is_update = oper_r.ldst_mode == LDSTMode.update           # UPDATE
-        op_is_cix = oper_r.ldst_mode == LDSTMode.cix           # cache-inhibit
-        comb += self.load_mem_o.eq(op_is_ld & self.go_ad_i)
-        comb += self.stwd_mem_o.eq(op_is_st & self.go_st_i)
-        comb += self.ld_o.eq(op_is_ld)
-        comb += self.st_o.eq(op_is_st)
-
         ############################
         # Control Signal calculation
 
         comb += addr_ok.eq(self.pi.addr_ok_o)  # no exc, address fine
 
         # byte-reverse on LD
+        revnorev = Signal(64, reset_less=True)
         with m.If(oper_r.byte_reverse):
             # byte-reverse the data based on ld/st width (turn it to LE)
             data_len = oper_r.data_len
             lddata_r = byte_reverse(m, 'lddata_r', pi.ld.data, data_len)
-            comb += ldd_o.eq(lddata_r)  # put reversed- data out
+            comb += revnorev.eq(lddata_r)  # put reversed- data out
         with m.Else():
-            comb += ldd_o.eq(pi.ld.data)  # put data out, straight (as BE)
+            comb += revnorev.eq(pi.ld.data)  # put data out, straight (as BE)
+
+        # then check sign-extend
+        with m.If(oper_r.sign_extend):
+            comb += ldd_o.eq(exts(revnorev, 32, 64))  # sign-extend
+        with m.Else():
+            comb += ldd_o.eq(revnorev)
+
         # ld - ld gets latched in via lod_l
         comb += ld_ok.eq(pi.ld.ok)  # ld.ok *closes* (freezes) ld data
 
 
         self.add_case(Program(lst, bigendian), initial_regs,
                              initial_mem=initial_mem)
 
+    def case_9_load_algebraic_1(self):
+        lst = ["lwax 3, 4, 2"]
+        initial_regs = [0] * 32
+        initial_regs[1] = 0x5678
+        initial_regs[2] = 0x001c
+        initial_regs[4] = 0x0008
+        initial_mem = {0x0000: (0x5432123412345678, 8),
+                       0x0008: (0xabcdef0187654321, 8),
+                       0x0020: (0xf000000f0000ffff, 8),
+                        }
+        self.add_case(Program(lst, bigendian), initial_regs,
+                             initial_mem=initial_mem)
+
+    def case_9_load_algebraic_2(self):
+        lst = ["lwax 3, 4, 2"]
+        initial_regs = [0] * 32
+        initial_regs[1] = 0x5678
+        initial_regs[2] = 0x001c
+        initial_regs[4] = 0x0008
+        initial_mem = {0x0000: (0x5432123412345678, 8),
+                       0x0008: (0xabcdef0187654321, 8),
+                       0x0020: (0x7000000f0000ffff, 8),
+                        }
+        self.add_case(Program(lst, bigendian), initial_regs,
+                             initial_mem=initial_mem)
+
+    def case_9_load_algebraic_3(self):
+        lst = ["lwaux 3, 4, 2"]
+        initial_regs = [0] * 32
+        initial_regs[1] = 0x5678
+        initial_regs[2] = 0x001c
+        initial_regs[4] = 0x0008
+        initial_mem = {0x0000: (0x5432123412345678, 8),
+                       0x0008: (0xabcdef0187654321, 8),
+                       0x0020: (0xf000000f0000ffff, 8),
+                        }
+        self.add_case(Program(lst, bigendian), initial_regs,
+                             initial_mem=initial_mem)
+