From 505e4fc3ae5d96d6443bfe59e3f9a4741f48702c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Foucault?= Date: Tue, 26 Aug 2025 10:10:43 +0200 Subject: [PATCH] GPU: Shader: Add support for templated struct This does a few things: - Add support for templated struct. - Change parsing of template scope. Now all template scope `<..>` are parsed properly. - Rework to support better match syntax. - Avoid warning from scope guard processing. Now initialize the return value to zero. Pull Request: https://projects.blender.org/blender/blender/pulls/145132 --- .../gpu/glsl_preprocess/glsl_preprocess.hh | 243 ++++++++++++++++-- .../gpu/glsl_preprocess/shader_parser.hh | 112 +++++--- .../gpu/intern/gpu_shader_create_info.cc | 12 + .../blender/gpu/shaders/gpu_glsl_cpp_stubs.hh | 19 +- .../gpu/shaders/metal/mtl_shader_defines.msl | 1 + .../gpu/tests/shader_preprocess_test.cc | 96 ++++++- 6 files changed, 422 insertions(+), 61 deletions(-) diff --git a/source/blender/gpu/glsl_preprocess/glsl_preprocess.hh b/source/blender/gpu/glsl_preprocess/glsl_preprocess.hh index ba706ffb6ce..634a51839ba 100644 --- a/source/blender/gpu/glsl_preprocess/glsl_preprocess.hh +++ b/source/blender/gpu/glsl_preprocess/glsl_preprocess.hh @@ -21,6 +21,8 @@ namespace blender::gpu::shader { +#define ERROR_TOK(token) (token).line_number(), (token).char_number(), (token).line_str() + /* Metadata extracted from shader source file. * These are then converted to their GPU module equivalent. */ /* TODO(fclem): Make GPU enums standalone and directly use them instead of using separate enums @@ -220,7 +222,9 @@ class Preprocessor { str = swizzle_function_mutation(str, report_error); str = enum_macro_injection(str, language == CPP, report_error); if (language == BLENDER_GLSL) { + str = template_struct_mutation(str, report_error); str = struct_method_mutation(str, report_error); + str = empty_struct_mutation(str, report_error); str = method_call_mutation(str, report_error); str = stage_function_mutation(str); str = resource_guard_mutation(str, report_error); @@ -350,6 +354,170 @@ class Preprocessor { return std::regex_replace(out_str, regex, "\n"); } + std::string template_struct_mutation(const std::string &str, report_callback &report_error) + { + using namespace std; + using namespace shader::parser; + + std::string out_str = str; + + { + Parser parser(out_str, report_error); + + parser.foreach_match("w<..>(..)", [&](const vector &tokens) { + const Scope template_args = tokens[1].scope(); + template_args.foreach_match("w<..>", [&parser](const vector &tokens) { + string args_concat; + tokens[1].scope().foreach_scope(ScopeType::TemplateArg, [&](const Scope &scope) { + args_concat += '_' + scope.start().str(); + }); + /* This is already contained in a template. Don't output trailing underscore as double + * underscore is reserved in GLSL. */ + parser.replace(tokens[1].scope(), args_concat); + }); + }); + + parser.apply_mutations(); + + /* Replace full specialization by simple struct. */ + parser.foreach_match("t<>sw<..>", [&](const std::vector &tokens) { + const Scope template_args = tokens[5].scope(); + const Token struct_name = tokens[4]; + string struct_name_str = struct_name.str() + "_"; + template_args.foreach_scope(ScopeType::TemplateArg, [&](Scope arg) { + struct_name_str += arg.start().str() + "_"; + }); + parser.erase(template_args); + parser.erase(tokens[0], tokens[2]); + parser.replace(struct_name, struct_name_str); + }); + + out_str = parser.result_get(); + } + { + Parser parser(out_str, report_error); + + parser.foreach_scope(ScopeType::Template, [&](Scope temp) { + /* Parse template declaration. */ + Token struct_start = temp.end().next(); + if (struct_start != Struct) { + return; + } + Token struct_name = struct_start.next(); + Scope struct_body = struct_name.next().scope(); + + bool error = false; + temp.foreach_match("=", [&](const std::vector &tokens) { + report_error(ERROR_TOK(tokens[0]), + "Default arguments are not supported inside template declaration"); + error = true; + }); + if (error) { + return; + } + + string arg_pattern; + vector arg_list; + temp.foreach_scope(ScopeType::TemplateArg, [&](Scope arg) { + const Token type = arg.start(); + const Token name = type.next(); + const string name_str = name.str(); + const string type_str = type.str(); + + arg_list.emplace_back(name_str); + + if (type_str == "typename") { + arg_pattern += ",w"; + } + else if (type_str == "enum" || type_str == "bool") { + arg_pattern += ",w"; + } + else if (type_str == "int" || type_str == "uint") { + arg_pattern += ",0"; + } + else { + report_error(ERROR_TOK(type), "Invalid template argument type"); + } + }); + + Token struct_end = struct_body.end(); + const string fn_decl = parser.substr_range_inclusive(struct_start.str_index_start(), + struct_end.str_index_last()); + + /* Remove declaration. */ + Token template_keyword = temp.start().prev(); + parser.erase(template_keyword.str_index_start(), struct_end.line_end()); + + /* Replace instantiations. */ + Scope parent_scope = temp.scope(); + string specialization_pattern = "tsw<" + arg_pattern.substr(1) + ">"; + parent_scope.foreach_match(specialization_pattern, [&](const std::vector &tokens) { + if (struct_name.str() != tokens[2].str()) { + return; + } + /* Parse template values. */ + vector> arg_name_value_pairs; + for (int i = 0; i < arg_list.size(); i++) { + arg_name_value_pairs.emplace_back(arg_list[i], tokens[4 + 2 * i].str()); + } + /* Specialize template content. */ + Parser instance_parser(fn_decl, report_error, true); + instance_parser.foreach_match("w", [&](const std::vector &tokens) { + string token_str = tokens[0].str(); + for (const auto &arg_name_value : arg_name_value_pairs) { + if (token_str == arg_name_value.first) { + instance_parser.replace(tokens[0], arg_name_value.second); + } + } + }); + + const string template_args = parser.substr_range_inclusive( + tokens[3], tokens[3 + arg_pattern.size()]); + size_t pos = fn_decl.find(" " + struct_name.str()); + instance_parser.insert_after(pos + struct_name.str().size(), template_args); + /* Paste template content in place of instantiation. */ + Token end_of_instantiation = tokens.back(); + string instance = instance_parser.result_get(); + parser.insert_line_number(tokens.front().str_index_start() - 1, + struct_start.line_number()); + parser.replace(tokens.front().str_index_start(), + end_of_instantiation.str_index_last_no_whitespace(), + instance); + parser.insert_line_number(end_of_instantiation.line_end() + 1, + end_of_instantiation.line_number() + 1); + }); + }); + + out_str = parser.result_get(); + } + { + Parser parser(out_str, report_error); + + /* This rely on our codestyle that do not put spaces between template name and the opening + * angle bracket. */ + parser.foreach_match("sw<", [&](const std::vector &tokens) { + Token token = tokens[2]; + parser.replace(token, "_"); + token = token.next(); + while (token != '>') { + if (token == ',') { + /* Also replace and skip the space after the comma. */ + Token next_token = token.next_not_whitespace(); + parser.replace(token, next_token.prev(), "_"); + token = next_token; + } + else { + token = token.next(); + } + } + /* Replace closing angle bracket. */ + parser.replace(token, "_"); + }); + out_str = parser.result_get(); + } + return out_str; + } + std::string template_definition_mutation(const std::string &str, report_callback &report_error) { if (str.find("template") == std::string::npos) { @@ -383,11 +551,14 @@ class Preprocessor { { Parser parser(out_str, report_error); - parser.foreach_scope(ScopeType::Template, [&](Scope temp) { + parser.foreach_match("t<..>ww(..)c?{..}", [&](const vector &tokens) { /* Parse template declaration. */ - Token fn_start = temp.end().next(); - Token fn_name = (fn_start == Static) ? fn_start.next().next() : fn_start.next(); - Scope fn_args = fn_name.next().scope(); + Token fn_start = tokens[5]; + Token fn_name = tokens[6]; + Scope fn_args = tokens[7].scope(); + Scope temp = tokens[1].scope(); + Scope fn_body = tokens[13].scope(); + Token fn_end = fn_body.end(); bool error = false; temp.foreach_match("=", [&](const std::vector &tokens) { @@ -435,16 +606,18 @@ class Preprocessor { all_template_args_in_function_signature = false; } else { - report_error(type.line_number(), - type.char_number(), - type.line_str(), - "Invalid template argument type"); + report_error(ERROR_TOK(type), "Invalid template argument type"); } }); - Token after_args = fn_name.next().scope().end().next(); - Scope fn_body = (after_args == Const) ? after_args.next().scope() : after_args.scope(); - Token fn_end = fn_body.end(); + Token fn_args_start = fn_name.next(); + + if (fn_args_start != '(') { + report_error(ERROR_TOK(fn_args_start), + "Expected open parenthesis after template function name"); + return; + } + const string fn_decl = parser.substr_range_inclusive(fn_start.str_index_start(), fn_end.line_end()); @@ -454,7 +627,7 @@ class Preprocessor { /* Replace instantiations. */ Scope parent_scope = temp.scope(); - string specialization_pattern = "tww<" + arg_pattern.substr(1) + ">("; + string specialization_pattern = "tww<" + arg_pattern.substr(1) + ">(..);"; parent_scope.foreach_match(specialization_pattern, [&](const std::vector &tokens) { if (fn_name.str() != tokens[2].str()) { return; @@ -482,7 +655,7 @@ class Preprocessor { instance_parser.insert_after(pos + fn_name.str().size(), template_args); } /* Paste template content in place of instantiation. */ - Token end_of_instantiation = tokens.back().scope().end().next(); + Token end_of_instantiation = tokens.back(); string instance = instance_parser.result_get(); parser.insert_line_number(tokens.front().str_index_start() - 1, fn_start.line_number()); parser.replace(tokens.front().str_index_start(), @@ -1331,6 +1504,24 @@ class Preprocessor { return parser.result_get(); } + /* Add padding member to empty structs. + * Empty structs are useful for templating. */ + std::string empty_struct_mutation(const std::string &str, report_callback report_error) + { + using namespace std; + using namespace shader::parser; + + Parser parser(str, report_error); + + parser.foreach_scope(ScopeType::Global, [&](Scope scope) { + scope.foreach_match("sw{};", [&](const std::vector &tokens) { + parser.insert_after(tokens[2], "int _pad;"); + }); + }); + + return parser.result_get(); + } + /* Transform `a.fn(b)` into `fn(a, b)`. */ std::string method_call_mutation(const std::string &str, report_callback report_error) { @@ -1447,9 +1638,10 @@ class Preprocessor { parser.foreach_function([&](bool, Token fn_type, Token, Scope, bool, Scope fn_body) { fn_body.foreach_match("w(w,", [&](const std::vector &tokens) { string func_name = tokens[0].str(); - if (func_name != "specialization_constant_get" && func_name != "push_constant_get" && - func_name != "interface_get" && func_name != "attribute_get" && - func_name != "buffer_get" && func_name != "sampler_get" && func_name != "image_get") + if (func_name != "specialization_constant_get" && func_name != "shared_variable_get" && + func_name != "push_constant_get" && func_name != "interface_get" && + func_name != "attribute_get" && func_name != "buffer_get" && + func_name != "sampler_get" && func_name != "image_get") { return; } @@ -1486,10 +1678,25 @@ class Preprocessor { string guard_start = "#if defined(CREATE_INFO_" + info + ")\n"; string guard_else; if (fn_type.is_valid() && fn_type.str() != "void") { + string type = fn_type.str(); + bool is_trivial = false; + if (type == "float" || type == "float2" || type == "float3" || type == "float4" || + /**/ + type == "int" || type == "int2" || type == "int3" || type == "int4" || + /**/ + type == "uint" || type == "uint2" || type == "uint3" || type == "uint4" || + /**/ + type == "float2x2" || type == "float2x3" || type == "float2x4" || + /**/ + type == "float3x2" || type == "float3x3" || type == "float3x4" || + /**/ + type == "float4x2" || type == "float4x3" || type == "float4x4") + { + is_trivial = true; + } guard_else += "#else\n"; guard_else += line_start; - guard_else += " " + fn_type.str() + " result;\n"; - guard_else += " return result;\n"; + guard_else += " return " + type + (is_trivial ? "(0)" : "::zero()") + ";\n"; } string guard_end = "#endif\n"; diff --git a/source/blender/gpu/glsl_preprocess/shader_parser.hh b/source/blender/gpu/glsl_preprocess/shader_parser.hh index 96bfb2636eb..19ee2c3c542 100644 --- a/source/blender/gpu/glsl_preprocess/shader_parser.hh +++ b/source/blender/gpu/glsl_preprocess/shader_parser.hh @@ -634,7 +634,7 @@ struct Scope { int str_start = data->token_offsets[index_range.start].start; int str_end = data->token_offsets[index_range.last()].last(); return {std::string_view(data->token_types).substr(index_range.start, index_range.size), - std::string_view(data->str).substr(str_start, str_end - str_start), + std::string_view(data->str).substr(str_start, str_end - str_start + 1), data, index}; } @@ -691,22 +691,64 @@ struct Scope { void foreach_match(const std::string &pattern, std::function)> callback) const { - const std::string scope_tokens = data->token_types.substr(range().start, range().size); + assert(!pattern.empty()); + const std::string_view scope_tokens = + std::string_view(data->token_types).substr(range().start, range().size); + + auto count_match = [](const std::string_view &s, const std::string_view &pattern) { + size_t pos = 0, occurrences = 0; + while ((pos = s.find(pattern, pos)) != std::string::npos) { + occurrences += 1; + pos += pattern.length(); + } + return occurrences; + }; + const int control_token_count = count_match(pattern, "?") * 2 + count_match(pattern, "..") * 2; + + if (range().size < pattern.size() - control_token_count) { + return; + } + + const size_t searchable_range = scope_tokens.size() - + (pattern.size() - 1 - control_token_count); std::vector match; match.resize(pattern.size()); - size_t pos = 0; - while ((pos = scope_tokens.find(pattern, pos)) != std::string::npos) { - match[0] = Token::from_position(data, range().start + pos); - /* Do not match preprocessor directive by default. */ - if (match[0].scope().type() != ScopeType::Preprocessor) { - for (int i = 1; i < pattern.size(); i++) { - match[i] = Token::from_position(data, range().start + pos + i); + for (size_t pos = 0; pos < searchable_range; pos++) { + size_t cursor = range().start + pos; + + for (int i = 0; i < pattern.size(); i++) { + bool is_last_token = i == pattern.size() - 1; + TokenType token_type = TokenType(data->token_types[cursor]); + TokenType curr_search_token = TokenType(pattern[i]); + TokenType next_search_token = TokenType(is_last_token ? '\0' : pattern[i + 1]); + + /* Scope skipping. */ + if (!is_last_token && curr_search_token == '.' && next_search_token == '.') { + cursor = match[i - 1].scope().end().index; + i++; + continue; + } + + /* Regular token. */ + if (curr_search_token == token_type) { + match[i] = Token::from_position(data, cursor++); + + if (is_last_token) { + callback(match); + } + } + else if (!is_last_token && curr_search_token != '?' && next_search_token == '?') { + /* This was and optional token. Continue scanning. */ + match[i] = Token::invalid(); + i++; + } + else { + /* Token mismatch. Test next position. */ + break; } - callback(match); } - pos += 1; } } @@ -764,7 +806,7 @@ inline void ParserData::parse_scopes(report_callback &report_error) enter_scope(ScopeType::Global, 0); - bool in_template = false; + int in_template = 0; int tok_id = -1; for (char &c : token_types) { @@ -826,16 +868,19 @@ inline void ParserData::parse_scopes(report_callback &report_error) enter_scope(ScopeType::Subscript, tok_id); break; case AngleOpen: - if ((tok_id >= 1 && token_types[tok_id - 1] == Template) || - /* Catch case of specialized declaration. */ - ScopeType(scope_types.back()) == ScopeType::Template) - { - enter_scope(ScopeType::Template, tok_id); - in_template = true; + if (tok_id >= 1) { + char prev_char = str[token_offsets[tok_id - 1].last()]; + /* Rely on the fact that template are formatted without spaces but comparison isn't. */ + if ((prev_char != ' ' && prev_char != '\n' && prev_char != '<') || + token_types[tok_id - 1] == Template) + { + enter_scope(ScopeType::Template, tok_id); + in_template++; + } } break; case AngleClose: - if (in_template && scopes.top().type == ScopeType::Assignment) { + if (in_template > 0 && scopes.top().type == ScopeType::Assignment) { exit_scope(tok_id - 1); } if (scopes.top().type == ScopeType::TemplateArg) { @@ -843,6 +888,7 @@ inline void ParserData::parse_scopes(report_callback &report_error) } if (scopes.top().type == ScopeType::Template) { exit_scope(tok_id); + in_template--; } break; case BracketClose: @@ -881,6 +927,10 @@ inline void ParserData::parse_scopes(report_callback &report_error) } } + if (scopes.top().type == ScopeType::Preprocessor) { + exit_scope(tok_id - 1); + } + if (scopes.top().type != ScopeType::Global) { ScopeItem scope_item = scopes.top(); Token token = Token::from_position(this, scope_ranges[scope_item.index].start); @@ -971,18 +1021,13 @@ struct Parser { std::function callback) { - foreach_scope(ScopeType::FunctionArgs, [&](const Scope args) { - const bool is_const = args.end().next() == Const; - Token next = (is_const ? args.end().next() : args.end()).next(); - if (next != '{') { - /* Function Prototype. */ - return; - } - const bool is_static = args.start().prev().prev().prev() == Static; - Token type = args.start().prev().prev(); - Token name = args.start().prev(); - Scope body = next.scope(); - callback(is_static, type, name, args, is_const, body); + foreach_match("m?ww(..)c?{..}", [&](const std::vector matches) { + callback(matches[0] == Static, + matches[2], + matches[3], + matches[4].scope(), + matches[8] == Const, + matches[10].scope()); }); } @@ -1032,6 +1077,11 @@ struct Parser { { replace(tok.str_index_start(), tok.str_index_last(), replacement); } + /* Replace Scope by string. */ + void replace(Scope scope, const std::string &replacement) + { + replace(scope.start(), scope.end(), replacement); + } /* Replace the content from `from` to `to` (inclusive) by whitespaces without changing * line count and keep the remaining indentation spaces. */ diff --git a/source/blender/gpu/intern/gpu_shader_create_info.cc b/source/blender/gpu/intern/gpu_shader_create_info.cc index 256e5f14942..8fa55529289 100644 --- a/source/blender/gpu/intern/gpu_shader_create_info.cc +++ b/source/blender/gpu/intern/gpu_shader_create_info.cc @@ -150,6 +150,8 @@ void ShaderCreateInfo::finalize(const bool recursive) specialization_constants_.extend_non_duplicates(info.specialization_constants_); compilation_constants_.extend_non_duplicates(info.compilation_constants_); + shared_variables_.extend(info.shared_variables_); + validate_vertex_attributes(&info); /* Insert with duplicate check. */ @@ -360,6 +362,16 @@ std::string ShaderCreateInfo::check_error() const } } + /* Validate shared variables. */ + for (int i = 0; i < shared_variables_.size(); i++) { + for (int j = i + 1; j < shared_variables_.size(); j++) { + if (shared_variables_[i].name == shared_variables_[j].name) { + error += this->name_ + " contains two specialization constants with the name: " + + std::string(shared_variables_[i].name); + } + } + } + return error; } diff --git a/source/blender/gpu/shaders/gpu_glsl_cpp_stubs.hh b/source/blender/gpu/shaders/gpu_glsl_cpp_stubs.hh index 5a7fa51ac7a..b0c09c277cf 100644 --- a/source/blender/gpu/shaders/gpu_glsl_cpp_stubs.hh +++ b/source/blender/gpu/shaders/gpu_glsl_cpp_stubs.hh @@ -922,9 +922,11 @@ const int gpu_ViewportIndex = 0; } // namespace gl_FragmentShader +/* Outside of namespace to be used in create infos. */ +constexpr uint3 gl_WorkGroupSize = uint3(16, 16, 16); + namespace gl_ComputeShader { -constexpr uint3 gl_WorkGroupSize = uint3(16, 16, 16); extern const uint3 gl_NumWorkGroups; extern const uint3 gl_WorkGroupID; extern const uint3 gl_LocalInvocationID; @@ -1094,13 +1096,14 @@ void groupMemoryBarrier() {} #endif /* Resource accessor. */ -#define specialization_constant_get(create_info, _res) _res -#define push_constant_get(create_info, _res) _res -#define interface_get(create_info, _res) _res -#define attribute_get(create_info, _res) _res -#define buffer_get(create_info, _res) _res -#define sampler_get(create_info, _res) _res -#define image_get(create_info, _res) _res +#define specialization_constant_get(create_info, _res) create_info::_res +#define shared_variable_get(create_info, _res) create_info::_res +#define push_constant_get(create_info, _res) create_info::_res +#define interface_get(create_info, _res) create_info::_res +#define attribute_get(create_info, _res) create_info::_res +#define buffer_get(create_info, _res) create_info::_res +#define sampler_get(create_info, _res) create_info::_res +#define image_get(create_info, _res) create_info::_res #include "GPU_shader_shared_utils.hh" diff --git a/source/blender/gpu/shaders/metal/mtl_shader_defines.msl b/source/blender/gpu/shaders/metal/mtl_shader_defines.msl index ae919c82f64..eee569db141 100644 --- a/source/blender/gpu/shaders/metal/mtl_shader_defines.msl +++ b/source/blender/gpu/shaders/metal/mtl_shader_defines.msl @@ -1241,6 +1241,7 @@ float4x4 __mat4x4(float3x3 a) { return to_float4x4(a); } /* Resource accessor. */ #define specialization_constant_get(create_info, _res) _res +#define shared_variable_get(create_info, _res) _res #define push_constant_get(create_info, _res) _res #define interface_get(create_info, _res) _res #define attribute_get(create_info, _res) _res diff --git a/source/blender/gpu/tests/shader_preprocess_test.cc b/source/blender/gpu/tests/shader_preprocess_test.cc index da9dcb4e96d..cde0b1c58a3 100644 --- a/source/blender/gpu/tests/shader_preprocess_test.cc +++ b/source/blender/gpu/tests/shader_preprocess_test.cc @@ -303,6 +303,62 @@ template void func(float a); } GPU_TEST(preprocess_template); +static void test_preprocess_template_struct() +{ + using namespace shader; + using namespace std; + + { + string input = R"( +template +struct A { T a; }; +template struct A; +)"; + string expect = R"( + + +#line 3 +struct A_float_{ float a; }; +#line 4 +#line 5 +)"; + string error; + string output = process_test_string(input, error); + EXPECT_EQ(output, expect); + EXPECT_EQ(error, ""); + } + { + string input = R"( +template<> struct A{ + float a; +}; +)"; + string expect = R"( + struct A_float_{ + float a; +}; +#line 5 +)"; + string error; + string output = process_test_string(input, error); + EXPECT_EQ(output, expect); + EXPECT_EQ(error, ""); + } + { + string input = R"( +void func(A a) {} +)"; + string expect = R"( +void func(A_float_ a) {} +)"; + string error; + string output = process_test_string(input, error); + EXPECT_EQ(output, expect); + EXPECT_EQ(error, ""); + } +} +GPU_TEST(preprocess_template_struct); + static void test_preprocess_reference() { using namespace shader; @@ -511,7 +567,7 @@ int func2(int a) )"; string expect = R"( -struct A_S {}; +struct A_S {int _pad;}; #line 4 int A_func(int a) { @@ -649,7 +705,7 @@ void test() { string expect = R"( void A_B_func() {} -struct A_B_S {}; +struct A_B_S {int _pad;}; #line 5 @@ -892,8 +948,7 @@ uint my_func() { return i; #else #line 3 - uint result; - return result; + return uint(0); #endif #line 6 } @@ -968,6 +1023,39 @@ uint my_func() { } GPU_TEST(preprocess_resource_guard); +static void test_preprocess_empty_struct() +{ + using namespace shader; + using namespace std; + + { + string input = R"( +class S {}; +struct T {}; +struct U { + static void fn() {} +}; +)"; + string expect = R"( +struct S {int _pad;}; +#line 3 +struct T {int _pad;}; +#line 4 +struct U { + +int _pad;}; +#line 5 + static void U_fn() {} +#line 7 +)"; + string error; + string output = process_test_string(input, error); + EXPECT_EQ(output, expect); + EXPECT_EQ(error, ""); + } +} +GPU_TEST(preprocess_empty_struct); + static void test_preprocess_struct_methods() { using namespace shader;