BLI: Add support for non-square matrix multiplication.

Adds support for multiplying non-square non-equal matrices.

Co-authored-by: Clément Foucault <foucault.clem@gmail.com>
Pull Request: https://projects.blender.org/blender/blender/pulls/115783
This commit is contained in:
casey bianco-davis
2024-03-03 16:26:04 +01:00
committed by Clément Foucault
parent 62bb346af9
commit 3d136d0d00
3 changed files with 141 additions and 60 deletions

View File

@@ -364,24 +364,6 @@ struct alignas(Alignment) MatBase : public vec_struct_base<VecBase<T, NumRow>, N
return *this;
}
/** Multiply two matrices using matrix multiplication. */
MatBase<T, NumRow, NumRow> operator*(const MatBase<T, NumRow, NumCol> &b) const
{
const MatBase &a = *this;
/* This is the reference implementation.
* Might be overloaded with vectorized / optimized code. */
/* TODO(fclem): It should be possible to return non-square matrices when multiplying against
* MatBase<T, NumRow, OtherNumRow>. */
MatBase<T, NumRow, NumRow> result{};
unroll<NumRow>([&](auto j) {
unroll<NumRow>([&](auto i) {
/* Same as dot product, but avoid dependency on vector math. */
unroll<NumCol>([&](auto k) { result[j][i] += a[k][i] * b[j][k]; });
});
});
return result;
}
/** Multiply each component by a scalar. */
friend MatBase operator*(const MatBase &a, T b)
{
@@ -647,34 +629,6 @@ struct MatView : NonCopyable, NonMovable {
return result;
}
/** Multiply two matrices using matrix multiplication. */
template<int OtherSrcNumCol,
int OtherSrcNumRow,
int OtherSrcStartCol,
int OtherSrcStartRow,
int OtherSrcAlignment>
MatBase<T, NumRow, NumRow> operator*(const MatView<T,
NumRow,
NumCol,
OtherSrcNumCol,
OtherSrcNumRow,
OtherSrcStartCol,
OtherSrcStartRow,
OtherSrcAlignment> &b) const
{
const MatView &a = *this;
/* This is the reference implementation.
* Might be overloaded with vectorized / optimized code. */
MatBase<T, NumRow, NumRow> result{};
unroll<NumRow>([&](auto j) {
unroll<NumRow>([&](auto i) {
/* Same as dot product, but avoid dependency on vector math. */
unroll<NumCol>([&](auto k) { result[j][i] += a[k][i] * b[j][k]; });
});
});
return result;
}
MatT operator*(const MatT &b) const
{
return *this * b.view();
@@ -932,6 +886,120 @@ struct MutableMatView
}
};
namespace detail {
/** Multiply two matrices using matrix multiplication. */
template<typename T,
int A_NumCol,
int A_NumRow,
int B_NumCol,
int B_NumRow,
typename MatA,
typename MatB>
MatBase<T, B_NumCol, A_NumRow> matrix_mul_impl(const MatA &a, const MatB &b)
{
static_assert(A_NumCol == B_NumRow);
/* This is the reference implementation.
* Might be overloaded with vectorized / optimized code. */
MatBase<T, B_NumCol, A_NumRow> result{};
unroll<B_NumCol>([&](auto j) {
unroll<A_NumRow>([&](auto i) {
/* Same as dot product, but avoid dependency on vector math. */
unroll<A_NumCol>([&](auto k) { result[j][i] += a[k][i] * b[j][k]; });
});
});
return result;
}
} // namespace detail
template<typename T, int A_NumCol, int A_NumRow, int B_NumCol, int B_NumRow>
MatBase<T, B_NumCol, A_NumRow> operator*(const MatBase<T, A_NumCol, A_NumRow> &a,
const MatBase<T, B_NumCol, B_NumRow> &b)
{
return detail::matrix_mul_impl<T, A_NumCol, A_NumRow, B_NumCol, B_NumRow>(a, b);
}
template<typename T,
int A_NumCol,
int A_NumRow,
int A_SrcNumCol,
int A_SrcNumRow,
int A_SrcStartCol,
int A_SrcStartRow,
int A_SrcAlignment,
int B_NumCol,
int B_NumRow,
int B_SrcNumCol,
int B_SrcNumRow,
int B_SrcStartCol,
int B_SrcStartRow,
int B_SrcAlignment>
MatBase<T, B_NumCol, A_NumRow> operator*(const MatView<T,
A_NumCol,
A_NumRow,
A_SrcNumCol,
A_SrcNumRow,
A_SrcStartCol,
A_SrcStartRow,
A_SrcAlignment> &a,
const MatView<T,
B_NumCol,
B_NumRow,
B_SrcNumCol,
B_SrcNumRow,
B_SrcStartCol,
B_SrcStartRow,
B_SrcAlignment> &b)
{
return detail::matrix_mul_impl<T, A_NumCol, A_NumRow, B_NumCol, B_NumRow>(a, b);
}
template<typename T,
int A_NumCol,
int A_NumRow,
int A_SrcNumCol,
int A_SrcNumRow,
int A_SrcStartCol,
int A_SrcStartRow,
int A_SrcAlignment,
int B_NumCol,
int B_NumRow>
MatBase<T, B_NumCol, A_NumRow> operator*(const MatView<T,
A_NumCol,
A_NumRow,
A_SrcNumCol,
A_SrcNumRow,
A_SrcStartCol,
A_SrcStartRow,
A_SrcAlignment> &a,
const MatBase<T, B_NumCol, B_NumRow> &b)
{
return detail::matrix_mul_impl<T, A_NumCol, A_NumRow, B_NumCol, B_NumRow>(a, b);
}
template<typename T,
int A_NumCol,
int A_NumRow,
int B_NumCol,
int B_NumRow,
int B_SrcNumCol,
int B_SrcNumRow,
int B_SrcStartCol,
int B_SrcStartRow,
int B_SrcAlignment>
MatBase<T, B_NumCol, A_NumRow> operator*(const MatBase<T, A_NumCol, A_NumRow> &a,
const MatView<T,
B_NumCol,
B_NumRow,
B_SrcNumCol,
B_SrcNumRow,
B_SrcStartCol,
B_SrcStartRow,
B_SrcAlignment> &b)
{
return detail::matrix_mul_impl<T, A_NumCol, A_NumRow, B_NumCol, B_NumRow>(a, b);
}
using float2x2 = MatBase<float, 2, 2>;
using float2x3 = MatBase<float, 2, 3>;
using float2x4 = MatBase<float, 2, 4>;
@@ -958,12 +1026,12 @@ using double4x3 = MatBase<double, 4, 3>;
using double4x4 = MatBase<double, 4, 4>;
/* Specialization for SSE optimization. */
template<> float4x4 float4x4::operator*(const float4x4 &b) const;
template<> float3x3 float3x3::operator*(const float3x3 &b) const;
template<> float4x4 operator*(const float4x4 &a, const float4x4 &b);
template<> float3x3 operator*(const float3x3 &a, const float3x3 &b);
extern template float2x2 float2x2::operator*(const float2x2 &b) const;
extern template double2x2 double2x2::operator*(const double2x2 &b) const;
extern template double3x3 double3x3::operator*(const double3x3 &b) const;
extern template double4x4 double4x4::operator*(const double4x4 &b) const;
extern template float2x2 operator*(const float2x2 &a, const float2x2 &b);
extern template double2x2 operator*(const double2x2 &a, const double2x2 &b);
extern template double3x3 operator*(const double3x3 &a, const double3x3 &b);
extern template double4x4 operator*(const double4x4 &a, const double4x4 &b);
} // namespace blender

View File

@@ -21,10 +21,9 @@
namespace blender {
template<> float4x4 float4x4::operator*(const float4x4 &b) const
template<> float4x4 operator*(const float4x4 &a, const float4x4 &b)
{
using namespace math;
const float4x4 &a = *this;
float4x4 result;
#if BLI_HAVE_SSE2
@@ -69,10 +68,9 @@ template<> float4x4 float4x4::operator*(const float4x4 &b) const
return result;
}
template<> float3x3 float3x3::operator*(const float3x3 &b) const
template<> float3x3 operator*(const float3x3 &a, const float3x3 &b)
{
using namespace math;
const float3x3 &a = *this;
float3x3 result;
#if 0 /* 1.2 times slower. Could be used as reference for aligned version. */
@@ -114,10 +112,10 @@ template<> float3x3 float3x3::operator*(const float3x3 &b) const
return result;
}
template float2x2 float2x2::operator*(const float2x2 &b) const;
template double2x2 double2x2::operator*(const double2x2 &b) const;
template double3x3 double3x3::operator*(const double3x3 &b) const;
template double4x4 double4x4::operator*(const double4x4 &b) const;
template float2x2 operator*(const float2x2 &a, const float2x2 &b);
template double2x2 operator*(const double2x2 &a, const double2x2 &b);
template double3x3 operator*(const double3x3 &a, const double3x3 &b);
template double4x4 operator*(const double4x4 &a, const double4x4 &b);
} // namespace blender

View File

@@ -255,6 +255,21 @@ TEST(math_matrix_types, MatrixMultiplyOperator)
EXPECT_EQ(result4[0][1], expect4[0][1]);
EXPECT_EQ(result4[1][0], expect4[1][0]);
EXPECT_EQ(result4[1][1], expect4[1][1]);
float3x4 a5(float4(1), float4(3), float4(5));
float2x3 b5(float3(11, 7, 5), float3(13, 11, 17));
float2x4 expect5(float4(57), float4(131));
float2x4 result5 = a5 * b5;
EXPECT_EQ(result5[0][0], expect5[0][0]);
EXPECT_EQ(result5[0][1], expect5[0][1]);
EXPECT_EQ(result5[0][2], expect5[0][2]);
EXPECT_EQ(result5[0][3], expect5[0][3]);
EXPECT_EQ(result5[1][0], expect5[1][0]);
EXPECT_EQ(result5[1][1], expect5[1][1]);
EXPECT_EQ(result5[1][2], expect5[1][2]);
EXPECT_EQ(result5[1][3], expect5[1][3]);
}
TEST(math_matrix_types, VectorMultiplyOperator)