rustframe/compute/models/
pca.rs1use crate::compute::stats::correlation::covariance_matrix;
13use crate::compute::stats::descriptive::mean_vertical;
14use crate::matrix::{Axis, Matrix, SeriesOps};
15
16pub struct PCA {
18 pub components: Matrix<f64>, pub mean: Matrix<f64>, }
21
22impl PCA {
23 pub fn fit(x: &Matrix<f64>, n_components: usize, _iters: usize) -> Self {
24 let mean = mean_vertical(x); let broadcasted_mean = mean.broadcast_row_to_target_shape(x.rows(), x.cols());
26 let centered_data = x.zip(&broadcasted_mean, |x_i, mean_i| x_i - mean_i);
27 let covariance_matrix = covariance_matrix(¢ered_data, Axis::Col); let mut components = Matrix::zeros(n_components, x.cols());
30 for i in 0..n_components {
31 if i < covariance_matrix.rows() {
32 components.row_copy_from_slice(i, &covariance_matrix.row(i));
33 } else {
34 break;
35 }
36 }
37
38 PCA { components, mean }
39 }
40
41 pub fn transform(&self, x: &Matrix<f64>) -> Matrix<f64> {
43 let broadcasted_mean = self.mean.broadcast_row_to_target_shape(x.rows(), x.cols());
44 let centered_data = x.zip(&broadcasted_mean, |x_i, mean_i| x_i - mean_i);
45 centered_data.matrix_mul(&self.components.transpose())
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52 use crate::matrix::Matrix;
53
54 const EPSILON: f64 = 1e-8;
55
56 #[test]
57 fn test_pca_basic() {
58 let data = Matrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0], 3, 2);
60 let (_n_samples, _n_features) = data.shape();
61
62 let pca = PCA::fit(&data, 1, 0); println!("Data shape: {:?}", data.shape());
65 println!("PCA mean shape: {:?}", pca.mean.shape());
66 println!("PCA components shape: {:?}", pca.components.shape());
67
68 assert!((pca.mean.get(0, 0) - 2.0).abs() < EPSILON);
70 assert!((pca.mean.get(0, 1) - 2.0).abs() < EPSILON);
71
72 assert!((pca.components.get(0, 0) - 1.0).abs() < EPSILON);
79 assert!((pca.components.get(0, 1) - 1.0).abs() < EPSILON);
80
81 let transformed_data = pca.transform(&data);
83 assert_eq!(transformed_data.rows(), 3);
84 assert_eq!(transformed_data.cols(), 1);
85 assert!((transformed_data.get(0, 0) - -2.0).abs() < EPSILON);
86 assert!((transformed_data.get(1, 0) - 0.0).abs() < EPSILON);
87 assert!((transformed_data.get(2, 0) - 2.0).abs() < EPSILON);
88 }
89
90 #[test]
91 fn test_pca_fit_break_branch() {
92 let data = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
94 let (_n_samples, n_features) = data.shape();
95
96 let n_components_large = n_features + 1;
98 let pca = PCA::fit(&data, n_components_large, 0);
99
100 assert_eq!(pca.components.rows(), n_components_large);
104 assert_eq!(pca.components.cols(), n_features);
105
106 for i in n_features..n_components_large {
108 for j in 0..n_features {
109 assert!((pca.components.get(i, j) - 0.0).abs() < EPSILON);
110 }
111 }
112 }
113}