1 /****************************************************************************
2 * Copyright (C) 2014-2015 Intel Corporation. All Rights Reserved.
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
23 * @file builder_misc.cpp
25 * @brief Implementation for miscellaneous builder functions
29 ******************************************************************************/
30 #include "jit_pch.hpp"
32 #include "common/rdtsc_buckets.h"
36 extern "C" void CallPrint(const char* fmt
, ...);
40 //////////////////////////////////////////////////////////////////////////
41 /// @brief Convert an IEEE 754 32-bit single precision float to an
42 /// 16 bit float with 5 exponent bits and a variable
43 /// number of mantissa bits.
44 /// @param val - 32-bit float
45 /// @todo Maybe move this outside of this file into a header?
46 static uint16_t ConvertFloat32ToFloat16(float val
)
48 uint32_t sign
, exp
, mant
;
51 // Extract the sign, exponent, and mantissa
52 uint32_t uf
= *(uint32_t*)&val
;
53 sign
= (uf
& 0x80000000) >> 31;
54 exp
= (uf
& 0x7F800000) >> 23;
55 mant
= uf
& 0x007FFFFF;
57 // Check for out of range
62 sign
= 1; // set the sign bit for NANs
64 else if (std::isinf(val
))
69 else if (exp
> (0x70 + 0x1E)) // Too big to represent -> max representable value
74 else if ((exp
<= 0x70) && (exp
>= 0x66)) // It's a denorm
77 for (; exp
<= 0x70; mant
>>= 1, exp
++)
82 else if (exp
< 0x66) // Too small to represent -> Zero
89 // Saves bits that will be shifted off for rounding
90 roundBits
= mant
& 0x1FFFu
;
91 // convert exponent and mantissa to 16 bit format
95 // Essentially RTZ, but round up if off by only 1 lsb
96 if (roundBits
== 0x1FFFu
)
100 if ((mant
& 0xC00u
) != 0)
102 // make sure only the needed bits are used
107 uint32_t tmpVal
= (sign
<< 15) | (exp
<< 10) | mant
;
108 return (uint16_t)tmpVal
;
111 Constant
* Builder::C(bool i
) { return ConstantInt::get(IRB()->getInt1Ty(), (i
? 1 : 0)); }
113 Constant
* Builder::C(char i
) { return ConstantInt::get(IRB()->getInt8Ty(), i
); }
115 Constant
* Builder::C(uint8_t i
) { return ConstantInt::get(IRB()->getInt8Ty(), i
); }
117 Constant
* Builder::C(int i
) { return ConstantInt::get(IRB()->getInt32Ty(), i
); }
119 Constant
* Builder::C(int64_t i
) { return ConstantInt::get(IRB()->getInt64Ty(), i
); }
121 Constant
* Builder::C(uint16_t i
) { return ConstantInt::get(mInt16Ty
, i
); }
123 Constant
* Builder::C(uint32_t i
) { return ConstantInt::get(IRB()->getInt32Ty(), i
); }
125 Constant
* Builder::C(uint64_t i
) { return ConstantInt::get(IRB()->getInt64Ty(), i
); }
127 Constant
* Builder::C(float i
) { return ConstantFP::get(IRB()->getFloatTy(), i
); }
129 Constant
* Builder::PRED(bool pred
)
131 return ConstantInt::get(IRB()->getInt1Ty(), (pred
? 1 : 0));
134 Value
* Builder::VIMMED1(uint64_t i
)
136 #if LLVM_VERSION_MAJOR > 10
137 return ConstantVector::getSplat(ElementCount::get(mVWidth
, false), cast
<ConstantInt
>(C(i
)));
139 return ConstantVector::getSplat(mVWidth
, cast
<ConstantInt
>(C(i
)));
143 Value
* Builder::VIMMED1_16(uint64_t i
)
145 #if LLVM_VERSION_MAJOR > 10
146 return ConstantVector::getSplat(ElementCount::get(mVWidth16
, false), cast
<ConstantInt
>(C(i
)));
148 return ConstantVector::getSplat(mVWidth16
, cast
<ConstantInt
>(C(i
)));
152 Value
* Builder::VIMMED1(int i
)
154 #if LLVM_VERSION_MAJOR > 10
155 return ConstantVector::getSplat(ElementCount::get(mVWidth
, false), cast
<ConstantInt
>(C(i
)));
157 return ConstantVector::getSplat(mVWidth
, cast
<ConstantInt
>(C(i
)));
161 Value
* Builder::VIMMED1_16(int i
)
163 #if LLVM_VERSION_MAJOR > 10
164 return ConstantVector::getSplat(ElementCount::get(mVWidth16
, false), cast
<ConstantInt
>(C(i
)));
166 return ConstantVector::getSplat(mVWidth16
, cast
<ConstantInt
>(C(i
)));
170 Value
* Builder::VIMMED1(uint32_t i
)
172 #if LLVM_VERSION_MAJOR > 10
173 return ConstantVector::getSplat(ElementCount::get(mVWidth
, false), cast
<ConstantInt
>(C(i
)));
175 return ConstantVector::getSplat(mVWidth
, cast
<ConstantInt
>(C(i
)));
179 Value
* Builder::VIMMED1_16(uint32_t i
)
181 #if LLVM_VERSION_MAJOR > 10
182 return ConstantVector::getSplat(ElementCount::get(mVWidth16
, false), cast
<ConstantInt
>(C(i
)));
184 return ConstantVector::getSplat(mVWidth16
, cast
<ConstantInt
>(C(i
)));
188 Value
* Builder::VIMMED1(float i
)
190 #if LLVM_VERSION_MAJOR > 10
191 return ConstantVector::getSplat(ElementCount::get(mVWidth
, false), cast
<ConstantFP
>(C(i
)));
193 return ConstantVector::getSplat(mVWidth
, cast
<ConstantFP
>(C(i
)));
197 Value
* Builder::VIMMED1_16(float i
)
199 #if LLVM_VERSION_MAJOR > 10
200 return ConstantVector::getSplat(ElementCount::get(mVWidth16
, false), cast
<ConstantFP
>(C(i
)));
202 return ConstantVector::getSplat(mVWidth16
, cast
<ConstantFP
>(C(i
)));
206 Value
* Builder::VIMMED1(bool i
)
208 #if LLVM_VERSION_MAJOR > 10
209 return ConstantVector::getSplat(ElementCount::get(mVWidth
, false), cast
<ConstantInt
>(C(i
)));
211 return ConstantVector::getSplat(mVWidth
, cast
<ConstantInt
>(C(i
)));
215 Value
* Builder::VIMMED1_16(bool i
)
217 #if LLVM_VERSION_MAJOR > 10
218 return ConstantVector::getSplat(ElementCount::get(mVWidth16
, false), cast
<ConstantInt
>(C(i
)));
220 return ConstantVector::getSplat(mVWidth16
, cast
<ConstantInt
>(C(i
)));
224 Value
* Builder::VUNDEF_IPTR() { return UndefValue::get(getVectorType(mInt32PtrTy
, mVWidth
)); }
226 Value
* Builder::VUNDEF(Type
* t
) { return UndefValue::get(getVectorType(t
, mVWidth
)); }
228 Value
* Builder::VUNDEF_I() { return UndefValue::get(getVectorType(mInt32Ty
, mVWidth
)); }
230 Value
* Builder::VUNDEF_I_16() { return UndefValue::get(getVectorType(mInt32Ty
, mVWidth16
)); }
232 Value
* Builder::VUNDEF_F() { return UndefValue::get(getVectorType(mFP32Ty
, mVWidth
)); }
234 Value
* Builder::VUNDEF_F_16() { return UndefValue::get(getVectorType(mFP32Ty
, mVWidth16
)); }
236 Value
* Builder::VUNDEF(Type
* ty
, uint32_t size
)
238 return UndefValue::get(getVectorType(ty
, size
));
241 Value
* Builder::VBROADCAST(Value
* src
, const llvm::Twine
& name
)
243 // check if src is already a vector
244 if (src
->getType()->isVectorTy())
249 return VECTOR_SPLAT(mVWidth
, src
, name
);
252 Value
* Builder::VBROADCAST_16(Value
* src
)
254 // check if src is already a vector
255 if (src
->getType()->isVectorTy())
260 return VECTOR_SPLAT(mVWidth16
, src
);
263 uint32_t Builder::IMMED(Value
* v
)
265 SWR_ASSERT(isa
<ConstantInt
>(v
));
266 ConstantInt
* pValConst
= cast
<ConstantInt
>(v
);
267 return pValConst
->getZExtValue();
270 int32_t Builder::S_IMMED(Value
* v
)
272 SWR_ASSERT(isa
<ConstantInt
>(v
));
273 ConstantInt
* pValConst
= cast
<ConstantInt
>(v
);
274 return pValConst
->getSExtValue();
277 CallInst
* Builder::CALL(Value
* Callee
,
278 const std::initializer_list
<Value
*>& argsList
,
279 const llvm::Twine
& name
)
281 std::vector
<Value
*> args
;
282 for (auto arg
: argsList
)
284 #if LLVM_VERSION_MAJOR >= 11
285 // see comment to CALLA(Callee) function in the header
286 return CALLA(FunctionCallee(cast
<Function
>(Callee
)), args
, name
);
288 return CALLA(Callee
, args
, name
);
292 CallInst
* Builder::CALL(Value
* Callee
, Value
* arg
)
294 std::vector
<Value
*> args
;
296 #if LLVM_VERSION_MAJOR >= 11
297 // see comment to CALLA(Callee) function in the header
298 return CALLA(FunctionCallee(cast
<Function
>(Callee
)), args
);
300 return CALLA(Callee
, args
);
304 CallInst
* Builder::CALL2(Value
* Callee
, Value
* arg1
, Value
* arg2
)
306 std::vector
<Value
*> args
;
307 args
.push_back(arg1
);
308 args
.push_back(arg2
);
309 #if LLVM_VERSION_MAJOR >= 11
310 // see comment to CALLA(Callee) function in the header
311 return CALLA(FunctionCallee(cast
<Function
>(Callee
)), args
);
313 return CALLA(Callee
, args
);
317 CallInst
* Builder::CALL3(Value
* Callee
, Value
* arg1
, Value
* arg2
, Value
* arg3
)
319 std::vector
<Value
*> args
;
320 args
.push_back(arg1
);
321 args
.push_back(arg2
);
322 args
.push_back(arg3
);
323 #if LLVM_VERSION_MAJOR >= 11
324 // see comment to CALLA(Callee) function in the header
325 return CALLA(FunctionCallee(cast
<Function
>(Callee
)), args
);
327 return CALLA(Callee
, args
);
331 Value
* Builder::VRCP(Value
* va
, const llvm::Twine
& name
)
333 return FDIV(VIMMED1(1.0f
), va
, name
); // 1 / a
336 Value
* Builder::VPLANEPS(Value
* vA
, Value
* vB
, Value
* vC
, Value
*& vX
, Value
*& vY
)
338 Value
* vOut
= FMADDPS(vA
, vX
, vC
);
339 vOut
= FMADDPS(vB
, vY
, vOut
);
343 //////////////////////////////////////////////////////////////////////////
344 /// @brief insert a JIT call to CallPrint
345 /// - outputs formatted string to both stdout and VS output window
346 /// - DEBUG builds only
348 /// PRINT("index %d = 0x%p\n",{C(lane), pIndex});
349 /// where C(lane) creates a constant value to print, and pIndex is the Value*
350 /// result from a GEP, printing out the pointer to memory
351 /// @param printStr - constant string to print, which includes format specifiers
352 /// @param printArgs - initializer list of Value*'s to print to std out
353 CallInst
* Builder::PRINT(const std::string
& printStr
,
354 const std::initializer_list
<Value
*>& printArgs
)
356 // push the arguments to CallPrint into a vector
357 std::vector
<Value
*> printCallArgs
;
358 // save room for the format string. we still need to modify it for vectors
359 printCallArgs
.resize(1);
361 // search through the format string for special processing
363 std::string
tempStr(printStr
);
364 pos
= tempStr
.find('%', pos
);
365 auto v
= printArgs
.begin();
367 while ((pos
!= std::string::npos
) && (v
!= printArgs
.end()))
370 Type
* pType
= pArg
->getType();
372 if (pType
->isVectorTy())
374 Type
* pContainedType
= pType
->getContainedType(0);
375 #if LLVM_VERSION_MAJOR >= 11
376 VectorType
* pVectorType
= cast
<VectorType
>(pType
);
378 if (toupper(tempStr
[pos
+ 1]) == 'X')
381 tempStr
[pos
+ 1] = 'x';
382 tempStr
.insert(pos
+ 2, "%08X ");
385 printCallArgs
.push_back(VEXTRACT(pArg
, C(0)));
387 std::string vectorFormatStr
;
388 #if LLVM_VERSION_MAJOR >= 11
389 for (uint32_t i
= 1; i
< pVectorType
->getNumElements(); ++i
)
391 for (uint32_t i
= 1; i
< pType
->getVectorNumElements(); ++i
)
394 vectorFormatStr
+= "0x%08X ";
395 printCallArgs
.push_back(VEXTRACT(pArg
, C(i
)));
398 tempStr
.insert(pos
, vectorFormatStr
);
399 pos
+= vectorFormatStr
.size();
401 else if ((tempStr
[pos
+ 1] == 'f') && (pContainedType
->isFloatTy()))
404 #if LLVM_VERSION_MAJOR >= 11
405 for (; i
< pVectorType
->getNumElements() - 1; i
++)
407 for (; i
< pType
->getVectorNumElements() - 1; i
++)
410 tempStr
.insert(pos
, std::string("%f "));
412 printCallArgs
.push_back(
413 FP_EXT(VEXTRACT(pArg
, C(i
)), Type::getDoubleTy(JM()->mContext
)));
415 printCallArgs
.push_back(
416 FP_EXT(VEXTRACT(pArg
, C(i
)), Type::getDoubleTy(JM()->mContext
)));
418 else if ((tempStr
[pos
+ 1] == 'd') && (pContainedType
->isIntegerTy()))
421 #if LLVM_VERSION_MAJOR >= 11
422 for (; i
< pVectorType
->getNumElements() - 1; i
++)
424 for (; i
< pType
->getVectorNumElements() - 1; i
++)
427 tempStr
.insert(pos
, std::string("%d "));
429 printCallArgs
.push_back(
430 S_EXT(VEXTRACT(pArg
, C(i
)), Type::getInt32Ty(JM()->mContext
)));
432 printCallArgs
.push_back(
433 S_EXT(VEXTRACT(pArg
, C(i
)), Type::getInt32Ty(JM()->mContext
)));
435 else if ((tempStr
[pos
+ 1] == 'u') && (pContainedType
->isIntegerTy()))
438 #if LLVM_VERSION_MAJOR >= 11
439 for (; i
< pVectorType
->getNumElements() - 1; i
++)
441 for (; i
< pType
->getVectorNumElements() - 1; i
++)
444 tempStr
.insert(pos
, std::string("%d "));
446 printCallArgs
.push_back(
447 Z_EXT(VEXTRACT(pArg
, C(i
)), Type::getInt32Ty(JM()->mContext
)));
449 printCallArgs
.push_back(
450 Z_EXT(VEXTRACT(pArg
, C(i
)), Type::getInt32Ty(JM()->mContext
)));
455 if (toupper(tempStr
[pos
+ 1]) == 'X')
458 tempStr
.insert(pos
+ 1, "x%08");
459 printCallArgs
.push_back(pArg
);
462 // for %f we need to cast float Values to doubles so that they print out correctly
463 else if ((tempStr
[pos
+ 1] == 'f') && (pType
->isFloatTy()))
465 printCallArgs
.push_back(FP_EXT(pArg
, Type::getDoubleTy(JM()->mContext
)));
470 printCallArgs
.push_back(pArg
);
474 // advance to the next arguement
476 pos
= tempStr
.find('%', ++pos
);
479 // create global variable constant string
480 Constant
* constString
= ConstantDataArray::getString(JM()->mContext
, tempStr
, true);
481 GlobalVariable
* gvPtr
= new GlobalVariable(
482 constString
->getType(), true, GlobalValue::InternalLinkage
, constString
, "printStr");
483 JM()->mpCurrentModule
->getGlobalList().push_back(gvPtr
);
485 // get a pointer to the first character in the constant string array
486 std::vector
<Constant
*> geplist
{C(0), C(0)};
487 Constant
* strGEP
= ConstantExpr::getGetElementPtr(nullptr, gvPtr
, geplist
, false);
489 // insert the pointer to the format string in the argument vector
490 printCallArgs
[0] = strGEP
;
492 // get pointer to CallPrint function and insert decl into the module if needed
493 std::vector
<Type
*> args
;
494 args
.push_back(PointerType::get(mInt8Ty
, 0));
495 FunctionType
* callPrintTy
= FunctionType::get(Type::getVoidTy(JM()->mContext
), args
, true);
496 Function
* callPrintFn
=
497 #if LLVM_VERSION_MAJOR >= 9
498 cast
<Function
>(JM()->mpCurrentModule
->getOrInsertFunction("CallPrint", callPrintTy
).getCallee());
500 cast
<Function
>(JM()->mpCurrentModule
->getOrInsertFunction("CallPrint", callPrintTy
));
503 // if we haven't yet added the symbol to the symbol table
504 if ((sys::DynamicLibrary::SearchForAddressOfSymbol("CallPrint")) == nullptr)
506 sys::DynamicLibrary::AddSymbol("CallPrint", (void*)&CallPrint
);
509 // insert a call to CallPrint
510 return CALLA(callPrintFn
, printCallArgs
);
513 //////////////////////////////////////////////////////////////////////////
514 /// @brief Wrapper around PRINT with initializer list.
515 CallInst
* Builder::PRINT(const std::string
& printStr
) { return PRINT(printStr
, {}); }
517 Value
* Builder::EXTRACT_16(Value
* x
, uint32_t imm
)
521 return VSHUFFLE(x
, UndefValue::get(x
->getType()), {0, 1, 2, 3, 4, 5, 6, 7});
525 return VSHUFFLE(x
, UndefValue::get(x
->getType()), {8, 9, 10, 11, 12, 13, 14, 15});
529 Value
* Builder::JOIN_16(Value
* a
, Value
* b
)
531 return VSHUFFLE(a
, b
, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
534 //////////////////////////////////////////////////////////////////////////
535 /// @brief convert x86 <N x float> mask to llvm <N x i1> mask
536 Value
* Builder::MASK(Value
* vmask
)
538 Value
* src
= BITCAST(vmask
, mSimdInt32Ty
);
539 return ICMP_SLT(src
, VIMMED1(0));
542 Value
* Builder::MASK_16(Value
* vmask
)
544 Value
* src
= BITCAST(vmask
, mSimd16Int32Ty
);
545 return ICMP_SLT(src
, VIMMED1_16(0));
548 //////////////////////////////////////////////////////////////////////////
549 /// @brief convert llvm <N x i1> mask to x86 <N x i32> mask
550 Value
* Builder::VMASK(Value
* mask
) { return S_EXT(mask
, mSimdInt32Ty
); }
552 Value
* Builder::VMASK_16(Value
* mask
) { return S_EXT(mask
, mSimd16Int32Ty
); }
554 /// @brief Convert <Nxi1> llvm mask to integer
555 Value
* Builder::VMOVMSK(Value
* mask
)
557 #if LLVM_VERSION_MAJOR >= 11
558 VectorType
* pVectorType
= cast
<VectorType
>(mask
->getType());
559 SWR_ASSERT(pVectorType
->getElementType() == mInt1Ty
);
560 uint32_t numLanes
= pVectorType
->getNumElements();
562 SWR_ASSERT(mask
->getType()->getVectorElementType() == mInt1Ty
);
563 uint32_t numLanes
= mask
->getType()->getVectorNumElements();
568 i32Result
= BITCAST(mask
, mInt8Ty
);
570 else if (numLanes
== 16)
572 i32Result
= BITCAST(mask
, mInt16Ty
);
576 SWR_ASSERT("Unsupported vector width");
577 i32Result
= BITCAST(mask
, mInt8Ty
);
579 return Z_EXT(i32Result
, mInt32Ty
);
582 //////////////////////////////////////////////////////////////////////////
583 /// @brief Generate a VPSHUFB operation in LLVM IR. If not
584 /// supported on the underlying platform, emulate it
585 /// @param a - 256bit SIMD(32x8bit) of 8bit integer values
586 /// @param b - 256bit SIMD(32x8bit) of 8bit integer mask values
587 /// Byte masks in lower 128 lane of b selects 8 bit values from lower
588 /// 128bits of a, and vice versa for the upper lanes. If the mask
589 /// value is negative, '0' is inserted.
590 Value
* Builder::PSHUFB(Value
* a
, Value
* b
)
593 // use avx2 pshufb instruction if available
594 if (JM()->mArch
.AVX2())
600 Constant
* cB
= dyn_cast
<Constant
>(b
);
601 assert(cB
!= nullptr);
602 // number of 8 bit elements in b
603 uint32_t numElms
= cast
<VectorType
>(cB
->getType())->getNumElements();
605 Value
* vShuf
= UndefValue::get(getVectorType(mInt8Ty
, numElms
));
607 // insert an 8 bit value from the high and low lanes of a per loop iteration
609 for (uint32_t i
= 0; i
< numElms
; i
++)
611 ConstantInt
* cLow128b
= cast
<ConstantInt
>(cB
->getAggregateElement(i
));
612 ConstantInt
* cHigh128b
= cast
<ConstantInt
>(cB
->getAggregateElement(i
+ numElms
));
614 // extract values from constant mask
615 char valLow128bLane
= (char)(cLow128b
->getSExtValue());
616 char valHigh128bLane
= (char)(cHigh128b
->getSExtValue());
618 Value
* insertValLow128b
;
619 Value
* insertValHigh128b
;
621 // if the mask value is negative, insert a '0' in the respective output position
622 // otherwise, lookup the value at mask position (bits 3..0 of the respective mask
623 // byte) in a and insert in output vector
625 (valLow128bLane
< 0) ? C((char)0) : VEXTRACT(a
, C((valLow128bLane
& 0xF)));
626 insertValHigh128b
= (valHigh128bLane
< 0)
628 : VEXTRACT(a
, C((valHigh128bLane
& 0xF) + numElms
));
630 vShuf
= VINSERT(vShuf
, insertValLow128b
, i
);
631 vShuf
= VINSERT(vShuf
, insertValHigh128b
, (i
+ numElms
));
638 //////////////////////////////////////////////////////////////////////////
639 /// @brief Generate a VPSHUFB operation (sign extend 8 8bit values to 32
640 /// bits)in LLVM IR. If not supported on the underlying platform, emulate it
641 /// @param a - 128bit SIMD lane(16x8bit) of 8bit integer values. Only
642 /// lower 8 values are used.
643 Value
* Builder::PMOVSXBD(Value
* a
)
645 // VPMOVSXBD output type
646 Type
* v8x32Ty
= getVectorType(mInt32Ty
, 8);
647 // Extract 8 values from 128bit lane and sign extend
648 return S_EXT(VSHUFFLE(a
, a
, C
<int>({0, 1, 2, 3, 4, 5, 6, 7})), v8x32Ty
);
651 //////////////////////////////////////////////////////////////////////////
652 /// @brief Generate a VPSHUFB operation (sign extend 8 16bit values to 32
653 /// bits)in LLVM IR. If not supported on the underlying platform, emulate it
654 /// @param a - 128bit SIMD lane(8x16bit) of 16bit integer values.
655 Value
* Builder::PMOVSXWD(Value
* a
)
657 // VPMOVSXWD output type
658 Type
* v8x32Ty
= getVectorType(mInt32Ty
, 8);
659 // Extract 8 values from 128bit lane and sign extend
660 return S_EXT(VSHUFFLE(a
, a
, C
<int>({0, 1, 2, 3, 4, 5, 6, 7})), v8x32Ty
);
663 //////////////////////////////////////////////////////////////////////////
664 /// @brief Generate a VCVTPH2PS operation (float16->float32 conversion)
665 /// in LLVM IR. If not supported on the underlying platform, emulate it
666 /// @param a - 128bit SIMD lane(8x16bit) of float16 in int16 format.
667 Value
* Builder::CVTPH2PS(Value
* a
, const llvm::Twine
& name
)
669 // Bitcast Nxint16 to Nxhalf
670 #if LLVM_VERSION_MAJOR >= 11
671 uint32_t numElems
= cast
<VectorType
>(a
->getType())->getNumElements();
673 uint32_t numElems
= a
->getType()->getVectorNumElements();
675 Value
* input
= BITCAST(a
, getVectorType(mFP16Ty
, numElems
));
677 return FP_EXT(input
, getVectorType(mFP32Ty
, numElems
), name
);
680 //////////////////////////////////////////////////////////////////////////
681 /// @brief Generate a VCVTPS2PH operation (float32->float16 conversion)
682 /// in LLVM IR. If not supported on the underlying platform, emulate it
683 /// @param a - 128bit SIMD lane(8x16bit) of float16 in int16 format.
684 Value
* Builder::CVTPS2PH(Value
* a
, Value
* rounding
)
686 if (JM()->mArch
.F16C())
688 return VCVTPS2PH(a
, rounding
);
692 // call scalar C function for now
693 FunctionType
* pFuncTy
= FunctionType::get(mInt16Ty
, mFP32Ty
);
694 Function
* pCvtPs2Ph
= cast
<Function
>(
695 #if LLVM_VERSION_MAJOR >= 9
696 JM()->mpCurrentModule
->getOrInsertFunction("ConvertFloat32ToFloat16", pFuncTy
).getCallee());
698 JM()->mpCurrentModule
->getOrInsertFunction("ConvertFloat32ToFloat16", pFuncTy
));
701 if (sys::DynamicLibrary::SearchForAddressOfSymbol("ConvertFloat32ToFloat16") == nullptr)
703 sys::DynamicLibrary::AddSymbol("ConvertFloat32ToFloat16",
704 (void*)&ConvertFloat32ToFloat16
);
707 Value
* pResult
= UndefValue::get(mSimdInt16Ty
);
708 for (uint32_t i
= 0; i
< mVWidth
; ++i
)
710 Value
* pSrc
= VEXTRACT(a
, C(i
));
711 Value
* pConv
= CALL(pCvtPs2Ph
, std::initializer_list
<Value
*>{pSrc
});
712 pResult
= VINSERT(pResult
, pConv
, C(i
));
719 Value
* Builder::PMAXSD(Value
* a
, Value
* b
)
721 Value
* cmp
= ICMP_SGT(a
, b
);
722 return SELECT(cmp
, a
, b
);
725 Value
* Builder::PMINSD(Value
* a
, Value
* b
)
727 Value
* cmp
= ICMP_SLT(a
, b
);
728 return SELECT(cmp
, a
, b
);
731 Value
* Builder::PMAXUD(Value
* a
, Value
* b
)
733 Value
* cmp
= ICMP_UGT(a
, b
);
734 return SELECT(cmp
, a
, b
);
737 Value
* Builder::PMINUD(Value
* a
, Value
* b
)
739 Value
* cmp
= ICMP_ULT(a
, b
);
740 return SELECT(cmp
, a
, b
);
743 // Helper function to create alloca in entry block of function
744 Value
* Builder::CreateEntryAlloca(Function
* pFunc
, Type
* pType
)
746 auto saveIP
= IRB()->saveIP();
747 IRB()->SetInsertPoint(&pFunc
->getEntryBlock(), pFunc
->getEntryBlock().begin());
748 Value
* pAlloca
= ALLOCA(pType
);
750 IRB()->restoreIP(saveIP
);
754 Value
* Builder::CreateEntryAlloca(Function
* pFunc
, Type
* pType
, Value
* pArraySize
)
756 auto saveIP
= IRB()->saveIP();
757 IRB()->SetInsertPoint(&pFunc
->getEntryBlock(), pFunc
->getEntryBlock().begin());
758 Value
* pAlloca
= ALLOCA(pType
, pArraySize
);
760 IRB()->restoreIP(saveIP
);
764 Value
* Builder::VABSPS(Value
* a
)
766 Value
* asInt
= BITCAST(a
, mSimdInt32Ty
);
767 Value
* result
= BITCAST(AND(asInt
, VIMMED1(0x7fffffff)), mSimdFP32Ty
);
771 Value
* Builder::ICLAMP(Value
* src
, Value
* low
, Value
* high
, const llvm::Twine
& name
)
773 Value
* lowCmp
= ICMP_SLT(src
, low
);
774 Value
* ret
= SELECT(lowCmp
, low
, src
);
776 Value
* highCmp
= ICMP_SGT(ret
, high
);
777 ret
= SELECT(highCmp
, high
, ret
, name
);
782 Value
* Builder::FCLAMP(Value
* src
, Value
* low
, Value
* high
)
784 Value
* lowCmp
= FCMP_OLT(src
, low
);
785 Value
* ret
= SELECT(lowCmp
, low
, src
);
787 Value
* highCmp
= FCMP_OGT(ret
, high
);
788 ret
= SELECT(highCmp
, high
, ret
);
793 Value
* Builder::FCLAMP(Value
* src
, float low
, float high
)
795 Value
* result
= VMAXPS(src
, VIMMED1(low
));
796 result
= VMINPS(result
, VIMMED1(high
));
801 Value
* Builder::FMADDPS(Value
* a
, Value
* b
, Value
* c
)
804 // This maps to LLVM fmuladd intrinsic
805 vOut
= VFMADDPS(a
, b
, c
);
809 //////////////////////////////////////////////////////////////////////////
810 /// @brief pop count on vector mask (e.g. <8 x i1>)
811 Value
* Builder::VPOPCNT(Value
* a
) { return POPCNT(VMOVMSK(a
)); }
813 //////////////////////////////////////////////////////////////////////////
814 /// @brief Float / Fixed-point conversions
815 //////////////////////////////////////////////////////////////////////////
816 Value
* Builder::VCVT_F32_FIXED_SI(Value
* vFloat
,
818 uint32_t numFracBits
,
819 const llvm::Twine
& name
)
821 SWR_ASSERT((numIntBits
+ numFracBits
) <= 32, "Can only handle 32-bit fixed-point values");
822 Value
* fixed
= nullptr;
824 #if 0 // This doesn't work for negative numbers!!
826 fixed
= FP_TO_SI(VROUND(FMUL(vFloat
, VIMMED1(float(1 << numFracBits
))),
827 C(_MM_FROUND_TO_NEAREST_INT
)),
833 // Do round to nearest int on fractional bits first
834 // Not entirely perfect for negative numbers, but close enough
835 vFloat
= VROUND(FMUL(vFloat
, VIMMED1(float(1 << numFracBits
))),
836 C(_MM_FROUND_TO_NEAREST_INT
));
837 vFloat
= FMUL(vFloat
, VIMMED1(1.0f
/ float(1 << numFracBits
)));
839 // TODO: Handle INF, NAN, overflow / underflow, etc.
841 Value
* vSgn
= FCMP_OLT(vFloat
, VIMMED1(0.0f
));
842 Value
* vFloatInt
= BITCAST(vFloat
, mSimdInt32Ty
);
843 Value
* vFixed
= AND(vFloatInt
, VIMMED1((1 << 23) - 1));
844 vFixed
= OR(vFixed
, VIMMED1(1 << 23));
845 vFixed
= SELECT(vSgn
, NEG(vFixed
), vFixed
);
847 Value
* vExp
= LSHR(SHL(vFloatInt
, VIMMED1(1)), VIMMED1(24));
848 vExp
= SUB(vExp
, VIMMED1(127));
850 Value
* vExtraBits
= SUB(VIMMED1(23 - numFracBits
), vExp
);
852 fixed
= ASHR(vFixed
, vExtraBits
, name
);
858 Value
* Builder::VCVT_FIXED_SI_F32(Value
* vFixed
,
860 uint32_t numFracBits
,
861 const llvm::Twine
& name
)
863 SWR_ASSERT((numIntBits
+ numFracBits
) <= 32, "Can only handle 32-bit fixed-point values");
864 uint32_t extraBits
= 32 - numIntBits
- numFracBits
;
865 if (numIntBits
&& extraBits
)
868 Value
* shftAmt
= VIMMED1(extraBits
);
869 vFixed
= ASHR(SHL(vFixed
, shftAmt
), shftAmt
);
872 Value
* fVal
= VIMMED1(0.0f
);
873 Value
* fFrac
= VIMMED1(0.0f
);
876 fVal
= SI_TO_FP(ASHR(vFixed
, VIMMED1(numFracBits
)), mSimdFP32Ty
, name
);
881 fFrac
= UI_TO_FP(AND(vFixed
, VIMMED1((1 << numFracBits
) - 1)), mSimdFP32Ty
);
882 fFrac
= FDIV(fFrac
, VIMMED1(float(1 << numFracBits
)), name
);
885 return FADD(fVal
, fFrac
, name
);
888 Value
* Builder::VCVT_F32_FIXED_UI(Value
* vFloat
,
890 uint32_t numFracBits
,
891 const llvm::Twine
& name
)
893 SWR_ASSERT((numIntBits
+ numFracBits
) <= 32, "Can only handle 32-bit fixed-point values");
894 Value
* fixed
= nullptr;
895 #if 1 // KNOB_SIM_FAST_MATH? Below works correctly from a precision
898 fixed
= FP_TO_UI(VROUND(FMUL(vFloat
, VIMMED1(float(1 << numFracBits
))),
899 C(_MM_FROUND_TO_NEAREST_INT
)),
904 // Do round to nearest int on fractional bits first
905 vFloat
= VROUND(FMUL(vFloat
, VIMMED1(float(1 << numFracBits
))),
906 C(_MM_FROUND_TO_NEAREST_INT
));
907 vFloat
= FMUL(vFloat
, VIMMED1(1.0f
/ float(1 << numFracBits
)));
909 // TODO: Handle INF, NAN, overflow / underflow, etc.
911 Value
* vSgn
= FCMP_OLT(vFloat
, VIMMED1(0.0f
));
912 Value
* vFloatInt
= BITCAST(vFloat
, mSimdInt32Ty
);
913 Value
* vFixed
= AND(vFloatInt
, VIMMED1((1 << 23) - 1));
914 vFixed
= OR(vFixed
, VIMMED1(1 << 23));
916 Value
* vExp
= LSHR(SHL(vFloatInt
, VIMMED1(1)), VIMMED1(24));
917 vExp
= SUB(vExp
, VIMMED1(127));
919 Value
* vExtraBits
= SUB(VIMMED1(23 - numFracBits
), vExp
);
921 fixed
= LSHR(vFixed
, vExtraBits
, name
);
927 Value
* Builder::VCVT_FIXED_UI_F32(Value
* vFixed
,
929 uint32_t numFracBits
,
930 const llvm::Twine
& name
)
932 SWR_ASSERT((numIntBits
+ numFracBits
) <= 32, "Can only handle 32-bit fixed-point values");
933 uint32_t extraBits
= 32 - numIntBits
- numFracBits
;
934 if (numIntBits
&& extraBits
)
937 Value
* shftAmt
= VIMMED1(extraBits
);
938 vFixed
= ASHR(SHL(vFixed
, shftAmt
), shftAmt
);
941 Value
* fVal
= VIMMED1(0.0f
);
942 Value
* fFrac
= VIMMED1(0.0f
);
945 fVal
= UI_TO_FP(LSHR(vFixed
, VIMMED1(numFracBits
)), mSimdFP32Ty
, name
);
950 fFrac
= UI_TO_FP(AND(vFixed
, VIMMED1((1 << numFracBits
) - 1)), mSimdFP32Ty
);
951 fFrac
= FDIV(fFrac
, VIMMED1(float(1 << numFracBits
)), name
);
954 return FADD(fVal
, fFrac
, name
);
957 //////////////////////////////////////////////////////////////////////////
958 /// @brief C functions called by LLVM IR
959 //////////////////////////////////////////////////////////////////////////
961 Value
* Builder::VEXTRACTI128(Value
* a
, Constant
* imm8
)
963 bool flag
= !imm8
->isZeroValue();
964 SmallVector
<Constant
*, 8> idx
;
965 for (unsigned i
= 0; i
< mVWidth
/ 2; i
++)
967 idx
.push_back(C(flag
? i
+ mVWidth
/ 2 : i
));
969 return VSHUFFLE(a
, VUNDEF_I(), ConstantVector::get(idx
));
972 Value
* Builder::VINSERTI128(Value
* a
, Value
* b
, Constant
* imm8
)
974 bool flag
= !imm8
->isZeroValue();
975 SmallVector
<Constant
*, 8> idx
;
976 for (unsigned i
= 0; i
< mVWidth
; i
++)
980 Value
* inter
= VSHUFFLE(b
, VUNDEF_I(), ConstantVector::get(idx
));
982 SmallVector
<Constant
*, 8> idx2
;
983 for (unsigned i
= 0; i
< mVWidth
/ 2; i
++)
985 idx2
.push_back(C(flag
? i
: i
+ mVWidth
));
987 for (unsigned i
= mVWidth
/ 2; i
< mVWidth
; i
++)
989 idx2
.push_back(C(flag
? i
+ mVWidth
/ 2 : i
));
991 return VSHUFFLE(a
, inter
, ConstantVector::get(idx2
));
994 // rdtsc buckets macros
995 void Builder::RDTSC_START(Value
* pBucketMgr
, Value
* pId
)
997 // @todo due to an issue with thread local storage propagation in llvm, we can only safely
998 // call into buckets framework when single threaded
999 if (KNOB_SINGLE_THREADED
)
1001 std::vector
<Type
*> args
{
1002 PointerType::get(mInt32Ty
, 0), // pBucketMgr
1006 FunctionType
* pFuncTy
= FunctionType::get(Type::getVoidTy(JM()->mContext
), args
, false);
1007 Function
* pFunc
= cast
<Function
>(
1008 #if LLVM_VERSION_MAJOR >= 9
1009 JM()->mpCurrentModule
->getOrInsertFunction("BucketManager_StartBucket", pFuncTy
).getCallee());
1011 JM()->mpCurrentModule
->getOrInsertFunction("BucketManager_StartBucket", pFuncTy
));
1013 if (sys::DynamicLibrary::SearchForAddressOfSymbol("BucketManager_StartBucket") ==
1016 sys::DynamicLibrary::AddSymbol("BucketManager_StartBucket",
1017 (void*)&BucketManager_StartBucket
);
1020 CALL(pFunc
, {pBucketMgr
, pId
});
1024 void Builder::RDTSC_STOP(Value
* pBucketMgr
, Value
* pId
)
1026 // @todo due to an issue with thread local storage propagation in llvm, we can only safely
1027 // call into buckets framework when single threaded
1028 if (KNOB_SINGLE_THREADED
)
1030 std::vector
<Type
*> args
{
1031 PointerType::get(mInt32Ty
, 0), // pBucketMgr
1035 FunctionType
* pFuncTy
= FunctionType::get(Type::getVoidTy(JM()->mContext
), args
, false);
1036 Function
* pFunc
= cast
<Function
>(
1037 #if LLVM_VERSION_MAJOR >= 9
1038 JM()->mpCurrentModule
->getOrInsertFunction("BucketManager_StopBucket", pFuncTy
).getCallee());
1040 JM()->mpCurrentModule
->getOrInsertFunction("BucketManager_StopBucket", pFuncTy
));
1042 if (sys::DynamicLibrary::SearchForAddressOfSymbol("BucketManager_StopBucket") ==
1045 sys::DynamicLibrary::AddSymbol("BucketManager_StopBucket",
1046 (void*)&BucketManager_StopBucket
);
1049 CALL(pFunc
, {pBucketMgr
, pId
});
1053 uint32_t Builder::GetTypeSize(Type
* pType
)
1055 if (pType
->isStructTy())
1057 uint32_t numElems
= pType
->getStructNumElements();
1058 Type
* pElemTy
= pType
->getStructElementType(0);
1059 return numElems
* GetTypeSize(pElemTy
);
1062 if (pType
->isArrayTy())
1064 uint32_t numElems
= pType
->getArrayNumElements();
1065 Type
* pElemTy
= pType
->getArrayElementType();
1066 return numElems
* GetTypeSize(pElemTy
);
1069 if (pType
->isIntegerTy())
1071 uint32_t bitSize
= pType
->getIntegerBitWidth();
1075 if (pType
->isFloatTy())
1080 if (pType
->isHalfTy())
1085 if (pType
->isDoubleTy())
1090 SWR_ASSERT(false, "Unimplemented type.");
1093 } // namespace SwrJit