GPU: Shader: Add support for templated struct

This does a few things:
- Add support for templated struct.
- Change parsing of template scope.
  Now all template scope `<..>` are parsed properly.
- Rework to support better match syntax.
- Avoid warning from scope guard processing. Now initialize
  the return value to zero.

Pull Request: https://projects.blender.org/blender/blender/pulls/145132
This commit is contained in:
Clément Foucault
2025-08-26 10:10:43 +02:00
committed by Clément Foucault
parent 0a2c84d48d
commit 505e4fc3ae
6 changed files with 422 additions and 61 deletions

View File

@@ -21,6 +21,8 @@
namespace blender::gpu::shader {
#define ERROR_TOK(token) (token).line_number(), (token).char_number(), (token).line_str()
/* Metadata extracted from shader source file.
* These are then converted to their GPU module equivalent. */
/* TODO(fclem): Make GPU enums standalone and directly use them instead of using separate enums
@@ -220,7 +222,9 @@ class Preprocessor {
str = swizzle_function_mutation(str, report_error);
str = enum_macro_injection(str, language == CPP, report_error);
if (language == BLENDER_GLSL) {
str = template_struct_mutation(str, report_error);
str = struct_method_mutation(str, report_error);
str = empty_struct_mutation(str, report_error);
str = method_call_mutation(str, report_error);
str = stage_function_mutation(str);
str = resource_guard_mutation(str, report_error);
@@ -350,6 +354,170 @@ class Preprocessor {
return std::regex_replace(out_str, regex, "\n");
}
std::string template_struct_mutation(const std::string &str, report_callback &report_error)
{
using namespace std;
using namespace shader::parser;
std::string out_str = str;
{
Parser parser(out_str, report_error);
parser.foreach_match("w<..>(..)", [&](const vector<Token> &tokens) {
const Scope template_args = tokens[1].scope();
template_args.foreach_match("w<..>", [&parser](const vector<Token> &tokens) {
string args_concat;
tokens[1].scope().foreach_scope(ScopeType::TemplateArg, [&](const Scope &scope) {
args_concat += '_' + scope.start().str();
});
/* This is already contained in a template. Don't output trailing underscore as double
* underscore is reserved in GLSL. */
parser.replace(tokens[1].scope(), args_concat);
});
});
parser.apply_mutations();
/* Replace full specialization by simple struct. */
parser.foreach_match("t<>sw<..>", [&](const std::vector<Token> &tokens) {
const Scope template_args = tokens[5].scope();
const Token struct_name = tokens[4];
string struct_name_str = struct_name.str() + "_";
template_args.foreach_scope(ScopeType::TemplateArg, [&](Scope arg) {
struct_name_str += arg.start().str() + "_";
});
parser.erase(template_args);
parser.erase(tokens[0], tokens[2]);
parser.replace(struct_name, struct_name_str);
});
out_str = parser.result_get();
}
{
Parser parser(out_str, report_error);
parser.foreach_scope(ScopeType::Template, [&](Scope temp) {
/* Parse template declaration. */
Token struct_start = temp.end().next();
if (struct_start != Struct) {
return;
}
Token struct_name = struct_start.next();
Scope struct_body = struct_name.next().scope();
bool error = false;
temp.foreach_match("=", [&](const std::vector<Token> &tokens) {
report_error(ERROR_TOK(tokens[0]),
"Default arguments are not supported inside template declaration");
error = true;
});
if (error) {
return;
}
string arg_pattern;
vector<string> arg_list;
temp.foreach_scope(ScopeType::TemplateArg, [&](Scope arg) {
const Token type = arg.start();
const Token name = type.next();
const string name_str = name.str();
const string type_str = type.str();
arg_list.emplace_back(name_str);
if (type_str == "typename") {
arg_pattern += ",w";
}
else if (type_str == "enum" || type_str == "bool") {
arg_pattern += ",w";
}
else if (type_str == "int" || type_str == "uint") {
arg_pattern += ",0";
}
else {
report_error(ERROR_TOK(type), "Invalid template argument type");
}
});
Token struct_end = struct_body.end();
const string fn_decl = parser.substr_range_inclusive(struct_start.str_index_start(),
struct_end.str_index_last());
/* Remove declaration. */
Token template_keyword = temp.start().prev();
parser.erase(template_keyword.str_index_start(), struct_end.line_end());
/* Replace instantiations. */
Scope parent_scope = temp.scope();
string specialization_pattern = "tsw<" + arg_pattern.substr(1) + ">";
parent_scope.foreach_match(specialization_pattern, [&](const std::vector<Token> &tokens) {
if (struct_name.str() != tokens[2].str()) {
return;
}
/* Parse template values. */
vector<pair<string, string>> arg_name_value_pairs;
for (int i = 0; i < arg_list.size(); i++) {
arg_name_value_pairs.emplace_back(arg_list[i], tokens[4 + 2 * i].str());
}
/* Specialize template content. */
Parser instance_parser(fn_decl, report_error, true);
instance_parser.foreach_match("w", [&](const std::vector<Token> &tokens) {
string token_str = tokens[0].str();
for (const auto &arg_name_value : arg_name_value_pairs) {
if (token_str == arg_name_value.first) {
instance_parser.replace(tokens[0], arg_name_value.second);
}
}
});
const string template_args = parser.substr_range_inclusive(
tokens[3], tokens[3 + arg_pattern.size()]);
size_t pos = fn_decl.find(" " + struct_name.str());
instance_parser.insert_after(pos + struct_name.str().size(), template_args);
/* Paste template content in place of instantiation. */
Token end_of_instantiation = tokens.back();
string instance = instance_parser.result_get();
parser.insert_line_number(tokens.front().str_index_start() - 1,
struct_start.line_number());
parser.replace(tokens.front().str_index_start(),
end_of_instantiation.str_index_last_no_whitespace(),
instance);
parser.insert_line_number(end_of_instantiation.line_end() + 1,
end_of_instantiation.line_number() + 1);
});
});
out_str = parser.result_get();
}
{
Parser parser(out_str, report_error);
/* This rely on our codestyle that do not put spaces between template name and the opening
* angle bracket. */
parser.foreach_match("sw<", [&](const std::vector<Token> &tokens) {
Token token = tokens[2];
parser.replace(token, "_");
token = token.next();
while (token != '>') {
if (token == ',') {
/* Also replace and skip the space after the comma. */
Token next_token = token.next_not_whitespace();
parser.replace(token, next_token.prev(), "_");
token = next_token;
}
else {
token = token.next();
}
}
/* Replace closing angle bracket. */
parser.replace(token, "_");
});
out_str = parser.result_get();
}
return out_str;
}
std::string template_definition_mutation(const std::string &str, report_callback &report_error)
{
if (str.find("template") == std::string::npos) {
@@ -383,11 +551,14 @@ class Preprocessor {
{
Parser parser(out_str, report_error);
parser.foreach_scope(ScopeType::Template, [&](Scope temp) {
parser.foreach_match("t<..>ww(..)c?{..}", [&](const vector<Token> &tokens) {
/* Parse template declaration. */
Token fn_start = temp.end().next();
Token fn_name = (fn_start == Static) ? fn_start.next().next() : fn_start.next();
Scope fn_args = fn_name.next().scope();
Token fn_start = tokens[5];
Token fn_name = tokens[6];
Scope fn_args = tokens[7].scope();
Scope temp = tokens[1].scope();
Scope fn_body = tokens[13].scope();
Token fn_end = fn_body.end();
bool error = false;
temp.foreach_match("=", [&](const std::vector<Token> &tokens) {
@@ -435,16 +606,18 @@ class Preprocessor {
all_template_args_in_function_signature = false;
}
else {
report_error(type.line_number(),
type.char_number(),
type.line_str(),
"Invalid template argument type");
report_error(ERROR_TOK(type), "Invalid template argument type");
}
});
Token after_args = fn_name.next().scope().end().next();
Scope fn_body = (after_args == Const) ? after_args.next().scope() : after_args.scope();
Token fn_end = fn_body.end();
Token fn_args_start = fn_name.next();
if (fn_args_start != '(') {
report_error(ERROR_TOK(fn_args_start),
"Expected open parenthesis after template function name");
return;
}
const string fn_decl = parser.substr_range_inclusive(fn_start.str_index_start(),
fn_end.line_end());
@@ -454,7 +627,7 @@ class Preprocessor {
/* Replace instantiations. */
Scope parent_scope = temp.scope();
string specialization_pattern = "tww<" + arg_pattern.substr(1) + ">(";
string specialization_pattern = "tww<" + arg_pattern.substr(1) + ">(..);";
parent_scope.foreach_match(specialization_pattern, [&](const std::vector<Token> &tokens) {
if (fn_name.str() != tokens[2].str()) {
return;
@@ -482,7 +655,7 @@ class Preprocessor {
instance_parser.insert_after(pos + fn_name.str().size(), template_args);
}
/* Paste template content in place of instantiation. */
Token end_of_instantiation = tokens.back().scope().end().next();
Token end_of_instantiation = tokens.back();
string instance = instance_parser.result_get();
parser.insert_line_number(tokens.front().str_index_start() - 1, fn_start.line_number());
parser.replace(tokens.front().str_index_start(),
@@ -1331,6 +1504,24 @@ class Preprocessor {
return parser.result_get();
}
/* Add padding member to empty structs.
* Empty structs are useful for templating. */
std::string empty_struct_mutation(const std::string &str, report_callback report_error)
{
using namespace std;
using namespace shader::parser;
Parser parser(str, report_error);
parser.foreach_scope(ScopeType::Global, [&](Scope scope) {
scope.foreach_match("sw{};", [&](const std::vector<Token> &tokens) {
parser.insert_after(tokens[2], "int _pad;");
});
});
return parser.result_get();
}
/* Transform `a.fn(b)` into `fn(a, b)`. */
std::string method_call_mutation(const std::string &str, report_callback report_error)
{
@@ -1447,9 +1638,10 @@ class Preprocessor {
parser.foreach_function([&](bool, Token fn_type, Token, Scope, bool, Scope fn_body) {
fn_body.foreach_match("w(w,", [&](const std::vector<Token> &tokens) {
string func_name = tokens[0].str();
if (func_name != "specialization_constant_get" && func_name != "push_constant_get" &&
func_name != "interface_get" && func_name != "attribute_get" &&
func_name != "buffer_get" && func_name != "sampler_get" && func_name != "image_get")
if (func_name != "specialization_constant_get" && func_name != "shared_variable_get" &&
func_name != "push_constant_get" && func_name != "interface_get" &&
func_name != "attribute_get" && func_name != "buffer_get" &&
func_name != "sampler_get" && func_name != "image_get")
{
return;
}
@@ -1486,10 +1678,25 @@ class Preprocessor {
string guard_start = "#if defined(CREATE_INFO_" + info + ")\n";
string guard_else;
if (fn_type.is_valid() && fn_type.str() != "void") {
string type = fn_type.str();
bool is_trivial = false;
if (type == "float" || type == "float2" || type == "float3" || type == "float4" ||
/**/
type == "int" || type == "int2" || type == "int3" || type == "int4" ||
/**/
type == "uint" || type == "uint2" || type == "uint3" || type == "uint4" ||
/**/
type == "float2x2" || type == "float2x3" || type == "float2x4" ||
/**/
type == "float3x2" || type == "float3x3" || type == "float3x4" ||
/**/
type == "float4x2" || type == "float4x3" || type == "float4x4")
{
is_trivial = true;
}
guard_else += "#else\n";
guard_else += line_start;
guard_else += " " + fn_type.str() + " result;\n";
guard_else += " return result;\n";
guard_else += " return " + type + (is_trivial ? "(0)" : "::zero()") + ";\n";
}
string guard_end = "#endif\n";

View File

@@ -634,7 +634,7 @@ struct Scope {
int str_start = data->token_offsets[index_range.start].start;
int str_end = data->token_offsets[index_range.last()].last();
return {std::string_view(data->token_types).substr(index_range.start, index_range.size),
std::string_view(data->str).substr(str_start, str_end - str_start),
std::string_view(data->str).substr(str_start, str_end - str_start + 1),
data,
index};
}
@@ -691,22 +691,64 @@ struct Scope {
void foreach_match(const std::string &pattern,
std::function<void(const std::vector<Token>)> callback) const
{
const std::string scope_tokens = data->token_types.substr(range().start, range().size);
assert(!pattern.empty());
const std::string_view scope_tokens =
std::string_view(data->token_types).substr(range().start, range().size);
auto count_match = [](const std::string_view &s, const std::string_view &pattern) {
size_t pos = 0, occurrences = 0;
while ((pos = s.find(pattern, pos)) != std::string::npos) {
occurrences += 1;
pos += pattern.length();
}
return occurrences;
};
const int control_token_count = count_match(pattern, "?") * 2 + count_match(pattern, "..") * 2;
if (range().size < pattern.size() - control_token_count) {
return;
}
const size_t searchable_range = scope_tokens.size() -
(pattern.size() - 1 - control_token_count);
std::vector<Token> match;
match.resize(pattern.size());
size_t pos = 0;
while ((pos = scope_tokens.find(pattern, pos)) != std::string::npos) {
match[0] = Token::from_position(data, range().start + pos);
/* Do not match preprocessor directive by default. */
if (match[0].scope().type() != ScopeType::Preprocessor) {
for (int i = 1; i < pattern.size(); i++) {
match[i] = Token::from_position(data, range().start + pos + i);
for (size_t pos = 0; pos < searchable_range; pos++) {
size_t cursor = range().start + pos;
for (int i = 0; i < pattern.size(); i++) {
bool is_last_token = i == pattern.size() - 1;
TokenType token_type = TokenType(data->token_types[cursor]);
TokenType curr_search_token = TokenType(pattern[i]);
TokenType next_search_token = TokenType(is_last_token ? '\0' : pattern[i + 1]);
/* Scope skipping. */
if (!is_last_token && curr_search_token == '.' && next_search_token == '.') {
cursor = match[i - 1].scope().end().index;
i++;
continue;
}
/* Regular token. */
if (curr_search_token == token_type) {
match[i] = Token::from_position(data, cursor++);
if (is_last_token) {
callback(match);
}
}
else if (!is_last_token && curr_search_token != '?' && next_search_token == '?') {
/* This was and optional token. Continue scanning. */
match[i] = Token::invalid();
i++;
}
else {
/* Token mismatch. Test next position. */
break;
}
callback(match);
}
pos += 1;
}
}
@@ -764,7 +806,7 @@ inline void ParserData::parse_scopes(report_callback &report_error)
enter_scope(ScopeType::Global, 0);
bool in_template = false;
int in_template = 0;
int tok_id = -1;
for (char &c : token_types) {
@@ -826,16 +868,19 @@ inline void ParserData::parse_scopes(report_callback &report_error)
enter_scope(ScopeType::Subscript, tok_id);
break;
case AngleOpen:
if ((tok_id >= 1 && token_types[tok_id - 1] == Template) ||
/* Catch case of specialized declaration. */
ScopeType(scope_types.back()) == ScopeType::Template)
{
enter_scope(ScopeType::Template, tok_id);
in_template = true;
if (tok_id >= 1) {
char prev_char = str[token_offsets[tok_id - 1].last()];
/* Rely on the fact that template are formatted without spaces but comparison isn't. */
if ((prev_char != ' ' && prev_char != '\n' && prev_char != '<') ||
token_types[tok_id - 1] == Template)
{
enter_scope(ScopeType::Template, tok_id);
in_template++;
}
}
break;
case AngleClose:
if (in_template && scopes.top().type == ScopeType::Assignment) {
if (in_template > 0 && scopes.top().type == ScopeType::Assignment) {
exit_scope(tok_id - 1);
}
if (scopes.top().type == ScopeType::TemplateArg) {
@@ -843,6 +888,7 @@ inline void ParserData::parse_scopes(report_callback &report_error)
}
if (scopes.top().type == ScopeType::Template) {
exit_scope(tok_id);
in_template--;
}
break;
case BracketClose:
@@ -881,6 +927,10 @@ inline void ParserData::parse_scopes(report_callback &report_error)
}
}
if (scopes.top().type == ScopeType::Preprocessor) {
exit_scope(tok_id - 1);
}
if (scopes.top().type != ScopeType::Global) {
ScopeItem scope_item = scopes.top();
Token token = Token::from_position(this, scope_ranges[scope_item.index].start);
@@ -971,18 +1021,13 @@ struct Parser {
std::function<void(
bool is_static, Token type, Token name, Scope args, bool is_const, Scope body)> callback)
{
foreach_scope(ScopeType::FunctionArgs, [&](const Scope args) {
const bool is_const = args.end().next() == Const;
Token next = (is_const ? args.end().next() : args.end()).next();
if (next != '{') {
/* Function Prototype. */
return;
}
const bool is_static = args.start().prev().prev().prev() == Static;
Token type = args.start().prev().prev();
Token name = args.start().prev();
Scope body = next.scope();
callback(is_static, type, name, args, is_const, body);
foreach_match("m?ww(..)c?{..}", [&](const std::vector<Token> matches) {
callback(matches[0] == Static,
matches[2],
matches[3],
matches[4].scope(),
matches[8] == Const,
matches[10].scope());
});
}
@@ -1032,6 +1077,11 @@ struct Parser {
{
replace(tok.str_index_start(), tok.str_index_last(), replacement);
}
/* Replace Scope by string. */
void replace(Scope scope, const std::string &replacement)
{
replace(scope.start(), scope.end(), replacement);
}
/* Replace the content from `from` to `to` (inclusive) by whitespaces without changing
* line count and keep the remaining indentation spaces. */

View File

@@ -150,6 +150,8 @@ void ShaderCreateInfo::finalize(const bool recursive)
specialization_constants_.extend_non_duplicates(info.specialization_constants_);
compilation_constants_.extend_non_duplicates(info.compilation_constants_);
shared_variables_.extend(info.shared_variables_);
validate_vertex_attributes(&info);
/* Insert with duplicate check. */
@@ -360,6 +362,16 @@ std::string ShaderCreateInfo::check_error() const
}
}
/* Validate shared variables. */
for (int i = 0; i < shared_variables_.size(); i++) {
for (int j = i + 1; j < shared_variables_.size(); j++) {
if (shared_variables_[i].name == shared_variables_[j].name) {
error += this->name_ + " contains two specialization constants with the name: " +
std::string(shared_variables_[i].name);
}
}
}
return error;
}

View File

@@ -922,9 +922,11 @@ const int gpu_ViewportIndex = 0;
} // namespace gl_FragmentShader
/* Outside of namespace to be used in create infos. */
constexpr uint3 gl_WorkGroupSize = uint3(16, 16, 16);
namespace gl_ComputeShader {
constexpr uint3 gl_WorkGroupSize = uint3(16, 16, 16);
extern const uint3 gl_NumWorkGroups;
extern const uint3 gl_WorkGroupID;
extern const uint3 gl_LocalInvocationID;
@@ -1094,13 +1096,14 @@ void groupMemoryBarrier() {}
#endif
/* Resource accessor. */
#define specialization_constant_get(create_info, _res) _res
#define push_constant_get(create_info, _res) _res
#define interface_get(create_info, _res) _res
#define attribute_get(create_info, _res) _res
#define buffer_get(create_info, _res) _res
#define sampler_get(create_info, _res) _res
#define image_get(create_info, _res) _res
#define specialization_constant_get(create_info, _res) create_info::_res
#define shared_variable_get(create_info, _res) create_info::_res
#define push_constant_get(create_info, _res) create_info::_res
#define interface_get(create_info, _res) create_info::_res
#define attribute_get(create_info, _res) create_info::_res
#define buffer_get(create_info, _res) create_info::_res
#define sampler_get(create_info, _res) create_info::_res
#define image_get(create_info, _res) create_info::_res
#include "GPU_shader_shared_utils.hh"

View File

@@ -1241,6 +1241,7 @@ float4x4 __mat4x4(float3x3 a) { return to_float4x4(a); }
/* Resource accessor. */
#define specialization_constant_get(create_info, _res) _res
#define shared_variable_get(create_info, _res) _res
#define push_constant_get(create_info, _res) _res
#define interface_get(create_info, _res) _res
#define attribute_get(create_info, _res) _res

View File

@@ -303,6 +303,62 @@ template void func(float a);
}
GPU_TEST(preprocess_template);
static void test_preprocess_template_struct()
{
using namespace shader;
using namespace std;
{
string input = R"(
template<typename T>
struct A { T a; };
template struct A<float>;
)";
string expect = R"(
#line 3
struct A_float_{ float a; };
#line 4
#line 5
)";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(output, expect);
EXPECT_EQ(error, "");
}
{
string input = R"(
template<> struct A<float>{
float a;
};
)";
string expect = R"(
struct A_float_{
float a;
};
#line 5
)";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(output, expect);
EXPECT_EQ(error, "");
}
{
string input = R"(
void func(A<float> a) {}
)";
string expect = R"(
void func(A_float_ a) {}
)";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(output, expect);
EXPECT_EQ(error, "");
}
}
GPU_TEST(preprocess_template_struct);
static void test_preprocess_reference()
{
using namespace shader;
@@ -511,7 +567,7 @@ int func2(int a)
)";
string expect = R"(
struct A_S {};
struct A_S {int _pad;};
#line 4
int A_func(int a)
{
@@ -649,7 +705,7 @@ void test() {
string expect = R"(
void A_B_func() {}
struct A_B_S {};
struct A_B_S {int _pad;};
#line 5
@@ -892,8 +948,7 @@ uint my_func() {
return i;
#else
#line 3
uint result;
return result;
return uint(0);
#endif
#line 6
}
@@ -968,6 +1023,39 @@ uint my_func() {
}
GPU_TEST(preprocess_resource_guard);
static void test_preprocess_empty_struct()
{
using namespace shader;
using namespace std;
{
string input = R"(
class S {};
struct T {};
struct U {
static void fn() {}
};
)";
string expect = R"(
struct S {int _pad;};
#line 3
struct T {int _pad;};
#line 4
struct U {
int _pad;};
#line 5
static void U_fn() {}
#line 7
)";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(output, expect);
EXPECT_EQ(error, "");
}
}
GPU_TEST(preprocess_empty_struct);
static void test_preprocess_struct_methods()
{
using namespace shader;