GPU: Shader: Add support for full template specialization

As the title says.
The existing implementation did not support this.

Note that this doesn't support partial specialization.

See #137441 for the original implementation.

This is needed for #143582.

Pull Request: https://projects.blender.org/blender/blender/pulls/144212
This commit is contained in:
Clément Foucault
2025-08-11 14:26:58 +02:00
parent ca57cf0750
commit bbd2dcb02d
3 changed files with 85 additions and 2 deletions

View File

@@ -412,6 +412,28 @@ class Preprocessor {
out_str.replace(start, body_end - start, macro_body);
}
}
{
using namespace std;
using namespace shader::parser;
Parser parser(out_str);
parser.foreach_scope(ScopeType::Global, [&](Scope scope) {
/* Replace full specialization by simple functions. */
scope.foreach_match("t<>ww<", [&](const std::vector<Token> &tokens) {
const Scope template_args = tokens[5].scope();
const Token fn_name = tokens[4];
string fn_name_str = fn_name.str_no_whitespace() + "_";
template_args.foreach_scope(ScopeType::TemplateArg, [&](Scope arg) {
fn_name_str += arg.start().str_no_whitespace() + "_";
});
parser.erase(template_args);
parser.erase(tokens[0], tokens[2]);
parser.replace(fn_name, fn_name_str);
});
});
out_str = parser.result_get();
}
{
/* Replace explicit instantiation by macro call. */
/* Only `template ret_t fn<T>(args);` syntax is supported. */

View File

@@ -111,11 +111,14 @@ enum class ScopeType : char {
Function = 'F',
FunctionArgs = 'f',
Template = 'T',
TemplateArg = 't',
Subscript = 'A',
Preprocessor = 'P',
Assignment = 'a',
/* Added scope inside function body. */
Local = 'L',
/* Added scope inside FunctionArgs. */
FunctionArg = 'g',
};
/* Poor man's IndexRange. */
@@ -458,7 +461,10 @@ struct ParserData {
enter_scope(ScopeType::Subscript, tok_id);
break;
case AngleOpen:
if (token_types[tok_id - 1] == Template) {
if (token_types[tok_id - 1] == Template ||
/* Catch case of specialized declaration. */
ScopeType(scope_types.back()) == ScopeType::Template)
{
enter_scope(ScopeType::Template, tok_id);
in_template = true;
}
@@ -467,6 +473,9 @@ struct ParserData {
if (in_template && scopes.top().type == ScopeType::Assignment) {
exit_scope(tok_id - 1);
}
if (scopes.top().type == ScopeType::TemplateArg) {
exit_scope(tok_id - 1);
}
if (scopes.top().type == ScopeType::Template) {
exit_scope(tok_id);
}
@@ -476,6 +485,9 @@ struct ParserData {
if (scopes.top().type == ScopeType::Assignment) {
exit_scope(tok_id - 1);
}
if (scopes.top().type == ScopeType::FunctionArg) {
exit_scope(tok_id - 1);
}
exit_scope(tok_id);
break;
case SquareClose:
@@ -486,8 +498,20 @@ struct ParserData {
if (scopes.top().type == ScopeType::Assignment) {
exit_scope(tok_id - 1);
}
if (scopes.top().type == ScopeType::FunctionArg) {
exit_scope(tok_id - 1);
}
if (scopes.top().type == ScopeType::TemplateArg) {
exit_scope(tok_id - 1);
}
break;
default:
if (scopes.top().type == ScopeType::FunctionArgs) {
enter_scope(ScopeType::FunctionArg, tok_id);
}
if (scopes.top().type == ScopeType::Template) {
enter_scope(ScopeType::TemplateArg, tok_id);
}
break;
}
}
@@ -721,7 +745,7 @@ struct Scope {
std::string str() const
{
return data->str.substr(start().str_index_start(),
end().str_index_last() - start().str_index_start());
end().str_index_last() - start().str_index_start() + 1);
}
void foreach_match(const std::string &pattern,
@@ -745,6 +769,24 @@ struct Scope {
pos += 1;
}
}
/* Will iterate over all the scopes that are direct children. */
void foreach_scope(ScopeType type, std::function<void(Scope)> callback) const
{
size_t pos = this->index;
while ((pos = data->scope_types.find(char(type), pos)) != std::string::npos) {
Scope scope{data, pos};
if (scope.start().index > this->end().index) {
/* Found scope starts after this scope. End iteration. */
break;
}
/* Make sure found scope is direct child of this scope. */
if (scope.start().scope().scope().index == this->index) {
callback(scope);
}
pos += 1;
}
}
};
inline Scope Token::scope() const
@@ -859,6 +901,11 @@ struct Parser {
{
replace(from.str_index_start(), to.str_index_last(), replacement);
}
/* Replace token by string. */
void replace(Token tok, const std::string &replacement)
{
replace(tok.str_index_start(), tok.str_index_last(), replacement);
}
/* Replace the content from `from` to `to` (inclusive) by whitespaces without changing
* line count and keep the remaining indentation spaces. */
@@ -885,6 +932,12 @@ struct Parser {
{
erase(tok, tok);
}
/* Replace the content of the scope by whitespaces without changing
* line count and keep the remaining indentation spaces. */
void erase(Scope scope)
{
erase(scope.start(), scope.end());
}
void insert_after(size_t at, const std::string &content)
{

View File

@@ -254,6 +254,14 @@ func_TEMPLATE(float, 1)/*float a*/)";
EXPECT_EQ(output, expect);
EXPECT_EQ(error, "");
}
{
string input = R"(template<> void func<T, Q>(T a) {a};)";
string expect = R"( void func_T_Q_(T a) {a};)";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(output, expect);
EXPECT_EQ(error, "");
}
{
string input = R"(template<typename T, int i = 0> void func(T a) {a;})";
string error;