Get basic function calls working in the shaders.
authorZack Rusin <zack@tungstengraphics.com>
Fri, 26 Oct 2007 18:52:10 +0000 (14:52 -0400)
committerZack Rusin <zack@tungstengraphics.com>
Fri, 26 Oct 2007 18:59:38 +0000 (14:59 -0400)
src/mesa/pipe/llvm/instructions.cpp
src/mesa/pipe/llvm/instructions.h
src/mesa/pipe/llvm/llvm_base_shader.cpp
src/mesa/pipe/llvm/llvmtgsi.cpp
src/mesa/pipe/llvm/storage.cpp
src/mesa/pipe/llvm/storage.h

index 3fca522324da8c4f7fed769ff2fa3a98fec6a444..645ab9106f4c74747dc470cb8a64e0aaa611175a 100644 (file)
@@ -32,6 +32,8 @@
 
 #include "instructions.h"
 
+#include "storage.h"
+
 #include <llvm/CallingConv.h>
 #include <llvm/Constants.h>
 #include <llvm/DerivedTypes.h>
 #include <llvm/InstrTypes.h>
 #include <llvm/Instructions.h>
 
-using namespace llvm;
+#include <sstream>
+#include <fstream>
+#include <iostream>
 
+using namespace llvm;
 
 Function* makeLitFunction(Module *mod);
 
+static inline std::string createFuncName(int label)
+{
+   std::ostringstream stream;
+   stream << "function";
+   stream << label;
+   return stream.str();
+}
+
 Instructions::Instructions(llvm::Module *mod, llvm::Function *func, llvm::BasicBlock *block)
    :  m_mod(mod), m_func(func), m_block(block), m_idx(0)
 {
@@ -623,18 +636,18 @@ void Instructions::printVector(llvm::Value *val)
 llvm::Function * Instructions::declarePrintf()
 {
    std::vector<const Type*> args;
-  ParamAttrsList *params = 0;
-  FunctionType* funcTy = FunctionType::get(
-     /*Result=*/IntegerType::get(32),
-     /*Params=*/args,
-     /*isVarArg=*/true,
-     /*ParamAttrs=*/params);
-  Function* func_printf = new Function(
-     /*Type=*/funcTy,
-     /*Linkage=*/GlobalValue::ExternalLinkage,
-     /*Name=*/"printf", m_mod);
-  func_printf->setCallingConv(CallingConv::C);
-  return func_printf;
+   ParamAttrsList *params = 0;
+   FunctionType* funcTy = FunctionType::get(
+      /*Result=*/IntegerType::get(32),
+      /*Params=*/args,
+      /*isVarArg=*/true,
+      /*ParamAttrs=*/params);
+   Function* func_printf = new Function(
+      /*Type=*/funcTy,
+      /*Linkage=*/GlobalValue::ExternalLinkage,
+      /*Name=*/"printf", m_mod);
+   func_printf->setCallingConv(CallingConv::C);
+   return func_printf;
 }
 
 
@@ -822,8 +835,6 @@ Function* makeLitFunction(Module *mod) {
       /*isVarArg=*/false,
       /*ParamAttrs=*/FuncTy_0_PAL);
 
-   PointerType* PointerTy_1 = PointerType::get(FuncTy_0);
-
    VectorType* VectorTy_2 = VectorType::get(Type::FloatTy, 4);
 
    std::vector<const Type*>FuncTy_3_args;
@@ -1085,3 +1096,72 @@ void Instructions::end()
    new ReturnInst(m_block);
 }
 
+void Instructions::cal(int label, llvm::Value *out, llvm::Value *in,
+                    llvm::Value *cst)
+{
+   std::vector<Value*> params;
+   params.push_back(out);
+   params.push_back(in);
+   params.push_back(cst);
+   llvm::Function *func = findFunction(label);
+
+   new CallInst(func, params.begin(), params.end(), std::string(), m_block);
+}
+
+llvm::Function * Instructions::declareFunc(int label)
+{
+   PointerType *vecPtr = PointerType::get(m_floatVecType);
+   std::vector<const Type*> args;
+   args.push_back(vecPtr);
+   args.push_back(vecPtr);
+   args.push_back(vecPtr);
+   ParamAttrsList *params = 0;
+   FunctionType *funcType = FunctionType::get(
+      /*Result=*/Type::VoidTy,
+      /*Params=*/args,
+      /*isVarArg=*/false,
+      /*ParamAttrs=*/params);
+   std::string name = createFuncName(label);
+   Function *func = new Function(
+      /*Type=*/funcType,
+      /*Linkage=*/GlobalValue::ExternalLinkage,
+      /*Name=*/name.c_str(), m_mod);
+   func->setCallingConv(CallingConv::C);
+   return func;
+}
+
+void Instructions::bgnSub(unsigned label, Storage *storage)
+{
+   llvm::Function *func = findFunction(label);
+
+   Function::arg_iterator args = func->arg_begin();
+   Value *ptr_OUT = args++;
+   ptr_OUT->setName("OUT");
+   Value *ptr_IN = args++;
+   ptr_IN->setName("IN");
+   Value *ptr_CONST = args++;
+   ptr_CONST->setName("CONST");
+   storage->pushArguments(ptr_OUT, ptr_IN, ptr_CONST);
+
+   llvm::BasicBlock *entry = new BasicBlock("entry", func, 0);
+
+   m_func = func;
+   m_block = entry;
+}
+
+void Instructions::endSub()
+{
+   m_func = 0;
+   m_block = 0;
+}
+
+llvm::Function * Instructions::findFunction(int label)
+{
+   llvm::Function *func = m_functions[label];
+   if (!func) {
+      func = declareFunc(label);
+      m_functions[label] = func;
+   }
+   return func;
+}
+
index 82d871d41067222716b4eca52778f53fb549ba34..85feb1665db82ea9eec5e744b645e4378092c876 100644 (file)
@@ -37,6 +37,7 @@
 #include <llvm/Module.h>
 #include <llvm/Value.h>
 
+#include <map>
 #include <stack>
 
 namespace llvm {
@@ -44,6 +45,8 @@ namespace llvm {
    class Function;
 }
 
+class Storage;
+
 class Instructions
 {
 public:
@@ -55,7 +58,10 @@ public:
    llvm::Value *arl(llvm::Value *in1);
    llvm::Value *add(llvm::Value *in1, llvm::Value *in2);
    void         beginLoop();
+   void         bgnSub(unsigned, Storage *);
    void         brk();
+   void         cal(int label, llvm::Value *out, llvm::Value *in,
+                    llvm::Value *cst);
    llvm::Value *cross(llvm::Value *in1, llvm::Value *in2);
    llvm::Value *dp3(llvm::Value *in1, llvm::Value *in2);
    llvm::Value *dp4(llvm::Value *in1, llvm::Value *in2);
@@ -65,6 +71,7 @@ public:
    void         endif();
    void         endLoop();
    void         end();
+   void         endSub();
    llvm::Value *ex2(llvm::Value *in);
    llvm::Value *floor(llvm::Value *in);
    llvm::Value *frc(llvm::Value *in);
@@ -101,6 +108,9 @@ private:
                                llvm::Value *z, llvm::Value *w=0);
 
    llvm::Function *declarePrintf();
+   llvm::Function *declareFunc(int label);
+
+   llvm::Function *findFunction(int label);
 private:
    llvm::Module *m_mod;
    llvm::Function *m_func;
@@ -125,6 +135,7 @@ private:
       llvm::BasicBlock *end;
    };
    std::stack<Loop> m_loopStack;
+   std::map<int, llvm::Function*> m_functions;
 };
 
 #endif
index f6fc83be9a6afccba4916b330763cad9d4b847e7..3f058258eebdee967f611e4ad6ea6a0bfc4c6353 100644 (file)
@@ -634,7 +634,7 @@ Module* createBaseShader() {
     BinaryOperator* int32_inc_103 = BinaryOperator::create(Instruction::Add, int32_i_0_reg2mem_0_100, const_int32_21, "inc", label_forbody_71);
     ICmpInst* int1_cmp21 = new ICmpInst(ICmpInst::ICMP_SLT, int32_inc_103, int32_num_vertices, "cmp21", label_forbody_71);
     new BranchInst(label_forbody_71, label_afterfor_72, int1_cmp21, label_forbody_71);
-    
+
     // Block afterfor (label_afterfor_72)
     new ReturnInst(label_afterfor_72);
     
index 6dfd7926fb53420e464d6a506592149e0f65107f..cfeb19e4ba8867cc0c652351fc434ee0452cebf5 100644 (file)
@@ -171,7 +171,8 @@ translate_instruction(llvm::Module *module,
                       Storage *storage,
                       Instructions *instr,
                       struct tgsi_full_instruction *inst,
-                      struct tgsi_full_instruction *fi)
+                      struct tgsi_full_instruction *fi,
+                      unsigned instno)
 {
    llvm::Value *inputs[4];
    inputs[0] = 0;
@@ -400,9 +401,18 @@ translate_instruction(llvm::Module *module,
       break;
    case TGSI_OPCODE_BRA:
       break;
-   case TGSI_OPCODE_CAL:
+   case TGSI_OPCODE_CAL: {
+      instr->cal(inst->InstructionExtLabel.Label,
+                 storage->outputPtr(),
+                 storage->inputPtr(),
+                 storage->constPtr());
+      return;
+   }
       break;
-   case TGSI_OPCODE_RET:
+   case TGSI_OPCODE_RET: {
+      instr->end();
+      return;
+   }
       break;
    case TGSI_OPCODE_SSG:
       break;
@@ -495,15 +505,24 @@ translate_instruction(llvm::Module *module,
       return;
    }
       break;
-   case TGSI_OPCODE_BGNSUB:
+   case TGSI_OPCODE_BGNSUB: {
+      instr->bgnSub(instno, storage);
+      storage->setCurrentBlock(instr->currentBlock());
+      return;
+   }
       break;
    case TGSI_OPCODE_ENDLOOP2: {
       instr->endLoop();
       storage->setCurrentBlock(instr->currentBlock());
+      storage->popArguments();
       return;
    }
       break;
-   case TGSI_OPCODE_ENDSUB:
+   case TGSI_OPCODE_ENDSUB: {
+      instr->endSub();
+      storage->setCurrentBlock(instr->currentBlock());
+      return;
+   }
       break;
    case TGSI_OPCODE_NOISE1:
       break;
@@ -620,7 +639,7 @@ tgsi_to_llvm(struct gallivm_prog *prog, const struct tgsi_token *tokens)
    struct tgsi_parse_context parse;
    struct tgsi_full_instruction fi;
    struct tgsi_full_declaration fd;
-
+   unsigned instno = 0;
    Function* shader = mod->getFunction("execute_shader");
    std::ostringstream stream;
    stream << "execute_shader";
@@ -662,7 +681,8 @@ tgsi_to_llvm(struct gallivm_prog *prog, const struct tgsi_token *tokens)
       case TGSI_TOKEN_TYPE_INSTRUCTION:
          translate_instruction(mod, &storage, &instr,
                                &parse.FullToken.FullInstruction,
-                               &fi);
+                               &fi, instno);
+         ++instno;
          break;
 
       default:
@@ -776,7 +796,7 @@ void gallivm_prog_dump(struct gallivm_prog *prog, const char *file_prefix)
       llvm::Function *func = mod->getFunction(func_name.c_str());
       assert(func);
       std::cout<<"; ---------- Start shader "<<prog->id<<std::endl;
-      std::cout<<*func<<std::endl;
+      std::cout<<*mod<<std::endl;
       std::cout<<"; ---------- End shader "<<prog->id<<std::endl;
    }
 }
index cba719a8becb6ef98f307e82908bd3ff1a9f0628..88ef6711cfb1b1ffb4d149145ddd650178e1f0e6 100644 (file)
@@ -46,7 +46,7 @@
 using namespace llvm;
 
 Storage::Storage(llvm::BasicBlock *block, llvm::Value *out,
-                                         llvm::Value *in, llvm::Value *consts)
+                 llvm::Value *in, llvm::Value *consts)
    : m_block(block), m_OUT(out),
      m_IN(in), m_CONST(consts),
      m_temps(32), m_addrs(32),
@@ -331,3 +331,41 @@ llvm::Value * Storage::outputElement(int idx, llvm::Value *indIdx )
 
    return load;
 }
+
+llvm::Value * Storage::inputPtr() const
+{
+   return m_IN;
+}
+
+llvm::Value * Storage::outputPtr() const
+{
+   return m_OUT;
+}
+
+llvm::Value * Storage::constPtr() const
+{
+   return m_CONST;
+}
+
+void Storage::pushArguments(llvm::Value *out, llvm::Value *in,
+                            llvm::Value *constPtr)
+{
+   Args arg;
+   arg.out = m_OUT;
+   arg.in  = m_IN;
+   arg.cst = m_CONST;
+   m_argStack.push(arg);
+
+   m_OUT = out;
+   m_IN = in;
+   m_CONST = constPtr;
+}
+
+void Storage::popArguments()
+{
+   Args arg = m_argStack.top();
+   m_OUT = arg.out;
+   m_IN = arg.in;
+   m_CONST = arg.cst;
+   m_argStack.pop();
+}
index ebdfcdefd601973af10805577d9259e2175d8fa2..b8d6eb06049f061001c93356009a7287684f1485 100644 (file)
@@ -35,6 +35,7 @@
 
 #include <map>
 #include <set>
+#include <stack>
 #include <vector>
 
 namespace llvm {
@@ -53,6 +54,10 @@ public:
            llvm::Value *out,
            llvm::Value *in, llvm::Value *consts);
 
+   llvm::Value *inputPtr() const;
+   llvm::Value *outputPtr() const;
+   llvm::Value *constPtr() const;
+
    void setCurrentBlock(llvm::BasicBlock *block);
 
    llvm::ConstantInt *constantInt(int);
@@ -76,6 +81,10 @@ public:
 
    int numConsts() const;
 
+   void pushArguments(llvm::Value *out, llvm::Value *in,
+                      llvm::Value *constPtr);
+   void popArguments();
+
 private:
    llvm::Value *maskWrite(llvm::Value *src, int mask, llvm::Value *templ);
    const char *name(const char *prefix);
@@ -106,6 +115,13 @@ private:
    int         m_numConsts;
 
    std::map<int, bool > m_destWriteMap;
+
+   struct Args {
+      llvm::Value *out;
+      llvm::Value *in;
+      llvm::Value *cst;
+   };
+   std::stack<Args> m_argStack;
 };
 
 #endif