swr/rasterizer: Better implementation of scatter
[mesa.git] / src / gallium / drivers / swr / rasterizer / jitter / functionpasses / lower_x86.cpp
1 /****************************************************************************
2 * Copyright (C) 2014-2018 Intel Corporation. All Rights Reserved.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * @file lower_x86.cpp
24 *
25 * @brief llvm pass to lower meta code to x86
26 *
27 * Notes:
28 *
29 ******************************************************************************/
30
31 #include "jit_pch.hpp"
32 #include "passes.h"
33 #include "JitManager.h"
34
35 #include "common/simdlib.hpp"
36
37 #include <unordered_map>
38
39 extern "C" void ScatterPS_256(uint8_t*, SIMD256::Integer, SIMD256::Float, uint8_t, uint32_t);
40
41 namespace llvm
42 {
43 // foward declare the initializer
44 void initializeLowerX86Pass(PassRegistry&);
45 } // namespace llvm
46
47 namespace SwrJit
48 {
49 using namespace llvm;
50
51 enum TargetArch
52 {
53 AVX = 0,
54 AVX2 = 1,
55 AVX512 = 2
56 };
57
58 enum TargetWidth
59 {
60 W256 = 0,
61 W512 = 1,
62 NUM_WIDTHS = 2
63 };
64
65 struct LowerX86;
66
67 typedef std::function<Instruction*(LowerX86*, TargetArch, TargetWidth, CallInst*)> EmuFunc;
68
69 struct X86Intrinsic
70 {
71 Intrinsic::ID intrin[NUM_WIDTHS];
72 EmuFunc emuFunc;
73 };
74
75 // Map of intrinsics that haven't been moved to the new mechanism yet. If used, these get the
76 // previous behavior of mapping directly to avx/avx2 intrinsics.
77 static std::map<std::string, Intrinsic::ID> intrinsicMap = {
78 {"meta.intrinsic.BEXTR_32", Intrinsic::x86_bmi_bextr_32},
79 {"meta.intrinsic.VPSHUFB", Intrinsic::x86_avx2_pshuf_b},
80 {"meta.intrinsic.VCVTPS2PH", Intrinsic::x86_vcvtps2ph_256},
81 {"meta.intrinsic.VPTESTC", Intrinsic::x86_avx_ptestc_256},
82 {"meta.intrinsic.VPTESTZ", Intrinsic::x86_avx_ptestz_256},
83 {"meta.intrinsic.VPHADDD", Intrinsic::x86_avx2_phadd_d},
84 {"meta.intrinsic.PDEP32", Intrinsic::x86_bmi_pdep_32},
85 {"meta.intrinsic.RDTSC", Intrinsic::x86_rdtsc},
86 };
87
88 // Forward decls
89 Instruction* NO_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
90 Instruction*
91 VPERM_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
92 Instruction*
93 VGATHER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
94 Instruction*
95 VSCATTER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
96 Instruction*
97 VROUND_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
98 Instruction*
99 VHSUB_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
100 Instruction*
101 VCONVERT_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
102
103 Instruction* DOUBLE_EMU(LowerX86* pThis,
104 TargetArch arch,
105 TargetWidth width,
106 CallInst* pCallInst,
107 Intrinsic::ID intrin);
108
109 static Intrinsic::ID DOUBLE = (Intrinsic::ID)-1;
110
111 // clang-format off
112 static std::map<std::string, X86Intrinsic> intrinsicMap2[] = {
113 // 256 wide 512 wide
114 {
115 // AVX
116 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx_rcp_ps_256, DOUBLE}, NO_EMU}},
117 {"meta.intrinsic.VPERMPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}},
118 {"meta.intrinsic.VPERMD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}},
119 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
120 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
121 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
122 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VSCATTER_EMU}},
123 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx_cvt_pd2_ps_256, Intrinsic::not_intrinsic}, NO_EMU}},
124 {"meta.intrinsic.VCVTPH2PS", {{Intrinsic::x86_vcvtph2ps_256, Intrinsic::not_intrinsic}, NO_EMU}},
125 {"meta.intrinsic.VROUND", {{Intrinsic::x86_avx_round_ps_256, DOUBLE}, NO_EMU}},
126 {"meta.intrinsic.VHSUBPS", {{Intrinsic::x86_avx_hsub_ps_256, DOUBLE}, NO_EMU}},
127 },
128 {
129 // AVX2
130 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx_rcp_ps_256, DOUBLE}, NO_EMU}},
131 {"meta.intrinsic.VPERMPS", {{Intrinsic::x86_avx2_permps, Intrinsic::not_intrinsic}, VPERM_EMU}},
132 {"meta.intrinsic.VPERMD", {{Intrinsic::x86_avx2_permd, Intrinsic::not_intrinsic}, VPERM_EMU}},
133 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
134 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
135 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
136 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VSCATTER_EMU}},
137 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx_cvt_pd2_ps_256, DOUBLE}, NO_EMU}},
138 {"meta.intrinsic.VCVTPH2PS", {{Intrinsic::x86_vcvtph2ps_256, Intrinsic::not_intrinsic}, NO_EMU}},
139 {"meta.intrinsic.VROUND", {{Intrinsic::x86_avx_round_ps_256, DOUBLE}, NO_EMU}},
140 {"meta.intrinsic.VHSUBPS", {{Intrinsic::x86_avx_hsub_ps_256, DOUBLE}, NO_EMU}},
141 },
142 {
143 // AVX512
144 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx512_rcp14_ps_256, Intrinsic::x86_avx512_rcp14_ps_512}, NO_EMU}},
145 #if LLVM_VERSION_MAJOR < 7
146 {"meta.intrinsic.VPERMPS", {{Intrinsic::x86_avx512_mask_permvar_sf_256, Intrinsic::x86_avx512_mask_permvar_sf_512}, NO_EMU}},
147 {"meta.intrinsic.VPERMD", {{Intrinsic::x86_avx512_mask_permvar_si_256, Intrinsic::x86_avx512_mask_permvar_si_512}, NO_EMU}},
148 #else
149 {"meta.intrinsic.VPERMPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}},
150 {"meta.intrinsic.VPERMD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}},
151 #endif
152 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
153 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
154 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}},
155 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VSCATTER_EMU}},
156 #if LLVM_VERSION_MAJOR < 7
157 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx512_mask_cvtpd2ps_256, Intrinsic::x86_avx512_mask_cvtpd2ps_512}, NO_EMU}},
158 #else
159 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VCONVERT_EMU}},
160 #endif
161 {"meta.intrinsic.VCVTPH2PS", {{Intrinsic::x86_avx512_mask_vcvtph2ps_256, Intrinsic::x86_avx512_mask_vcvtph2ps_512}, NO_EMU}},
162 {"meta.intrinsic.VROUND", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VROUND_EMU}},
163 {"meta.intrinsic.VHSUBPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VHSUB_EMU}},
164 }};
165 // clang-format on
166
167 struct LowerX86 : public FunctionPass
168 {
169 LowerX86(Builder* b = nullptr) : FunctionPass(ID), B(b)
170 {
171 initializeLowerX86Pass(*PassRegistry::getPassRegistry());
172
173 // Determine target arch
174 if (JM()->mArch.AVX512F())
175 {
176 mTarget = AVX512;
177 }
178 else if (JM()->mArch.AVX2())
179 {
180 mTarget = AVX2;
181 }
182 else if (JM()->mArch.AVX())
183 {
184 mTarget = AVX;
185 }
186 else
187 {
188 SWR_ASSERT(false, "Unsupported AVX architecture.");
189 mTarget = AVX;
190 }
191
192 // Setup scatter function for 256 wide
193 uint32_t curWidth = B->mVWidth;
194 B->SetTargetWidth(8);
195 std::vector<Type*> args = {
196 B->mInt8PtrTy, // pBase
197 B->mSimdInt32Ty, // vIndices
198 B->mSimdFP32Ty, // vSrc
199 B->mInt8Ty, // mask
200 B->mInt32Ty // scale
201 };
202
203 FunctionType* pfnScatterTy = FunctionType::get(B->mVoidTy, args, false);
204 mPfnScatter256 = cast<Function>(
205 B->JM()->mpCurrentModule->getOrInsertFunction("ScatterPS_256", pfnScatterTy));
206 if (sys::DynamicLibrary::SearchForAddressOfSymbol("ScatterPS_256") == nullptr)
207 {
208 sys::DynamicLibrary::AddSymbol("ScatterPS_256", (void*)&ScatterPS_256);
209 }
210
211 B->SetTargetWidth(curWidth);
212 }
213
214 // Try to decipher the vector type of the instruction. This does not work properly
215 // across all intrinsics, and will have to be rethought. Probably need something
216 // similar to llvm's getDeclaration() utility to map a set of inputs to a specific typed
217 // intrinsic.
218 void GetRequestedWidthAndType(CallInst* pCallInst,
219 const StringRef intrinName,
220 TargetWidth* pWidth,
221 Type** pTy)
222 {
223 Type* pVecTy = pCallInst->getType();
224
225 // Check for intrinsic specific types
226 // VCVTPD2PS type comes from src, not dst
227 if (intrinName.equals("meta.intrinsic.VCVTPD2PS"))
228 {
229 pVecTy = pCallInst->getOperand(0)->getType();
230 }
231
232 if (!pVecTy->isVectorTy())
233 {
234 for (auto& op : pCallInst->arg_operands())
235 {
236 if (op.get()->getType()->isVectorTy())
237 {
238 pVecTy = op.get()->getType();
239 break;
240 }
241 }
242 }
243 SWR_ASSERT(pVecTy->isVectorTy(), "Couldn't determine vector size");
244
245 uint32_t width = cast<VectorType>(pVecTy)->getBitWidth();
246 switch (width)
247 {
248 case 256:
249 *pWidth = W256;
250 break;
251 case 512:
252 *pWidth = W512;
253 break;
254 default:
255 SWR_ASSERT(false, "Unhandled vector width %d", width);
256 *pWidth = W256;
257 }
258
259 *pTy = pVecTy->getScalarType();
260 }
261
262 Value* GetZeroVec(TargetWidth width, Type* pTy)
263 {
264 uint32_t numElem = 0;
265 switch (width)
266 {
267 case W256:
268 numElem = 8;
269 break;
270 case W512:
271 numElem = 16;
272 break;
273 default:
274 SWR_ASSERT(false, "Unhandled vector width type %d\n", width);
275 }
276
277 return ConstantVector::getNullValue(VectorType::get(pTy, numElem));
278 }
279
280 Value* GetMask(TargetWidth width)
281 {
282 Value* mask;
283 switch (width)
284 {
285 case W256:
286 mask = B->C((uint8_t)-1);
287 break;
288 case W512:
289 mask = B->C((uint16_t)-1);
290 break;
291 default:
292 SWR_ASSERT(false, "Unhandled vector width type %d\n", width);
293 }
294 return mask;
295 }
296
297 // Convert <N x i1> mask to <N x i32> x86 mask
298 Value* VectorMask(Value* vi1Mask)
299 {
300 uint32_t numElem = vi1Mask->getType()->getVectorNumElements();
301 return B->S_EXT(vi1Mask, VectorType::get(B->mInt32Ty, numElem));
302 }
303
304 Instruction* ProcessIntrinsicAdvanced(CallInst* pCallInst)
305 {
306 Function* pFunc = pCallInst->getCalledFunction();
307 auto& intrinsic = intrinsicMap2[mTarget][pFunc->getName()];
308 TargetWidth vecWidth;
309 Type* pElemTy;
310 GetRequestedWidthAndType(pCallInst, pFunc->getName(), &vecWidth, &pElemTy);
311
312 // Check if there is a native intrinsic for this instruction
313 Intrinsic::ID id = intrinsic.intrin[vecWidth];
314 if (id == DOUBLE)
315 {
316 // Double pump the next smaller SIMD intrinsic
317 SWR_ASSERT(vecWidth != 0, "Cannot double pump smallest SIMD width.");
318 Intrinsic::ID id2 = intrinsic.intrin[vecWidth - 1];
319 SWR_ASSERT(id2 != Intrinsic::not_intrinsic,
320 "Cannot find intrinsic to double pump.");
321 return DOUBLE_EMU(this, mTarget, vecWidth, pCallInst, id2);
322 }
323 else if (id != Intrinsic::not_intrinsic)
324 {
325 Function* pIntrin = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, id);
326 SmallVector<Value*, 8> args;
327 for (auto& arg : pCallInst->arg_operands())
328 {
329 args.push_back(arg.get());
330 }
331
332 // If AVX512, all instructions add a src operand and mask. We'll pass in 0 src and
333 // full mask for now Assuming the intrinsics are consistent and place the src
334 // operand and mask last in the argument list.
335 if (mTarget == AVX512)
336 {
337 if (pFunc->getName().equals("meta.intrinsic.VCVTPD2PS"))
338 {
339 args.push_back(GetZeroVec(W256, pCallInst->getType()->getScalarType()));
340 args.push_back(GetMask(W256));
341 // for AVX512 VCVTPD2PS, we also have to add rounding mode
342 args.push_back(B->C(_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
343 }
344 else
345 {
346 args.push_back(GetZeroVec(vecWidth, pElemTy));
347 args.push_back(GetMask(vecWidth));
348 }
349 }
350
351 return B->CALLA(pIntrin, args);
352 }
353 else
354 {
355 // No native intrinsic, call emulation function
356 return intrinsic.emuFunc(this, mTarget, vecWidth, pCallInst);
357 }
358
359 SWR_ASSERT(false);
360 return nullptr;
361 }
362
363 Instruction* ProcessIntrinsic(CallInst* pCallInst)
364 {
365 Function* pFunc = pCallInst->getCalledFunction();
366
367 // Forward to the advanced support if found
368 if (intrinsicMap2[mTarget].find(pFunc->getName()) != intrinsicMap2[mTarget].end())
369 {
370 return ProcessIntrinsicAdvanced(pCallInst);
371 }
372
373 SWR_ASSERT(intrinsicMap.find(pFunc->getName()) != intrinsicMap.end(),
374 "Unimplemented intrinsic %s.",
375 pFunc->getName());
376
377 Intrinsic::ID x86Intrinsic = intrinsicMap[pFunc->getName()];
378 Function* pX86IntrinFunc =
379 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, x86Intrinsic);
380
381 SmallVector<Value*, 8> args;
382 for (auto& arg : pCallInst->arg_operands())
383 {
384 args.push_back(arg.get());
385 }
386 return B->CALLA(pX86IntrinFunc, args);
387 }
388
389 //////////////////////////////////////////////////////////////////////////
390 /// @brief LLVM funtion pass run method.
391 /// @param f- The function we're working on with this pass.
392 virtual bool runOnFunction(Function& F)
393 {
394 std::vector<Instruction*> toRemove;
395 std::vector<BasicBlock*> bbs;
396
397 // Make temp copy of the basic blocks and instructions, as the intrinsic
398 // replacement code might invalidate the iterators
399 for (auto& b : F.getBasicBlockList())
400 {
401 bbs.push_back(&b);
402 }
403
404 for (auto* BB : bbs)
405 {
406 std::vector<Instruction*> insts;
407 for (auto& i : BB->getInstList())
408 {
409 insts.push_back(&i);
410 }
411
412 for (auto* I : insts)
413 {
414 if (CallInst* pCallInst = dyn_cast<CallInst>(I))
415 {
416 Function* pFunc = pCallInst->getCalledFunction();
417 if (pFunc)
418 {
419 if (pFunc->getName().startswith("meta.intrinsic"))
420 {
421 B->IRB()->SetInsertPoint(I);
422 Instruction* pReplace = ProcessIntrinsic(pCallInst);
423 toRemove.push_back(pCallInst);
424 if (pReplace)
425 {
426 pCallInst->replaceAllUsesWith(pReplace);
427 }
428 }
429 }
430 }
431 }
432 }
433
434 for (auto* pInst : toRemove)
435 {
436 pInst->eraseFromParent();
437 }
438
439 JitManager::DumpToFile(&F, "lowerx86");
440
441 return true;
442 }
443
444 virtual void getAnalysisUsage(AnalysisUsage& AU) const {}
445
446 JitManager* JM() { return B->JM(); }
447 Builder* B;
448 TargetArch mTarget;
449 Function* mPfnScatter256;
450
451 static char ID; ///< Needed by LLVM to generate ID for FunctionPass.
452 };
453
454 char LowerX86::ID = 0; // LLVM uses address of ID as the actual ID.
455
456 FunctionPass* createLowerX86Pass(Builder* b) { return new LowerX86(b); }
457
458 Instruction* NO_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
459 {
460 SWR_ASSERT(false, "Unimplemented intrinsic emulation.");
461 return nullptr;
462 }
463
464 Instruction* VPERM_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
465 {
466 // Only need vperm emulation for AVX
467 SWR_ASSERT(arch == AVX);
468
469 Builder* B = pThis->B;
470 auto v32A = pCallInst->getArgOperand(0);
471 auto vi32Index = pCallInst->getArgOperand(1);
472
473 Value* v32Result;
474 if (isa<Constant>(vi32Index))
475 {
476 // Can use llvm shuffle vector directly with constant shuffle indices
477 v32Result = B->VSHUFFLE(v32A, v32A, vi32Index);
478 }
479 else
480 {
481 v32Result = UndefValue::get(v32A->getType());
482 for (uint32_t l = 0; l < v32A->getType()->getVectorNumElements(); ++l)
483 {
484 auto i32Index = B->VEXTRACT(vi32Index, B->C(l));
485 auto val = B->VEXTRACT(v32A, i32Index);
486 v32Result = B->VINSERT(v32Result, val, B->C(l));
487 }
488 }
489 return cast<Instruction>(v32Result);
490 }
491
492 Instruction*
493 VGATHER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
494 {
495 Builder* B = pThis->B;
496 auto vSrc = pCallInst->getArgOperand(0);
497 auto pBase = pCallInst->getArgOperand(1);
498 auto vi32Indices = pCallInst->getArgOperand(2);
499 auto vi1Mask = pCallInst->getArgOperand(3);
500 auto i8Scale = pCallInst->getArgOperand(4);
501
502 pBase = B->POINTER_CAST(pBase, PointerType::get(B->mInt8Ty, 0));
503 uint32_t numElem = vSrc->getType()->getVectorNumElements();
504 auto i32Scale = B->Z_EXT(i8Scale, B->mInt32Ty);
505 auto srcTy = vSrc->getType()->getVectorElementType();
506 Value* v32Gather;
507 if (arch == AVX)
508 {
509 // Full emulation for AVX
510 // Store source on stack to provide a valid address to load from inactive lanes
511 auto pStack = B->STACKSAVE();
512 auto pTmp = B->ALLOCA(vSrc->getType());
513 B->STORE(vSrc, pTmp);
514
515 v32Gather = UndefValue::get(vSrc->getType());
516 auto vi32Scale = ConstantVector::getSplat(numElem, cast<ConstantInt>(i32Scale));
517 auto vi32Offsets = B->MUL(vi32Indices, vi32Scale);
518
519 for (uint32_t i = 0; i < numElem; ++i)
520 {
521 auto i32Offset = B->VEXTRACT(vi32Offsets, B->C(i));
522 auto pLoadAddress = B->GEP(pBase, i32Offset);
523 pLoadAddress = B->BITCAST(pLoadAddress, PointerType::get(srcTy, 0));
524 auto pMaskedLoadAddress = B->GEP(pTmp, {0, i});
525 auto i1Mask = B->VEXTRACT(vi1Mask, B->C(i));
526 auto pValidAddress = B->SELECT(i1Mask, pLoadAddress, pMaskedLoadAddress);
527 auto val = B->LOAD(pValidAddress);
528 v32Gather = B->VINSERT(v32Gather, val, B->C(i));
529 }
530
531 B->STACKRESTORE(pStack);
532 }
533 else if (arch == AVX2 || (arch == AVX512 && width == W256))
534 {
535 Function* pX86IntrinFunc;
536 if (srcTy == B->mFP32Ty)
537 {
538 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
539 Intrinsic::x86_avx2_gather_d_ps_256);
540 }
541 else if (srcTy == B->mInt32Ty)
542 {
543 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
544 Intrinsic::x86_avx2_gather_d_d_256);
545 }
546 else if (srcTy == B->mDoubleTy)
547 {
548 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
549 Intrinsic::x86_avx2_gather_d_q_256);
550 }
551 else
552 {
553 SWR_ASSERT(false, "Unsupported vector element type for gather.");
554 }
555
556 if (width == W256)
557 {
558 auto v32Mask = B->BITCAST(pThis->VectorMask(vi1Mask), vSrc->getType());
559 v32Gather = B->CALL(pX86IntrinFunc, {vSrc, pBase, vi32Indices, v32Mask, i8Scale});
560 }
561 else if (width == W512)
562 {
563 // Double pump 4-wide for 64bit elements
564 if (vSrc->getType()->getVectorElementType() == B->mDoubleTy)
565 {
566 auto v64Mask = pThis->VectorMask(vi1Mask);
567 v64Mask = B->S_EXT(
568 v64Mask,
569 VectorType::get(B->mInt64Ty, v64Mask->getType()->getVectorNumElements()));
570 v64Mask = B->BITCAST(v64Mask, vSrc->getType());
571
572 Value* src0 = B->VSHUFFLE(vSrc, vSrc, B->C({0, 1, 2, 3}));
573 Value* src1 = B->VSHUFFLE(vSrc, vSrc, B->C({4, 5, 6, 7}));
574
575 Value* indices0 = B->VSHUFFLE(vi32Indices, vi32Indices, B->C({0, 1, 2, 3}));
576 Value* indices1 = B->VSHUFFLE(vi32Indices, vi32Indices, B->C({4, 5, 6, 7}));
577
578 Value* mask0 = B->VSHUFFLE(v64Mask, v64Mask, B->C({0, 1, 2, 3}));
579 Value* mask1 = B->VSHUFFLE(v64Mask, v64Mask, B->C({4, 5, 6, 7}));
580
581 src0 = B->BITCAST(
582 src0,
583 VectorType::get(B->mInt64Ty, src0->getType()->getVectorNumElements()));
584 mask0 = B->BITCAST(
585 mask0,
586 VectorType::get(B->mInt64Ty, mask0->getType()->getVectorNumElements()));
587 Value* gather0 =
588 B->CALL(pX86IntrinFunc, {src0, pBase, indices0, mask0, i8Scale});
589 src1 = B->BITCAST(
590 src1,
591 VectorType::get(B->mInt64Ty, src1->getType()->getVectorNumElements()));
592 mask1 = B->BITCAST(
593 mask1,
594 VectorType::get(B->mInt64Ty, mask1->getType()->getVectorNumElements()));
595 Value* gather1 =
596 B->CALL(pX86IntrinFunc, {src1, pBase, indices1, mask1, i8Scale});
597
598 v32Gather = B->VSHUFFLE(gather0, gather1, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
599 v32Gather = B->BITCAST(v32Gather, vSrc->getType());
600 }
601 else
602 {
603 // Double pump 8-wide for 32bit elements
604 auto v32Mask = pThis->VectorMask(vi1Mask);
605 v32Mask = B->BITCAST(v32Mask, vSrc->getType());
606 Value* src0 = B->EXTRACT_16(vSrc, 0);
607 Value* src1 = B->EXTRACT_16(vSrc, 1);
608
609 Value* indices0 = B->EXTRACT_16(vi32Indices, 0);
610 Value* indices1 = B->EXTRACT_16(vi32Indices, 1);
611
612 Value* mask0 = B->EXTRACT_16(v32Mask, 0);
613 Value* mask1 = B->EXTRACT_16(v32Mask, 1);
614
615 Value* gather0 =
616 B->CALL(pX86IntrinFunc, {src0, pBase, indices0, mask0, i8Scale});
617 Value* gather1 =
618 B->CALL(pX86IntrinFunc, {src1, pBase, indices1, mask1, i8Scale});
619
620 v32Gather = B->JOIN_16(gather0, gather1);
621 }
622 }
623 }
624 else if (arch == AVX512)
625 {
626 Value* iMask;
627 Function* pX86IntrinFunc;
628 if (srcTy == B->mFP32Ty)
629 {
630 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
631 Intrinsic::x86_avx512_gather_dps_512);
632 iMask = B->BITCAST(vi1Mask, B->mInt16Ty);
633 }
634 else if (srcTy == B->mInt32Ty)
635 {
636 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
637 Intrinsic::x86_avx512_gather_dpi_512);
638 iMask = B->BITCAST(vi1Mask, B->mInt16Ty);
639 }
640 else if (srcTy == B->mDoubleTy)
641 {
642 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
643 Intrinsic::x86_avx512_gather_dpd_512);
644 iMask = B->BITCAST(vi1Mask, B->mInt8Ty);
645 }
646 else
647 {
648 SWR_ASSERT(false, "Unsupported vector element type for gather.");
649 }
650
651 auto i32Scale = B->Z_EXT(i8Scale, B->mInt32Ty);
652 v32Gather = B->CALL(pX86IntrinFunc, {vSrc, pBase, vi32Indices, iMask, i32Scale});
653 }
654
655 return cast<Instruction>(v32Gather);
656 }
657 Instruction*
658 VSCATTER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
659 {
660 Builder* B = pThis->B;
661 auto pBase = pCallInst->getArgOperand(0);
662 auto vi1Mask = pCallInst->getArgOperand(1);
663 auto vi32Indices = pCallInst->getArgOperand(2);
664 auto v32Src = pCallInst->getArgOperand(3);
665 auto i32Scale = pCallInst->getArgOperand(4);
666
667 if (arch != AVX512)
668 {
669 // Call into C function to do the scatter. This has significantly better compile perf
670 // compared to jitting scatter loops for every scatter
671 if (width == W256)
672 {
673 auto mask = B->BITCAST(vi1Mask, B->mInt8Ty);
674 B->CALL(pThis->mPfnScatter256, {pBase, vi32Indices, v32Src, mask, i32Scale});
675 }
676 else
677 {
678 // Need to break up 512 wide scatter to two 256 wide
679 auto maskLo = B->VSHUFFLE(vi1Mask, vi1Mask, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
680 auto indicesLo =
681 B->VSHUFFLE(vi32Indices, vi32Indices, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
682 auto srcLo = B->VSHUFFLE(v32Src, v32Src, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
683
684 auto mask = B->BITCAST(maskLo, B->mInt8Ty);
685 B->CALL(pThis->mPfnScatter256, {pBase, indicesLo, srcLo, mask, i32Scale});
686
687 auto maskHi = B->VSHUFFLE(vi1Mask, vi1Mask, B->C({8, 9, 10, 11, 12, 13, 14, 15}));
688 auto indicesHi =
689 B->VSHUFFLE(vi32Indices, vi32Indices, B->C({8, 9, 10, 11, 12, 13, 14, 15}));
690 auto srcHi = B->VSHUFFLE(v32Src, v32Src, B->C({8, 9, 10, 11, 12, 13, 14, 15}));
691
692 mask = B->BITCAST(maskHi, B->mInt8Ty);
693 B->CALL(pThis->mPfnScatter256, {pBase, indicesHi, srcHi, mask, i32Scale});
694 }
695 return nullptr;
696 }
697
698 Value* iMask;
699 Function* pX86IntrinFunc;
700 if (width == W256)
701 {
702 // No direct intrinsic supported in llvm to scatter 8 elem with 32bit indices, but we
703 // can use the scatter of 8 elements with 64bit indices
704 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
705 Intrinsic::x86_avx512_scatter_qps_512);
706
707 auto vi32IndicesExt = B->Z_EXT(vi32Indices, B->mSimdInt64Ty);
708 iMask = B->BITCAST(vi1Mask, B->mInt8Ty);
709 B->CALL(pX86IntrinFunc, {pBase, iMask, vi32IndicesExt, v32Src, i32Scale});
710 }
711 else if (width == W512)
712 {
713 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
714 Intrinsic::x86_avx512_scatter_dps_512);
715 iMask = B->BITCAST(vi1Mask, B->mInt16Ty);
716 B->CALL(pX86IntrinFunc, {pBase, iMask, vi32Indices, v32Src, i32Scale});
717 }
718 return nullptr;
719 }
720
721 // No support for vroundps in avx512 (it is available in kncni), so emulate with avx
722 // instructions
723 Instruction*
724 VROUND_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
725 {
726 SWR_ASSERT(arch == AVX512);
727
728 auto B = pThis->B;
729 auto vf32Src = pCallInst->getOperand(0);
730 auto i8Round = pCallInst->getOperand(1);
731 auto pfnFunc =
732 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, Intrinsic::x86_avx_round_ps_256);
733
734 if (width == W256)
735 {
736 return cast<Instruction>(B->CALL2(pfnFunc, vf32Src, i8Round));
737 }
738 else if (width == W512)
739 {
740 auto v8f32SrcLo = B->EXTRACT_16(vf32Src, 0);
741 auto v8f32SrcHi = B->EXTRACT_16(vf32Src, 1);
742
743 auto v8f32ResLo = B->CALL2(pfnFunc, v8f32SrcLo, i8Round);
744 auto v8f32ResHi = B->CALL2(pfnFunc, v8f32SrcHi, i8Round);
745
746 return cast<Instruction>(B->JOIN_16(v8f32ResLo, v8f32ResHi));
747 }
748 else
749 {
750 SWR_ASSERT(false, "Unimplemented vector width.");
751 }
752
753 return nullptr;
754 }
755
756 Instruction*
757 VCONVERT_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
758 {
759 SWR_ASSERT(arch == AVX512);
760
761 auto B = pThis->B;
762 auto vf32Src = pCallInst->getOperand(0);
763
764 if (width == W256)
765 {
766 auto vf32SrcRound = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
767 Intrinsic::x86_avx_round_ps_256);
768 return cast<Instruction>(B->FP_TRUNC(vf32SrcRound, B->mFP32Ty));
769 }
770 else if (width == W512)
771 {
772 // 512 can use intrinsic
773 auto pfnFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
774 Intrinsic::x86_avx512_mask_cvtpd2ps_512);
775 return cast<Instruction>(B->CALL(pfnFunc, vf32Src));
776 }
777 else
778 {
779 SWR_ASSERT(false, "Unimplemented vector width.");
780 }
781
782 return nullptr;
783 }
784
785 // No support for hsub in AVX512
786 Instruction* VHSUB_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
787 {
788 SWR_ASSERT(arch == AVX512);
789
790 auto B = pThis->B;
791 auto src0 = pCallInst->getOperand(0);
792 auto src1 = pCallInst->getOperand(1);
793
794 // 256b hsub can just use avx intrinsic
795 if (width == W256)
796 {
797 auto pX86IntrinFunc =
798 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, Intrinsic::x86_avx_hsub_ps_256);
799 return cast<Instruction>(B->CALL2(pX86IntrinFunc, src0, src1));
800 }
801 else if (width == W512)
802 {
803 // 512b hsub can be accomplished with shuf/sub combo
804 auto minuend = B->VSHUFFLE(src0, src1, B->C({0, 2, 8, 10, 4, 6, 12, 14}));
805 auto subtrahend = B->VSHUFFLE(src0, src1, B->C({1, 3, 9, 11, 5, 7, 13, 15}));
806 return cast<Instruction>(B->SUB(minuend, subtrahend));
807 }
808 else
809 {
810 SWR_ASSERT(false, "Unimplemented vector width.");
811 return nullptr;
812 }
813 }
814
815 // Double pump input using Intrin template arg. This blindly extracts lower and upper 256 from
816 // each vector argument and calls the 256 wide intrinsic, then merges the results to 512 wide
817 Instruction* DOUBLE_EMU(LowerX86* pThis,
818 TargetArch arch,
819 TargetWidth width,
820 CallInst* pCallInst,
821 Intrinsic::ID intrin)
822 {
823 auto B = pThis->B;
824 SWR_ASSERT(width == W512);
825 Value* result[2];
826 Function* pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, intrin);
827 for (uint32_t i = 0; i < 2; ++i)
828 {
829 SmallVector<Value*, 8> args;
830 for (auto& arg : pCallInst->arg_operands())
831 {
832 auto argType = arg.get()->getType();
833 if (argType->isVectorTy())
834 {
835 uint32_t vecWidth = argType->getVectorNumElements();
836 Value* lanes = B->CInc<int>(i * vecWidth / 2, vecWidth / 2);
837 Value* argToPush = B->VSHUFFLE(
838 arg.get(), B->VUNDEF(argType->getVectorElementType(), vecWidth), lanes);
839 args.push_back(argToPush);
840 }
841 else
842 {
843 args.push_back(arg.get());
844 }
845 }
846 result[i] = B->CALLA(pX86IntrinFunc, args);
847 }
848 uint32_t vecWidth;
849 if (result[0]->getType()->isVectorTy())
850 {
851 assert(result[1]->getType()->isVectorTy());
852 vecWidth = result[0]->getType()->getVectorNumElements() +
853 result[1]->getType()->getVectorNumElements();
854 }
855 else
856 {
857 vecWidth = 2;
858 }
859 Value* lanes = B->CInc<int>(0, vecWidth);
860 return cast<Instruction>(B->VSHUFFLE(result[0], result[1], lanes));
861 }
862
863 } // namespace SwrJit
864
865 using namespace SwrJit;
866
867 INITIALIZE_PASS_BEGIN(LowerX86, "LowerX86", "LowerX86", false, false)
868 INITIALIZE_PASS_END(LowerX86, "LowerX86", "LowerX86", false, false)