Files
test/tests/gtests/functions/FN_multi_function_network_test.cc
Jacques Lucke 2679236047 Cleanup: avoid static initialization order issues when accessing CPPTypes
Instead of depending on static initialization order of globals use
static variables within functions. Those are initialized on first use.
This is every so slighly less efficient, but avoids a full class of problems.
2020-07-17 14:15:06 +02:00

254 lines
7.7 KiB
C++

/* Apache License, Version 2.0 */
#include "testing/testing.h"
#include "FN_multi_function_builder.hh"
#include "FN_multi_function_network.hh"
#include "FN_multi_function_network_evaluation.hh"
namespace blender::fn {
TEST(multi_function_network, Test1)
{
CustomMF_SI_SO<int, int> add_10_fn("add 10", [](int value) { return value + 10; });
CustomMF_SI_SI_SO<int, int, int> multiply_fn("multiply", [](int a, int b) { return a * b; });
MFNetwork network;
MFNode &node1 = network.add_function(add_10_fn);
MFNode &node2 = network.add_function(multiply_fn);
MFOutputSocket &input_socket = network.add_input("Input", MFDataType::ForSingle<int>());
MFInputSocket &output_socket = network.add_output("Output", MFDataType::ForSingle<int>());
network.add_link(node1.output(0), node2.input(0));
network.add_link(node1.output(0), node2.input(1));
network.add_link(node2.output(0), output_socket);
network.add_link(input_socket, node1.input(0));
MFNetworkEvaluator network_fn{{&input_socket}, {&output_socket}};
{
Array<int> values = {4, 6, 1, 2, 0};
Array<int> results(values.size(), 0);
MFParamsBuilder params(network_fn, values.size());
params.add_readonly_single_input(values.as_span());
params.add_uninitialized_single_output(results.as_mutable_span());
MFContextBuilder context;
network_fn.call({0, 2, 3, 4}, params, context);
EXPECT_EQ(results[0], 14 * 14);
EXPECT_EQ(results[1], 0);
EXPECT_EQ(results[2], 11 * 11);
EXPECT_EQ(results[3], 12 * 12);
EXPECT_EQ(results[4], 10 * 10);
}
{
int value = 3;
Array<int> results(5, 0);
MFParamsBuilder params(network_fn, results.size());
params.add_readonly_single_input(&value);
params.add_uninitialized_single_output(results.as_mutable_span());
MFContextBuilder context;
network_fn.call({1, 2, 4}, params, context);
EXPECT_EQ(results[0], 0);
EXPECT_EQ(results[1], 13 * 13);
EXPECT_EQ(results[2], 13 * 13);
EXPECT_EQ(results[3], 0);
EXPECT_EQ(results[4], 13 * 13);
}
}
class ConcatVectorsFunction : public MultiFunction {
public:
ConcatVectorsFunction()
{
MFSignatureBuilder signature = this->get_builder("Concat Vectors");
signature.vector_mutable<int>("A");
signature.vector_input<int>("B");
}
void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override
{
GVectorArrayRef<int> a = params.vector_mutable<int>(0);
VArraySpan<int> b = params.readonly_vector_input<int>(1);
for (uint i : mask) {
a.extend(i, b[i]);
}
}
};
class AppendFunction : public MultiFunction {
public:
AppendFunction()
{
MFSignatureBuilder signature = this->get_builder("Append");
signature.vector_mutable<int>("Vector");
signature.single_input<int>("Value");
}
void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override
{
GVectorArrayRef<int> vectors = params.vector_mutable<int>(0);
VSpan<int> values = params.readonly_single_input<int>(1);
for (uint i : mask) {
vectors.append(i, values[i]);
}
}
};
class SumVectorFunction : public MultiFunction {
public:
SumVectorFunction()
{
MFSignatureBuilder signature = this->get_builder("Sum Vector");
signature.vector_input<int>("Vector");
signature.single_output<int>("Sum");
}
void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override
{
VArraySpan<int> vectors = params.readonly_vector_input<int>(0);
MutableSpan<int> sums = params.uninitialized_single_output<int>(1);
for (uint i : mask) {
int sum = 0;
VSpan<int> vector = vectors[i];
for (uint j = 0; j < vector.size(); j++) {
sum += vector[j];
}
sums[i] = sum;
}
}
};
class CreateRangeFunction : public MultiFunction {
public:
CreateRangeFunction()
{
MFSignatureBuilder builder = this->get_builder("Create Range");
builder.single_input<int>("Size");
builder.vector_output<int>("Range");
}
void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override
{
VSpan<int> sizes = params.readonly_single_input<int>(0, "Size");
GVectorArrayRef<int> ranges = params.vector_output<int>(1, "Range");
for (int i : mask) {
int size = sizes[i];
for (int j : IndexRange(size)) {
ranges.append(i, j);
}
}
}
};
TEST(multi_function_network, Test2)
{
CustomMF_SI_SO<int, int> add_3_fn("add 3", [](int value) { return value + 3; });
ConcatVectorsFunction concat_vectors_fn;
AppendFunction append_fn;
SumVectorFunction sum_fn;
CreateRangeFunction create_range_fn;
MFNetwork network;
MFOutputSocket &input1 = network.add_input("Input 1", MFDataType::ForVector<int>());
MFOutputSocket &input2 = network.add_input("Input 2", MFDataType::ForSingle<int>());
MFInputSocket &output1 = network.add_output("Output 1", MFDataType::ForVector<int>());
MFInputSocket &output2 = network.add_output("Output 2", MFDataType::ForSingle<int>());
MFNode &node1 = network.add_function(add_3_fn);
MFNode &node2 = network.add_function(create_range_fn);
MFNode &node3 = network.add_function(concat_vectors_fn);
MFNode &node4 = network.add_function(sum_fn);
MFNode &node5 = network.add_function(append_fn);
MFNode &node6 = network.add_function(sum_fn);
network.add_link(input2, node1.input(0));
network.add_link(node1.output(0), node2.input(0));
network.add_link(node2.output(0), node3.input(1));
network.add_link(input1, node3.input(0));
network.add_link(input1, node4.input(0));
network.add_link(node4.output(0), node5.input(1));
network.add_link(node3.output(0), node5.input(0));
network.add_link(node5.output(0), node6.input(0));
network.add_link(node3.output(0), output1);
network.add_link(node6.output(0), output2);
// std::cout << network.to_dot() << "\n\n";
MFNetworkEvaluator network_fn{{&input1, &input2}, {&output1, &output2}};
{
Array<int> input_value_1 = {3, 6};
int input_value_2 = 4;
GVectorArray output_value_1(CPPType::get<int32_t>(), 5);
Array<int> output_value_2(5, -1);
MFParamsBuilder params(network_fn, 5);
params.add_readonly_vector_input(GVArraySpan(input_value_1.as_span(), 5));
params.add_readonly_single_input(&input_value_2);
params.add_vector_output(output_value_1);
params.add_uninitialized_single_output(output_value_2.as_mutable_span());
MFContextBuilder context;
network_fn.call({1, 2, 4}, params, context);
EXPECT_EQ(output_value_1[0].size(), 0);
EXPECT_EQ(output_value_1[1].size(), 9);
EXPECT_EQ(output_value_1[2].size(), 9);
EXPECT_EQ(output_value_1[3].size(), 0);
EXPECT_EQ(output_value_1[4].size(), 9);
EXPECT_EQ(output_value_2[0], -1);
EXPECT_EQ(output_value_2[1], 39);
EXPECT_EQ(output_value_2[2], 39);
EXPECT_EQ(output_value_2[3], -1);
EXPECT_EQ(output_value_2[4], 39);
}
{
GVectorArray input_value_1(CPPType::get<int32_t>(), 3);
GVectorArrayRef<int> input_value_ref_1 = input_value_1;
input_value_ref_1.extend(0, {3, 4, 5});
input_value_ref_1.extend(1, {1, 2});
Array<int> input_value_2 = {4, 2, 3};
GVectorArray output_value_1(CPPType::get<int32_t>(), 3);
Array<int> output_value_2(3, -1);
MFParamsBuilder params(network_fn, 3);
params.add_readonly_vector_input(input_value_1);
params.add_readonly_single_input(input_value_2.as_span());
params.add_vector_output(output_value_1);
params.add_uninitialized_single_output(output_value_2.as_mutable_span());
MFContextBuilder context;
network_fn.call({0, 1, 2}, params, context);
EXPECT_EQ(output_value_1[0].size(), 10);
EXPECT_EQ(output_value_1[1].size(), 7);
EXPECT_EQ(output_value_1[2].size(), 6);
EXPECT_EQ(output_value_2[0], 45);
EXPECT_EQ(output_value_2[1], 16);
EXPECT_EQ(output_value_2[2], 15);
}
}
} // namespace blender::fn