diff --git a/source/blender/editors/space_node/space_node.cc b/source/blender/editors/space_node/space_node.cc index 56a85b6e64a..71da08f8c62 100644 --- a/source/blender/editors/space_node/space_node.cc +++ b/source/blender/editors/space_node/space_node.cc @@ -473,16 +473,27 @@ static std::optional compute_context_for_tree_path( return current; } +static bool is_evaluate_closure_node_input(const nodes::SocketInContext &socket) +{ + return socket->is_input() && socket->index() == 0 && + socket.owner_node()->is_type("GeometryNodeEvaluateClosure"); +} + +static bool is_closure_zone_output_socket(const nodes::SocketInContext &socket) +{ + return socket->owner_node().is_type("GeometryNodeClosureOutput") && socket->is_output(); +} + static Vector find_origin_sockets_through_contexts( nodes::SocketInContext start_socket, bke::ComputeContextCache &compute_context_cache, - StringRef query_node_idname, + FunctionRef handle_possible_origin_socket_fn, bool find_all); static Vector find_target_sockets_through_contexts( const nodes::SocketInContext start_socket, bke::ComputeContextCache &compute_context_cache, - const StringRef query_node_idname, + const FunctionRef handle_possible_target_socket_fn, const bool find_all) { using BundlePath = Vector; @@ -519,7 +530,7 @@ static Vector find_target_sockets_through_contexts( } continue; } - if (bundle_path.is_empty() && node->is_type(query_node_idname)) { + if (bundle_path.is_empty() && handle_possible_target_socket_fn(socket)) { found_targets.add(socket); if (!find_all) { break; @@ -583,7 +594,7 @@ static Vector find_target_sockets_through_contexts( node->storage); const StringRef key = closure_storage.output_items.items[socket->index()].name; const Vector target_sockets = find_target_sockets_through_contexts( - node.output_socket(0), compute_context_cache, "GeometryNodeEvaluateClosure", true); + node.output_socket(0), compute_context_cache, is_evaluate_closure_node_input, true); for (const auto &target_socket : target_sockets) { const nodes::NodeInContext evaluate_node = target_socket.owner_node(); const auto &evaluate_storage = *static_cast( @@ -606,7 +617,7 @@ static Vector find_target_sockets_through_contexts( node->storage); const StringRef key = evaluate_storage.input_items.items[socket->index() - 1].name; const Vector origin_sockets = find_origin_sockets_through_contexts( - node.input_socket(0), compute_context_cache, "GeometryNodeClosureOutput", true); + node.input_socket(0), compute_context_cache, is_closure_zone_output_socket, true); for (const nodes::SocketInContext origin_socket : origin_sockets) { const bNodeTree &closure_tree = origin_socket->owner_tree(); const bke::bNodeTreeZones *closure_tree_zones = closure_tree.zones(); @@ -679,7 +690,7 @@ static Vector find_target_sockets_through_contexts( const Vector target_sockets = find_target_sockets_through_contexts( {closure_socket_context, &closure_socket}, compute_context_cache, - "GeometryNodeEvaluateClosure", + is_evaluate_closure_node_input, false); if (target_sockets.is_empty()) { return nullptr; @@ -695,7 +706,7 @@ static Vector find_target_sockets_through_contexts( static Vector find_origin_sockets_through_contexts( const nodes::SocketInContext start_socket, bke::ComputeContextCache &compute_context_cache, - const StringRef query_node_idname, + const FunctionRef handle_possible_origin_socket_fn, const bool find_all) { using BundlePath = Vector; @@ -724,6 +735,13 @@ static Vector find_origin_sockets_through_contexts( const BundlePath &bundle_path = socket_to_check.bundle_path; const nodes::NodeInContext &node = socket.owner_node(); if (socket->is_input()) { + if (bundle_path.is_empty() && handle_possible_origin_socket_fn(socket)) { + found_origins.add(socket); + if (!find_all) { + break; + } + continue; + } const bke::bNodeTreeZones *zones = node->owner_tree().zones(); if (!zones) { continue; @@ -762,7 +780,7 @@ static Vector find_origin_sockets_through_contexts( } continue; } - if (bundle_path.is_empty() && node->is_type(query_node_idname)) { + if (bundle_path.is_empty() && handle_possible_origin_socket_fn(socket)) { found_origins.add(socket); if (!find_all) { break; @@ -804,7 +822,7 @@ static Vector find_origin_sockets_through_contexts( node->storage); const StringRef key = evaluate_storage.output_items.items[socket->index()].name; const Vector origin_sockets = find_origin_sockets_through_contexts( - node.input_socket(0), compute_context_cache, "GeometryNodeClosureOutput", true); + node.input_socket(0), compute_context_cache, is_closure_zone_output_socket, true); for (const nodes::SocketInContext origin_socket : origin_sockets) { const bNodeTree &closure_tree = origin_socket->owner_tree(); const nodes::NodeInContext closure_output_node = origin_socket.owner_node(); @@ -839,7 +857,7 @@ static Vector find_origin_sockets_through_contexts( const Vector target_sockets = find_target_sockets_through_contexts( {socket.context, &closure_output_socket}, compute_context_cache, - "GeometryNodeEvaluateClosure", + is_evaluate_closure_node_input, true); for (const nodes::SocketInContext &target_socket : target_sockets) { const nodes::NodeInContext target_node = target_socket.owner_node(); @@ -889,7 +907,9 @@ Vector gather_linked_target_bundle_signatures( const Vector target_sockets = find_target_sockets_through_contexts( {bundle_socket_context, &bundle_socket}, compute_context_cache, - "GeometryNodeSeparateBundle", + [](const nodes::SocketInContext &socket) { + return socket->owner_node().is_type("GeometryNodeSeparateBundle"); + }, true); Vector signatures; for (const nodes::SocketInContext &target_socket : target_sockets) { @@ -907,7 +927,9 @@ Vector gather_linked_origin_bundle_signatures( const Vector origin_sockets = find_origin_sockets_through_contexts( {bundle_socket_context, &bundle_socket}, compute_context_cache, - "GeometryNodeCombineBundle", + [](const nodes::SocketInContext &socket) { + return socket->owner_node().is_type("GeometryNodeCombineBundle"); + }, true); Vector signatures; for (const nodes::SocketInContext &origin_socket : origin_sockets) { @@ -925,7 +947,7 @@ Vector gather_linked_target_closure_signatures( const Vector target_sockets = find_target_sockets_through_contexts( {closure_socket_context, &closure_socket}, compute_context_cache, - "GeometryNodeEvaluateClosure", + is_evaluate_closure_node_input, true); Vector signatures; for (const nodes::SocketInContext &target_socket : target_sockets) { @@ -940,16 +962,19 @@ Vector gather_linked_origin_closure_signatures( const bNodeSocket &closure_socket, bke::ComputeContextCache &compute_context_cache) { - const Vector origin_sockets = find_origin_sockets_through_contexts( + Vector signatures; + find_origin_sockets_through_contexts( {closure_socket_context, &closure_socket}, compute_context_cache, - "GeometryNodeClosureOutput", + [&](const nodes::SocketInContext &socket) { + if (is_closure_zone_output_socket(socket)) { + signatures.append( + nodes::ClosureSignature::from_closure_output_node(socket->owner_node())); + return true; + } + return false; + }, true); - Vector signatures; - for (const nodes::SocketInContext &origin_socket : origin_sockets) { - const nodes::NodeInContext &origin_node = origin_socket.owner_node(); - signatures.append(nodes::ClosureSignature::from_closure_output_node(*origin_node.node)); - } return signatures; }