Refactor: Nodes: extract value inferencer construction

This extracts the construction of the `SocketValueInferencer` out of
`SocketUsageInferencer`. This leads to better separation of concerns and gives
the caller more flexibility. In the future, I especially want to get information
about which group input values were required to determine the usage of other
group inputs. This might help with caching the inferenced values.

Pull Request: https://projects.blender.org/blender/blender/pulls/147352
This commit is contained in:
Jacques Lucke
2025-10-04 15:18:24 +02:00
parent d990026fd6
commit 2323bd2691
5 changed files with 91 additions and 99 deletions

View File

@@ -34,10 +34,9 @@ class SocketUsageInferencer {
public:
SocketUsageInferencer(const bNodeTree &tree,
std::optional<Span<InferenceValue>> tree_input_values,
ResourceScope &scope,
SocketValueInferencer &value_inferencer,
bke::ComputeContextCache &compute_context_cache,
std::optional<Span<bool>> top_level_ignored_inputs = std::nullopt,
bool ignore_top_level_node_muting = false);
bool is_socket_used(const SocketInContext &socket);

View File

@@ -82,11 +82,12 @@ class SocketValueInferencer {
SocketValueInferencerImpl &impl_;
public:
SocketValueInferencer(const bNodeTree &tree,
ResourceScope &scope,
bke::ComputeContextCache &compute_context_cache,
const std::optional<Span<InferenceValue>> tree_input_values,
const std::optional<Span<bool>> top_level_ignored_inputs);
SocketValueInferencer(
const bNodeTree &tree,
ResourceScope &scope,
bke::ComputeContextCache &compute_context_cache,
FunctionRef<InferenceValue(int group_input_i)> group_input_value_fn = nullptr,
std::optional<Span<bool>> top_level_ignored_inputs = std::nullopt);
InferenceValue get_socket_value(const SocketInContext &socket);
};

View File

@@ -385,8 +385,13 @@ static void foreach_active_gizmo_exposed_to_modifier(
ResourceScope scope;
const Vector<InferenceValue> input_values = get_geometry_nodes_input_inference_values(
*nmd.node_group, nmd.settings.properties, scope);
SocketValueInferencer value_inferencer{
*nmd.node_group, scope, compute_context_cache, [&](const int group_input_i) {
return input_values[group_input_i];
}};
socket_usage_inference::SocketUsageInferencer usage_inferencer(
*nmd.node_group, input_values, scope, compute_context_cache);
*nmd.node_group, scope, value_inferencer, compute_context_cache);
const ComputeContext &root_compute_context = compute_context_cache.for_modifier(nullptr, nmd);
for (auto &&item : tree.runtime->gizmo_propagation->gizmo_inputs_by_group_inputs.items()) {

View File

@@ -35,11 +35,10 @@ class SocketUsageInferencerImpl {
private:
friend InputSocketUsageParams;
ResourceScope &scope_;
bke::ComputeContextCache &compute_context_cache_;
/** Inferences the socket values if possible. */
SocketValueInferencer value_inferencer_;
SocketValueInferencer &value_inferencer_;
/** Root node tree. */
const bNodeTree &root_tree_;
@@ -77,15 +76,11 @@ class SocketUsageInferencerImpl {
SocketUsageInferencer *owner_ = nullptr;
SocketUsageInferencerImpl(const bNodeTree &tree,
const std::optional<Span<InferenceValue>> tree_input_values,
ResourceScope &scope,
SocketValueInferencer &value_inferencer,
bke::ComputeContextCache &compute_context_cache,
const std::optional<Span<bool>> top_level_ignored_inputs,
const bool ignore_top_level_node_muting)
: scope_(scope),
compute_context_cache_(compute_context_cache),
value_inferencer_(
tree, scope_, compute_context_cache_, tree_input_values, top_level_ignored_inputs),
: compute_context_cache_(compute_context_cache),
value_inferencer_(value_inferencer),
root_tree_(tree),
ignore_top_level_node_muting_(ignore_top_level_node_muting)
{
@@ -735,19 +730,13 @@ class SocketUsageInferencerImpl {
}
};
SocketUsageInferencer::SocketUsageInferencer(
const bNodeTree &tree,
const std::optional<Span<InferenceValue>> tree_input_values,
ResourceScope &scope,
bke::ComputeContextCache &compute_context_cache,
const std::optional<Span<bool>> top_level_ignored_inputs,
const bool ignore_top_level_node_muting)
: impl_(scope.construct<SocketUsageInferencerImpl>(tree,
tree_input_values,
scope,
compute_context_cache,
top_level_ignored_inputs,
ignore_top_level_node_muting))
SocketUsageInferencer::SocketUsageInferencer(const bNodeTree &tree,
ResourceScope &scope,
SocketValueInferencer &value_inferencer,
bke::ComputeContextCache &compute_context_cache,
const bool ignore_top_level_node_muting)
: impl_(scope.construct<SocketUsageInferencerImpl>(
tree, value_inferencer, compute_context_cache, ignore_top_level_node_muting))
{
impl_.owner_ = this;
}
@@ -776,15 +765,13 @@ Array<SocketUsage> infer_all_sockets_usage(const bNodeTree &tree)
{
/* Find actual socket usages. */
SocketUsageInferencer inferencer{tree,
std::nullopt,
scope,
compute_context_cache,
std::nullopt,
ignore_top_level_node_muting};
inferencer.mark_top_level_node_outputs_as_used();
SocketValueInferencer value_inferencer{tree, scope, compute_context_cache};
SocketUsageInferencer usage_inferencer{
tree, scope, value_inferencer, compute_context_cache, ignore_top_level_node_muting};
usage_inferencer.mark_top_level_node_outputs_as_used();
for (const bNodeSocket *socket : all_input_sockets) {
all_usages[socket->index_in_tree()].is_used = inferencer.is_socket_used({nullptr, socket});
all_usages[socket->index_in_tree()].is_used = usage_inferencer.is_socket_used(
{nullptr, socket});
}
}
@@ -797,20 +784,22 @@ Array<SocketUsage> infer_all_sockets_usage(const bNodeTree &tree)
only_controllers_used[i] = !input_may_affect_visibility(socket);
}
});
SocketUsageInferencer inferencer_all_unknown{tree,
std::nullopt,
scope,
compute_context_cache,
all_ignored_inputs,
ignore_top_level_node_muting};
SocketUsageInferencer inferencer_only_controllers{tree,
std::nullopt,
scope,
compute_context_cache,
only_controllers_used,
ignore_top_level_node_muting};
inferencer_all_unknown.mark_top_level_node_outputs_as_used();
inferencer_only_controllers.mark_top_level_node_outputs_as_used();
SocketValueInferencer value_inferencer_all_unknown{
tree, scope, compute_context_cache, nullptr, all_ignored_inputs};
SocketUsageInferencer usage_inferencer_all_unknown{tree,
scope,
value_inferencer_all_unknown,
compute_context_cache,
ignore_top_level_node_muting};
SocketValueInferencer value_inferencer_only_controllers{
tree, scope, compute_context_cache, nullptr, only_controllers_used};
SocketUsageInferencer usage_inferencer_only_controllers{tree,
scope,
value_inferencer_only_controllers,
compute_context_cache,
ignore_top_level_node_muting};
usage_inferencer_all_unknown.mark_top_level_node_outputs_as_used();
usage_inferencer_only_controllers.mark_top_level_node_outputs_as_used();
for (const bNodeSocket *socket : all_input_sockets) {
SocketUsage &usage = all_usages[socket->index_in_tree()];
if (usage.is_used) {
@@ -818,12 +807,12 @@ Array<SocketUsage> infer_all_sockets_usage(const bNodeTree &tree)
continue;
}
const SocketInContext socket_ctx{nullptr, socket};
if (inferencer_only_controllers.is_socket_used(socket_ctx)) {
if (usage_inferencer_only_controllers.is_socket_used(socket_ctx)) {
/* The input should be visible if it's used if only visibility-controlling inputs are
* considered. */
continue;
}
if (!inferencer_all_unknown.is_socket_used(socket_ctx)) {
if (!usage_inferencer_all_unknown.is_socket_used(socket_ctx)) {
/* The input should be visible if it's never used, regardless of any inputs. Its usage does
* not depend on any visibility-controlling input. */
continue;
@@ -836,7 +825,7 @@ Array<SocketUsage> infer_all_sockets_usage(const bNodeTree &tree)
continue;
}
const SocketInContext socket_ctx{nullptr, socket};
if (inferencer_only_controllers.is_disabled_output(socket_ctx)) {
if (usage_inferencer_only_controllers.is_disabled_output(socket_ctx)) {
SocketUsage &usage = all_usages[socket->index_in_tree()];
usage.is_visible = false;
}
@@ -863,43 +852,50 @@ void infer_group_interface_usage(const bNodeTree &group,
{
/* Detect actually used inputs. */
SocketUsageInferencer inferencer{group, group_input_values, scope, compute_context_cache};
SocketValueInferencer value_inferencer{
group, scope, compute_context_cache, [&](const int group_input_i) {
return group_input_values[group_input_i];
}};
SocketUsageInferencer usage_inferencer{group, scope, value_inferencer, compute_context_cache};
for (const int i : group.interface_inputs().index_range()) {
r_input_usages[i].is_used |= inferencer.is_group_input_used(i);
r_input_usages[i].is_used |= usage_inferencer.is_group_input_used(i);
}
}
bool visibility_controlling_input_exists = false;
Array<InferenceValue, 32> inputs_all_unknown(group_input_values.size(),
InferenceValue::Unknown());
Array<InferenceValue, 32> inputs_only_controllers = group_input_values;
for (const int i : group.interface_inputs().index_range()) {
const bNodeTreeInterfaceSocket &io_socket = *group.interface_inputs()[i];
if (input_may_affect_visibility(io_socket)) {
visibility_controlling_input_exists = true;
}
else {
inputs_only_controllers[i] = InferenceValue::Unknown();
}
}
if (!visibility_controlling_input_exists) {
/* If there is no visibility controller inputs, all inputs are always visible. */
return;
}
SocketUsageInferencer inferencer_all_unknown{
group, inputs_all_unknown, scope, compute_context_cache};
SocketUsageInferencer inferencer_only_controllers{
group, inputs_only_controllers, scope, compute_context_cache};
SocketValueInferencer value_inferencer_all_unknown{group, scope, compute_context_cache};
SocketUsageInferencer usage_inferencer_all_unknown{
group, scope, value_inferencer_all_unknown, compute_context_cache};
SocketValueInferencer value_inferencer_only_controllers{
group, scope, compute_context_cache, [&](const int group_input_i) {
const bNodeTreeInterfaceSocket &io_socket = *group.interface_inputs()[group_input_i];
if (input_may_affect_visibility(io_socket)) {
return group_input_values[group_input_i];
}
return InferenceValue::Unknown();
}};
SocketUsageInferencer usage_inferencer_only_controllers{
group, scope, value_inferencer_only_controllers, compute_context_cache};
for (const int i : group.interface_inputs().index_range()) {
if (r_input_usages[i].is_used) {
/* Used inputs are always visible. */
continue;
}
if (inferencer_only_controllers.is_group_input_used(i)) {
if (usage_inferencer_only_controllers.is_group_input_used(i)) {
/* The input should be visible if it's used if only visibility-controlling inputs are
* considered. */
continue;
}
if (!inferencer_all_unknown.is_group_input_used(i)) {
if (!usage_inferencer_all_unknown.is_group_input_used(i)) {
/* The input should be visible if it's never used, regardless of any inputs. Its usage does
* not depend on any visibility-controlling input. */
continue;
@@ -908,7 +904,7 @@ void infer_group_interface_usage(const bNodeTree &group,
}
if (r_output_usages) {
for (const int i : group.interface_outputs().index_range()) {
if (inferencer_only_controllers.is_disabled_group_output(i)) {
if (usage_inferencer_only_controllers.is_disabled_group_output(i)) {
SocketUsage &usage = (*r_output_usages)[i];
usage.is_used = false;
usage.is_visible = false;

View File

@@ -40,6 +40,8 @@ class SocketValueInferencerImpl {
*/
Map<SocketInContext, InferenceValue> all_socket_values_;
FunctionRef<InferenceValue(int group_input_i)> group_input_value_fn_;
/**
* All sockets that have animation data and thus their value is not fixed statically. This can
* contain sockets from multiple different trees.
@@ -51,39 +53,21 @@ class SocketValueInferencerImpl {
const bNodeTree &root_tree_;
public:
SocketValueInferencerImpl(const bNodeTree &tree,
ResourceScope &scope,
bke::ComputeContextCache &compute_context_cache,
const std::optional<Span<InferenceValue>> tree_input_values,
const std::optional<Span<bool>> top_level_ignored_inputs)
SocketValueInferencerImpl(
const bNodeTree &tree,
ResourceScope &scope,
bke::ComputeContextCache &compute_context_cache,
const FunctionRef<InferenceValue(int group_input_i)> group_input_value_fn,
const std::optional<Span<bool>> top_level_ignored_inputs)
: scope_(scope),
compute_context_cache_(compute_context_cache),
group_input_value_fn_(group_input_value_fn),
top_level_ignored_inputs_(top_level_ignored_inputs),
root_tree_(tree)
{
root_tree_.ensure_topology_cache();
root_tree_.ensure_interface_cache();
this->ensure_animation_data_processed(root_tree_);
for (const bNode *node : root_tree_.group_input_nodes()) {
for (const int i : root_tree_.interface_inputs().index_range()) {
const bNodeSocket &socket = node->output_socket(i);
if (!socket.is_directly_linked()) {
/* This socket is not linked, hence it's value is never used. Thus we don't have to add
* it to #all_socket_values_. This optimization helps a lot when the node group has a
* very large number of inputs and group input nodes. */
continue;
}
const SocketInContext socket_in_context{nullptr, &socket};
InferenceValue input_value = InferenceValue::Unknown();
if (!this->treat_socket_as_unknown(socket_in_context)) {
if (tree_input_values.has_value()) {
input_value = (*tree_input_values)[i];
}
}
all_socket_values_.add_new(socket_in_context, input_value);
}
}
}
InferenceValue get_socket_value(const SocketInContext &socket)
@@ -255,8 +239,15 @@ class SocketValueInferencerImpl {
void value_task__output__group_input_node(const SocketInContext &socket)
{
/* Group inputs for the root context should be initialized already. */
BLI_assert(socket.context != nullptr);
const bool is_root_context = socket.context == nullptr;
if (is_root_context) {
InferenceValue value = InferenceValue::Unknown();
if (group_input_value_fn_) {
value = group_input_value_fn_(socket->index());
}
all_socket_values_.add_new(socket, value);
return;
}
const bke::GroupNodeComputeContext &group_context =
*static_cast<const bke::GroupNodeComputeContext *>(socket.context);
@@ -900,10 +891,10 @@ SocketValueInferencer::SocketValueInferencer(
const bNodeTree &tree,
ResourceScope &scope,
bke::ComputeContextCache &compute_context_cache,
const std::optional<Span<InferenceValue>> tree_input_values,
const FunctionRef<InferenceValue(int group_input_i)> group_input_value_fn,
const std::optional<Span<bool>> top_level_ignored_inputs)
: impl_(scope.construct<SocketValueInferencerImpl>(
tree, scope, compute_context_cache, tree_input_values, top_level_ignored_inputs))
tree, scope, compute_context_cache, group_input_value_fn, top_level_ignored_inputs))
{
}