GPU: Shader: Rewrite default_argument_mutation using parser

This avoid failure cases from the regex.

Pull Request: https://projects.blender.org/blender/blender/pulls/144386
This commit is contained in:
Clément Foucault
2025-08-12 10:10:12 +02:00
committed by Clément Foucault
parent 33cc0eb805
commit 831969f4f0
3 changed files with 82 additions and 87 deletions

View File

@@ -1420,6 +1420,11 @@ class Preprocessor {
return str;
}
std::string strip_whitespace(const std::string &str) const
{
return str.substr(0, str.find_last_not_of(" \n") + 1);
}
/**
* Expand functions with default arguments to function overloads.
* Expects formatted input and that function bodies are followed by newline.
@@ -1427,95 +1432,63 @@ class Preprocessor {
std::string default_argument_mutation(std::string str)
{
using namespace std;
int match = 0;
default_argument_search(
str, [&](int /*parenthesis_depth*/, int /*bracket_depth*/, char & /*c*/) { match++; });
using namespace shader::parser;
if (match == 0) {
/* No mutation to do. Early out as the following regex is expensive. */
return str;
}
Parser parser(str);
vector<pair<string, string>> mutations;
int64_t line = 0;
/* Matches function definition. */
regex regex_func(R"(\n((\w+)\s+(\w+)\s*\()([^{]+))");
regex_global_search(str, regex_func, [&](const smatch &match) {
const string prefix = match[1].str();
const string return_type = match[2].str();
const string func_name = match[3].str();
const string args = get_content_between_balanced_pair('(' + match[4].str(), '(', ')');
const string suffix = ")\n{";
int64_t lines_in_content = line_count(match[0].str());
line += line_count(match.prefix().str()) + lines_in_content;
if (args.find('=') == string::npos) {
return;
}
const bool has_non_void_return_type = return_type != "void";
string line_directive = "#line " + std::to_string(line - lines_in_content + 2) + "\n";
vector<string> args_split = split_string_not_between_balanced_pair(args, ',', '(', ')');
string overloads;
string args_defined;
string args_called;
/* Rewrite original definition without defaults. */
string with_default = match[0].str();
string no_default = with_default;
for (const string &arg : args_split) {
regex regex(R"(((?:const )?\w+)\s+(\w+)( = (.+))?)");
smatch match;
regex_search(arg, match, regex);
string arg_type = match[1].str();
string arg_name = match[2].str();
string arg_assign = match[3].str();
string arg_value = match[4].str();
if (!arg_value.empty()) {
string body = func_name + "(" + args_called + arg_value + ");";
if (has_non_void_return_type) {
body = " return " + body;
}
else {
body = " " + body;
parser.foreach_function(
[&](bool, Token fn_type, Token fn_name, Scope fn_args, bool, Scope fn_body) {
if (!fn_args.contains_token('=')) {
return;
}
overloads = line_directive + prefix + args_defined + suffix + '\n' + line_directive +
body + "\n}\n" + overloads;
const bool has_non_void_return_type = fn_type.str_no_whitespace() != "void";
replace_all(no_default, arg_assign, "");
}
if (!args_defined.empty()) {
args_defined += ", ";
}
args_defined += arg_type + ' ' + arg_name;
args_called += arg_name + ", ";
}
string args_decl;
string args_names;
/* Get function body to put the overload after it. */
string body_content = '{' +
get_content_between_balanced_pair(match.suffix().str(), '{', '}') +
"}\n";
vector<string> fn_overloads;
string last_line_directive =
"#line " + std::to_string(line - lines_in_content + line_count(body_content) + 3) + "\n";
fn_args.foreach_scope(ScopeType::FunctionArg, [&](Scope arg) {
Token equal = arg.find_token('=');
const char *comma = (args_decl.empty() ? "" : ", ");
if (equal.is_invalid()) {
args_decl += comma + arg.str();
args_names += comma + arg.end().str();
}
else {
string arg_name = equal.prev().str_no_whitespace();
string value = parser.substr_range_inclusive(equal.next(), arg.end());
string decl = parser.substr_range_inclusive(arg.start(), equal.prev());
mutations.emplace_back(with_default + body_content,
no_default + body_content + overloads + last_line_directive);
});
string fn_call = fn_name.str() + '(' + args_names + comma + value + ");";
if (has_non_void_return_type) {
fn_call = "return " + fn_call;
}
string overload;
overload += fn_type.str();
overload += fn_name.str() + '(' + args_decl + ")\n";
overload += "{\n";
overload += "#line " + std::to_string(fn_type.line_number()) + "\n";
overload += " " + fn_call + "\n}\n";
fn_overloads.emplace_back(overload);
for (auto mutation : mutations) {
replace_all(str, mutation.first, mutation.second);
}
return str;
args_decl += comma + strip_whitespace(decl);
args_names += comma + arg_name;
/* Erase the value assignment and keep the declaration. */
parser.erase(equal.scope());
}
});
size_t end_of_fn_char = fn_body.end().line_end() + 1;
/* Have to reverse the declaration order. */
for (auto it = fn_overloads.rbegin(); it != fn_overloads.rend(); ++it) {
parser.insert_line_number(end_of_fn_char, fn_type.line_number());
parser.insert_after(end_of_fn_char, *it);
}
parser.insert_line_number(end_of_fn_char, fn_body.end().line_number() + 1);
});
return parser.result_get();
}
/* Used to make GLSL matrix constructor compatible with MSL in pyGPU shaders.

View File

@@ -299,6 +299,7 @@ struct ParserData {
token_offsets.offsets.emplace_back(offset);
}
}
token_offsets.offsets.emplace_back(offset);
}
{
/* Keywords detection. */
@@ -760,6 +761,18 @@ struct Scope {
end().str_index_last() - start().str_index_start() + 1);
}
Token find_token(const char token_type) const
{
size_t pos = data->token_types.substr(range().start, range().size).find(token_type);
return (pos != std::string::npos) ? Token{data, int64_t(range().start + pos)} :
Token::invalid();
}
bool contains_token(const char token_type) const
{
return find_token(token_type).is_valid();
}
void foreach_match(const std::string &pattern,
std::function<void(const std::vector<Token>)> callback) const
{
@@ -961,6 +974,11 @@ struct Parser {
insert_after(at.str_index_last(), content);
}
void insert_line_number(size_t at, int line)
{
insert_after(at, "#line " + std::to_string(line) + "\n");
}
void insert_before(size_t at, const std::string &content)
{
IndexRange range = IndexRange(at, 0);
@@ -1009,7 +1027,11 @@ struct Parser {
{
std::string out;
for (const Mutation &mut : mutations_) {
out += "Replace \"";
out += "Replace ";
out += std::to_string(mut.src_range.start);
out += " - ";
out += std::to_string(mut.src_range.size);
out += " \"";
out += data_.str.substr(mut.src_range.start, mut.src_range.size);
out += "\" by \"";
out += mut.replacement;

View File

@@ -371,7 +371,7 @@ int func(int a, int b = 0)
}
)";
string expect = R"(
int func(int a, int b)
int func(int a, int b )
{
return a + b;
}
@@ -396,7 +396,7 @@ int func(int a = 0, const int b = 0)
}
)";
string expect = R"(
int func(int a, const int b)
int func(int a , const int b )
{
return a + b;
}
@@ -426,7 +426,7 @@ int2 func(int2 a = int2(0, 0)) {
}
)";
string expect = R"(
int2 func(int2 a) {
int2 func(int2 a ) {
return a;
}
#line 2
@@ -435,7 +435,7 @@ int2 func()
#line 2
return func(int2(0, 0));
}
#line 6
#line 5
)";
string error;
string output = process_test_string(input, error);
@@ -449,7 +449,7 @@ void func(int a = 0) {
}
)";
string expect = R"(
void func(int a) {
void func(int a ) {
a;
}
#line 2
@@ -458,7 +458,7 @@ void func()
#line 2
func(0);
}
#line 6
#line 5
)";
string error;
string output = process_test_string(input, error);