swr: Fix crashes on non-AVX hardware
[mesa.git] / src / gallium / drivers / swr / rasterizer / jitter / functionpasses / lower_x86.cpp
1 /****************************************************************************
2 * Copyright (C) 2014-2018 Intel Corporation. All Rights Reserved.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * @file lower_x86.cpp
24 *
25 * @brief llvm pass to lower meta code to x86
26 *
27 * Notes:
28 *
29 ******************************************************************************/
30
31 #include "jit_pch.hpp"
32 #include "passes.h"
33 #include "JitManager.h"
34
35 #include "common/simdlib.hpp"
36
37 #include <unordered_map>
38
39 extern "C" void ScatterPS_256(uint8_t*, SIMD256::Integer, SIMD256::Float, uint8_t, uint32_t);
40
41 namespace llvm
42 {
43 // foward declare the initializer
44 void initializeLowerX86Pass(PassRegistry&);
45 } // namespace llvm
46
47 namespace SwrJit
48 {
49 using namespace llvm;
50
51 enum TargetArch
52 {
53 AVX = 0,
54 AVX2 = 1,
55 AVX512 = 2
56 };
57
58 enum TargetWidth
59 {
60 W256 = 0,
61 W512 = 1,
62 NUM_WIDTHS = 2
63 };
64
65 struct LowerX86;
66
67 typedef std::function<Instruction*(LowerX86*, TargetArch, TargetWidth, CallInst*)> EmuFunc;
68
69 struct X86Intrinsic
70 {
71 IntrinsicID intrin[NUM_WIDTHS];
72 EmuFunc emuFunc;
73 };
74
75 // Map of intrinsics that haven't been moved to the new mechanism yet. If used, these get the
76 // previous behavior of mapping directly to avx/avx2 intrinsics.
77 using intrinsicMap_t = std::map<std::string, IntrinsicID>;
78 static intrinsicMap_t& getIntrinsicMap() {
79 static std::map<std::string, IntrinsicID> intrinsicMap = {
80 {"meta.intrinsic.BEXTR_32", Intrinsic::x86_bmi_bextr_32},
81 {"meta.intrinsic.VPSHUFB", Intrinsic::x86_avx2_pshuf_b},
82 {"meta.intrinsic.VCVTPS2PH", Intrinsic::x86_vcvtps2ph_256},
83 {"meta.intrinsic.VPTESTC", Intrinsic::x86_avx_ptestc_256},
84 {"meta.intrinsic.VPTESTZ", Intrinsic::x86_avx_ptestz_256},
85 {"meta.intrinsic.VPHADDD", Intrinsic::x86_avx2_phadd_d},
86 {"meta.intrinsic.PDEP32", Intrinsic::x86_bmi_pdep_32},
87 {"meta.intrinsic.RDTSC", Intrinsic::x86_rdtsc}
88 };
89 return intrinsicMap;
90 }
91
92 // Forward decls
93 Instruction* NO_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
94 Instruction*
95 VPERM_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
96 Instruction*
97 VGATHER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
98 Instruction*
99 VSCATTER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
100 Instruction*
101 VROUND_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
102 Instruction*
103 VHSUB_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
104 Instruction*
105 VCONVERT_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
106
107 Instruction* DOUBLE_EMU(LowerX86* pThis,
108 TargetArch arch,
109 TargetWidth width,
110 CallInst* pCallInst,
111 Intrinsic::ID intrin);
112
113 static Intrinsic::ID DOUBLE = (Intrinsic::ID)-1;
114
115 using intrinsicMapAdvanced_t = std::vector<std::map<std::string, X86Intrinsic>>;
116
117 static intrinsicMapAdvanced_t& getIntrinsicMapAdvanced()
118 {
119 // clang-format off
120 static intrinsicMapAdvanced_t intrinsicMapAdvanced = {
121 // 256 wide 512 wide
122 {
123 // AVX
124 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx_rcp_ps_256, DOUBLE}, NO_EMU}},
125 {"meta.intrinsic.VPERMPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}},
126 {"meta.intrinsic.VPERMD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}},
127 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
128 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
129 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
130 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VSCATTER_EMU}},
131 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx_cvt_pd2_ps_256, Intrinsic::not_intrinsic}, NO_EMU}},
132 {"meta.intrinsic.VROUND", {{Intrinsic::x86_avx_round_ps_256, DOUBLE}, NO_EMU}},
133 {"meta.intrinsic.VHSUBPS", {{Intrinsic::x86_avx_hsub_ps_256, DOUBLE}, NO_EMU}},
134 },
135 {
136 // AVX2
137 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx_rcp_ps_256, DOUBLE}, NO_EMU}},
138 {"meta.intrinsic.VPERMPS", {{Intrinsic::x86_avx2_permps, Intrinsic::not_intrinsic}, VPERM_EMU}},
139 {"meta.intrinsic.VPERMD", {{Intrinsic::x86_avx2_permd, Intrinsic::not_intrinsic}, VPERM_EMU}},
140 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
141 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
142 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
143 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VSCATTER_EMU}},
144 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx_cvt_pd2_ps_256, DOUBLE}, NO_EMU}},
145 {"meta.intrinsic.VROUND", {{Intrinsic::x86_avx_round_ps_256, DOUBLE}, NO_EMU}},
146 {"meta.intrinsic.VHSUBPS", {{Intrinsic::x86_avx_hsub_ps_256, DOUBLE}, NO_EMU}},
147 },
148 {
149 // AVX512
150 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx512_rcp14_ps_256, Intrinsic::x86_avx512_rcp14_ps_512}, NO_EMU}},
151 #if LLVM_VERSION_MAJOR < 7
152 {"meta.intrinsic.VPERMPS", {{Intrinsic::x86_avx512_mask_permvar_sf_256, Intrinsic::x86_avx512_mask_permvar_sf_512}, NO_EMU}},
153 {"meta.intrinsic.VPERMD", {{Intrinsic::x86_avx512_mask_permvar_si_256, Intrinsic::x86_avx512_mask_permvar_si_512}, NO_EMU}},
154 #else
155 {"meta.intrinsic.VPERMPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}},
156 {"meta.intrinsic.VPERMD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}},
157 #endif
158 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
159 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
160 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
161 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VSCATTER_EMU}},
162 #if LLVM_VERSION_MAJOR < 7
163 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx512_mask_cvtpd2ps_256, Intrinsic::x86_avx512_mask_cvtpd2ps_512}, NO_EMU}},
164 #else
165 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VCONVERT_EMU}},
166 #endif
167 {"meta.intrinsic.VROUND", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VROUND_EMU}},
168 {"meta.intrinsic.VHSUBPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VHSUB_EMU}}
169 }};
170 // clang-format on
171 return intrinsicMapAdvanced;
172 }
173
174 static uint32_t getBitWidth(VectorType *pVTy)
175 {
176 #if LLVM_VERSION_MAJOR >= 11
177 return pVTy->getNumElements() * pVTy->getElementType()->getPrimitiveSizeInBits();
178 #else
179 return pVTy->getBitWidth();
180 #endif
181 }
182
183 struct LowerX86 : public FunctionPass
184 {
185 LowerX86(Builder* b = nullptr) : FunctionPass(ID), B(b)
186 {
187 initializeLowerX86Pass(*PassRegistry::getPassRegistry());
188
189 // Determine target arch
190 if (JM()->mArch.AVX512F())
191 {
192 mTarget = AVX512;
193 }
194 else if (JM()->mArch.AVX2())
195 {
196 mTarget = AVX2;
197 }
198 else if (JM()->mArch.AVX())
199 {
200 mTarget = AVX;
201 }
202 else
203 {
204 SWR_ASSERT(false, "Unsupported AVX architecture.");
205 mTarget = AVX;
206 }
207
208 // Setup scatter function for 256 wide
209 uint32_t curWidth = B->mVWidth;
210 B->SetTargetWidth(8);
211 std::vector<Type*> args = {
212 B->mInt8PtrTy, // pBase
213 B->mSimdInt32Ty, // vIndices
214 B->mSimdFP32Ty, // vSrc
215 B->mInt8Ty, // mask
216 B->mInt32Ty // scale
217 };
218
219 FunctionType* pfnScatterTy = FunctionType::get(B->mVoidTy, args, false);
220 mPfnScatter256 = cast<Function>(
221 #if LLVM_VERSION_MAJOR >= 9
222 B->JM()->mpCurrentModule->getOrInsertFunction("ScatterPS_256", pfnScatterTy).getCallee());
223 #else
224 B->JM()->mpCurrentModule->getOrInsertFunction("ScatterPS_256", pfnScatterTy));
225 #endif
226 if (sys::DynamicLibrary::SearchForAddressOfSymbol("ScatterPS_256") == nullptr)
227 {
228 sys::DynamicLibrary::AddSymbol("ScatterPS_256", (void*)&ScatterPS_256);
229 }
230
231 B->SetTargetWidth(curWidth);
232 }
233
234 // Try to decipher the vector type of the instruction. This does not work properly
235 // across all intrinsics, and will have to be rethought. Probably need something
236 // similar to llvm's getDeclaration() utility to map a set of inputs to a specific typed
237 // intrinsic.
238 void GetRequestedWidthAndType(CallInst* pCallInst,
239 const StringRef intrinName,
240 TargetWidth* pWidth,
241 Type** pTy)
242 {
243 assert(pCallInst);
244 Type* pVecTy = pCallInst->getType();
245
246 // Check for intrinsic specific types
247 // VCVTPD2PS type comes from src, not dst
248 if (intrinName.equals("meta.intrinsic.VCVTPD2PS"))
249 {
250 Value* pOp = pCallInst->getOperand(0);
251 assert(pOp);
252 pVecTy = pOp->getType();
253 }
254
255 if (!pVecTy->isVectorTy())
256 {
257 for (auto& op : pCallInst->arg_operands())
258 {
259 if (op.get()->getType()->isVectorTy())
260 {
261 pVecTy = op.get()->getType();
262 break;
263 }
264 }
265 }
266 SWR_ASSERT(pVecTy->isVectorTy(), "Couldn't determine vector size");
267
268 uint32_t width = getBitWidth(cast<VectorType>(pVecTy));
269 switch (width)
270 {
271 case 256:
272 *pWidth = W256;
273 break;
274 case 512:
275 *pWidth = W512;
276 break;
277 default:
278 SWR_ASSERT(false, "Unhandled vector width %d", width);
279 *pWidth = W256;
280 }
281
282 *pTy = pVecTy->getScalarType();
283 }
284
285 Value* GetZeroVec(TargetWidth width, Type* pTy)
286 {
287 uint32_t numElem = 0;
288 switch (width)
289 {
290 case W256:
291 numElem = 8;
292 break;
293 case W512:
294 numElem = 16;
295 break;
296 default:
297 SWR_ASSERT(false, "Unhandled vector width type %d\n", width);
298 }
299
300 return ConstantVector::getNullValue(getVectorType(pTy, numElem));
301 }
302
303 Value* GetMask(TargetWidth width)
304 {
305 Value* mask;
306 switch (width)
307 {
308 case W256:
309 mask = B->C((uint8_t)-1);
310 break;
311 case W512:
312 mask = B->C((uint16_t)-1);
313 break;
314 default:
315 SWR_ASSERT(false, "Unhandled vector width type %d\n", width);
316 }
317 return mask;
318 }
319
320 // Convert <N x i1> mask to <N x i32> x86 mask
321 Value* VectorMask(Value* vi1Mask)
322 {
323 #if LLVM_VERSION_MAJOR >= 11
324 uint32_t numElem = cast<VectorType>(vi1Mask->getType())->getNumElements();
325 #else
326 uint32_t numElem = vi1Mask->getType()->getVectorNumElements();
327 #endif
328 return B->S_EXT(vi1Mask, getVectorType(B->mInt32Ty, numElem));
329 }
330
331 Instruction* ProcessIntrinsicAdvanced(CallInst* pCallInst)
332 {
333 Function* pFunc = pCallInst->getCalledFunction();
334 assert(pFunc);
335
336 auto& intrinsic = getIntrinsicMapAdvanced()[mTarget][pFunc->getName().str()];
337 TargetWidth vecWidth;
338 Type* pElemTy;
339 GetRequestedWidthAndType(pCallInst, pFunc->getName(), &vecWidth, &pElemTy);
340
341 // Check if there is a native intrinsic for this instruction
342 IntrinsicID id = intrinsic.intrin[vecWidth];
343 if (id == DOUBLE)
344 {
345 // Double pump the next smaller SIMD intrinsic
346 SWR_ASSERT(vecWidth != 0, "Cannot double pump smallest SIMD width.");
347 Intrinsic::ID id2 = intrinsic.intrin[vecWidth - 1];
348 SWR_ASSERT(id2 != Intrinsic::not_intrinsic,
349 "Cannot find intrinsic to double pump.");
350 return DOUBLE_EMU(this, mTarget, vecWidth, pCallInst, id2);
351 }
352 else if (id != Intrinsic::not_intrinsic)
353 {
354 Function* pIntrin = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, id);
355 SmallVector<Value*, 8> args;
356 for (auto& arg : pCallInst->arg_operands())
357 {
358 args.push_back(arg.get());
359 }
360
361 // If AVX512, all instructions add a src operand and mask. We'll pass in 0 src and
362 // full mask for now Assuming the intrinsics are consistent and place the src
363 // operand and mask last in the argument list.
364 if (mTarget == AVX512)
365 {
366 if (pFunc->getName().equals("meta.intrinsic.VCVTPD2PS"))
367 {
368 args.push_back(GetZeroVec(W256, pCallInst->getType()->getScalarType()));
369 args.push_back(GetMask(W256));
370 // for AVX512 VCVTPD2PS, we also have to add rounding mode
371 args.push_back(B->C(_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
372 }
373 else
374 {
375 args.push_back(GetZeroVec(vecWidth, pElemTy));
376 args.push_back(GetMask(vecWidth));
377 }
378 }
379
380 return B->CALLA(pIntrin, args);
381 }
382 else
383 {
384 // No native intrinsic, call emulation function
385 return intrinsic.emuFunc(this, mTarget, vecWidth, pCallInst);
386 }
387
388 SWR_ASSERT(false);
389 return nullptr;
390 }
391
392 Instruction* ProcessIntrinsic(CallInst* pCallInst)
393 {
394 Function* pFunc = pCallInst->getCalledFunction();
395 assert(pFunc);
396
397 // Forward to the advanced support if found
398 if (getIntrinsicMapAdvanced()[mTarget].find(pFunc->getName().str()) != getIntrinsicMapAdvanced()[mTarget].end())
399 {
400 return ProcessIntrinsicAdvanced(pCallInst);
401 }
402
403 SWR_ASSERT(getIntrinsicMap().find(pFunc->getName().str()) != getIntrinsicMap().end(),
404 "Unimplemented intrinsic %s.",
405 pFunc->getName().str().c_str());
406
407 Intrinsic::ID x86Intrinsic = getIntrinsicMap()[pFunc->getName().str()];
408 Function* pX86IntrinFunc =
409 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, x86Intrinsic);
410
411 SmallVector<Value*, 8> args;
412 for (auto& arg : pCallInst->arg_operands())
413 {
414 args.push_back(arg.get());
415 }
416 return B->CALLA(pX86IntrinFunc, args);
417 }
418
419 //////////////////////////////////////////////////////////////////////////
420 /// @brief LLVM funtion pass run method.
421 /// @param f- The function we're working on with this pass.
422 virtual bool runOnFunction(Function& F)
423 {
424 std::vector<Instruction*> toRemove;
425 std::vector<BasicBlock*> bbs;
426
427 // Make temp copy of the basic blocks and instructions, as the intrinsic
428 // replacement code might invalidate the iterators
429 for (auto& b : F.getBasicBlockList())
430 {
431 bbs.push_back(&b);
432 }
433
434 for (auto* BB : bbs)
435 {
436 std::vector<Instruction*> insts;
437 for (auto& i : BB->getInstList())
438 {
439 insts.push_back(&i);
440 }
441
442 for (auto* I : insts)
443 {
444 if (CallInst* pCallInst = dyn_cast<CallInst>(I))
445 {
446 Function* pFunc = pCallInst->getCalledFunction();
447 if (pFunc)
448 {
449 if (pFunc->getName().startswith("meta.intrinsic"))
450 {
451 B->IRB()->SetInsertPoint(I);
452 Instruction* pReplace = ProcessIntrinsic(pCallInst);
453 toRemove.push_back(pCallInst);
454 if (pReplace)
455 {
456 pCallInst->replaceAllUsesWith(pReplace);
457 }
458 }
459 }
460 }
461 }
462 }
463
464 for (auto* pInst : toRemove)
465 {
466 pInst->eraseFromParent();
467 }
468
469 JitManager::DumpToFile(&F, "lowerx86");
470
471 return true;
472 }
473
474 virtual void getAnalysisUsage(AnalysisUsage& AU) const {}
475
476 JitManager* JM() { return B->JM(); }
477 Builder* B;
478 TargetArch mTarget;
479 Function* mPfnScatter256;
480
481 static char ID; ///< Needed by LLVM to generate ID for FunctionPass.
482 };
483
484 char LowerX86::ID = 0; // LLVM uses address of ID as the actual ID.
485
486 FunctionPass* createLowerX86Pass(Builder* b) { return new LowerX86(b); }
487
488 Instruction* NO_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
489 {
490 SWR_ASSERT(false, "Unimplemented intrinsic emulation.");
491 return nullptr;
492 }
493
494 Instruction* VPERM_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
495 {
496 // Only need vperm emulation for AVX
497 SWR_ASSERT(arch == AVX);
498
499 Builder* B = pThis->B;
500 auto v32A = pCallInst->getArgOperand(0);
501 auto vi32Index = pCallInst->getArgOperand(1);
502
503 Value* v32Result;
504 if (isa<Constant>(vi32Index))
505 {
506 // Can use llvm shuffle vector directly with constant shuffle indices
507 v32Result = B->VSHUFFLE(v32A, v32A, vi32Index);
508 }
509 else
510 {
511 v32Result = UndefValue::get(v32A->getType());
512 #if LLVM_VERSION_MAJOR >= 11
513 uint32_t numElem = cast<VectorType>(v32A->getType())->getNumElements();
514 #else
515 uint32_t numElem = v32A->getType()->getVectorNumElements();
516 #endif
517 for (uint32_t l = 0; l < numElem; ++l)
518 {
519 auto i32Index = B->VEXTRACT(vi32Index, B->C(l));
520 auto val = B->VEXTRACT(v32A, i32Index);
521 v32Result = B->VINSERT(v32Result, val, B->C(l));
522 }
523 }
524 return cast<Instruction>(v32Result);
525 }
526
527 Instruction*
528 VGATHER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
529 {
530 Builder* B = pThis->B;
531 auto vSrc = pCallInst->getArgOperand(0);
532 auto pBase = pCallInst->getArgOperand(1);
533 auto vi32Indices = pCallInst->getArgOperand(2);
534 auto vi1Mask = pCallInst->getArgOperand(3);
535 auto i8Scale = pCallInst->getArgOperand(4);
536
537 pBase = B->POINTER_CAST(pBase, PointerType::get(B->mInt8Ty, 0));
538 #if LLVM_VERSION_MAJOR >= 11
539 VectorType* pVectorType = cast<VectorType>(vSrc->getType());
540 uint32_t numElem = pVectorType->getNumElements();
541 auto srcTy = pVectorType->getElementType();
542 #else
543 uint32_t numElem = vSrc->getType()->getVectorNumElements();
544 auto srcTy = vSrc->getType()->getVectorElementType();
545 #endif
546 auto i32Scale = B->Z_EXT(i8Scale, B->mInt32Ty);
547
548 Value* v32Gather = nullptr;
549 if (arch == AVX)
550 {
551 // Full emulation for AVX
552 // Store source on stack to provide a valid address to load from inactive lanes
553 auto pStack = B->STACKSAVE();
554 auto pTmp = B->ALLOCA(vSrc->getType());
555 B->STORE(vSrc, pTmp);
556
557 v32Gather = UndefValue::get(vSrc->getType());
558 #if LLVM_VERSION_MAJOR > 10
559 auto vi32Scale = ConstantVector::getSplat(ElementCount::get(numElem, false), cast<ConstantInt>(i32Scale));
560 #else
561 auto vi32Scale = ConstantVector::getSplat(numElem, cast<ConstantInt>(i32Scale));
562 #endif
563 auto vi32Offsets = B->MUL(vi32Indices, vi32Scale);
564
565 for (uint32_t i = 0; i < numElem; ++i)
566 {
567 auto i32Offset = B->VEXTRACT(vi32Offsets, B->C(i));
568 auto pLoadAddress = B->GEP(pBase, i32Offset);
569 pLoadAddress = B->BITCAST(pLoadAddress, PointerType::get(srcTy, 0));
570 auto pMaskedLoadAddress = B->GEP(pTmp, {0, i});
571 auto i1Mask = B->VEXTRACT(vi1Mask, B->C(i));
572 auto pValidAddress = B->SELECT(i1Mask, pLoadAddress, pMaskedLoadAddress);
573 auto val = B->LOAD(pValidAddress);
574 v32Gather = B->VINSERT(v32Gather, val, B->C(i));
575 }
576
577 B->STACKRESTORE(pStack);
578 }
579 else if (arch == AVX2 || (arch == AVX512 && width == W256))
580 {
581 Function* pX86IntrinFunc = nullptr;
582 if (srcTy == B->mFP32Ty)
583 {
584 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
585 Intrinsic::x86_avx2_gather_d_ps_256);
586 }
587 else if (srcTy == B->mInt32Ty)
588 {
589 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
590 Intrinsic::x86_avx2_gather_d_d_256);
591 }
592 else if (srcTy == B->mDoubleTy)
593 {
594 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
595 Intrinsic::x86_avx2_gather_d_q_256);
596 }
597 else
598 {
599 SWR_ASSERT(false, "Unsupported vector element type for gather.");
600 }
601
602 if (width == W256)
603 {
604 auto v32Mask = B->BITCAST(pThis->VectorMask(vi1Mask), vSrc->getType());
605 v32Gather = B->CALL(pX86IntrinFunc, {vSrc, pBase, vi32Indices, v32Mask, i8Scale});
606 }
607 else if (width == W512)
608 {
609 // Double pump 4-wide for 64bit elements
610 #if LLVM_VERSION_MAJOR >= 11
611 if (cast<VectorType>(vSrc->getType())->getElementType() == B->mDoubleTy)
612 #else
613 if (vSrc->getType()->getVectorElementType() == B->mDoubleTy)
614 #endif
615 {
616 auto v64Mask = pThis->VectorMask(vi1Mask);
617 #if LLVM_VERSION_MAJOR >= 11
618 uint32_t numElem = cast<VectorType>(v64Mask->getType())->getNumElements();
619 #else
620 uint32_t numElem = v64Mask->getType()->getVectorNumElements();
621 #endif
622 v64Mask = B->S_EXT(v64Mask, getVectorType(B->mInt64Ty, numElem));
623 v64Mask = B->BITCAST(v64Mask, vSrc->getType());
624
625 Value* src0 = B->VSHUFFLE(vSrc, vSrc, B->C({0, 1, 2, 3}));
626 Value* src1 = B->VSHUFFLE(vSrc, vSrc, B->C({4, 5, 6, 7}));
627
628 Value* indices0 = B->VSHUFFLE(vi32Indices, vi32Indices, B->C({0, 1, 2, 3}));
629 Value* indices1 = B->VSHUFFLE(vi32Indices, vi32Indices, B->C({4, 5, 6, 7}));
630
631 Value* mask0 = B->VSHUFFLE(v64Mask, v64Mask, B->C({0, 1, 2, 3}));
632 Value* mask1 = B->VSHUFFLE(v64Mask, v64Mask, B->C({4, 5, 6, 7}));
633
634 #if LLVM_VERSION_MAJOR >= 11
635 uint32_t numElemSrc0 = cast<VectorType>(src0->getType())->getNumElements();
636 uint32_t numElemMask0 = cast<VectorType>(mask0->getType())->getNumElements();
637 uint32_t numElemSrc1 = cast<VectorType>(src1->getType())->getNumElements();
638 uint32_t numElemMask1 = cast<VectorType>(mask1->getType())->getNumElements();
639 #else
640 uint32_t numElemSrc0 = src0->getType()->getVectorNumElements();
641 uint32_t numElemMask0 = mask0->getType()->getVectorNumElements();
642 uint32_t numElemSrc1 = src1->getType()->getVectorNumElements();
643 uint32_t numElemMask1 = mask1->getType()->getVectorNumElements();
644 #endif
645 src0 = B->BITCAST(src0, getVectorType(B->mInt64Ty, numElemSrc0));
646 mask0 = B->BITCAST(mask0, getVectorType(B->mInt64Ty, numElemMask0));
647 Value* gather0 =
648 B->CALL(pX86IntrinFunc, {src0, pBase, indices0, mask0, i8Scale});
649 src1 = B->BITCAST(src1, getVectorType(B->mInt64Ty, numElemSrc1));
650 mask1 = B->BITCAST(mask1, getVectorType(B->mInt64Ty, numElemMask1));
651 Value* gather1 =
652 B->CALL(pX86IntrinFunc, {src1, pBase, indices1, mask1, i8Scale});
653 v32Gather = B->VSHUFFLE(gather0, gather1, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
654 v32Gather = B->BITCAST(v32Gather, vSrc->getType());
655 }
656 else
657 {
658 // Double pump 8-wide for 32bit elements
659 auto v32Mask = pThis->VectorMask(vi1Mask);
660 v32Mask = B->BITCAST(v32Mask, vSrc->getType());
661 Value* src0 = B->EXTRACT_16(vSrc, 0);
662 Value* src1 = B->EXTRACT_16(vSrc, 1);
663
664 Value* indices0 = B->EXTRACT_16(vi32Indices, 0);
665 Value* indices1 = B->EXTRACT_16(vi32Indices, 1);
666
667 Value* mask0 = B->EXTRACT_16(v32Mask, 0);
668 Value* mask1 = B->EXTRACT_16(v32Mask, 1);
669
670 Value* gather0 =
671 B->CALL(pX86IntrinFunc, {src0, pBase, indices0, mask0, i8Scale});
672 Value* gather1 =
673 B->CALL(pX86IntrinFunc, {src1, pBase, indices1, mask1, i8Scale});
674
675 v32Gather = B->JOIN_16(gather0, gather1);
676 }
677 }
678 }
679 else if (arch == AVX512)
680 {
681 Value* iMask = nullptr;
682 Function* pX86IntrinFunc = nullptr;
683 if (srcTy == B->mFP32Ty)
684 {
685 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
686 Intrinsic::x86_avx512_gather_dps_512);
687 iMask = B->BITCAST(vi1Mask, B->mInt16Ty);
688 }
689 else if (srcTy == B->mInt32Ty)
690 {
691 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
692 Intrinsic::x86_avx512_gather_dpi_512);
693 iMask = B->BITCAST(vi1Mask, B->mInt16Ty);
694 }
695 else if (srcTy == B->mDoubleTy)
696 {
697 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
698 Intrinsic::x86_avx512_gather_dpd_512);
699 iMask = B->BITCAST(vi1Mask, B->mInt8Ty);
700 }
701 else
702 {
703 SWR_ASSERT(false, "Unsupported vector element type for gather.");
704 }
705
706 auto i32Scale = B->Z_EXT(i8Scale, B->mInt32Ty);
707 v32Gather = B->CALL(pX86IntrinFunc, {vSrc, pBase, vi32Indices, iMask, i32Scale});
708 }
709
710 return cast<Instruction>(v32Gather);
711 }
712 Instruction*
713 VSCATTER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
714 {
715 Builder* B = pThis->B;
716 auto pBase = pCallInst->getArgOperand(0);
717 auto vi1Mask = pCallInst->getArgOperand(1);
718 auto vi32Indices = pCallInst->getArgOperand(2);
719 auto v32Src = pCallInst->getArgOperand(3);
720 auto i32Scale = pCallInst->getArgOperand(4);
721
722 if (arch != AVX512)
723 {
724 // Call into C function to do the scatter. This has significantly better compile perf
725 // compared to jitting scatter loops for every scatter
726 if (width == W256)
727 {
728 auto mask = B->BITCAST(vi1Mask, B->mInt8Ty);
729 B->CALL(pThis->mPfnScatter256, {pBase, vi32Indices, v32Src, mask, i32Scale});
730 }
731 else
732 {
733 // Need to break up 512 wide scatter to two 256 wide
734 auto maskLo = B->VSHUFFLE(vi1Mask, vi1Mask, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
735 auto indicesLo =
736 B->VSHUFFLE(vi32Indices, vi32Indices, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
737 auto srcLo = B->VSHUFFLE(v32Src, v32Src, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
738
739 auto mask = B->BITCAST(maskLo, B->mInt8Ty);
740 B->CALL(pThis->mPfnScatter256, {pBase, indicesLo, srcLo, mask, i32Scale});
741
742 auto maskHi = B->VSHUFFLE(vi1Mask, vi1Mask, B->C({8, 9, 10, 11, 12, 13, 14, 15}));
743 auto indicesHi =
744 B->VSHUFFLE(vi32Indices, vi32Indices, B->C({8, 9, 10, 11, 12, 13, 14, 15}));
745 auto srcHi = B->VSHUFFLE(v32Src, v32Src, B->C({8, 9, 10, 11, 12, 13, 14, 15}));
746
747 mask = B->BITCAST(maskHi, B->mInt8Ty);
748 B->CALL(pThis->mPfnScatter256, {pBase, indicesHi, srcHi, mask, i32Scale});
749 }
750 return nullptr;
751 }
752
753 Value* iMask;
754 Function* pX86IntrinFunc;
755 if (width == W256)
756 {
757 // No direct intrinsic supported in llvm to scatter 8 elem with 32bit indices, but we
758 // can use the scatter of 8 elements with 64bit indices
759 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
760 Intrinsic::x86_avx512_scatter_qps_512);
761
762 auto vi32IndicesExt = B->Z_EXT(vi32Indices, B->mSimdInt64Ty);
763 iMask = B->BITCAST(vi1Mask, B->mInt8Ty);
764 B->CALL(pX86IntrinFunc, {pBase, iMask, vi32IndicesExt, v32Src, i32Scale});
765 }
766 else if (width == W512)
767 {
768 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
769 Intrinsic::x86_avx512_scatter_dps_512);
770 iMask = B->BITCAST(vi1Mask, B->mInt16Ty);
771 B->CALL(pX86IntrinFunc, {pBase, iMask, vi32Indices, v32Src, i32Scale});
772 }
773 return nullptr;
774 }
775
776 // No support for vroundps in avx512 (it is available in kncni), so emulate with avx
777 // instructions
778 Instruction*
779 VROUND_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
780 {
781 SWR_ASSERT(arch == AVX512);
782
783 auto B = pThis->B;
784 auto vf32Src = pCallInst->getOperand(0);
785 assert(vf32Src);
786 auto i8Round = pCallInst->getOperand(1);
787 assert(i8Round);
788 auto pfnFunc =
789 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, Intrinsic::x86_avx_round_ps_256);
790
791 if (width == W256)
792 {
793 return cast<Instruction>(B->CALL2(pfnFunc, vf32Src, i8Round));
794 }
795 else if (width == W512)
796 {
797 auto v8f32SrcLo = B->EXTRACT_16(vf32Src, 0);
798 auto v8f32SrcHi = B->EXTRACT_16(vf32Src, 1);
799
800 auto v8f32ResLo = B->CALL2(pfnFunc, v8f32SrcLo, i8Round);
801 auto v8f32ResHi = B->CALL2(pfnFunc, v8f32SrcHi, i8Round);
802
803 return cast<Instruction>(B->JOIN_16(v8f32ResLo, v8f32ResHi));
804 }
805 else
806 {
807 SWR_ASSERT(false, "Unimplemented vector width.");
808 }
809
810 return nullptr;
811 }
812
813 Instruction*
814 VCONVERT_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
815 {
816 SWR_ASSERT(arch == AVX512);
817
818 auto B = pThis->B;
819 auto vf32Src = pCallInst->getOperand(0);
820
821 if (width == W256)
822 {
823 auto vf32SrcRound = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
824 Intrinsic::x86_avx_round_ps_256);
825 return cast<Instruction>(B->FP_TRUNC(vf32SrcRound, B->mFP32Ty));
826 }
827 else if (width == W512)
828 {
829 // 512 can use intrinsic
830 auto pfnFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
831 Intrinsic::x86_avx512_mask_cvtpd2ps_512);
832 return cast<Instruction>(B->CALL(pfnFunc, vf32Src));
833 }
834 else
835 {
836 SWR_ASSERT(false, "Unimplemented vector width.");
837 }
838
839 return nullptr;
840 }
841
842 // No support for hsub in AVX512
843 Instruction* VHSUB_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
844 {
845 SWR_ASSERT(arch == AVX512);
846
847 auto B = pThis->B;
848 auto src0 = pCallInst->getOperand(0);
849 auto src1 = pCallInst->getOperand(1);
850
851 // 256b hsub can just use avx intrinsic
852 if (width == W256)
853 {
854 auto pX86IntrinFunc =
855 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, Intrinsic::x86_avx_hsub_ps_256);
856 return cast<Instruction>(B->CALL2(pX86IntrinFunc, src0, src1));
857 }
858 else if (width == W512)
859 {
860 // 512b hsub can be accomplished with shuf/sub combo
861 auto minuend = B->VSHUFFLE(src0, src1, B->C({0, 2, 8, 10, 4, 6, 12, 14}));
862 auto subtrahend = B->VSHUFFLE(src0, src1, B->C({1, 3, 9, 11, 5, 7, 13, 15}));
863 return cast<Instruction>(B->SUB(minuend, subtrahend));
864 }
865 else
866 {
867 SWR_ASSERT(false, "Unimplemented vector width.");
868 return nullptr;
869 }
870 }
871
872 // Double pump input using Intrin template arg. This blindly extracts lower and upper 256 from
873 // each vector argument and calls the 256 wide intrinsic, then merges the results to 512 wide
874 Instruction* DOUBLE_EMU(LowerX86* pThis,
875 TargetArch arch,
876 TargetWidth width,
877 CallInst* pCallInst,
878 Intrinsic::ID intrin)
879 {
880 auto B = pThis->B;
881 SWR_ASSERT(width == W512);
882 Value* result[2];
883 Function* pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, intrin);
884 for (uint32_t i = 0; i < 2; ++i)
885 {
886 SmallVector<Value*, 8> args;
887 for (auto& arg : pCallInst->arg_operands())
888 {
889 auto argType = arg.get()->getType();
890 if (argType->isVectorTy())
891 {
892 #if LLVM_VERSION_MAJOR >= 11
893 uint32_t vecWidth = cast<VectorType>(argType)->getNumElements();
894 auto elemTy = cast<VectorType>(argType)->getElementType();
895 #else
896 uint32_t vecWidth = argType->getVectorNumElements();
897 auto elemTy = argType->getVectorElementType();
898 #endif
899 Value* lanes = B->CInc<int>(i * vecWidth / 2, vecWidth / 2);
900 Value* argToPush = B->VSHUFFLE(arg.get(), B->VUNDEF(elemTy, vecWidth), lanes);
901 args.push_back(argToPush);
902 }
903 else
904 {
905 args.push_back(arg.get());
906 }
907 }
908 result[i] = B->CALLA(pX86IntrinFunc, args);
909 }
910 uint32_t vecWidth;
911 if (result[0]->getType()->isVectorTy())
912 {
913 assert(result[1]->getType()->isVectorTy());
914 #if LLVM_VERSION_MAJOR >= 11
915 vecWidth = cast<VectorType>(result[0]->getType())->getNumElements() +
916 cast<VectorType>(result[1]->getType())->getNumElements();
917 #else
918 vecWidth = result[0]->getType()->getVectorNumElements() +
919 result[1]->getType()->getVectorNumElements();
920 #endif
921 }
922 else
923 {
924 vecWidth = 2;
925 }
926 Value* lanes = B->CInc<int>(0, vecWidth);
927 return cast<Instruction>(B->VSHUFFLE(result[0], result[1], lanes));
928 }
929
930 } // namespace SwrJit
931
932 using namespace SwrJit;
933
934 INITIALIZE_PASS_BEGIN(LowerX86, "LowerX86", "LowerX86", false, false)
935 INITIALIZE_PASS_END(LowerX86, "LowerX86", "LowerX86", false, false)