CNum 0.2.1
CPU-optimized ML library for C++
Loading...
Searching...
No Matches
CNum::Model::Tree::TreeBooster Class Referenceabstract

A decision tree used in various gradient-boosting models as a weak learner. More...

#include <TreeBooster.h>

Inheritance diagram for CNum::Model::Tree::TreeBooster:
CNum::Model::Tree::XGTreeBooster

Public Member Functions

 TreeBooster (arena_t *a=nullptr, int md=5, int ms=3, double weight_decay=0.0, double reg_lambda=1.0, double gamma=0.0)
 Overloaded default constructor.
 TreeBooster (const TreeBooster &other) noexcept
 Copy constructor.
TreeBoosteroperator= (const TreeBooster &other) noexcept
 Copy equals operator.
 TreeBooster (TreeBooster &&other) noexcept
 Move constructor.
TreeBoosteroperator= (TreeBooster &&other) noexcept
 Move equals operator.
virtual ~TreeBooster ()
 Destructor.
void set_root (TreeBoosterNode *root)
 Set the root of the tree.
virtual void fit (DataMatrix &X, std::shared_ptr< CNum::Data::Shelf[]> shelves, double *g, double *h, DataPartition &partition)=0
CNum::DataStructs::Matrix< double > predict (CNum::DataStructs::Matrix< double > &data)
 Inference (making predictions) on tabular data.
arena_view_t init_hist_view (size_t n_data_cols)
 Allocate space for histograms on the arena.
std::string to_json ()
 Save tree data in json encoded string.

Static Public Member Functions

static size_t partition_data (const CNum::DataStructs::Matrix< int > &X, double *g, double *h, size_t feat, uint8_t bin, const DataPartition &partition)
 Partition idx array, g, and h based on a split to make each nodes' slice of the dataset contigous.
static void histogram_subtraction (const arena_view_t &parent_hist_view, arena_view_t &small_hist_view, arena_view_t &large_hist_view)
 Subtract a parent histogram from "small" histogram for histogram caching.

Protected Attributes

TreeBoosterNode_root
int _max_depth
int _min_samples
double _reg_lambda
double _gamma
double _weight_decay
arena_t_arena

Detailed Description

A decision tree used in various gradient-boosting models as a weak learner.

The TreeBooster class is a robust and effecient weak learner that is good at recognizing subtle relationships between features in tabular data. While weak on their own when used in gradient-boosting algorithms become extremely powerful.

Constructor & Destructor Documentation

◆ TreeBooster() [1/3]

CNum::Model::Tree::TreeBooster::TreeBooster ( arena_t * a = nullptr,
int md = 5,
int ms = 3,
double weight_decay = 0.0,
double reg_lambda = 1.0,
double gamma = 0.0 )

Overloaded default constructor.

Parameters
aThe arena to use for allocation in this tree (arena of the parent thread in the tree building process)

◆ TreeBooster() [2/3]

CNum::Model::Tree::TreeBooster::TreeBooster ( const TreeBooster & other)
noexcept

Copy constructor.

◆ TreeBooster() [3/3]

CNum::Model::Tree::TreeBooster::TreeBooster ( TreeBooster && other)
noexcept

Move constructor.

◆ ~TreeBooster()

virtual CNum::Model::Tree::TreeBooster::~TreeBooster ( )
virtual

Destructor.

Member Function Documentation

◆ fit()

virtual void CNum::Model::Tree::TreeBooster::fit ( DataMatrix & X,
std::shared_ptr< CNum::Data::Shelf[]> shelves,
double * g,
double * h,
DataPartition & partition )
pure virtual

◆ histogram_subtraction()

void CNum::Model::Tree::TreeBooster::histogram_subtraction ( const arena_view_t & parent_hist_view,
arena_view_t & small_hist_view,
arena_view_t & large_hist_view )
static

Subtract a parent histogram from "small" histogram for histogram caching.

CNum's tree boosting models exploit histogram caching which reduces the amount of histograms built by building the histogram for the smaller of 2 child node partitions and subtracting it from the parent to yield the larger partition's histogram

Parameters
parent_hist_viewThe arena_view_t with the parent histogram
small_hist_viewThe arena_view_t with the small partition's histogram
large_hist_viewThe arena_view_t that we will fill with the difference of the parent and small histograms

◆ init_hist_view()

arena_view_t CNum::Model::Tree::TreeBooster::init_hist_view ( size_t n_data_cols)

Allocate space for histograms on the arena.

Parameters
n_data_colsThe number of features in the dataset
Returns
An arena_view_t with the histograms

◆ operator=() [1/2]

TreeBooster & CNum::Model::Tree::TreeBooster::operator= ( const TreeBooster & other)
noexcept

Copy equals operator.

◆ operator=() [2/2]

TreeBooster & CNum::Model::Tree::TreeBooster::operator= ( TreeBooster && other)
noexcept

Move equals operator.

◆ partition_data()

size_t CNum::Model::Tree::TreeBooster::partition_data ( const CNum::DataStructs::Matrix< int > & X,
double * g,
double * h,
size_t feat,
uint8_t bin,
const DataPartition & partition )
static

Partition idx array, g, and h based on a split to make each nodes' slice of the dataset contigous.

Parameters
XThe dataset (row-wise features)
gThe gradient array
hThe hessian array
featThe feature associated with the split
binThe bin associated with the split
partitionThe current node's data partition
Returns
The index of the boundary between the left and right partitions

◆ predict()

CNum::DataStructs::Matrix< double > CNum::Model::Tree::TreeBooster::predict ( CNum::DataStructs::Matrix< double > & data)

Inference (making predictions) on tabular data.

Parameters
dataThe data to make predictions on
Returns
The predictions

◆ set_root()

void CNum::Model::Tree::TreeBooster::set_root ( TreeBoosterNode * root)

Set the root of the tree.

◆ to_json()

std::string CNum::Model::Tree::TreeBooster::to_json ( )

Save tree data in json encoded string.

Returns
The JSON string

Member Data Documentation

◆ _arena

arena_t* CNum::Model::Tree::TreeBooster::_arena
protected

◆ _gamma

double CNum::Model::Tree::TreeBooster::_gamma
protected

◆ _max_depth

int CNum::Model::Tree::TreeBooster::_max_depth
protected

◆ _min_samples

int CNum::Model::Tree::TreeBooster::_min_samples
protected

◆ _reg_lambda

double CNum::Model::Tree::TreeBooster::_reg_lambda
protected

◆ _root

TreeBoosterNode* CNum::Model::Tree::TreeBooster::_root
protected

◆ _weight_decay

double CNum::Model::Tree::TreeBooster::_weight_decay
protected

The documentation for this class was generated from the following file: