GPU: Shader Preprocess: Add basic support for struct methods

This adds the following features:
- `class` keyword support: checked by C++, mutated to struct for shader.
- `private` and `public` keywords: checked by C++, removed for shader.
- `static` methods.
- `const` and non-const methods.

What is not supported:
- Constructors
- Destructors
- operators
- Method definition outside of class definition
- member reference without `this` keyword.

This is implemented using a very simple lexer/parser allowing semantic traversal.

Pull Request: https://projects.blender.org/blender/blender/pulls/144025
This commit is contained in:
Clément Foucault
2025-08-08 16:49:15 +02:00
committed by Clément Foucault
parent 55942f5fbe
commit 628a10a9fb
10 changed files with 1468 additions and 65 deletions

View File

@@ -15,10 +15,11 @@ void main()
* or tweak the levels of the matte. */
bool is_edge = false;
#if defined(COMPUTE_EDGES)
if (true) {
bool compute_edges = true;
#else
if (black_level != 0.0f || white_level != 1.0f) {
bool compute_edges = black_level != 0.0f || white_level != 1.0f;
#endif
if (compute_edges) {
/* Count the number of neighbors whose matte is sufficiently similar to the current matte,
* as controlled by the edge_tolerance factor. */
int count = 0;

View File

@@ -75,27 +75,9 @@ struct DofGatherData {
float transparency;
float layer_opacity;
#if defined(GPU_METAL) || defined(GLSL_CPP_STUBS)
/* Explicit constructors -- To support GLSL syntax. */
inline DofGatherData() = default;
inline DofGatherData(float4 in_color,
float in_weight,
float in_dist,
float in_coc,
float in_coc_sqr,
float in_transparency,
float in_layer_opacity)
: color(in_color),
weight(in_weight),
dist(in_dist),
coc(in_coc),
coc_sqr(in_coc_sqr),
transparency(in_transparency),
layer_opacity(in_layer_opacity)
{
}
#endif
/* clang-format off */
METAL_CONSTRUCTOR_7(DofGatherData, float4, color, float, weight, float, dist, float, coc, float, coc_sqr, float, transparency, float, layer_opacity)
/* clang-format on */
};
#define GATHER_DATA_INIT DofGatherData(float4(0.0f), 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f)

View File

@@ -17,10 +17,7 @@ struct FilterSample {
float4 color;
float weight;
#if defined(GPU_METAL) || defined(GLSL_CPP_STUBS)
inline FilterSample() = default;
inline FilterSample(float4 in_color, float in_weight) : color(in_color), weight(in_weight) {}
#endif
METAL_CONSTRUCTOR_2(FilterSample, float4, color, float, weight)
};
/* -------------------------------------------------------------------- */

View File

@@ -30,11 +30,7 @@ struct DofSample {
float4 color;
float coc;
#if defined(GPU_METAL) || defined(GLSL_CPP_STUBS)
/* Explicit constructors -- To support GLSL syntax. */
inline DofSample() = default;
inline DofSample(float4 in_color, float in_coc) : color(in_color), coc(in_coc) {}
#endif
METAL_CONSTRUCTOR_2(DofSample, float4, color, float, coc)
};
/* -------------------------------------------------------------------- */
@@ -160,11 +156,7 @@ struct DofNeighborhoodMinMax {
DofSample min;
DofSample max;
#if defined(GPU_METAL) || defined(GLSL_CPP_STUBS)
/* Explicit constructors -- To support GLSL syntax. */
inline DofNeighborhoodMinMax() = default;
inline DofNeighborhoodMinMax(DofSample in_min, DofSample in_max) : min(in_min), max(in_max) {}
#endif
METAL_CONSTRUCTOR_2(DofNeighborhoodMinMax, DofSample, min, DofSample, max)
};
/* Return history clipping bounding box in YCoCg color space. */

View File

@@ -8,13 +8,17 @@
#pragma once
#include <cctype>
#include <cstdint>
#include <functional>
#include <iostream>
#include <regex>
#include <sstream>
#include <string>
#include <vector>
#include "shader_parser.hh"
namespace blender::gpu::shader {
/* Metadata extracted from shader source file.
@@ -214,6 +218,8 @@ class Preprocessor {
str = preprocessor_directive_mutation(str);
str = swizzle_function_mutation(str);
if (language == BLENDER_GLSL) {
str = struct_method_mutation(str, report_error);
str = method_call_mutation(str, report_error);
str = stage_function_mutation(str);
str = resource_guard_mutation(str, report_error);
str = loop_unroll(str, report_error);
@@ -235,9 +241,9 @@ class Preprocessor {
str = namespace_mutation(str, report_error);
str = namespace_separator_mutation(str);
}
str = argument_reference_mutation(str);
str = enum_macro_injection(str);
str = default_argument_mutation(str);
str = argument_reference_mutation(str);
str = variable_reference_mutation(str, report_error);
str = template_definition_mutation(str, report_error);
str = template_call_mutation(str);
@@ -263,6 +269,7 @@ class Preprocessor {
private:
using regex_callback = std::function<void(const std::smatch &)>;
using regex_callback_with_line_count = std::function<void(const std::smatch &, int64_t)>;
/* Helper to make the code more readable in parsing functions. */
void regex_global_search(const std::string &str,
@@ -276,6 +283,19 @@ class Preprocessor {
}
}
void regex_global_search(const std::string &str,
const std::regex &regex,
regex_callback_with_line_count callback)
{
using namespace std;
int64_t line = 1;
regex_global_search(str, regex, [&line, &callback](const std::smatch &match) {
line += line_count(match.prefix().str());
callback(match, line);
line += line_count(match[0].str());
});
}
template<typename ReportErrorF>
std::string remove_comments(const std::string &str, const ReportErrorF &report_error)
{
@@ -787,7 +807,7 @@ class Preprocessor {
out_str + '}', '{', '}', true);
if (parent_scope.empty()) {
report_error(match, "The `using` keyword is not allowed in global scope.");
break;
return str;
}
/* Ensure we are bringing symbols from the same namespace.
* Otherwise we can have different shadowing outcome between shader and C++. */
@@ -795,7 +815,7 @@ class Preprocessor {
size_t pos = out_str.rfind(ns_keyword, out_str.size() - parent_scope.size());
if (pos == string::npos) {
report_error(match, "Couldn't find `namespace` keyword at beginning of scope.");
break;
return str;
}
size_t start = pos + ns_keyword.size();
size_t end = out_str.size() - parent_scope.size() - start - 2;
@@ -806,7 +826,7 @@ class Preprocessor {
"The `using` keyword is only allowed in namespace scope to make visible symbols "
"from the same namespace declared in another scope, potentially from another "
"file.");
break;
return str;
}
}
/** IMPORTANT: `match` is invalid after the assignment. */
@@ -1037,6 +1057,178 @@ class Preprocessor {
return str;
}
/* Move all method definition outside of struct definition blocks. */
std::string struct_method_mutation(const std::string &str, report_callback report_error)
{
using namespace std;
using namespace shader::parser;
Parser parser(str);
parser.foreach_scope(ScopeType::Global, [&](Scope scope) {
/* `class` -> `struct` */
scope.foreach_match("S", [&](const std::vector<Token> &tokens) {
parser.replace(tokens[0], tokens[0], "struct ");
});
});
parser.apply_mutations();
parser.foreach_scope(ScopeType::Global, [&](Scope scope) {
scope.foreach_match("sw", [&](const std::vector<Token> &tokens) {
const Token struct_name = tokens[1];
if (struct_name.next() == ':') {
/* TODO(fclem): Good report. */
report_error(smatch(), "class inheritance is not supported");
return;
}
if (struct_name.next() == '<') {
/* TODO(fclem): Good report. */
report_error(smatch(), "class template is not supported");
return;
}
if (struct_name.next() != '{') {
/* TODO(fclem): Good report. */
report_error(smatch(), "Expected `{`");
return;
}
const Scope struct_scope = struct_name.next().scope();
const Token struct_end = struct_scope.end().next();
/* Erase `public:` and `private:` keywords. */
struct_scope.foreach_match("v:", [&](const std::vector<Token> &tokens) {
parser.erase(tokens[0].line_start(), tokens[1].line_end());
});
struct_scope.foreach_match("V:", [&](const std::vector<Token> &tokens) {
parser.erase(tokens[0].line_start(), tokens[1].line_end());
});
struct_scope.foreach_match("ww(", [&](const std::vector<Token> &tokens) {
if (tokens[0].prev() == Const) {
/* TODO(fclem): Good report. */
report_error(smatch(),
"function return type is marked `const` but it makes no sense for values "
"and returning reference is not supported");
return;
}
const bool is_static = tokens[0].prev() == Static;
const Token fn_start = is_static ? tokens[0].prev() : tokens[0];
const Scope fn_args = tokens[2].scope();
const Token after_args = fn_args.end().next();
const bool is_const = after_args == Const;
const Scope fn_body = (is_const ? after_args.next() : after_args).scope();
string fn_content = parser.substr_range_inclusive(fn_start.line_start(),
fn_body.end().line_end() + 1);
Parser fn_parser(fn_content);
fn_parser.foreach_scope(ScopeType::Global, [&](Scope scope) {
if (is_static) {
scope.foreach_match("mww(", [&](const std::vector<Token> &tokens) {
const Token fn_name = tokens[2];
fn_parser.replace(
fn_name, fn_name, struct_name.str_no_whitespace() + "::" + fn_name.str());
});
}
else {
scope.foreach_match("ww(", [&](const std::vector<Token> &tokens) {
const Scope args = tokens[2].scope();
const bool has_no_args = args.token_count() == 2;
const char *suffix = (has_no_args ? "" : ", ");
if (is_const) {
fn_parser.erase(args.end().next());
fn_parser.insert_after(
args.start(), "const " + struct_name.str_no_whitespace() + " this" + suffix);
}
else {
fn_parser.insert_after(args.start(),
struct_name.str_no_whitespace() + " &this" + suffix);
}
});
}
/* `*this` -> `this` */
scope.foreach_match("*T", [&](const std::vector<Token> &tokens) {
fn_parser.replace(tokens[0], tokens[1], tokens[1].str());
});
/* `this->` -> `this.` */
scope.foreach_match("TD", [&](const std::vector<Token> &tokens) {
fn_parser.replace(tokens[0], tokens[1], tokens[0].str() + ".");
});
});
string line_directive = "#line " + std::to_string(fn_start.line_number()) + '\n';
parser.erase(fn_start.line_start(), fn_body.end().line_end());
parser.insert_after(struct_end.line_end() + 1, line_directive + fn_parser.result_get());
});
string line_directive = "#line " + std::to_string(struct_end.line_number() + 1) + '\n';
parser.insert_after(struct_end.line_end() + 1, line_directive);
});
});
return parser.result_get();
}
/* Transform `a.fn(b)` into `fn(a, b)`. */
std::string method_call_mutation(const std::string &str, report_callback report_error)
{
using namespace std;
using namespace shader::parser;
Parser parser(str);
do {
parser.foreach_scope(ScopeType::Function, [&](Scope scope) {
scope.foreach_match(".w(", [&](const std::vector<Token> &tokens) {
const Token dot = tokens[0];
const Token func = tokens[1];
const Token par_open = tokens[2];
const Token end_of_this = dot.prev();
Token start_of_this = end_of_this;
while (true) {
if (start_of_this == ')') {
/* Function call. Take argument scope and function name. No recursion. */
start_of_this = start_of_this.scope().start().prev();
break;
}
if (start_of_this == ']') {
/* Array subscript. Take scope and continue. */
start_of_this = start_of_this.scope().start().prev();
continue;
}
if (start_of_this == Word) {
/* Member. */
if (start_of_this.prev() == '.') {
start_of_this = start_of_this.prev().prev();
/* Continue until we find root member. */
continue;
}
/* End of chain. */
break;
}
std::string error = "method_call_mutation parsing error : " + start_of_this.str() +
to_string(start_of_this.type());
report_error(smatch(), error.c_str());
break;
}
string this_str = parser.substr_range_inclusive(start_of_this, end_of_this);
string func_str = func.str();
const bool has_no_arg = par_open.next() == ')';
/* `a.fn(b)` -> `fn(a, b)` */
parser.replace_try(
start_of_this, par_open, func_str + "(" + this_str + (has_no_arg ? "" : ", "));
});
});
} while (parser.apply_mutations());
return parser.result_get();
}
std::string stage_function_mutation(const std::string &str)
{
using namespace std;
@@ -1239,7 +1431,7 @@ class Preprocessor {
const bool has_non_void_return_type = return_type != "void";
string line_directive = "#line " + to_string(line - lines_in_content + 2) + "\n";
string line_directive = "#line " + std::to_string(line - lines_in_content + 2) + "\n";
vector<string> args_split = split_string_not_between_balanced_pair(args, ',', '(', ')');
string overloads;
@@ -1287,7 +1479,7 @@ class Preprocessor {
"}\n";
string last_line_directive =
"#line " + to_string(line - lines_in_content + line_count(body_content) + 3) + "\n";
"#line " + std::to_string(line - lines_in_content + line_count(body_content) + 3) + "\n";
mutations.emplace_back(with_default + body_content,
no_default + body_content + overloads + last_line_directive);
@@ -1318,29 +1510,29 @@ class Preprocessor {
/* To be run before `argument_decorator_macro_injection()`. */
std::string argument_reference_mutation(std::string &str)
{
/* Next two REGEX checks are expensive. Check if they are needed at all. */
bool valid_match = false;
reference_search(str, [&](int parenthesis_depth, int bracket_depth, char &c) {
/* Check if inside a function signature.
* Check parenthesis_depth == 2 for array references. */
if ((parenthesis_depth == 1 || parenthesis_depth == 2) && bracket_depth == 0) {
valid_match = true;
/* Modify the & into @ to make sure we only match these references in the regex
* below. @ being forbidden in the shader language, it is safe to use a temp
* character. */
c = '@';
using namespace std;
using namespace shader::parser;
Parser parser(str);
auto add_mutation = [&](Token type, Token arg_name, Token last_tok) {
if (type.prev() == Const) {
parser.replace(type.prev(), last_tok, type.str() + arg_name.str());
}
else {
parser.replace(type, last_tok, "inout " + type.str() + arg_name.str());
}
};
parser.foreach_scope(ScopeType::FunctionArgs, [&](const Scope scope) {
scope.foreach_match(
"w(&w)", [&](const vector<Token> toks) { add_mutation(toks[0], toks[3], toks[4]); });
scope.foreach_match(
"w&w", [&](const vector<Token> toks) { add_mutation(toks[0], toks[2], toks[2]); });
scope.foreach_match(
"w&T", [&](const vector<Token> toks) { add_mutation(toks[0], toks[2], toks[2]); });
});
if (!valid_match) {
return str;
}
/* Remove parenthesis first. */
/* Example: `float (&var)[2]` > `float &var[2]` */
std::regex regex_parenthesis(R"((\w+ )\(@(\w+)\))");
std::string out = std::regex_replace(str, regex_parenthesis, "$1@$2");
/* Example: `const float &var[2]` > `inout float var[2]` */
std::regex regex(R"((?:const)?(\s*)(\w+)\s+\@(\w+)(\[\d*\])?)");
return std::regex_replace(out, regex, "$1 inout $2 $3$4");
return parser.result_get();
}
/* To be run after `argument_reference_mutation()`. */

View File

@@ -0,0 +1,1006 @@
/* SPDX-FileCopyrightText: 2025 Blender Authors
*
* SPDX-License-Identifier: GPL-2.0-or-later */
/** \file
* \ingroup glsl_preprocess
*
* Very simple parsing of our shader file that are a subset of C++. It allows to traverse the
* semantic using tokens and scopes instead of trying to match string patterns throughout the whole
* input string.
*
* The goal of this representation is to output code that doesn't modify the style of the input
* string and keep the same line numbers (to match compilation error with input source).
*
* The `Parser` class contain a copy of the given string to apply string substitutions (called
* `Mutation`). It is usually faster to record all of them and apply them all at once after
* scanning through the whole semantic representation. In the rare case where mutation need to
* overlap (recursive processing), it is better to do them in passes until there is no mutation to
* do.
*
* `Token` and `Scope` are read only interfaces to the data stored inside the `ParserData`.
* The data is stored as SoA (Structure of Arrays) for fast traversal.
* The types of token and scopes are defined as readable chars to easily create sequences of token
* type.
*
* The `Parser` object needs to be fed a well formed source (without preprocessor directive, see
* below), otherwise a crash can occur. The `Parser` doesn't apply any preprocessor. All
* preprocessor directive are parsed as `Preprocessor` scope but they are not expanded.
*
* By default, whitespaces are merged with the previous token. Only a handful of processing
* requires access to whitespaces as individual tokens.
*/
#pragma once
#include <cassert>
#include <cctype>
#include <chrono>
#include <cstdint>
#include <functional>
#include <iostream>
#include <stack>
#include <string>
#include <vector>
namespace blender::gpu::shader::parser {
enum TokenType : char {
/* Use ascii chars to store them in string, and for easy debugging / testing. */
Word = 'w',
NewLine = '\n',
Space = ' ',
Dot = '.',
Hash = '#',
Ampersand = '&',
Literal = '0',
ParOpen = '(',
ParClose = ')',
BracketOpen = '{',
BracketClose = '}',
SquareOpen = '[',
SquareClose = ']',
AngleOpen = '<',
AngleClose = '>',
Assign = '=',
SemiColon = ';',
Question = '?',
Not = '!',
Colon = ':',
Comma = ',',
Star = '*',
Plus = '+',
Minus = '-',
Divide = '/',
Tilde = '~',
Backslash = '\\',
/* Keywords */
Namespace = 'n',
Struct = 's',
Class = 'S',
Const = 'c',
Constexpr = 'C',
Return = 'r',
Switch = 'h',
Case = 'H',
If = 'i',
Else = 'I',
For = 'f',
While = 'F',
Do = 'd',
Template = 't',
This = 'T',
Deref = 'D',
Static = 'm',
PreprocessorNewline = 'N',
Equal = 'E',
NotEqual = 'e',
GEqual = 'G',
LEqual = 'L',
Increment = 'P',
Decrement = 'D',
Private = 'v',
Public = 'V',
};
enum class ScopeType : char {
/* Use ascii chars to store them in string, and for easy debugging / testing. */
Global = 'G',
Namespace = 'N',
Struct = 'S',
Function = 'F',
FunctionArgs = 'f',
Template = 'T',
Subscript = 'A',
Preprocessor = 'P',
Assignment = 'a',
/* Added scope inside function body. */
Local = 'L',
};
/* Poor man's IndexRange. */
struct IndexRange {
size_t start;
size_t size;
IndexRange(size_t start, size_t size) : start(start), size(size) {}
bool overlaps(IndexRange other) const
{
return ((start < other.start) && (other.start < (start + size))) ||
((other.start < start) && (start < (other.start + other.size)));
}
size_t last()
{
return start + size - 1;
}
};
/* Poor man's OffsetIndices. */
struct OffsetIndices {
std::vector<size_t> offsets;
IndexRange operator[](const int64_t index) const
{
return {offsets[index], offsets[index + 1] - offsets[index]};
}
void clear()
{
offsets.clear();
};
};
struct Scope;
struct ParserData {
std::string str;
std::string token_types;
std::string scope_types;
/* Ranges of characters per token. */
OffsetIndices token_offsets;
/* Index of bottom most scope per token. */
std::vector<int> token_scope;
/* Range of token per scope. */
std::vector<IndexRange> scope_ranges;
/* If keep_whitespace is false, whitespaces are merged with the previous token. */
void tokenize(const bool keep_whitespace)
{
{
/* Tokenization. */
token_types.clear();
token_offsets.clear();
token_types += char(to_type(str[0]));
token_offsets.offsets.emplace_back(0);
/* When doing whitespace merging, keep knowledge about whether previous char was whitespace.
* This allows to still split words on spaces. */
bool prev_was_whitespace = (token_types[0] == NewLine || token_types[0] == Space);
bool inside_preprocessor_directive = false;
int offset = 0;
for (const char &c : str.substr(1)) {
offset++;
TokenType type = to_type(c);
TokenType prev = TokenType(token_types.back());
/* Detect preprocessor directive newlines `\\\n`. */
if (prev == Backslash && type == NewLine) {
token_types.back() = PreprocessorNewline;
continue;
}
/* Make sure to keep the ending newline for a preprocessor directive. */
if (inside_preprocessor_directive && type == NewLine) {
inside_preprocessor_directive = false;
token_types += char(type);
token_offsets.offsets.emplace_back(offset);
continue;
}
if (type == Hash) {
inside_preprocessor_directive = true;
}
/* Merge newlines and spaces with previous token. */
if (!keep_whitespace && (type == NewLine || type == Space)) {
prev_was_whitespace = true;
continue;
}
/* Merge '=='. */
if (prev == Assign && type == Assign) {
token_types.back() = Equal;
continue;
}
/* Merge '!='. */
if (prev == '!' && type == Assign) {
token_types.back() = NotEqual;
continue;
}
/* Merge '>='. */
if (prev == '>' && type == Assign) {
token_types.back() = GEqual;
continue;
}
/* Merge '<='. */
if (prev == '<' && type == Assign) {
token_types.back() = LEqual;
continue;
}
/* Merge '->'. */
if (prev == '-' && type == '>') {
token_types.back() = Deref;
continue;
}
/* If digit is part of word. */
if (type == Literal && prev == Word) {
continue;
}
/* If 'x' is part of hex literal. */
if (c == 'x' && prev == Literal) {
continue;
}
/* If 'A-F' is part of hex literal. */
if (c >= 'A' && c <= 'F' && prev == Literal) {
continue;
}
/* If 'a-f' is part of hex literal. */
if (c >= 'a' && c <= 'f' && prev == Literal) {
continue;
}
/* If 'u' is part of unsigned int literal. */
if (c == 'u' && prev == Literal) {
continue;
}
/* If dot is part of float literal. */
if (type == Dot && prev == Literal) {
continue;
}
/* If 'f' suffix is part of float literal. */
if (c == 'f' && prev == Literal) {
continue;
}
/* If 'e' is part of float literal. */
if (c == 'e' && prev == Literal) {
continue;
}
/* If sign is part of float literal after exponent. */
if ((c == '+' || c == '-') && prev == Literal) {
continue;
}
/* Detect increment. */
if (type == '+' && prev == '+') {
token_types.back() = Increment;
continue;
}
/* Detect decrement. */
if (type == '+' && prev == '+') {
token_types.back() = Decrement;
continue;
}
/* Only merge these token. Otherwise, always emit a token. */
if (type != Word && type != NewLine && type != Space && type != Literal) {
prev = Word;
}
/* Split words on whitespaces even when merging. */
if (!keep_whitespace && type == Word && prev_was_whitespace) {
prev = Space;
prev_was_whitespace = false;
}
/* Emit a token if we don't merge. */
if (type != prev) {
token_types += char(type);
token_offsets.offsets.emplace_back(offset);
}
}
}
{
/* Keywords detection. */
int tok_id = -1;
for (char &c : token_types) {
tok_id++;
if (TokenType(c) == Word) {
IndexRange range = token_offsets[tok_id];
std::string word = str.substr(range.start, range.size);
if (!keep_whitespace) {
size_t last_non_whitespace = word.find_last_not_of(" \n");
if (last_non_whitespace != std::string::npos) {
word = word.substr(0, last_non_whitespace + 1);
}
}
if (word == "namespace") {
c = Namespace;
}
else if (word == "struct") {
c = Struct;
}
else if (word == "class") {
c = Class;
}
else if (word == "const") {
c = Const;
}
else if (word == "constexpr") {
c = Constexpr;
}
else if (word == "return") {
c = Return;
}
else if (word == "case") {
c = Case;
}
else if (word == "switch") {
c = Switch;
}
else if (word == "if") {
c = If;
}
else if (word == "else") {
c = Else;
}
else if (word == "while") {
c = While;
}
else if (word == "do") {
c = Do;
}
else if (word == "for") {
c = For;
}
else if (word == "template") {
c = Template;
}
else if (word == "this") {
c = This;
}
else if (word == "static") {
c = Static;
}
else if (word == "private") {
c = Private;
}
else if (word == "public") {
c = Public;
}
}
}
}
}
void parse_scopes()
{
{
/* Scope detection. */
scope_ranges.clear();
scope_types.clear();
struct ScopeItem {
ScopeType type;
size_t start;
int index;
};
int scope_index = 0;
std::stack<ScopeItem> scopes;
auto enter_scope = [&](ScopeType type, size_t start_tok_id) {
scopes.emplace(ScopeItem{type, start_tok_id, scope_index++});
scope_ranges.emplace_back(start_tok_id, 1);
scope_types += char(type);
};
auto exit_scope = [&](int end_tok_id) {
ScopeItem scope = scopes.top();
scope_ranges[scope.index].size = end_tok_id - scope.start + 1;
scopes.pop();
};
enter_scope(ScopeType::Global, 0);
bool in_template = false;
int tok_id = -1;
for (char &c : token_types) {
tok_id++;
if (scopes.top().type == ScopeType::Preprocessor) {
if (TokenType(c) == NewLine) {
exit_scope(tok_id);
}
else {
/* Do nothing. Enclose all preprocessor lines together. */
continue;
}
}
switch (TokenType(c)) {
case Hash:
enter_scope(ScopeType::Preprocessor, tok_id);
break;
case Assign:
if (scopes.top().type == ScopeType::Assignment) {
/* Chained assignments. */
exit_scope(tok_id - 1);
}
enter_scope(ScopeType::Assignment, tok_id);
break;
case BracketOpen:
if (token_types[tok_id - 2] == Struct) {
enter_scope(ScopeType::Local, tok_id);
}
else if (token_types[tok_id - 2] == Namespace) {
enter_scope(ScopeType::Namespace, tok_id);
}
else if (scopes.top().type == ScopeType::Global) {
enter_scope(ScopeType::Function, tok_id);
}
else if (scopes.top().type == ScopeType::Struct) {
enter_scope(ScopeType::Function, tok_id);
}
else {
enter_scope(ScopeType::Local, tok_id);
}
break;
case ParOpen:
if (scopes.top().type == ScopeType::Global) {
enter_scope(ScopeType::FunctionArgs, tok_id);
}
else if (scopes.top().type == ScopeType::Struct) {
enter_scope(ScopeType::FunctionArgs, tok_id);
}
else {
enter_scope(ScopeType::Local, tok_id);
}
break;
case SquareOpen:
enter_scope(ScopeType::Subscript, tok_id);
break;
case AngleOpen:
if (token_types[tok_id - 1] == Template) {
enter_scope(ScopeType::Template, tok_id);
in_template = true;
}
break;
case AngleClose:
if (in_template && scopes.top().type == ScopeType::Assignment) {
exit_scope(tok_id - 1);
}
if (scopes.top().type == ScopeType::Template) {
exit_scope(tok_id);
}
break;
case BracketClose:
case ParClose:
if (scopes.top().type == ScopeType::Assignment) {
exit_scope(tok_id - 1);
}
exit_scope(tok_id);
break;
case SquareClose:
exit_scope(tok_id);
break;
case SemiColon:
case Comma:
if (scopes.top().type == ScopeType::Assignment) {
exit_scope(tok_id - 1);
}
break;
default:
break;
}
}
exit_scope(tok_id);
/* Some syntax confuses the parser. Bisect the error by removing things in the source file
* until the error is found. Then either fix the unsupported syntax in the parser or use
* alternative syntax. */
assert(scopes.empty());
}
{
token_scope.clear();
token_scope.resize(scope_ranges[0].size);
int scope_id = -1;
for (const IndexRange &range : scope_ranges) {
scope_id++;
for (int i = 0; i < range.size; i++) {
int j = range.start + i;
token_scope[j] = scope_id;
}
}
}
}
private:
TokenType to_type(const char c)
{
switch (c) {
case '\n':
return TokenType::NewLine;
case ' ':
return TokenType::Space;
case '#':
return TokenType::Hash;
case '&':
return TokenType::Ampersand;
case '.':
return TokenType::Dot;
case '(':
return TokenType::ParOpen;
case ')':
return TokenType::ParClose;
case '{':
return TokenType::BracketOpen;
case '}':
return TokenType::BracketClose;
case '[':
return TokenType::SquareOpen;
case ']':
return TokenType::SquareClose;
case '<':
return TokenType::AngleOpen;
case '>':
return TokenType::AngleClose;
case '=':
return TokenType::Assign;
case '!':
return TokenType::Not;
case '*':
return TokenType::Star;
case '-':
return TokenType::Minus;
case '+':
return TokenType::Plus;
case '/':
return TokenType::Divide;
case '~':
return TokenType::Tilde;
case '\\':
return TokenType::Backslash;
case '?':
return TokenType::Question;
case ':':
return TokenType::Colon;
case ',':
return TokenType::Comma;
case ';':
return TokenType::SemiColon;
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '9':
return TokenType::Literal;
default:
return TokenType::Word;
}
}
};
struct Token {
const ParserData *data;
size_t index;
static Token invalid()
{
return {nullptr, 0};
}
bool is_valid() const
{
return data != nullptr;
}
/* String index range. */
IndexRange index_range() const
{
return data->token_offsets[index];
}
Token prev() const
{
return {data, index - 1};
}
Token next() const
{
return {data, index + 1};
}
/* Returns the scope that contains this token. */
Scope scope() const;
size_t str_index_start() const
{
return index_range().start;
}
size_t str_index_last() const
{
return index_range().last();
}
/* Index of the first character of the line this token is. */
size_t line_start() const
{
size_t pos = data->str.rfind('\n', str_index_start());
return (pos == std::string::npos) ? 0 : (pos + 1);
}
/* Index of the last character of the line this token is, excluding `\n`. */
size_t line_end() const
{
size_t pos = data->str.find('\n', str_index_start());
return (pos == std::string::npos) ? (data->str.size() - 1) : (pos - 1);
}
std::string str() const
{
return data->str.substr(index_range().start, index_range().size);
}
std::string str_no_whitespace() const
{
std::string str = this->str();
return str.substr(0, str.find_last_not_of(" \n") + 1);
}
/* Return the line number this token is found at. Take into account the #line directives. */
size_t line_number() const
{
std::string directive = "#line ";
/* String to count the number of line. */
std::string sub_str = data->str.substr(0, str_index_start());
size_t nearest_line_directive = sub_str.rfind(directive);
size_t line_count = 1;
if (nearest_line_directive != std::string::npos) {
sub_str = sub_str.substr(nearest_line_directive + directive.size());
line_count = std::stoll(sub_str) - 1;
}
return line_count + std::count(sub_str.begin(), sub_str.end(), '\n');
}
TokenType type() const
{
return TokenType(*this);
}
operator TokenType() const
{
return TokenType(data->token_types[index]);
}
bool operator==(TokenType type) const
{
return TokenType(*this) == type;
}
bool operator!=(TokenType type) const
{
return !(*this == type);
}
};
struct Scope {
const ParserData *data;
size_t index;
Token start() const
{
return {data, range().start};
}
Token end() const
{
return {data, range().last()};
}
IndexRange range() const
{
return data->scope_ranges[index];
}
size_t token_count() const
{
return range().size;
}
ScopeType type() const
{
return ScopeType(data->scope_types[index]);
}
/* Returns the scope that contains this scope. */
Scope scope() const
{
return start().prev().scope();
}
std::string str() const
{
return data->str.substr(start().str_index_start(),
end().str_index_last() - start().str_index_start());
}
void foreach_match(const std::string &pattern,
std::function<void(const std::vector<Token>)> callback) const
{
const std::string scope_tokens = data->token_types.substr(range().start, range().size);
std::vector<Token> match;
match.resize(pattern.size());
size_t pos = 0;
while ((pos = scope_tokens.find(pattern, pos)) != std::string::npos) {
match[0] = {data, range().start + pos};
/* Do not match preprocessor directive by default. */
if (match[0].scope().type() != ScopeType::Preprocessor) {
for (int i = 1; i < pattern.size(); i++) {
match[i] = Token{data, range().start + pos + i};
}
callback(match);
}
pos += 1;
}
}
};
inline Scope Token::scope() const
{
return {data, size_t(data->token_scope[index])};
}
struct Parser {
private:
ParserData data_;
/* If false, the whitespaces are fused with the tokens. Otherwise they are kept as separate space
* and newline tokens. */
bool keep_whitespace_;
struct Mutation {
/* Range of the original string to replace. */
IndexRange src_range;
/* The replacement string. */
std::string replacement;
Mutation(IndexRange src_range, std::string replacement)
: src_range(src_range), replacement(replacement)
{
}
/* Define operator in order to sort the mutation by starting position.
* Otherwise, applying them in one pass will not work. */
friend bool operator<(const Mutation &a, const Mutation &b)
{
return a.src_range.start < b.src_range.start;
}
};
std::vector<Mutation> mutations_;
public:
Parser(const std::string &input, bool keep_whitespace = false)
: keep_whitespace_(keep_whitespace)
{
data_.str = input;
parse();
}
/* Run a callback for all existing scopes of a given type. */
void foreach_scope(ScopeType type, std::function<void(Scope)> callback)
{
size_t pos = 0;
while ((pos = data_.scope_types.find(char(type), pos)) != std::string::npos) {
callback(Scope{&data_, pos});
pos += 1;
}
}
/* Run a callback for all existing function scopes. */
void foreach_function(
std::function<void(
bool is_static, Token type, Token name, Scope args, bool is_const, Scope body)> callback)
{
foreach_scope(ScopeType::FunctionArgs, [&](const Scope args) {
const bool is_const = args.end().next() == Const;
Token next = (is_const ? args.end().next() : args.end()).next();
if (next != '{') {
/* Function Prototype. */
return;
}
const bool is_static = args.start().prev().prev().prev() == Static;
Token type = args.start().prev().prev();
Token name = args.start().prev();
Scope body = next.scope();
callback(is_static, type, name, args, is_const, body);
});
}
std::string substr_range_inclusive(size_t start, size_t end)
{
return data_.str.substr(start, end - start + 1);
}
std::string substr_range_inclusive(Token start, Token end)
{
return substr_range_inclusive(start.str_index_start(), end.str_index_last());
}
/* Replace everything from `from` to `to` (inclusive).
* Return true on success. */
bool replace_try(size_t from, size_t to, const std::string &replacement)
{
IndexRange range = IndexRange(from, to + 1 - from);
for (const Mutation &mut : mutations_) {
if (mut.src_range.overlaps(range)) {
return false;
}
}
mutations_.emplace_back(range, replacement);
return true;
}
/* Replace everything from `from` to `to` (inclusive).
* Return true on success. */
bool replace_try(Token from, Token to, const std::string &replacement)
{
return replace_try(from.str_index_start(), to.str_index_last(), replacement);
}
/* Replace everything from `from` to `to` (inclusive). */
void replace(size_t from, size_t to, const std::string &replacement)
{
bool success = replace_try(from, to, replacement);
assert(success);
(void)success;
}
/* Replace everything from `from` to `to` (inclusive). */
void replace(Token from, Token to, const std::string &replacement)
{
replace(from.str_index_start(), to.str_index_last(), replacement);
}
/* Replace the content from `from` to `to` (inclusive) by whitespaces without changing
* line count and keep the remaining indentation spaces. */
void erase(size_t from, size_t to)
{
IndexRange range = IndexRange(from, to + 1 - from);
std::string content = data_.str.substr(range.start, range.size);
size_t lines = std::count(content.begin(), content.end(), '\n');
size_t spaces = content.find_last_not_of(" ");
if (spaces != std::string::npos) {
spaces = content.length() - (spaces + 1);
}
replace(from, to, std::string(lines, '\n') + std::string(spaces, ' '));
}
/* Replace the content from `from` to `to` (inclusive) by whitespaces without changing
* line count and keep the remaining indentation spaces. */
void erase(Token from, Token to)
{
erase(from.str_index_start(), to.str_index_last());
}
/* Replace the content from `from` to `to` (inclusive) by whitespaces without changing
* line count and keep the remaining indentation spaces. */
void erase(Token tok)
{
erase(tok, tok);
}
void insert_after(size_t at, const std::string &content)
{
IndexRange range = IndexRange(at + 1, 0);
mutations_.emplace_back(range, content);
}
void insert_after(Token at, const std::string &content)
{
insert_after(at.str_index_last(), content);
}
void insert_before(size_t at, const std::string &content)
{
IndexRange range = IndexRange(at, 0);
mutations_.emplace_back(range, content);
}
void insert_before(Token at, const std::string &content)
{
insert_after(at.str_index_start(), content);
}
/* Return true if any mutation was applied. */
bool apply_mutations()
{
if (mutations_.empty()) {
return false;
}
/* Order mutations so that they can be applied in one pass. */
std::sort(mutations_.begin(), mutations_.end());
int64_t offset = 0;
for (const Mutation &mut : mutations_) {
data_.str.replace(mut.src_range.start + offset, mut.src_range.size, mut.replacement);
offset += mut.replacement.size() - mut.src_range.size;
}
mutations_.clear();
this->parse();
return true;
}
/* Apply mutations if any and get resulting string. */
const std::string &result_get()
{
apply_mutations();
return data_.str;
}
/* For testing. */
const ParserData &data_get()
{
return data_;
}
/* For testing. */
std::string serialize_mutations() const
{
std::string out;
for (const Mutation &mut : mutations_) {
out += "Replace \"";
out += data_.str.substr(mut.src_range.start, mut.src_range.size);
out += "\" by \"";
out += mut.replacement;
out += "\"\n";
}
return out;
}
private:
using Duration = std::chrono::microseconds;
Duration tokenize_time;
Duration parse_scope_time;
struct TimeIt {
Duration &time;
std::chrono::high_resolution_clock::time_point start;
TimeIt(Duration &time) : time(time)
{
start = std::chrono::high_resolution_clock::now();
}
~TimeIt()
{
auto end = std::chrono::high_resolution_clock::now();
time = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
}
};
void parse()
{
{
TimeIt time_it(parse_scope_time);
data_.tokenize(keep_whitespace_);
}
{
TimeIt time_it(tokenize_time);
data_.parse_scopes();
}
}
public:
void print_stats()
{
std::cout << "Tokenize time: " << tokenize_time.count() << " µs" << std::endl;
std::cout << "Parser time: " << parse_scope_time.count() << " µs" << std::endl;
std::cout << "String len: " << std::to_string(data_.str.size()) << std::endl;
std::cout << "Token len: " << std::to_string(data_.token_types.size()) << std::endl;
std::cout << "Scope len: " << std::to_string(data_.scope_types.size()) << std::endl;
}
void debug_print()
{
std::cout << "Input: \n" << data_.str << " \nEnd of Input\n" << std::endl;
std::cout << "Token Types: \"" << data_.token_types << "\"" << std::endl;
std::cout << "Scope Types: \"" << data_.scope_types << "\"" << std::endl;
}
};
} // namespace blender::gpu::shader::parser

View File

@@ -1003,6 +1003,21 @@ void groupMemoryBarrier() {}
class_name(t1 m1##_, t2 m2##_, t3 m3##_, t4 m4##_) \
: m1(m1##_), m2(m2##_), m3(m3##_), m4(m4##_){};
#define METAL_CONSTRUCTOR_5(class_name, t1, m1, t2, m2, t3, m3, t4, m4, t5, m5) \
class_name() = default; \
class_name(t1 m1##_, t2 m2##_, t3 m3##_, t4 m4##_, t5 m5##_) \
: m1(m1##_), m2(m2##_), m3(m3##_), m4(m4##_), m5(m5##_){};
#define METAL_CONSTRUCTOR_6(class_name, t1, m1, t2, m2, t3, m3, t4, m4, t5, m5, t6, m6) \
class_name() = default; \
class_name(t1 m1##_, t2 m2##_, t3 m3##_, t4 m4##_, t5 m5##_, t6 m6##_) \
: m1(m1##_), m2(m2##_), m3(m3##_), m4(m4##_), m5(m5##_), m6(m6##_){};
#define METAL_CONSTRUCTOR_7(class_name, t1, m1, t2, m2, t3, m3, t4, m4, t5, m5, t6, m6, t7, m7) \
class_name() = default; \
class_name(t1 m1##_, t2 m2##_, t3 m3##_, t4 m4##_, t5 m5##_, t6 m6##_, t7 m7##_) \
: m1(m1##_), m2(m2##_), m3(m3##_), m4(m4##_), m5(m5##_), m6(m6##_), m7(m7##_){};
/** \} */
/* Use to suppress `-Wimplicit-fallthrough` (in place of `break`). */

View File

@@ -1181,6 +1181,21 @@ float4x4 __mat4x4(float3x3 a) { return to_float4x4(a); }
class_name(t1 m1##_, t2 m2##_, t3 m3##_, t4 m4##_) \
: m1(m1##_), m2(m2##_), m3(m3##_), m4(m4##_){};
#define METAL_CONSTRUCTOR_5(class_name, t1, m1, t2, m2, t3, m3, t4, m4, t5, m5) \
class_name() = default; \
class_name(t1 m1##_, t2 m2##_, t3 m3##_, t4 m4##_, t5 m5##_) \
: m1(m1##_), m2(m2##_), m3(m3##_), m4(m4##_), m5(m5##_){};
#define METAL_CONSTRUCTOR_6(class_name, t1, m1, t2, m2, t3, m3, t4, m4, t5, m5, t6, m6) \
class_name() = default; \
class_name(t1 m1##_, t2 m2##_, t3 m3##_, t4 m4##_, t5 m5##_, t6 m6##_) \
: m1(m1##_), m2(m2##_), m3(m3##_), m4(m4##_), m5(m5##_), m6(m6##_){};
#define METAL_CONSTRUCTOR_7(class_name, t1, m1, t2, m2, t3, m3, t4, m4, t5, m5, t6, m6, t7, m7) \
class_name() = default; \
class_name(t1 m1##_, t2 m2##_, t3 m3##_, t4 m4##_, t5 m5##_, t6 m6##_, t7 m7##_) \
: m1(m1##_), m2(m2##_), m3(m3##_), m4(m4##_), m5(m5##_), m6(m6##_), m7(m7##_){};
#undef ENABLE_IF
/* Array syntax compatibility. */

View File

@@ -146,6 +146,9 @@ RESHAPE(float3x3, mat3x3, mat3x4)
#define METAL_CONSTRUCTOR_2(class_name, t1, m1, t2, m2)
#define METAL_CONSTRUCTOR_3(class_name, t1, m1, t2, m2, t3, m3)
#define METAL_CONSTRUCTOR_4(class_name, t1, m1, t2, m2, t3, m3, t4, m4)
#define METAL_CONSTRUCTOR_5(class_name, t1, m1, t2, m2, t3, m3, t4, m4, t5, m5)
#define METAL_CONSTRUCTOR_6(class_name, t1, m1, t2, m2, t3, m3, t4, m4, t5, m5, t6, m6)
#define METAL_CONSTRUCTOR_7(class_name, t1, m1, t2, m2, t3, m3, t4, m4, t5, m5, t6, m6, t7, m7)
#define _in_sta
#define _in_end

View File

@@ -255,7 +255,7 @@ func_TEMPLATE(float, 1)/*float a*/)";
EXPECT_EQ(error, "");
}
{
string input = R"(template<typename T, int i = 0> void func(T a) {a;)";
string input = R"(template<typename T, int i = 0> void func(T a) {a;})";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(error, "Template declaration unsupported syntax");
@@ -486,6 +486,7 @@ int func2(int a)
string expect = R"(
struct A_S {};
#line 4
int A_func(int a)
{
A_S s;
@@ -623,6 +624,7 @@ void test() {
void A_B_func() {}
struct A_B_S {};
#line 5
@@ -897,4 +899,202 @@ uint my_func() {
}
GPU_TEST(preprocess_resource_guard);
static void test_preprocess_struct_methods()
{
using namespace shader;
using namespace std;
{
string input = R"(
class S {
private:
int member;
int this_member;
public:
static S construct()
{
S a;
a.member = 0;
a.this_member = 0;
return a;
}
int another_member;
S function(int i)
{
this->member = i;
this_member++;
return *this;
}
int size() const
{
return this->member;
}
};
void main()
{
S s = S::construct();
a.b();
a(0).b();
a().b();
a.b.c();
a.b(0).c();
a.b().c();
a[0].b();
a.b[0].c();
a.b().c[0];
}
)";
string expect = R"(
struct S {
int member;
int this_member;
int another_member;
};
#line 8
static S S_construct()
{
S a;
a.member = 0;
a.this_member = 0;
return a;
}
#line 18
S function(inout S _inout_sta this _inout_end, int i)
{
this.member = i;
this_member++;
return this;
}
#line 25
int size(const S this)
{
return this.member;
}
#line 30
void main()
{
S s = S_construct();
b(a);
b(a(0));
b(a());
c(a.b);
c(b(a, 0));
c(b(a));
b(a[0]);
c(a.b[0]);
b(a).c[0];
}
)";
string error;
string output = process_test_string(input, error);
EXPECT_EQ(output, expect);
EXPECT_EQ(error, "");
}
}
GPU_TEST(preprocess_struct_methods);
static void test_preprocess_parser()
{
using namespace std;
using namespace shader::parser;
{
string input = R"(
1;
1.0;
2e10;
2e10f;
2.e10f;
2.0e-1f;
2.0e-1;
2.0e-1f;
0xFF;
0xFFu;
)";
string expect = R"(
0;0;0;0;0;0;0;0;0;0;)";
EXPECT_EQ(Parser(input).data_get().token_types, expect);
}
{
string input = R"(
struct T {
int t = 1;
};
class B {
T t;
};
)";
string expect = R"(
sw{ww=0;};Sw{ww;};)";
EXPECT_EQ(Parser(input).data_get().token_types, expect);
}
{
string input = R"(
void f(int t = 0) {
int i = 0, u = 2, v = {1.0f};
{
v = i = u, v++;
if (v == i) {
return;
}
}
}
)";
string expect = R"(
ww(ww=0){ww=0,w=0,w={0};{w=w=w,wP;i(wEw){r;}}})";
EXPECT_EQ(Parser(input).data_get().token_types, expect);
}
{
Parser parser("float i;");
parser.insert_after(Token{&parser.data_get(), 0}, "A ");
parser.insert_after(Token{&parser.data_get(), 0}, "B ");
EXPECT_EQ(parser.result_get(), "float A B i;");
}
{
string input = R"(
A
#line 100
B
)";
Parser parser(input);
Token A = {&parser.data_get(), 1};
Token B = {&parser.data_get(), 5};
EXPECT_EQ(A.str_no_whitespace(), "A");
EXPECT_EQ(B.str_no_whitespace(), "B");
EXPECT_EQ(A.line_number(), 2);
EXPECT_EQ(B.line_number(), 100);
}
}
GPU_TEST(preprocess_parser);
} // namespace blender::gpu::tests