nv50/ir/nir: implement nir_alu_instr handling
authorKarol Herbst <kherbst@redhat.com>
Tue, 12 Dec 2017 20:05:30 +0000 (21:05 +0100)
committerKarol Herbst <kherbst@redhat.com>
Sun, 17 Mar 2019 09:33:28 +0000 (10:33 +0100)
v2: user bitfield_insert instead of bfi
    rework switch helper macros
    remove some lowering code (LoweringHelper is now used for this)
v3: add pack_half_2x16_split
    add unpack_half_2x16_split_x/y
v5: replace first argument with nullptr in loadImm calls
    prefer getSSA over getScratch
v8: fix setting precise modifier for first instruction inside a block
    add guard in case no instruction gets inserted into an empty block
    don't require C++11 features
v9: use CC_NE for integer compares
    convert to C++ style comments
    fix b2f for doubles
    remove macros around nir ops to make it easier to grep them
    add handling for fpow

Signed-off-by: Karol Herbst <kherbst@redhat.com>
src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp

index a99f3bbbc051d41fb0fd5e4a9bdf116dcb4d6a9d..a553e42e08a861f1b4cc36fafb382b0a9bd667a0 100644 (file)
@@ -114,9 +114,17 @@ private:
    std::vector<DataType> getSTypes(nir_alu_instr *);
    DataType getSType(nir_src &, bool isFloat, bool isSigned);
 
+   operation getOperation(nir_op);
+   operation preOperationNeeded(nir_op);
+
+   int getSubOp(nir_op);
+
+   CondCode getCondCode(nir_op);
+
    bool assignSlots();
    bool parseNIR();
 
+   bool visit(nir_alu_instr *);
    bool visit(nir_block *);
    bool visit(nir_cf_node *);
    bool visit(nir_function *);
@@ -135,6 +143,7 @@ private:
    unsigned int curLoopDepth;
 
    BasicBlock *exit;
+   Value *zero;
 
    union {
       struct {
@@ -146,7 +155,10 @@ private:
 Converter::Converter(Program *prog, nir_shader *nir, nv50_ir_prog_info *info)
    : ConverterCommon(prog, info),
      nir(nir),
-     curLoopDepth(0) {}
+     curLoopDepth(0)
+{
+   zero = mkImm((uint32_t)0);
+}
 
 BasicBlock *
 Converter::convert(nir_block *block)
@@ -275,6 +287,191 @@ Converter::getSType(nir_src &src, bool isFloat, bool isSigned)
    return ty;
 }
 
+operation
+Converter::getOperation(nir_op op)
+{
+   switch (op) {
+   // basic ops with float and int variants
+   case nir_op_fabs:
+   case nir_op_iabs:
+      return OP_ABS;
+   case nir_op_fadd:
+   case nir_op_iadd:
+      return OP_ADD;
+   case nir_op_fand:
+   case nir_op_iand:
+      return OP_AND;
+   case nir_op_ifind_msb:
+   case nir_op_ufind_msb:
+      return OP_BFIND;
+   case nir_op_fceil:
+      return OP_CEIL;
+   case nir_op_fcos:
+      return OP_COS;
+   case nir_op_f2f32:
+   case nir_op_f2f64:
+   case nir_op_f2i32:
+   case nir_op_f2i64:
+   case nir_op_f2u32:
+   case nir_op_f2u64:
+   case nir_op_i2f32:
+   case nir_op_i2f64:
+   case nir_op_i2i32:
+   case nir_op_i2i64:
+   case nir_op_u2f32:
+   case nir_op_u2f64:
+   case nir_op_u2u32:
+   case nir_op_u2u64:
+      return OP_CVT;
+   case nir_op_fddx:
+   case nir_op_fddx_coarse:
+   case nir_op_fddx_fine:
+      return OP_DFDX;
+   case nir_op_fddy:
+   case nir_op_fddy_coarse:
+   case nir_op_fddy_fine:
+      return OP_DFDY;
+   case nir_op_fdiv:
+   case nir_op_idiv:
+   case nir_op_udiv:
+      return OP_DIV;
+   case nir_op_fexp2:
+      return OP_EX2;
+   case nir_op_ffloor:
+      return OP_FLOOR;
+   case nir_op_ffma:
+      return OP_FMA;
+   case nir_op_flog2:
+      return OP_LG2;
+   case nir_op_fmax:
+   case nir_op_imax:
+   case nir_op_umax:
+      return OP_MAX;
+   case nir_op_pack_64_2x32_split:
+      return OP_MERGE;
+   case nir_op_fmin:
+   case nir_op_imin:
+   case nir_op_umin:
+      return OP_MIN;
+   case nir_op_fmod:
+   case nir_op_imod:
+   case nir_op_umod:
+   case nir_op_frem:
+   case nir_op_irem:
+      return OP_MOD;
+   case nir_op_fmul:
+   case nir_op_imul:
+   case nir_op_imul_high:
+   case nir_op_umul_high:
+      return OP_MUL;
+   case nir_op_fneg:
+   case nir_op_ineg:
+      return OP_NEG;
+   case nir_op_fnot:
+   case nir_op_inot:
+      return OP_NOT;
+   case nir_op_for:
+   case nir_op_ior:
+      return OP_OR;
+   case nir_op_fpow:
+      return OP_POW;
+   case nir_op_frcp:
+      return OP_RCP;
+   case nir_op_frsq:
+      return OP_RSQ;
+   case nir_op_fsat:
+      return OP_SAT;
+   case nir_op_feq32:
+   case nir_op_ieq32:
+   case nir_op_fge32:
+   case nir_op_ige32:
+   case nir_op_uge32:
+   case nir_op_flt32:
+   case nir_op_ilt32:
+   case nir_op_ult32:
+   case nir_op_fne32:
+   case nir_op_ine32:
+      return OP_SET;
+   case nir_op_ishl:
+      return OP_SHL;
+   case nir_op_ishr:
+   case nir_op_ushr:
+      return OP_SHR;
+   case nir_op_fsin:
+      return OP_SIN;
+   case nir_op_fsqrt:
+      return OP_SQRT;
+   case nir_op_fsub:
+   case nir_op_isub:
+      return OP_SUB;
+   case nir_op_ftrunc:
+      return OP_TRUNC;
+   case nir_op_fxor:
+   case nir_op_ixor:
+      return OP_XOR;
+   default:
+      ERROR("couldn't get operation for op %s\n", nir_op_infos[op].name);
+      assert(false);
+      return OP_NOP;
+   }
+}
+
+operation
+Converter::preOperationNeeded(nir_op op)
+{
+   switch (op) {
+   case nir_op_fcos:
+   case nir_op_fsin:
+      return OP_PRESIN;
+   default:
+      return OP_NOP;
+   }
+}
+
+int
+Converter::getSubOp(nir_op op)
+{
+   switch (op) {
+   case nir_op_imul_high:
+   case nir_op_umul_high:
+      return NV50_IR_SUBOP_MUL_HIGH;
+   default:
+      return 0;
+   }
+}
+
+CondCode
+Converter::getCondCode(nir_op op)
+{
+   switch (op) {
+   case nir_op_feq32:
+   case nir_op_ieq32:
+      return CC_EQ;
+   case nir_op_fge32:
+   case nir_op_ige32:
+   case nir_op_uge32:
+      return CC_GE;
+   case nir_op_flt32:
+   case nir_op_ilt32:
+   case nir_op_ult32:
+      return CC_LT;
+   case nir_op_fne32:
+      return CC_NEU;
+   case nir_op_ine32:
+      return CC_NE;
+   default:
+      ERROR("couldn't get CondCode for op %s\n", nir_op_infos[op].name);
+      assert(false);
+      return CC_FL;
+   }
+}
+
+Converter::LValues&
+Converter::convert(nir_alu_dest *dest)
+{
+   return convert(&dest->dest);
+}
+
 Converter::LValues&
 Converter::convert(nir_dest *dest)
 {
@@ -1314,6 +1511,8 @@ bool
 Converter::visit(nir_instr *insn)
 {
    switch (insn->type) {
+   case nir_instr_type_alu:
+      return visit(nir_instr_as_alu(insn));
    case nir_instr_type_intrinsic:
       return visit(nir_instr_as_intrinsic(insn));
    case nir_instr_type_jump:
@@ -1393,6 +1592,367 @@ Converter::visit(nir_load_const_instr *insn)
    return true;
 }
 
+#define DEFAULT_CHECKS \
+      if (insn->dest.dest.ssa.num_components > 1) { \
+         ERROR("nir_alu_instr only supported with 1 component!\n"); \
+         return false; \
+      } \
+      if (insn->dest.write_mask != 1) { \
+         ERROR("nir_alu_instr only with write_mask of 1 supported!\n"); \
+         return false; \
+      }
+bool
+Converter::visit(nir_alu_instr *insn)
+{
+   const nir_op op = insn->op;
+   const nir_op_info &info = nir_op_infos[op];
+   DataType dType = getDType(insn);
+   const std::vector<DataType> sTypes = getSTypes(insn);
+
+   Instruction *oldPos = this->bb->getExit();
+
+   switch (op) {
+   case nir_op_fabs:
+   case nir_op_iabs:
+   case nir_op_fadd:
+   case nir_op_iadd:
+   case nir_op_fand:
+   case nir_op_iand:
+   case nir_op_fceil:
+   case nir_op_fcos:
+   case nir_op_fddx:
+   case nir_op_fddx_coarse:
+   case nir_op_fddx_fine:
+   case nir_op_fddy:
+   case nir_op_fddy_coarse:
+   case nir_op_fddy_fine:
+   case nir_op_fdiv:
+   case nir_op_idiv:
+   case nir_op_udiv:
+   case nir_op_fexp2:
+   case nir_op_ffloor:
+   case nir_op_ffma:
+   case nir_op_flog2:
+   case nir_op_fmax:
+   case nir_op_imax:
+   case nir_op_umax:
+   case nir_op_fmin:
+   case nir_op_imin:
+   case nir_op_umin:
+   case nir_op_fmod:
+   case nir_op_imod:
+   case nir_op_umod:
+   case nir_op_fmul:
+   case nir_op_imul:
+   case nir_op_imul_high:
+   case nir_op_umul_high:
+   case nir_op_fneg:
+   case nir_op_ineg:
+   case nir_op_fnot:
+   case nir_op_inot:
+   case nir_op_for:
+   case nir_op_ior:
+   case nir_op_pack_64_2x32_split:
+   case nir_op_fpow:
+   case nir_op_frcp:
+   case nir_op_frem:
+   case nir_op_irem:
+   case nir_op_frsq:
+   case nir_op_fsat:
+   case nir_op_ishr:
+   case nir_op_ushr:
+   case nir_op_fsin:
+   case nir_op_fsqrt:
+   case nir_op_fsub:
+   case nir_op_isub:
+   case nir_op_ftrunc:
+   case nir_op_ishl:
+   case nir_op_fxor:
+   case nir_op_ixor: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      operation preOp = preOperationNeeded(op);
+      if (preOp != OP_NOP) {
+         assert(info.num_inputs < 2);
+         Value *tmp = getSSA(typeSizeof(dType));
+         Instruction *i0 = mkOp(preOp, dType, tmp);
+         Instruction *i1 = mkOp(getOperation(op), dType, newDefs[0]);
+         if (info.num_inputs) {
+            i0->setSrc(0, getSrc(&insn->src[0]));
+            i1->setSrc(0, tmp);
+         }
+         i1->subOp = getSubOp(op);
+      } else {
+         Instruction *i = mkOp(getOperation(op), dType, newDefs[0]);
+         for (unsigned s = 0u; s < info.num_inputs; ++s) {
+            i->setSrc(s, getSrc(&insn->src[s]));
+         }
+         i->subOp = getSubOp(op);
+      }
+      break;
+   }
+   case nir_op_ifind_msb:
+   case nir_op_ufind_msb: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      dType = sTypes[0];
+      mkOp1(getOperation(op), dType, newDefs[0], getSrc(&insn->src[0]));
+      break;
+   }
+   case nir_op_fround_even: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      mkCvt(OP_CVT, dType, newDefs[0], dType, getSrc(&insn->src[0]))->rnd = ROUND_NI;
+      break;
+   }
+   // convert instructions
+   case nir_op_f2f32:
+   case nir_op_f2i32:
+   case nir_op_f2u32:
+   case nir_op_i2f32:
+   case nir_op_i2i32:
+   case nir_op_u2f32:
+   case nir_op_u2u32:
+   case nir_op_f2f64:
+   case nir_op_f2i64:
+   case nir_op_f2u64:
+   case nir_op_i2f64:
+   case nir_op_i2i64:
+   case nir_op_u2f64:
+   case nir_op_u2u64: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      Instruction *i = mkOp1(getOperation(op), dType, newDefs[0], getSrc(&insn->src[0]));
+      if (op == nir_op_f2i32 || op == nir_op_f2i64 || op == nir_op_f2u32 || op == nir_op_f2u64)
+         i->rnd = ROUND_Z;
+      i->sType = sTypes[0];
+      break;
+   }
+   // compare instructions
+   case nir_op_feq32:
+   case nir_op_ieq32:
+   case nir_op_fge32:
+   case nir_op_ige32:
+   case nir_op_uge32:
+   case nir_op_flt32:
+   case nir_op_ilt32:
+   case nir_op_ult32:
+   case nir_op_fne32:
+   case nir_op_ine32: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      Instruction *i = mkCmp(getOperation(op),
+                             getCondCode(op),
+                             dType,
+                             newDefs[0],
+                             dType,
+                             getSrc(&insn->src[0]),
+                             getSrc(&insn->src[1]));
+      if (info.num_inputs == 3)
+         i->setSrc(2, getSrc(&insn->src[2]));
+      i->sType = sTypes[0];
+      break;
+   }
+   // those are weird ALU ops and need special handling, because
+   //   1. they are always componend based
+   //   2. they basically just merge multiple values into one data type
+   case nir_op_imov:
+   case nir_op_fmov:
+   case nir_op_vec2:
+   case nir_op_vec3:
+   case nir_op_vec4: {
+      LValues &newDefs = convert(&insn->dest);
+      for (LValues::size_type c = 0u; c < newDefs.size(); ++c) {
+         mkMov(newDefs[c], getSrc(&insn->src[c]), dType);
+      }
+      break;
+   }
+   // (un)pack
+   case nir_op_pack_64_2x32: {
+      LValues &newDefs = convert(&insn->dest);
+      Instruction *merge = mkOp(OP_MERGE, dType, newDefs[0]);
+      merge->setSrc(0, getSrc(&insn->src[0], 0));
+      merge->setSrc(1, getSrc(&insn->src[0], 1));
+      break;
+   }
+   case nir_op_pack_half_2x16_split: {
+      LValues &newDefs = convert(&insn->dest);
+      Value *tmpH = getSSA();
+      Value *tmpL = getSSA();
+
+      mkCvt(OP_CVT, TYPE_F16, tmpL, TYPE_F32, getSrc(&insn->src[0]));
+      mkCvt(OP_CVT, TYPE_F16, tmpH, TYPE_F32, getSrc(&insn->src[1]));
+      mkOp3(OP_INSBF, TYPE_U32, newDefs[0], tmpH, mkImm(0x1010), tmpL);
+      break;
+   }
+   case nir_op_unpack_half_2x16_split_x:
+   case nir_op_unpack_half_2x16_split_y: {
+      LValues &newDefs = convert(&insn->dest);
+      Instruction *cvt = mkCvt(OP_CVT, TYPE_F32, newDefs[0], TYPE_F16, getSrc(&insn->src[0]));
+      if (op == nir_op_unpack_half_2x16_split_y)
+         cvt->subOp = 1;
+      break;
+   }
+   case nir_op_unpack_64_2x32: {
+      LValues &newDefs = convert(&insn->dest);
+      mkOp1(OP_SPLIT, dType, newDefs[0], getSrc(&insn->src[0]))->setDef(1, newDefs[1]);
+      break;
+   }
+   case nir_op_unpack_64_2x32_split_x: {
+      LValues &newDefs = convert(&insn->dest);
+      mkOp1(OP_SPLIT, dType, newDefs[0], getSrc(&insn->src[0]))->setDef(1, getSSA());
+      break;
+   }
+   case nir_op_unpack_64_2x32_split_y: {
+      LValues &newDefs = convert(&insn->dest);
+      mkOp1(OP_SPLIT, dType, getSSA(), getSrc(&insn->src[0]))->setDef(1, newDefs[0]);
+      break;
+   }
+   // special instructions
+   case nir_op_fsign:
+   case nir_op_isign: {
+      DEFAULT_CHECKS;
+      DataType iType;
+      if (::isFloatType(dType))
+         iType = TYPE_F32;
+      else
+         iType = TYPE_S32;
+
+      LValues &newDefs = convert(&insn->dest);
+      LValue *val0 = getScratch();
+      LValue *val1 = getScratch();
+      mkCmp(OP_SET, CC_GT, iType, val0, dType, getSrc(&insn->src[0]), zero);
+      mkCmp(OP_SET, CC_LT, iType, val1, dType, getSrc(&insn->src[0]), zero);
+
+      if (dType == TYPE_F64) {
+         mkOp2(OP_SUB, iType, val0, val0, val1);
+         mkCvt(OP_CVT, TYPE_F64, newDefs[0], iType, val0);
+      } else if (dType == TYPE_S64 || dType == TYPE_U64) {
+         mkOp2(OP_SUB, iType, val0, val1, val0);
+         mkOp2(OP_SHR, iType, val1, val0, loadImm(NULL, 31));
+         mkOp2(OP_MERGE, dType, newDefs[0], val0, val1);
+      } else if (::isFloatType(dType))
+         mkOp2(OP_SUB, iType, newDefs[0], val0, val1);
+      else
+         mkOp2(OP_SUB, iType, newDefs[0], val1, val0);
+      break;
+   }
+   case nir_op_fcsel:
+   case nir_op_b32csel: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      mkCmp(OP_SLCT, CC_NE, dType, newDefs[0], sTypes[0], getSrc(&insn->src[1]), getSrc(&insn->src[2]), getSrc(&insn->src[0]));
+      break;
+   }
+   case nir_op_ibitfield_extract:
+   case nir_op_ubitfield_extract: {
+      DEFAULT_CHECKS;
+      Value *tmp = getSSA();
+      LValues &newDefs = convert(&insn->dest);
+      mkOp3(OP_INSBF, dType, tmp, getSrc(&insn->src[2]), loadImm(NULL, 0x808), getSrc(&insn->src[1]));
+      mkOp2(OP_EXTBF, dType, newDefs[0], getSrc(&insn->src[0]), tmp);
+      break;
+   }
+   case nir_op_bfm: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[0]), loadImm(NULL, 0x808), getSrc(&insn->src[1]));
+      break;
+   }
+   case nir_op_bitfield_insert: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      LValue *temp = getSSA();
+      mkOp3(OP_INSBF, TYPE_U32, temp, getSrc(&insn->src[3]), mkImm(0x808), getSrc(&insn->src[2]));
+      mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[1]), temp, getSrc(&insn->src[0]));
+      break;
+   }
+   case nir_op_bit_count: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      mkOp2(OP_POPCNT, dType, newDefs[0], getSrc(&insn->src[0]), getSrc(&insn->src[0]));
+      break;
+   }
+   case nir_op_bitfield_reverse: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      mkOp2(OP_EXTBF, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV;
+      break;
+   }
+   case nir_op_find_lsb: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      Value *tmp = getSSA();
+      mkOp2(OP_EXTBF, TYPE_U32, tmp, getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV;
+      mkOp1(OP_BFIND, TYPE_U32, newDefs[0], tmp)->subOp = NV50_IR_SUBOP_BFIND_SAMT;
+      break;
+   }
+   // boolean conversions
+   case nir_op_b2f32: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      mkOp2(OP_AND, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), loadImm(NULL, 1.0f));
+      break;
+   }
+   case nir_op_b2f64: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      Value *tmp = getSSA(4);
+      mkOp2(OP_AND, TYPE_U32, tmp, getSrc(&insn->src[0]), loadImm(NULL, 0x3ff00000));
+      mkOp2(OP_MERGE, TYPE_U64, newDefs[0], loadImm(NULL, 0), tmp);
+      break;
+   }
+   case nir_op_f2b32:
+   case nir_op_i2b32: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      Value *src1;
+      if (typeSizeof(sTypes[0]) == 8) {
+         src1 = loadImm(getSSA(8), 0.0);
+      } else {
+         src1 = zero;
+      }
+      CondCode cc = op == nir_op_f2b32 ? CC_NEU : CC_NE;
+      mkCmp(OP_SET, cc, TYPE_U32, newDefs[0], sTypes[0], getSrc(&insn->src[0]), src1);
+      break;
+   }
+   case nir_op_b2i32: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      mkOp2(OP_AND, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), loadImm(NULL, 1));
+      break;
+   }
+   case nir_op_b2i64: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      LValue *def = getScratch();
+      mkOp2(OP_AND, TYPE_U32, def, getSrc(&insn->src[0]), loadImm(NULL, 1));
+      mkOp2(OP_MERGE, TYPE_S64, newDefs[0], def, loadImm(NULL, 0));
+      break;
+   }
+   default:
+      ERROR("unknown nir_op %s\n", info.name);
+      return false;
+   }
+
+   if (!oldPos) {
+      oldPos = this->bb->getEntry();
+      oldPos->precise = insn->exact;
+   }
+
+   if (unlikely(!oldPos))
+      return true;
+
+   while (oldPos->next) {
+      oldPos = oldPos->next;
+      oldPos->precise = insn->exact;
+   }
+   oldPos->saturate = insn->dest.saturate;
+
+   return true;
+}
+#undef DEFAULT_CHECKS
+
 bool
 Converter::run()
 {