CNum 0.2.1
CPU-optimized ML library for C++
Loading...
Searching...
No Matches
XGTreeBooster.h
Go to the documentation of this file.
1#ifndef XG_TREE_BOOSTER_H
2#define XG_TREE_BOOSTER_H
3
4#include "CNum/Data/Data.h"
8
9namespace CNum::Model::Tree {
14 class XGTreeBooster : public TreeBooster {
15 private:
17 virtual void fit_node_greedy(const CNum::DataStructs::Matrix<double> &X,
18 double *g,
19 double *h,
20 TreeBoosterNode *node,
21 int depth = 0) override;
22
35 virtual void fit_node_hist(const CNum::DataStructs::Matrix<int> &X,
36 std::shared_ptr<CNum::Data::Shelf[]> shelves,
37 double *g,
38 double *h,
39 DataPartition &partition,
40 const arena_view_t &parent_hist_view,
41 TreeBoosterNode *node,
42 int depth = 0) override;
43
50 virtual void fit_prep(const CNum::DataStructs::Matrix<int> &X,
51 std::shared_ptr<CNum::Data::Shelf[]> shelves,
52 double *g,
53 double *h,
54 DataPartition &partition) override;
55
62 virtual void fit_prep(const CNum::DataStructs::Matrix<double> &X,
63 std::shared_ptr<CNum::Data::Shelf[]> shelves,
64 double *g,
65 double *h,
66 DataPartition &partition) override;
67
68 public:
75 XGTreeBooster(arena_t *a = nullptr,
76 int md = 5,
77 int ms = 3,
78 double weight_decay = 0.0,
79 double reg_lambda = 1.0,
80 double gamma = 0.0);
81
84
91 virtual void fit(DataMatrix &X,
92 std::shared_ptr<CNum::Data::Shelf[]> shelves,
93 double *g,
94 double *h,
95 DataPartition &partition) override;
96 };
97};
98
99#endif
struct arena arena_t
struct arena_view arena_view_t
2d array abstraction
Definition Matrix.h:43
A node used in a TreeBooster used for gather and storing information about the decision making proces...
Definition TreeBoosterNode.h:25
TreeBooster(arena_t *a=nullptr, int md=5, int ms=3, double weight_decay=0.0, double reg_lambda=1.0, double gamma=0.0)
Overloaded default constructor.
virtual void fit(DataMatrix &X, std::shared_ptr< CNum::Data::Shelf[]> shelves, double *g, double *h, DataPartition &partition) override
Unified fit function.
XGTreeBooster(arena_t *a=nullptr, int md=5, int ms=3, double weight_decay=0.0, double reg_lambda=1.0, double gamma=0.0)
Overloaded constructor.
Tree-based models.
Definition GBModel.h:11
std::variant< CNum::DataStructs::Matrix< int >, CNum::DataStructs::Matrix< double > > DataMatrix
Definition TreeDefs.h:23
Contains bins and the ranges of values they represent.
Definition Data.h:31
A data partition for the set of samples a tree node has to work with during the tree building process...
Definition TreeDefs.h:39