ARM: Implement signed saturating add and/or subtract instructions.
authorGabe Black <gblack@eecs.umich.edu>
Wed, 2 Jun 2010 17:58:05 +0000 (12:58 -0500)
committerGabe Black <gblack@eecs.umich.edu>
Wed, 2 Jun 2010 17:58:05 +0000 (12:58 -0500)
src/arch/arm/insts/static_inst.hh
src/arch/arm/isa/insts/data.isa

index 485d6997ed14648daa978a4b76d9eea0821d6761..634cf08127dcec0846967430becdbfa4db2570ed 100644 (file)
@@ -60,6 +60,23 @@ class ArmStaticInst : public StaticInst
     bool shift_carry_rs(uint32_t base, uint32_t shamt,
                         uint32_t type, uint32_t cfval) const;
 
+    template<int width>
+    static bool
+    saturateOp(int32_t &res, int64_t op1, int64_t op2, bool sub=false)
+    {
+        int64_t midRes = sub ? (op1 - op2) : (op1 + op2);
+        if (bits(midRes, width) != bits(midRes, width - 1)) {
+            if (midRes > 0)
+                res = (1LL << (width - 1)) - 1;
+            else
+                res = -(1LL << (width - 1));
+            return true;
+        } else {
+            res = midRes;
+            return false;
+        }
+    }
+
     // Constructor
     ArmStaticInst(const char *mnem, ExtMachInst _machInst,
                   OpClass __opClass)
index 69e813d2541a4d559556ff9a2e91280300640a08..9de42807e660547770ec5f52eaf5ab3f7c404b4b 100644 (file)
@@ -44,8 +44,7 @@ let {{
     exec_output = ""
 
     calcQCode = '''
-        cprintf("canOverflow: %%d\\n", Dest < resTemp);
-        replaceBits(CondCodes, 27, Dest < resTemp);
+        CondCodes = CondCodes | ((resTemp & 1) << 27);
     '''
 
     calcCcCode = '''
@@ -68,6 +67,7 @@ let {{
     carryCode = {
         "none": (oldC, oldC, oldC),
         "llbit": (oldC, oldC, oldC),
+        "saturate": ('0', '0', '0'),
         "overflow": ('0', '0', '0'),
         "add": ('findCarry(32, resTemp, Op1, secondOp)',
                 'findCarry(32, resTemp, Op1, secondOp)',
@@ -86,6 +86,7 @@ let {{
     overflowCode = {
         "none": oldV,
         "llbit": oldV,
+        "saturate": '0',
         "overflow": '0',
         "add": 'findOverflow(32, resTemp, Op1, secondOp)',
         "sub": 'findOverflow(32, resTemp, Op1, ~secondOp)',
@@ -98,14 +99,14 @@ let {{
     regOp2 = "shift_rm_imm(Op2, shiftAmt, shiftType, CondCodes<29:>)"
     regRegOp2 = "shift_rm_rs(Op2, Shift<7:0>, shiftType, CondCodes<29:>)"
 
-    def buildImmDataInst(mnem, code, flagType = "logic", \
-                         suffix = "Imm", buildCc = True):
+    def buildImmDataInst(mnem, code, flagType = "logic", suffix = "Imm", \
+                         buildCc = True, buildNonCc = True):
         cCode = carryCode[flagType]
         vCode = overflowCode[flagType]
         negBit = 31
         if flagType == "llbit":
             negBit = 63
-        if flagType == "overflow":
+        if flagType == "saturate":
             immCcCode = calcQCode
         else:
             immCcCode = calcCcCode % {
@@ -128,18 +129,19 @@ let {{
             decoder_output += DataImmConstructor.subst(iop)
             exec_output += PredOpExecute.subst(iop)
 
-        subst(immIop)
+        if buildNonCc:
+            subst(immIop)
         if buildCc:
             subst(immIopCc)
 
-    def buildRegDataInst(mnem, code, flagType = "logic", \
-                         suffix = "Reg", buildCc = True):
+    def buildRegDataInst(mnem, code, flagType = "logic", suffix = "Reg", \
+                         buildCc = True, buildNonCc = True):
         cCode = carryCode[flagType]
         vCode = overflowCode[flagType]
         negBit = 31
         if flagType == "llbit":
             negBit = 63
-        if flagType == "overflow":
+        if flagType == "saturate":
             regCcCode = calcQCode
         else:
             regCcCode = calcCcCode % {
@@ -162,18 +164,20 @@ let {{
             decoder_output += DataRegConstructor.subst(iop)
             exec_output += PredOpExecute.subst(iop)
 
-        subst(regIop)
+        if buildNonCc:
+            subst(regIop)
         if buildCc:
             subst(regIopCc)
 
     def buildRegRegDataInst(mnem, code, flagType = "logic", \
-                            suffix = "RegReg", buildCc = True):
+                            suffix = "RegReg", \
+                            buildCc = True, buildNonCc = True):
         cCode = carryCode[flagType]
         vCode = overflowCode[flagType]
         negBit = 31
         if flagType == "llbit":
             negBit = 63
-        if flagType == "overflow":
+        if flagType == "saturate":
             regRegCcCode = calcQCode
         else:
             regRegCcCode = calcCcCode % {
@@ -198,7 +202,8 @@ let {{
             decoder_output += DataRegRegConstructor.subst(iop)
             exec_output += PredOpExecute.subst(iop)
 
-        subst(regRegIop)
+        if buildNonCc:
+            subst(regRegIop)
         if buildCc:
             subst(regRegIopCc)
 
@@ -250,4 +255,99 @@ let {{
     buildDataInst("movt",
                   "Dest = resTemp = insertBits(Op1, 31, 16, secondOp);",
                   aiw = False)
+
+    buildRegDataInst("qadd", '''
+            int32_t midRes;
+            resTemp = saturateOp<32>(midRes, Op1.sw, Op2.sw);
+                                     Dest = midRes;
+        ''', flagType="saturate", buildNonCc=False)
+    buildRegDataInst("qadd16", '''
+            int32_t midRes;
+            for (unsigned i = 0; i < 2; i++) {
+                int high = (i + 1) * 16 - 1;
+                int low = i * 16;
+                int64_t arg1 = sext<16>(bits(Op1.sw, high, low));
+                int64_t arg2 = sext<16>(bits(Op2.sw, high, low));
+                saturateOp<16>(midRes, arg1, arg2);
+                replaceBits(resTemp, high, low, midRes);
+            }
+            Dest = resTemp;
+        ''', flagType="none", buildCc=False)
+    buildRegDataInst("qadd8", '''
+            int32_t midRes;
+            for (unsigned i = 0; i < 4; i++) {
+                int high = (i + 1) * 8 - 1;
+                int low = i * 8;
+                int64_t arg1 = sext<8>(bits(Op1.sw, high, low));
+                int64_t arg2 = sext<8>(bits(Op2.sw, high, low));
+                saturateOp<8>(midRes, arg1, arg2);
+                replaceBits(resTemp, high, low, midRes);
+            }
+            Dest = resTemp;
+        ''', flagType="none", buildCc=False)
+    buildRegDataInst("qdadd", '''
+            int32_t midRes;
+            resTemp = saturateOp<32>(midRes, Op2.sw, Op2.sw) |
+                      saturateOp<32>(midRes, Op1.sw, midRes);
+            Dest = midRes;
+        ''', flagType="saturate", buildNonCc=False)
+    buildRegDataInst("qsub", '''
+            int32_t midRes;
+            resTemp = saturateOp<32>(midRes, Op1.sw, Op2.sw, true);
+            Dest = midRes;
+        ''', flagType="saturate")
+    buildRegDataInst("qsub16", '''
+            int32_t midRes;
+            for (unsigned i = 0; i < 2; i++) {
+                 int high = (i + 1) * 16 - 1;
+                 int low = i * 16;
+                 int64_t arg1 = sext<16>(bits(Op1.sw, high, low));
+                 int64_t arg2 = sext<16>(bits(Op2.sw, high, low));
+                 saturateOp<16>(midRes, arg1, arg2, true);
+                 replaceBits(resTemp, high, low, midRes);
+            }
+            Dest = resTemp;
+        ''', flagType="none", buildCc=False)
+    buildRegDataInst("qsub8", '''
+            int32_t midRes;
+            for (unsigned i = 0; i < 4; i++) {
+                 int high = (i + 1) * 8 - 1;
+                 int low = i * 8;
+                 int64_t arg1 = sext<8>(bits(Op1.sw, high, low));
+                 int64_t arg2 = sext<8>(bits(Op2.sw, high, low));
+                 saturateOp<8>(midRes, arg1, arg2, true);
+                 replaceBits(resTemp, high, low, midRes);
+            }
+            Dest = resTemp;
+        ''', flagType="none", buildCc=False)
+    buildRegDataInst("qdsub", '''
+            int32_t midRes;
+            resTemp = saturateOp<32>(midRes, Op2.sw, Op2.sw) |
+                      saturateOp<32>(midRes, Op1.sw, midRes, true);
+            Dest = midRes;
+        ''', flagType="saturate", buildNonCc=False)
+    buildRegDataInst("qasx", '''
+            int32_t midRes;
+            int64_t arg1Low = sext<16>(bits(Op1.sw, 15, 0));
+            int64_t arg1High = sext<16>(bits(Op1.sw, 31, 16));
+            int64_t arg2Low = sext<16>(bits(Op2.sw, 15, 0));
+            int64_t arg2High = sext<16>(bits(Op2.sw, 31, 16));
+            saturateOp<16>(midRes, arg1Low, arg2High, true);
+            replaceBits(resTemp, 15, 0, midRes);
+            saturateOp<16>(midRes, arg1High, arg2Low);
+            replaceBits(resTemp, 31, 16, midRes);
+            Dest = resTemp;
+        ''', flagType="none", buildCc=False)
+    buildRegDataInst("qsax", '''
+            int32_t midRes;
+            int64_t arg1Low = sext<16>(bits(Op1.sw, 15, 0));
+            int64_t arg1High = sext<16>(bits(Op1.sw, 31, 16));
+            int64_t arg2Low = sext<16>(bits(Op2.sw, 15, 0));
+            int64_t arg2High = sext<16>(bits(Op2.sw, 31, 16));
+            saturateOp<16>(midRes, arg1Low, arg2High);
+            replaceBits(resTemp, 15, 0, midRes);
+            saturateOp<16>(midRes, arg1High, arg2Low, true);
+            replaceBits(resTemp, 31, 16, midRes);
+            Dest = resTemp;
+        ''', flagType="none", buildCc=False)
 }};