CNum 0.2.1
CPU-optimized ML library for C++
Loading...
Searching...
No Matches
Matrix.h
Go to the documentation of this file.
1#ifndef MATRIX_H
2#define MATRIX_H
3
8
9#include <iostream>
10#include <vector>
11#include <atomic>
12#include <memory>
13#include <functional>
14#include <cmath>
15#include <future>
16#include <span>
17#include <cstring>
18#include <stdexcept>
19#include <string>
20
21namespace CNum::DataStructs {
26 enum Dim: uint8_t {
29 };
30
31 class IndexMask;
32 class BinaryMask;
33
34 /**
35 * @class Matrix
36 * @brief 2d array abstraction
37 *
38 * Used for storing 2d, tabular data. Used in conjuction with CNum ML models
39 * and linear algebra operations
40 * @tparam T The type of the data stored
41 */
42 template <typename T>
43 class Matrix {
44 private:
45 ::std::unique_ptr<T[]> _data;
46 size_t _cols;
47 size_t _rows;
48
53 Matrix<T> element_wise(T val, ::std::function< void(T &, T) > func) const noexcept;
54
55 /// @brief Move Logic
56 void move(Matrix<T> &&other) noexcept;
57
58
59 void copy(const Matrix<T> &other) noexcept;
64
65 // Data must be aligned to the cache-line size to avoid false sharing
66 static void par_execute(uint8_t num_threads,
67 size_t total_el,
68 ::std::function< void(size_t) > callback);
69
70 public:
75 Matrix(size_t rows = 0, size_t cols = 0, ::std::unique_ptr<T[]> ptr = nullptr);
76
78 Matrix(const Matrix &other) noexcept;
79
81 Matrix<T> &operator=(const Matrix &other) noexcept;
82
84 Matrix(Matrix &&other) noexcept;
85
86 /// @brief Move Assignment
87 Matrix<T> &operator=(Matrix &&other) noexcept;
88
91
95 Matrix<T> operator*(const Matrix &other) const;
96
99 Matrix<T> operator*(T scale_factor) const noexcept;
100
104 T dot(const Matrix<T> &other) const;
105
109 Matrix<T> operator+(const Matrix &other) const;
114 Matrix<T> operator+(T a) const noexcept;
115
117 /// @param other The matrix to subtract
118 /// @return The resultant matrix
119 Matrix<T> operator-(const Matrix &other) const;
120
124 Matrix<T> operator-(T a) const noexcept;
125
128 Matrix<T> abs() const;
129
133
137
140 T sum() const;
141
144 T mean() const;
145
148 T std() const;
149
154 T get(size_t row, size_t col) const;
155
160 Matrix<T> get(Dim d, size_t idx) const;
161
166
170 ::std::span<T> get_row_view(size_t idx) const;
171
175 T operator[](size_t idx) const;
176
180 Matrix<T> operator[](const BinaryMask &bin_mask) const;
185 Matrix<T> operator[](const IndexMask &idx_mask) const noexcept;
186
188 /// @param idx_mask The mask containing the column indeces to preserve
189 /// @return The masked matrix
190 Matrix<T> col_wise_mask_application(const IndexMask &idx_mask) const noexcept;
191
195 BinaryMask operator<(T val) const;
196
200 BinaryMask operator<=(T val) const;
201
204 /// @return The resultant binary mask
205 BinaryMask operator>(T val) const;
206
210 BinaryMask operator>=(T val) const;
215 BinaryMask operator==(T val) const;
216
220 BinaryMask operator!=(T val) const;
221
224 /// @return An index mask with the sorting order
225 IndexMask argsort(bool descending = false) const;
226
229 Matrix<T> transpose() const noexcept;
230
233
235 static Matrix<T> init_const(size_t rows, size_t cols, T val);
236
240 static Matrix<T> identity(size_t dim);
241
244 /// @return The merged matrix
245 static Matrix<T> join_cols(::std::vector< Matrix<T> > &cols);
246
250 static Matrix<T> combine_vertically(::std::vector< Matrix<T> > &matrices, size_t total_rows);
251
254 size_t get_rows() const;
255
257 size_t get_cols() const;
258
261 const T *begin() const;
262
265 const T *end() const;
266
268 /// @return Raw pointer
269 T *begin();
272 T *end();
273
276 size_t size() const;
277
280 ::std::unique_ptr<T[]> &&move_ptr();
281
283 void print_matrix() const;
284
285
286 };
287
288#include "Matrix.tpp"
289};
290
291#endif
A bit mask used for representing subsets of elements in a container.
Definition BinaryMask.h:22
A list of indecies representing a subset or ordering of data.
Definition IndexMask.h:19
static Matrix< double > init_const(size_t rows, size_t cols, double val)
T sum() const
Get the sum of all elements in a matrix.
Definition Matrix.h:229
Matrix< T > operator-(T a) const noexcept
Subtract a value to every element in a matrix.
Definition Matrix.h:204
Matrix< T > squared() const
Square all elements in a matrix.
Definition Matrix.h:261
BinaryMask operator<=(T val) const
Create a binary mask of values less than or equal to another.
Definition Matrix.h:301
Matrix< T > operator[](const BinaryMask &bin_mask) const
Apply a binary mask.
Definition Matrix.h:404
const double * begin() const
Definition Matrix.h:494
Matrix< T > operator-(const Matrix &other) const
Subtract two matrices element wise.
Definition Matrix.h:188
Matrix(Matrix &&other) noexcept
Move Constructor.
Definition Matrix.h:55
Matrix< T > & operator=(Matrix &&other) noexcept
Move Assignment.
Definition Matrix.h:60
T get(size_t row, size_t col) const
Get value of a matrix.
Definition Matrix.h:342
Matrix< T > abs() const
Take the absolute value of all elements in a matrix.
Definition Matrix.h:224
::std::span< T > get_row_view(size_t idx) const
Get a row view.
Definition Matrix.h:395
IndexMask argsort(bool descending=false) const
Argsort.
Definition Matrix.h:414
const double * end() const
Definition Matrix.h:497
static Matrix< double > identity(size_t dim)
BinaryMask operator==(T val) const
Create a binary mask of values equal to another.
Definition Matrix.h:316
size_t get_rows() const
Definition Matrix.h:482
Matrix< T > operator*(const Matrix &other) const
Dot Product.
Definition Matrix.h:86
static Matrix< double > combine_vertically(::std::vector< Matrix< double > > &matrices, size_t total_rows)
T operator[](size_t idx) const
Get the value at index idx of a Matrix with shape=(n,1).
Definition Matrix.h:331
Matrix< T > operator[](const IndexMask &idx_mask) const noexcept
Apply index mask.
Definition Matrix.h:409
T std() const
Get the standard deviation of all elements in a matrix.
Definition Matrix.h:244
Matrix< T > transpose() const noexcept
Transpose a matrix.
Definition Matrix.h:426
Matrix< T > get(Dim d, size_t idx) const
Get a copy of a Row/Col (prefer views for memory effeciency).
Definition Matrix.h:352
T dot(const Matrix< T > &other) const
Vector dot product (1d).
Definition Matrix.h:117
Matrix(const Matrix &other) noexcept
Copy Constructor.
Definition Matrix.h:29
BinaryMask operator>(T val) const
Create a binary mask of values greater than another.
Definition Matrix.h:306
Matrix< T > col_wise_mask_application(const IndexMask &idx_mask) const noexcept
Apply IndexMask column wise.
Definition Matrix.h:326
BinaryMask operator!=(T val) const
Create a binary mask of values not equal to another.
Definition Matrix.h:321
Matrix< T > standardize() const
Standardize Matrix.
Definition Matrix.h:268
BinaryMask operator<(T val) const
Create a binary mask of values less than another.
Definition Matrix.h:296
size_t size() const
Definition Matrix.h:500
Matrix< T > operator+(T a) const noexcept
Add a value to every element in a matrix.
Definition Matrix.h:181
~Matrix()
Destructor.
Definition Matrix.h:66
Matrix(size_t rows=0, size_t cols=0, ::std::unique_ptr< T[]> ptr=nullptr)
Default Overloaded Constructor.
Definition Matrix.h:7
Matrix< T > operator*(T scale_factor) const noexcept
Scale a matrix.
Definition Matrix.h:110
Matrix< T > operator+(const Matrix &other) const
Add two matrices element wise.
Definition Matrix.h:165
size_t get_cols() const
Definition Matrix.h:485
void print_matrix() const
Definition Matrix.h:508
CNum::DataStructs::Views::StrideView< T > get_col_view(size_t idx) const
Get a column view.
Definition Matrix.h:386
::std::unique_ptr< double[]> && move_ptr()
Definition Matrix.h:503
BinaryMask operator>=(T val) const
Create a binary mask of values greater than or equal to another.
Definition Matrix.h:311
static Matrix< double > join_cols(::std::vector< Matrix< double > > &cols)
Matrix< T > & operator=(const Matrix &other) noexcept
Copy Logic.
Definition Matrix.h:34
T mean() const
Get the mean of all values in a matrix.
Definition Matrix.h:235
Definition StrideView.h:32
The data structures used in CNum.
Definition ConcurrentQueue.h:8
Dim
Definition Matrix.h:26
@ COL
Definition Matrix.h:28
@ ROW
Definition Matrix.h:27