swr/rast: Refactor to improve code sharing.
authorGeorge Kyriazis <george.kyriazis@intel.com>
Mon, 9 Apr 2018 18:35:43 +0000 (13:35 -0500)
committerGeorge Kyriazis <george.kyriazis@intel.com>
Wed, 18 Apr 2018 15:51:38 +0000 (10:51 -0500)
Reviewed-by: Bruce Cherniak <bruce.cherniak@intel.com>
src/gallium/drivers/swr/rasterizer/jitter/fetch_jit.cpp

index 767866f68b1329726f42170fb0665eb3564fb584..af97b83cb2d6e3fb3261a8b3b40c346cd83c21c3 100644 (file)
@@ -63,6 +63,7 @@ struct FetchJit : public BuilderGfxMem
     Value* GetSimdValid32bitIndices(Value* vIndices, Value* pLastIndex);
     Value* GetSimdValid16bitIndices(Value* vIndices, Value* pLastIndex);
     Value* GetSimdValid8bitIndices(Value* vIndices, Value* pLastIndex);
+    template<typename T> Value* GetSimdValidIndicesHelper(Value* pIndices, Value* pLastIndex);
 
     // package up Shuffle*bpcGatherd args into a tuple for convenience
     typedef std::tuple<Value*&, Value*, const Instruction::CastOps, const ConversionType,
@@ -985,37 +986,48 @@ typedef void*(*PFN_TRANSLATEGFXADDRESS_FUNC)(void* pdc, gfxptr_t va);
 extern "C" void GetSimdValid8bitIndicesGfx(gfxptr_t indices, gfxptr_t lastIndex, uint32_t vWidth, PFN_TRANSLATEGFXADDRESS_FUNC pfnTranslate, void* pdc, uint32_t* outIndices);
 extern "C" void GetSimdValid16bitIndicesGfx(gfxptr_t indices, gfxptr_t lastIndex, uint32_t vWidth, PFN_TRANSLATEGFXADDRESS_FUNC pfnTranslate, void* pdc, uint32_t* outIndices);
 
-//////////////////////////////////////////////////////////////////////////
-/// @brief Loads a simd of valid indices. OOB indices are set to 0
-/// *Note* have to do 8bit index checking in scalar until we have AVX-512
-/// support
-/// @param pIndices - pointer to 8 bit indices
-/// @param pLastIndex - pointer to last valid index
-Value* FetchJit::GetSimdValid8bitIndices(Value* pIndices, Value* pLastIndex)
+template<typename T> Value* FetchJit::GetSimdValidIndicesHelper(Value* pIndices, Value* pLastIndex)
 {
     SWR_ASSERT(pIndices->getType() == mInt64Ty && pLastIndex->getType() == mInt64Ty, "Function expects gfxptr_t for both input parameters.");
 
+    Type* Ty = nullptr;
+
+    static_assert(sizeof(T) == sizeof(uint16_t) || sizeof(T) == sizeof(uint8_t), "Unsupported type for use with GetSimdValidIndicesHelper<T>");
+    constexpr bool bSize = (sizeof(T) == sizeof(uint16_t));
+    if (bSize)
+    {
+        Ty = mInt16PtrTy;
+    }
+    else if (sizeof(T) == sizeof(uint8_t))
+    {
+        Ty = mInt8PtrTy;
+    }
+    else
+    {
+        SWR_ASSERT(false, "This should never happen as per static_assert above.");
+    }
+
     Value* vIndices = VUNDEF_I();
 
     {
         // store 0 index on stack to be used to conditionally load from if index address is OOB
-        Value* pZeroIndex = ALLOCA(mInt8Ty);
-        STORE(C((uint8_t)0), pZeroIndex);
+        Value* pZeroIndex = ALLOCA(Ty);
+        STORE(C((T)0), pZeroIndex);
 
         // Load a SIMD of index pointers
         for (int64_t lane = 0; lane < mVWidth; lane++)
         {
             // Calculate the address of the requested index
-            Value *pIndex = GEP(pIndices, C(lane), mInt8PtrTy);
+            Value *pIndex = GEP(pIndices, C(lane), Ty);
 
-            pLastIndex = INT_TO_PTR(pLastIndex, mInt8PtrTy);
+            pLastIndex = INT_TO_PTR(pLastIndex, Ty);
 
             // check if the address is less than the max index, 
             Value* mask = ICMP_ULT(pIndex, pLastIndex);
 
             // if valid, load the index. if not, load 0 from the stack
             Value* pValid = SELECT(mask, pIndex, pZeroIndex);
-            Value *index = LOAD(pValid, "valid index", PointerType::get(mInt8Ty, 0), GFX_MEM_CLIENT_FETCH);
+            Value *index = LOAD(pValid, "valid index", Ty, GFX_MEM_CLIENT_FETCH);
 
             // zero extended index to 32 bits and insert into the correct simd lane
             index = Z_EXT(index, mInt32Ty);
@@ -1026,6 +1038,17 @@ Value* FetchJit::GetSimdValid8bitIndices(Value* pIndices, Value* pLastIndex)
     return vIndices;
 }
 
+//////////////////////////////////////////////////////////////////////////
+/// @brief Loads a simd of valid indices. OOB indices are set to 0
+/// *Note* have to do 8bit index checking in scalar until we have AVX-512
+/// support
+/// @param pIndices - pointer to 8 bit indices
+/// @param pLastIndex - pointer to last valid index
+Value* FetchJit::GetSimdValid8bitIndices(Value* pIndices, Value* pLastIndex)
+{
+    return GetSimdValidIndicesHelper<uint8_t>(pIndices, pLastIndex);
+}
+
 //////////////////////////////////////////////////////////////////////////
 /// @brief Loads a simd of valid indices. OOB indices are set to 0
 /// *Note* have to do 16bit index checking in scalar until we have AVX-512
@@ -1034,37 +1057,7 @@ Value* FetchJit::GetSimdValid8bitIndices(Value* pIndices, Value* pLastIndex)
 /// @param pLastIndex - pointer to last valid index
 Value* FetchJit::GetSimdValid16bitIndices(Value* pIndices, Value* pLastIndex)
 {
-    SWR_ASSERT(pIndices->getType() == mInt64Ty && pLastIndex->getType() == mInt64Ty, "Function expects gfxptr_t for both input parameters.");
-
-    Value* vIndices = VUNDEF_I();
-
-    {
-        // store 0 index on stack to be used to conditionally load from if index address is OOB
-        Value* pZeroIndex = ALLOCA(mInt16Ty);
-        STORE(C((uint16_t)0), pZeroIndex);
-
-        // Load a SIMD of index pointers
-        for (int64_t lane = 0; lane < mVWidth; lane++)
-        {
-            // Calculate the address of the requested index
-            Value *pIndex = GEP(pIndices, C(lane), mInt16PtrTy);
-
-            pLastIndex = INT_TO_PTR(pLastIndex, mInt16PtrTy);
-
-            // check if the address is less than the max index, 
-            Value* mask = ICMP_ULT(pIndex, pLastIndex);
-
-            // if valid, load the index. if not, load 0 from the stack
-            Value* pValid = SELECT(mask, pIndex, pZeroIndex);
-            Value *index = LOAD(pValid, "valid index", PointerType::get(mInt16Ty, 0), GFX_MEM_CLIENT_FETCH);
-
-            // zero extended index to 32 bits and insert into the correct simd lane
-            index = Z_EXT(index, mInt32Ty);
-            vIndices = VINSERT(vIndices, index, lane);
-        }
-    }
-
-    return vIndices;
+    return GetSimdValidIndicesHelper<uint16_t>(pIndices, pLastIndex);
 }
 
 //////////////////////////////////////////////////////////////////////////