diff --git a/source/blender/nodes/CMakeLists.txt b/source/blender/nodes/CMakeLists.txt index 862e213119f..4af19158798 100644 --- a/source/blender/nodes/CMakeLists.txt +++ b/source/blender/nodes/CMakeLists.txt @@ -98,6 +98,7 @@ set(SRC intern/partial_eval.cc intern/socket_search_link.cc intern/socket_usage_inference.cc + intern/socket_value_inference.cc intern/sync_sockets.cc intern/trace_values.cc intern/value_elem.cc @@ -150,6 +151,7 @@ set(SRC NOD_socket_search_link.hh NOD_socket_usage_inference.hh NOD_socket_usage_inference_fwd.hh + NOD_socket_value_inference.hh NOD_sync_sockets.hh NOD_texture.h NOD_trace_values.hh diff --git a/source/blender/nodes/NOD_socket_usage_inference.hh b/source/blender/nodes/NOD_socket_usage_inference.hh index a1f14c0f460..6f5436090c3 100644 --- a/source/blender/nodes/NOD_socket_usage_inference.hh +++ b/source/blender/nodes/NOD_socket_usage_inference.hh @@ -11,6 +11,7 @@ #include "NOD_geometry_nodes_execute.hh" #include "NOD_socket_usage_inference_fwd.hh" +#include "NOD_socket_value_inference.hh" struct bNodeTree; struct bNodeSocket; @@ -20,53 +21,6 @@ namespace blender::nodes::socket_usage_inference { struct SocketUsageInferencer; -/** - * During socket usage inferencing, some socket values are computed. This class represents such a - * computed value. Not all possible values can be presented here, only "basic" once (like int, but - * not int-field). A value can also be unknown if it can't be determined statically. - */ -class InferenceValue { - private: - /** - * Non-owning pointer to a value of type #bNodeSocketType.base_cpp_type of the corresponding - * socket. If this is null, the value is assumed to be unknown (aka, it can't be determined - * statically). - */ - const void *value_ = nullptr; - - public: - explicit InferenceValue(const void *value) : value_(value) {} - - static InferenceValue Unknown() - { - return InferenceValue(nullptr); - } - - bool is_unknown() const - { - return value_ == nullptr; - } - - const void *data() const - { - return value_; - } - - template T get_known() const - { - BLI_assert(!this->is_unknown()); - return *static_cast(this->value_); - } - - template std::optional get() const - { - if (this->is_unknown()) { - return std::nullopt; - } - return this->get_known(); - } -}; - class InputSocketUsageParams { private: SocketUsageInferencer &inferencer_; diff --git a/source/blender/nodes/NOD_socket_value_inference.hh b/source/blender/nodes/NOD_socket_value_inference.hh new file mode 100644 index 00000000000..a3b834eb5aa --- /dev/null +++ b/source/blender/nodes/NOD_socket_value_inference.hh @@ -0,0 +1,95 @@ +/* SPDX-FileCopyrightText: 2025 Blender Authors + * + * SPDX-License-Identifier: GPL-2.0-or-later */ + +#pragma once + +#include "BLI_generic_pointer.hh" +#include "BLI_resource_scope.hh" + +#include "BKE_compute_context_cache_fwd.hh" + +#include "DNA_material_types.h" +#include "NOD_node_in_compute_context.hh" + +struct bNodeTree; + +namespace blender::nodes { + +/** + * During socket usage inferencing, some socket values are computed. This class represents such a + * computed value. Not all possible values can be presented here, only "basic" once (like int, but + * not int-field). A value can also be unknown if it can't be determined statically. + */ +class InferenceValue { + private: + /** + * Non-owning pointer to a value of type #bNodeSocketType.base_cpp_type of the corresponding + * socket. If this is null, the value is assumed to be unknown (aka, it can't be determined + * statically). + */ + const void *value_ = nullptr; + + public: + explicit InferenceValue(const void *value) : value_(value) {} + + static InferenceValue Unknown() + { + return InferenceValue(nullptr); + } + + bool is_unknown() const + { + return value_ == nullptr; + } + + const void *data() const + { + return value_; + } + + template T get_known() const + { + BLI_assert(!this->is_unknown()); + return *static_cast(this->value_); + } + + template std::optional get() const + { + if (this->is_unknown()) { + return std::nullopt; + } + return this->get_known(); + } +}; + +class SocketValueInferencerImpl; + +class SocketValueInferencer { + private: + SocketValueInferencerImpl &impl_; + + public: + SocketValueInferencer(const bNodeTree &tree, + ResourceScope &scope, + bke::ComputeContextCache &compute_context_cache, + const std::optional> tree_input_values, + const std::optional> top_level_ignored_inputs); + + InferenceValue get_socket_value(const SocketInContext &socket); +}; + +namespace switch_node_inference_utils { + +bool is_socket_selected__switch(const SocketInContext &socket, const InferenceValue &condition); +bool is_socket_selected__index_switch(const SocketInContext &socket, + const InferenceValue &condition); +bool is_socket_selected__menu_switch(const SocketInContext &socket, + const InferenceValue &condition); +bool is_socket_selected__mix_node(const SocketInContext &socket, const InferenceValue &condition); +bool is_socket_selected__shader_mix_node(const SocketInContext &socket, + const InferenceValue &condition); + +} // namespace switch_node_inference_utils + +} // namespace blender::nodes diff --git a/source/blender/nodes/intern/socket_usage_inference.cc b/source/blender/nodes/intern/socket_usage_inference.cc index 3201e5bd530..d05e1ac2dc1 100644 --- a/source/blender/nodes/intern/socket_usage_inference.cc +++ b/source/blender/nodes/intern/socket_usage_inference.cc @@ -34,9 +34,11 @@ struct SocketUsageInferencer { private: friend InputSocketUsageParams; - /** Owns e.g. intermediate evaluated values. */ - ResourceScope scope_; - bke::ComputeContextCache compute_context_cache_; + ResourceScope &scope_; + bke::ComputeContextCache &compute_context_cache_; + + /** Inferences the socket values if possible. */ + SocketValueInferencer value_inferencer_; /** Root node tree. */ const bNodeTree &root_tree_; @@ -45,7 +47,6 @@ struct SocketUsageInferencer { * Stack of tasks that allows depth-first (partial) evaluation of the tree. */ Stack usage_tasks_; - Stack value_tasks_; /** * If the usage of a socket is known, it is added to this map. Sockets not in this map are not @@ -53,54 +54,20 @@ struct SocketUsageInferencer { */ Map all_socket_usages_; - /** - * Once a socket value has been determined, it is added to this map. Note that a socket value may - * be determined to be unknown because it depends on values that are not known statically. - */ - Map all_socket_values_; - - /** - * All sockets that have animation data and thus their value is not fixed statically. This can - * contain sockets from multiple different trees. - */ - Set animated_sockets_; - Set trees_with_handled_animation_data_; - - /** Some inline storage to reduce the number of allocations. */ - AlignedBuffer<1024, 8> scope_buffer_; - - std::optional> top_level_ignored_inputs_; - public: SocketUsageInferencer(const bNodeTree &tree, const std::optional> tree_input_values, + ResourceScope &scope, + bke::ComputeContextCache &compute_context_cache, const std::optional> top_level_ignored_inputs = std::nullopt) - : root_tree_(tree), top_level_ignored_inputs_(top_level_ignored_inputs) + : scope_(scope), + compute_context_cache_(compute_context_cache), + value_inferencer_( + tree, scope_, compute_context_cache_, tree_input_values, top_level_ignored_inputs), + root_tree_(tree) { - scope_.allocator().provide_buffer(scope_buffer_); root_tree_.ensure_topology_cache(); root_tree_.ensure_interface_cache(); - this->ensure_animation_data_processed(root_tree_); - - for (const bNode *node : root_tree_.group_input_nodes()) { - for (const int i : root_tree_.interface_inputs().index_range()) { - const bNodeSocket &socket = node->output_socket(i); - if (!socket.is_directly_linked()) { - /* This socket is not linked, hence it's value is never used. Thus we don't have to add - * it to #all_socket_values_. This optimization helps a lot when the node group has a - * very large number of inputs and group input nodes. */ - continue; - } - const SocketInContext socket_in_context{nullptr, &socket}; - const void *input_value = nullptr; - if (!this->treat_socket_as_unknown(socket_in_context)) { - if (tree_input_values.has_value()) { - input_value = (*tree_input_values)[i].get(); - } - } - all_socket_values_.add_new(socket_in_context, InferenceValue(input_value)); - } - } } void mark_top_level_node_outputs_as_used() @@ -159,27 +126,7 @@ struct SocketUsageInferencer { InferenceValue get_socket_value(const SocketInContext &socket) { - const std::optional value = all_socket_values_.lookup_try(socket); - if (value.has_value()) { - return *value; - } - if (socket->owner_tree().has_available_link_cycle()) { - return InferenceValue::Unknown(); - } - - BLI_assert(value_tasks_.is_empty()); - value_tasks_.push(socket); - - while (!value_tasks_.is_empty()) { - const SocketInContext &socket = value_tasks_.peek(); - this->value_task(socket); - if (&socket == &value_tasks_.peek()) { - /* The task is finished if it hasn't added any new task it depends on. */ - value_tasks_.pop(); - } - } - - return all_socket_values_.lookup(socket); + return value_inferencer_.get_socket_value(socket); } private: @@ -225,23 +172,28 @@ struct SocketUsageInferencer { break; } case GEO_NODE_SWITCH: { - this->usage_task__input__generic_switch(socket, switch__is_socket_selected); + this->usage_task__input__generic_switch( + socket, switch_node_inference_utils::is_socket_selected__switch); break; } case GEO_NODE_INDEX_SWITCH: { - this->usage_task__input__generic_switch(socket, index_switch__is_socket_selected); + this->usage_task__input__generic_switch( + socket, switch_node_inference_utils::is_socket_selected__index_switch); break; } case GEO_NODE_MENU_SWITCH: { - this->usage_task__input__generic_switch(socket, menu_switch__is_socket_selected); + this->usage_task__input__generic_switch( + socket, switch_node_inference_utils::is_socket_selected__menu_switch); break; } case SH_NODE_MIX: { - this->usage_task__input__generic_switch(socket, mix_node__is_socket_selected); + this->usage_task__input__generic_switch( + socket, switch_node_inference_utils::is_socket_selected__mix_node); break; } case SH_NODE_MIX_SHADER: { - this->usage_task__input__generic_switch(socket, shader_mix_node__is_socket_selected); + this->usage_task__input__generic_switch( + socket, switch_node_inference_utils::is_socket_selected__shader_mix_node); break; } case GEO_NODE_SIMULATION_INPUT: { @@ -304,7 +256,7 @@ struct SocketUsageInferencer { return; } const SocketInContext output_socket{socket.context, - this->get_first_available_bsocket(node->output_sockets())}; + get_first_available_bsocket(node->output_sockets())}; const std::optional output_is_used = all_socket_usages_.lookup_try(output_socket); if (!output_is_used.has_value()) { this->push_usage_task(output_socket); @@ -314,8 +266,8 @@ struct SocketUsageInferencer { all_socket_usages_.add_new(socket, false); return; } - const SocketInContext condition_socket{ - socket.context, this->get_first_available_bsocket(node->input_sockets())}; + const SocketInContext condition_socket{socket.context, + get_first_available_bsocket(node->input_sockets())}; if (socket == condition_socket) { all_socket_usages_.add_new(socket, true); return; @@ -330,16 +282,6 @@ struct SocketUsageInferencer { all_socket_usages_.add_new(socket, is_used); } - const bNodeSocket *get_first_available_bsocket(const Span sockets) const - { - for (const bNodeSocket *socket : sockets) { - if (socket->is_available()) { - return socket; - } - } - return nullptr; - } - void usage_task__input__group_node(const SocketInContext &socket) { const NodeInContext node = socket.owner_node(); @@ -353,7 +295,6 @@ struct SocketUsageInferencer { all_socket_usages_.add_new(socket, false); return; } - this->ensure_animation_data_processed(*group); /* The group node input is used if any of the matching group inputs within the group is * used. */ @@ -563,801 +504,19 @@ struct SocketUsageInferencer { all_socket_usages_.add_new(socket, all_condition_inputs_true); } - void value_task(const SocketInContext &socket) - { - if (all_socket_values_.contains(socket)) { - /* Task is done already. */ - return; - } - const bNode &node = socket->owner_node(); - if (node.is_undefined() && !node.is_custom_group()) { - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - const CPPType *base_type = socket->typeinfo->base_cpp_type; - if (!base_type) { - /* The socket type is unknown for some reason (maybe a socket type from the future?). */ - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - if (socket->is_input()) { - this->value_task__input(socket); - } - else { - this->value_task__output(socket); - } - } - - void value_task__output(const SocketInContext &socket) - { - const NodeInContext node = socket.owner_node(); - if (node->is_muted()) { - this->value_task__output__muted_node(socket); - return; - } - switch (node->type_legacy) { - case NODE_GROUP: - case NODE_CUSTOM_GROUP: { - this->value_task__output__group_node(socket); - return; - } - case NODE_GROUP_INPUT: { - this->value_task__output__group_input_node(socket); - return; - } - case NODE_REROUTE: { - this->value_task__output__reroute_node(socket); - return; - } - case GEO_NODE_SWITCH: { - this->value_task__output__generic_switch(socket, switch__is_socket_selected); - return; - } - case GEO_NODE_INDEX_SWITCH: { - this->value_task__output__generic_switch(socket, index_switch__is_socket_selected); - return; - } - case GEO_NODE_MENU_SWITCH: { - this->value_task__output__generic_switch(socket, menu_switch__is_socket_selected); - return; - } - case SH_NODE_MIX: { - this->value_task__output__generic_switch(socket, mix_node__is_socket_selected); - return; - } - case SH_NODE_MIX_SHADER: { - this->value_task__output__generic_switch(socket, shader_mix_node__is_socket_selected); - return; - } - case SH_NODE_MATH: { - this->value_task__output__float_math(socket); - return; - } - case SH_NODE_VECTOR_MATH: { - this->value_task__output__vector_math(socket); - return; - } - case FN_NODE_INTEGER_MATH: { - this->value_task__output__integer_math(socket); - return; - } - case FN_NODE_BOOLEAN_MATH: { - this->value_task__output__boolean_math(socket); - return; - } - default: { - if (node->typeinfo->build_multi_function) { - this->value_task__output__multi_function_node(socket); - return; - } - break; - } - } - /* If none of the above cases work, the socket value is set to null which means that it is - * unknown/dynamic. */ - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - } - - void value_task__output__group_node(const SocketInContext &socket) - { - const NodeInContext node = socket.owner_node(); - const bNodeTree *group = reinterpret_cast(node->id); - if (!group || ID_MISSING(&group->id)) { - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - group->ensure_topology_cache(); - if (group->has_available_link_cycle()) { - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - this->ensure_animation_data_processed(*group); - const bNode *group_output_node = group->group_output_node(); - if (!group_output_node) { - /* Can't compute the value if the group does not have an output node. */ - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - const ComputeContext &group_context = compute_context_cache_.for_group_node( - socket.context, node->identifier, &node->owner_tree()); - const SocketInContext socket_in_group{&group_context, - &group_output_node->input_socket(socket->index())}; - const std::optional value = all_socket_values_.lookup_try(socket_in_group); - if (!value.has_value()) { - this->push_value_task(socket_in_group); - return; - } - all_socket_values_.add_new(socket, *value); - } - - void value_task__output__group_input_node(const SocketInContext &socket) - { - /* Group inputs for the root context should be initialized already. */ - BLI_assert(socket.context != nullptr); - - const bke::GroupNodeComputeContext &group_context = - *static_cast(socket.context); - const SocketInContext group_node_input{group_context.parent(), - &group_context.node()->input_socket(socket->index())}; - const std::optional value = all_socket_values_.lookup_try(group_node_input); - if (!value.has_value()) { - this->push_value_task(group_node_input); - return; - } - all_socket_values_.add_new(socket, *value); - } - - void value_task__output__reroute_node(const SocketInContext &socket) - { - const SocketInContext input_socket = socket.owner_node().input_socket(0); - const std::optional value = all_socket_values_.lookup_try(input_socket); - if (!value.has_value()) { - this->push_value_task(input_socket); - return; - } - all_socket_values_.add_new(socket, *value); - } - - void value_task__output__float_math(const SocketInContext &socket) - { - const NodeInContext node = socket.owner_node(); - const NodeMathOperation operation = NodeMathOperation(node->custom1); - switch (operation) { - case NODE_MATH_MULTIPLY: { - this->value_task__output__generic_eval( - socket, [&](const Span inputs) -> std::optional { - const std::optional a = inputs[0].get(); - const std::optional b = inputs[1].get(); - if (a == 0.0f || b == 0.0f) { - return InferenceValue(&scope_.construct(0.0f)); - } - if (a.has_value() && b.has_value()) { - return InferenceValue(&scope_.construct(*a * *b)); - } - return std::nullopt; - }); - break; - } - default: { - this->value_task__output__multi_function_node(socket); - break; - } - } - } - - void value_task__output__vector_math(const SocketInContext &socket) - { - const NodeInContext node = socket.owner_node(); - const NodeVectorMathOperation operation = NodeVectorMathOperation(node->custom1); - switch (operation) { - case NODE_VECTOR_MATH_MULTIPLY: { - this->value_task__output__generic_eval( - socket, [&](const Span inputs) -> std::optional { - const std::optional a = inputs[0].get(); - const std::optional b = inputs[1].get(); - if (a == float3(0.0f) || b == float3(0.0f)) { - return InferenceValue(&scope_.construct(0.0f)); - } - if (a.has_value() && b.has_value()) { - return InferenceValue(&scope_.construct(*a * *b)); - } - return std::nullopt; - }); - break; - } - case NODE_VECTOR_MATH_SCALE: { - this->value_task__output__generic_eval( - socket, [&](const Span inputs) -> std::optional { - const std::optional a = inputs[0].get(); - const std::optional scale = inputs[3].get(); - if (a == float3(0.0f) || scale == 0.0f) { - return InferenceValue(&scope_.construct(0.0f)); - } - if (a.has_value() && scale.has_value()) { - return InferenceValue(&scope_.construct(*a * *scale)); - } - return std::nullopt; - }); - break; - } - default: { - this->value_task__output__multi_function_node(socket); - break; - } - } - } - - void value_task__output__integer_math(const SocketInContext &socket) - { - const NodeInContext node = socket.owner_node(); - const NodeIntegerMathOperation operation = NodeIntegerMathOperation(node->custom1); - switch (operation) { - case NODE_INTEGER_MATH_MULTIPLY: { - this->value_task__output__generic_eval( - socket, [&](const Span inputs) -> std::optional { - const std::optional a = inputs[0].get(); - const std::optional b = inputs[1].get(); - if (a == 0 || b == 0) { - return InferenceValue(&scope_.construct(0)); - } - if (a.has_value() && b.has_value()) { - return InferenceValue(&scope_.construct(*a * *b)); - } - return std::nullopt; - }); - break; - } - default: { - this->value_task__output__multi_function_node(socket); - break; - } - } - } - - void value_task__output__boolean_math(const SocketInContext &socket) - { - const NodeInContext node = socket.owner_node(); - const NodeBooleanMathOperation operation = NodeBooleanMathOperation(node->custom1); - - const auto handle_binary_op = - [&](FunctionRef(std::optional, std::optional)> fn) { - this->value_task__output__generic_eval( - socket, [&](const Span inputs) -> std::optional { - const std::optional a = inputs[0].get(); - const std::optional b = inputs[1].get(); - const std::optional result = fn(a, b); - if (result.has_value()) { - return InferenceValue(&scope_.construct(*result)); - } - return std::nullopt; - }); - }; - switch (operation) { - case NODE_BOOLEAN_MATH_AND: { - handle_binary_op( - [](const std::optional &a, const std::optional &b) -> std::optional { - if (a == false || b == false) { - return false; - } - if (a.has_value() && b.has_value()) { - return *a && *b; - } - return std::nullopt; - }); - break; - } - case NODE_BOOLEAN_MATH_OR: { - handle_binary_op( - [](const std::optional &a, const std::optional &b) -> std::optional { - if (a == true || b == true) { - return true; - } - if (a.has_value() && b.has_value()) { - return *a || *b; - } - return std::nullopt; - }); - break; - } - case NODE_BOOLEAN_MATH_NAND: { - handle_binary_op( - [](const std::optional &a, const std::optional &b) -> std::optional { - if (a == false || b == false) { - return true; - } - if (a.has_value() && b.has_value()) { - return !(*a && *b); - } - return std::nullopt; - }); - break; - } - case NODE_BOOLEAN_MATH_NOR: { - handle_binary_op( - [](const std::optional &a, const std::optional &b) -> std::optional { - if (a == true || b == true) { - return false; - } - if (a.has_value() && b.has_value()) { - return !(*a || *b); - } - return std::nullopt; - }); - break; - } - case NODE_BOOLEAN_MATH_IMPLY: { - handle_binary_op( - [](const std::optional &a, const std::optional &b) -> std::optional { - if (a == false || b == true) { - return true; - } - if (a.has_value() && b.has_value()) { - return !*a || *b; - } - return std::nullopt; - }); - break; - } - case NODE_BOOLEAN_MATH_NIMPLY: { - handle_binary_op( - [](const std::optional &a, const std::optional &b) -> std::optional { - if (a == false || b == true) { - return false; - } - if (a.has_value() && b.has_value()) { - return *a && !*b; - } - return std::nullopt; - }); - break; - } - default: { - this->value_task__output__multi_function_node(socket); - break; - } - } - } - - /** - * Assumes that the first available input is a condition that selects one of the remaining inputs - * which is then output. - */ - void value_task__output__generic_switch( - const SocketInContext &socket, - const FunctionRef - is_selected_socket) - { - const NodeInContext node = socket.owner_node(); - BLI_assert(node->input_sockets().size() >= 1); - BLI_assert(node->output_sockets().size() >= 1); - - const SocketInContext condition_socket{ - socket.context, this->get_first_available_bsocket(node->input_sockets())}; - const std::optional condition_value = all_socket_values_.lookup_try( - condition_socket); - if (!condition_value.has_value()) { - this->push_value_task(condition_socket); - return; - } - if (condition_value->is_unknown()) { - /* The condition value is not a simple static value, so the output is unknown. */ - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - Vector selected_inputs; - for (const int input_i : - node->input_sockets().index_range().drop_front(condition_socket->index() + 1)) - { - const SocketInContext input_socket = node.input_socket(input_i); - if (!input_socket->is_available()) { - continue; - } - if (input_socket->type == SOCK_CUSTOM && STREQ(input_socket->idname, "NodeSocketVirtual")) { - continue; - } - const bool is_selected = is_selected_socket(input_socket, *condition_value); - if (is_selected) { - selected_inputs.append(input_socket.socket); - } - } - if (selected_inputs.is_empty()) { - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - if (selected_inputs.size() == 1) { - /* A single input is selected, so just pass through this value without regarding others. */ - const SocketInContext selected_input{socket.context, selected_inputs[0]}; - const std::optional input_value = all_socket_values_.lookup_try( - selected_input); - if (!input_value.has_value()) { - this->push_value_task(selected_input); - return; - } - all_socket_values_.add_new(socket, *input_value); - return; - } - - /* Multiple inputs are selected. */ - if (node->typeinfo->build_multi_function) { - /* Try to compute the output value from the multiple selected inputs. */ - this->value_task__output__multi_function_node(socket); - return; - } - /* Can't compute the output value, so set it to be unknown. */ - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - } - - void value_task__output__generic_eval( - const SocketInContext &socket, - const FunctionRef(Span inputs)> eval_fn) - { - const NodeInContext node = socket.owner_node(); - const int inputs_num = node->input_sockets().size(); - - Array input_values(inputs_num, InferenceValue::Unknown()); - std::optional next_unknown_input_index; - for (const int input_i : IndexRange(inputs_num)) { - const SocketInContext input_socket = node.input_socket(input_i); - if (!input_socket->is_available()) { - continue; - } - const std::optional input_value = all_socket_values_.lookup_try( - input_socket); - if (!input_value.has_value()) { - next_unknown_input_index = input_i; - break; - } - input_values[input_i] = *input_value; - } - const std::optional output_value = eval_fn(input_values); - if (output_value.has_value()) { - /* Was able to compute the output value. */ - all_socket_values_.add_new(socket, *output_value); - return; - } - if (!next_unknown_input_index.has_value()) { - /* The output is still unknown even though we know as much about the inputs as possible - * already. */ - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - /* Request the next input socket. */ - const SocketInContext next_input = node.input_socket(*next_unknown_input_index); - this->push_value_task(next_input); - } - - void value_task__output__multi_function_node(const SocketInContext &socket) - { - const NodeInContext node = socket.owner_node(); - const int inputs_num = node->input_sockets().size(); - - /* Gather all input values are return early if any of them is not known. */ - Vector input_values(inputs_num); - for (const int input_i : IndexRange(inputs_num)) { - const SocketInContext input_socket = node.input_socket(input_i); - const std::optional input_value = all_socket_values_.lookup_try( - input_socket); - if (!input_value.has_value()) { - this->push_value_task(input_socket); - return; - } - if (input_value->is_unknown()) { - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - input_values[input_i] = input_value->data(); - } - - /* Get the multi-function for the node. */ - NodeMultiFunctionBuilder builder{*node.node, node->owner_tree()}; - node->typeinfo->build_multi_function(builder); - const mf::MultiFunction &fn = builder.function(); - - /* We only evaluate the node for a single value here. */ - const IndexMask mask(1); - - /* Prepare parameters for the multi-function evaluation. */ - mf::ParamsBuilder params{fn, &mask}; - for (const int input_i : IndexRange(inputs_num)) { - const SocketInContext input_socket = node.input_socket(input_i); - if (!input_socket->is_available()) { - continue; - } - params.add_readonly_single_input( - GPointer(input_socket->typeinfo->base_cpp_type, input_values[input_i])); - } - for (const int output_i : node->output_sockets().index_range()) { - const SocketInContext output_socket = node.output_socket(output_i); - if (!output_socket->is_available()) { - continue; - } - /* Allocate memory for the output value. */ - const CPPType &base_type = *output_socket->typeinfo->base_cpp_type; - void *value = scope_.allocate_owned(base_type); - params.add_uninitialized_single_output(GMutableSpan(base_type, value, 1)); - all_socket_values_.add_new(output_socket, InferenceValue(value)); - } - mf::ContextBuilder context; - /* Actually evaluate the multi-function. The outputs will be written into the memory allocated - * earlier, which has been added to #all_socket_values_ already. */ - fn.call(mask, params, context); - } - - void value_task__output__muted_node(const SocketInContext &socket) - { - const NodeInContext node = socket.owner_node(); - - SocketInContext input_socket; - for (const bNodeLink &internal_link : node->internal_links()) { - if (internal_link.tosock == socket.socket) { - input_socket = SocketInContext{socket.context, internal_link.fromsock}; - break; - } - } - if (!input_socket) { - /* The output does not have an internal link to an input. */ - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - const std::optional input_value = all_socket_values_.lookup_try(input_socket); - if (!input_value.has_value()) { - this->push_value_task(input_socket); - return; - } - const void *converted_value = this->convert_type_if_necessary( - input_value->data(), *input_socket.socket, *socket.socket); - all_socket_values_.add_new(socket, InferenceValue(converted_value)); - } - - void value_task__input(const SocketInContext &socket) - { - if (socket->is_multi_input()) { - /* Can't know the single value of a multi-input. */ - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - const bNodeLink *source_link = nullptr; - const Span connected_links = socket->directly_linked_links(); - for (const bNodeLink *link : connected_links) { - if (!link->is_used()) { - continue; - } - if (link->fromnode->is_dangling_reroute()) { - continue; - } - source_link = link; - break; - } - if (!source_link) { - this->value_task__input__unlinked(socket); - return; - } - this->value_task__input__linked({socket.context, source_link->fromsock}, socket); - } - - void value_task__input__unlinked(const SocketInContext &socket) - { - if (this->treat_socket_as_unknown(socket)) { - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - if (animated_sockets_.contains(socket.socket)) { - /* The value of animated sockets is not known statically. */ - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - if (const SocketDeclaration *socket_decl = socket.socket->runtime->declaration) { - if (socket_decl->input_field_type == InputSocketFieldType::Implicit) { - /* Implicit fields inputs don't have a single static value. */ - all_socket_values_.add_new(socket, InferenceValue::Unknown()); - return; - } - } - - void *value_buffer = scope_.allocate_owned(*socket->typeinfo->base_cpp_type); - socket->typeinfo->get_base_cpp_value(socket->default_value, value_buffer); - all_socket_values_.add_new(socket, InferenceValue(value_buffer)); - } - - void value_task__input__linked(const SocketInContext &from_socket, - const SocketInContext &to_socket) - { - const std::optional from_value = all_socket_values_.lookup_try(from_socket); - if (!from_value.has_value()) { - this->push_value_task(from_socket); - return; - } - const void *converted_value = this->convert_type_if_necessary( - from_value->data(), *from_socket.socket, *to_socket.socket); - all_socket_values_.add_new(to_socket, InferenceValue(converted_value)); - } - - const void *convert_type_if_necessary(const void *src, - const bNodeSocket &from_socket, - const bNodeSocket &to_socket) - { - if (!src) { - return nullptr; - } - const CPPType *from_type = from_socket.typeinfo->base_cpp_type; - const CPPType *to_type = to_socket.typeinfo->base_cpp_type; - if (from_type == to_type) { - return src; - } - if (!to_type) { - return nullptr; - } - const bke::DataTypeConversions &conversions = bke::get_implicit_type_conversions(); - if (!conversions.is_convertible(*from_type, *to_type)) { - return nullptr; - } - void *dst = scope_.allocate_owned(*to_type); - conversions.convert_to_uninitialized(*from_type, *to_type, src, dst); - return dst; - } - - static bool switch__is_socket_selected(const SocketInContext &socket, - const InferenceValue &condition) - { - const bool is_true = condition.get_known(); - const int selected_index = is_true ? 2 : 1; - return socket->index() == selected_index; - } - - static bool index_switch__is_socket_selected(const SocketInContext &socket, - const InferenceValue &condition) - { - const int index = condition.get_known(); - return socket->index() == index + 1; - } - - static bool menu_switch__is_socket_selected(const SocketInContext &socket, - const InferenceValue &condition) - { - const NodeMenuSwitch &storage = *static_cast( - socket->owner_node().storage); - const int menu_value = condition.get_known(); - const NodeEnumItem &item = storage.enum_definition.items_array[socket->index() - 1]; - return menu_value == item.identifier; - } - - static bool mix_node__is_socket_selected(const SocketInContext &socket, - const InferenceValue &condition) - { - const NodeShaderMix &storage = *static_cast( - socket.owner_node()->storage); - if (storage.data_type == SOCK_RGBA && storage.blend_type != MA_RAMP_BLEND) { - return true; - } - - const bool clamp_factor = storage.clamp_factor != 0; - bool only_a = false; - bool only_b = false; - if (storage.data_type == SOCK_VECTOR && storage.factor_mode == NODE_MIX_MODE_NON_UNIFORM) { - const float3 mix_factor = condition.get_known(); - if (clamp_factor) { - only_a = mix_factor.x <= 0.0f && mix_factor.y <= 0.0f && mix_factor.z <= 0.0f; - only_b = mix_factor.x >= 1.0f && mix_factor.y >= 1.0f && mix_factor.z >= 1.0f; - } - else { - only_a = float3{0.0f, 0.0f, 0.0f} == mix_factor; - only_b = float3{1.0f, 1.0f, 1.0f} == mix_factor; - } - } - else { - const float mix_factor = condition.get_known(); - if (clamp_factor) { - only_a = mix_factor <= 0.0f; - only_b = mix_factor >= 1.0f; - } - else { - only_a = mix_factor == 0.0f; - only_b = mix_factor == 1.0f; - } - } - if (only_a) { - if (STREQ(socket->name, "B")) { - return false; - } - } - if (only_b) { - if (STREQ(socket->name, "A")) { - return false; - } - } - return true; - } - - static bool shader_mix_node__is_socket_selected(const SocketInContext &socket, - const InferenceValue &condition) - { - const float mix_factor = condition.get_known(); - if (mix_factor == 0.0f) { - if (STREQ(socket->identifier, "Shader_001")) { - return false; - } - } - else if (mix_factor == 1.0f) { - if (STREQ(socket->identifier, "Shader")) { - return false; - } - } - return true; - } - void push_usage_task(const SocketInContext &socket) { usage_tasks_.push(socket); } - void push_value_task(const SocketInContext &socket) + static const bNodeSocket *get_first_available_bsocket(const Span sockets) { - value_tasks_.push(socket); - } - - void ensure_animation_data_processed(const bNodeTree &tree) - { - if (!trees_with_handled_animation_data_.add(&tree)) { - return; - } - if (!tree.adt) { - return; - } - - static std::regex pattern(R"#(nodes\["(.*)"\].inputs\[(\d+)\].default_value)#"); - MultiValueMap animated_inputs_by_node_name; - auto handle_rna_path = [&](const char *rna_path) { - std::cmatch match; - if (!std::regex_match(rna_path, match, pattern)) { - return; - } - const StringRef node_name{match[1].first, match[1].second - match[1].first}; - const int socket_index = std::stoi(match[2]); - animated_inputs_by_node_name.add(node_name, socket_index); - }; - - /* Gather all inputs controlled by fcurves. */ - if (tree.adt->action) { - animrig::foreach_fcurve_in_action_slot( - tree.adt->action->wrap(), tree.adt->slot_handle, [&](const FCurve &fcurve) { - handle_rna_path(fcurve.rna_path); - }); - } - /* Gather all inputs controlled by drivers. */ - LISTBASE_FOREACH (const FCurve *, driver, &tree.adt->drivers) { - handle_rna_path(driver->rna_path); - } - - /* Actually find the #bNodeSocket for each controlled input. */ - if (!animated_inputs_by_node_name.is_empty()) { - for (const bNode *node : tree.all_nodes()) { - const Span animated_inputs = animated_inputs_by_node_name.lookup(node->name); - const Span input_sockets = node->input_sockets(); - for (const int socket_index : animated_inputs) { - if (socket_index < 0 || socket_index >= input_sockets.size()) { - /* This can happen when the animation data is not immediately updated after a socket is - * removed. */ - continue; - } - const bNodeSocket &socket = *input_sockets[socket_index]; - animated_sockets_.add(&socket); - } + for (const bNodeSocket *socket : sockets) { + if (socket->is_available()) { + return socket; } } - } - - bool treat_socket_as_unknown(const SocketInContext &socket) const - { - if (!top_level_ignored_inputs_.has_value()) { - return false; - } - if (socket.context) { - return false; - } - if (socket->is_output()) { - return false; - } - return (*top_level_ignored_inputs_)[socket->index_in_all_inputs()]; + return nullptr; } }; @@ -1377,9 +536,12 @@ Array infer_all_input_sockets_usage(const bNodeTree &tree) const Span all_input_sockets = tree.all_input_sockets(); Array all_usages(all_input_sockets.size()); + ResourceScope scope; + bke::ComputeContextCache compute_context_cache; + { /* Find actual socket usages. */ - SocketUsageInferencer inferencer{tree, std::nullopt}; + SocketUsageInferencer inferencer{tree, std::nullopt, scope, compute_context_cache}; inferencer.mark_top_level_node_outputs_as_used(); for (const int i : all_input_sockets.index_range()) { const bNodeSocket &socket = *all_input_sockets[i]; @@ -1396,8 +558,10 @@ Array infer_all_input_sockets_usage(const bNodeTree &tree) only_controllers_used[i] = !input_may_affect_visibility(socket); } }); - SocketUsageInferencer inferencer_all_unknown{tree, std::nullopt, all_ignored_inputs}; - SocketUsageInferencer inferencer_only_controllers{tree, std::nullopt, only_controllers_used}; + SocketUsageInferencer inferencer_all_unknown{ + tree, std::nullopt, scope, compute_context_cache, all_ignored_inputs}; + SocketUsageInferencer inferencer_only_controllers{ + tree, std::nullopt, scope, compute_context_cache, only_controllers_used}; inferencer_all_unknown.mark_top_level_node_outputs_as_used(); inferencer_only_controllers.mark_top_level_node_outputs_as_used(); for (const int i : all_input_sockets.index_range()) { @@ -1431,9 +595,12 @@ void infer_group_interface_inputs_usage(const bNodeTree &group, default_usage.is_visible = true; r_input_usages.fill(default_usage); + ResourceScope scope; + bke::ComputeContextCache compute_context_cache; + { /* Detect actually used inputs. */ - SocketUsageInferencer inferencer{group, group_input_values}; + SocketUsageInferencer inferencer{group, group_input_values, scope, compute_context_cache}; for (const bNode *node : group.group_input_nodes()) { for (const int i : group.interface_inputs().index_range()) { const bNodeSocket &socket = node->output_socket(i); @@ -1465,8 +632,10 @@ void infer_group_interface_inputs_usage(const bNodeTree &group, /* If there is no visibility controller inputs, all inputs are always visible. */ return; } - SocketUsageInferencer inferencer_all_unknown{group, inputs_all_unknown}; - SocketUsageInferencer inferencer_only_controllers{group, inputs_only_controllers}; + SocketUsageInferencer inferencer_all_unknown{ + group, inputs_all_unknown, scope, compute_context_cache}; + SocketUsageInferencer inferencer_only_controllers{ + group, inputs_only_controllers, scope, compute_context_cache}; for (const int i : group.interface_inputs().index_range()) { if (r_input_usages[i].is_used) { /* Used inputs are always visible. */ diff --git a/source/blender/nodes/intern/socket_value_inference.cc b/source/blender/nodes/intern/socket_value_inference.cc new file mode 100644 index 00000000000..13c2cf1851f --- /dev/null +++ b/source/blender/nodes/intern/socket_value_inference.cc @@ -0,0 +1,940 @@ +/* SPDX-FileCopyrightText: 2024 Blender Authors + * + * SPDX-License-Identifier: GPL-2.0-or-later */ + +#include + +#include "NOD_menu_value.hh" +#include "NOD_multi_function.hh" +#include "NOD_node_declaration.hh" +#include "NOD_node_in_compute_context.hh" +#include "NOD_socket_usage_inference.hh" + +#include "DNA_anim_types.h" +#include "DNA_material_types.h" +#include "DNA_node_types.h" + +#include "BKE_compute_context_cache.hh" +#include "BKE_compute_contexts.hh" +#include "BKE_node_legacy_types.hh" +#include "BKE_node_runtime.hh" +#include "BKE_type_conversions.hh" + +#include "ANIM_action.hh" +#include "ANIM_action_iterators.hh" + +#include "BLI_listbase.h" +#include "BLI_stack.hh" + +namespace blender::nodes { + +class SocketValueInferencerImpl { + private: + ResourceScope &scope_; + bke::ComputeContextCache &compute_context_cache_; + + Stack value_tasks_; + /** + * Once a socket value has been determined, it is added to this map. Note that a socket value may + * be determined to be unknown because it depends on values that are not known statically. + */ + Map all_socket_values_; + + /** + * All sockets that have animation data and thus their value is not fixed statically. This can + * contain sockets from multiple different trees. + */ + Set animated_sockets_; + Set trees_with_handled_animation_data_; + std::optional> top_level_ignored_inputs_; + + const bNodeTree &root_tree_; + + public: + SocketValueInferencerImpl(const bNodeTree &tree, + ResourceScope &scope, + bke::ComputeContextCache &compute_context_cache, + const std::optional> tree_input_values, + const std::optional> top_level_ignored_inputs) + : scope_(scope), + compute_context_cache_(compute_context_cache), + top_level_ignored_inputs_(top_level_ignored_inputs), + root_tree_(tree) + { + root_tree_.ensure_topology_cache(); + root_tree_.ensure_interface_cache(); + this->ensure_animation_data_processed(root_tree_); + + for (const bNode *node : root_tree_.group_input_nodes()) { + for (const int i : root_tree_.interface_inputs().index_range()) { + const bNodeSocket &socket = node->output_socket(i); + if (!socket.is_directly_linked()) { + /* This socket is not linked, hence it's value is never used. Thus we don't have to add + * it to #all_socket_values_. This optimization helps a lot when the node group has a + * very large number of inputs and group input nodes. */ + continue; + } + const SocketInContext socket_in_context{nullptr, &socket}; + const void *input_value = nullptr; + if (!this->treat_socket_as_unknown(socket_in_context)) { + if (tree_input_values.has_value()) { + input_value = (*tree_input_values)[i].get(); + } + } + all_socket_values_.add_new(socket_in_context, InferenceValue(input_value)); + } + } + } + + InferenceValue get_socket_value(const SocketInContext &socket) + { + const std::optional value = all_socket_values_.lookup_try(socket); + if (value.has_value()) { + return *value; + } + if (socket->owner_tree().has_available_link_cycle()) { + return InferenceValue::Unknown(); + } + + BLI_assert(value_tasks_.is_empty()); + value_tasks_.push(socket); + + while (!value_tasks_.is_empty()) { + const SocketInContext &socket = value_tasks_.peek(); + this->value_task(socket); + if (&socket == &value_tasks_.peek()) { + /* The task is finished if it hasn't added any new task it depends on. */ + value_tasks_.pop(); + } + } + + return all_socket_values_.lookup(socket); + } + + private: + void value_task(const SocketInContext &socket) + { + if (all_socket_values_.contains(socket)) { + /* Task is done already. */ + return; + } + const bNode &node = socket->owner_node(); + if (node.is_undefined() && !node.is_custom_group()) { + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + const CPPType *base_type = socket->typeinfo->base_cpp_type; + if (!base_type) { + /* The socket type is unknown for some reason (maybe a socket type from the future?). */ + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + if (socket->is_input()) { + this->value_task__input(socket); + } + else { + this->value_task__output(socket); + } + } + + void value_task__output(const SocketInContext &socket) + { + const NodeInContext node = socket.owner_node(); + if (node->is_muted()) { + this->value_task__output__muted_node(socket); + return; + } + switch (node->type_legacy) { + case NODE_GROUP: + case NODE_CUSTOM_GROUP: { + this->value_task__output__group_node(socket); + return; + } + case NODE_GROUP_INPUT: { + this->value_task__output__group_input_node(socket); + return; + } + case NODE_REROUTE: { + this->value_task__output__reroute_node(socket); + return; + } + case GEO_NODE_SWITCH: { + this->value_task__output__generic_switch( + socket, switch_node_inference_utils::is_socket_selected__switch); + return; + } + case GEO_NODE_INDEX_SWITCH: { + this->value_task__output__generic_switch( + socket, switch_node_inference_utils::is_socket_selected__index_switch); + return; + } + case GEO_NODE_MENU_SWITCH: { + this->value_task__output__generic_switch( + socket, switch_node_inference_utils::is_socket_selected__menu_switch); + return; + } + case SH_NODE_MIX: { + this->value_task__output__generic_switch( + socket, switch_node_inference_utils::is_socket_selected__mix_node); + return; + } + case SH_NODE_MIX_SHADER: { + this->value_task__output__generic_switch( + socket, switch_node_inference_utils::is_socket_selected__shader_mix_node); + return; + } + case SH_NODE_MATH: { + this->value_task__output__float_math(socket); + return; + } + case SH_NODE_VECTOR_MATH: { + this->value_task__output__vector_math(socket); + return; + } + case FN_NODE_INTEGER_MATH: { + this->value_task__output__integer_math(socket); + return; + } + case FN_NODE_BOOLEAN_MATH: { + this->value_task__output__boolean_math(socket); + return; + } + default: { + if (node->typeinfo->build_multi_function) { + this->value_task__output__multi_function_node(socket); + return; + } + break; + } + } + /* If none of the above cases work, the socket value is set to null which means that it is + * unknown/dynamic. */ + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + } + + void value_task__output__group_node(const SocketInContext &socket) + { + const NodeInContext node = socket.owner_node(); + const bNodeTree *group = reinterpret_cast(node->id); + if (!group || ID_MISSING(&group->id)) { + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + group->ensure_topology_cache(); + if (group->has_available_link_cycle()) { + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + this->ensure_animation_data_processed(*group); + const bNode *group_output_node = group->group_output_node(); + if (!group_output_node) { + /* Can't compute the value if the group does not have an output node. */ + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + const ComputeContext &group_context = compute_context_cache_.for_group_node( + socket.context, node->identifier, &node->owner_tree()); + const SocketInContext socket_in_group{&group_context, + &group_output_node->input_socket(socket->index())}; + const std::optional value = all_socket_values_.lookup_try(socket_in_group); + if (!value.has_value()) { + this->push_value_task(socket_in_group); + return; + } + all_socket_values_.add_new(socket, *value); + } + + void value_task__output__group_input_node(const SocketInContext &socket) + { + /* Group inputs for the root context should be initialized already. */ + BLI_assert(socket.context != nullptr); + + const bke::GroupNodeComputeContext &group_context = + *static_cast(socket.context); + const SocketInContext group_node_input{group_context.parent(), + &group_context.node()->input_socket(socket->index())}; + const std::optional value = all_socket_values_.lookup_try(group_node_input); + if (!value.has_value()) { + this->push_value_task(group_node_input); + return; + } + all_socket_values_.add_new(socket, *value); + } + + void value_task__output__reroute_node(const SocketInContext &socket) + { + const SocketInContext input_socket = socket.owner_node().input_socket(0); + const std::optional value = all_socket_values_.lookup_try(input_socket); + if (!value.has_value()) { + this->push_value_task(input_socket); + return; + } + all_socket_values_.add_new(socket, *value); + } + + void value_task__output__float_math(const SocketInContext &socket) + { + const NodeInContext node = socket.owner_node(); + const NodeMathOperation operation = NodeMathOperation(node->custom1); + switch (operation) { + case NODE_MATH_MULTIPLY: { + this->value_task__output__generic_eval( + socket, [&](const Span inputs) -> std::optional { + const std::optional a = inputs[0].get(); + const std::optional b = inputs[1].get(); + if (a == 0.0f || b == 0.0f) { + return InferenceValue(&scope_.construct(0.0f)); + } + if (a.has_value() && b.has_value()) { + return InferenceValue(&scope_.construct(*a * *b)); + } + return std::nullopt; + }); + break; + } + default: { + this->value_task__output__multi_function_node(socket); + break; + } + } + } + + void value_task__output__vector_math(const SocketInContext &socket) + { + const NodeInContext node = socket.owner_node(); + const NodeVectorMathOperation operation = NodeVectorMathOperation(node->custom1); + switch (operation) { + case NODE_VECTOR_MATH_MULTIPLY: { + this->value_task__output__generic_eval( + socket, [&](const Span inputs) -> std::optional { + const std::optional a = inputs[0].get(); + const std::optional b = inputs[1].get(); + if (a == float3(0.0f) || b == float3(0.0f)) { + return InferenceValue(&scope_.construct(0.0f)); + } + if (a.has_value() && b.has_value()) { + return InferenceValue(&scope_.construct(*a * *b)); + } + return std::nullopt; + }); + break; + } + case NODE_VECTOR_MATH_SCALE: { + this->value_task__output__generic_eval( + socket, [&](const Span inputs) -> std::optional { + const std::optional a = inputs[0].get(); + const std::optional scale = inputs[3].get(); + if (a == float3(0.0f) || scale == 0.0f) { + return InferenceValue(&scope_.construct(0.0f)); + } + if (a.has_value() && scale.has_value()) { + return InferenceValue(&scope_.construct(*a * *scale)); + } + return std::nullopt; + }); + break; + } + default: { + this->value_task__output__multi_function_node(socket); + break; + } + } + } + + void value_task__output__integer_math(const SocketInContext &socket) + { + const NodeInContext node = socket.owner_node(); + const NodeIntegerMathOperation operation = NodeIntegerMathOperation(node->custom1); + switch (operation) { + case NODE_INTEGER_MATH_MULTIPLY: { + this->value_task__output__generic_eval( + socket, [&](const Span inputs) -> std::optional { + const std::optional a = inputs[0].get(); + const std::optional b = inputs[1].get(); + if (a == 0 || b == 0) { + return InferenceValue(&scope_.construct(0)); + } + if (a.has_value() && b.has_value()) { + return InferenceValue(&scope_.construct(*a * *b)); + } + return std::nullopt; + }); + break; + } + default: { + this->value_task__output__multi_function_node(socket); + break; + } + } + } + + void value_task__output__boolean_math(const SocketInContext &socket) + { + const NodeInContext node = socket.owner_node(); + const NodeBooleanMathOperation operation = NodeBooleanMathOperation(node->custom1); + + const auto handle_binary_op = + [&](FunctionRef(std::optional, std::optional)> fn) { + this->value_task__output__generic_eval( + socket, [&](const Span inputs) -> std::optional { + const std::optional a = inputs[0].get(); + const std::optional b = inputs[1].get(); + const std::optional result = fn(a, b); + if (result.has_value()) { + return InferenceValue(&scope_.construct(*result)); + } + return std::nullopt; + }); + }; + switch (operation) { + case NODE_BOOLEAN_MATH_AND: { + handle_binary_op( + [](const std::optional &a, const std::optional &b) -> std::optional { + if (a == false || b == false) { + return false; + } + if (a.has_value() && b.has_value()) { + return *a && *b; + } + return std::nullopt; + }); + break; + } + case NODE_BOOLEAN_MATH_OR: { + handle_binary_op( + [](const std::optional &a, const std::optional &b) -> std::optional { + if (a == true || b == true) { + return true; + } + if (a.has_value() && b.has_value()) { + return *a || *b; + } + return std::nullopt; + }); + break; + } + case NODE_BOOLEAN_MATH_NAND: { + handle_binary_op( + [](const std::optional &a, const std::optional &b) -> std::optional { + if (a == false || b == false) { + return true; + } + if (a.has_value() && b.has_value()) { + return !(*a && *b); + } + return std::nullopt; + }); + break; + } + case NODE_BOOLEAN_MATH_NOR: { + handle_binary_op( + [](const std::optional &a, const std::optional &b) -> std::optional { + if (a == true || b == true) { + return false; + } + if (a.has_value() && b.has_value()) { + return !(*a || *b); + } + return std::nullopt; + }); + break; + } + case NODE_BOOLEAN_MATH_IMPLY: { + handle_binary_op( + [](const std::optional &a, const std::optional &b) -> std::optional { + if (a == false || b == true) { + return true; + } + if (a.has_value() && b.has_value()) { + return !*a || *b; + } + return std::nullopt; + }); + break; + } + case NODE_BOOLEAN_MATH_NIMPLY: { + handle_binary_op( + [](const std::optional &a, const std::optional &b) -> std::optional { + if (a == false || b == true) { + return false; + } + if (a.has_value() && b.has_value()) { + return *a && !*b; + } + return std::nullopt; + }); + break; + } + default: { + this->value_task__output__multi_function_node(socket); + break; + } + } + } + + /** + * Assumes that the first available input is a condition that selects one of the remaining inputs + * which is then output. + */ + void value_task__output__generic_switch( + const SocketInContext &socket, + const FunctionRef + is_selected_socket) + { + const NodeInContext node = socket.owner_node(); + BLI_assert(node->input_sockets().size() >= 1); + BLI_assert(node->output_sockets().size() >= 1); + + const SocketInContext condition_socket{socket.context, + get_first_available_bsocket(node->input_sockets())}; + const std::optional condition_value = all_socket_values_.lookup_try( + condition_socket); + if (!condition_value.has_value()) { + this->push_value_task(condition_socket); + return; + } + if (condition_value->is_unknown()) { + /* The condition value is not a simple static value, so the output is unknown. */ + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + Vector selected_inputs; + for (const int input_i : + node->input_sockets().index_range().drop_front(condition_socket->index() + 1)) + { + const SocketInContext input_socket = node.input_socket(input_i); + if (!input_socket->is_available()) { + continue; + } + if (input_socket->type == SOCK_CUSTOM && STREQ(input_socket->idname, "NodeSocketVirtual")) { + continue; + } + const bool is_selected = is_selected_socket(input_socket, *condition_value); + if (is_selected) { + selected_inputs.append(input_socket.socket); + } + } + if (selected_inputs.is_empty()) { + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + if (selected_inputs.size() == 1) { + /* A single input is selected, so just pass through this value without regarding others. */ + const SocketInContext selected_input{socket.context, selected_inputs[0]}; + const std::optional input_value = all_socket_values_.lookup_try( + selected_input); + if (!input_value.has_value()) { + this->push_value_task(selected_input); + return; + } + all_socket_values_.add_new(socket, *input_value); + return; + } + + /* Multiple inputs are selected. */ + if (node->typeinfo->build_multi_function) { + /* Try to compute the output value from the multiple selected inputs. */ + this->value_task__output__multi_function_node(socket); + return; + } + /* Can't compute the output value, so set it to be unknown. */ + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + } + + void value_task__output__generic_eval( + const SocketInContext &socket, + const FunctionRef(Span inputs)> eval_fn) + { + const NodeInContext node = socket.owner_node(); + const int inputs_num = node->input_sockets().size(); + + Array input_values(inputs_num, InferenceValue::Unknown()); + std::optional next_unknown_input_index; + for (const int input_i : IndexRange(inputs_num)) { + const SocketInContext input_socket = node.input_socket(input_i); + if (!input_socket->is_available()) { + continue; + } + const std::optional input_value = all_socket_values_.lookup_try( + input_socket); + if (!input_value.has_value()) { + next_unknown_input_index = input_i; + break; + } + input_values[input_i] = *input_value; + } + const std::optional output_value = eval_fn(input_values); + if (output_value.has_value()) { + /* Was able to compute the output value. */ + all_socket_values_.add_new(socket, *output_value); + return; + } + if (!next_unknown_input_index.has_value()) { + /* The output is still unknown even though we know as much about the inputs as possible + * already. */ + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + /* Request the next input socket. */ + const SocketInContext next_input = node.input_socket(*next_unknown_input_index); + this->push_value_task(next_input); + } + + void value_task__output__multi_function_node(const SocketInContext &socket) + { + const NodeInContext node = socket.owner_node(); + const int inputs_num = node->input_sockets().size(); + + /* Gather all input values are return early if any of them is not known. */ + Vector input_values(inputs_num); + for (const int input_i : IndexRange(inputs_num)) { + const SocketInContext input_socket = node.input_socket(input_i); + const std::optional input_value = all_socket_values_.lookup_try( + input_socket); + if (!input_value.has_value()) { + this->push_value_task(input_socket); + return; + } + if (input_value->is_unknown()) { + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + input_values[input_i] = input_value->data(); + } + + /* Get the multi-function for the node. */ + NodeMultiFunctionBuilder builder{*node.node, node->owner_tree()}; + node->typeinfo->build_multi_function(builder); + const mf::MultiFunction &fn = builder.function(); + + /* We only evaluate the node for a single value here. */ + const IndexMask mask(1); + + /* Prepare parameters for the multi-function evaluation. */ + mf::ParamsBuilder params{fn, &mask}; + for (const int input_i : IndexRange(inputs_num)) { + const SocketInContext input_socket = node.input_socket(input_i); + if (!input_socket->is_available()) { + continue; + } + params.add_readonly_single_input( + GPointer(input_socket->typeinfo->base_cpp_type, input_values[input_i])); + } + for (const int output_i : node->output_sockets().index_range()) { + const SocketInContext output_socket = node.output_socket(output_i); + if (!output_socket->is_available()) { + continue; + } + /* Allocate memory for the output value. */ + const CPPType &base_type = *output_socket->typeinfo->base_cpp_type; + void *value = scope_.allocate_owned(base_type); + params.add_uninitialized_single_output(GMutableSpan(base_type, value, 1)); + all_socket_values_.add_new(output_socket, InferenceValue(value)); + } + mf::ContextBuilder context; + /* Actually evaluate the multi-function. The outputs will be written into the memory allocated + * earlier, which has been added to #all_socket_values_ already. */ + fn.call(mask, params, context); + } + + void value_task__output__muted_node(const SocketInContext &socket) + { + const NodeInContext node = socket.owner_node(); + + SocketInContext input_socket; + for (const bNodeLink &internal_link : node->internal_links()) { + if (internal_link.tosock == socket.socket) { + input_socket = SocketInContext{socket.context, internal_link.fromsock}; + break; + } + } + if (!input_socket) { + /* The output does not have an internal link to an input. */ + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + const std::optional input_value = all_socket_values_.lookup_try(input_socket); + if (!input_value.has_value()) { + this->push_value_task(input_socket); + return; + } + const void *converted_value = this->convert_type_if_necessary( + input_value->data(), *input_socket.socket, *socket.socket); + all_socket_values_.add_new(socket, InferenceValue(converted_value)); + } + + void value_task__input(const SocketInContext &socket) + { + if (socket->is_multi_input()) { + /* Can't know the single value of a multi-input. */ + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + const bNodeLink *source_link = nullptr; + const Span connected_links = socket->directly_linked_links(); + for (const bNodeLink *link : connected_links) { + if (!link->is_used()) { + continue; + } + if (link->fromnode->is_dangling_reroute()) { + continue; + } + source_link = link; + break; + } + if (!source_link) { + this->value_task__input__unlinked(socket); + return; + } + this->value_task__input__linked({socket.context, source_link->fromsock}, socket); + } + + void value_task__input__unlinked(const SocketInContext &socket) + { + if (this->treat_socket_as_unknown(socket)) { + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + if (animated_sockets_.contains(socket.socket)) { + /* The value of animated sockets is not known statically. */ + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + if (const SocketDeclaration *socket_decl = socket.socket->runtime->declaration) { + if (socket_decl->input_field_type == InputSocketFieldType::Implicit) { + /* Implicit fields inputs don't have a single static value. */ + all_socket_values_.add_new(socket, InferenceValue::Unknown()); + return; + } + } + + void *value_buffer = scope_.allocate_owned(*socket->typeinfo->base_cpp_type); + socket->typeinfo->get_base_cpp_value(socket->default_value, value_buffer); + all_socket_values_.add_new(socket, InferenceValue(value_buffer)); + } + + void value_task__input__linked(const SocketInContext &from_socket, + const SocketInContext &to_socket) + { + const std::optional from_value = all_socket_values_.lookup_try(from_socket); + if (!from_value.has_value()) { + this->push_value_task(from_socket); + return; + } + const void *converted_value = this->convert_type_if_necessary( + from_value->data(), *from_socket.socket, *to_socket.socket); + all_socket_values_.add_new(to_socket, InferenceValue(converted_value)); + } + + const void *convert_type_if_necessary(const void *src, + const bNodeSocket &from_socket, + const bNodeSocket &to_socket) + { + if (!src) { + return nullptr; + } + const CPPType *from_type = from_socket.typeinfo->base_cpp_type; + const CPPType *to_type = to_socket.typeinfo->base_cpp_type; + if (from_type == to_type) { + return src; + } + if (!to_type) { + return nullptr; + } + const bke::DataTypeConversions &conversions = bke::get_implicit_type_conversions(); + if (!conversions.is_convertible(*from_type, *to_type)) { + return nullptr; + } + void *dst = scope_.allocate_owned(*to_type); + conversions.convert_to_uninitialized(*from_type, *to_type, src, dst); + return dst; + } + + bool treat_socket_as_unknown(const SocketInContext &socket) const + { + if (!top_level_ignored_inputs_.has_value()) { + return false; + } + if (socket.context) { + return false; + } + if (socket->is_output()) { + return false; + } + return (*top_level_ignored_inputs_)[socket->index_in_all_inputs()]; + } + + void ensure_animation_data_processed(const bNodeTree &tree) + { + if (!trees_with_handled_animation_data_.add(&tree)) { + return; + } + if (!tree.adt) { + return; + } + + static std::regex pattern(R"#(nodes\["(.*)"\].inputs\[(\d+)\].default_value)#"); + MultiValueMap animated_inputs_by_node_name; + auto handle_rna_path = [&](const char *rna_path) { + std::cmatch match; + if (!std::regex_match(rna_path, match, pattern)) { + return; + } + const StringRef node_name{match[1].first, match[1].second - match[1].first}; + const int socket_index = std::stoi(match[2]); + animated_inputs_by_node_name.add(node_name, socket_index); + }; + + /* Gather all inputs controlled by fcurves. */ + if (tree.adt->action) { + animrig::foreach_fcurve_in_action_slot( + tree.adt->action->wrap(), tree.adt->slot_handle, [&](const FCurve &fcurve) { + handle_rna_path(fcurve.rna_path); + }); + } + /* Gather all inputs controlled by drivers. */ + LISTBASE_FOREACH (const FCurve *, driver, &tree.adt->drivers) { + handle_rna_path(driver->rna_path); + } + + /* Actually find the #bNodeSocket for each controlled input. */ + if (!animated_inputs_by_node_name.is_empty()) { + for (const bNode *node : tree.all_nodes()) { + const Span animated_inputs = animated_inputs_by_node_name.lookup(node->name); + const Span input_sockets = node->input_sockets(); + for (const int socket_index : animated_inputs) { + if (socket_index < 0 || socket_index >= input_sockets.size()) { + /* This can happen when the animation data is not immediately updated after a socket is + * removed. */ + continue; + } + const bNodeSocket &socket = *input_sockets[socket_index]; + animated_sockets_.add(&socket); + } + } + } + } + + void push_value_task(const SocketInContext &socket) + { + value_tasks_.push(socket); + } + + static const bNodeSocket *get_first_available_bsocket(const Span sockets) + { + for (const bNodeSocket *socket : sockets) { + if (socket->is_available()) { + return socket; + } + } + return nullptr; + } +}; + +SocketValueInferencer::SocketValueInferencer( + const bNodeTree &tree, + ResourceScope &scope, + bke::ComputeContextCache &compute_context_cache, + const std::optional> tree_input_values, + const std::optional> top_level_ignored_inputs) + : impl_(scope.construct( + tree, scope, compute_context_cache, tree_input_values, top_level_ignored_inputs)) +{ +} + +InferenceValue SocketValueInferencer::get_socket_value(const SocketInContext &socket) +{ + return impl_.get_socket_value(socket); +} + +namespace switch_node_inference_utils { + +bool is_socket_selected__switch(const SocketInContext &socket, const InferenceValue &condition) +{ + const bool is_true = condition.get_known(); + const int selected_index = is_true ? 2 : 1; + return socket->index() == selected_index; +} + +bool is_socket_selected__index_switch(const SocketInContext &socket, + const InferenceValue &condition) +{ + const int index = condition.get_known(); + return socket->index() == index + 1; +} + +bool is_socket_selected__menu_switch(const SocketInContext &socket, + const InferenceValue &condition) +{ + const NodeMenuSwitch &storage = *static_cast( + socket->owner_node().storage); + const int menu_value = condition.get_known(); + const NodeEnumItem &item = storage.enum_definition.items_array[socket->index() - 1]; + return menu_value == item.identifier; +} + +bool is_socket_selected__mix_node(const SocketInContext &socket, const InferenceValue &condition) +{ + const NodeShaderMix &storage = *static_cast(socket.owner_node()->storage); + if (storage.data_type == SOCK_RGBA && storage.blend_type != MA_RAMP_BLEND) { + return true; + } + + const bool clamp_factor = storage.clamp_factor != 0; + bool only_a = false; + bool only_b = false; + if (storage.data_type == SOCK_VECTOR && storage.factor_mode == NODE_MIX_MODE_NON_UNIFORM) { + const float3 mix_factor = condition.get_known(); + if (clamp_factor) { + only_a = mix_factor.x <= 0.0f && mix_factor.y <= 0.0f && mix_factor.z <= 0.0f; + only_b = mix_factor.x >= 1.0f && mix_factor.y >= 1.0f && mix_factor.z >= 1.0f; + } + else { + only_a = float3{0.0f, 0.0f, 0.0f} == mix_factor; + only_b = float3{1.0f, 1.0f, 1.0f} == mix_factor; + } + } + else { + const float mix_factor = condition.get_known(); + if (clamp_factor) { + only_a = mix_factor <= 0.0f; + only_b = mix_factor >= 1.0f; + } + else { + only_a = mix_factor == 0.0f; + only_b = mix_factor == 1.0f; + } + } + if (only_a) { + if (STREQ(socket->name, "B")) { + return false; + } + } + if (only_b) { + if (STREQ(socket->name, "A")) { + return false; + } + } + return true; +} + +bool is_socket_selected__shader_mix_node(const SocketInContext &socket, + const InferenceValue &condition) +{ + const float mix_factor = condition.get_known(); + if (mix_factor == 0.0f) { + if (STREQ(socket->identifier, "Shader_001")) { + return false; + } + } + else if (mix_factor == 1.0f) { + if (STREQ(socket->identifier, "Shader")) { + return false; + } + } + return true; +} + +} // namespace switch_node_inference_utils + +} // namespace blender::nodes