GPU: Shader: Use parser for loop unrolling

This refactor the loop unrolling mechanism and
replaces the use of regex by the new parser.

Pull Request: https://projects.blender.org/blender/blender/pulls/145956
This commit is contained in:
Clément Foucault
2025-09-10 11:59:52 +02:00
committed by Clément Foucault
parent ece97ef3dc
commit 805e037df3
3 changed files with 348 additions and 274 deletions

View File

@@ -755,236 +755,221 @@ class Preprocessor {
return str;
}
struct Loop {
/* `[[gpu::unroll]] for (int i = 0; i < 10; i++)` */
std::string definition;
/* `{ some_computation(i); }` */
std::string body;
/* `int i = 0` */
std::string init_statement;
/* `i < 10` */
std::string test_statement;
/* `i++` */
std::string iter_statement;
/* Spaces and newline between loop start and body. */
std::string body_prefix;
/* Spaces before the loop definition. */
std::string indent;
/* `10` */
int64_t iter_count;
/* Line at which the loop was defined. */
int64_t definition_line;
/* Line at which the body starts. */
int64_t body_line;
/* Line at which the body ends. */
int64_t end_line;
};
using namespace std;
using namespace shader::parser;
std::vector<Loop> loops;
Parser parser(str, report_error);
auto add_loop = [&](Loop &loop,
const std::smatch &match,
int64_t line,
int64_t lines_in_content) {
std::string suffix = match.suffix().str();
loop.body = get_content_between_balanced_pair(loop.definition + suffix, '{', '}');
loop.body = '{' + loop.body + '}';
loop.definition_line = line - lines_in_content;
loop.body_line = line;
loop.end_line = loop.body_line + line_count(loop.body);
auto parse_for_args =
[&](const Scope loop_args, Scope &r_init, Scope &r_condition, Scope &r_iter) {
r_init = r_condition = r_iter = Scope::invalid();
loop_args.foreach_scope(ScopeType::LoopArg, [&](const Scope arg) {
if (arg.start().prev() == '(' && arg.end().next() == ';') {
r_init = arg;
}
else if (arg.start().prev() == ';' && arg.end().next() == ';') {
r_condition = arg;
}
else if (arg.start().prev() == ';' && arg.end().next() == ')') {
r_iter = arg;
}
else {
report_error(ERROR_TOK(arg.start()), "Invalid loop declaration.");
}
});
};
auto add_loop = [&](const Token loop_start,
const int iter_count,
const bool condition_is_trivial,
const Scope init,
const Scope cond,
const Scope iter,
const Scope body) {
string body_str = body.str();
/* Check that there is no unsupported keywords in the loop body. */
if (loop.body.find(" break;") != std::string::npos ||
loop.body.find(" continue;") != std::string::npos)
if (body_str.find(" break;") != std::string::npos ||
body_str.find(" continue;") != std::string::npos)
{
/* Expensive check. Remove other loops and switch scopes inside the unrolled loop scope and
* check again to avoid false positive. */
std::string modified_body = loop.body;
string modified_body = body_str;
std::regex regex_loop(R"( (for|while|do) )");
regex_global_search(loop.body, regex_loop, [&](const std::smatch &match) {
regex_global_search(modified_body, regex_loop, [&](const std::smatch &match) {
std::string inner_scope = get_content_between_balanced_pair(match.suffix(), '{', '}');
replace_all(modified_body, inner_scope, "");
});
/* Checks if `continue` exists, even in switch statement inside the unrolled loop scope. */
if (modified_body.find(" continue;") != std::string::npos) {
report_error(line_number(match),
char_number(match),
line_str(match),
"Error: Unrolled loop cannot contain \"continue\" statement.");
report_error(ERROR_TOK(loop_start),
"Unrolled loop cannot contain \"continue\" statement.");
}
std::regex regex_switch(R"( switch )");
regex_global_search(loop.body, regex_switch, [&](const std::smatch &match) {
regex_global_search(modified_body, regex_switch, [&](const std::smatch &match) {
std::string inner_scope = get_content_between_balanced_pair(match.suffix(), '{', '}');
replace_all(modified_body, inner_scope, "");
});
/* Checks if `break` exists inside the unrolled loop scope. */
if (modified_body.find(" break;") != std::string::npos) {
report_error(line_number(match),
char_number(match),
line_str(match),
"Error: Unrolled loop cannot contain \"break\" statement.");
report_error(ERROR_TOK(loop_start), "Unrolled loop cannot contain \"break\" statement.");
}
}
loops.emplace_back(loop);
if (!parser.replace_try(loop_start, body.end(), "", true)) {
/* This is the case of nested loops. This loop will be processed in another parser pass. */
return;
}
string indent_init, indent_cond, indent_iter;
if (init.is_valid()) {
indent_init = string(init.start().char_number() - 1, ' ');
}
if (cond.is_valid()) {
indent_cond = string(cond.start().char_number() - 3, ' ');
}
if (iter.is_valid()) {
indent_iter = string(iter.start().char_number(), ' ');
}
string indent_body = string(body.start().char_number(), ' ');
string indent_end = string(body.end().char_number(), ' ');
parser.insert_after(body.end(), "\n");
if (init.is_valid()) {
parser.insert_line_number(body.end(), init.start().line_number());
parser.insert_after(body.end(), indent_init + "{" + init.str() + ";\n");
}
else {
parser.insert_after(body.end(), "{\n");
}
for (int64_t i = 0; i < iter_count; i++) {
if (cond.is_valid() && !condition_is_trivial) {
parser.insert_line_number(body.end(), cond.start().line_number());
parser.insert_after(body.end(), indent_cond + "if(" + cond.str() + ")\n");
}
parser.insert_line_number(body.end(), body.start().line_number());
parser.insert_after(body.end(), indent_body + body_str + "\n");
if (iter.is_valid()) {
parser.insert_line_number(body.end(), iter.start().line_number());
parser.insert_after(body.end(), indent_iter + iter.str() + ";\n");
}
}
parser.insert_line_number(body.end(), body.end().line_number());
parser.insert_after(body.end(), indent_end + body.end().str_with_whitespace());
};
/* Parse the loop syntax. */
{
/* [[gpu::unroll]]. */
std::regex regex(R"(( *))"
R"(\[\[gpu::unroll\]\])"
R"(\s*for\s*\()"
R"(\s*((?:uint|int)\s+(\w+)\s+=\s+(-?\d+));)" /* Init statement. */
R"(\s*((\w+)\s+(>|<)(=?)\s+(-?\d+)))" /* Conditional statement. */
R"(\s*(?:&&)?\s*([^;)]+)?;)" /* Extra conditional statement. */
R"(\s*(((\w+)(\+\+|\-\-))[^\)]*))" /* Iteration statement. */
R"(\)(\s*))");
int64_t line = 0;
regex_global_search(str, regex, [&](const std::smatch &match) {
std::string counter_1 = match[3].str();
std::string counter_2 = match[6].str();
std::string counter_3 = match[13].str();
std::string content = match[0].str();
int64_t lines_in_content = line_count(content);
line += line_count(match.prefix().str()) + lines_in_content;
if ((counter_1 != counter_2) || (counter_1 != counter_3)) {
report_error(line_number(match),
char_number(match),
line_str(match),
"Error: Non matching loop counter variable.");
return;
}
Loop loop;
int64_t init = std::stol(match[4].str());
int64_t end = std::stol(match[9].str());
/* TODO(fclem): Support arbitrary strides (aka, arbitrary iter statement). */
loop.iter_count = std::abs(end - init);
std::string condition = match[7].str();
if (condition.empty()) {
report_error(line_number(match),
char_number(match),
line_str(match),
"Error: Unsupported condition in unrolled loop.");
}
std::string equal = match[8].str();
if (equal == "=") {
loop.iter_count += 1;
}
std::string iter = match[14].str();
if (iter == "++") {
if (condition == ">") {
report_error(line_number(match),
char_number(match),
line_str(match),
"Error: Unsupported condition in unrolled loop.");
do {
/* Parse the loop syntax. */
{
/* [[gpu::unroll]]. */
parser.foreach_match("[[w::w]]f(..){..}", [&](const std::vector<Token> tokens) {
if (tokens[1].scope().str() != "[gpu::unroll]") {
return;
}
}
else if (iter == "--") {
if (condition == "<") {
report_error(line_number(match),
char_number(match),
line_str(match),
"Error: Unsupported condition in unrolled loop.");
const Token for_tok = tokens[8];
const Scope loop_args = tokens[9].scope();
const Scope loop_body = tokens[13].scope();
Scope init, cond, iter;
parse_for_args(loop_args, init, cond, iter);
/* Init statement. */
const Token var_type = init[0];
const Token var_name = init[1];
const Token var_init = init[2];
if (var_type.str() != "int" && var_type.str() != "uint") {
report_error(ERROR_TOK(var_init), "Can only unroll integer based loop.");
return;
}
if (var_init != '=') {
report_error(ERROR_TOK(var_init), "Expecting assignment here.");
return;
}
if (init[3] != '0' && init[3] != '-') {
report_error(ERROR_TOK(init[3]), "Expecting integer literal here.");
return;
}
}
else {
report_error(line_number(match),
char_number(match),
line_str(match),
"Error: Unsupported for loop expression. Expecting ++ or --");
}
loop.definition = content;
loop.indent = match[1].str();
loop.init_statement = match[2].str();
if (!match[10].str().empty()) {
loop.test_statement = "if (" + match[10].str() + ") ";
}
loop.iter_statement = match[11].str();
loop.body_prefix = match[15].str();
/* Conditional statement. */
const Token cond_var = cond[0];
const Token cond_type = cond[1];
const Token cond_sign = (cond[2] == '+' || cond[2] == '-') ? cond[2] : Token::invalid();
const Token cond_end = cond_sign.is_valid() ? cond[3] : cond[2];
if (cond_var.str() != var_name.str()) {
report_error(ERROR_TOK(cond_var), "Non matching loop counter variable.");
return;
}
if (cond_end != '0') {
report_error(ERROR_TOK(cond_end), "Expecting integer literal here.");
return;
}
add_loop(loop, match, line, lines_in_content);
});
}
{
/* [[gpu::unroll(n)]]. */
std::regex regex(R"(( *))"
R"(\[\[gpu::unroll\((\d+)\)\]\])"
R"(\s*for\s*\()"
R"(\s*([^;]*);)"
R"(\s*([^;]*);)"
R"(\s*([^)]*))"
R"(\)(\s*))");
/* Iteration statement. */
const Token iter_var = iter[0];
const Token iter_type = iter[1];
if (iter_var.str() != var_name.str()) {
report_error(ERROR_TOK(iter_var), "Non matching loop counter variable.");
return;
}
if (iter_type == Increment) {
if (cond_type == '>') {
report_error(ERROR_TOK(for_tok), "Unsupported condition in unrolled loop.");
return;
}
}
else if (iter_type == Decrement) {
if (cond_type == '<') {
report_error(ERROR_TOK(for_tok), "Unsupported condition in unrolled loop.");
return;
}
}
else {
report_error(ERROR_TOK(iter_type), "Unsupported loop expression. Expecting ++ or --.");
return;
}
int64_t line = 0;
int64_t init_value = std::stol(
parser.substr_range_inclusive(var_init.next(), var_init.scope().end()));
int64_t end_value = std::stol(parser.substr_range_inclusive(
cond_sign.is_valid() ? cond_sign : cond_end, cond_end));
/* TODO(fclem): Support arbitrary strides (aka, arbitrary iter statement). */
int iter_count = std::abs(end_value - init_value);
if (cond_type == GEqual || cond_type == LEqual) {
iter_count += 1;
}
regex_global_search(str, regex, [&](const std::smatch &match) {
std::string content = match[0].str();
bool condition_is_trivial = (cond_end == cond.end());
int64_t lines_in_content = line_count(content);
line += line_count(match.prefix().str()) + lines_in_content;
Loop loop;
loop.iter_count = std::stol(match[2].str());
loop.definition = content;
loop.indent = match[1].str();
loop.init_statement = match[3].str();
loop.test_statement = "if (" + match[4].str() + ") ";
loop.iter_statement = match[5].str();
loop.body_prefix = match[13].str();
add_loop(loop, match, line, lines_in_content);
});
}
std::string out = str;
/* Copy paste loop iterations. */
for (const Loop &loop : loops) {
std::string replacement = loop.indent + "{ " + loop.init_statement + ";";
for (int64_t i = 0; i < loop.iter_count; i++) {
replacement += std::string("\n#line ") + std::to_string(loop.body_line + 1) + "\n";
replacement += loop.indent + loop.test_statement + loop.body;
replacement += std::string("\n#line ") + std::to_string(loop.definition_line + 1) + "\n";
replacement += loop.indent + loop.iter_statement + ";";
if (i == loop.iter_count - 1) {
replacement += std::string("\n#line ") + std::to_string(loop.end_line + 1) + "\n";
replacement += loop.indent + "}";
}
add_loop(tokens[0], iter_count, condition_is_trivial, init, cond, iter, loop_body);
});
}
{
/* [[gpu::unroll(n)]]. */
parser.foreach_match("[[w::w(0)]]f(..){..}", [&](const std::vector<Token> tokens) {
const Scope loop_args = tokens[12].scope();
const Scope loop_body = tokens[16].scope();
std::string replaced = loop.definition + loop.body;
Scope init, cond, iter;
parse_for_args(loop_args, init, cond, iter);
/* Replace all occurrences in case of recursive unrolling. */
replace_all(out, replaced, replacement);
}
int iter_count = std::stol(tokens[7].str());
add_loop(tokens[0], iter_count, false, init, cond, iter, loop_body);
});
}
} while (parser.apply_mutations());
/* Check for remaining keywords. */
if (out.find("[[gpu::unroll") != std::string::npos) {
regex_global_search(str, std::regex(R"(\[\[gpu::unroll)"), [&](const std::smatch &match) {
report_error(line_number(match),
char_number(match),
line_str(match),
"Error: Incompatible format for [[gpu::unroll]].");
});
}
parser.foreach_match("[[w::w", [&](const std::vector<Token> tokens) {
if (tokens[2].str() == "gpu" && tokens[5].str() == "unroll") {
report_error(ERROR_TOK(tokens[0]), "Incompatible loop format for [[gpu::unroll]].");
}
});
return out;
return parser.result_get();
}
std::string namespace_mutation(const std::string &str, report_callback report_error)

View File

@@ -114,6 +114,7 @@ enum class ScopeType : char {
Namespace = 'N',
Struct = 'S',
Function = 'F',
LoopArgs = 'l',
FunctionArgs = 'f',
FunctionCall = 'c',
Template = 'T',
@@ -125,6 +126,9 @@ enum class ScopeType : char {
Local = 'L',
/* Added scope inside FunctionArgs. */
FunctionArg = 'g',
/* Added scope inside LoopArgs. */
LoopArg = 'r',
};
/* Poor man's IndexRange. */
@@ -321,6 +325,7 @@ struct ParserData {
token_offsets.offsets.emplace_back(offset);
}
}
offset++;
token_offsets.offsets.emplace_back(offset);
}
{
@@ -741,12 +746,39 @@ struct Scope {
return start().prev().scope();
}
static Scope invalid()
{
return {"", "", nullptr, size_t(-1)};
}
bool is_valid() const
{
return data != nullptr && index >= 0;
}
bool is_invalid() const
{
return !is_valid();
}
std::string str() const
{
if (this->is_invalid()) {
return "";
}
return data->str.substr(start().str_index_start(),
end().str_index_last() - start().str_index_start() + 1);
}
/* Return the content without the first and last characters. */
std::string str_exclusive() const
{
if (this->is_invalid()) {
return "";
}
return data->str.substr(start().str_index_start() + 1,
end().str_index_last() - start().str_index_start() - 1);
}
Token find_token(const char token_type) const
{
size_t pos = data->token_types.substr(range().start, range().size).find(token_type);
@@ -992,7 +1024,12 @@ inline void ParserData::parse_scopes(report_callback &report_error)
break;
}
case ParOpen:
if (scopes.top().type == ScopeType::Global) {
if ((tok_id >= 1 && token_types[tok_id - 1] == For) ||
(tok_id >= 1 && token_types[tok_id - 1] == While))
{
enter_scope(ScopeType::LoopArgs, tok_id);
}
else if (scopes.top().type == ScopeType::Global) {
enter_scope(ScopeType::FunctionArgs, tok_id);
}
else if (scopes.top().type == ScopeType::Struct) {
@@ -1043,12 +1080,28 @@ inline void ParserData::parse_scopes(report_callback &report_error)
if (scopes.top().type == ScopeType::FunctionArg) {
exit_scope(tok_id - 1);
}
if (scopes.top().type == ScopeType::LoopArg) {
exit_scope(tok_id - 1);
}
exit_scope(tok_id);
break;
case SquareClose:
exit_scope(tok_id);
break;
case SemiColon:
if (scopes.top().type == ScopeType::Assignment) {
exit_scope(tok_id - 1);
}
if (scopes.top().type == ScopeType::FunctionArg) {
exit_scope(tok_id - 1);
}
if (scopes.top().type == ScopeType::TemplateArg) {
exit_scope(tok_id - 1);
}
if (scopes.top().type == ScopeType::LoopArg) {
exit_scope(tok_id - 1);
}
break;
case Comma:
if (scopes.top().type == ScopeType::Assignment) {
exit_scope(tok_id - 1);
@@ -1064,6 +1117,9 @@ inline void ParserData::parse_scopes(report_callback &report_error)
if (scopes.top().type == ScopeType::FunctionArgs) {
enter_scope(ScopeType::FunctionArg, tok_id);
}
if (scopes.top().type == ScopeType::LoopArgs) {
enter_scope(ScopeType::LoopArg, tok_id);
}
if (scopes.top().type == ScopeType::Template) {
enter_scope(ScopeType::TemplateArg, tok_id);
}
@@ -1221,8 +1277,14 @@ struct Parser {
}
/* Replace everything from `from` to `to` (inclusive).
* Return true on success. */
bool replace_try(Token from, Token to, const std::string &replacement)
bool replace_try(Token from,
Token to,
const std::string &replacement,
bool keep_trailing_whitespaces = false)
{
if (keep_trailing_whitespaces) {
return replace_try(from.str_index_start(), to.str_index_last_no_whitespace(), replacement);
}
return replace_try(from.str_index_start(), to.str_index_last(), replacement);
}
@@ -1307,6 +1369,10 @@ struct Parser {
{
insert_after(at, "#line " + std::to_string(line) + "\n");
}
void insert_line_number(Token at, int line)
{
insert_line_number(at.str_index_last(), line);
}
void insert_before(size_t at, const std::string &content)
{

View File

@@ -87,92 +87,112 @@ static void test_preprocess_unroll()
using namespace std;
{
string input = R"([[gpu::unroll]] for (int i = 2; i < 4; i++, y++) { content += i; })";
string expect = R"({ int i = 2;
#line 1
{ content += i; }
#line 1
i++, y++;
#line 1
{ content += i; }
#line 1
i++, y++;
#line 1
})";
string input = R"(
[[gpu::unroll]] for (int i = 2; i < 4; i++, y++) { content += i; })";
string expect = R"(
#line 2
{int i = 2;
#line 2
{ content += i; }
#line 2
i++, y++;
#line 2
{ content += i; }
#line 2
i++, y++;
#line 2
})";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(output, expect);
EXPECT_EQ(error, "");
}
{
string input = R"([[gpu::unroll]] for (int i = 2; i < 4 && i < y; i++, y++) { cont += i; })";
string expect = R"({ int i = 2;
#line 1
if (i < y) { cont += i; }
#line 1
i++, y++;
#line 1
if (i < y) { cont += i; }
#line 1
i++, y++;
#line 1
})";
string input = R"(
[[gpu::unroll]] for (int i = 2; i < 4 && i < y; i++, y++) { cont += i; })";
string expect = R"(
#line 2
{int i = 2;
#line 2
if(i < 4 && i < y)
#line 2
{ cont += i; }
#line 2
i++, y++;
#line 2
if(i < 4 && i < y)
#line 2
{ cont += i; }
#line 2
i++, y++;
#line 2
})";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(output, expect);
EXPECT_EQ(error, "");
}
{
string input = R"([[gpu::unroll(2)]] for (; i < j;) { content += i; })";
string expect = R"({ ;
#line 1
if (i < j) { content += i; }
#line 1
;
#line 1
if (i < j) { content += i; }
#line 1
;
#line 1
})";
string input = R"(
[[gpu::unroll(2)]] for (; i < j;) { content += i; })";
string expect = R"(
{
#line 2
if(i < j)
#line 2
{ content += i; }
#line 2
if(i < j)
#line 2
{ content += i; }
#line 2
})";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(output, expect);
EXPECT_EQ(error, "");
}
{
string input = R"([[gpu::unroll(2)]] for (; i < j;) { [[gpu::unroll(2)]] for (; j < k;) {} })";
string expect = R"({ ;
#line 1
if (i < j) { { ;
#line 1
if (j < k) {}
#line 1
;
#line 1
if (j < k) {}
#line 1
;
#line 1
} }
#line 1
;
#line 1
if (i < j) { { ;
#line 1
if (j < k) {}
#line 1
;
#line 1
if (j < k) {}
#line 1
;
#line 1
} }
#line 1
;
#line 1
})";
string input = R"(
[[gpu::unroll(2)]] for (; i < j;) { [[gpu::unroll(2)]] for (; j < k;) {} })";
string expect = R"(
{
#line 2
if(i < j)
#line 2
{
{
#line 2
if(j < k)
#line 2
{}
#line 2
if(j < k)
#line 2
{}
#line 2
} }
#line 2
if(i < j)
#line 2
{
{
#line 2
if(j < k)
#line 2
{}
#line 2
if(j < k)
#line 2
{}
#line 2
} }
#line 2
})";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(output, expect);
@@ -182,27 +202,30 @@ if (i < j) { { ;
string input = R"([[gpu::unroll(2)]] for (; i < j;) { break; })";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(error, "Error: Unrolled loop cannot contain \"break\" statement.");
EXPECT_EQ(error, "Unrolled loop cannot contain \"break\" statement.");
}
{
string input = R"([[gpu::unroll(2)]] for (; i < j;) { continue; })";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(error, "Error: Unrolled loop cannot contain \"continue\" statement.");
EXPECT_EQ(error, "Unrolled loop cannot contain \"continue\" statement.");
}
{
string input = R"([[gpu::unroll(2)]] for (; i < j;) { for (; j < k;) {break;continue;} })";
string expect = R"({ ;
#line 1
if (i < j) { for (; j < k;) {break;continue;} }
#line 1
;
#line 1
if (i < j) { for (; j < k;) {break;continue;} }
#line 1
;
#line 1
})";
string input = R"(
[[gpu::unroll(2)]] for (; i < j;) { for (; j < k;) {break;continue;} })";
string expect = R"(
{
#line 2
if(i < j)
#line 2
{ for (; j < k;) {break;continue;} }
#line 2
if(i < j)
#line 2
{ for (; j < k;) {break;continue;} }
#line 2
})";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(output, expect);
@@ -212,7 +235,7 @@ if (i < j) { for (; j < k;) {break;continue;} }
string input = R"([[gpu::unroll]] for (int i = 3; i > 2; i++) {})";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(error, "Error: Unsupported condition in unrolled loop.");
EXPECT_EQ(error, "Unsupported condition in unrolled loop.");
}
}
GPU_TEST(preprocess_unroll);