The problem is taken from one of recent SO questions:
Finding max sum of matrix elements with following constraints:
- Exactly one row element has to be included in the sum
- If element at (i, j) is selected, then (i + 1, j) is not.
- Values in the matrix are nonnegative integers.
Below is my working (as far as I can tell) solution. It is based on the following observation: given a solution for a (N x M) matrix (in terms of taken path), the path is either the solution for the matrix without the last row (that is N -1 x M), or there exists at most one better path for the submatrix.
I've aimed to provide code testable at compile time, for that purpose I've also written a lightweight wrapper of std::array, with interface being a subset of e.g. eigen::Matrix classes.
Compiles fine with both gcc and clang (-std=c++17 -Werror -Wall -Wpedantic).
Any thoughts on how to improve it?
#include <array>
/*
* Small access wrapper for an array.
* Underlying storage kept puclic to allow efficient construction as an aggregate
*/
template <typename T, std::size_t rows_, std::size_t cols_>
struct Array2d {
T& operator()(std::size_t row, std::size_t col) {
return storage[index(row, col)];
}
constexpr T const & operator()(std::size_t row, std::size_t col) const {
return storage[index(row, col)];
}
constexpr std::size_t rows() const { return rows_; }
constexpr std::size_t cols() const { return cols_; }
std::array<T, rows_ * cols_> storage;
private:
constexpr std::size_t index (std::size_t row, std::size_t col) const {
return cols_ * row + col;
}
};
template<std::size_t rows, std::size_t cols>
using Problem = Array2d<int, rows, cols>;
namespace MaxSum {
namespace Details {
struct IndexedRowValue {
int val;
std::size_t index;
};
template<std::size_t count>
using TopElements = std::array<IndexedRowValue, count>;
constexpr TopElements<2> sorted_first_two (int v1, int v2) {
auto m1 = IndexedRowValue{v1, 0};
auto m2 = IndexedRowValue{v2, 1};
return v1 >= v2
? TopElements<2>{m1, m2}
: TopElements<2>{m2, m1};
}
constexpr void update_top (TopElements<2> & top_paths, int path_value,
std::size_t col) {
if (path_value > top_paths[0].val) {
top_paths[1] = top_paths[0];
top_paths[0] = IndexedRowValue{path_value, col};
}
else if (path_value > top_paths[1].val)
top_paths[1] = IndexedRowValue{path_value, col};
}
template<typename Matrix>
constexpr TopElements<2> find_top2_in_first_row (Matrix const &input) {
auto result = sorted_first_two(input(0, 0), input(0, 1));
for (auto i = 2u; i < input.cols(); ++i)
update_top(result,input(0, i), i);
return result;
}
constexpr int best_path_value_through_element(TopElements<2> const & top_last_row,
int val,
std::size_t col){
return top_last_row[0].index != col
? top_last_row[0].val + val
: top_last_row[1].val + val;
}
template<typename Matrix>
constexpr TopElements<2> find_best_paths_for_row(TopElements<2> const & top_last_row,
std::size_t row,
Matrix const & input) {
auto path_0 = best_path_value_through_element(top_last_row, input(row, 0), 0u);
auto path_1 = best_path_value_through_element(top_last_row, input(row, 1), 1u);
auto top_paths = sorted_first_two(path_0, path_1);
for (auto i = 2u; i < input.cols(); ++i) {
auto path_i = best_path_value_through_element(top_last_row, input(row, i), i);
update_top(top_paths, path_i, i);
}
return top_paths;
}
template<typename Matrix>
constexpr int solve_non_trivial(Matrix const & input) {
auto top_paths = find_top2_in_first_row(input);
for (auto i = 1u; i < input.rows(); ++i)
top_paths = find_best_paths_for_row(top_paths, i, input);
// key observation: optimal path at row i is either best or second best at i - 1
return top_paths[0].val;
}
} // namespace Details
/*
* Finds max sum of elements of input Matrix, with following constraints:
* Exactly one element from each row can be selected
* If element at (i, j) has been selected, then (i + 1, j) can't be selected
*
* Matrix elements are required to be nonnegative integers.
*/
template<typename Matrix>
constexpr int solve (Matrix const & input) {
int result = 0; // reasonable answer for cases where rows > cols
// special case for 1x1 matrices
if (input.rows() == 1 && input.cols() == 1)
result = input(0, 0);
else if (input.rows() <= input.cols()){
result = Details::solve_non_trivial(input);
}
return result;
}
} // namespace MaxSum
int main() {
constexpr auto trivial = Problem<1u, 1u>{{1}};
static_assert(MaxSum::solve(trivial) == 1);
constexpr auto problem2x2_0 = Problem<2u, 2u>{{1, 0, 0, 1}};
static_assert(MaxSum::solve(problem2x2_0) == 2);
constexpr auto problem2x2_1 = Problem<2u, 2u>{{10, 0, 9, 0}};
static_assert(MaxSum::solve(problem2x2_1) == 10);
constexpr auto problem2x2_2 = Problem<2u, 2u>{{10, 2, 9, 0}};
static_assert(MaxSum::solve(problem2x2_2) == 11);
constexpr auto problem1x5 = Problem<1u, 5u>{{10, 2, 9, 7, 6}};
static_assert(MaxSum::solve(problem1x5) == 10);
constexpr auto problem1x7 = Problem<1u, 7u>{{10, 2, 9, 7, 6, 12, 11}};
static_assert(MaxSum::solve(problem1x7) == 12);
constexpr auto problem3x3 = Problem<3u, 3u>{{1, 2, 3,
5, 6, 4,
3, 2, 4}};
static_assert(MaxSum::solve(problem3x3) == 13);
constexpr auto problem4x4 = Problem<4u, 4u>{{1, 2, 3, 4,
5, 6, 7, 8,
9, 1, 4, 2,
6, 3, 5, 7}};
static_assert(MaxSum::solve(problem4x4) == 27);
}
constoperator[]sconstexprtoo:constexpr T& operator()(std::size_t row, std::size_t col) { /*...*/ }\$\endgroup\$