Implement pow/rcp and sub opcodes
authorZack Rusin <zack@tungstengraphics.com>
Wed, 17 Oct 2007 17:27:25 +0000 (13:27 -0400)
committerZack Rusin <zack@tungstengraphics.com>
Wed, 24 Oct 2007 15:21:04 +0000 (11:21 -0400)
src/mesa/pipe/llvm/instructions.cpp
src/mesa/pipe/llvm/instructions.h
src/mesa/pipe/llvm/llvmtgsi.cpp

index 147a1b64f21491708450fa1bc9d49a740686a416..31729d0f5861102a1d59c40d1ccab31519bc7525 100644 (file)
@@ -12,8 +12,10 @@ Instructions::Instructions(llvm::Module *mod, llvm::BasicBlock *block)
    :  m_mod(mod), m_block(block), m_idx(0)
 {
    m_floatVecType = VectorType::get(Type::FloatTy, 4);
+
    m_llvmFSqrt = 0;
    m_llvmFAbs = 0;
+   m_llvmPow = 0;
 }
 
 llvm::Value * Instructions::add(llvm::Value *in1, llvm::Value *in2)
@@ -160,3 +162,67 @@ llvm::Value * Instructions::lit(llvm::Value *in1)
    return in1;
 }
 
+llvm::Value * Instructions::sub(llvm::Value *in1, llvm::Value *in2)
+{
+   BinaryOperator *res = BinaryOperator::create(Instruction::Sub, in1, in2,
+                                                name("sub"),
+                                                m_block);
+   return res;
+}
+
+llvm::Value * Instructions::callPow(llvm::Value *val1, llvm::Value *val2)
+{
+   if (!m_llvmPow) {
+      // predeclare the intrinsic
+      std::vector<const Type*> powArgs;
+      powArgs.push_back(Type::FloatTy);
+      powArgs.push_back(Type::FloatTy);
+      ParamAttrsList *powPal = 0;
+      FunctionType* powType = FunctionType::get(
+         /*Result=*/Type::FloatTy,
+         /*Params=*/powArgs,
+         /*isVarArg=*/false,
+         /*ParamAttrs=*/powPal);
+      m_llvmPow = new Function(
+         /*Type=*/powType,
+         /*Linkage=*/GlobalValue::ExternalLinkage,
+         /*Name=*/"llvm.pow.f32", m_mod);
+      m_llvmPow->setCallingConv(CallingConv::C);
+   }
+   std::vector<Value*> params;
+   params.push_back(val1);
+   params.push_back(val2);
+   CallInst *call = new CallInst(m_llvmPow, params.begin(), params.end(),
+                                 name("pow"),
+                                 m_block);
+   call->setCallingConv(CallingConv::C);
+   call->setTailCall(false);
+   return call;
+}
+
+llvm::Value * Instructions::pow(llvm::Value *in1, llvm::Value *in2)
+{
+   ExtractElementInst *x1 = new ExtractElementInst(in1, unsigned(0),
+                                                   name("x1"),
+                                                   m_block);
+   ExtractElementInst *x2 = new ExtractElementInst(in2, unsigned(0),
+                                                   name("x2"),
+                                                   m_block);
+   llvm::Value *val = callPow(x1, x2);
+   return vectorFromVals(val, val, val, val);
+}
+
+llvm::Value * Instructions::rcp(llvm::Value *in1)
+{
+   ExtractElementInst *x1 = new ExtractElementInst(in1, unsigned(0),
+                                                   name("x1"),
+                                                   m_block);
+   BinaryOperator *res = BinaryOperator::create(Instruction::FDiv,
+                                                ConstantFP::get(Type::FloatTy,
+                                                                APFloat(1.f)),
+                                                x1,
+                                                name("rcp"),
+                                                m_block);
+   return vectorFromVals(res, res, res, res);
+}
+
index c6e77710ea0ed8f20f159053ac48f42a2b4d4271..18b5f91131ec0a11332dfa07cb55fd37e4ad3034 100644 (file)
@@ -20,12 +20,16 @@ public:
    llvm::Value *madd(llvm::Value *in1, llvm::Value *in2,
                      llvm::Value *in2);
    llvm::Value *mul(llvm::Value *in1, llvm::Value *in2);
+   llvm::Value *pow(llvm::Value *in1, llvm::Value *in2);
+   llvm::Value *rcp(llvm::Value *in1);
    llvm::Value *rsq(llvm::Value *in1);
+   llvm::Value *sub(llvm::Value *in1, llvm::Value *in2);
 private:
    const char *name(const char *prefix);
 
    llvm::Value *callFSqrt(llvm::Value *val);
    llvm::Value *callFAbs(llvm::Value *val);
+   llvm::Value *callPow(llvm::Value *val1, llvm::Value *val2);
 
    llvm::Value *vectorFromVals(llvm::Value *x, llvm::Value *y,
                                llvm::Value *z, llvm::Value *w=0);
@@ -34,10 +38,12 @@ private:
    char        m_name[32];
    llvm::BasicBlock *m_block;
    int               m_idx;
-   llvm::Function   *m_llvmFSqrt;
-   llvm::Function   *m_llvmFAbs;
 
    llvm::VectorType *m_floatVecType;
+
+   llvm::Function   *m_llvmFSqrt;
+   llvm::Function   *m_llvmFAbs;
+   llvm::Function   *m_llvmPow;
 };
 
 #endif
index c934c002f08c2dbdc009e7cc4d55dcfeedf4c0fd..45114abe4e32209ed088221c7ce3d6b370a391b9 100644 (file)
@@ -125,6 +125,7 @@ translate_instruction(llvm::Module *module,
                       struct tgsi_full_instruction *fi)
 {
    llvm::Value *inputs[4];
+   printf("translate instr START\n");
    for (int i = 0; i < inst->Instruction.NumSrcRegs; ++i) {
       struct tgsi_full_src_register *src = &inst->FullSrcRegisters[i];
       llvm::Value *val = 0;
@@ -136,6 +137,7 @@ translate_instruction(llvm::Module *module,
          val = storage->tempElement(src->SrcRegister.Index);
       } else {
          fprintf(stderr, "ERROR: not support llvm source\n");
+         printf("translate instr END\n");
          return;
       }
 
@@ -154,6 +156,9 @@ translate_instruction(llvm::Module *module,
                  src->SrcRegister.SwizzleY != TGSI_SWIZZLE_Y ||
                  src->SrcRegister.SwizzleZ != TGSI_SWIZZLE_Z ||
                  src->SrcRegister.SwizzleW != TGSI_SWIZZLE_W) {
+         fprintf(stderr, "SWIZZLE is %d %d %d %d\n",
+                 src->SrcRegister.SwizzleX, src->SrcRegister.SwizzleY,
+                 src->SrcRegister.SwizzleZ, src->SrcRegister.SwizzleW);
          int swizzle = src->SrcRegister.SwizzleX * 1000;
          swizzle += src->SrcRegister.SwizzleY  * 100;
          swizzle += src->SrcRegister.SwizzleZ  * 10;
@@ -176,7 +181,9 @@ translate_instruction(llvm::Module *module,
       return;
    }
       break;
-   case TGSI_OPCODE_RCP:
+   case TGSI_OPCODE_RCP: {
+      out = instr->rcp(inputs[0]);
+   }
       break;
    case TGSI_OPCODE_RSQ: {
       out = instr->rsq(inputs[0]);
@@ -214,7 +221,9 @@ translate_instruction(llvm::Module *module,
       out = instr->madd(inputs[0], inputs[1], inputs[2]);
    }
       break;
-   case TGSI_OPCODE_SUB:
+   case TGSI_OPCODE_SUB: {
+      out = instr->sub(inputs[0], inputs[1]);
+   }
       break;
    case TGSI_OPCODE_LERP:
       break;
@@ -240,7 +249,9 @@ translate_instruction(llvm::Module *module,
       break;
    case TGSI_OPCODE_LOGBASE2:
       break;
-   case TGSI_OPCODE_POWER:
+   case TGSI_OPCODE_POWER: {
+      out = instr->pow(inputs[0], inputs[1]);
+   }
       break;
    case TGSI_OPCODE_CROSSPRODUCT:
       break;
@@ -449,6 +460,7 @@ translate_instruction(llvm::Module *module,
    case TGSI_OPCODE_KIL:
       break;
    case TGSI_OPCODE_END:
+      printf("translate instr END\n");
       return;
       break;
    default:
@@ -481,6 +493,7 @@ translate_instruction(llvm::Module *module,
       struct tgsi_full_dst_register *dst = &inst->FullDstRegisters[i];
 
       if (dst->DstRegister.File == TGSI_FILE_OUTPUT) {
+         printf("--- storing to %d %p\n", dst->DstRegister.Index, out);
          storage->store(dst->DstRegister.Index, out);
       } else if (dst->DstRegister.File == TGSI_FILE_TEMPORARY) {
          storage->setTempElement(dst->DstRegister.Index, out);
@@ -501,6 +514,7 @@ translate_instruction(llvm::Module *module,
       }
 #endif
    }
+   printf("translate instr END\n");
 }