X-Git-Url: https://git.libre-soc.org/?p=mesa.git;a=blobdiff_plain;f=src%2Fcompiler%2Fspirv%2Fvtn_opencl.c;h=b5e92dc3b85e1d53ccb321747bf8e74ec2201c73;hp=0dc95a19f6b86f7301e4fdea0c710ab6f1463bbd;hb=608c87afdddb9524bce3d2ccd95c2297b453072b;hpb=6d69ed88f875b1ed0cc0def96067fc2a92f4d0ed diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c index 0dc95a19f6b..b5e92dc3b85 100644 --- a/src/compiler/spirv/vtn_opencl.c +++ b/src/compiler/spirv/vtn_opencl.c @@ -39,21 +39,18 @@ static void handle_instr(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, const uint32_t *w, unsigned count, nir_handler handler) { - const struct glsl_type *dest_type = - vtn_value(b, w[1], vtn_value_type_type)->type->type; + const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type; unsigned num_srcs = count - 5; nir_ssa_def *srcs[3] = { NULL }; vtn_assert(num_srcs <= ARRAY_SIZE(srcs)); for (unsigned i = 0; i < num_srcs; i++) { - srcs[i] = vtn_ssa_value(b, w[i + 5])->def; + srcs[i] = vtn_get_nir_ssa(b, w[i + 5]); } nir_ssa_def *result = handler(b, opcode, num_srcs, srcs, dest_type); if (result) { - struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); - val->ssa = vtn_create_ssa_value(b, dest_type); - val->ssa->def = result; + vtn_push_nir_ssa(b, w[2], result); } else { vtn_assert(dest_type == glsl_void_type()); } @@ -84,8 +81,15 @@ nir_alu_op_for_opencl_opcode(struct vtn_builder *b, case OpenCLstd_UMin: return nir_op_umin; case OpenCLstd_Fmod: return nir_op_fmod; case OpenCLstd_Mix: return nir_op_flrp; + case OpenCLstd_Native_cos: return nir_op_fcos; + case OpenCLstd_Native_divide: return nir_op_fdiv; case OpenCLstd_Native_exp2: return nir_op_fexp2; case OpenCLstd_Native_log2: return nir_op_flog2; + case OpenCLstd_Native_powr: return nir_op_fpow; + case OpenCLstd_Native_recip: return nir_op_frcp; + case OpenCLstd_Native_rsqrt: return nir_op_frsq; + case OpenCLstd_Native_sin: return nir_op_fsin; + case OpenCLstd_Native_sqrt: return nir_op_fsqrt; case OpenCLstd_SMul_hi: return nir_op_imul_high; case OpenCLstd_UMul_hi: return nir_op_umul_high; case OpenCLstd_Popcount: return nir_op_bit_count; @@ -100,6 +104,7 @@ nir_alu_op_for_opencl_opcode(struct vtn_builder *b, case OpenCLstd_SSub_sat: return nir_op_isub_sat; case OpenCLstd_USub_sat: return nir_op_usub_sat; case OpenCLstd_Trunc: return nir_op_ftrunc; + case OpenCLstd_Rint: return nir_op_fround_even; /* uhm... */ case OpenCLstd_UAbs: return nir_op_mov; default: @@ -112,8 +117,11 @@ handle_alu(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs, nir_ssa_def **srcs, const struct glsl_type *dest_type) { - return nir_build_alu(&b->nb, nir_alu_op_for_opencl_opcode(b, opcode), - srcs[0], srcs[1], srcs[2], NULL); + nir_ssa_def *ret = nir_build_alu(&b->nb, nir_alu_op_for_opencl_opcode(b, opcode), + srcs[0], srcs[1], srcs[2], NULL); + if (opcode == OpenCLstd_Popcount) + ret = nir_u2u(&b->nb, ret, glsl_get_bit_size(dest_type)); + return ret; } static nir_ssa_def * @@ -203,6 +211,8 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, return nir_flog(nb, srcs[0]); case OpenCLstd_Native_log10: return nir_fmul_imm(nb, nir_flog2(nb, srcs[0]), log(2) / log(10)); + case OpenCLstd_Native_tan: + return nir_ftan(nb, srcs[0]); default: vtn_fail("No NIR equivalent"); return NULL; @@ -215,15 +225,15 @@ _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, { struct vtn_type *type; if (load) - type = vtn_value(b, w[1], vtn_value_type_type)->type; + type = vtn_get_type(b, w[1]); else - type = vtn_untyped_value(b, w[5])->type; + type = vtn_get_value_type(b, w[5]); unsigned a = load ? 0 : 1; const struct glsl_type *dest_type = type->type; unsigned components = glsl_get_vector_elements(dest_type); - nir_ssa_def *offset = vtn_ssa_value(b, w[5 + a])->def; + nir_ssa_def *offset = vtn_get_nir_ssa(b, w[5 + a]); struct vtn_value *p = vtn_value(b, w[6 + a], vtn_value_type_pointer); struct vtn_ssa_value *comps[NIR_MAX_VEC_COMPONENTS]; @@ -242,14 +252,12 @@ _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, } else { struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, glsl_scalar_type(glsl_get_base_type(dest_type))); struct vtn_ssa_value *val = vtn_ssa_value(b, w[5]); - ssa->def = vtn_vector_extract(b, val->def, i); + ssa->def = nir_channel(&b->nb, val->def, i); vtn_local_store(b, ssa, arr_deref, p->type->access); } } if (load) { - struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, dest_type); - ssa->def = nir_vec(&b->nb, ncomps, components); - vtn_push_ssa(b, w[2], type, ssa); + vtn_push_nir_ssa(b, w[2], nir_vec(&b->nb, ncomps, components)); } } @@ -276,6 +284,21 @@ handle_printf(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, return nir_imm_int(&b->nb, -1); } +static nir_ssa_def * +handle_round(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, + unsigned num_srcs, nir_ssa_def **srcs, + const struct glsl_type *dest_type) +{ + nir_ssa_def *src = srcs[0]; + nir_builder *nb = &b->nb; + nir_ssa_def *half = nir_imm_floatN_t(nb, 0.5, src->bit_size); + nir_ssa_def *truncated = nir_ftrunc(nb, src); + nir_ssa_def *remainder = nir_fsub(nb, src, truncated); + + return nir_bcsel(nb, nir_fge(nb, nir_fabs(nb, remainder), half), + nir_fadd(nb, truncated, nir_fsign(nb, src)), truncated); +} + static nir_ssa_def * handle_shuffle(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs, nir_ssa_def **srcs, const struct glsl_type *dest_type) @@ -349,8 +372,15 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode, case OpenCLstd_SMin: case OpenCLstd_UMin: case OpenCLstd_Mix: + case OpenCLstd_Native_cos: + case OpenCLstd_Native_divide: case OpenCLstd_Native_exp2: case OpenCLstd_Native_log2: + case OpenCLstd_Native_powr: + case OpenCLstd_Native_recip: + case OpenCLstd_Native_rsqrt: + case OpenCLstd_Native_sin: + case OpenCLstd_Native_sqrt: case OpenCLstd_Fmod: case OpenCLstd_SMul_hi: case OpenCLstd_UMul_hi: @@ -366,6 +396,7 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode, case OpenCLstd_SSub_sat: case OpenCLstd_USub_sat: case OpenCLstd_Trunc: + case OpenCLstd_Rint: handle_instr(b, cl_opcode, w, count, handle_alu); return true; case OpenCLstd_SAbs_diff: @@ -407,6 +438,7 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode, case OpenCLstd_Native_exp10: case OpenCLstd_Native_log: case OpenCLstd_Native_log10: + case OpenCLstd_Native_tan: handle_instr(b, cl_opcode, w, count, handle_special); return true; case OpenCLstd_Vloadn: @@ -421,6 +453,9 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode, case OpenCLstd_Shuffle2: handle_instr(b, cl_opcode, w, count, handle_shuffle2); return true; + case OpenCLstd_Round: + handle_instr(b, cl_opcode, w, count, handle_round); + return true; case OpenCLstd_Printf: handle_instr(b, cl_opcode, w, count, handle_printf); return true;