diff --git a/main.cpp b/main.cpp index d8aff1523..cd7b63868 100644 --- a/main.cpp +++ b/main.cpp @@ -1554,6 +1554,30 @@ static string compile_iteration(const CLIArguments &args, std::vector return ret; } +static MSLShaderVariableFormat parse_format(const char *text) +{ + MSLShaderVariableFormat format; + if (strcmp(text, "i8") == 0) + format = MSL_SHADER_VARIABLE_FORMAT_INT8; + else if (strcmp(text, "i16") == 0) + format = MSL_SHADER_VARIABLE_FORMAT_INT16; + else if (strcmp(text, "i32") == 0) + format = MSL_SHADER_VARIABLE_FORMAT_INT32; + else if (strcmp(text, "u8") == 0) + format = MSL_SHADER_VARIABLE_FORMAT_UINT8; + else if (strcmp(text, "u16") == 0) + format = MSL_SHADER_VARIABLE_FORMAT_UINT16; + else if (strcmp(text, "u32") == 0) + format = MSL_SHADER_VARIABLE_FORMAT_UINT32; + else if (strcmp(text, "float") == 0) + format = MSL_SHADER_VARIABLE_FORMAT_FLOAT; + else if (strcmp(text, "half") == 0) + format = MSL_SHADER_VARIABLE_FORMAT_HALF; + else + format = MSL_SHADER_VARIABLE_FORMAT_OTHER; + return format; +} + static int main_inner(int argc, char *argv[]) { CLIArguments args; @@ -1685,16 +1709,7 @@ static int main_inner(int argc, char *argv[]) // Make sure next_uint() is called in-order. input.location = parser.next_uint(); const char *format = parser.next_value_string("other"); - if (strcmp(format, "any32") == 0) - input.format = MSL_SHADER_VARIABLE_FORMAT_ANY32; - else if (strcmp(format, "any16") == 0) - input.format = MSL_SHADER_VARIABLE_FORMAT_ANY16; - else if (strcmp(format, "u16") == 0) - input.format = MSL_SHADER_VARIABLE_FORMAT_UINT16; - else if (strcmp(format, "u8") == 0) - input.format = MSL_SHADER_VARIABLE_FORMAT_UINT8; - else - input.format = MSL_SHADER_VARIABLE_FORMAT_OTHER; + input.format = parse_format(format); input.vecsize = parser.next_uint(); const char *rate = parser.next_value_string("vertex"); if (strcmp(rate, "primitive") == 0) @@ -1710,16 +1725,7 @@ static int main_inner(int argc, char *argv[]) // Make sure next_uint() is called in-order. output.location = parser.next_uint(); const char *format = parser.next_value_string("other"); - if (strcmp(format, "any32") == 0) - output.format = MSL_SHADER_VARIABLE_FORMAT_ANY32; - else if (strcmp(format, "any16") == 0) - output.format = MSL_SHADER_VARIABLE_FORMAT_ANY16; - else if (strcmp(format, "u16") == 0) - output.format = MSL_SHADER_VARIABLE_FORMAT_UINT16; - else if (strcmp(format, "u8") == 0) - output.format = MSL_SHADER_VARIABLE_FORMAT_UINT8; - else - output.format = MSL_SHADER_VARIABLE_FORMAT_OTHER; + output.format = parse_format(format); output.vecsize = parser.next_uint(); const char *rate = parser.next_value_string("vertex"); if (strcmp(rate, "primitive") == 0) @@ -1735,16 +1741,7 @@ static int main_inner(int argc, char *argv[]) // Make sure next_uint() is called in-order. input.location = parser.next_uint(); const char *format = parser.next_value_string("other"); - if (strcmp(format, "any32") == 0) - input.format = MSL_SHADER_VARIABLE_FORMAT_ANY32; - else if (strcmp(format, "any16") == 0) - input.format = MSL_SHADER_VARIABLE_FORMAT_ANY16; - else if (strcmp(format, "u16") == 0) - input.format = MSL_SHADER_VARIABLE_FORMAT_UINT16; - else if (strcmp(format, "u8") == 0) - input.format = MSL_SHADER_VARIABLE_FORMAT_UINT8; - else - input.format = MSL_SHADER_VARIABLE_FORMAT_OTHER; + input.format = parse_format(format); input.vecsize = parser.next_uint(); args.msl_shader_inputs.push_back(input); }); @@ -1753,16 +1750,7 @@ static int main_inner(int argc, char *argv[]) // Make sure next_uint() is called in-order. output.location = parser.next_uint(); const char *format = parser.next_value_string("other"); - if (strcmp(format, "any32") == 0) - output.format = MSL_SHADER_VARIABLE_FORMAT_ANY32; - else if (strcmp(format, "any16") == 0) - output.format = MSL_SHADER_VARIABLE_FORMAT_ANY16; - else if (strcmp(format, "u16") == 0) - output.format = MSL_SHADER_VARIABLE_FORMAT_UINT16; - else if (strcmp(format, "u8") == 0) - output.format = MSL_SHADER_VARIABLE_FORMAT_UINT8; - else - output.format = MSL_SHADER_VARIABLE_FORMAT_OTHER; + output.format = parse_format(format); output.vecsize = parser.next_uint(); args.msl_shader_outputs.push_back(output); }); diff --git a/spirv_cross_c.h b/spirv_cross_c.h index 0d8e6e10a..195d73b6f 100644 --- a/spirv_cross_c.h +++ b/spirv_cross_c.h @@ -293,22 +293,24 @@ typedef enum spvc_msl_index_type /* Maps to C++ API. */ typedef enum spvc_msl_shader_variable_format { + SPVC_MSL_SHADER_VARIABLE_FORMAT_OTHER = 0, SPVC_MSL_SHADER_VARIABLE_FORMAT_UINT8 = 1, SPVC_MSL_SHADER_VARIABLE_FORMAT_UINT16 = 2, - SPVC_MSL_SHADER_VARIABLE_FORMAT_ANY16 = 3, - SPVC_MSL_SHADER_VARIABLE_FORMAT_ANY32 = 4, - - /* Deprecated names. */ + SPVC_MSL_SHADER_VARIABLE_FORMAT_UINT32 = 3, + SPVC_MSL_SHADER_VARIABLE_FORMAT_FLOAT = 4, + SPVC_MSL_SHADER_VARIABLE_FORMAT_INT8 = 5, + SPVC_MSL_SHADER_VARIABLE_FORMAT_INT16 = 6, + SPVC_MSL_SHADER_VARIABLE_FORMAT_INT32 = 7, + SPVC_MSL_SHADER_VARIABLE_FORMAT_HALF = 8, + + // Deprecated aliases. SPVC_MSL_VERTEX_FORMAT_OTHER = SPVC_MSL_SHADER_VARIABLE_FORMAT_OTHER, SPVC_MSL_VERTEX_FORMAT_UINT8 = SPVC_MSL_SHADER_VARIABLE_FORMAT_UINT8, SPVC_MSL_VERTEX_FORMAT_UINT16 = SPVC_MSL_SHADER_VARIABLE_FORMAT_UINT16, SPVC_MSL_SHADER_INPUT_FORMAT_OTHER = SPVC_MSL_SHADER_VARIABLE_FORMAT_OTHER, SPVC_MSL_SHADER_INPUT_FORMAT_UINT8 = SPVC_MSL_SHADER_VARIABLE_FORMAT_UINT8, SPVC_MSL_SHADER_INPUT_FORMAT_UINT16 = SPVC_MSL_SHADER_VARIABLE_FORMAT_UINT16, - SPVC_MSL_SHADER_INPUT_FORMAT_ANY16 = SPVC_MSL_SHADER_VARIABLE_FORMAT_ANY16, - SPVC_MSL_SHADER_INPUT_FORMAT_ANY32 = SPVC_MSL_SHADER_VARIABLE_FORMAT_ANY32, - SPVC_MSL_SHADER_INPUT_FORMAT_INT_MAX = 0x7fffffff } spvc_msl_shader_variable_format, spvc_msl_shader_input_format, spvc_msl_vertex_format; diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 8471e34f3..d9cfd23d0 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -36,6 +36,21 @@ static const uint32_t k_unknown_location = ~0u; static const uint32_t k_unknown_component = ~0u; static const char *force_inline = "static inline __attribute__((always_inline))"; + +static bool builtin_is_per_primitive_mesh_output(BuiltIn builtin) { + switch (builtin) + { + case BuiltInLayer: + case BuiltInViewportIndex: + case BuiltInPrimitiveId: + case BuiltInCullPrimitiveEXT: + return true; + default: break; + } + + return false; +} + CompilerMSL::CompilerMSL(std::vector spirv_) : CompilerGLSL(std::move(spirv_)) { @@ -244,11 +259,13 @@ void CompilerMSL::build_implicit_builtins() active_input_builtins.get(BuiltInSubgroupGtMask)); bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index && msl_options.multiview_layered_rendering && - (msl_options.multiview || active_input_builtins.get(BuiltInViewIndex)); + (msl_options.multiview || active_input_builtins.get(BuiltInViewIndex)) && + !msl_options.for_mesh_pipeline; bool need_dispatch_base = msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute && (active_input_builtins.get(BuiltInWorkgroupId) || active_input_builtins.get(BuiltInGlobalInvocationId)); - bool need_grid_params = get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation; + bool need_grid_params = get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation && + !msl_options.for_mesh_pipeline; bool need_vertex_base_params = need_grid_params && (active_input_builtins.get(BuiltInVertexId) || active_input_builtins.get(BuiltInVertexIndex) || @@ -934,9 +951,9 @@ void CompilerMSL::build_implicit_builtins() } // If we're returning a struct from a vertex-like entry point, we must return a position attribute. - bool need_position = (get_execution_model() == ExecutionModelVertex || is_tese_shader()) && + bool need_position = (get_execution_model() == ExecutionModelVertex || is_tese_shader() || get_execution_model() == ExecutionModelGeometry) && !capture_output_to_buffer && !get_is_rasterization_disabled() && - !active_output_builtins.get(BuiltInPosition); + !active_output_builtins.get(BuiltInPosition) && !msl_options.for_mesh_pipeline; if (need_position) { @@ -1463,9 +1480,406 @@ void CompilerMSL::emit_entry_point_declarations() } } +static int vertex_count_in_primitive(CompilerMSL::Options::PrimitiveTopology top) { + if (top == CompilerMSL::Options::PrimitiveTopology::TriangleStrip) { + return 3; + } else if (top == CompilerMSL::Options::PrimitiveTopology::Triangles) { + return 3; + } else if (top == CompilerMSL::Options::PrimitiveTopology::Points) { + return 1; + } + + return 0; +} + +static const char* get_vertex_loader_component_suffix(uint32_t elements) { + switch (elements) { + case 1: return ""; + case 2: return "2"; + case 3: return "3"; + case 4: return "4"; + default: + SPIRV_CROSS_THROW("Invalid component count: " + std::to_string(elements)); + return "INVALID_COMPONENT_COUNT"; + } +} + +static const char* get_normalization_string(MSLShaderVariableFormat type, bool normalized) { + if (!normalized) + return ""; + switch (type) { + case MSL_SHADER_VARIABLE_FORMAT_UINT16: return " * (1.f/65535.f)"; + case MSL_SHADER_VARIABLE_FORMAT_UINT8: return " * (1.f/255.f)"; + // TODO: Proper positive sint normalization + case MSL_SHADER_VARIABLE_FORMAT_INT16: return " * (1.f/32768.f)"; + case MSL_SHADER_VARIABLE_FORMAT_INT8: return " * (1.f/128.f)"; + default: return ""; + } +} + +static const char *get_variable_format_string(MSLShaderVariableFormat format) { + switch (format) { + case MSL_SHADER_VARIABLE_FORMAT_INT8: return "byte"; + case MSL_SHADER_VARIABLE_FORMAT_UINT8: return "ubyte"; + case MSL_SHADER_VARIABLE_FORMAT_INT16: return "short"; + case MSL_SHADER_VARIABLE_FORMAT_UINT16: return "ushort"; + case MSL_SHADER_VARIABLE_FORMAT_INT32: return "int"; + case MSL_SHADER_VARIABLE_FORMAT_UINT32: return "uint"; + case MSL_SHADER_VARIABLE_FORMAT_HALF: return "half"; + case MSL_SHADER_VARIABLE_FORMAT_FLOAT: return "float"; + default: + SPIRV_CROSS_THROW("Format not handled: " + std::to_string(format)); + return "INVALID_TYPE"; + } +} + +static const char *get_variable_format_string(SPIRType::BaseType type) { + switch (type) { + case SPIRType::SByte: return "byte"; + case SPIRType::UByte: return "ubyte"; + case SPIRType::Short: return "short"; + case SPIRType::UShort: return "ushort"; + case SPIRType::Int: return "int"; + case SPIRType::UInt: return "uint"; + case SPIRType::Int64: return "long"; + case SPIRType::UInt64: return "ulong"; + case SPIRType::Half: return "half"; + case SPIRType::Float: return "float"; + case SPIRType::Double: return "double"; + default: + SPIRV_CROSS_THROW("Type not handled: " + std::to_string(type)); + return "INVALID_TYPE"; + } +} + +void CompilerMSL::emit_mesh_wrapper() { + auto &execution = get_entry_point(); + + if (execution.model == ExecutionModelVertex) { + auto out_var_type = get_variable_data_type(get(stage_out_var_id)); + + // Emit the payload struct + statement("struct Payload"); + begin_scope(); + statement(join(type_to_glsl(out_var_type), " vertices[", std::to_string(vertex_count_in_primitive(msl_options.input_primitive_type)), "];")); + end_scope(";"); + + // Emit struct with info about the draw call. + statement("struct DrawInfo"); + begin_scope(); + + statement("int32_t indexed;"); + statement("int32_t indexSize;"); + statement("int64_t indexBuffer;"); + + end_scope(";"); + + // Object entry point. + statement("[[object]] void ", execution.name, "(object_data Payload &payload [[payload]], mesh_grid_properties meshGridProperties, constant DrawInfo *drawInfo [[buffer(", std::to_string(get_msl_options().draw_info_index),")]],"); + + bool vertex_bindings[32] = {false}; + for (auto si: inputs_by_location) { + if (si.second.builtin != spv::BuiltInMax) continue; + vertex_bindings[si.second.binding] = true; + } + + for (int i = 0; i < 32; ++i) { + if (!vertex_bindings[i]) continue; + std::string binding = std::to_string(i); + statement("device uchar *vb", binding, " [[buffer(", binding, ")]],"); + } + + // Disable for_mesh_pipeline temporarily so that args get their [[attributes]]. + msl_options.for_mesh_pipeline = false; + string object_arguments; + entry_point_args_discrete_descriptors(object_arguments); + msl_options.for_mesh_pipeline = true; + + if (!object_arguments.empty()) object_arguments += ","; + statement(object_arguments); + + statement("uint3 positionInGrid [[thread_position_in_grid]])"); + + begin_scope(); + + if (msl_options.input_primitive_type == Options::PrimitiveTopology::TriangleStrip) { + statement("int startingIndex = positionInGrid.x;"); + statement("int vertexCount = 3;"); + } else if (msl_options.input_primitive_type == Options::PrimitiveTopology::Triangles) { + statement("int startingIndex = positionInGrid.x * 3;"); + statement("int vertexCount = 3;"); + } else if (msl_options.input_primitive_type == Options::PrimitiveTopology::Points) { + statement("int startingIndex = positionInGrid.x;"); + statement("int vertexCount = 1;"); + } else { + SPIRV_CROSS_THROW("Input primitive type not supported"); + } + + statement("int instanceIndex = positionInGrid.y;"); + + statement("for (int i = 0; i < vertexCount; ++i)"); + begin_scope(); + statement("uint vertexIndex;"); + + statement("if (drawInfo->indexed)"); + begin_scope(); + + statement(R"END( + if (drawInfo->indexSize == 1) { + vertexIndex = ((constant uchar *)drawInfo->indexBuffer)[startingIndex + i]; + if (vertexIndex == 0xff) { + return; + } + } else if (drawInfo->indexSize == 2) { + vertexIndex = ((constant ushort *)drawInfo->indexBuffer)[startingIndex + i]; + if (vertexIndex == 0xffff) { + return; + } + } else { + vertexIndex = ((constant uint *)drawInfo->indexBuffer)[startingIndex + i]; + if (vertexIndex == 0xffffffff) { + return; + } + })END"); + + end_scope(); + statement("else vertexIndex = startingIndex + i;"); + + if (stage_in_var_id) { + auto in_var_type = get_variable_data_type(get(stage_in_var_id)); + statement(type_to_glsl(in_var_type), " ", to_name(stage_in_var_id), ";"); + } + + ir.for_each_typed_id([&](uint32_t id, SPIRVariable &var) { + if (var.storage != StorageClassInput) + return; + + auto &type = get(var.basetype); + if (has_decoration(type.self, DecorationBlock)) + return; + + if (!interface_variable_exists_in_entry_point(var.self)) + return; + + if (is_hidden_variable(var, true)) + return; + + if (is_builtin_variable(var)) + return; + + uint32_t location = get_decoration(id, DecorationLocation); + + LocationComponentPair key; + key.location = location; + key.component = 0; + + auto payload_it = inputs_by_location.find(key); + if (payload_it == inputs_by_location.end()) + return; + + auto variable = payload_it->second; + std::string name = to_name(id); + + SPIRType& parent_type = get(type.parent_type); + std::string parent_type_str = type_to_glsl(get(type.parent_type), var.self); + uint32_t load_elements = std::min(variable.vecsize, parent_type.vecsize); + + std::string type_name = get_variable_format_string(variable.format); + std::string component = type_name + get_vertex_loader_component_suffix(load_elements); + std::string packed = variable.vecsize > 1 ? "packed_" : ""; + std::string load_string = "*(device " + packed + component + " *)(vb" + std::to_string(variable.binding) + " + " + std::to_string(variable.offset) + " + vertexIndex * " + std::to_string(variable.stride) + ")"; + if (load_elements != parent_type.vecsize) { + component = type_name + get_vertex_loader_component_suffix(parent_type.vecsize); + load_string = component + "(" + load_string; + for (uint32_t i = load_elements; i < parent_type.vecsize; i++) + load_string += (i == 3) ? ", 1" : ", 0"; + load_string += ")"; + } + if (parent_type_str != component) + load_string = parent_type_str + "(" + load_string + ")"; + statement(to_name(var.self), " = ", load_string, get_normalization_string(variable.format, variable.normalized), ";"); + }); + + statement("payload.vertices[i] = ", execution.name, "("); + bool need_comma = false; + if (stage_in_var_id) { + need_comma = true; + statement(to_name(stage_in_var_id)); + } + + auto resources = get_sorted_entry_point_args(false); + + for (auto &resource : resources) { + statement(need_comma ? ", " : "", resource.name); + need_comma = true; + } + + ir.for_each_typed_id([&](uint32_t var_id, SPIRVariable &var) { + if (var.storage == StorageClassInput && is_builtin_variable(var)) { + uint32_t builtin = get_decoration(var_id, DecorationBuiltIn); + + switch (builtin) { + case BuiltInInstanceIndex: { + statement(need_comma ? ", " : "", "instanceIndex"); + need_comma = true; + break; + } + case BuiltInBaseInstance: { + statement(need_comma ? ", " : "", "0"); + need_comma = true; + break; + } + case BuiltInVertexIndex: { + statement(need_comma ? ", " : "", "vertexIndex"); + need_comma = true; + break; + } + default: { + statement(need_comma ? ", " : "", "0 /* Unhandled builtin ", builtin, " */"); + need_comma = true; + } break; + } + } + }); + + statement(");"); + + end_scope(); + + statement("meshGridProperties.set_threadgroups_per_grid(uint3(1, 1, 1));"); + + end_scope(); + } else { + assert(execution.model == ExecutionModelGeometry); + + // Emit the payload struct + statement("struct Payload"); + begin_scope(); + statement("struct payload_vertex"); + begin_scope(); + + for (auto kv: inputs_by_location) { + std::string type_name = get_variable_format_string(kv.second.type); + if (kv.second.vecsize) type_name += get_vertex_loader_component_suffix(kv.second.vecsize); + std::string location = std::to_string(kv.second.location); + std::string attribute; + + switch (kv.second.builtin) { + case BuiltInPosition: + attribute = "[[position]]"; + break; + default: + fprintf(stderr, "Builtin not handled: %u\n", kv.second.builtin); + case BuiltInMax: + attribute = join("[[user(locn", std::to_string(kv.second.location), ")]]"); break; + break; + } + + statement(join(type_name, " in", location, " ", attribute, ";")); + } + + end_scope(";"); + + statement(join("payload_vertex vertices[", std::to_string(vertex_count_in_primitive(msl_options.input_primitive_type)), "];")); + end_scope(";"); + + // Mesh entry point. + + statement("[[mesh]] void ", execution.name, "(mesh_stream_t::mesh_t outputMesh, const object_data Payload &payload [[payload]],"); + + // Geometry bindings + + msl_options.for_mesh_pipeline = false; + + string mesh_arguments; + entry_point_args_discrete_descriptors(mesh_arguments); + msl_options.for_mesh_pipeline = true; + if (!mesh_arguments.empty()) mesh_arguments += ","; + statement(mesh_arguments); + + statement("uint lid [[thread_index_in_threadgroup]], uint tid [[threadgroup_position_in_grid]])"); + + begin_scope(); + + auto in_var_type = get_variable_data_type(get(stage_in_var_id)); + statement(type_to_glsl(in_var_type), " ", to_name(stage_in_var_id), ";"); + + + if (msl_options.input_primitive_type == Options::PrimitiveTopology::TriangleStrip) { + statement("const int vertexCount = 3;"); + } else if (msl_options.input_primitive_type == Options::PrimitiveTopology::Triangles) { + statement("const int vertexCount = 3;"); + } else if (msl_options.input_primitive_type == Options::PrimitiveTopology::Points) { + statement("const int vertexCount = 1;"); + } else { + SPIRV_CROSS_THROW("Input primitive type not supported"); + } + + statement("for (int i = 0; i < vertexCount; ++i)"); + begin_scope(); + + statement("auto out = payload.vertices[i];"); + + ir.for_each_typed_id([&](uint32_t id, SPIRVariable &var) { + if (var.storage != StorageClassInput) + return; + + auto &type = get(var.basetype); + if (has_decoration(type.self, DecorationBlock)) + return; + + if (!interface_variable_exists_in_entry_point(var.self)) + return; + + if (is_hidden_variable(var, true)) + return; + + uint32_t location = get_decoration(id, DecorationLocation); + + bool is_builtin = is_builtin_variable(var); + std::string name = to_name(id); + MSLShaderInterfaceVariable variable; + bool found; + + if (is_builtin) { + auto payload_it = inputs_by_builtin.find(get_decoration(id, DecorationBuiltIn)); + if ((found = (payload_it != inputs_by_builtin.end()))) + variable = payload_it->second; + } else { + LocationComponentPair key; + key.location = location; + key.component = 0; + auto payload_it = inputs_by_location.find(key); + if ((found = (payload_it != inputs_by_location.end()))) + variable = payload_it->second; + } + + if (found) { + statement("if (i < sizeof(", name, ") / sizeof(", name, "[0]))"); + statement("\t", name, "[i] = out.in", std::to_string(variable.location), ";"); + } + }); + + end_scope(); + + statement(join(execution.name, "(outputMesh, ", to_name(stage_in_var_id))); + + auto resources = get_sorted_entry_point_args(false); + + for (auto &resource : resources) { + statement(", ", resource.name); + } + + statement(");"); + end_scope(); + } +} + + string CompilerMSL::compile() { replace_illegal_entry_point_names(); + ir.fixup_reserved_names(); // Do not deal with GLES-isms like precision, older extensions and such. @@ -1563,10 +1977,16 @@ string CompilerMSL::compile() if (builtin_sample_mask_id) add_active_interface_variable(builtin_sample_mask_id); + auto &execution = get_entry_point(); + // Create structs to hold input, output and uniform variables. // Do output first to ensure out. is declared at top of entry function. qual_pos_var_name = ""; stage_out_var_id = add_interface_block(StorageClassOutput); + if (execution.model == ExecutionModelGeometry) { + stage_out_mesh_primitive_var_id = add_interface_block(StorageClassOutput, false, true); + } + patch_stage_out_var_id = add_interface_block(StorageClassOutput, true); stage_in_var_id = add_interface_block(StorageClassInput); if (is_tese_shader()) @@ -1593,6 +2013,14 @@ string CompilerMSL::compile() // the loop, so the hooks aren't added multiple times. fix_up_shader_inputs_outputs(); + if (execution.model == ExecutionModelGeometry) { + auto &entry_func = get(ir.default_entry_point); + + entry_func.fixup_hooks_in.push_back([=]() { + statement("mesh_stream_t meshStream(spvMeshOut, ", to_name(stage_out_var_id),", ", to_name(stage_out_mesh_primitive_var_id) ,");"); + }); + } + // If we are using argument buffers, we create argument buffer structures for them here. // These buffers will be used in the entry point, not the individual resources. if (msl_options.argument_buffers) @@ -1622,8 +2050,24 @@ string CompilerMSL::compile() emit_custom_functions(); emit_specialization_constants_and_structs(); emit_resources(); + + if (execution.model == ExecutionModelGeometry) { + auto output_primitives = execution.output_primitives; + if (!output_primitives) output_primitives = execution.output_vertices - 2; + + auto vertex_type = type_to_glsl(get_variable_data_type(get(stage_out_var_id))); + auto prim_type = type_to_glsl(get_variable_data_type(get(stage_out_mesh_primitive_var_id))); + + statement("enum { VERTEX_COUNT = ", std::to_string(execution.output_vertices), ", PRIMITIVE_COUNT = ", std::to_string(output_primitives), "};"); + statement("using mesh_stream_t = spvMeshStream<", vertex_type, ", ", prim_type, ", VERTEX_COUNT, PRIMITIVE_COUNT, metal::topology::triangle>;"); // TODO fill out actual topology + } + emit_function(get(ir.default_entry_point), Bitset()); + if (msl_options.for_mesh_pipeline) { + emit_mesh_wrapper(); + } + pass_count++; } while (is_forcing_recompilation()); @@ -3686,7 +4130,7 @@ void CompilerMSL::add_variable_to_interface_block(StorageClass storage, const st is_builtin = false; // MSL does not allow matrices or arrays in input or output variables, so need to handle it specially. - if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type) + if (!msl_options.for_mesh_pipeline && (!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type) { add_composite_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta); } @@ -3746,11 +4190,11 @@ void CompilerMSL::fix_up_interface_member_indices(StorageClass storage, uint32_t // Add an interface structure for the type of storage, which is either StorageClassInput or StorageClassOutput. // Returns the ID of the newly added variable, or zero if no variable was added. -uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch) +uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch, bool mesh_primitive) { // Accumulate the variables that should appear in the interface struct. SmallVector vars; - bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader(); + bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader() || get_execution_model() == ExecutionModelGeometry; bool has_seen_barycentric = false; InterfaceBlockMeta meta; @@ -3792,6 +4236,14 @@ uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch) bi_type == BuiltInFragDepth || bi_type == BuiltInFragStencilRefEXT || bi_type == BuiltInSampleMask; + if (get_execution_model() == ExecutionModelGeometry) + { + if (mesh_primitive && (!is_builtin || !builtin_is_per_primitive_mesh_output(bi_type))) + return; + + if (!mesh_primitive && is_builtin && builtin_is_per_primitive_mesh_output(bi_type)) return; + } + // These builtins are part of the stage in/out structs. bool is_interface_block_builtin = builtin_is_stage_in_out || (is_tese_shader() && !msl_options.raw_buffer_tese_input && @@ -3912,7 +4364,7 @@ uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch) // If no variables qualify, leave. // For patch input in a tessellation evaluation shader, the per-vertex stage inputs // are included in a special patch control point array. - if (vars.empty() && + if (vars.empty() && !mesh_primitive && !(!msl_options.raw_buffer_tese_input && storage == StorageClassInput && patch && stage_in_var_id)) return 0; @@ -4013,8 +4465,10 @@ uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch) for (auto &blk_id : entry_func.blocks) { auto &blk = get(blk_id); - if (blk.terminator == SPIRBlock::Return || (blk.terminator == SPIRBlock::Kill && blk_id == entry_func.blocks.back())) - blk.return_value = rtn_id; + if (blk.terminator == SPIRBlock::Return || (blk.terminator == SPIRBlock::Kill && blk_id == entry_func.blocks.back())) { + if (get_execution_model() != ExecutionModelGeometry) // Geometry shaders don't return the output structure, it's emitted to the mesh stream. + blk.return_value = rtn_id; + } } vars_needing_early_declaration.push_back(ib_var_id); } @@ -4154,11 +4608,11 @@ uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch) switch (input.second.format) { case MSL_SHADER_VARIABLE_FORMAT_UINT16: - case MSL_SHADER_VARIABLE_FORMAT_ANY16: + case MSL_SHADER_VARIABLE_FORMAT_INT16: + case MSL_SHADER_VARIABLE_FORMAT_HALF: type.basetype = SPIRType::UShort; type.width = 16; break; - case MSL_SHADER_VARIABLE_FORMAT_ANY32: default: type.basetype = SPIRType::UInt; type.width = 32; @@ -4212,11 +4666,11 @@ uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch) switch (output.second.format) { case MSL_SHADER_VARIABLE_FORMAT_UINT16: - case MSL_SHADER_VARIABLE_FORMAT_ANY16: + case MSL_SHADER_VARIABLE_FORMAT_INT16: + case MSL_SHADER_VARIABLE_FORMAT_HALF: type.basetype = SPIRType::UShort; type.width = 16; break; - case MSL_SHADER_VARIABLE_FORMAT_ANY32: default: type.basetype = SPIRType::UInt; type.width = 32; @@ -4471,6 +4925,26 @@ uint32_t CompilerMSL::ensure_correct_input_type(uint32_t type_id, uint32_t locat } } + case MSL_SHADER_VARIABLE_FORMAT_UINT32: + { + switch (type.basetype) + { + case SPIRType::UShort: + case SPIRType::UInt: + if (num_components > type.vecsize) + return build_extended_vector_type(type_id, num_components); + else + return type_id; + + case SPIRType::Int: + return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize, + SPIRType::UInt); + + default: + SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader"); + } + } + default: if (num_components > type.vecsize) type_id = build_extended_vector_type(type_id, num_components); @@ -5234,8 +5708,10 @@ void CompilerMSL::emit_header() if (!pragma_lines.empty() || suppress_missing_prototypes) statement(""); - statement("#include "); - statement("#include "); + if (!msl_options.for_mesh_pipeline) { + statement("#include "); + statement("#include "); + } for (auto &header : header_lines) statement(header); @@ -5433,6 +5909,64 @@ void CompilerMSL::emit_custom_templates() statement(""); break; + case SPVFuncImplEmitVertex: + statement("template"); + statement("struct spvMeshStream"); + begin_scope(); + + statement("using mesh_t = metal::mesh;"); + statement("thread mesh_t &meshOut;"); + + statement("int currentVertex = 0;"); + statement("int currentIndex = 0;"); + statement("int currentVertexInPrimitive = 0;"); + statement("int currentPrimitive = 0;"); + statement("thread P &primitiveData;"); + statement("thread V &vertexData;"); + + statement("spvMeshStream(thread mesh_t &_meshOut, thread V &_v, thread P &_p) : meshOut(_meshOut), primitiveData(_p), vertexData(_v)"); + begin_scope(); + end_scope(); + + statement("~spvMeshStream()"); + begin_scope(); + statement("meshOut.set_primitive_count(currentPrimitive);"); + end_scope(); + + statement("int VperP()"); + begin_scope(); + statement("if (T == metal::topology::triangle) return 3;"); + statement("else if (T == metal::topology::line) return 2;"); + statement("else /* if (T == metal::topology::point) */ return 1;"); + end_scope(); + + statement("void EndPrimitive()"); + begin_scope(); + statement("currentVertexInPrimitive = 0;"); + end_scope(); + + statement("void EmitVertex()"); + begin_scope(); + if (options.vertex.flip_vert_y) { + statement("V v = vertexData;"); + statement("v.gl_Position.y = -v.gl_Position.y; // Invert Y-axis for Metal"); + statement("meshOut.set_vertex(currentVertex++, v);"); + } else { + statement("meshOut.set_vertex(currentVertex++, vertexData);"); + } + statement("currentVertexInPrimitive++;"); + statement("if (currentVertexInPrimitive >= VperP())"); + begin_scope(); + statement("if (T == metal::topology::triangle) meshOut.set_index(currentIndex++, currentVertex-3);"); + statement("if (T == metal::topology::triangle || T == metal::topology::line) meshOut.set_index(currentIndex++, currentVertex-2);"); + statement("meshOut.set_index(currentIndex++, currentVertex-1);"); + statement("meshOut.set_primitive(currentPrimitive++, primitiveData);"); + end_scope(); + end_scope(); + end_scope(";"); + + break; + default: break; } @@ -7411,6 +7945,8 @@ void CompilerMSL::emit_resources() // Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created. emit_interface_block(stage_out_var_id); + if (stage_out_mesh_primitive_var_id) + emit_interface_block(stage_out_mesh_primitive_var_id); emit_interface_block(patch_stage_out_var_id); emit_interface_block(stage_in_var_id); emit_interface_block(patch_stage_in_var_id); @@ -9479,6 +10015,20 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) break; } + case OpEmitVertex: + { + add_spv_func_and_recompile(SPVFuncImplEmitVertex); + statement("meshStream.EmitVertex();"); + break; + } + + case OpEndPrimitive: + { + add_spv_func_and_recompile(SPVFuncImplEmitVertex); + statement("meshStream.EndPrimitive();"); + break; + } + default: CompilerGLSL::emit_instruction(instruction); break; @@ -10420,6 +10970,13 @@ void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &) if (processing_entry_point) { + auto &execution = get_entry_point(); + if (execution.model == ExecutionModelGeometry) { + auto output_primitives = execution.output_primitives; + if (!output_primitives) output_primitives = execution.output_vertices - 2; + decl += "mesh_stream_t::mesh_t spvMeshOut, "; + } + if (msl_options.argument_buffers) decl += entry_point_args_argument_buffer(!func.arguments.empty()); else @@ -11930,8 +12487,9 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in case BuiltInInstanceId: case BuiltInInstanceIndex: case BuiltInBaseInstance: - if (msl_options.vertex_for_tessellation) + if (msl_options.vertex_for_tessellation || msl_options.for_mesh_pipeline) return ""; + return string(" [[") + builtin_qualifier(builtin) + "]]"; case BuiltInDrawIndex: @@ -11948,12 +12506,19 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in else locn = get_member_location(type.self, index); - if (locn != k_unknown_location) + if (locn != k_unknown_location) { + if (msl_options.for_mesh_pipeline) + return ""; return string(" [[attribute(") + convert_to_string(locn) + ")]]"; + } + + + if (msl_options.for_mesh_pipeline) + return ""; } // Vertex and tessellation evaluation function outputs - if (((execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation) || is_tese_shader()) && + if (((execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation) || is_tese_shader() || (execution.model == ExecutionModelGeometry)) && type.storage == StorageClassOutput) { if (is_builtin) @@ -12392,6 +12957,10 @@ uint32_t CompilerMSL::get_or_allocate_builtin_output_member_location(spv::BuiltI // entry type if the current function is the entry point function string CompilerMSL::func_type_decl(SPIRType &type) { + auto &execution = get_entry_point(); + if (execution.model == ExecutionModelGeometry) + return "void"; + // The regular function return type. If not processing the entry point function, that's all we need string return_type = type_to_glsl(type) + type_to_array_glsl(type); if (!processing_entry_point) @@ -12404,10 +12973,15 @@ string CompilerMSL::func_type_decl(SPIRType &type) // Prepend a entry type, based on the execution model string entry_type; - auto &execution = get_entry_point(); switch (execution.model) { case ExecutionModelVertex: + if (msl_options.for_mesh_pipeline) { + if (!msl_options.supports_msl_version(3, 0)) + SPIRV_CROSS_THROW("Mesh pipelines require MSL 3.0."); + entry_type = ""; + break; + } if (msl_options.vertex_for_tessellation && !msl_options.supports_msl_version(1, 2)) SPIRV_CROSS_THROW("Tessellation requires Metal 1.2."); entry_type = msl_options.vertex_for_tessellation ? "kernel" : "vertex"; @@ -12441,6 +13015,9 @@ string CompilerMSL::func_type_decl(SPIRType &type) break; } + if (entry_type.empty()) + return return_type; + return entry_type + " " + return_type; } @@ -12621,7 +13198,9 @@ string CompilerMSL::entry_point_arg_stage_in() auto &type = get_variable_data_type(var); add_resource_name(var.self); - decl = join(type_to_glsl(type), " ", to_name(var.self), " [[stage_in]]"); + decl = join(type_to_glsl(type), " ", to_name(var.self)); + if (!msl_options.for_mesh_pipeline) + decl += " [[stage_in]]"; } return decl; @@ -12640,7 +13219,8 @@ bool CompilerMSL::is_direct_input_builtin(BuiltIn bi_type) case BuiltInInstanceId: case BuiltInInstanceIndex: case BuiltInBaseInstance: - return get_execution_model() != ExecutionModelVertex || !msl_options.vertex_for_tessellation; + return get_execution_model() != ExecutionModelVertex || !msl_options.vertex_for_tessellation || + msl_options.for_mesh_pipeline; // Tess. control function in case BuiltInPosition: case BuiltInPointSize: @@ -12741,16 +13321,18 @@ void CompilerMSL::entry_point_args_builtin(string &ep_args) else ep_args += builtin_type_decl(bi_type, var_id) + " " + to_expression(var_id); - ep_args += string(" [[") + builtin_qualifier(bi_type); - if (bi_type == BuiltInSampleMask && get_entry_point().flags.get(ExecutionModePostDepthCoverage)) - { - if (!msl_options.supports_msl_version(2)) - SPIRV_CROSS_THROW("Post-depth coverage requires MSL 2.0."); - if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3)) - SPIRV_CROSS_THROW("Post-depth coverage on Mac requires MSL 2.3."); - ep_args += ", post_depth_coverage"; + if (!msl_options.for_mesh_pipeline) { + ep_args += string(" [[") + builtin_qualifier(bi_type); + if (bi_type == BuiltInSampleMask && get_entry_point().flags.get(ExecutionModePostDepthCoverage)) + { + if (!msl_options.supports_msl_version(2)) + SPIRV_CROSS_THROW("Post-depth coverage requires MSL 2.0."); + if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3)) + SPIRV_CROSS_THROW("Post-depth coverage on Mac requires MSL 2.3."); + ep_args += ", post_depth_coverage"; + } + ep_args += "]]"; } - ep_args += "]]"; builtin_declaration = false; } } @@ -13003,7 +13585,8 @@ string CompilerMSL::entry_point_args_argument_buffer(bool append_comma) claimed_bindings.set(buffer_binding); ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id); - ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]"; + if (!msl_options.for_mesh_pipeline) + ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]"; next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1); } @@ -13039,23 +13622,12 @@ const MSLConstexprSampler *CompilerMSL::find_constexpr_sampler(uint32_t id) cons return nullptr; } -void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) +SmallVector CompilerMSL::get_sorted_entry_point_args(bool add_names) { // Output resources, sorted by resource index & type // We need to sort to work around a bug on macOS 10.13 with NVidia drivers where switching between shaders // with different order of buffers can result in issues with buffer assignments inside the driver. - struct Resource - { - SPIRVariable *var; - SPIRVariable *descriptor_alias; - string name; - SPIRType::BaseType basetype; - uint32_t index; - uint32_t plane; - uint32_t secondary_index; - }; - - SmallVector resources; + SmallVector resources; entry_point_bindings.clear(); ir.for_each_typed_id([&](uint32_t var_id, SPIRVariable &var) { @@ -13092,7 +13664,7 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) // and it's being used as an alias (so we can emit void* instead). resource.descriptor_alias = resource.var; // Need to promote interlocked usage so that the primary declaration is correct. - if (interlocked_resources.count(var_id)) + if (add_names && interlocked_resources.count(var_id)) interlocked_resources.insert(resource.var->self); break; } @@ -13119,7 +13691,7 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) if (type.basetype == SPIRType::SampledImage) { - add_resource_name(var_id); + if (add_names) add_resource_name(var_id); uint32_t plane_count = 1; if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable) @@ -13139,7 +13711,7 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) else if (!constexpr_sampler) { // constexpr samplers are not declared as resources. - add_resource_name(var_id); + if (add_names) add_resource_name(var_id); // Don't allocate resource indices for aliases. uint32_t resource_index = ~0u; @@ -13154,9 +13726,16 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) }); stable_sort(resources.begin(), resources.end(), - [](const Resource &lhs, const Resource &rhs) + [](const Entry_Point_Resource &lhs, const Entry_Point_Resource &rhs) { return tie(lhs.basetype, lhs.index) < tie(rhs.basetype, rhs.index); }); + return resources; +} + +void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) +{ + auto resources = get_sorted_entry_point_args(); + for (auto &r : resources) { auto &var = *r.var; @@ -13223,10 +13802,13 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) type_to_glsl(type) + "*>* "; } ep_args += to_restrict(var_id, true) + r.name + "_"; - ep_args += " [[buffer(" + convert_to_string(r.index) + ")"; - if (interlocked_resources.count(var_id)) - ep_args += ", raster_order_group(0)"; - ep_args += "]]"; + if (!msl_options.for_mesh_pipeline) + { + ep_args += " [[buffer(" + convert_to_string(r.index) + ")"; + if (interlocked_resources.count(var_id)) + ep_args += ", raster_order_group(0)"; + ep_args += "]]"; + } } else { @@ -13250,10 +13832,13 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) ep_args += ", "; ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(var_id, true) + r.name; - ep_args += " [[buffer(" + convert_to_string(r.index) + ")"; - if (interlocked_resources.count(var_id)) - ep_args += ", raster_order_group(0)"; - ep_args += "]]"; + if (!msl_options.for_mesh_pipeline) + { + ep_args += " [[buffer(" + convert_to_string(r.index) + ")"; + if (interlocked_resources.count(var_id)) + ep_args += ", raster_order_group(0)"; + ep_args += "]]"; + } } break; } @@ -13261,10 +13846,17 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) if (!ep_args.empty()) ep_args += ", "; ep_args += sampler_type(type, var_id) + " " + r.name; + if (is_runtime_size_array(type)) - ep_args += "_ [[buffer(" + convert_to_string(r.index) + ")]]"; - else - ep_args += " [[sampler(" + convert_to_string(r.index) + ")]]"; + ep_args += "_"; + + if (!msl_options.for_mesh_pipeline) + { + if (is_runtime_size_array(type)) + ep_args += " [[buffer(" + convert_to_string(r.index) + ")]]"; + else + ep_args += " [[sampler(" + convert_to_string(r.index) + ")]]"; + } break; case SPIRType::Image: { @@ -13280,13 +13872,18 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) ep_args += join(plane_name_suffix, r.plane); if (is_runtime_size_array(type)) - ep_args += "_ [[buffer(" + convert_to_string(r.index) + ")"; - else - ep_args += " [[texture(" + convert_to_string(r.index) + ")"; + ep_args += "_"; - if (interlocked_resources.count(var_id)) - ep_args += ", raster_order_group(0)"; - ep_args += "]]"; + if (!msl_options.for_mesh_pipeline) { + if (is_runtime_size_array(type)) + ep_args += " [[buffer(" + convert_to_string(r.index) + ")"; + else + ep_args += " [[texture(" + convert_to_string(r.index) + ")"; + + if (interlocked_resources.count(var_id)) + ep_args += ", raster_order_group(0)"; + ep_args += "]]"; + } } else { @@ -13301,10 +13898,13 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) { ep_args += ", device atomic_" + type_to_glsl(get(basetype.image.type), 0); ep_args += "* " + r.name + "_atomic"; - ep_args += " [[buffer(" + convert_to_string(r.secondary_index) + ")"; - if (interlocked_resources.count(var_id)) - ep_args += ", raster_order_group(0)"; - ep_args += "]]"; + if (!msl_options.for_mesh_pipeline) + { + ep_args += " [[buffer(" + convert_to_string(r.secondary_index) + ")"; + if (interlocked_resources.count(var_id)) + ep_args += ", raster_order_group(0)"; + ep_args += "]]"; + } } break; } @@ -13333,10 +13933,13 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args) type_to_glsl(type, var_id) + "& " + r.name; else ep_args += type_to_glsl(type, var_id) + " " + r.name; - ep_args += " [[buffer(" + convert_to_string(r.index) + ")"; - if (interlocked_resources.count(var_id)) - ep_args += ", raster_order_group(0)"; - ep_args += "]]"; + if (!msl_options.for_mesh_pipeline) + { + ep_args += " [[buffer(" + convert_to_string(r.index) + ")"; + if (interlocked_resources.count(var_id)) + ep_args += ", raster_order_group(0)"; + ep_args += "]]"; + } break; } } @@ -15816,8 +16419,21 @@ string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage) if (is_tesc_shader()) break; if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) && - !is_stage_output_builtin_masked(builtin)) + !is_stage_output_builtin_masked(builtin)) { + if (builtin_is_per_primitive_mesh_output(builtin) && get_execution_model() == ExecutionModelGeometry) { + return stage_out_mesh_primitive_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage); + } + return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage); + } + + if (current_function && get_execution_model() == ExecutionModelGeometry) { + if (storage == StorageClassInput) + return stage_in_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage); + else if (storage == StorageClassGeneric) + return CompilerGLSL::builtin_to_glsl(builtin, storage); + } + break; case BuiltInSampleMask: @@ -16060,7 +16676,8 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin) } else if (execution.model == ExecutionModelKernel || execution.model == ExecutionModelGLCompute || execution.model == ExecutionModelTessellationControl || - (execution.model == ExecutionModelVertex && msl_options.vertex_for_tessellation)) + (execution.model == ExecutionModelVertex && + (msl_options.vertex_for_tessellation || msl_options.for_mesh_pipeline))) { // We are generating a Metal kernel function. if (!msl_options.supports_msl_version(2)) diff --git a/spirv_msl.hpp b/spirv_msl.hpp index 26167f673..d53d70b85 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -42,8 +42,12 @@ enum MSLShaderVariableFormat MSL_SHADER_VARIABLE_FORMAT_OTHER = 0, MSL_SHADER_VARIABLE_FORMAT_UINT8 = 1, MSL_SHADER_VARIABLE_FORMAT_UINT16 = 2, - MSL_SHADER_VARIABLE_FORMAT_ANY16 = 3, - MSL_SHADER_VARIABLE_FORMAT_ANY32 = 4, + MSL_SHADER_VARIABLE_FORMAT_UINT32 = 3, + MSL_SHADER_VARIABLE_FORMAT_FLOAT = 4, + MSL_SHADER_VARIABLE_FORMAT_INT8 = 5, + MSL_SHADER_VARIABLE_FORMAT_INT16 = 6, + MSL_SHADER_VARIABLE_FORMAT_INT32 = 7, + MSL_SHADER_VARIABLE_FORMAT_HALF = 8, // Deprecated aliases. MSL_VERTEX_FORMAT_OTHER = MSL_SHADER_VARIABLE_FORMAT_OTHER, @@ -52,8 +56,6 @@ enum MSLShaderVariableFormat MSL_SHADER_INPUT_FORMAT_OTHER = MSL_SHADER_VARIABLE_FORMAT_OTHER, MSL_SHADER_INPUT_FORMAT_UINT8 = MSL_SHADER_VARIABLE_FORMAT_UINT8, MSL_SHADER_INPUT_FORMAT_UINT16 = MSL_SHADER_VARIABLE_FORMAT_UINT16, - MSL_SHADER_INPUT_FORMAT_ANY16 = MSL_SHADER_VARIABLE_FORMAT_ANY16, - MSL_SHADER_INPUT_FORMAT_ANY32 = MSL_SHADER_VARIABLE_FORMAT_ANY32, MSL_SHADER_VARIABLE_FORMAT_INT_MAX = 0x7fffffff }; @@ -81,6 +83,11 @@ struct MSLShaderInterfaceVariable spv::BuiltIn builtin = spv::BuiltInMax; uint32_t vecsize = 0; MSLShaderVariableRate rate = MSL_SHADER_VARIABLE_RATE_PER_VERTEX; + SPIRType::BaseType type = SPIRType::Unknown; + uint32_t offset = 0; + uint32_t stride = 0; + uint32_t binding = 0; + bool normalized = false; }; // Matches the binding index of a MSL resource for a binding within a descriptor set. @@ -319,6 +326,7 @@ class CompilerMSL : public CompilerGLSL uint32_t shader_input_buffer_index = 22; uint32_t shader_index_buffer_index = 21; uint32_t shader_patch_input_buffer_index = 20; + uint32_t draw_info_index = 20; uint32_t shader_input_wg_index = 0; uint32_t device_index = 0; uint32_t enable_frag_output_mask = 0xffffffff; @@ -505,6 +513,14 @@ class CompilerMSL : public CompilerGLSL // Note: Only Apple's GPU compiler takes advantage of the lack of coherency, so make sure to test on Apple GPUs if you disable this. bool readwrite_texture_fences = true; + // Compile for use with a geometry shader. If set, vertex shaders will be compiled as [[object]] + // functions, and geometry shaders as [[mesh]]. + bool for_mesh_pipeline = false; + + enum class PrimitiveTopology { + Triangles, TriangleStrip, Points + } input_primitive_type; + bool is_ios() const { return platform == iOS; @@ -808,6 +824,7 @@ class CompilerMSL : public CompilerGLSL SPVFuncImplVariableDescriptor, SPVFuncImplVariableSizedDescriptor, SPVFuncImplVariableDescriptorArray, + SPVFuncImplEmitVertex, }; // If the underlying resource has been used for comparison then duplicate loads of that resource must be too @@ -902,7 +919,7 @@ class CompilerMSL : public CompilerGLSL void extract_global_variables_from_function(uint32_t func_id, std::set &added_arg_ids, std::unordered_set &global_var_ids, std::unordered_set &processed_func_ids); - uint32_t add_interface_block(spv::StorageClass storage, bool patch = false); + uint32_t add_interface_block(spv::StorageClass storage, bool patch = false, bool mesh_primitive = false); uint32_t add_interface_block_pointer(uint32_t ib_var_id, spv::StorageClass storage); struct InterfaceBlockMeta @@ -964,6 +981,7 @@ class CompilerMSL : public CompilerGLSL void emit_interface_block(uint32_t ib_var_id); bool maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs); uint32_t get_resource_array_size(uint32_t id) const; + void emit_mesh_wrapper(); void fix_up_shader_inputs_outputs(); @@ -973,6 +991,19 @@ class CompilerMSL : public CompilerGLSL std::string entry_point_arg_stage_in(); void entry_point_args_builtin(std::string &args); void entry_point_args_discrete_descriptors(std::string &args); + + struct Entry_Point_Resource + { + SPIRVariable *var; + SPIRVariable *descriptor_alias; + std::string name; + SPIRType::BaseType basetype; + uint32_t index; + uint32_t plane; + uint32_t secondary_index; + }; + + SmallVector get_sorted_entry_point_args(bool add_name = true); std::string append_member_name(const std::string &qualifier, const SPIRType &type, uint32_t index); std::string ensure_valid_name(std::string name, std::string pfx); std::string to_sampler_expression(uint32_t id); @@ -1140,6 +1171,7 @@ class CompilerMSL : public CompilerGLSL VariableID tess_level_inner_var_id = 0; VariableID tess_level_outer_var_id = 0; VariableID stage_out_masked_builtin_type_id = 0; + VariableID stage_out_mesh_primitive_var_id = 0; // Handle HLSL-style 0-based vertex/instance index. enum class TriState @@ -1169,6 +1201,7 @@ class CompilerMSL : public CompilerGLSL std::string qual_pos_var_name; std::string stage_in_var_name = "in"; std::string stage_out_var_name = "out"; + std::string stage_out_mesh_primitive_var_name = "out_1"; std::string patch_stage_in_var_name = "patchIn"; std::string patch_stage_out_var_name = "patchOut"; std::string sampler_name_suffix = "Smplr";