Editorial for Fibonacci Matrix


Remember to use this editorial only when stuck, and not to copy-paste code from it. Please be respectful to the problem author and editorialist.

Submitting an official solution before solving the problem yourself is a bannable offence.
#include <bits/stdc++.h>
using namespace std;

#define MOD 998244353

void read(vector<vector<int>> &m) 
{
    for(auto &x: m) 
        for(auto &y: x) 
            cin >> y;
}

void resize(vector<vector<int>> &m, int rows, int cols) 
{
    m.resize(rows);
    for(auto &x: m) 
        x.resize(cols);
}

void resize(vector<vector<int>> &m, int n) 
{
    resize(m, n, n);
}

int rows(const vector<vector<int>> &a) 
{
    return a.size();
}

int cols(const vector<vector<int>> &a) 
{
    return a[0].size();
}

int add(int a, int b) 
{
    return a + b >= MOD ? a + b - MOD : a + b;
}

int mult(int a, int b) 
{
    return a * (long long)b % MOD;
}

void addref(int &x, int y) 
{
    x = add(x,  y);
}

vector<vector<int>> mult(int scalar, const vector<vector<int>> &m) 
{
    vector<vector<int>> res;
    resize(res, rows(m), cols(m));
    for(int i = 0; i < rows(m); i++) 
        for(int j = 0; j < cols(m); j++) 
            res[i][j] = mult(scalar, m[i][j]);  
    return res;
}

vector<vector<int>> mult(const vector<vector<int>> &a, const vector<vector<int>> &b) 
{
    vector<vector<int>> res;
    resize(res, rows(a), cols(b));
    assert(cols(a) == rows(b));
    for(int i = 0; i < rows(a); i++) 
        for(int j = 0; j < cols(b); j++) 
            for(int k = 0; k < cols(a); k++) 
                addref(res[i][j], mult(a[i][k], b[k][j]));
    return res;
}

void build(vector<vector<int>> &m, const vector<vector<int>> &block, int blockRow, int blockColumn, int blockSize) 
{
    blockRow *= blockSize;
    blockColumn *= blockSize;
    for(int i = blockRow; i < blockRow + blockSize; i++) 
        for(int j = blockColumn; j < blockColumn + blockSize; j++) 
            m[i][j] = block[i - blockRow][j - blockColumn];
}

vector<vector<int>> one(int n) 
{
    vector<vector<int>> res;
    resize(res, n);
    for(int i = 0; i < n; i++) 
        res[i][i] = 1;
    return res;
}

int main() 
{
    int n;
    long long k;
    cin >> n >> k;
    vector<vector<int>> m;
    resize(m, n);
    read(m);
    vector<vector<int>> row, full, e = one(n);
    resize(row, 4 * n, n);
    resize(full, 4 * n, 4 * n);
    build(row, m, 0, 0, n);
    build(row, e, 1, 0, n);
    build(row, mult(2, m), 2, 0, n);                    
    build(row, e, 3, 0, n);
    build(full, m, 0, 0, n);
    build(full, mult(m, m), 0, 1, n);
    build(full, m, 0, 2, n);
    build(full, e, 1, 0, n);
    build(full, mult(2, m), 2, 0, n);
    build(full, m, 2, 2, n);
    build(full, e, 3, 0, n);
    build(full, e, 3, 3, n);
    vector<vector<int>> res = one(4 * n);
    while(k > 0) 
    {
        if(k & 1)
            res = mult(res, full);
        k >>= 1;
        full = mult(full, full);
    }
    vector<vector<int>> final = mult(res, row);
    for(int i = 3 * n; i < 4 * n; i++) 
    {
        for(int j = 0; j < n; j++) 
            cout << final[i][j] << ' ';
        cout << endl;
    }
    return 0;
}