Paper: A Flag Decomposition for Hierarchical Datasets

This article explains the key idea in the paper A Flag Decomposition for Hierarchical Datasets by Mankovich et al. It allows one to decompose (or factorize) a matrix A into a product of matrices QR which tries to preserve the hierarchical structure of the data.
0. Problem Statement
Hierarchical data can be represented by a list of column vectors \( A_i \in \mathbb{R}^{n \times n_i} \), where \( A_i \subset A_{i+1} \). We can also represent the extra columns that are added by \( B_{i+1} = A_{i+1} \setminus A_i \) (assuming a column as an element in the ordered set of columns) with \( B_0 = A_0 \). In the example below, three such matrices are shown.
B₁B₂B₃A₁A₂A₃
Note that \( A_3 \) is the entire data matrix. We want to find an orthogonal basis for each subspace such that the hierarchical structure is preserved.
1. Key idea
We want to create a hierarchy of subspaces that respects the hierarchical structure of the data. Such a subspace can be represented by an orthogonal matrix \( Q: Q^TQ = I \). Specifically, for each \( B_i \in \mathbb{R}^{n \times b_i} \), we want to find a set of \( m_i \) orthonormal vectors \( Q_i \in \mathbb{R}^{n \times m_i},\; m_i \leq b_i \) that preserve the structure of \( B_i \) (in a way that will be defined shortly) as much as possible. We also have \( Q_i^T Q_j = 0 \) for \( i \neq j \), meaning the subspaces are orthogonal to each other.
Once we obtain these \( Q_i \) matrices, here is how the data can be represented:\[ \begin{align} B_1 &\approx (Q_1Q_1^T) B_1 \\ B_2 &\approx (Q_1Q_1^T + Q_2Q_2^T)B_2 \\ B_3 &\approx (Q_1Q_1^T + Q_2Q_2^T + Q_3Q_3^T)B_3 \\ &\vdots \\ B_k &\approx (\sum_{i=1}^{k-1} Q_iQ_i^T) B_k \end{align} \]where k is the number of levels in the hierarchy. Here is how one can interpret each term in the sum above:\( \underbrace{\overbrace{Q_i Q_i^T}^{\amber{\text{projection matrix}}} B_i}_{\rose{\text{projection onto subspace i}}} \)
1.1 Finding the \( Q_i \) matrices
The expression \( B_1 \approx Q_1Q_1^T B_1 \) can be seen as finding a projection matrix that best approximates \( B_1 \) in the subspace spanned by \( Q_1 \). This is essentially solving a least-squares problem. The solution to it is given by the first \( m_i \) columns of the left singular vectors of \( B_1 \).\( \rose{Q_1 = U[:, \text{:}m_1]} \ \text{where} \ \amber{B_1 = U \Sigma V^T} \)
The expression \( B_2 \approx Q_1Q_1^T B_1 + Q_2Q_2^T B_2 \) can be seen as finding a projection matrix that best approximates \( B_2 \) in the subspace spanned by \( Q_2 \), after removing the component already explained by \( Q_1 \).\( \rose{Q_2 = U[:, \text{:}m_2]} \ \text{where} \ \amber{B_2 - Q_1Q_1^TB_2 = U \Sigma V^T} \)
The expression \( B_3 \approx Q_1Q_1^T B_1 + Q_2Q_2^T B_2 + Q_3Q_3^T B_3 \) can be seen as finding a projection matrix that best approximates \( B_3 \) in the subspace spanned by \( Q_3 \), after removing the component already explained by \( Q_1 \) and \( Q_2 \).\( \rose{Q_3 = U[:, \text{:}m_3]} \ \text{where} \ \amber{B_3 - Q_1Q_1^TB_3 - Q_2Q_2^TB_3 = U \Sigma V^T} \)
This pattern continues for all \( Q_i \) where \( i = 1, 2, \ldots, k \).
2. Decomposition
Now that we have computed all the \( Q_i \) matrices, we can write the data matrix \( A_k (k=3) \) as:\( A_k = \)\( \begin{bmatrix} Q_1 & Q_2 & Q_3 \end{bmatrix} \) \( \begin{bmatrix} Q_1^TB_1 & \amber{Q_1^TB_2} & \rose{Q_1^TB_3} \\ \grey{0} & \amber{Q_2^TB_2} & \rose{Q_2^TB_3} \\ \grey{0} & \grey{0} & \rose{Q_3^TB_3} \end{bmatrix} \)\( = QR \)The right matrix can be re-written as a \( k \times k \) block upper triangular matrix \( R \) where each entry \( R_{ij} \in \mathbb{R}^{m_i \times m_j} \). Note that these arguments hold true for any \( k \).
3. Algorithm
The following observation will help us implement the algorithm efficiently:
\( Q_i^T ( B_j - \sum_{c \ne i } Q_c Q_c^T B_j) = Q_i^T B_j \).
The implementation below largely a modification of the official implementation of the algorithm:
def flag_decomposition(data, Aset, flag_type):
    n, p = data.shape
    k = len(flag_type)
    ms = [flag_type[0]] + [flag_type[i]-flag_type[i-1] for i in range(1,k)]
    R = np.zeros((sum(ms), p))
    Qs = []

    B_indices = [np.array(Aset[0])]+[np.setdiff1d(Aset[i],Aset[i-1])for i in range(1,len(Aset))]
    Bs = [data[:,Bset_i] for Bset_i in B_indices]

    for i in range(k):
        for j in range(i,k):
            if j == i:
                Ui,_,_ = np.linalg.svd(Bs[i], full_matrices=False)
                Ui = Ui[:,:ms[i]]
                Qs.append(Ui)
                R_block = Ui.T @ Bs[i]
                R[sum(ms[:i]):sum(ms[:i+1]), B_indices[j]] = R_block
            else:
                Ui = Qs[i]
                R_block = Ui.T @ Bs[j]
                R[sum(ms[:i]):sum(ms[:i+1]), B_indices[j]] = R_block
                Bs[j] = Bs[j] - Ui @ R_block
    
    return np.concatenate(Qs, axis=1), R

Related articles: