Nodes: add selection output for each item in Menu Switch node

This adds a boolean output for each of the menu items. The output is true, if
the passed in menu value is that item. This avoids the need to compare the
output value to the input values to get a boolean for whether a specific menu
item was passed in.

Support is added for Geometry Nodes as well as the Compositor. Usage/Value
inferencing has been updated as well.

Pull Request: https://projects.blender.org/blender/blender/pulls/145712
This commit is contained in:
Jacques Lucke
2025-09-08 13:01:47 +02:00
parent ff468e44ee
commit bfc5f8d51c
5 changed files with 141 additions and 40 deletions

View File

@@ -2,6 +2,7 @@
*
* SPDX-License-Identifier: GPL-2.0-or-later */
#include "BLI_array_utils.hh"
#include "node_geometry_util.hh"
#include "DNA_node_types.h"
@@ -38,6 +39,9 @@ NODE_STORAGE_FUNCS(NodeMenuSwitch)
static void node_declare(blender::nodes::NodeDeclarationBuilder &b)
{
b.use_custom_socket_order();
b.allow_any_socket_order();
const bNodeTree *ntree = b.tree_or_null();
const bNode *node = b.node_or_null();
if (node == nullptr) {
@@ -62,6 +66,17 @@ static void node_declare(blender::nodes::NodeDeclarationBuilder &b)
menu_structure_type = StructureType::Single;
}
auto &output = b.add_output(data_type, "Output");
if (supports_fields) {
output.dependent_field().reference_pass_all();
}
else if (data_type == SOCK_GEOMETRY) {
output.propagate_all();
}
output.structure_type(value_structure_type);
b.add_default_layout();
auto &menu = b.add_input<decl::Menu>("Menu");
if (supports_fields) {
menu.supports_field();
@@ -70,7 +85,7 @@ static void node_declare(blender::nodes::NodeDeclarationBuilder &b)
for (const NodeEnumItem &enum_item : storage.enum_definition.items()) {
const std::string identifier = MenuSwitchItemsAccessor::socket_identifier_for_item(enum_item);
auto &input = b.add_input(data_type, enum_item.name, std::move(identifier))
auto &input = b.add_input(data_type, enum_item.name, identifier)
.socket_name_ptr(
&ntree->id, MenuSwitchItemsAccessor::item_srna, &enum_item, "name")
.compositor_realization_mode(CompositorInputRealizationMode::None);
@@ -80,17 +95,14 @@ static void node_declare(blender::nodes::NodeDeclarationBuilder &b)
/* Labels are ugly in combination with data-block pickers and are usually disabled. */
input.hide_label(ELEM(data_type, SOCK_OBJECT, SOCK_IMAGE, SOCK_COLLECTION, SOCK_MATERIAL));
input.structure_type(value_structure_type);
auto &item_output =
b.add_output<decl::Bool>(enum_item.name, std::move(identifier)).align_with_previous();
if (supports_fields) {
item_output.dependent_field({menu.index()});
item_output.structure_type(menu_structure_type);
}
}
auto &output = b.add_output(data_type, "Output");
if (supports_fields) {
output.dependent_field().reference_pass_all();
}
else if (data_type == SOCK_GEOMETRY) {
output.propagate_all();
}
output.structure_type(value_structure_type);
b.add_input<decl::Extend>("", "__extend__").structure_type(StructureType::Dynamic);
}
@@ -167,7 +179,10 @@ class MenuSwitchFn : public mf::MultiFunction {
for (const NodeEnumItem &enum_item : enum_def.items()) {
builder.single_input(enum_item.name, type);
}
builder.single_output("Output", type);
builder.single_output("Output", type, mf::ParamFlag::SupportsUnusedOutput);
for (const NodeEnumItem &item : enum_def.items()) {
builder.single_output<bool>(item.name, mf::ParamFlag::SupportsUnusedOutput);
}
this->set_signature(&signature_);
}
@@ -180,8 +195,15 @@ class MenuSwitchFn : public mf::MultiFunction {
/* Use one extra mask at the end for invalid indices. */
const int invalid_index = inputs_num;
GMutableSpan output = params.uninitialized_single_output(
signature_.params.index_range().last(), "Output");
GMutableSpan value_output = params.uninitialized_single_output_if_required(1 + inputs_num,
"Output");
Array<MutableSpan<bool>> item_mask_outputs(inputs_num);
for (const int item_i : IndexRange(inputs_num)) {
const int param_index = 2 + inputs_num + item_i;
item_mask_outputs[item_i] = params.uninitialized_single_output_if_required<bool>(
param_index);
}
auto find_item_index = [&](const MenuValue value) -> int {
for (const int i : enum_def_.items().index_range()) {
@@ -196,11 +218,27 @@ class MenuSwitchFn : public mf::MultiFunction {
if (const std::optional<MenuValue> value = values.get_if_single()) {
const int index = find_item_index(*value);
if (index < inputs_num) {
const GVArray inputs = params.readonly_single_input(value_inputs_start + index);
inputs.materialize_to_uninitialized(mask, output.data());
if (!value_output.is_empty()) {
const GVArray inputs = params.readonly_single_input(value_inputs_start + index);
inputs.materialize_to_uninitialized(mask, value_output.data());
}
for (const int item_i : IndexRange(inputs_num)) {
MutableSpan<bool> item_mask_output = item_mask_outputs[item_i];
if (!item_mask_output.is_empty()) {
index_mask::masked_fill(item_mask_output, item_i == index, mask);
}
}
}
else {
type_.fill_construct_indices(type_.default_value(), output.data(), mask);
if (!value_output.is_empty()) {
type_.fill_construct_indices(type_.default_value(), value_output.data(), mask);
}
for (const int item_i : IndexRange(inputs_num)) {
MutableSpan<bool> item_mask_output = item_mask_outputs[item_i];
if (!item_mask_output.is_empty()) {
index_mask::masked_fill(item_mask_output, false, mask);
}
}
}
return;
}
@@ -210,14 +248,23 @@ class MenuSwitchFn : public mf::MultiFunction {
IndexMask::from_groups<int64_t>(
mask, memory, [&](const int64_t i) { return find_item_index(values[i]); }, masks);
for (const int i : IndexRange(inputs_num)) {
if (!masks[i].is_empty()) {
const GVArray inputs = params.readonly_single_input(value_inputs_start + i);
inputs.materialize_to_uninitialized(masks[i], output.data());
for (const int item_i : IndexRange(inputs_num)) {
const IndexMask &mask_for_index = masks[item_i];
if (!mask_for_index.is_empty() && !value_output.is_empty()) {
const GVArray inputs = params.readonly_single_input(value_inputs_start + item_i);
inputs.materialize_to_uninitialized(mask_for_index, value_output.data());
}
MutableSpan<bool> item_mask_output = item_mask_outputs[item_i];
if (!item_mask_output.is_empty()) {
if (mask.size() != mask_for_index.size()) {
/* First set output to false before setting selected items to true. */
index_mask::masked_fill(item_mask_output, false, mask);
}
index_mask::masked_fill(item_mask_output, true, mask_for_index);
}
}
type_.fill_construct_indices(type_.default_value(), output.data(), masks[invalid_index]);
type_.fill_construct_indices(type_.default_value(), value_output.data(), masks[invalid_index]);
}
};
@@ -252,6 +299,11 @@ class LazyFunctionForMenuSwitchNode : public LazyFunction {
}
lf_index_by_bsocket[node.output_socket(0).index_in_tree()] = outputs_.append_and_get_index_as(
"Value", CPPType::get<bke::SocketValueVariant>());
for (const int i : enum_def_.items().index_range()) {
const NodeEnumItem &enum_item = enum_def_.items()[i];
lf_index_by_bsocket[node.output_socket(i + 1).index_in_tree()] =
outputs_.append_and_get_index_as(enum_item.name, CPPType::get<SocketValueVariant>());
}
}
void execute_impl(lf::Params &params, const lf::Context & /*context*/) const override
@@ -270,7 +322,8 @@ class LazyFunctionForMenuSwitchNode : public LazyFunction {
for (const int i : IndexRange(enum_def_.items_num)) {
const NodeEnumItem &enum_item = enum_def_.items_array[i];
const int input_index = i + 1;
if (enum_item.identifier == condition.value) {
const bool is_selected = enum_item.identifier == condition.value;
if (is_selected) {
SocketValueVariant *value_to_forward =
params.try_get_input_data_ptr_or_request<SocketValueVariant>(input_index);
if (value_to_forward == nullptr) {
@@ -283,6 +336,9 @@ class LazyFunctionForMenuSwitchNode : public LazyFunction {
else {
params.set_input_unused(input_index);
}
if (!params.output_was_set(i + 1)) {
params.set_output(i + 1, SocketValueVariant(is_selected));
}
}
/* No guarantee that the switch input matches any enum,
* set default outputs to ensure valid state. */
@@ -310,11 +366,13 @@ class LazyFunctionForMenuSwitchNode : public LazyFunction {
}
std::unique_ptr<MultiFunction> multi_function = std::make_unique<MenuSwitchFn>(
enum_def_, *field_base_type_);
GField output_field{FieldOperation::from(std::move(multi_function), std::move(item_fields))};
std::shared_ptr<fn::FieldOperation> operation = FieldOperation::from(std::move(multi_function),
std::move(item_fields));
void *output_ptr = params.get_output_data_ptr(0);
SocketValueVariant::ConstructIn(output_ptr, std::move(output_field));
params.output_set(0);
params.set_output(0, SocketValueVariant::From(GField(operation, 0)));
for (const int item_i : IndexRange(enum_def_.items_num)) {
params.set_output(item_i + 1, SocketValueVariant::From(GField(operation, item_i + 1)));
}
}
};
@@ -364,24 +422,33 @@ class MenuSwitchOperation : public NodeOperation {
void execute() override
{
Result &output = this->get_result("Output");
Result &value_output = this->get_result("Output");
const MenuValue menu_identifier = this->get_input("Menu").get_single_value<MenuValue>();
const NodeEnumDefinition &enum_definition = node_storage(bnode()).enum_definition;
bool found_item = false;
for (const int i : IndexRange(enum_definition.items_num)) {
const NodeEnumItem &enum_item = enum_definition.items()[i];
if (enum_item.identifier != menu_identifier.value) {
continue;
}
const std::string identifier = MenuSwitchItemsAccessor::socket_identifier_for_item(
enum_item);
const bool is_selected = enum_item.identifier == menu_identifier.value;
Result &item_output = this->get_result(identifier);
if (item_output.should_compute()) {
item_output.allocate_single_value();
item_output.set_single_value(is_selected);
}
if (!is_selected) {
continue;
}
const Result &input = this->get_input(identifier);
output.share_data(input);
return;
value_output.share_data(input);
found_item = true;
}
/* The menu identifier didn't match any item, so allocate an invalid output. */
output.allocate_invalid();
if (!found_item) {
/* The menu identifier didn't match any item, so allocate an invalid output. */
value_output.allocate_invalid();
}
}
};

View File

@@ -3599,12 +3599,13 @@ struct GeometryNodesLazyFunctionBuilder {
input_index++;
}
}
int output_index = 0;
for (const bNodeSocket *bsocket : bnode.output_sockets()) {
if (bsocket->is_available()) {
lf::OutputSocket &lf_socket = lf_node.output(0);
lf::OutputSocket &lf_socket = lf_node.output(output_index);
graph_params.lf_output_by_bsocket.add(bsocket, &lf_socket);
mapping_->bsockets_by_lf_socket_map.add(&lf_socket, bsocket);
break;
output_index++;
}
}

View File

@@ -182,8 +182,13 @@ struct SocketUsageInferencer {
break;
}
case GEO_NODE_MENU_SWITCH: {
this->usage_task__input__generic_switch(
socket, switch_node_inference_utils::is_socket_selected__menu_switch);
if (socket->index() == 0) {
this->usage_task__input__fallback(socket);
}
else {
this->usage_task__input__generic_switch(
socket, switch_node_inference_utils::is_socket_selected__menu_switch);
}
break;
}
case SH_NODE_MIX: {

View File

@@ -169,8 +169,13 @@ class SocketValueInferencerImpl {
return;
}
case GEO_NODE_MENU_SWITCH: {
this->value_task__output__generic_switch(
socket, switch_node_inference_utils::is_socket_selected__menu_switch);
if (socket->index() == 0) {
this->value_task__output__generic_switch(
socket, switch_node_inference_utils::is_socket_selected__menu_switch);
}
else {
this->value_task__output__menu_switch_selection(socket);
}
return;
}
case SH_NODE_MIX: {
@@ -272,6 +277,26 @@ class SocketValueInferencerImpl {
all_socket_values_.add_new(socket, *value);
}
void value_task__output__menu_switch_selection(const SocketInContext &socket)
{
const NodeInContext node = socket.owner_node();
const SocketInContext input_socket = node.input_socket(0);
const std::optional<InferenceValue> value = all_socket_values_.lookup_try(input_socket);
if (!value.has_value()) {
this->push_value_task(input_socket);
return;
}
const std::optional<MenuValue> menu_value = value->get_if_primitive<MenuValue>();
if (!menu_value.has_value()) {
all_socket_values_.add_new(socket, InferenceValue::Unknown());
return;
}
const NodeMenuSwitch &storage = *static_cast<const NodeMenuSwitch *>(node->storage);
const NodeEnumItem &item = storage.enum_definition.items_array[socket->index() - 1];
const bool is_selected = item.identifier == menu_value->value;
all_socket_values_.add_new(socket, this->make_primitive_inference_value(is_selected));
}
void value_task__output__float_math(const SocketInContext &socket)
{
const NodeInContext node = socket.owner_node();

Binary file not shown.