gallivm/nir: add support for multiple vertex streams
[mesa.git] / src / gallium / auxiliary / gallivm / lp_bld_nir_soa.c
index dc2a7405ae3dfd12d61db10781c762d1b84fbb92..e1e476e35ce9569fcbe695cc307a77446ed8f70c 100644 (file)
@@ -1318,7 +1318,7 @@ static void emit_vertex(struct lp_build_nir_context *bld_base, uint32_t stream_i
 
    assert(bld->gs_iface->emit_vertex);
    LLVMValueRef total_emitted_vertices_vec =
-      LLVMBuildLoad(builder, bld->total_emitted_vertices_vec_ptr, "");
+      LLVMBuildLoad(builder, bld->total_emitted_vertices_vec_ptr[stream_id], "");
    LLVMValueRef mask = mask_vec(bld_base);
    mask = clamp_mask_to_max_output_vertices(bld, mask,
                                             total_emitted_vertices_vec);
@@ -1327,38 +1327,39 @@ static void emit_vertex(struct lp_build_nir_context *bld_base, uint32_t stream_i
                               total_emitted_vertices_vec,
                               lp_build_const_int_vec(bld->bld_base.base.gallivm, bld->bld_base.base.type, stream_id));
 
-   increment_vec_ptr_by_mask(bld_base, bld->emitted_vertices_vec_ptr,
+   increment_vec_ptr_by_mask(bld_base, bld->emitted_vertices_vec_ptr[stream_id],
                              mask);
-   increment_vec_ptr_by_mask(bld_base, bld->total_emitted_vertices_vec_ptr,
+   increment_vec_ptr_by_mask(bld_base, bld->total_emitted_vertices_vec_ptr[stream_id],
                              mask);
 }
 
 static void
 end_primitive_masked(struct lp_build_nir_context * bld_base,
-                     LLVMValueRef mask)
+                     LLVMValueRef mask, uint32_t stream_id)
 {
    struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
    LLVMBuilderRef builder = bld->bld_base.base.gallivm->builder;
 
    struct lp_build_context *uint_bld = &bld_base->uint_bld;
    LLVMValueRef emitted_vertices_vec =
-      LLVMBuildLoad(builder, bld->emitted_vertices_vec_ptr, "");
+      LLVMBuildLoad(builder, bld->emitted_vertices_vec_ptr[stream_id], "");
    LLVMValueRef emitted_prims_vec =
-      LLVMBuildLoad(builder, bld->emitted_prims_vec_ptr, "");
+      LLVMBuildLoad(builder, bld->emitted_prims_vec_ptr[stream_id], "");
    LLVMValueRef total_emitted_vertices_vec =
-      LLVMBuildLoad(builder, bld->total_emitted_vertices_vec_ptr, "");
+      LLVMBuildLoad(builder, bld->total_emitted_vertices_vec_ptr[stream_id], "");
 
    LLVMValueRef emitted_mask = lp_build_cmp(uint_bld,
                                             PIPE_FUNC_NOTEQUAL,
                                             emitted_vertices_vec,
                                             uint_bld->zero);
    mask = LLVMBuildAnd(builder, mask, emitted_mask, "");
-   bld->gs_iface->end_primitive(bld->gs_iface, &bld->bld_base.base,
-                                total_emitted_vertices_vec,
-                                emitted_vertices_vec, emitted_prims_vec, mask_vec(bld_base));
-   increment_vec_ptr_by_mask(bld_base, bld->emitted_prims_vec_ptr,
+   if (stream_id == 0)
+      bld->gs_iface->end_primitive(bld->gs_iface, &bld->bld_base.base,
+                                   total_emitted_vertices_vec,
+                                   emitted_vertices_vec, emitted_prims_vec, mask_vec(bld_base));
+   increment_vec_ptr_by_mask(bld_base, bld->emitted_prims_vec_ptr[stream_id],
                              mask);
-   clear_uint_vec_ptr_from_mask(bld_base, bld->emitted_vertices_vec_ptr,
+   clear_uint_vec_ptr_from_mask(bld_base, bld->emitted_vertices_vec_ptr[stream_id],
                                 mask);
 }
 
@@ -1369,7 +1370,7 @@ static void end_primitive(struct lp_build_nir_context *bld_base, uint32_t stream
    assert(bld->gs_iface->end_primitive);
 
    LLVMValueRef mask = mask_vec(bld_base);
-   end_primitive_masked(bld_base, mask);
+   end_primitive_masked(bld_base, mask, stream_id);
 }
 
 static void
@@ -1400,6 +1401,63 @@ emit_prologue(struct lp_build_nir_soa_context *bld)
    }
 }
 
+static void emit_vote(struct lp_build_nir_context *bld_base, LLVMValueRef src, nir_intrinsic_instr *instr, LLVMValueRef result[4])
+{
+   struct gallivm_state * gallivm = bld_base->base.gallivm;
+   LLVMBuilderRef builder = gallivm->builder;
+
+   LLVMValueRef exec_mask = mask_vec(bld_base);
+   struct lp_build_loop_state loop_state;
+
+   LLVMValueRef outer_cond = LLVMBuildICmp(builder, LLVMIntNE, exec_mask, bld_base->uint_bld.zero, "");
+
+   LLVMValueRef res_store = lp_build_alloca(gallivm, bld_base->int_bld.elem_type, "");
+   LLVMValueRef init_val;
+   if (instr->intrinsic == nir_intrinsic_vote_ieq) {
+      /* for equal we unfortunately have to loop and find the first valid one. */
+      lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0));
+      LLVMValueRef if_cond = LLVMBuildExtractElement(gallivm->builder, outer_cond, loop_state.counter, "");
+
+      struct lp_build_if_state ifthen;
+      lp_build_if(&ifthen, gallivm, if_cond);
+      LLVMValueRef value_ptr = LLVMBuildExtractElement(gallivm->builder, src,
+                                                       loop_state.counter, "");
+      LLVMBuildStore(builder, value_ptr, res_store);
+      lp_build_endif(&ifthen);
+      lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, bld_base->uint_bld.type.length),
+                            NULL, LLVMIntUGE);
+      lp_build_print_value(gallivm, "init_val is ", LLVMBuildLoad(builder, res_store, ""));
+      init_val = LLVMBuildLoad(builder, res_store, "");
+   } else {
+      LLVMBuildStore(builder, lp_build_const_int32(gallivm, instr->intrinsic == nir_intrinsic_vote_any ? 0 : -1), res_store);
+   }
+
+   LLVMValueRef res;
+   lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0));
+   LLVMValueRef value_ptr = LLVMBuildExtractElement(gallivm->builder, src,
+                                                       loop_state.counter, "");
+   struct lp_build_if_state ifthen;
+   LLVMValueRef if_cond;
+   if_cond = LLVMBuildExtractElement(gallivm->builder, outer_cond, loop_state.counter, "");
+
+   lp_build_if(&ifthen, gallivm, if_cond);
+   res = LLVMBuildLoad(builder, res_store, "");
+
+   if (instr->intrinsic == nir_intrinsic_vote_ieq) {
+      LLVMValueRef tmp = LLVMBuildICmp(builder, LLVMIntEQ, init_val, value_ptr, "");
+      tmp = LLVMBuildSExt(builder, tmp, bld_base->uint_bld.elem_type, "");
+      res = LLVMBuildOr(builder, res, tmp, "");
+   } else if (instr->intrinsic == nir_intrinsic_vote_any)
+      res = LLVMBuildOr(builder, res, value_ptr, "");
+   else
+      res = LLVMBuildAnd(builder, res, value_ptr, "");
+   LLVMBuildStore(builder, res, res_store);
+   lp_build_endif(&ifthen);
+   lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, bld_base->uint_bld.type.length),
+                         NULL, LLVMIntUGE);
+   result[0] = lp_build_broadcast_scalar(&bld_base->uint_bld, LLVMBuildLoad(builder, res_store, ""));
+}
+
 void lp_build_nir_soa(struct gallivm_state *gallivm,
                       struct nir_shader *shader,
                       const struct lp_build_tgsi_params *params,
@@ -1494,6 +1552,7 @@ void lp_build_nir_soa(struct gallivm_state *gallivm,
    bld.bld_base.barrier = emit_barrier;
    bld.bld_base.image_op = emit_image_op;
    bld.bld_base.image_size = emit_image_size;
+   bld.bld_base.vote = emit_vote;
 
    bld.mask = params->mask;
    bld.inputs = params->inputs;
@@ -1521,12 +1580,14 @@ void lp_build_nir_soa(struct gallivm_state *gallivm,
 
       bld.max_output_vertices_vec = lp_build_const_int_vec(gallivm, bld.bld_base.int_bld.type,
                                                            shader->info.gs.vertices_out);
-      bld.emitted_prims_vec_ptr =
-         lp_build_alloca(gallivm, uint_bld->vec_type, "emitted_prims_ptr");
-      bld.emitted_vertices_vec_ptr =
-         lp_build_alloca(gallivm, uint_bld->vec_type, "emitted_vertices_ptr");
-      bld.total_emitted_vertices_vec_ptr =
-         lp_build_alloca(gallivm, uint_bld->vec_type, "total_emitted_vertices_ptr");
+      for (int i = 0; i < PIPE_MAX_VERTEX_STREAMS; i++) {
+         bld.emitted_prims_vec_ptr[i] =
+            lp_build_alloca(gallivm, uint_bld->vec_type, "emitted_prims_ptr");
+         bld.emitted_vertices_vec_ptr[i] =
+            lp_build_alloca(gallivm, uint_bld->vec_type, "emitted_vertices_ptr");
+         bld.total_emitted_vertices_vec_ptr[i] =
+            lp_build_alloca(gallivm, uint_bld->vec_type, "total_emitted_vertices_ptr");
+      }
    }
    lp_exec_mask_init(&bld.exec_mask, &bld.bld_base.int_bld);
 
@@ -1541,15 +1602,18 @@ void lp_build_nir_soa(struct gallivm_state *gallivm,
       LLVMBuilderRef builder = bld.bld_base.base.gallivm->builder;
       LLVMValueRef total_emitted_vertices_vec;
       LLVMValueRef emitted_prims_vec;
-      end_primitive_masked(&bld.bld_base, lp_build_mask_value(bld.mask));
-      total_emitted_vertices_vec =
-         LLVMBuildLoad(builder, bld.total_emitted_vertices_vec_ptr, "");
-      emitted_prims_vec =
-         LLVMBuildLoad(builder, bld.emitted_prims_vec_ptr, "");
-
-      bld.gs_iface->gs_epilogue(bld.gs_iface,
-                                 total_emitted_vertices_vec,
-                                 emitted_prims_vec);
+
+      end_primitive_masked(&bld.bld_base, lp_build_mask_value(bld.mask), 0);
+      for (int i = 0; i < PIPE_MAX_VERTEX_STREAMS; i++) {
+         total_emitted_vertices_vec =
+            LLVMBuildLoad(builder, bld.total_emitted_vertices_vec_ptr[i], "");
+
+         emitted_prims_vec =
+            LLVMBuildLoad(builder, bld.emitted_prims_vec_ptr[i], "");
+         bld.gs_iface->gs_epilogue(bld.gs_iface,
+                                   total_emitted_vertices_vec,
+                                   emitted_prims_vec, i);
+      }
    }
    lp_exec_mask_fini(&bld.exec_mask);
 }