BaseTrain
Overview¶
The train
pipeline is a pipeline that performs common model training activities, such as performing a grid search, versioning models, computing feature importance, checking for model bias, etc.
Attributes¶
BaseTrain
contains no default attributes.
Configuration¶
BaseTrain
contains no default or required configuration.
Interface¶
The following methods are part of BaseTrain
and should be implemented in any class that inherits from this base class:
split_data¶
Splits the input data into a dictionary containing training/test/validation data.
def split_data(self, data, *args, **kwargs) -> Any
Arguments:
data
(object): The data to split
Returns:
data_dict
(dict): The split data.
train_model¶
Trains a model.
def train_model(self, data, *args, **kwargs) -> tuple[Any, Any]
Arguments:
data
(object): The dictionary of train/test/validation data.
Returns:
model
(object): The trained modelmodel_version
(object): The model version object created by versioning the model.
check_model¶
Runs model checks on a model.
def check_model(self, data, model, model_version, *args, **kwargs)
Arguments:
data
(object): The dictionary of train/test/validation data.model
(object): The trained modelmodel_version
(object): The model version object created by versioning the model.
Returns:
Nothing
analyze_model¶
Runs various model analysis, such as feature importance, baseline comparison, confusion matrix, error plots, etc. Analysis will vary by problem type.
def analyze_model(self, data, model, model_verison, *args, **kwargs)
Arguments:
data
(object): The dictionary of train/test/validation data.model
(object): The trained modelmodel_version
(object): The model version object created by versioning the model.
Returns:
Nothing
compare_models¶
Runs a model drift analysis and generates a drift report.
def compare_models(self, data, model, model_version, *args, **kwargs) -> bool
Arguments:
data
(object): The dictionary of train/test/validation data.model
(object): The trained modelmodel_version
(object): The model version object created by versioning the model.
Returns:
is_new_model_better
(bool): Whether or not the current model version performs better than the deployed version.
check_model_bias¶
Runs model bias checks and logs bias metrics.
def check_model_bias(self, data, model, model_version, *args, **kwargs)
Arguments:
data
(object): The dictionary of train/test/validation data.model
(object): The trained modelmodel_version
(object): The model version object created by versioning the model.
Returns:
Nothing.
retrain_model_on_all_data¶
Runs model bias checks and logs bias metrics.
def retrain_model_on_all_data(self, data, model_version, *args, **kwargs) -> tuple[Any, Any]
Arguments:
data
(object): The dictionary of train/test/validation data.model_version
(object): The model version object created by versioning the model.
Returns:
model
(object): The newly trained model.experiment
(object): The experiment the model was trained in.