nir: Add nir_address_format_32bit_offset_as_64bit
[mesa.git] / src / compiler / nir / nir_lower_io.c
index fd5a713349e9eecf8a07ade297c2fcbbc939ec6b..cc1c456227a4bb6f075666e2b9e2bba60a54f25c 100644 (file)
@@ -692,17 +692,23 @@ build_addr_iadd(nir_builder *b, nir_ssa_def *addr,
                 nir_address_format addr_format, nir_ssa_def *offset)
 {
    assert(offset->num_components == 1);
-   assert(addr->bit_size == offset->bit_size);
 
    switch (addr_format) {
    case nir_address_format_32bit_global:
    case nir_address_format_64bit_global:
    case nir_address_format_32bit_offset:
+      assert(addr->bit_size == offset->bit_size);
       assert(addr->num_components == 1);
       return nir_iadd(b, addr, offset);
 
+   case nir_address_format_32bit_offset_as_64bit:
+      assert(addr->num_components == 1);
+      assert(offset->bit_size == 32);
+      return nir_u2u64(b, nir_iadd(b, nir_u2u32(b, addr), offset));
+
    case nir_address_format_64bit_bounded_global:
       assert(addr->num_components == 4);
+      assert(addr->bit_size == offset->bit_size);
       return nir_vec4(b, nir_channel(b, addr, 0),
                          nir_channel(b, addr, 1),
                          nir_channel(b, addr, 2),
@@ -710,10 +716,12 @@ build_addr_iadd(nir_builder *b, nir_ssa_def *addr,
 
    case nir_address_format_32bit_index_offset:
       assert(addr->num_components == 2);
+      assert(addr->bit_size == offset->bit_size);
       return nir_vec2(b, nir_channel(b, addr, 0),
                          nir_iadd(b, nir_channel(b, addr, 1), offset));
    case nir_address_format_vec2_index_32bit_offset:
       assert(addr->num_components == 3);
+      assert(offset->bit_size == 32);
       return nir_vec3(b, nir_channel(b, addr, 0), nir_channel(b, addr, 1),
                          nir_iadd(b, nir_channel(b, addr, 2), offset));
    case nir_address_format_logical:
@@ -722,12 +730,21 @@ build_addr_iadd(nir_builder *b, nir_ssa_def *addr,
    unreachable("Invalid address format");
 }
 
+static unsigned
+addr_get_offset_bit_size(nir_ssa_def *addr, nir_address_format addr_format)
+{
+   if (addr_format == nir_address_format_32bit_offset_as_64bit)
+      return 32;
+   return addr->bit_size;
+}
+
 static nir_ssa_def *
 build_addr_iadd_imm(nir_builder *b, nir_ssa_def *addr,
                     nir_address_format addr_format, int64_t offset)
 {
    return build_addr_iadd(b, addr, addr_format,
-                             nir_imm_intN_t(b, offset, addr->bit_size));
+                             nir_imm_intN_t(b, offset,
+                                            addr_get_offset_bit_size(addr, addr_format)));
 }
 
 static nir_ssa_def *
@@ -749,14 +766,19 @@ static nir_ssa_def *
 addr_to_offset(nir_builder *b, nir_ssa_def *addr,
                nir_address_format addr_format)
 {
-   if (addr_format == nir_address_format_32bit_index_offset) {
+   switch (addr_format) {
+   case nir_address_format_32bit_index_offset:
       assert(addr->num_components == 2);
       return nir_channel(b, addr, 1);
-   } else if (addr_format == nir_address_format_vec2_index_32bit_offset) {
+   case nir_address_format_vec2_index_32bit_offset:
       assert(addr->num_components == 3);
       return nir_channel(b, addr, 2);
-   } else {
-      unreachable("bad address format for offset");
+   case nir_address_format_32bit_offset:
+      return addr;
+   case nir_address_format_32bit_offset_as_64bit:
+      return nir_u2u32(b, addr);
+   default:
+      unreachable("Invalid address format");
    }
 }
 
@@ -772,7 +794,8 @@ addr_format_is_global(nir_address_format addr_format)
 static bool
 addr_format_is_offset(nir_address_format addr_format)
 {
-   return addr_format == nir_address_format_32bit_offset;
+   return addr_format == nir_address_format_32bit_offset ||
+          addr_format == nir_address_format_32bit_offset_as_64bit;
 }
 
 static nir_ssa_def *
@@ -793,6 +816,7 @@ addr_to_global(nir_builder *b, nir_ssa_def *addr,
    case nir_address_format_32bit_index_offset:
    case nir_address_format_vec2_index_32bit_offset:
    case nir_address_format_32bit_offset:
+   case nir_address_format_32bit_offset_as_64bit:
    case nir_address_format_logical:
       unreachable("Cannot get a 64-bit address with this address format");
    }
@@ -863,9 +887,9 @@ build_explicit_io_load(nir_builder *b, nir_intrinsic_instr *intrin,
 
    if (addr_format_is_global(addr_format)) {
       load->src[0] = nir_src_for_ssa(addr_to_global(b, addr, addr_format));
-   } else if (addr_format == nir_address_format_32bit_offset) {
+   } else if (addr_format_is_offset(addr_format)) {
       assert(addr->num_components == 1);
-      load->src[0] = nir_src_for_ssa(addr);
+      load->src[0] = nir_src_for_ssa(addr_to_offset(b, addr, addr_format));
    } else {
       load->src[0] = nir_src_for_ssa(addr_to_index(b, addr, addr_format));
       load->src[1] = nir_src_for_ssa(addr_to_offset(b, addr, addr_format));
@@ -985,9 +1009,9 @@ build_explicit_io_store(nir_builder *b, nir_intrinsic_instr *intrin,
    store->src[0] = nir_src_for_ssa(value);
    if (addr_format_is_global(addr_format)) {
       store->src[1] = nir_src_for_ssa(addr_to_global(b, addr, addr_format));
-   } else if (addr_format == nir_address_format_32bit_offset) {
+   } else if (addr_format_is_offset(addr_format)) {
       assert(addr->num_components == 1);
-      store->src[1] = nir_src_for_ssa(addr);
+      store->src[1] = nir_src_for_ssa(addr_to_offset(b, addr, addr_format));
    } else {
       store->src[1] = nir_src_for_ssa(addr_to_index(b, addr, addr_format));
       store->src[2] = nir_src_for_ssa(addr_to_offset(b, addr, addr_format));
@@ -1042,7 +1066,7 @@ build_explicit_io_atomic(nir_builder *b, nir_intrinsic_instr *intrin,
       op = global_atomic_for_deref(intrin->intrinsic);
       break;
    case nir_var_mem_shared:
-      assert(addr_format == nir_address_format_32bit_offset);
+      assert(addr_format_is_offset(addr_format));
       op = shared_atomic_for_deref(intrin->intrinsic);
       break;
    default:
@@ -1054,9 +1078,9 @@ build_explicit_io_atomic(nir_builder *b, nir_intrinsic_instr *intrin,
    unsigned src = 0;
    if (addr_format_is_global(addr_format)) {
       atomic->src[src++] = nir_src_for_ssa(addr_to_global(b, addr, addr_format));
-   } else if (addr_format == nir_address_format_32bit_offset) {
+   } else if (addr_format_is_offset(addr_format)) {
       assert(addr->num_components == 1);
-      atomic->src[src++] = nir_src_for_ssa(addr);
+      atomic->src[src++] = nir_src_for_ssa(addr_to_offset(b, addr, addr_format));
    } else {
       atomic->src[src++] = nir_src_for_ssa(addr_to_index(b, addr, addr_format));
       atomic->src[src++] = nir_src_for_ssa(addr_to_offset(b, addr, addr_format));
@@ -1111,6 +1135,7 @@ nir_explicit_io_address_from_deref(nir_builder *b, nir_deref_instr *deref,
          return build_addr_iadd_imm(b, base_addr, addr_format,
                                        deref->var->data.driver_location);
       } else {
+         assert(deref->var->data.driver_location <= UINT32_MAX);
          return nir_imm_intN_t(b, deref->var->data.driver_location,
                                deref->dest.ssa.bit_size);
       }
@@ -1127,14 +1152,14 @@ nir_explicit_io_address_from_deref(nir_builder *b, nir_deref_instr *deref,
       assert(stride > 0);
 
       nir_ssa_def *index = nir_ssa_for_src(b, deref->arr.index, 1);
-      index = nir_i2i(b, index, base_addr->bit_size);
+      index = nir_i2i(b, index, addr_get_offset_bit_size(base_addr, addr_format));
       return build_addr_iadd(b, base_addr, addr_format,
                                 nir_amul_imm(b, index, stride));
    }
 
    case nir_deref_type_ptr_as_array: {
       nir_ssa_def *index = nir_ssa_for_src(b, deref->arr.index, 1);
-      index = nir_i2i(b, index, base_addr->bit_size);
+      index = nir_i2i(b, index, addr_get_offset_bit_size(base_addr, addr_format));
       unsigned stride = nir_deref_instr_ptr_as_array_stride(deref);
       return build_addr_iadd(b, base_addr, addr_format,
                                 nir_amul_imm(b, index, stride));
@@ -1615,6 +1640,7 @@ nir_address_format_null_value(nir_address_format addr_format)
       [nir_address_format_32bit_index_offset] = {{.u32 = ~0}, {.u32 = ~0}},
       [nir_address_format_vec2_index_32bit_offset] = {{.u32 = ~0}, {.u32 = ~0}, {.u32 = ~0}},
       [nir_address_format_32bit_offset] = {{.u32 = ~0}},
+      [nir_address_format_32bit_offset_as_64bit] = {{.u64 = ~0ull}},
       [nir_address_format_logical] = {{.u32 = ~0}},
    };
 
@@ -1635,6 +1661,10 @@ nir_build_addr_ieq(nir_builder *b, nir_ssa_def *addr0, nir_ssa_def *addr1,
    case nir_address_format_32bit_offset:
       return nir_ball_iequal(b, addr0, addr1);
 
+   case nir_address_format_32bit_offset_as_64bit:
+      assert(addr0->num_components == 1 && addr1->num_components == 1);
+      return nir_ieq(b, nir_u2u32(b, addr0), nir_u2u32(b, addr1));
+
    case nir_address_format_logical:
       unreachable("Unsupported address format");
    }
@@ -1654,6 +1684,11 @@ nir_build_addr_isub(nir_builder *b, nir_ssa_def *addr0, nir_ssa_def *addr1,
       assert(addr1->num_components == 1);
       return nir_isub(b, addr0, addr1);
 
+   case nir_address_format_32bit_offset_as_64bit:
+      assert(addr0->num_components == 1);
+      assert(addr1->num_components == 1);
+      return nir_u2u64(b, nir_isub(b, nir_u2u32(b, addr0), nir_u2u32(b, addr1)));
+
    case nir_address_format_64bit_bounded_global:
       return nir_isub(b, addr_to_global(b, addr0, addr_format),
                          addr_to_global(b, addr1, addr_format));