zink: implement some more trivial opcodes
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
index 47d93efc91071212bdfc6817156c597593297806..09312dae406c2eb67a4135d91c7d737929713b20 100644 (file)
@@ -710,6 +710,15 @@ emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
                                       op, args, ARRAY_SIZE(args));
 }
 
+static SpvId
+emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
+                   SpvId src0, SpvId src1, SpvId src2)
+{
+   SpvId args[] = { src0, src1, src2 };
+   return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
+                                      op, args, ARRAY_SIZE(args));
+}
+
 static SpvId
 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
                   unsigned num_components, float value)
@@ -897,6 +906,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
       result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
       break;
 
+   BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
    BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
    BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
    BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
@@ -939,6 +949,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
    BINOP(nir_op_imul, SpvOpIMul)
    BINOP(nir_op_idiv, SpvOpSDiv)
    BINOP(nir_op_udiv, SpvOpUDiv)
+   BINOP(nir_op_umod, SpvOpUMod)
    BINOP(nir_op_fadd, SpvOpFAdd)
    BINOP(nir_op_fsub, SpvOpFSub)
    BINOP(nir_op_fmul, SpvOpFMul)
@@ -948,6 +959,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
    BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
    BINOP(nir_op_ieq, SpvOpIEqual)
    BINOP(nir_op_ine, SpvOpINotEqual)
+   BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
    BINOP(nir_op_flt, SpvOpFOrdLessThan)
    BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
    BINOP(nir_op_feq, SpvOpFOrdEqual)
@@ -1013,6 +1025,12 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
       }
       break;
 
+   case nir_op_flrp:
+      assert(nir_op_infos[alu->op].num_inputs == 3);
+      result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
+                                  src[0], src[1], src[2]);
+      break;
+
    case nir_op_fcsel:
       result = emit_binop(ctx, SpvOpFOrdGreaterThan,
                           get_bvec_type(ctx, num_components),
@@ -1028,6 +1046,54 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
       result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
       break;
 
+   case nir_op_bany_fnequal2:
+   case nir_op_bany_fnequal3:
+   case nir_op_bany_fnequal4:
+      assert(nir_op_infos[alu->op].num_inputs == 2);
+      assert(alu_instr_src_components(alu, 0) ==
+             alu_instr_src_components(alu, 1));
+      result = emit_binop(ctx, SpvOpFOrdNotEqual,
+                          get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
+                          src[0], src[1]);
+      result = emit_unop(ctx, SpvOpAny, dest_type, result);
+      break;
+
+   case nir_op_ball_fequal2:
+   case nir_op_ball_fequal3:
+   case nir_op_ball_fequal4:
+      assert(nir_op_infos[alu->op].num_inputs == 2);
+      assert(alu_instr_src_components(alu, 0) ==
+             alu_instr_src_components(alu, 1));
+      result = emit_binop(ctx, SpvOpFOrdEqual,
+                          get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
+                          src[0], src[1]);
+      result = emit_unop(ctx, SpvOpAll, dest_type, result);
+      break;
+
+   case nir_op_bany_inequal2:
+   case nir_op_bany_inequal3:
+   case nir_op_bany_inequal4:
+      assert(nir_op_infos[alu->op].num_inputs == 2);
+      assert(alu_instr_src_components(alu, 0) ==
+             alu_instr_src_components(alu, 1));
+      result = emit_binop(ctx, SpvOpINotEqual,
+                          get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
+                          src[0], src[1]);
+      result = emit_unop(ctx, SpvOpAny, dest_type, result);
+      break;
+
+   case nir_op_ball_iequal2:
+   case nir_op_ball_iequal3:
+   case nir_op_ball_iequal4:
+      assert(nir_op_infos[alu->op].num_inputs == 2);
+      assert(alu_instr_src_components(alu, 0) ==
+             alu_instr_src_components(alu, 1));
+      result = emit_binop(ctx, SpvOpIEqual,
+                          get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
+                          src[0], src[1]);
+      result = emit_unop(ctx, SpvOpAll, dest_type, result);
+      break;
+
    case nir_op_vec2:
    case nir_op_vec3:
    case nir_op_vec4: {
@@ -1256,21 +1322,35 @@ get_src_float(struct ntv_context *ctx, nir_src *src)
    return bitcast_to_fvec(ctx, def, bit_size, num_components);
 }
 
+static SpvId
+get_src_int(struct ntv_context *ctx, nir_src *src)
+{
+   SpvId def = get_src_uint(ctx, src);
+   unsigned num_components = nir_src_num_components(*src);
+   unsigned bit_size = nir_src_bit_size(*src);
+   return bitcast_to_ivec(ctx, def, bit_size, num_components);
+}
+
 static void
 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
 {
    assert(tex->op == nir_texop_tex ||
           tex->op == nir_texop_txb ||
-          tex->op == nir_texop_txl);
+          tex->op == nir_texop_txl ||
+          tex->op == nir_texop_txd ||
+          tex->op == nir_texop_txf);
    assert(nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
    assert(tex->texture_index == tex->sampler_index);
 
-   SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0;
-   unsigned coord_components;
+   SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0;
+   unsigned coord_components = 0;
    for (unsigned i = 0; i < tex->num_srcs; i++) {
       switch (tex->src[i].src_type) {
       case nir_tex_src_coord:
-         coord = get_src_float(ctx, &tex->src[i].src);
+         if (tex->op == nir_texop_txf)
+            coord = get_src_int(ctx, &tex->src[i].src);
+         else
+            coord = get_src_float(ctx, &tex->src[i].src);
          coord_components = nir_src_num_components(tex->src[i].src);
          break;
 
@@ -1288,7 +1368,10 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
 
       case nir_tex_src_lod:
          assert(nir_src_num_components(tex->src[i].src) == 1);
-         lod = get_src_float(ctx, &tex->src[i].src);
+         if (tex->op == nir_texop_txf)
+            lod = get_src_int(ctx, &tex->src[i].src);
+         else
+            lod = get_src_float(ctx, &tex->src[i].src);
          assert(lod != 0);
          break;
 
@@ -1298,6 +1381,16 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
          assert(dref != 0);
          break;
 
+      case nir_tex_src_ddx:
+         dx = get_src_float(ctx, &tex->src[i].src);
+         assert(dx != 0);
+         break;
+
+      case nir_tex_src_ddy:
+         dy = get_src_float(ctx, &tex->src[i].src);
+         assert(dy != 0);
+         break;
+
       default:
          fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
          unreachable("unknown texture source");
@@ -1351,11 +1444,19 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
    if (dref)
       actual_dest_type = float_type;
 
-   SpvId result = spirv_builder_emit_image_sample(&ctx->builder,
-                                                  actual_dest_type, load,
-                                                  coord,
-                                                  proj != 0,
-                                                  lod, bias, dref);
+   SpvId result;
+   if (tex->op == nir_texop_txf) {
+      SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
+      result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
+                                              image, coord, lod);
+   } else {
+      result = spirv_builder_emit_image_sample(&ctx->builder,
+                                               actual_dest_type, load,
+                                               coord,
+                                               proj != 0,
+                                               lod, bias, dref, dx, dy);
+   }
+
    spirv_builder_emit_decoration(&ctx->builder, result,
                                  SpvDecorationRelaxedPrecision);