BLI: Add support for non-square matrix multiplication.
Adds support for multiplying non-square non-equal matrices. Co-authored-by: Clément Foucault <foucault.clem@gmail.com> Pull Request: https://projects.blender.org/blender/blender/pulls/115783
This commit is contained in:
committed by
Clément Foucault
parent
62bb346af9
commit
3d136d0d00
@@ -364,24 +364,6 @@ struct alignas(Alignment) MatBase : public vec_struct_base<VecBase<T, NumRow>, N
|
||||
return *this;
|
||||
}
|
||||
|
||||
/** Multiply two matrices using matrix multiplication. */
|
||||
MatBase<T, NumRow, NumRow> operator*(const MatBase<T, NumRow, NumCol> &b) const
|
||||
{
|
||||
const MatBase &a = *this;
|
||||
/* This is the reference implementation.
|
||||
* Might be overloaded with vectorized / optimized code. */
|
||||
/* TODO(fclem): It should be possible to return non-square matrices when multiplying against
|
||||
* MatBase<T, NumRow, OtherNumRow>. */
|
||||
MatBase<T, NumRow, NumRow> result{};
|
||||
unroll<NumRow>([&](auto j) {
|
||||
unroll<NumRow>([&](auto i) {
|
||||
/* Same as dot product, but avoid dependency on vector math. */
|
||||
unroll<NumCol>([&](auto k) { result[j][i] += a[k][i] * b[j][k]; });
|
||||
});
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Multiply each component by a scalar. */
|
||||
friend MatBase operator*(const MatBase &a, T b)
|
||||
{
|
||||
@@ -647,34 +629,6 @@ struct MatView : NonCopyable, NonMovable {
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Multiply two matrices using matrix multiplication. */
|
||||
template<int OtherSrcNumCol,
|
||||
int OtherSrcNumRow,
|
||||
int OtherSrcStartCol,
|
||||
int OtherSrcStartRow,
|
||||
int OtherSrcAlignment>
|
||||
MatBase<T, NumRow, NumRow> operator*(const MatView<T,
|
||||
NumRow,
|
||||
NumCol,
|
||||
OtherSrcNumCol,
|
||||
OtherSrcNumRow,
|
||||
OtherSrcStartCol,
|
||||
OtherSrcStartRow,
|
||||
OtherSrcAlignment> &b) const
|
||||
{
|
||||
const MatView &a = *this;
|
||||
/* This is the reference implementation.
|
||||
* Might be overloaded with vectorized / optimized code. */
|
||||
MatBase<T, NumRow, NumRow> result{};
|
||||
unroll<NumRow>([&](auto j) {
|
||||
unroll<NumRow>([&](auto i) {
|
||||
/* Same as dot product, but avoid dependency on vector math. */
|
||||
unroll<NumCol>([&](auto k) { result[j][i] += a[k][i] * b[j][k]; });
|
||||
});
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
MatT operator*(const MatT &b) const
|
||||
{
|
||||
return *this * b.view();
|
||||
@@ -932,6 +886,120 @@ struct MutableMatView
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
/** Multiply two matrices using matrix multiplication. */
|
||||
template<typename T,
|
||||
int A_NumCol,
|
||||
int A_NumRow,
|
||||
int B_NumCol,
|
||||
int B_NumRow,
|
||||
typename MatA,
|
||||
typename MatB>
|
||||
MatBase<T, B_NumCol, A_NumRow> matrix_mul_impl(const MatA &a, const MatB &b)
|
||||
{
|
||||
static_assert(A_NumCol == B_NumRow);
|
||||
/* This is the reference implementation.
|
||||
* Might be overloaded with vectorized / optimized code. */
|
||||
MatBase<T, B_NumCol, A_NumRow> result{};
|
||||
unroll<B_NumCol>([&](auto j) {
|
||||
unroll<A_NumRow>([&](auto i) {
|
||||
/* Same as dot product, but avoid dependency on vector math. */
|
||||
unroll<A_NumCol>([&](auto k) { result[j][i] += a[k][i] * b[j][k]; });
|
||||
});
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template<typename T, int A_NumCol, int A_NumRow, int B_NumCol, int B_NumRow>
|
||||
MatBase<T, B_NumCol, A_NumRow> operator*(const MatBase<T, A_NumCol, A_NumRow> &a,
|
||||
const MatBase<T, B_NumCol, B_NumRow> &b)
|
||||
{
|
||||
return detail::matrix_mul_impl<T, A_NumCol, A_NumRow, B_NumCol, B_NumRow>(a, b);
|
||||
}
|
||||
|
||||
template<typename T,
|
||||
int A_NumCol,
|
||||
int A_NumRow,
|
||||
int A_SrcNumCol,
|
||||
int A_SrcNumRow,
|
||||
int A_SrcStartCol,
|
||||
int A_SrcStartRow,
|
||||
int A_SrcAlignment,
|
||||
int B_NumCol,
|
||||
int B_NumRow,
|
||||
int B_SrcNumCol,
|
||||
int B_SrcNumRow,
|
||||
int B_SrcStartCol,
|
||||
int B_SrcStartRow,
|
||||
int B_SrcAlignment>
|
||||
MatBase<T, B_NumCol, A_NumRow> operator*(const MatView<T,
|
||||
A_NumCol,
|
||||
A_NumRow,
|
||||
A_SrcNumCol,
|
||||
A_SrcNumRow,
|
||||
A_SrcStartCol,
|
||||
A_SrcStartRow,
|
||||
A_SrcAlignment> &a,
|
||||
const MatView<T,
|
||||
B_NumCol,
|
||||
B_NumRow,
|
||||
B_SrcNumCol,
|
||||
B_SrcNumRow,
|
||||
B_SrcStartCol,
|
||||
B_SrcStartRow,
|
||||
B_SrcAlignment> &b)
|
||||
{
|
||||
return detail::matrix_mul_impl<T, A_NumCol, A_NumRow, B_NumCol, B_NumRow>(a, b);
|
||||
}
|
||||
template<typename T,
|
||||
int A_NumCol,
|
||||
int A_NumRow,
|
||||
int A_SrcNumCol,
|
||||
int A_SrcNumRow,
|
||||
int A_SrcStartCol,
|
||||
int A_SrcStartRow,
|
||||
int A_SrcAlignment,
|
||||
int B_NumCol,
|
||||
int B_NumRow>
|
||||
MatBase<T, B_NumCol, A_NumRow> operator*(const MatView<T,
|
||||
A_NumCol,
|
||||
A_NumRow,
|
||||
A_SrcNumCol,
|
||||
A_SrcNumRow,
|
||||
A_SrcStartCol,
|
||||
A_SrcStartRow,
|
||||
A_SrcAlignment> &a,
|
||||
const MatBase<T, B_NumCol, B_NumRow> &b)
|
||||
{
|
||||
return detail::matrix_mul_impl<T, A_NumCol, A_NumRow, B_NumCol, B_NumRow>(a, b);
|
||||
}
|
||||
|
||||
template<typename T,
|
||||
int A_NumCol,
|
||||
int A_NumRow,
|
||||
int B_NumCol,
|
||||
int B_NumRow,
|
||||
int B_SrcNumCol,
|
||||
int B_SrcNumRow,
|
||||
int B_SrcStartCol,
|
||||
int B_SrcStartRow,
|
||||
int B_SrcAlignment>
|
||||
MatBase<T, B_NumCol, A_NumRow> operator*(const MatBase<T, A_NumCol, A_NumRow> &a,
|
||||
const MatView<T,
|
||||
B_NumCol,
|
||||
B_NumRow,
|
||||
B_SrcNumCol,
|
||||
B_SrcNumRow,
|
||||
B_SrcStartCol,
|
||||
B_SrcStartRow,
|
||||
B_SrcAlignment> &b)
|
||||
{
|
||||
return detail::matrix_mul_impl<T, A_NumCol, A_NumRow, B_NumCol, B_NumRow>(a, b);
|
||||
}
|
||||
|
||||
using float2x2 = MatBase<float, 2, 2>;
|
||||
using float2x3 = MatBase<float, 2, 3>;
|
||||
using float2x4 = MatBase<float, 2, 4>;
|
||||
@@ -958,12 +1026,12 @@ using double4x3 = MatBase<double, 4, 3>;
|
||||
using double4x4 = MatBase<double, 4, 4>;
|
||||
|
||||
/* Specialization for SSE optimization. */
|
||||
template<> float4x4 float4x4::operator*(const float4x4 &b) const;
|
||||
template<> float3x3 float3x3::operator*(const float3x3 &b) const;
|
||||
template<> float4x4 operator*(const float4x4 &a, const float4x4 &b);
|
||||
template<> float3x3 operator*(const float3x3 &a, const float3x3 &b);
|
||||
|
||||
extern template float2x2 float2x2::operator*(const float2x2 &b) const;
|
||||
extern template double2x2 double2x2::operator*(const double2x2 &b) const;
|
||||
extern template double3x3 double3x3::operator*(const double3x3 &b) const;
|
||||
extern template double4x4 double4x4::operator*(const double4x4 &b) const;
|
||||
extern template float2x2 operator*(const float2x2 &a, const float2x2 &b);
|
||||
extern template double2x2 operator*(const double2x2 &a, const double2x2 &b);
|
||||
extern template double3x3 operator*(const double3x3 &a, const double3x3 &b);
|
||||
extern template double4x4 operator*(const double4x4 &a, const double4x4 &b);
|
||||
|
||||
} // namespace blender
|
||||
|
||||
@@ -21,10 +21,9 @@
|
||||
|
||||
namespace blender {
|
||||
|
||||
template<> float4x4 float4x4::operator*(const float4x4 &b) const
|
||||
template<> float4x4 operator*(const float4x4 &a, const float4x4 &b)
|
||||
{
|
||||
using namespace math;
|
||||
const float4x4 &a = *this;
|
||||
float4x4 result;
|
||||
|
||||
#if BLI_HAVE_SSE2
|
||||
@@ -69,10 +68,9 @@ template<> float4x4 float4x4::operator*(const float4x4 &b) const
|
||||
return result;
|
||||
}
|
||||
|
||||
template<> float3x3 float3x3::operator*(const float3x3 &b) const
|
||||
template<> float3x3 operator*(const float3x3 &a, const float3x3 &b)
|
||||
{
|
||||
using namespace math;
|
||||
const float3x3 &a = *this;
|
||||
float3x3 result;
|
||||
|
||||
#if 0 /* 1.2 times slower. Could be used as reference for aligned version. */
|
||||
@@ -114,10 +112,10 @@ template<> float3x3 float3x3::operator*(const float3x3 &b) const
|
||||
return result;
|
||||
}
|
||||
|
||||
template float2x2 float2x2::operator*(const float2x2 &b) const;
|
||||
template double2x2 double2x2::operator*(const double2x2 &b) const;
|
||||
template double3x3 double3x3::operator*(const double3x3 &b) const;
|
||||
template double4x4 double4x4::operator*(const double4x4 &b) const;
|
||||
template float2x2 operator*(const float2x2 &a, const float2x2 &b);
|
||||
template double2x2 operator*(const double2x2 &a, const double2x2 &b);
|
||||
template double3x3 operator*(const double3x3 &a, const double3x3 &b);
|
||||
template double4x4 operator*(const double4x4 &a, const double4x4 &b);
|
||||
|
||||
} // namespace blender
|
||||
|
||||
|
||||
@@ -255,6 +255,21 @@ TEST(math_matrix_types, MatrixMultiplyOperator)
|
||||
EXPECT_EQ(result4[0][1], expect4[0][1]);
|
||||
EXPECT_EQ(result4[1][0], expect4[1][0]);
|
||||
EXPECT_EQ(result4[1][1], expect4[1][1]);
|
||||
|
||||
float3x4 a5(float4(1), float4(3), float4(5));
|
||||
float2x3 b5(float3(11, 7, 5), float3(13, 11, 17));
|
||||
|
||||
float2x4 expect5(float4(57), float4(131));
|
||||
|
||||
float2x4 result5 = a5 * b5;
|
||||
EXPECT_EQ(result5[0][0], expect5[0][0]);
|
||||
EXPECT_EQ(result5[0][1], expect5[0][1]);
|
||||
EXPECT_EQ(result5[0][2], expect5[0][2]);
|
||||
EXPECT_EQ(result5[0][3], expect5[0][3]);
|
||||
EXPECT_EQ(result5[1][0], expect5[1][0]);
|
||||
EXPECT_EQ(result5[1][1], expect5[1][1]);
|
||||
EXPECT_EQ(result5[1][2], expect5[1][2]);
|
||||
EXPECT_EQ(result5[1][3], expect5[1][3]);
|
||||
}
|
||||
|
||||
TEST(math_matrix_types, VectorMultiplyOperator)
|
||||
|
||||
Reference in New Issue
Block a user