GPU: Shader: Use parser for loop unrolling
This refactor the loop unrolling mechanism and replaces the use of regex by the new parser. Pull Request: https://projects.blender.org/blender/blender/pulls/145956
This commit is contained in:
committed by
Clément Foucault
parent
ece97ef3dc
commit
805e037df3
@@ -755,236 +755,221 @@ class Preprocessor {
|
||||
return str;
|
||||
}
|
||||
|
||||
struct Loop {
|
||||
/* `[[gpu::unroll]] for (int i = 0; i < 10; i++)` */
|
||||
std::string definition;
|
||||
/* `{ some_computation(i); }` */
|
||||
std::string body;
|
||||
/* `int i = 0` */
|
||||
std::string init_statement;
|
||||
/* `i < 10` */
|
||||
std::string test_statement;
|
||||
/* `i++` */
|
||||
std::string iter_statement;
|
||||
/* Spaces and newline between loop start and body. */
|
||||
std::string body_prefix;
|
||||
/* Spaces before the loop definition. */
|
||||
std::string indent;
|
||||
/* `10` */
|
||||
int64_t iter_count;
|
||||
/* Line at which the loop was defined. */
|
||||
int64_t definition_line;
|
||||
/* Line at which the body starts. */
|
||||
int64_t body_line;
|
||||
/* Line at which the body ends. */
|
||||
int64_t end_line;
|
||||
};
|
||||
using namespace std;
|
||||
using namespace shader::parser;
|
||||
|
||||
std::vector<Loop> loops;
|
||||
Parser parser(str, report_error);
|
||||
|
||||
auto add_loop = [&](Loop &loop,
|
||||
const std::smatch &match,
|
||||
int64_t line,
|
||||
int64_t lines_in_content) {
|
||||
std::string suffix = match.suffix().str();
|
||||
loop.body = get_content_between_balanced_pair(loop.definition + suffix, '{', '}');
|
||||
loop.body = '{' + loop.body + '}';
|
||||
loop.definition_line = line - lines_in_content;
|
||||
loop.body_line = line;
|
||||
loop.end_line = loop.body_line + line_count(loop.body);
|
||||
auto parse_for_args =
|
||||
[&](const Scope loop_args, Scope &r_init, Scope &r_condition, Scope &r_iter) {
|
||||
r_init = r_condition = r_iter = Scope::invalid();
|
||||
loop_args.foreach_scope(ScopeType::LoopArg, [&](const Scope arg) {
|
||||
if (arg.start().prev() == '(' && arg.end().next() == ';') {
|
||||
r_init = arg;
|
||||
}
|
||||
else if (arg.start().prev() == ';' && arg.end().next() == ';') {
|
||||
r_condition = arg;
|
||||
}
|
||||
else if (arg.start().prev() == ';' && arg.end().next() == ')') {
|
||||
r_iter = arg;
|
||||
}
|
||||
else {
|
||||
report_error(ERROR_TOK(arg.start()), "Invalid loop declaration.");
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
auto add_loop = [&](const Token loop_start,
|
||||
const int iter_count,
|
||||
const bool condition_is_trivial,
|
||||
const Scope init,
|
||||
const Scope cond,
|
||||
const Scope iter,
|
||||
const Scope body) {
|
||||
string body_str = body.str();
|
||||
/* Check that there is no unsupported keywords in the loop body. */
|
||||
if (loop.body.find(" break;") != std::string::npos ||
|
||||
loop.body.find(" continue;") != std::string::npos)
|
||||
if (body_str.find(" break;") != std::string::npos ||
|
||||
body_str.find(" continue;") != std::string::npos)
|
||||
{
|
||||
/* Expensive check. Remove other loops and switch scopes inside the unrolled loop scope and
|
||||
* check again to avoid false positive. */
|
||||
std::string modified_body = loop.body;
|
||||
string modified_body = body_str;
|
||||
|
||||
std::regex regex_loop(R"( (for|while|do) )");
|
||||
regex_global_search(loop.body, regex_loop, [&](const std::smatch &match) {
|
||||
regex_global_search(modified_body, regex_loop, [&](const std::smatch &match) {
|
||||
std::string inner_scope = get_content_between_balanced_pair(match.suffix(), '{', '}');
|
||||
replace_all(modified_body, inner_scope, "");
|
||||
});
|
||||
|
||||
/* Checks if `continue` exists, even in switch statement inside the unrolled loop scope. */
|
||||
if (modified_body.find(" continue;") != std::string::npos) {
|
||||
report_error(line_number(match),
|
||||
char_number(match),
|
||||
line_str(match),
|
||||
"Error: Unrolled loop cannot contain \"continue\" statement.");
|
||||
report_error(ERROR_TOK(loop_start),
|
||||
"Unrolled loop cannot contain \"continue\" statement.");
|
||||
}
|
||||
|
||||
std::regex regex_switch(R"( switch )");
|
||||
regex_global_search(loop.body, regex_switch, [&](const std::smatch &match) {
|
||||
regex_global_search(modified_body, regex_switch, [&](const std::smatch &match) {
|
||||
std::string inner_scope = get_content_between_balanced_pair(match.suffix(), '{', '}');
|
||||
replace_all(modified_body, inner_scope, "");
|
||||
});
|
||||
|
||||
/* Checks if `break` exists inside the unrolled loop scope. */
|
||||
if (modified_body.find(" break;") != std::string::npos) {
|
||||
report_error(line_number(match),
|
||||
char_number(match),
|
||||
line_str(match),
|
||||
"Error: Unrolled loop cannot contain \"break\" statement.");
|
||||
report_error(ERROR_TOK(loop_start), "Unrolled loop cannot contain \"break\" statement.");
|
||||
}
|
||||
}
|
||||
loops.emplace_back(loop);
|
||||
|
||||
if (!parser.replace_try(loop_start, body.end(), "", true)) {
|
||||
/* This is the case of nested loops. This loop will be processed in another parser pass. */
|
||||
return;
|
||||
}
|
||||
|
||||
string indent_init, indent_cond, indent_iter;
|
||||
if (init.is_valid()) {
|
||||
indent_init = string(init.start().char_number() - 1, ' ');
|
||||
}
|
||||
if (cond.is_valid()) {
|
||||
indent_cond = string(cond.start().char_number() - 3, ' ');
|
||||
}
|
||||
if (iter.is_valid()) {
|
||||
indent_iter = string(iter.start().char_number(), ' ');
|
||||
}
|
||||
string indent_body = string(body.start().char_number(), ' ');
|
||||
string indent_end = string(body.end().char_number(), ' ');
|
||||
|
||||
parser.insert_after(body.end(), "\n");
|
||||
if (init.is_valid()) {
|
||||
parser.insert_line_number(body.end(), init.start().line_number());
|
||||
parser.insert_after(body.end(), indent_init + "{" + init.str() + ";\n");
|
||||
}
|
||||
else {
|
||||
parser.insert_after(body.end(), "{\n");
|
||||
}
|
||||
for (int64_t i = 0; i < iter_count; i++) {
|
||||
if (cond.is_valid() && !condition_is_trivial) {
|
||||
parser.insert_line_number(body.end(), cond.start().line_number());
|
||||
parser.insert_after(body.end(), indent_cond + "if(" + cond.str() + ")\n");
|
||||
}
|
||||
parser.insert_line_number(body.end(), body.start().line_number());
|
||||
parser.insert_after(body.end(), indent_body + body_str + "\n");
|
||||
if (iter.is_valid()) {
|
||||
parser.insert_line_number(body.end(), iter.start().line_number());
|
||||
parser.insert_after(body.end(), indent_iter + iter.str() + ";\n");
|
||||
}
|
||||
}
|
||||
parser.insert_line_number(body.end(), body.end().line_number());
|
||||
parser.insert_after(body.end(), indent_end + body.end().str_with_whitespace());
|
||||
};
|
||||
|
||||
/* Parse the loop syntax. */
|
||||
{
|
||||
/* [[gpu::unroll]]. */
|
||||
std::regex regex(R"(( *))"
|
||||
R"(\[\[gpu::unroll\]\])"
|
||||
R"(\s*for\s*\()"
|
||||
R"(\s*((?:uint|int)\s+(\w+)\s+=\s+(-?\d+));)" /* Init statement. */
|
||||
R"(\s*((\w+)\s+(>|<)(=?)\s+(-?\d+)))" /* Conditional statement. */
|
||||
R"(\s*(?:&&)?\s*([^;)]+)?;)" /* Extra conditional statement. */
|
||||
R"(\s*(((\w+)(\+\+|\-\-))[^\)]*))" /* Iteration statement. */
|
||||
R"(\)(\s*))");
|
||||
|
||||
int64_t line = 0;
|
||||
|
||||
regex_global_search(str, regex, [&](const std::smatch &match) {
|
||||
std::string counter_1 = match[3].str();
|
||||
std::string counter_2 = match[6].str();
|
||||
std::string counter_3 = match[13].str();
|
||||
|
||||
std::string content = match[0].str();
|
||||
int64_t lines_in_content = line_count(content);
|
||||
|
||||
line += line_count(match.prefix().str()) + lines_in_content;
|
||||
|
||||
if ((counter_1 != counter_2) || (counter_1 != counter_3)) {
|
||||
report_error(line_number(match),
|
||||
char_number(match),
|
||||
line_str(match),
|
||||
"Error: Non matching loop counter variable.");
|
||||
return;
|
||||
}
|
||||
|
||||
Loop loop;
|
||||
|
||||
int64_t init = std::stol(match[4].str());
|
||||
int64_t end = std::stol(match[9].str());
|
||||
/* TODO(fclem): Support arbitrary strides (aka, arbitrary iter statement). */
|
||||
loop.iter_count = std::abs(end - init);
|
||||
|
||||
std::string condition = match[7].str();
|
||||
if (condition.empty()) {
|
||||
report_error(line_number(match),
|
||||
char_number(match),
|
||||
line_str(match),
|
||||
"Error: Unsupported condition in unrolled loop.");
|
||||
}
|
||||
|
||||
std::string equal = match[8].str();
|
||||
if (equal == "=") {
|
||||
loop.iter_count += 1;
|
||||
}
|
||||
|
||||
std::string iter = match[14].str();
|
||||
if (iter == "++") {
|
||||
if (condition == ">") {
|
||||
report_error(line_number(match),
|
||||
char_number(match),
|
||||
line_str(match),
|
||||
"Error: Unsupported condition in unrolled loop.");
|
||||
do {
|
||||
/* Parse the loop syntax. */
|
||||
{
|
||||
/* [[gpu::unroll]]. */
|
||||
parser.foreach_match("[[w::w]]f(..){..}", [&](const std::vector<Token> tokens) {
|
||||
if (tokens[1].scope().str() != "[gpu::unroll]") {
|
||||
return;
|
||||
}
|
||||
}
|
||||
else if (iter == "--") {
|
||||
if (condition == "<") {
|
||||
report_error(line_number(match),
|
||||
char_number(match),
|
||||
line_str(match),
|
||||
"Error: Unsupported condition in unrolled loop.");
|
||||
const Token for_tok = tokens[8];
|
||||
const Scope loop_args = tokens[9].scope();
|
||||
const Scope loop_body = tokens[13].scope();
|
||||
|
||||
Scope init, cond, iter;
|
||||
parse_for_args(loop_args, init, cond, iter);
|
||||
|
||||
/* Init statement. */
|
||||
const Token var_type = init[0];
|
||||
const Token var_name = init[1];
|
||||
const Token var_init = init[2];
|
||||
if (var_type.str() != "int" && var_type.str() != "uint") {
|
||||
report_error(ERROR_TOK(var_init), "Can only unroll integer based loop.");
|
||||
return;
|
||||
}
|
||||
if (var_init != '=') {
|
||||
report_error(ERROR_TOK(var_init), "Expecting assignment here.");
|
||||
return;
|
||||
}
|
||||
if (init[3] != '0' && init[3] != '-') {
|
||||
report_error(ERROR_TOK(init[3]), "Expecting integer literal here.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
else {
|
||||
report_error(line_number(match),
|
||||
char_number(match),
|
||||
line_str(match),
|
||||
"Error: Unsupported for loop expression. Expecting ++ or --");
|
||||
}
|
||||
|
||||
loop.definition = content;
|
||||
loop.indent = match[1].str();
|
||||
loop.init_statement = match[2].str();
|
||||
if (!match[10].str().empty()) {
|
||||
loop.test_statement = "if (" + match[10].str() + ") ";
|
||||
}
|
||||
loop.iter_statement = match[11].str();
|
||||
loop.body_prefix = match[15].str();
|
||||
/* Conditional statement. */
|
||||
const Token cond_var = cond[0];
|
||||
const Token cond_type = cond[1];
|
||||
const Token cond_sign = (cond[2] == '+' || cond[2] == '-') ? cond[2] : Token::invalid();
|
||||
const Token cond_end = cond_sign.is_valid() ? cond[3] : cond[2];
|
||||
if (cond_var.str() != var_name.str()) {
|
||||
report_error(ERROR_TOK(cond_var), "Non matching loop counter variable.");
|
||||
return;
|
||||
}
|
||||
if (cond_end != '0') {
|
||||
report_error(ERROR_TOK(cond_end), "Expecting integer literal here.");
|
||||
return;
|
||||
}
|
||||
|
||||
add_loop(loop, match, line, lines_in_content);
|
||||
});
|
||||
}
|
||||
{
|
||||
/* [[gpu::unroll(n)]]. */
|
||||
std::regex regex(R"(( *))"
|
||||
R"(\[\[gpu::unroll\((\d+)\)\]\])"
|
||||
R"(\s*for\s*\()"
|
||||
R"(\s*([^;]*);)"
|
||||
R"(\s*([^;]*);)"
|
||||
R"(\s*([^)]*))"
|
||||
R"(\)(\s*))");
|
||||
/* Iteration statement. */
|
||||
const Token iter_var = iter[0];
|
||||
const Token iter_type = iter[1];
|
||||
if (iter_var.str() != var_name.str()) {
|
||||
report_error(ERROR_TOK(iter_var), "Non matching loop counter variable.");
|
||||
return;
|
||||
}
|
||||
if (iter_type == Increment) {
|
||||
if (cond_type == '>') {
|
||||
report_error(ERROR_TOK(for_tok), "Unsupported condition in unrolled loop.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
else if (iter_type == Decrement) {
|
||||
if (cond_type == '<') {
|
||||
report_error(ERROR_TOK(for_tok), "Unsupported condition in unrolled loop.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
else {
|
||||
report_error(ERROR_TOK(iter_type), "Unsupported loop expression. Expecting ++ or --.");
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t line = 0;
|
||||
int64_t init_value = std::stol(
|
||||
parser.substr_range_inclusive(var_init.next(), var_init.scope().end()));
|
||||
int64_t end_value = std::stol(parser.substr_range_inclusive(
|
||||
cond_sign.is_valid() ? cond_sign : cond_end, cond_end));
|
||||
/* TODO(fclem): Support arbitrary strides (aka, arbitrary iter statement). */
|
||||
int iter_count = std::abs(end_value - init_value);
|
||||
if (cond_type == GEqual || cond_type == LEqual) {
|
||||
iter_count += 1;
|
||||
}
|
||||
|
||||
regex_global_search(str, regex, [&](const std::smatch &match) {
|
||||
std::string content = match[0].str();
|
||||
bool condition_is_trivial = (cond_end == cond.end());
|
||||
|
||||
int64_t lines_in_content = line_count(content);
|
||||
|
||||
line += line_count(match.prefix().str()) + lines_in_content;
|
||||
|
||||
Loop loop;
|
||||
loop.iter_count = std::stol(match[2].str());
|
||||
loop.definition = content;
|
||||
loop.indent = match[1].str();
|
||||
loop.init_statement = match[3].str();
|
||||
loop.test_statement = "if (" + match[4].str() + ") ";
|
||||
loop.iter_statement = match[5].str();
|
||||
loop.body_prefix = match[13].str();
|
||||
|
||||
add_loop(loop, match, line, lines_in_content);
|
||||
});
|
||||
}
|
||||
|
||||
std::string out = str;
|
||||
|
||||
/* Copy paste loop iterations. */
|
||||
for (const Loop &loop : loops) {
|
||||
std::string replacement = loop.indent + "{ " + loop.init_statement + ";";
|
||||
for (int64_t i = 0; i < loop.iter_count; i++) {
|
||||
replacement += std::string("\n#line ") + std::to_string(loop.body_line + 1) + "\n";
|
||||
replacement += loop.indent + loop.test_statement + loop.body;
|
||||
replacement += std::string("\n#line ") + std::to_string(loop.definition_line + 1) + "\n";
|
||||
replacement += loop.indent + loop.iter_statement + ";";
|
||||
if (i == loop.iter_count - 1) {
|
||||
replacement += std::string("\n#line ") + std::to_string(loop.end_line + 1) + "\n";
|
||||
replacement += loop.indent + "}";
|
||||
}
|
||||
add_loop(tokens[0], iter_count, condition_is_trivial, init, cond, iter, loop_body);
|
||||
});
|
||||
}
|
||||
{
|
||||
/* [[gpu::unroll(n)]]. */
|
||||
parser.foreach_match("[[w::w(0)]]f(..){..}", [&](const std::vector<Token> tokens) {
|
||||
const Scope loop_args = tokens[12].scope();
|
||||
const Scope loop_body = tokens[16].scope();
|
||||
|
||||
std::string replaced = loop.definition + loop.body;
|
||||
Scope init, cond, iter;
|
||||
parse_for_args(loop_args, init, cond, iter);
|
||||
|
||||
/* Replace all occurrences in case of recursive unrolling. */
|
||||
replace_all(out, replaced, replacement);
|
||||
}
|
||||
int iter_count = std::stol(tokens[7].str());
|
||||
|
||||
add_loop(tokens[0], iter_count, false, init, cond, iter, loop_body);
|
||||
});
|
||||
}
|
||||
} while (parser.apply_mutations());
|
||||
|
||||
/* Check for remaining keywords. */
|
||||
if (out.find("[[gpu::unroll") != std::string::npos) {
|
||||
regex_global_search(str, std::regex(R"(\[\[gpu::unroll)"), [&](const std::smatch &match) {
|
||||
report_error(line_number(match),
|
||||
char_number(match),
|
||||
line_str(match),
|
||||
"Error: Incompatible format for [[gpu::unroll]].");
|
||||
});
|
||||
}
|
||||
parser.foreach_match("[[w::w", [&](const std::vector<Token> tokens) {
|
||||
if (tokens[2].str() == "gpu" && tokens[5].str() == "unroll") {
|
||||
report_error(ERROR_TOK(tokens[0]), "Incompatible loop format for [[gpu::unroll]].");
|
||||
}
|
||||
});
|
||||
|
||||
return out;
|
||||
return parser.result_get();
|
||||
}
|
||||
|
||||
std::string namespace_mutation(const std::string &str, report_callback report_error)
|
||||
|
||||
@@ -114,6 +114,7 @@ enum class ScopeType : char {
|
||||
Namespace = 'N',
|
||||
Struct = 'S',
|
||||
Function = 'F',
|
||||
LoopArgs = 'l',
|
||||
FunctionArgs = 'f',
|
||||
FunctionCall = 'c',
|
||||
Template = 'T',
|
||||
@@ -125,6 +126,9 @@ enum class ScopeType : char {
|
||||
Local = 'L',
|
||||
/* Added scope inside FunctionArgs. */
|
||||
FunctionArg = 'g',
|
||||
/* Added scope inside LoopArgs. */
|
||||
LoopArg = 'r',
|
||||
|
||||
};
|
||||
|
||||
/* Poor man's IndexRange. */
|
||||
@@ -321,6 +325,7 @@ struct ParserData {
|
||||
token_offsets.offsets.emplace_back(offset);
|
||||
}
|
||||
}
|
||||
offset++;
|
||||
token_offsets.offsets.emplace_back(offset);
|
||||
}
|
||||
{
|
||||
@@ -741,12 +746,39 @@ struct Scope {
|
||||
return start().prev().scope();
|
||||
}
|
||||
|
||||
static Scope invalid()
|
||||
{
|
||||
return {"", "", nullptr, size_t(-1)};
|
||||
}
|
||||
|
||||
bool is_valid() const
|
||||
{
|
||||
return data != nullptr && index >= 0;
|
||||
}
|
||||
bool is_invalid() const
|
||||
{
|
||||
return !is_valid();
|
||||
}
|
||||
|
||||
std::string str() const
|
||||
{
|
||||
if (this->is_invalid()) {
|
||||
return "";
|
||||
}
|
||||
return data->str.substr(start().str_index_start(),
|
||||
end().str_index_last() - start().str_index_start() + 1);
|
||||
}
|
||||
|
||||
/* Return the content without the first and last characters. */
|
||||
std::string str_exclusive() const
|
||||
{
|
||||
if (this->is_invalid()) {
|
||||
return "";
|
||||
}
|
||||
return data->str.substr(start().str_index_start() + 1,
|
||||
end().str_index_last() - start().str_index_start() - 1);
|
||||
}
|
||||
|
||||
Token find_token(const char token_type) const
|
||||
{
|
||||
size_t pos = data->token_types.substr(range().start, range().size).find(token_type);
|
||||
@@ -992,7 +1024,12 @@ inline void ParserData::parse_scopes(report_callback &report_error)
|
||||
break;
|
||||
}
|
||||
case ParOpen:
|
||||
if (scopes.top().type == ScopeType::Global) {
|
||||
if ((tok_id >= 1 && token_types[tok_id - 1] == For) ||
|
||||
(tok_id >= 1 && token_types[tok_id - 1] == While))
|
||||
{
|
||||
enter_scope(ScopeType::LoopArgs, tok_id);
|
||||
}
|
||||
else if (scopes.top().type == ScopeType::Global) {
|
||||
enter_scope(ScopeType::FunctionArgs, tok_id);
|
||||
}
|
||||
else if (scopes.top().type == ScopeType::Struct) {
|
||||
@@ -1043,12 +1080,28 @@ inline void ParserData::parse_scopes(report_callback &report_error)
|
||||
if (scopes.top().type == ScopeType::FunctionArg) {
|
||||
exit_scope(tok_id - 1);
|
||||
}
|
||||
if (scopes.top().type == ScopeType::LoopArg) {
|
||||
exit_scope(tok_id - 1);
|
||||
}
|
||||
exit_scope(tok_id);
|
||||
break;
|
||||
case SquareClose:
|
||||
exit_scope(tok_id);
|
||||
break;
|
||||
case SemiColon:
|
||||
if (scopes.top().type == ScopeType::Assignment) {
|
||||
exit_scope(tok_id - 1);
|
||||
}
|
||||
if (scopes.top().type == ScopeType::FunctionArg) {
|
||||
exit_scope(tok_id - 1);
|
||||
}
|
||||
if (scopes.top().type == ScopeType::TemplateArg) {
|
||||
exit_scope(tok_id - 1);
|
||||
}
|
||||
if (scopes.top().type == ScopeType::LoopArg) {
|
||||
exit_scope(tok_id - 1);
|
||||
}
|
||||
break;
|
||||
case Comma:
|
||||
if (scopes.top().type == ScopeType::Assignment) {
|
||||
exit_scope(tok_id - 1);
|
||||
@@ -1064,6 +1117,9 @@ inline void ParserData::parse_scopes(report_callback &report_error)
|
||||
if (scopes.top().type == ScopeType::FunctionArgs) {
|
||||
enter_scope(ScopeType::FunctionArg, tok_id);
|
||||
}
|
||||
if (scopes.top().type == ScopeType::LoopArgs) {
|
||||
enter_scope(ScopeType::LoopArg, tok_id);
|
||||
}
|
||||
if (scopes.top().type == ScopeType::Template) {
|
||||
enter_scope(ScopeType::TemplateArg, tok_id);
|
||||
}
|
||||
@@ -1221,8 +1277,14 @@ struct Parser {
|
||||
}
|
||||
/* Replace everything from `from` to `to` (inclusive).
|
||||
* Return true on success. */
|
||||
bool replace_try(Token from, Token to, const std::string &replacement)
|
||||
bool replace_try(Token from,
|
||||
Token to,
|
||||
const std::string &replacement,
|
||||
bool keep_trailing_whitespaces = false)
|
||||
{
|
||||
if (keep_trailing_whitespaces) {
|
||||
return replace_try(from.str_index_start(), to.str_index_last_no_whitespace(), replacement);
|
||||
}
|
||||
return replace_try(from.str_index_start(), to.str_index_last(), replacement);
|
||||
}
|
||||
|
||||
@@ -1307,6 +1369,10 @@ struct Parser {
|
||||
{
|
||||
insert_after(at, "#line " + std::to_string(line) + "\n");
|
||||
}
|
||||
void insert_line_number(Token at, int line)
|
||||
{
|
||||
insert_line_number(at.str_index_last(), line);
|
||||
}
|
||||
|
||||
void insert_before(size_t at, const std::string &content)
|
||||
{
|
||||
|
||||
@@ -87,92 +87,112 @@ static void test_preprocess_unroll()
|
||||
using namespace std;
|
||||
|
||||
{
|
||||
string input = R"([[gpu::unroll]] for (int i = 2; i < 4; i++, y++) { content += i; })";
|
||||
string expect = R"({ int i = 2;
|
||||
#line 1
|
||||
{ content += i; }
|
||||
#line 1
|
||||
i++, y++;
|
||||
#line 1
|
||||
{ content += i; }
|
||||
#line 1
|
||||
i++, y++;
|
||||
#line 1
|
||||
})";
|
||||
string input = R"(
|
||||
[[gpu::unroll]] for (int i = 2; i < 4; i++, y++) { content += i; })";
|
||||
string expect = R"(
|
||||
|
||||
#line 2
|
||||
{int i = 2;
|
||||
#line 2
|
||||
{ content += i; }
|
||||
#line 2
|
||||
i++, y++;
|
||||
#line 2
|
||||
{ content += i; }
|
||||
#line 2
|
||||
i++, y++;
|
||||
#line 2
|
||||
})";
|
||||
string error;
|
||||
string output = process_test_string(input, error);
|
||||
EXPECT_EQ(output, expect);
|
||||
EXPECT_EQ(error, "");
|
||||
}
|
||||
{
|
||||
string input = R"([[gpu::unroll]] for (int i = 2; i < 4 && i < y; i++, y++) { cont += i; })";
|
||||
string expect = R"({ int i = 2;
|
||||
#line 1
|
||||
if (i < y) { cont += i; }
|
||||
#line 1
|
||||
i++, y++;
|
||||
#line 1
|
||||
if (i < y) { cont += i; }
|
||||
#line 1
|
||||
i++, y++;
|
||||
#line 1
|
||||
})";
|
||||
string input = R"(
|
||||
[[gpu::unroll]] for (int i = 2; i < 4 && i < y; i++, y++) { cont += i; })";
|
||||
string expect = R"(
|
||||
|
||||
#line 2
|
||||
{int i = 2;
|
||||
#line 2
|
||||
if(i < 4 && i < y)
|
||||
#line 2
|
||||
{ cont += i; }
|
||||
#line 2
|
||||
i++, y++;
|
||||
#line 2
|
||||
if(i < 4 && i < y)
|
||||
#line 2
|
||||
{ cont += i; }
|
||||
#line 2
|
||||
i++, y++;
|
||||
#line 2
|
||||
})";
|
||||
string error;
|
||||
string output = process_test_string(input, error);
|
||||
EXPECT_EQ(output, expect);
|
||||
EXPECT_EQ(error, "");
|
||||
}
|
||||
{
|
||||
string input = R"([[gpu::unroll(2)]] for (; i < j;) { content += i; })";
|
||||
string expect = R"({ ;
|
||||
#line 1
|
||||
if (i < j) { content += i; }
|
||||
#line 1
|
||||
;
|
||||
#line 1
|
||||
if (i < j) { content += i; }
|
||||
#line 1
|
||||
;
|
||||
#line 1
|
||||
})";
|
||||
string input = R"(
|
||||
[[gpu::unroll(2)]] for (; i < j;) { content += i; })";
|
||||
string expect = R"(
|
||||
|
||||
{
|
||||
#line 2
|
||||
if(i < j)
|
||||
#line 2
|
||||
{ content += i; }
|
||||
#line 2
|
||||
if(i < j)
|
||||
#line 2
|
||||
{ content += i; }
|
||||
#line 2
|
||||
})";
|
||||
string error;
|
||||
string output = process_test_string(input, error);
|
||||
EXPECT_EQ(output, expect);
|
||||
EXPECT_EQ(error, "");
|
||||
}
|
||||
{
|
||||
string input = R"([[gpu::unroll(2)]] for (; i < j;) { [[gpu::unroll(2)]] for (; j < k;) {} })";
|
||||
string expect = R"({ ;
|
||||
#line 1
|
||||
if (i < j) { { ;
|
||||
#line 1
|
||||
if (j < k) {}
|
||||
#line 1
|
||||
;
|
||||
#line 1
|
||||
if (j < k) {}
|
||||
#line 1
|
||||
;
|
||||
#line 1
|
||||
} }
|
||||
#line 1
|
||||
;
|
||||
#line 1
|
||||
if (i < j) { { ;
|
||||
#line 1
|
||||
if (j < k) {}
|
||||
#line 1
|
||||
;
|
||||
#line 1
|
||||
if (j < k) {}
|
||||
#line 1
|
||||
;
|
||||
#line 1
|
||||
} }
|
||||
#line 1
|
||||
;
|
||||
#line 1
|
||||
})";
|
||||
string input = R"(
|
||||
[[gpu::unroll(2)]] for (; i < j;) { [[gpu::unroll(2)]] for (; j < k;) {} })";
|
||||
string expect = R"(
|
||||
|
||||
{
|
||||
#line 2
|
||||
if(i < j)
|
||||
#line 2
|
||||
{
|
||||
{
|
||||
#line 2
|
||||
if(j < k)
|
||||
#line 2
|
||||
{}
|
||||
#line 2
|
||||
if(j < k)
|
||||
#line 2
|
||||
{}
|
||||
#line 2
|
||||
} }
|
||||
#line 2
|
||||
if(i < j)
|
||||
#line 2
|
||||
{
|
||||
{
|
||||
#line 2
|
||||
if(j < k)
|
||||
#line 2
|
||||
{}
|
||||
#line 2
|
||||
if(j < k)
|
||||
#line 2
|
||||
{}
|
||||
#line 2
|
||||
} }
|
||||
#line 2
|
||||
})";
|
||||
string error;
|
||||
string output = process_test_string(input, error);
|
||||
EXPECT_EQ(output, expect);
|
||||
@@ -182,27 +202,30 @@ if (i < j) { { ;
|
||||
string input = R"([[gpu::unroll(2)]] for (; i < j;) { break; })";
|
||||
string error;
|
||||
string output = process_test_string(input, error);
|
||||
EXPECT_EQ(error, "Error: Unrolled loop cannot contain \"break\" statement.");
|
||||
EXPECT_EQ(error, "Unrolled loop cannot contain \"break\" statement.");
|
||||
}
|
||||
{
|
||||
string input = R"([[gpu::unroll(2)]] for (; i < j;) { continue; })";
|
||||
string error;
|
||||
string output = process_test_string(input, error);
|
||||
EXPECT_EQ(error, "Error: Unrolled loop cannot contain \"continue\" statement.");
|
||||
EXPECT_EQ(error, "Unrolled loop cannot contain \"continue\" statement.");
|
||||
}
|
||||
{
|
||||
string input = R"([[gpu::unroll(2)]] for (; i < j;) { for (; j < k;) {break;continue;} })";
|
||||
string expect = R"({ ;
|
||||
#line 1
|
||||
if (i < j) { for (; j < k;) {break;continue;} }
|
||||
#line 1
|
||||
;
|
||||
#line 1
|
||||
if (i < j) { for (; j < k;) {break;continue;} }
|
||||
#line 1
|
||||
;
|
||||
#line 1
|
||||
})";
|
||||
string input = R"(
|
||||
[[gpu::unroll(2)]] for (; i < j;) { for (; j < k;) {break;continue;} })";
|
||||
string expect = R"(
|
||||
|
||||
{
|
||||
#line 2
|
||||
if(i < j)
|
||||
#line 2
|
||||
{ for (; j < k;) {break;continue;} }
|
||||
#line 2
|
||||
if(i < j)
|
||||
#line 2
|
||||
{ for (; j < k;) {break;continue;} }
|
||||
#line 2
|
||||
})";
|
||||
string error;
|
||||
string output = process_test_string(input, error);
|
||||
EXPECT_EQ(output, expect);
|
||||
@@ -212,7 +235,7 @@ if (i < j) { for (; j < k;) {break;continue;} }
|
||||
string input = R"([[gpu::unroll]] for (int i = 3; i > 2; i++) {})";
|
||||
string error;
|
||||
string output = process_test_string(input, error);
|
||||
EXPECT_EQ(error, "Error: Unsupported condition in unrolled loop.");
|
||||
EXPECT_EQ(error, "Unsupported condition in unrolled loop.");
|
||||
}
|
||||
}
|
||||
GPU_TEST(preprocess_unroll);
|
||||
|
||||
Reference in New Issue
Block a user