gallium/swr: Fix compilation with LLVM 12
[mesa.git] / src / gallium / drivers / swr / rasterizer / jitter / builder_misc.cpp
index 6687ead02d308efb17ad5bfe466c6afea0f73ac3..530752850c611802978fa37942136c7ef14e6b4f 100644 (file)
@@ -108,45 +108,6 @@ namespace SwrJit
         return (uint16_t)tmpVal;
     }
 
-    //////////////////////////////////////////////////////////////////////////
-    /// @brief Convert an IEEE 754 16-bit float to an 32-bit single precision
-    ///        float
-    /// @param val - 16-bit float
-    /// @todo Maybe move this outside of this file into a header?
-    static float ConvertFloat16ToFloat32(uint32_t val)
-    {
-        uint32_t result;
-        if ((val & 0x7fff) == 0)
-        {
-            result = ((uint32_t)(val & 0x8000)) << 16;
-        }
-        else if ((val & 0x7c00) == 0x7c00)
-        {
-            result = ((val & 0x3ff) == 0) ? 0x7f800000 : 0x7fc00000;
-            result |= ((uint32_t)val & 0x8000) << 16;
-        }
-        else
-        {
-            uint32_t sign = (val & 0x8000) << 16;
-            uint32_t mant = (val & 0x3ff) << 13;
-            uint32_t exp  = (val >> 10) & 0x1f;
-            if ((exp == 0) && (mant != 0)) // Adjust exponent and mantissa for denormals
-            {
-                mant <<= 1;
-                while (mant < (0x400 << 13))
-                {
-                    exp--;
-                    mant <<= 1;
-                }
-                mant &= (0x3ff << 13);
-            }
-            exp    = ((exp - 15 + 127) & 0xff) << 23;
-            result = sign | exp | mant;
-        }
-
-        return *(float*)&result;
-    }
-
     Constant* Builder::C(bool i) { return ConstantInt::get(IRB()->getInt1Ty(), (i ? 1 : 0)); }
 
     Constant* Builder::C(char i) { return ConstantInt::get(IRB()->getInt8Ty(), i); }
@@ -172,69 +133,109 @@ namespace SwrJit
 
     Value* Builder::VIMMED1(uint64_t i)
     {
+#if LLVM_VERSION_MAJOR > 10
+        return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantInt>(C(i)));
+#else
         return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));
+#endif
     }
 
     Value* Builder::VIMMED1_16(uint64_t i)
     {
+#if LLVM_VERSION_MAJOR > 10
+        return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantInt>(C(i)));
+#else
         return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));
+#endif
     }
 
     Value* Builder::VIMMED1(int i)
     {
+#if LLVM_VERSION_MAJOR > 10
+        return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantInt>(C(i)));
+#else
         return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));
+#endif
     }
 
     Value* Builder::VIMMED1_16(int i)
     {
+#if LLVM_VERSION_MAJOR > 10
+        return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantInt>(C(i)));
+#else
         return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));
+#endif
     }
 
     Value* Builder::VIMMED1(uint32_t i)
     {
+#if LLVM_VERSION_MAJOR > 10
+        return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantInt>(C(i)));
+#else
         return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));
+#endif
     }
 
     Value* Builder::VIMMED1_16(uint32_t i)
     {
+#if LLVM_VERSION_MAJOR > 10
+        return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantInt>(C(i)));
+#else
         return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));
+#endif
     }
 
     Value* Builder::VIMMED1(float i)
     {
+#if LLVM_VERSION_MAJOR > 10
+        return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantFP>(C(i)));
+#else
         return ConstantVector::getSplat(mVWidth, cast<ConstantFP>(C(i)));
+#endif
     }
 
     Value* Builder::VIMMED1_16(float i)
     {
+#if LLVM_VERSION_MAJOR > 10
+        return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantFP>(C(i)));
+#else
         return ConstantVector::getSplat(mVWidth16, cast<ConstantFP>(C(i)));
+#endif
     }
 
     Value* Builder::VIMMED1(bool i)
     {
+#if LLVM_VERSION_MAJOR > 10
+        return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantInt>(C(i)));
+#else
         return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));
+#endif
     }
 
     Value* Builder::VIMMED1_16(bool i)
     {
+#if LLVM_VERSION_MAJOR > 10
+        return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantInt>(C(i)));
+#else
         return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));
+#endif
     }
 
-    Value* Builder::VUNDEF_IPTR() { return UndefValue::get(VectorType::get(mInt32PtrTy, mVWidth)); }
+    Value* Builder::VUNDEF_IPTR() { return UndefValue::get(getVectorType(mInt32PtrTy, mVWidth)); }
 
-    Value* Builder::VUNDEF(Type* t) { return UndefValue::get(VectorType::get(t, mVWidth)); }
+    Value* Builder::VUNDEF(Type* t) { return UndefValue::get(getVectorType(t, mVWidth)); }
 
-    Value* Builder::VUNDEF_I() { return UndefValue::get(VectorType::get(mInt32Ty, mVWidth)); }
+    Value* Builder::VUNDEF_I() { return UndefValue::get(getVectorType(mInt32Ty, mVWidth)); }
 
-    Value* Builder::VUNDEF_I_16() { return UndefValue::get(VectorType::get(mInt32Ty, mVWidth16)); }
+    Value* Builder::VUNDEF_I_16() { return UndefValue::get(getVectorType(mInt32Ty, mVWidth16)); }
 
-    Value* Builder::VUNDEF_F() { return UndefValue::get(VectorType::get(mFP32Ty, mVWidth)); }
+    Value* Builder::VUNDEF_F() { return UndefValue::get(getVectorType(mFP32Ty, mVWidth)); }
 
-    Value* Builder::VUNDEF_F_16() { return UndefValue::get(VectorType::get(mFP32Ty, mVWidth16)); }
+    Value* Builder::VUNDEF_F_16() { return UndefValue::get(getVectorType(mFP32Ty, mVWidth16)); }
 
     Value* Builder::VUNDEF(Type* ty, uint32_t size)
     {
-        return UndefValue::get(VectorType::get(ty, size));
+        return UndefValue::get(getVectorType(ty, size));
     }
 
     Value* Builder::VBROADCAST(Value* src, const llvm::Twine& name)
@@ -280,14 +281,24 @@ namespace SwrJit
         std::vector<Value*> args;
         for (auto arg : argsList)
             args.push_back(arg);
+#if LLVM_VERSION_MAJOR >= 11
+        // see comment to CALLA(Callee) function in the header
+        return CALLA(FunctionCallee(cast<Function>(Callee)), args, name);
+#else
         return CALLA(Callee, args, name);
+#endif
     }
 
     CallInst* Builder::CALL(Value* Callee, Value* arg)
     {
         std::vector<Value*> args;
         args.push_back(arg);
+#if LLVM_VERSION_MAJOR >= 11
+        // see comment to CALLA(Callee) function in the header
+        return CALLA(FunctionCallee(cast<Function>(Callee)), args);
+#else
         return CALLA(Callee, args);
+#endif
     }
 
     CallInst* Builder::CALL2(Value* Callee, Value* arg1, Value* arg2)
@@ -295,7 +306,12 @@ namespace SwrJit
         std::vector<Value*> args;
         args.push_back(arg1);
         args.push_back(arg2);
+#if LLVM_VERSION_MAJOR >= 11
+        // see comment to CALLA(Callee) function in the header
+        return CALLA(FunctionCallee(cast<Function>(Callee)), args);
+#else
         return CALLA(Callee, args);
+#endif
     }
 
     CallInst* Builder::CALL3(Value* Callee, Value* arg1, Value* arg2, Value* arg3)
@@ -304,7 +320,12 @@ namespace SwrJit
         args.push_back(arg1);
         args.push_back(arg2);
         args.push_back(arg3);
+#if LLVM_VERSION_MAJOR >= 11
+        // see comment to CALLA(Callee) function in the header
+        return CALLA(FunctionCallee(cast<Function>(Callee)), args);
+#else
         return CALLA(Callee, args);
+#endif
     }
 
     Value* Builder::VRCP(Value* va, const llvm::Twine& name)
@@ -351,7 +372,9 @@ namespace SwrJit
             if (pType->isVectorTy())
             {
                 Type* pContainedType = pType->getContainedType(0);
-
+#if LLVM_VERSION_MAJOR >= 11
+                VectorType* pVectorType = cast<VectorType>(pType);
+#endif
                 if (toupper(tempStr[pos + 1]) == 'X')
                 {
                     tempStr[pos]     = '0';
@@ -362,7 +385,11 @@ namespace SwrJit
                     printCallArgs.push_back(VEXTRACT(pArg, C(0)));
 
                     std::string vectorFormatStr;
+#if LLVM_VERSION_MAJOR >= 11
+                    for (uint32_t i = 1; i < pVectorType->getNumElements(); ++i)
+#else
                     for (uint32_t i = 1; i < pType->getVectorNumElements(); ++i)
+#endif
                     {
                         vectorFormatStr += "0x%08X ";
                         printCallArgs.push_back(VEXTRACT(pArg, C(i)));
@@ -374,7 +401,11 @@ namespace SwrJit
                 else if ((tempStr[pos + 1] == 'f') && (pContainedType->isFloatTy()))
                 {
                     uint32_t i = 0;
-                    for (; i < (pArg->getType()->getVectorNumElements()) - 1; i++)
+#if LLVM_VERSION_MAJOR >= 11
+                    for (; i < pVectorType->getNumElements() - 1; i++)
+#else
+                    for (; i < pType->getVectorNumElements() - 1; i++)
+#endif
                     {
                         tempStr.insert(pos, std::string("%f "));
                         pos += 3;
@@ -387,7 +418,11 @@ namespace SwrJit
                 else if ((tempStr[pos + 1] == 'd') && (pContainedType->isIntegerTy()))
                 {
                     uint32_t i = 0;
-                    for (; i < (pArg->getType()->getVectorNumElements()) - 1; i++)
+#if LLVM_VERSION_MAJOR >= 11
+                    for (; i < pVectorType->getNumElements() - 1; i++)
+#else
+                    for (; i < pType->getVectorNumElements() - 1; i++)
+#endif
                     {
                         tempStr.insert(pos, std::string("%d "));
                         pos += 3;
@@ -400,7 +435,11 @@ namespace SwrJit
                 else if ((tempStr[pos + 1] == 'u') && (pContainedType->isIntegerTy()))
                 {
                     uint32_t i = 0;
-                    for (; i < (pArg->getType()->getVectorNumElements()) - 1; i++)
+#if LLVM_VERSION_MAJOR >= 11
+                    for (; i < pVectorType->getNumElements() - 1; i++)
+#else
+                    for (; i < pType->getVectorNumElements() - 1; i++)
+#endif
                     {
                         tempStr.insert(pos, std::string("%d "));
                         pos += 3;
@@ -455,7 +494,11 @@ namespace SwrJit
         args.push_back(PointerType::get(mInt8Ty, 0));
         FunctionType* callPrintTy = FunctionType::get(Type::getVoidTy(JM()->mContext), args, true);
         Function*     callPrintFn =
+#if LLVM_VERSION_MAJOR >= 9
+            cast<Function>(JM()->mpCurrentModule->getOrInsertFunction("CallPrint", callPrintTy).getCallee());
+#else
             cast<Function>(JM()->mpCurrentModule->getOrInsertFunction("CallPrint", callPrintTy));
+#endif
 
         // if we haven't yet added the symbol to the symbol table
         if ((sys::DynamicLibrary::SearchForAddressOfSymbol("CallPrint")) == nullptr)
@@ -511,8 +554,14 @@ namespace SwrJit
     /// @brief Convert <Nxi1> llvm mask to integer
     Value* Builder::VMOVMSK(Value* mask)
     {
+#if LLVM_VERSION_MAJOR >= 11
+        VectorType* pVectorType = cast<VectorType>(mask->getType());
+        SWR_ASSERT(pVectorType->getElementType() == mInt1Ty);
+        uint32_t numLanes = pVectorType->getNumElements();
+#else
         SWR_ASSERT(mask->getType()->getVectorElementType() == mInt1Ty);
         uint32_t numLanes = mask->getType()->getVectorNumElements();
+#endif
         Value*   i32Result;
         if (numLanes == 8)
         {
@@ -549,10 +598,11 @@ namespace SwrJit
         else
         {
             Constant* cB = dyn_cast<Constant>(b);
+            assert(cB != nullptr);
             // number of 8 bit elements in b
             uint32_t numElms = cast<VectorType>(cB->getType())->getNumElements();
             // output vector
-            Value* vShuf = UndefValue::get(VectorType::get(mInt8Ty, numElms));
+            Value* vShuf = UndefValue::get(getVectorType(mInt8Ty, numElms));
 
             // insert an 8 bit value from the high and low lanes of a per loop iteration
             numElms /= 2;
@@ -593,7 +643,7 @@ namespace SwrJit
     Value* Builder::PMOVSXBD(Value* a)
     {
         // VPMOVSXBD output type
-        Type* v8x32Ty = VectorType::get(mInt32Ty, 8);
+        Type* v8x32Ty = getVectorType(mInt32Ty, 8);
         // Extract 8 values from 128bit lane and sign extend
         return S_EXT(VSHUFFLE(a, a, C<int>({0, 1, 2, 3, 4, 5, 6, 7})), v8x32Ty);
     }
@@ -605,7 +655,7 @@ namespace SwrJit
     Value* Builder::PMOVSXWD(Value* a)
     {
         // VPMOVSXWD output type
-        Type* v8x32Ty = VectorType::get(mInt32Ty, 8);
+        Type* v8x32Ty = getVectorType(mInt32Ty, 8);
         // Extract 8 values from 128bit lane and sign extend
         return S_EXT(VSHUFFLE(a, a, C<int>({0, 1, 2, 3, 4, 5, 6, 7})), v8x32Ty);
     }
@@ -616,33 +666,15 @@ namespace SwrJit
     /// @param a - 128bit SIMD lane(8x16bit) of float16 in int16 format.
     Value* Builder::CVTPH2PS(Value* a, const llvm::Twine& name)
     {
-        if (JM()->mArch.F16C())
-        {
-            return VCVTPH2PS(a, name);
-        }
-        else
-        {
-            FunctionType* pFuncTy   = FunctionType::get(mFP32Ty, mInt16Ty);
-            Function*     pCvtPh2Ps = cast<Function>(
-                JM()->mpCurrentModule->getOrInsertFunction("ConvertFloat16ToFloat32", pFuncTy));
-
-            if (sys::DynamicLibrary::SearchForAddressOfSymbol("ConvertFloat16ToFloat32") == nullptr)
-            {
-                sys::DynamicLibrary::AddSymbol("ConvertFloat16ToFloat32",
-                                               (void*)&ConvertFloat16ToFloat32);
-            }
-
-            Value* pResult = UndefValue::get(mSimdFP32Ty);
-            for (uint32_t i = 0; i < mVWidth; ++i)
-            {
-                Value* pSrc  = VEXTRACT(a, C(i));
-                Value* pConv = CALL(pCvtPh2Ps, std::initializer_list<Value*>{pSrc});
-                pResult      = VINSERT(pResult, pConv, C(i));
-            }
+        // Bitcast Nxint16 to Nxhalf
+#if LLVM_VERSION_MAJOR >= 11
+        uint32_t numElems = cast<VectorType>(a->getType())->getNumElements();
+#else
+        uint32_t numElems = a->getType()->getVectorNumElements();
+#endif
+        Value*   input    = BITCAST(a, getVectorType(mFP16Ty, numElems));
 
-            pResult->setName(name);
-            return pResult;
-        }
+        return FP_EXT(input, getVectorType(mFP32Ty, numElems), name);
     }
 
     //////////////////////////////////////////////////////////////////////////
@@ -660,7 +692,11 @@ namespace SwrJit
             // call scalar C function for now
             FunctionType* pFuncTy   = FunctionType::get(mInt16Ty, mFP32Ty);
             Function*     pCvtPs2Ph = cast<Function>(
+#if LLVM_VERSION_MAJOR >= 9
+                JM()->mpCurrentModule->getOrInsertFunction("ConvertFloat32ToFloat16", pFuncTy).getCallee());
+#else
                 JM()->mpCurrentModule->getOrInsertFunction("ConvertFloat32ToFloat16", pFuncTy));
+#endif
 
             if (sys::DynamicLibrary::SearchForAddressOfSymbol("ConvertFloat32ToFloat16") == nullptr)
             {
@@ -969,7 +1005,11 @@ namespace SwrJit
 
             FunctionType* pFuncTy = FunctionType::get(Type::getVoidTy(JM()->mContext), args, false);
             Function*     pFunc   = cast<Function>(
+#if LLVM_VERSION_MAJOR >= 9
+                JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StartBucket", pFuncTy).getCallee());
+#else
                 JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StartBucket", pFuncTy));
+#endif
             if (sys::DynamicLibrary::SearchForAddressOfSymbol("BucketManager_StartBucket") ==
                 nullptr)
             {
@@ -994,7 +1034,11 @@ namespace SwrJit
 
             FunctionType* pFuncTy = FunctionType::get(Type::getVoidTy(JM()->mContext), args, false);
             Function*     pFunc   = cast<Function>(
+#if LLVM_VERSION_MAJOR >= 9
+                JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StopBucket", pFuncTy).getCallee());
+#else
                 JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StopBucket", pFuncTy));
+#endif
             if (sys::DynamicLibrary::SearchForAddressOfSymbol("BucketManager_StopBucket") ==
                 nullptr)
             {