diff --git a/raytracing/src/matrix.cpp b/raytracing/src/matrix.cpp index dee8ec8..9a66723 100644 --- a/raytracing/src/matrix.cpp +++ b/raytracing/src/matrix.cpp @@ -208,16 +208,16 @@ double Matrix::determinant(void) /* ------------------------------------------------------------------------- */ -Matrix Matrix::sub_matrix(uint8_t a_rows, uint8_t a_cols) +Matrix Matrix::sub_matrix(uint8_t a_row, uint8_t a_col) { Matrix the_sub = *this; - the_sub.m_data.erase(the_sub.m_data.begin() + a_rows); + the_sub.m_data.erase(the_sub.m_data.begin() + a_row); the_sub.m_rows--; for (int the_row = 0; the_row < the_sub.m_rows; the_row++) { - the_sub.m_data[the_row].erase(the_sub.m_data[the_row].begin() + a_cols); + the_sub.m_data[the_row].erase(the_sub.m_data[the_row].begin() + a_col); } the_sub.m_cols--; @@ -226,9 +226,23 @@ Matrix Matrix::sub_matrix(uint8_t a_rows, uint8_t a_cols) /* ------------------------------------------------------------------------- */ -double Matrix::minor(uint8_t a_rows, uint8_t a_cols) +double Matrix::minor(uint8_t a_row, uint8_t a_col) { - return sub_matrix(a_rows, a_cols).determinant(); + return sub_matrix(a_row, a_col).determinant(); +} + +/* ------------------------------------------------------------------------- */ + +double Matrix::cofactor(uint8_t a_row, uint8_t a_col) +{ + int8_t an_inverter = 1; + + if ((a_row + a_col) % 2 != 0) + { + an_inverter = -1; + } + + return an_inverter * minor(a_row, a_col); } /* ------------------------------------------------------------------------- */ diff --git a/raytracing/src/matrix.h b/raytracing/src/matrix.h index bdd946f..068bb94 100644 --- a/raytracing/src/matrix.h +++ b/raytracing/src/matrix.h @@ -61,8 +61,9 @@ namespace Raytracer bool transpose(void); double determinant(void); - Matrix sub_matrix(uint8_t a_rows, uint8_t a_cols); - double minor(uint8_t a_rows, uint8_t a_cols); + Matrix sub_matrix(uint8_t a_row, uint8_t a_col); + double minor(uint8_t a_row, uint8_t a_col); + double cofactor(uint8_t a_row, uint8_t a_col); static Matrix identity(void); diff --git a/tests/03_matrix.cpp b/tests/03_matrix.cpp index bae341b..1e3c32e 100644 --- a/tests/03_matrix.cpp +++ b/tests/03_matrix.cpp @@ -294,3 +294,19 @@ TEST_CASE("[Matrix] Calculating a minor of a 3x3 matrix", "[Matrix]") REQUIRE(a.minor(1, 0) == 25); } + +/* ------------------------------------------------------------------------- */ + +TEST_CASE("[Matrix] Calculating a cofactor of a 3x3 matrix", "[Matrix]") +{ + Matrix a = { + {3, 5, 0}, + {2, -1, -7}, + {6, -1, 5} + }; + + REQUIRE(a.minor(0, 0) == -12.0); + REQUIRE(a.cofactor(0, 0) == -12.0); + REQUIRE(a.minor(1, 0) == 25.0); + REQUIRE(a.cofactor(1, 0) == -25.0); +}