Geometry Nodes: add structure type inferencing tests

The way these tests work is similar to the existing field inferencing tests.
There is a .blend file that is opened and then we check the inferred structure
types from Python. A new `NodeSocket.inferred_structure_type` property is added
to be able to access this information. Other then the field inferencing tests,
this patch does not directly check the socket shapes, which are not always
exactly determined by the inferred structure type.

This also fixes a few issues I found while adding the tests.

Pull Request: https://projects.blender.org/blender/blender/pulls/140520
This commit is contained in:
Jacques Lucke
2025-06-18 08:39:01 +02:00
parent 588b9ff3cd
commit 1bb49edf7f
7 changed files with 362 additions and 44 deletions

View File

@@ -259,18 +259,19 @@ static ZoneInOutChange simulation_zone_requirements_propagate(
ZoneInOutChange change = ZoneInOutChange::None;
for (const int i : output_node.output_sockets().index_range()) {
/* First input node output is Delta Time which does not appear in the output node outputs. */
const bNodeSocket &socket_input = input_node.input_socket(i);
const bNodeSocket &socket_output = output_node.output_socket(i);
const bNodeSocket &input_of_input_node = input_node.input_socket(i);
const bNodeSocket &output_of_output_node = output_node.output_socket(i);
const bNodeSocket &input_of_output_node = output_node.input_socket(i + 1);
const DataRequirement new_value = merge(
input_requirements[socket_input.index_in_all_inputs()],
calc_output_socket_requirement(socket_output, input_requirements));
if (input_requirements[socket_input.index_in_all_inputs()] != new_value) {
input_requirements[socket_input.index_in_all_inputs()] = new_value;
input_requirements[input_of_input_node.index_in_all_inputs()],
calc_output_socket_requirement(output_of_output_node, input_requirements));
if (input_requirements[input_of_input_node.index_in_all_inputs()] != new_value) {
input_requirements[input_of_input_node.index_in_all_inputs()] = new_value;
change |= ZoneInOutChange::In;
}
if (input_requirements[socket_input.index_in_all_inputs()] != new_value) {
input_requirements[socket_input.index_in_all_inputs()] = new_value;
change |= ZoneInOutChange::In;
if (input_requirements[input_of_output_node.index_in_all_inputs()] != new_value) {
input_requirements[input_of_output_node.index_in_all_inputs()] = new_value;
change |= ZoneInOutChange::Out;
}
}
return change;
@@ -283,18 +284,19 @@ static ZoneInOutChange repeat_zone_requirements_propagate(
{
ZoneInOutChange change = ZoneInOutChange::None;
for (const int i : output_node.output_sockets().index_range()) {
const bNodeSocket &socket_input = input_node.input_socket(i + 1);
const bNodeSocket &socket_output = output_node.output_socket(i);
const bNodeSocket &input_of_input_node = input_node.input_socket(i + 1);
const bNodeSocket &output_of_output_node = output_node.output_socket(i);
const bNodeSocket &input_of_output_node = output_node.input_socket(i);
const DataRequirement new_value = merge(
input_requirements[socket_input.index_in_all_inputs()],
calc_output_socket_requirement(socket_output, input_requirements));
if (input_requirements[socket_input.index_in_all_inputs()] != new_value) {
input_requirements[socket_input.index_in_all_inputs()] = new_value;
input_requirements[input_of_input_node.index_in_all_inputs()],
calc_output_socket_requirement(output_of_output_node, input_requirements));
if (input_requirements[input_of_input_node.index_in_all_inputs()] != new_value) {
input_requirements[input_of_input_node.index_in_all_inputs()] = new_value;
change |= ZoneInOutChange::In;
}
if (input_requirements[socket_input.index_in_all_inputs()] != new_value) {
input_requirements[socket_input.index_in_all_inputs()] = new_value;
change |= ZoneInOutChange::In;
if (input_requirements[input_of_output_node.index_in_all_inputs()] != new_value) {
input_requirements[input_of_output_node.index_in_all_inputs()] = new_value;
change |= ZoneInOutChange::Out;
}
}
return change;
@@ -382,23 +384,41 @@ static void propagate_right_to_left(const bNodeTree &tree,
input_requirements[socket->index_in_all_inputs()]);
}
/* When a data requirement could be provided by multiple node inputs (i.e. only a single
* node input involved in a math operation has to be a volume grid for the output to be a
* grid), it's better to not propagate the data requirement than incorrectly saying that
* all of the inputs have it. */
Vector<int, 8> inputs_with_links;
for (const int input : node_interface.outputs[output].linked_inputs) {
const bNodeSocket &input_socket = *input_sockets[input];
if (input_socket.is_directly_linked()) {
inputs_with_links.append(input_socket.index_in_all_inputs());
switch (output_requirement) {
case DataRequirement::Invalid:
case DataRequirement::None: {
break;
}
}
if (inputs_with_links.size() == 1) {
input_requirements[inputs_with_links.first()] = output_requirement;
}
else {
for (const int input : inputs_with_links) {
input_requirements[input] = DataRequirement::None;
case DataRequirement::Single: {
/* If the output is a single, all inputs must be singles. */
for (const int input : node_interface.outputs[output].linked_inputs) {
const bNodeSocket &input_socket = *input_sockets[input];
input_requirements[input_socket.index_in_all_inputs()] = DataRequirement::Single;
}
break;
}
case DataRequirement::Field:
case DataRequirement::Grid: {
/* When a data requirement could be provided by multiple node inputs (i.e. only a
* single node input involved in a math operation has to be a volume grid for the
* output to be a grid), it's better to not propagate the data requirement than
* incorrectly saying that all of the inputs have it. */
Vector<int, 8> inputs_with_links;
for (const int input : node_interface.outputs[output].linked_inputs) {
const bNodeSocket &input_socket = *input_sockets[input];
if (input_socket.is_directly_linked()) {
inputs_with_links.append(input_socket.index_in_all_inputs());
}
}
if (inputs_with_links.size() == 1) {
input_requirements[inputs_with_links.first()] = output_requirement;
}
else {
for (const int input : inputs_with_links) {
input_requirements[input] = DataRequirement::None;
}
}
break;
}
}
}
@@ -428,7 +448,7 @@ static StructureType left_to_right_merge(const StructureType a, const StructureT
if ((a == StructureType::Dynamic && b == StructureType::Field) ||
(a == StructureType::Field && b == StructureType::Dynamic))
{
return StructureType::Field;
return StructureType::Dynamic;
}
if ((a == StructureType::Dynamic && b == StructureType::Grid) ||
(a == StructureType::Grid && b == StructureType::Dynamic))
@@ -483,16 +503,17 @@ static ZoneInOutChange repeat_zone_status_propagate(const bNode &input_node,
{
ZoneInOutChange change = ZoneInOutChange::None;
for (const int i : output_node.output_sockets().index_range()) {
const bNodeSocket &input = input_node.output_socket(i + 1);
const bNodeSocket &output = output_node.output_socket(i);
const StructureType new_value = left_to_right_merge(structure_types[input.index_in_tree()],
structure_types[output.index_in_tree()]);
if (structure_types[input.index_in_tree()] != new_value) {
structure_types[input.index_in_tree()] = new_value;
const bNodeSocket &input_of_input_node = input_node.output_socket(i + 1);
const bNodeSocket &output_of_output_node = output_node.output_socket(i);
const StructureType new_value = left_to_right_merge(
structure_types[input_of_input_node.index_in_tree()],
structure_types[output_of_output_node.index_in_tree()]);
if (structure_types[input_of_input_node.index_in_tree()] != new_value) {
structure_types[input_of_input_node.index_in_tree()] = new_value;
change |= ZoneInOutChange::In;
}
if (structure_types[output.index_in_tree()] != new_value) {
structure_types[output.index_in_tree()] = new_value;
if (structure_types[output_of_output_node.index_in_tree()] != new_value) {
structure_types[output_of_output_node.index_in_tree()] = new_value;
change |= ZoneInOutChange::Out;
}
}
@@ -583,6 +604,16 @@ static void propagate_left_to_right(const bNodeTree &tree,
}
}
}
/* Outputs of these nodes have dynamic structure type but should start out as single values. */
for (const StringRefNull idname : {"GeometryNodeRepeatInput", "GeometryNodeRepeatOutput"}) {
for (const bNode *node : tree.nodes_by_type(idname)) {
for (const bNodeSocket *socket : node->output_sockets()) {
structure_types[socket->index_in_tree()] = StructureType::Single;
}
}
}
while (true) {
bool need_update = false;
for (const bNode *node : tree.toposort_left_to_right()) {

View File

@@ -275,6 +275,19 @@ static void rna_NodeSocket_type_set(PointerRNA *ptr, int value)
blender::bke::node_modify_socket_type_static(ntree, &node, sock, value, 0);
}
static int rna_NodeSocket_inferred_structure_type_get(PointerRNA *ptr)
{
bNodeTree *tree = reinterpret_cast<bNodeTree *>(ptr->owner_id);
bNodeSocket *socket = ptr->data_as<bNodeSocket>();
tree->ensure_topology_cache();
if (tree->runtime->inferred_structure_types.size() != tree->all_sockets().size()) {
/* This cache is outdated or not available on this tree type. */
return int(blender::nodes::StructureType::Dynamic);
}
const int index = socket->index_in_tree();
return int(tree->runtime->inferred_structure_types[index]);
}
static void rna_NodeSocket_bl_idname_get(PointerRNA *ptr, char *value)
{
const bNodeSocket *node = static_cast<const bNodeSocket *>(ptr->data);
@@ -807,6 +820,16 @@ static void rna_def_node_socket(BlenderRNA *brna)
RNA_def_property_ui_text(prop, "Shape", "Socket shape");
RNA_def_property_update(prop, NC_NODE | NA_EDITED, "rna_NodeSocket_update");
prop = RNA_def_property(srna, "inferred_structure_type", PROP_ENUM, PROP_NONE);
RNA_def_property_enum_items(prop, rna_enum_node_socket_structure_type_items);
RNA_def_property_clear_flag(prop, PROP_EDITABLE);
RNA_def_property_enum_funcs(
prop, "rna_NodeSocket_inferred_structure_type_get", nullptr, nullptr);
RNA_def_property_ui_text(prop,
"Inferred Structure Type",
"Best known structure type of the socket. This may not match the "
"socket shape, e.g. for unlinked input sockets");
/* registration */
prop = RNA_def_property(srna, "bl_idname", PROP_STRING, PROP_NONE);
RNA_def_property_string_funcs(prop,

View File

@@ -29,7 +29,7 @@ static void node_declare(NodeDeclarationBuilder &b)
const eNodeSocketDatatype data_type = eNodeSocketDatatype(node->custom1);
b.add_input(data_type, "Grid").hide_value();
b.add_input(data_type, "Grid").hide_value().structure_type(StructureType::Grid);
b.add_output<decl::Matrix>("Transform")
.description("Transform from grid index space to object space");

View File

@@ -94,6 +94,8 @@ static void node_declare(NodeDeclarationBuilder &b)
input_decl.supports_field();
output_decl.dependent_field({input_decl.index()});
}
input_decl.structure_type(StructureType::Dynamic);
output_decl.structure_type(StructureType::Dynamic);
}
}
}
@@ -176,6 +178,8 @@ static void node_declare(NodeDeclarationBuilder &b)
input_decl.supports_field();
output_decl.dependent_field({input_decl.index()});
}
input_decl.structure_type(StructureType::Dynamic);
output_decl.structure_type(StructureType::Dynamic);
}
}
b.add_input<decl::Extend>("", "__extend__").structure_type(StructureType::Dynamic);

BIN
tests/files/node_group/structure_type_inference.blend (Stored with Git LFS) Normal file

Binary file not shown.

View File

@@ -580,6 +580,13 @@ if(TEST_SRC_DIR_EXISTS)
--testdir "${TEST_SRC_DIR}/node_group"
)
add_blender_test(
bl_node_structure_type_inference
--python ${CMAKE_CURRENT_LIST_DIR}/bl_node_structure_type_inference.py
--
--testdir "${TEST_SRC_DIR}/node_group"
)
add_blender_test(
bl_node_group_compat
--python ${CMAKE_CURRENT_LIST_DIR}/bl_node_group_compat.py

View File

@@ -0,0 +1,250 @@
# SPDX-FileCopyrightText: 2025 Blender Authors
#
# SPDX-License-Identifier: GPL-2.0-or-later
import pathlib
import sys
import tempfile
import bpy
import unittest
args = None
class StructureTypeInferenceTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.testdir = args.testdir
def setUp(self):
self.assertTrue(self.testdir.exists(),
"Test dir {0} should exist".format(self.testdir))
def load_testfile(self):
bpy.ops.wm.open_mainfile(filepath=str(self.testdir / "structure_type_inference.blend"))
def assertDynamic(self, socket):
self.assertEqual(socket.inferred_structure_type, "DYNAMIC")
def assertSingle(self, socket):
self.assertEqual(socket.inferred_structure_type, "SINGLE")
def assertField(self, socket):
self.assertEqual(socket.inferred_structure_type, "FIELD")
def assertGrid(self, socket):
self.assertEqual(socket.inferred_structure_type, "GRID")
def test_empty_group(self):
self.load_testfile()
tree = bpy.data.node_groups["test_empty_group"]
node = tree.nodes["Group Input"]
self.assertDynamic(node.outputs["Geometry"])
self.assertDynamic(node.outputs["Value"])
node = tree.nodes["Group Output"]
self.assertSingle(node.inputs["Geometry"])
self.assertSingle(node.inputs["Value"])
def test_math_node(self):
self.load_testfile()
tree = bpy.data.node_groups["test_math_node"]
node = tree.nodes["Group Input"]
self.assertDynamic(node.outputs["A"])
self.assertDynamic(node.outputs["B"])
node = tree.nodes["Group Output"]
self.assertDynamic(node.inputs["Out"])
def test_cube_node(self):
self.load_testfile()
tree = bpy.data.node_groups["test_cube_node"]
node = tree.nodes["Group Input"]
self.assertSingle(node.outputs["Size"])
self.assertSingle(node.outputs["Vertices X"])
self.assertSingle(node.outputs["Vertices Y"])
self.assertSingle(node.outputs["Vertices Z"])
node = tree.nodes["Group Output"]
self.assertSingle(node.inputs["Mesh"])
self.assertField(node.inputs["UV Map"])
def test_set_position_node(self):
self.load_testfile()
tree = bpy.data.node_groups["test_set_position_node"]
node = tree.nodes["Group Input"]
self.assertSingle(node.outputs["Geometry"])
self.assertField(node.outputs["Selection"])
self.assertField(node.outputs["Position"])
self.assertField(node.outputs["Offset"])
node = tree.nodes["Group Output"]
self.assertSingle(node.inputs["Geometry"])
def test_cube_with_math_node(self):
self.load_testfile()
tree = bpy.data.node_groups["test_cube_with_math_node"]
node = tree.nodes["Group Input"]
self.assertSingle(node.outputs["A"])
self.assertSingle(node.outputs["B"])
node = tree.nodes["Group Output"]
self.assertSingle(node.inputs["Mesh"])
self.assertField(node.inputs["UV Map"])
def test_output_field(self):
self.load_testfile()
tree = bpy.data.node_groups["test_output_field"]
node = tree.nodes["Group Output"]
self.assertField(node.inputs["Position"])
self.assertField(node.inputs["Normal 1"])
self.assertField(node.inputs["Normal 2"])
def test_add_all_types(self):
self.load_testfile()
tree = bpy.data.node_groups["test_add_all_types"]
node = tree.nodes["Group Input"]
self.assertDynamic(node.outputs["Auto"])
self.assertSingle(node.outputs["Single"])
self.assertDynamic(node.outputs["Dynamic"])
self.assertField(node.outputs["Field"])
self.assertGrid(node.outputs["Grid"])
node = tree.nodes["Group Output"]
self.assertDynamic(node.inputs["auto+auto"])
self.assertDynamic(node.inputs["auto+single"])
self.assertDynamic(node.inputs["auto+dynamic"])
self.assertDynamic(node.inputs["auto+field"])
self.assertGrid(node.inputs["auto+grid"])
self.assertSingle(node.inputs["single+single"])
self.assertDynamic(node.inputs["single+dynamic"])
self.assertField(node.inputs["single+field"])
self.assertGrid(node.inputs["single+grid"])
self.assertDynamic(node.inputs["dynamic+dynamic"])
self.assertDynamic(node.inputs["dynamic+field"])
self.assertGrid(node.inputs["dynamic+grid"])
self.assertField(node.inputs["field+field"])
self.assertGrid(node.inputs["field+grid"])
self.assertGrid(node.inputs["grid+grid"])
def test_requirement_combinations(self):
self.load_testfile()
tree = bpy.data.node_groups["test_requirement_combinations"]
node = tree.nodes["Group Input"]
self.assertDynamic(node.outputs["none"])
self.assertDynamic(node.outputs["dynamic"])
self.assertSingle(node.outputs["single"])
self.assertField(node.outputs["field"])
self.assertGrid(node.outputs["grid"])
self.assertDynamic(node.outputs["none+dynamic"])
self.assertSingle(node.outputs["none+single"])
self.assertField(node.outputs["none+field"])
self.assertGrid(node.outputs["none+grid"])
self.assertDynamic(node.outputs["dynamic+dynamic"])
self.assertSingle(node.outputs["dynamic+single"])
self.assertField(node.outputs["dynamic+field"])
self.assertGrid(node.outputs["dynamic+grid"])
self.assertSingle(node.outputs["single+single"])
self.assertSingle(node.outputs["single+dynamic"])
self.assertSingle(node.outputs["single+field"])
self.assertDynamic(node.outputs["single+grid"])
self.assertField(node.outputs["field+field"])
self.assertSingle(node.outputs["field+single"])
self.assertField(node.outputs["field+dynamic"])
self.assertDynamic(node.outputs["field+grid"])
self.assertGrid(node.outputs["grid+grid"])
self.assertDynamic(node.outputs["grid+single"])
self.assertDynamic(node.outputs["grid+field"])
self.assertGrid(node.outputs["grid+dynamic"])
self.assertDynamic(node.outputs["dynamic+single+field"])
self.assertDynamic(node.outputs["single+field+grid"])
self.assertDynamic(node.outputs["dynamic+field+grid"])
self.assertDynamic(node.outputs["dynamic+single+grid"])
self.assertDynamic(node.outputs["dynamic+single+field+grid"])
def test_simulation_zone(self):
self.load_testfile()
tree = bpy.data.node_groups["test_simulation_zone"]
node = tree.nodes["Group Input"]
self.assertSingle(node.outputs["single 1"])
self.assertSingle(node.outputs["single 2"])
self.assertSingle(node.outputs["single 3"])
self.assertSingle(node.outputs["single 4"])
self.assertDynamic(node.outputs["dynamic"])
def test_repeat_zone(self):
self.load_testfile()
tree = bpy.data.node_groups["test_repeat_zone"]
node = tree.nodes["Group Input"]
self.assertSingle(node.outputs["Iterations"])
self.assertSingle(node.outputs["single 1"])
self.assertSingle(node.outputs["single 2"])
self.assertSingle(node.outputs["single 3"])
self.assertSingle(node.outputs["single 4"])
node = tree.nodes["Group Output"]
self.assertSingle(node.inputs["single 1"])
self.assertSingle(node.inputs["single 2"])
self.assertSingle(node.inputs["single 3"])
self.assertSingle(node.inputs["single 4"])
self.assertField(node.inputs["field 1"])
self.assertField(node.inputs["field 2"])
self.assertField(node.inputs["field 3"])
self.assertGrid(node.inputs["grid 1"])
self.assertGrid(node.inputs["grid 2"])
self.assertGrid(node.inputs["grid 3"])
def test_closure_zone(self):
self.load_testfile()
tree = bpy.data.node_groups["test_closure_zone"]
node = tree.nodes["Closure Input"]
self.assertField(node.outputs["field"])
self.assertSingle(node.outputs["single"])
node = tree.nodes["Closure Output"]
self.assertGrid(node.inputs["grid"])
self.assertField(node.inputs["field"])
def main():
global args
import argparse
if '--' in sys.argv:
argv = [sys.argv[0]] + sys.argv[sys.argv.index('--') + 1:]
else:
argv = sys.argv
parser = argparse.ArgumentParser()
parser.add_argument('--testdir', required=True, type=pathlib.Path)
args, remaining = parser.parse_known_args(argv)
unittest.main(argv=remaining)
if __name__ == "__main__":
main()