{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Solving classification problems with CatBoost"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[](https://colab.research.google.com/github/catboost/tutorials/blob/master/events/pydata_nyc_oct_19_2018.ipynb)\n",
"\n",
"In this tutorial we will use dataset Amazon Employee Access Challenge from [Kaggle](https://www.kaggle.com) competition for our experiments. Data can be downloaded [here](https://www.kaggle.com/c/amazon-employee-access-challenge/data)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Libraries installation"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#!pip install --user --upgrade catboost\n",
"#!pip install --user --upgrade ipywidgets\n",
"#!pip install shap\n",
"#!pip install sklearn\n",
"#!pip install --upgrade numpy\n",
"#!pip install --upgrade pandas\n",
"#!jupyter nbextension enable --py widgetsnbextension"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.10.2\n",
"Python 3.5.5 :: Anaconda custom (64-bit)\r\n"
]
}
],
"source": [
"import catboost\n",
"print(catboost.__version__)\n",
"!python --version"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reading the data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The most simple way — read everything in pandas data frame"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import os\n",
"import numpy as np\n",
"np.set_printoptions(precision=4)\n",
"import catboost\n",
"from catboost import *\n",
"from catboost import datasets"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# import ssl\n",
"# ssl._create_default_https_context = ssl._create_unverified_context"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"(train_df, test_df) = catboost.datasets.amazon()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" ACTION \n",
" RESOURCE \n",
" MGR_ID \n",
" ROLE_ROLLUP_1 \n",
" ROLE_ROLLUP_2 \n",
" ROLE_DEPTNAME \n",
" ROLE_TITLE \n",
" ROLE_FAMILY_DESC \n",
" ROLE_FAMILY \n",
" ROLE_CODE \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 1 \n",
" 39353 \n",
" 85475 \n",
" 117961 \n",
" 118300 \n",
" 123472 \n",
" 117905 \n",
" 117906 \n",
" 290919 \n",
" 117908 \n",
" \n",
" \n",
" 1 \n",
" 1 \n",
" 17183 \n",
" 1540 \n",
" 117961 \n",
" 118343 \n",
" 123125 \n",
" 118536 \n",
" 118536 \n",
" 308574 \n",
" 118539 \n",
" \n",
" \n",
" 2 \n",
" 1 \n",
" 36724 \n",
" 14457 \n",
" 118219 \n",
" 118220 \n",
" 117884 \n",
" 117879 \n",
" 267952 \n",
" 19721 \n",
" 117880 \n",
" \n",
" \n",
" 3 \n",
" 1 \n",
" 36135 \n",
" 5396 \n",
" 117961 \n",
" 118343 \n",
" 119993 \n",
" 118321 \n",
" 240983 \n",
" 290919 \n",
" 118322 \n",
" \n",
" \n",
" 4 \n",
" 1 \n",
" 42680 \n",
" 5905 \n",
" 117929 \n",
" 117930 \n",
" 119569 \n",
" 119323 \n",
" 123932 \n",
" 19793 \n",
" 119325 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" ACTION RESOURCE MGR_ID ROLE_ROLLUP_1 ROLE_ROLLUP_2 ROLE_DEPTNAME \\\n",
"0 1 39353 85475 117961 118300 123472 \n",
"1 1 17183 1540 117961 118343 123125 \n",
"2 1 36724 14457 118219 118220 117884 \n",
"3 1 36135 5396 117961 118343 119993 \n",
"4 1 42680 5905 117929 117930 119569 \n",
"\n",
" ROLE_TITLE ROLE_FAMILY_DESC ROLE_FAMILY ROLE_CODE \n",
"0 117905 117906 290919 117908 \n",
"1 118536 118536 308574 118539 \n",
"2 117879 267952 19721 117880 \n",
"3 118321 240983 290919 118322 \n",
"4 119323 123932 19793 119325 "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preparing your data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Label values extraction"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"y = train_df.ACTION\n",
"X = train_df.drop('ACTION', axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Categorical features declaration"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 1, 2, 3, 4, 5, 6, 7, 8]\n"
]
}
],
"source": [
"cat_features = list(range(0, X.shape[1]))\n",
"print(cat_features)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Looking on label balance in dataset"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Labels: {0, 1}\n",
"Zero count = 1897, One count = 30872\n"
]
}
],
"source": [
"print('Labels: ', set(y))\n",
"print('Zero count = ' + str(len(y) - sum(y)) + ', One count = ' + str(sum(y)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To train model in CatBoost we need to create wrapper class for data: Pool.\n",
"This class stores the data in CatBoost internal format.\n",
"\n",
"There exists several ways to create pool. \n",
"The most simple one: create it from pandas dataframe or numpy array"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"pool1 = Pool(data=X, \n",
" label=y,\n",
" cat_features=cat_features) #Indicies of categorical columns in X"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This approach is not the most efficient, especially for big dataset: we'll need to copy everything from pandas to our internal format.\n",
"\n",
"So CatBoost could create Pool direclty from file"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lets look how we could load Pool from file"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Firstly, lets save data frame to disk"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"dataset_dir = './amazon'\n",
"if not os.path.exists(dataset_dir):\n",
" os.makedirs(dataset_dir)\n",
"\n",
"# We will be able to work with files with/without header and with different separators.\n",
"train_df.to_csv(os.path.join(dataset_dir, 'train.tsv'), index=False, sep='\\t', header=False)\n",
"test_df.to_csv(os.path.join(dataset_dir, 'test.tsv'), index=False, sep='\\t', header=False)\n",
"\n",
"train_df.to_csv(os.path.join(dataset_dir, 'train.csv'), index=False, sep=',', header=True)\n",
"test_df.to_csv(os.path.join(dataset_dir, 'test.csv'), index=False, sep=',', header=True)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\t39353\t85475\t117961\t118300\t123472\t117905\t117906\t290919\t117908\r\n",
"1\t17183\t1540\t117961\t118343\t123125\t118536\t118536\t308574\t118539\r\n"
]
}
],
"source": [
"!head -n2 amazon/train.tsv"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ACTION,RESOURCE,MGR_ID,ROLE_ROLLUP_1,ROLE_ROLLUP_2,ROLE_DEPTNAME,ROLE_TITLE,ROLE_FAMILY_DESC,ROLE_FAMILY,ROLE_CODE\r\n",
"1,39353,85475,117961,118300,123472,117905,117906,290919,117908\r\n",
"1,17183,1540,117961,118343,123125,118536,118536,308574,118539\r\n"
]
}
],
"source": [
"!head -n3 amazon/train.csv"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we have dataset in 2 different formats:\n",
"\n",
"1) tab-separated without header\n",
"\n",
"2) comma-separated with header\n",
"\n",
"\n",
"CatBoost, like pandas, could load data from different formats, we just need to pass proper options\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before we load data, we need to set types of each column\n",
"\n",
"Also, we need to specify columns type. For this CatBoost uses special file, column description\n",
"And we have helper-function to easily do this"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from catboost.utils import create_cd\n",
"\n",
"feature_names = dict()\n",
"for column, name in enumerate(train_df):\n",
" if column == 0:\n",
" continue\n",
" feature_names[column - 1] = name\n",
" \n",
"create_cd(\n",
" label=0, \n",
" cat_features=list(range(1, train_df.columns.shape[0])),\n",
" feature_names=feature_names,\n",
" output_path=os.path.join(dataset_dir, 'train.cd')\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\tLabel\t\r\n",
"1\tCateg\tRESOURCE\r\n",
"2\tCateg\tMGR_ID\r\n",
"3\tCateg\tROLE_ROLLUP_1\r\n",
"4\tCateg\tROLE_ROLLUP_2\r\n",
"5\tCateg\tROLE_DEPTNAME\r\n",
"6\tCateg\tROLE_TITLE\r\n",
"7\tCateg\tROLE_FAMILY_DESC\r\n",
"8\tCateg\tROLE_FAMILY\r\n",
"9\tCateg\tROLE_CODE\r\n"
]
}
],
"source": [
"!cat amazon/train.cd"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"create_cd(\n",
" label=0, \n",
" cat_features=list(range(1, train_df.columns.shape[0])),\n",
" # feature_names=feature_names,\n",
" output_path=os.path.join(dataset_dir, 'train_without_names.cd')\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\tLabel\t\r\n",
"1\tCateg\t\r\n",
"2\tCateg\t\r\n",
"3\tCateg\t\r\n",
"4\tCateg\t\r\n",
"5\tCateg\t\r\n",
"6\tCateg\t\r\n",
"7\tCateg\t\r\n",
"8\tCateg\t\r\n",
"9\tCateg\t\r\n"
]
}
],
"source": [
"!cat amazon/train_without_names.cd"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"create_cd(\n",
" label=0, \n",
" cat_features=list(range(2, train_df.columns.shape[0])),\n",
" feature_names=feature_names,\n",
" output_path=os.path.join(dataset_dir, 'train_with_num.cd')\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\tLabel\t\r\n",
"1\tNum\tRESOURCE\r\n",
"2\tCateg\tMGR_ID\r\n",
"3\tCateg\tROLE_ROLLUP_1\r\n",
"4\tCateg\tROLE_ROLLUP_2\r\n",
"5\tCateg\tROLE_DEPTNAME\r\n",
"6\tCateg\tROLE_TITLE\r\n",
"7\tCateg\tROLE_FAMILY_DESC\r\n",
"8\tCateg\tROLE_FAMILY\r\n",
"9\tCateg\tROLE_CODE\r\n"
]
}
],
"source": [
"!cat amazon/train_with_num.cd"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"? create_cd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's load pool from file now:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"pool2 = Pool(\n",
" data=os.path.join(dataset_dir, 'train.tsv'), \n",
" #delimiter=',', \n",
" column_description=os.path.join(dataset_dir, 'train.cd'),\n",
" # has_header=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Loading pool from file is the fastest way to build Pool if you don't have Pool in RAM yet"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Exercices: load the same pools from csv file"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, if you want maximum performance and you data is already in RAM, than in some cases we could do better, than simply passing dataframe to Pool constructor\n",
"\n",
"We have class FeaturesData that is a fast way to pass data from numpy matrices to catboost"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"? FeaturesData"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# Fastest way to create a Pool is to create it from numpy matrix. This way should be used if you want fast predictions\n",
"# or fastest way to load the data in python.\n",
"\n",
"X_prepared = X.values.astype(str).astype(object)\n",
"# For FeaturesData class categorial features must have type str\n",
"\n",
"pool3 = Pool(\n",
" data=FeaturesData(cat_feature_data=X_prepared, cat_feature_names=list(X)),\n",
" label=y.values\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset shape\n",
"dataset 1:(32769, 9)\n",
"dataset 2:(32769, 9)\n",
"dataset 3:(32769, 9)\n",
"\n",
"\n",
"Column names\n",
"dataset 1:\n",
"['RESOURCE', 'MGR_ID', 'ROLE_ROLLUP_1', 'ROLE_ROLLUP_2', 'ROLE_DEPTNAME', 'ROLE_TITLE', 'ROLE_FAMILY_DESC', 'ROLE_FAMILY', 'ROLE_CODE']\n",
"\n",
"dataset 2:\n",
"['RESOURCE', 'MGR_ID', 'ROLE_ROLLUP_1', 'ROLE_ROLLUP_2', 'ROLE_DEPTNAME', 'ROLE_TITLE', 'ROLE_FAMILY_DESC', 'ROLE_FAMILY', 'ROLE_CODE']\n",
"\n",
"dataset 3:\n",
"['RESOURCE', 'MGR_ID', 'ROLE_ROLLUP_1', 'ROLE_ROLLUP_2', 'ROLE_DEPTNAME', 'ROLE_TITLE', 'ROLE_FAMILY_DESC', 'ROLE_FAMILY', 'ROLE_CODE']\n"
]
}
],
"source": [
"print('Dataset shape')\n",
"print('dataset 1:' + str(pool1.shape) + '\\ndataset 2:' + str(pool2.shape) + \n",
" '\\ndataset 3:' + str(pool3.shape))\n",
"\n",
"print('\\n')\n",
"print('Column names')\n",
"print('dataset 1:')\n",
"print(pool1.get_feature_names()) \n",
"print('\\ndataset 2:')\n",
"print(pool2.get_feature_names())\n",
"print('\\ndataset 3:')\n",
"print(pool3.get_feature_names())\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Split your data into train and validation"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/noxoomo/anaconda3/lib/python3.5/site-packages/sklearn/model_selection/_split.py:2069: FutureWarning: From version 0.21, test_size will always complement train_size unless both are specified.\n",
" FutureWarning)\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"X_train, X_validation, y_train, y_validation = train_test_split(X, y, train_size=0.8, random_state=1234)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" RESOURCE \n",
" MGR_ID \n",
" ROLE_ROLLUP_1 \n",
" ROLE_ROLLUP_2 \n",
" ROLE_DEPTNAME \n",
" ROLE_TITLE \n",
" ROLE_FAMILY_DESC \n",
" ROLE_FAMILY \n",
" ROLE_CODE \n",
" \n",
" \n",
" \n",
" \n",
" 14926 \n",
" 74463 \n",
" 105908 \n",
" 117961 \n",
" 118225 \n",
" 129617 \n",
" 118702 \n",
" 132654 \n",
" 118704 \n",
" 118705 \n",
" \n",
" \n",
" 7940 \n",
" 17278 \n",
" 120340 \n",
" 120342 \n",
" 120343 \n",
" 119076 \n",
" 118834 \n",
" 311236 \n",
" 118424 \n",
" 118836 \n",
" \n",
" \n",
" 24768 \n",
" 79325 \n",
" 17733 \n",
" 117961 \n",
" 118300 \n",
" 119984 \n",
" 118890 \n",
" 125128 \n",
" 118398 \n",
" 118892 \n",
" \n",
" \n",
" 1633 \n",
" 5173 \n",
" 3075 \n",
" 117961 \n",
" 117962 \n",
" 120677 \n",
" 120357 \n",
" 120678 \n",
" 118424 \n",
" 120359 \n",
" \n",
" \n",
" 2743 \n",
" 15672 \n",
" 3745 \n",
" 117961 \n",
" 118300 \n",
" 118360 \n",
" 124435 \n",
" 118362 \n",
" 118363 \n",
" 124436 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" RESOURCE MGR_ID ROLE_ROLLUP_1 ROLE_ROLLUP_2 ROLE_DEPTNAME \\\n",
"14926 74463 105908 117961 118225 129617 \n",
"7940 17278 120340 120342 120343 119076 \n",
"24768 79325 17733 117961 118300 119984 \n",
"1633 5173 3075 117961 117962 120677 \n",
"2743 15672 3745 117961 118300 118360 \n",
"\n",
" ROLE_TITLE ROLE_FAMILY_DESC ROLE_FAMILY ROLE_CODE \n",
"14926 118702 132654 118704 118705 \n",
"7940 118834 311236 118424 118836 \n",
"24768 118890 125128 118398 118892 \n",
"1633 120357 120678 118424 120359 \n",
"2743 124435 118362 118363 124436 "
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.head()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(26215, 9)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.shape"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(6554, 9)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_validation.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Selecting the objective function"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Possible options for binary classification:\n",
"\n",
"`Logloss`\n",
"\n",
"`CrossEntropy` for probabilities in target"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from catboost import CatBoostClassifier\n",
"\n",
"model = CatBoostClassifier(\n",
" iterations=5,\n",
" learning_rate=0.1,\n",
" #loss_function='Logloss',\n",
" #loss_function='CrossEntropy'\n",
")\n",
"\n",
"train_pool = Pool(data=X_train, \n",
" label=y_train, \n",
" cat_features=cat_features)\n",
"\n",
"validation_pool = Pool(data=X_validation, \n",
" label=y_validation, \n",
" cat_features=cat_features)\n",
"model.fit(\n",
" train_pool,\n",
" eval_set=validation_pool,\n",
" verbose=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model is fitted: True\n",
"Model params:\n",
"{'iterations': 5, 'learning_rate': 0.1, 'loss_function': 'Logloss'}\n"
]
}
],
"source": [
"print('Model is fitted: ' + str(model.is_fitted()))\n",
"print('Model params:')\n",
"print(model.get_params())"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = CatBoostClassifier(\n",
" iterations=5,\n",
" learning_rate=0.1,\n",
" #loss_function='Logloss',\n",
" #loss_function='CrossEntropy'\n",
")\n",
"\n",
"model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" eval_set=(X_validation, y_validation),\n",
" verbose=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model is fitted: True\n",
"Model params:\n",
"{'iterations': 5, 'learning_rate': 0.1, 'loss_function': 'Logloss'}\n"
]
}
],
"source": [
"print('Model is fitted: ' + str(model.is_fitted()))\n",
"print('Model params:')\n",
"print(model.get_params())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Stdout of the training"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0:\tlearn: 0.2922585\ttest: 0.2894772\tbest: 0.2894772 (0)\ttotal: 50.6ms\tremaining: 708ms\n",
"1:\tlearn: 0.2200332\ttest: 0.2193253\tbest: 0.2193253 (1)\ttotal: 103ms\tremaining: 667ms\n",
"2:\tlearn: 0.1882770\ttest: 0.1794720\tbest: 0.1794720 (2)\ttotal: 169ms\tremaining: 676ms\n",
"3:\tlearn: 0.1787690\ttest: 0.1655785\tbest: 0.1655785 (3)\ttotal: 237ms\tremaining: 652ms\n",
"4:\tlearn: 0.1748384\ttest: 0.1586419\tbest: 0.1586419 (4)\ttotal: 309ms\tremaining: 619ms\n",
"5:\tlearn: 0.1737776\ttest: 0.1564081\tbest: 0.1564081 (5)\ttotal: 420ms\tremaining: 630ms\n",
"6:\tlearn: 0.1723479\ttest: 0.1542367\tbest: 0.1542367 (6)\ttotal: 508ms\tremaining: 581ms\n",
"7:\tlearn: 0.1719611\ttest: 0.1540092\tbest: 0.1540092 (7)\ttotal: 593ms\tremaining: 518ms\n",
"8:\tlearn: 0.1718945\ttest: 0.1536590\tbest: 0.1536590 (8)\ttotal: 676ms\tremaining: 450ms\n",
"9:\tlearn: 0.1697911\ttest: 0.1524950\tbest: 0.1524950 (9)\ttotal: 738ms\tremaining: 369ms\n",
"10:\tlearn: 0.1684481\ttest: 0.1509084\tbest: 0.1509084 (10)\ttotal: 832ms\tremaining: 303ms\n",
"11:\tlearn: 0.1673863\ttest: 0.1495139\tbest: 0.1495139 (11)\ttotal: 920ms\tremaining: 230ms\n",
"12:\tlearn: 0.1642212\ttest: 0.1473859\tbest: 0.1473859 (12)\ttotal: 1s\tremaining: 154ms\n",
"13:\tlearn: 0.1631211\ttest: 0.1468169\tbest: 0.1468169 (13)\ttotal: 1.08s\tremaining: 77.5ms\n",
"14:\tlearn: 0.1628934\ttest: 0.1465499\tbest: 0.1465499 (14)\ttotal: 1.15s\tremaining: 0us\n",
"\n",
"bestTest = 0.1465498889\n",
"bestIteration = 14\n",
"\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from catboost import CatBoostClassifier\n",
"model = CatBoostClassifier(\n",
" iterations=15,\n",
" #verbose=5,\n",
" logging_level='Verbose'\n",
")\n",
"model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" eval_set=(X_validation, y_validation),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"{ROLE_FAMILY} pr1 tb0 type0, border=1 score 71.20196708\n",
" tensor 0 is redundant, remove it and stop\n",
"0:\tlearn: 0.3020322\ttest: 0.3024549\tbest: 0.3024549 (0)\ttotal: 46.3ms\tremaining: 648ms\n",
"\n",
"{ROLE_FAMILY} pr0 tb0 type0, border=6 score 22.73090753\n",
"{ROLE_FAMILY, ROLE_CODE} pr0 tb0 type1, border=1 score 22.85737971\n",
"{ROLE_TITLE} pr0 tb0 type0, border=13 score 23.18484659\n",
"{ROLE_DEPTNAME, ROLE_FAMILY, ROLE_CODE} pr2 tb0 type0, border=10 score 23.4271605\n",
"{ROLE_ROLLUP_2} pr1 tb0 type0, border=13 score 25.42952513\n",
"{ROLE_TITLE, ROLE_FAMILY} pr2 tb0 type0, border=7 score 25.91828995\n",
"1:\tlearn: 0.2126748\ttest: 0.2068150\tbest: 0.2068150 (1)\ttotal: 137ms\tremaining: 890ms\n",
"\n",
"{MGR_ID} pr2 tb0 type0, border=11 score 9.55130831\n",
"{MGR_ID, ROLE_CODE} pr2 tb0 type0, border=14 score 12.59291657\n",
"{ROLE_FAMILY} pr0 tb0 type1, border=9 score 14.15678078\n",
"{ROLE_TITLE, ROLE_FAMILY} pr1 tb0 type0, border=6 score 14.34218082\n",
"{ROLE_ROLLUP_1, ROLE_FAMILY} pr2 tb0 type0, border=3 score 14.2932558\n",
"{ROLE_FAMILY, ROLE_CODE} pr0 tb0 type1, border=3 score 14.4073355\n",
"2:\tlearn: 0.1856771\ttest: 0.1759821\tbest: 0.1759821 (2)\ttotal: 282ms\tremaining: 1.13s\n",
"\n",
"{MGR_ID} pr2 tb0 type0, border=10 score 6.594784046\n",
"{MGR_ID, ROLE_CODE} pr2 tb0 type0, border=12 score 8.339746264\n",
"{MGR_ID, ROLE_FAMILY_DESC} pr2 tb0 type0, border=11 score 9.115213253\n",
"{ROLE_ROLLUP_2} pr0 tb0 type0, border=12 score 10.5940908\n",
"{MGR_ID, ROLE_FAMILY_DESC} pr0 tb0 type0, border=5 score 10.45322949\n",
"{ROLE_TITLE} pr2 tb0 type0, border=6 score 10.66637876\n",
"3:\tlearn: 0.1756533\ttest: 0.1621181\tbest: 0.1621181 (3)\ttotal: 397ms\tremaining: 1.09s\n",
"\n",
"{MGR_ID} pr2 tb0 type0, border=10 score 4.816932629\n",
"{MGR_ID, ROLE_FAMILY_DESC} pr2 tb0 type0, border=12 score 5.826899508\n",
"{ROLE_ROLLUP_2} pr2 tb0 type0, border=13 score 6.455962742\n",
"{ROLE_FAMILY} pr2 tb0 type0, border=10 score 6.668192917\n",
"{ROLE_TITLE} pr2 tb0 type0, border=8 score 6.787022309\n",
"{ROLE_FAMILY, ROLE_CODE} pr2 tb0 type0, border=1 score 6.784150409\n",
" tensor 5 is redundant, remove it and stop\n",
"4:\tlearn: 0.1732514\ttest: 0.1587816\tbest: 0.1587816 (4)\ttotal: 526ms\tremaining: 1.05s\n",
"\n",
"{MGR_ID} pr2 tb0 type0, border=7 score 2.297190459\n",
"{RESOURCE} pr2 tb0 type0, border=11 score 3.692999825\n",
"{MGR_ID, ROLE_FAMILY_DESC} pr2 tb0 type0, border=9 score 4.566200375\n",
"{RESOURCE} pr0 tb0 type0, border=13 score 5.195022352\n",
"{MGR_ID, ROLE_FAMILY_DESC} pr0 tb0 type0, border=10 score 5.882398347\n",
"{MGR_ID, ROLE_FAMILY_DESC} pr1 tb0 type0, border=6 score 5.715399768\n",
"5:\tlearn: 0.1693831\ttest: 0.1534864\tbest: 0.1534864 (5)\ttotal: 633ms\tremaining: 950ms\n",
"\n",
"{MGR_ID} pr2 tb0 type0, border=9 score 2.56631672\n",
"{MGR_ID, ROLE_FAMILY_DESC} pr2 tb0 type0, border=12 score 3.053372671\n",
"{ROLE_FAMILY} pr0 tb0 type1, border=14 score 3.662099526\n",
" tensor 2 is redundant, remove it and stop\n",
"6:\tlearn: 0.1692429\ttest: 0.1530965\tbest: 0.1530965 (6)\ttotal: 688ms\tremaining: 786ms\n",
"\n",
"{MGR_ID} pr2 tb0 type0, border=9 score 2.662106324\n",
"{MGR_ID, ROLE_CODE} pr0 tb0 type0, border=14 score 3.271764978\n",
" tensor 1 is redundant, remove it and stop\n",
"7:\tlearn: 0.1692302\ttest: 0.1530455\tbest: 0.1530455 (7)\ttotal: 740ms\tremaining: 647ms\n",
"\n",
"{RESOURCE} pr2 tb0 type0, border=9 score 2.704934036\n",
"{ROLE_ROLLUP_2} pr1 tb0 type0, border=9 score 3.380733801\n",
"{ROLE_FAMILY_DESC} pr0 tb0 type1, border=8 score 3.485919025\n",
"{ROLE_ROLLUP_2, ROLE_DEPTNAME} pr2 tb0 type0, border=8 score 4.040211274\n",
"{ROLE_DEPTNAME} pr2 tb0 type0, border=4 score 4.662311897\n",
"{ROLE_ROLLUP_1} pr2 tb0 type0, border=10 score 4.587959116\n",
"8:\tlearn: 0.1678218\ttest: 0.1518261\tbest: 0.1518261 (8)\ttotal: 843ms\tremaining: 562ms\n",
"\n",
"{RESOURCE} pr2 tb0 type0, border=6 score 3.22558481\n",
"{MGR_ID} pr0 tb0 type0, border=6 score 3.598670764\n",
"{MGR_ID} pr2 tb0 type0, border=6 score 4.442316954\n",
"{ROLE_ROLLUP_1} pr1 tb0 type0, border=14 score 4.534285954\n",
" tensor 3 is redundant, remove it and stop\n",
"9:\tlearn: 0.1675718\ttest: 0.1514379\tbest: 0.1514379 (9)\ttotal: 900ms\tremaining: 450ms\n",
"\n",
"{RESOURCE} pr1 tb0 type0, border=8 score 1.706869104\n",
"{RESOURCE} pr2 tb0 type0, border=12 score 2.541399767\n",
"{MGR_ID} pr0 tb0 type0, border=12 score 3.042148696\n",
"{MGR_ID, ROLE_DEPTNAME} pr1 tb0 type0, border=6 score 3.056465669\n",
"{ROLE_DEPTNAME} pr0 tb0 type1, border=14 score 3.417063506\n",
" tensor 4 is redundant, remove it and stop\n",
"10:\tlearn: 0.1662480\ttest: 0.1499766\tbest: 0.1499766 (10)\ttotal: 974ms\tremaining: 354ms\n",
"\n",
"{ROLE_FAMILY_DESC} pr1 tb0 type0, border=13 score 1.188282501\n",
"{MGR_ID, ROLE_FAMILY_DESC} pr2 tb0 type0, border=6 score 2.265758008\n",
"{RESOURCE} pr2 tb0 type0, border=12 score 2.620514967\n",
"{ROLE_CODE} pr0 tb0 type0, border=14 score 3.027549244\n",
" tensor 3 is redundant, remove it and stop\n",
"11:\tlearn: 0.1660153\ttest: 0.1497512\tbest: 0.1497512 (11)\ttotal: 1.04s\tremaining: 260ms\n",
"\n",
"{RESOURCE} pr2 tb0 type0, border=3 score 1.942379095\n",
"{ROLE_FAMILY_DESC} pr2 tb0 type0, border=12 score 2.328950537\n",
"{ROLE_ROLLUP_2, ROLE_FAMILY_DESC} pr2 tb0 type0, border=13 score 2.976624744\n",
"{RESOURCE, ROLE_ROLLUP_1} pr2 tb0 type0, border=11 score 3.247835595\n",
"{RESOURCE, ROLE_ROLLUP_1} pr2 tb0 type0, border=4 score 3.98226103\n",
"{ROLE_ROLLUP_2, ROLE_FAMILY_DESC} pr0 tb0 type0, border=14 score 3.884491014\n",
" tensor 5 is redundant, remove it and stop\n",
"12:\tlearn: 0.1655821\ttest: 0.1492094\tbest: 0.1492094 (12)\ttotal: 1.14s\tremaining: 175ms\n",
"\n",
"{RESOURCE} pr2 tb0 type0, border=6 score 1.401892445\n",
"{ROLE_DEPTNAME} pr2 tb0 type0, border=5 score 2.17295077\n",
"{RESOURCE, ROLE_ROLLUP_1} pr1 tb0 type0, border=6 score 2.344410053\n",
"{RESOURCE, ROLE_ROLLUP_1, ROLE_ROLLUP_2} pr0 tb0 type0, border=6 score 2.784440779\n",
"{ROLE_CODE} pr2 tb0 type0, border=8 score 3.478854515\n",
"{ROLE_DEPTNAME, ROLE_CODE} pr2 tb0 type0, border=7 score 3.183561803\n",
"13:\tlearn: 0.1637591\ttest: 0.1478251\tbest: 0.1478251 (13)\ttotal: 1.27s\tremaining: 91.1ms\n",
"\n",
"{MGR_ID} pr2 tb0 type0, border=5 score 1.678525175\n",
"{RESOURCE} pr2 tb0 type0, border=2 score 1.913548971\n",
"{ROLE_ROLLUP_2} pr1 tb0 type0, border=4 score 2.09899603\n",
"{ROLE_ROLLUP_2, ROLE_FAMILY_DESC} pr0 tb0 type0, border=14 score 2.529101972\n",
" tensor 3 is redundant, remove it and stop\n",
"14:\tlearn: 0.1637261\ttest: 0.1477449\tbest: 0.1477449 (14)\ttotal: 1.34s\tremaining: 0us\n",
"\n",
"bestTest = 0.1477448839\n",
"bestIteration = 14\n",
"\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from catboost import CatBoostClassifier\n",
"model = CatBoostClassifier(\n",
" iterations=15,\n",
" #verbose=5,\n",
" logging_level='Info'\n",
")\n",
"model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" eval_set=(X_validation, y_validation),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Metrics calculation and graph plotting"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b78812f3160e458dbd4fdd81998054e0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from catboost import CatBoostClassifier\n",
"model = CatBoostClassifier(\n",
" iterations=500,\n",
" random_seed=63,\n",
" learning_rate=0.5\n",
")\n",
"model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" eval_set=(X_validation, y_validation),\n",
" verbose=False,\n",
" plot=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Eval metric, custom metrics and best trees count"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a99d9adb234d42a9a82ff7812ebd0718",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from catboost import CatBoostClassifier\n",
"model = CatBoostClassifier(\n",
" iterations=50,\n",
" random_seed=63,\n",
" learning_rate=0.5,\n",
" eval_metric=\"Accuracy\",\n",
" use_best_model=False\n",
")\n",
"\n",
"model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" eval_set=(X_validation, y_validation),\n",
" verbose=False,\n",
" plot=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"50"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.tree_count_"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3237fc1e455c42079880250710ad16a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from catboost import CatBoostClassifier\n",
"model = CatBoostClassifier(\n",
" iterations=50,\n",
" random_seed=63,\n",
" learning_rate=0.5,\n",
" custom_loss=['AUC', 'Accuracy']\n",
")\n",
"model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" eval_set=(X_validation, y_validation),\n",
" verbose=False,\n",
" plot=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"28"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model._tree_count"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Metric hints"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "987ef484e520494a820090a2bb88ac50",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from catboost import CatBoostClassifier\n",
"\n",
"model = CatBoostClassifier(\n",
" iterations=50,\n",
" random_seed=63,\n",
" learning_rate=0.5,\n",
" eval_metric='AUC:hints=skip_train~false' #default\n",
")\n",
"\n",
"model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" eval_set=(X_validation, y_validation),\n",
" verbose=False,\n",
" plot=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "328bd48940a24bdc807e5bc149b56072",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from catboost import CatBoostClassifier\n",
"model = CatBoostClassifier(\n",
" iterations=50,\n",
" random_seed=63,\n",
" learning_rate=0.5,\n",
" eval_metric='AUC:hints=skip_train~false', #default\n",
" metric_period=10\n",
")\n",
"model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" eval_set=(X_validation, y_validation),\n",
" verbose=False,\n",
" plot=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model comparison"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model1 = CatBoostClassifier(\n",
" learning_rate=0.5,\n",
" iterations=100,\n",
" train_dir='learing_rate_0.5'\n",
")\n",
"\n",
"model2 = CatBoostClassifier(\n",
" learning_rate=0.01,\n",
" iterations=100,\n",
" train_dir='learing_rate_0.01'\n",
")\n",
"\n",
"model1.fit(\n",
" X_train, y_train,\n",
" eval_set=(X_validation, y_validation),\n",
" cat_features=cat_features,\n",
" verbose=False\n",
")\n",
"model2.fit(\n",
" X_train, y_train,\n",
" eval_set=(X_validation, y_validation),\n",
" cat_features=cat_features,\n",
" verbose=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d4ec633ced854e4f99dd4731264681ec",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from catboost import MetricVisualizer\n",
"MetricVisualizer(['learing_rate_0.01', 'learing_rate_0.5']).start()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Overfitting detector"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "80e70b13a7564741b84cd3c7328d04e0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_with_early_stop = CatBoostClassifier(\n",
" iterations=200,\n",
" random_seed=63,\n",
" learning_rate=0.5,\n",
" early_stopping_rounds=20\n",
")\n",
"model_with_early_stop.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" eval_set=(X_validation, y_validation),\n",
" verbose=False,\n",
" plot=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"28\n"
]
}
],
"source": [
"print(model_with_early_stop.tree_count_)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "32fc59c9d0204b41b75e2bb13327d8be",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_with_early_stop = CatBoostClassifier(\n",
" eval_metric='AUC',\n",
" iterations=200,\n",
" random_seed=63,\n",
" learning_rate=0.5,\n",
" early_stopping_rounds=20\n",
")\n",
"model_with_early_stop.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" eval_set=(X_validation, y_validation),\n",
" verbose=False,\n",
" plot=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cross-validation"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "447c85da2d564f4890220e78ff7eb230",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from catboost import cv\n",
"\n",
"params = {}\n",
"params['loss_function'] = 'Logloss'\n",
"params['iterations'] = 80\n",
"params['custom_loss'] = 'AUC'\n",
"params['random_seed'] = 63\n",
"params['learning_rate'] = 0.5\n",
"\n",
"cv_data = cv(\n",
" params = params,\n",
" pool = Pool(X, label=y, cat_features=cat_features),\n",
" fold_count=5,\n",
" type = 'Classical',\n",
" shuffle=True,\n",
" partition_random_seed=0,\n",
" plot=True,\n",
" stratified=False,\n",
" verbose=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" test-AUC-mean \n",
" test-AUC-std \n",
" test-Logloss-mean \n",
" test-Logloss-std \n",
" train-AUC-mean \n",
" train-AUC-std \n",
" train-Logloss-mean \n",
" train-Logloss-std \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 0.500000 \n",
" 0.000000 \n",
" 0.302217 \n",
" 0.003871 \n",
" 0.499996 \n",
" 0.000009 \n",
" 0.302189 \n",
" 0.002418 \n",
" \n",
" \n",
" 1 \n",
" 0.625287 \n",
" 0.098486 \n",
" 0.225502 \n",
" 0.014372 \n",
" 0.618981 \n",
" 0.088396 \n",
" 0.227776 \n",
" 0.009594 \n",
" \n",
" \n",
" 2 \n",
" 0.797381 \n",
" 0.031802 \n",
" 0.182074 \n",
" 0.010177 \n",
" 0.758575 \n",
" 0.036903 \n",
" 0.190874 \n",
" 0.004790 \n",
" \n",
" \n",
" 3 \n",
" 0.818542 \n",
" 0.019081 \n",
" 0.168630 \n",
" 0.006834 \n",
" 0.778474 \n",
" 0.012547 \n",
" 0.180657 \n",
" 0.004351 \n",
" \n",
" \n",
" 4 \n",
" 0.823875 \n",
" 0.013242 \n",
" 0.163591 \n",
" 0.007106 \n",
" 0.784664 \n",
" 0.007662 \n",
" 0.176791 \n",
" 0.003464 \n",
" \n",
" \n",
" 5 \n",
" 0.842055 \n",
" 0.006845 \n",
" 0.159424 \n",
" 0.006278 \n",
" 0.801486 \n",
" 0.008265 \n",
" 0.173545 \n",
" 0.003944 \n",
" \n",
" \n",
" 6 \n",
" 0.852505 \n",
" 0.011592 \n",
" 0.155835 \n",
" 0.007250 \n",
" 0.812029 \n",
" 0.004814 \n",
" 0.170815 \n",
" 0.003318 \n",
" \n",
" \n",
" 7 \n",
" 0.859290 \n",
" 0.013055 \n",
" 0.152686 \n",
" 0.007374 \n",
" 0.821573 \n",
" 0.004150 \n",
" 0.168168 \n",
" 0.002686 \n",
" \n",
" \n",
" 8 \n",
" 0.862400 \n",
" 0.012930 \n",
" 0.151412 \n",
" 0.007750 \n",
" 0.825537 \n",
" 0.002166 \n",
" 0.166854 \n",
" 0.002284 \n",
" \n",
" \n",
" 9 \n",
" 0.864313 \n",
" 0.012127 \n",
" 0.150388 \n",
" 0.007312 \n",
" 0.827699 \n",
" 0.004355 \n",
" 0.165980 \n",
" 0.002763 \n",
" \n",
" \n",
" 10 \n",
" 0.864621 \n",
" 0.012251 \n",
" 0.149981 \n",
" 0.007268 \n",
" 0.828238 \n",
" 0.004967 \n",
" 0.165710 \n",
" 0.002916 \n",
" \n",
" \n",
" 11 \n",
" 0.869911 \n",
" 0.012121 \n",
" 0.148150 \n",
" 0.008058 \n",
" 0.832391 \n",
" 0.004098 \n",
" 0.164250 \n",
" 0.002763 \n",
" \n",
" \n",
" 12 \n",
" 0.871178 \n",
" 0.011428 \n",
" 0.147479 \n",
" 0.008175 \n",
" 0.834270 \n",
" 0.005416 \n",
" 0.163328 \n",
" 0.003049 \n",
" \n",
" \n",
" 13 \n",
" 0.871663 \n",
" 0.011281 \n",
" 0.147303 \n",
" 0.008230 \n",
" 0.834886 \n",
" 0.005530 \n",
" 0.163057 \n",
" 0.003074 \n",
" \n",
" \n",
" 14 \n",
" 0.872252 \n",
" 0.010388 \n",
" 0.147200 \n",
" 0.008247 \n",
" 0.835580 \n",
" 0.006042 \n",
" 0.162711 \n",
" 0.003282 \n",
" \n",
" \n",
" 15 \n",
" 0.872506 \n",
" 0.010518 \n",
" 0.147003 \n",
" 0.008478 \n",
" 0.836177 \n",
" 0.006505 \n",
" 0.162350 \n",
" 0.003406 \n",
" \n",
" \n",
" 16 \n",
" 0.873782 \n",
" 0.009503 \n",
" 0.146393 \n",
" 0.008212 \n",
" 0.837587 \n",
" 0.006670 \n",
" 0.161729 \n",
" 0.003425 \n",
" \n",
" \n",
" 17 \n",
" 0.875730 \n",
" 0.008438 \n",
" 0.145739 \n",
" 0.007961 \n",
" 0.840422 \n",
" 0.008338 \n",
" 0.160630 \n",
" 0.003884 \n",
" \n",
" \n",
" 18 \n",
" 0.876655 \n",
" 0.009206 \n",
" 0.145421 \n",
" 0.008027 \n",
" 0.842564 \n",
" 0.005416 \n",
" 0.159910 \n",
" 0.003345 \n",
" \n",
" \n",
" 19 \n",
" 0.876961 \n",
" 0.009529 \n",
" 0.145151 \n",
" 0.008368 \n",
" 0.843762 \n",
" 0.004223 \n",
" 0.159456 \n",
" 0.002908 \n",
" \n",
" \n",
" 20 \n",
" 0.877131 \n",
" 0.009657 \n",
" 0.145095 \n",
" 0.008361 \n",
" 0.843862 \n",
" 0.004137 \n",
" 0.159337 \n",
" 0.002843 \n",
" \n",
" \n",
" 21 \n",
" 0.877797 \n",
" 0.009627 \n",
" 0.144691 \n",
" 0.008332 \n",
" 0.845742 \n",
" 0.004977 \n",
" 0.158503 \n",
" 0.002997 \n",
" \n",
" \n",
" 22 \n",
" 0.878456 \n",
" 0.009885 \n",
" 0.144474 \n",
" 0.008387 \n",
" 0.847026 \n",
" 0.004351 \n",
" 0.157964 \n",
" 0.002731 \n",
" \n",
" \n",
" 23 \n",
" 0.878098 \n",
" 0.010434 \n",
" 0.144497 \n",
" 0.008484 \n",
" 0.848029 \n",
" 0.005031 \n",
" 0.157516 \n",
" 0.002853 \n",
" \n",
" \n",
" 24 \n",
" 0.878518 \n",
" 0.010405 \n",
" 0.144371 \n",
" 0.008634 \n",
" 0.848554 \n",
" 0.005170 \n",
" 0.157213 \n",
" 0.002953 \n",
" \n",
" \n",
" 25 \n",
" 0.879114 \n",
" 0.010706 \n",
" 0.144143 \n",
" 0.008871 \n",
" 0.849741 \n",
" 0.005065 \n",
" 0.156656 \n",
" 0.002807 \n",
" \n",
" \n",
" 26 \n",
" 0.879001 \n",
" 0.010942 \n",
" 0.144083 \n",
" 0.008958 \n",
" 0.850191 \n",
" 0.004955 \n",
" 0.156296 \n",
" 0.002907 \n",
" \n",
" \n",
" 27 \n",
" 0.879322 \n",
" 0.011028 \n",
" 0.143980 \n",
" 0.008761 \n",
" 0.850911 \n",
" 0.005407 \n",
" 0.155861 \n",
" 0.002955 \n",
" \n",
" \n",
" 28 \n",
" 0.879694 \n",
" 0.011326 \n",
" 0.143692 \n",
" 0.009107 \n",
" 0.851471 \n",
" 0.004956 \n",
" 0.155599 \n",
" 0.002789 \n",
" \n",
" \n",
" 29 \n",
" 0.879781 \n",
" 0.011045 \n",
" 0.143615 \n",
" 0.009135 \n",
" 0.851630 \n",
" 0.005069 \n",
" 0.155455 \n",
" 0.002823 \n",
" \n",
" \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" \n",
" \n",
" 50 \n",
" 0.882512 \n",
" 0.012500 \n",
" 0.142604 \n",
" 0.010409 \n",
" 0.860230 \n",
" 0.005378 \n",
" 0.150163 \n",
" 0.003652 \n",
" \n",
" \n",
" 51 \n",
" 0.882453 \n",
" 0.012385 \n",
" 0.142621 \n",
" 0.010400 \n",
" 0.860310 \n",
" 0.005331 \n",
" 0.150098 \n",
" 0.003616 \n",
" \n",
" \n",
" 52 \n",
" 0.882385 \n",
" 0.012415 \n",
" 0.142637 \n",
" 0.010403 \n",
" 0.860454 \n",
" 0.005418 \n",
" 0.149982 \n",
" 0.003705 \n",
" \n",
" \n",
" 53 \n",
" 0.882443 \n",
" 0.012071 \n",
" 0.142524 \n",
" 0.010384 \n",
" 0.861073 \n",
" 0.005605 \n",
" 0.149685 \n",
" 0.003798 \n",
" \n",
" \n",
" 54 \n",
" 0.882297 \n",
" 0.012177 \n",
" 0.142517 \n",
" 0.010484 \n",
" 0.861248 \n",
" 0.005420 \n",
" 0.149561 \n",
" 0.003734 \n",
" \n",
" \n",
" 55 \n",
" 0.882199 \n",
" 0.012224 \n",
" 0.142556 \n",
" 0.010513 \n",
" 0.861319 \n",
" 0.005484 \n",
" 0.149484 \n",
" 0.003754 \n",
" \n",
" \n",
" 56 \n",
" 0.882194 \n",
" 0.012205 \n",
" 0.142574 \n",
" 0.010495 \n",
" 0.861369 \n",
" 0.005492 \n",
" 0.149424 \n",
" 0.003744 \n",
" \n",
" \n",
" 57 \n",
" 0.882277 \n",
" 0.012128 \n",
" 0.142493 \n",
" 0.010513 \n",
" 0.861646 \n",
" 0.005424 \n",
" 0.149269 \n",
" 0.003714 \n",
" \n",
" \n",
" 58 \n",
" 0.882321 \n",
" 0.012203 \n",
" 0.142489 \n",
" 0.010529 \n",
" 0.861708 \n",
" 0.005415 \n",
" 0.149229 \n",
" 0.003720 \n",
" \n",
" \n",
" 59 \n",
" 0.882194 \n",
" 0.012165 \n",
" 0.142555 \n",
" 0.010500 \n",
" 0.861739 \n",
" 0.005397 \n",
" 0.149171 \n",
" 0.003693 \n",
" \n",
" \n",
" 60 \n",
" 0.882152 \n",
" 0.012189 \n",
" 0.142579 \n",
" 0.010494 \n",
" 0.861742 \n",
" 0.005393 \n",
" 0.149148 \n",
" 0.003703 \n",
" \n",
" \n",
" 61 \n",
" 0.882442 \n",
" 0.012275 \n",
" 0.142514 \n",
" 0.010480 \n",
" 0.862027 \n",
" 0.005136 \n",
" 0.148943 \n",
" 0.003607 \n",
" \n",
" \n",
" 62 \n",
" 0.882531 \n",
" 0.012148 \n",
" 0.142492 \n",
" 0.010439 \n",
" 0.862121 \n",
" 0.005156 \n",
" 0.148857 \n",
" 0.003626 \n",
" \n",
" \n",
" 63 \n",
" 0.882384 \n",
" 0.012305 \n",
" 0.142450 \n",
" 0.010402 \n",
" 0.862249 \n",
" 0.005253 \n",
" 0.148737 \n",
" 0.003715 \n",
" \n",
" \n",
" 64 \n",
" 0.882327 \n",
" 0.012362 \n",
" 0.142604 \n",
" 0.010350 \n",
" 0.862404 \n",
" 0.005327 \n",
" 0.148628 \n",
" 0.003732 \n",
" \n",
" \n",
" 65 \n",
" 0.882394 \n",
" 0.012476 \n",
" 0.142554 \n",
" 0.010403 \n",
" 0.862664 \n",
" 0.005393 \n",
" 0.148464 \n",
" 0.003785 \n",
" \n",
" \n",
" 66 \n",
" 0.882718 \n",
" 0.012037 \n",
" 0.142475 \n",
" 0.010294 \n",
" 0.862964 \n",
" 0.005743 \n",
" 0.148285 \n",
" 0.003972 \n",
" \n",
" \n",
" 67 \n",
" 0.882441 \n",
" 0.012522 \n",
" 0.142524 \n",
" 0.010263 \n",
" 0.863391 \n",
" 0.006173 \n",
" 0.148076 \n",
" 0.004102 \n",
" \n",
" \n",
" 68 \n",
" 0.882410 \n",
" 0.012563 \n",
" 0.142535 \n",
" 0.010284 \n",
" 0.863487 \n",
" 0.006285 \n",
" 0.147985 \n",
" 0.004199 \n",
" \n",
" \n",
" 69 \n",
" 0.882688 \n",
" 0.012817 \n",
" 0.142429 \n",
" 0.010369 \n",
" 0.863947 \n",
" 0.006193 \n",
" 0.147717 \n",
" 0.004175 \n",
" \n",
" \n",
" 70 \n",
" 0.882825 \n",
" 0.012729 \n",
" 0.142343 \n",
" 0.010321 \n",
" 0.864080 \n",
" 0.006283 \n",
" 0.147574 \n",
" 0.004268 \n",
" \n",
" \n",
" 71 \n",
" 0.882948 \n",
" 0.012839 \n",
" 0.142289 \n",
" 0.010418 \n",
" 0.864243 \n",
" 0.006387 \n",
" 0.147433 \n",
" 0.004301 \n",
" \n",
" \n",
" 72 \n",
" 0.883091 \n",
" 0.012895 \n",
" 0.142179 \n",
" 0.010493 \n",
" 0.864703 \n",
" 0.006083 \n",
" 0.147174 \n",
" 0.004146 \n",
" \n",
" \n",
" 73 \n",
" 0.883080 \n",
" 0.012929 \n",
" 0.142213 \n",
" 0.010555 \n",
" 0.864784 \n",
" 0.006046 \n",
" 0.147111 \n",
" 0.004117 \n",
" \n",
" \n",
" 74 \n",
" 0.882921 \n",
" 0.012968 \n",
" 0.142223 \n",
" 0.010504 \n",
" 0.865053 \n",
" 0.006098 \n",
" 0.146930 \n",
" 0.004142 \n",
" \n",
" \n",
" 75 \n",
" 0.882927 \n",
" 0.012993 \n",
" 0.142192 \n",
" 0.010595 \n",
" 0.865155 \n",
" 0.006075 \n",
" 0.146842 \n",
" 0.004101 \n",
" \n",
" \n",
" 76 \n",
" 0.882942 \n",
" 0.012987 \n",
" 0.142221 \n",
" 0.010574 \n",
" 0.865304 \n",
" 0.005987 \n",
" 0.146695 \n",
" 0.004043 \n",
" \n",
" \n",
" 77 \n",
" 0.883305 \n",
" 0.012968 \n",
" 0.142100 \n",
" 0.010714 \n",
" 0.865512 \n",
" 0.006000 \n",
" 0.146537 \n",
" 0.003998 \n",
" \n",
" \n",
" 78 \n",
" 0.883523 \n",
" 0.013146 \n",
" 0.142131 \n",
" 0.010681 \n",
" 0.865617 \n",
" 0.005938 \n",
" 0.146448 \n",
" 0.003941 \n",
" \n",
" \n",
" 79 \n",
" 0.883661 \n",
" 0.013047 \n",
" 0.142058 \n",
" 0.010638 \n",
" 0.865727 \n",
" 0.005961 \n",
" 0.146341 \n",
" 0.003937 \n",
" \n",
" \n",
"
\n",
"
80 rows × 8 columns
\n",
"
"
],
"text/plain": [
" test-AUC-mean test-AUC-std test-Logloss-mean test-Logloss-std \\\n",
"0 0.500000 0.000000 0.302217 0.003871 \n",
"1 0.625287 0.098486 0.225502 0.014372 \n",
"2 0.797381 0.031802 0.182074 0.010177 \n",
"3 0.818542 0.019081 0.168630 0.006834 \n",
"4 0.823875 0.013242 0.163591 0.007106 \n",
"5 0.842055 0.006845 0.159424 0.006278 \n",
"6 0.852505 0.011592 0.155835 0.007250 \n",
"7 0.859290 0.013055 0.152686 0.007374 \n",
"8 0.862400 0.012930 0.151412 0.007750 \n",
"9 0.864313 0.012127 0.150388 0.007312 \n",
"10 0.864621 0.012251 0.149981 0.007268 \n",
"11 0.869911 0.012121 0.148150 0.008058 \n",
"12 0.871178 0.011428 0.147479 0.008175 \n",
"13 0.871663 0.011281 0.147303 0.008230 \n",
"14 0.872252 0.010388 0.147200 0.008247 \n",
"15 0.872506 0.010518 0.147003 0.008478 \n",
"16 0.873782 0.009503 0.146393 0.008212 \n",
"17 0.875730 0.008438 0.145739 0.007961 \n",
"18 0.876655 0.009206 0.145421 0.008027 \n",
"19 0.876961 0.009529 0.145151 0.008368 \n",
"20 0.877131 0.009657 0.145095 0.008361 \n",
"21 0.877797 0.009627 0.144691 0.008332 \n",
"22 0.878456 0.009885 0.144474 0.008387 \n",
"23 0.878098 0.010434 0.144497 0.008484 \n",
"24 0.878518 0.010405 0.144371 0.008634 \n",
"25 0.879114 0.010706 0.144143 0.008871 \n",
"26 0.879001 0.010942 0.144083 0.008958 \n",
"27 0.879322 0.011028 0.143980 0.008761 \n",
"28 0.879694 0.011326 0.143692 0.009107 \n",
"29 0.879781 0.011045 0.143615 0.009135 \n",
".. ... ... ... ... \n",
"50 0.882512 0.012500 0.142604 0.010409 \n",
"51 0.882453 0.012385 0.142621 0.010400 \n",
"52 0.882385 0.012415 0.142637 0.010403 \n",
"53 0.882443 0.012071 0.142524 0.010384 \n",
"54 0.882297 0.012177 0.142517 0.010484 \n",
"55 0.882199 0.012224 0.142556 0.010513 \n",
"56 0.882194 0.012205 0.142574 0.010495 \n",
"57 0.882277 0.012128 0.142493 0.010513 \n",
"58 0.882321 0.012203 0.142489 0.010529 \n",
"59 0.882194 0.012165 0.142555 0.010500 \n",
"60 0.882152 0.012189 0.142579 0.010494 \n",
"61 0.882442 0.012275 0.142514 0.010480 \n",
"62 0.882531 0.012148 0.142492 0.010439 \n",
"63 0.882384 0.012305 0.142450 0.010402 \n",
"64 0.882327 0.012362 0.142604 0.010350 \n",
"65 0.882394 0.012476 0.142554 0.010403 \n",
"66 0.882718 0.012037 0.142475 0.010294 \n",
"67 0.882441 0.012522 0.142524 0.010263 \n",
"68 0.882410 0.012563 0.142535 0.010284 \n",
"69 0.882688 0.012817 0.142429 0.010369 \n",
"70 0.882825 0.012729 0.142343 0.010321 \n",
"71 0.882948 0.012839 0.142289 0.010418 \n",
"72 0.883091 0.012895 0.142179 0.010493 \n",
"73 0.883080 0.012929 0.142213 0.010555 \n",
"74 0.882921 0.012968 0.142223 0.010504 \n",
"75 0.882927 0.012993 0.142192 0.010595 \n",
"76 0.882942 0.012987 0.142221 0.010574 \n",
"77 0.883305 0.012968 0.142100 0.010714 \n",
"78 0.883523 0.013146 0.142131 0.010681 \n",
"79 0.883661 0.013047 0.142058 0.010638 \n",
"\n",
" train-AUC-mean train-AUC-std train-Logloss-mean train-Logloss-std \n",
"0 0.499996 0.000009 0.302189 0.002418 \n",
"1 0.618981 0.088396 0.227776 0.009594 \n",
"2 0.758575 0.036903 0.190874 0.004790 \n",
"3 0.778474 0.012547 0.180657 0.004351 \n",
"4 0.784664 0.007662 0.176791 0.003464 \n",
"5 0.801486 0.008265 0.173545 0.003944 \n",
"6 0.812029 0.004814 0.170815 0.003318 \n",
"7 0.821573 0.004150 0.168168 0.002686 \n",
"8 0.825537 0.002166 0.166854 0.002284 \n",
"9 0.827699 0.004355 0.165980 0.002763 \n",
"10 0.828238 0.004967 0.165710 0.002916 \n",
"11 0.832391 0.004098 0.164250 0.002763 \n",
"12 0.834270 0.005416 0.163328 0.003049 \n",
"13 0.834886 0.005530 0.163057 0.003074 \n",
"14 0.835580 0.006042 0.162711 0.003282 \n",
"15 0.836177 0.006505 0.162350 0.003406 \n",
"16 0.837587 0.006670 0.161729 0.003425 \n",
"17 0.840422 0.008338 0.160630 0.003884 \n",
"18 0.842564 0.005416 0.159910 0.003345 \n",
"19 0.843762 0.004223 0.159456 0.002908 \n",
"20 0.843862 0.004137 0.159337 0.002843 \n",
"21 0.845742 0.004977 0.158503 0.002997 \n",
"22 0.847026 0.004351 0.157964 0.002731 \n",
"23 0.848029 0.005031 0.157516 0.002853 \n",
"24 0.848554 0.005170 0.157213 0.002953 \n",
"25 0.849741 0.005065 0.156656 0.002807 \n",
"26 0.850191 0.004955 0.156296 0.002907 \n",
"27 0.850911 0.005407 0.155861 0.002955 \n",
"28 0.851471 0.004956 0.155599 0.002789 \n",
"29 0.851630 0.005069 0.155455 0.002823 \n",
".. ... ... ... ... \n",
"50 0.860230 0.005378 0.150163 0.003652 \n",
"51 0.860310 0.005331 0.150098 0.003616 \n",
"52 0.860454 0.005418 0.149982 0.003705 \n",
"53 0.861073 0.005605 0.149685 0.003798 \n",
"54 0.861248 0.005420 0.149561 0.003734 \n",
"55 0.861319 0.005484 0.149484 0.003754 \n",
"56 0.861369 0.005492 0.149424 0.003744 \n",
"57 0.861646 0.005424 0.149269 0.003714 \n",
"58 0.861708 0.005415 0.149229 0.003720 \n",
"59 0.861739 0.005397 0.149171 0.003693 \n",
"60 0.861742 0.005393 0.149148 0.003703 \n",
"61 0.862027 0.005136 0.148943 0.003607 \n",
"62 0.862121 0.005156 0.148857 0.003626 \n",
"63 0.862249 0.005253 0.148737 0.003715 \n",
"64 0.862404 0.005327 0.148628 0.003732 \n",
"65 0.862664 0.005393 0.148464 0.003785 \n",
"66 0.862964 0.005743 0.148285 0.003972 \n",
"67 0.863391 0.006173 0.148076 0.004102 \n",
"68 0.863487 0.006285 0.147985 0.004199 \n",
"69 0.863947 0.006193 0.147717 0.004175 \n",
"70 0.864080 0.006283 0.147574 0.004268 \n",
"71 0.864243 0.006387 0.147433 0.004301 \n",
"72 0.864703 0.006083 0.147174 0.004146 \n",
"73 0.864784 0.006046 0.147111 0.004117 \n",
"74 0.865053 0.006098 0.146930 0.004142 \n",
"75 0.865155 0.006075 0.146842 0.004101 \n",
"76 0.865304 0.005987 0.146695 0.004043 \n",
"77 0.865512 0.006000 0.146537 0.003998 \n",
"78 0.865617 0.005938 0.146448 0.003941 \n",
"79 0.865727 0.005961 0.146341 0.003937 \n",
"\n",
"[80 rows x 8 columns]"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cv_data"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best validation Logloss score, not stratified: 0.1421±0.0106 on step 79\n"
]
}
],
"source": [
"best_value = np.min(cv_data['test-Logloss-mean'])\n",
"best_iter = cv_data['test-Logloss-mean'].idxmin()\n",
"\n",
"print('Best validation Logloss score, not stratified: {:.4f}±{:.4f} on step {}'.format(\n",
" best_value,\n",
" cv_data['test-Logloss-std'][best_iter],\n",
" best_iter)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f52baa02b8ad418ea07b4b50be5def76",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best validation Logloss score, not stratified: 0.1412±0.0055 on step 67\n"
]
}
],
"source": [
"cv_data = cv(\n",
" params = params,\n",
" pool = Pool(X, label=y, cat_features=cat_features),\n",
" fold_count=5,\n",
" type = 'Classical',\n",
" shuffle=True,\n",
" partition_random_seed=0,\n",
" plot=True,\n",
" stratified=True,\n",
" verbose=False\n",
")\n",
"\n",
"best_value = np.min(cv_data['test-Logloss-mean'])\n",
"best_iter = cv_data['test-Logloss-mean'].idxmin()\n",
"\n",
"print('Best validation Logloss score, not stratified: {:.4f}±{:.4f} on step {}'.format(\n",
" best_value,\n",
" cv_data['test-Logloss-std'][best_iter],\n",
" best_iter)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Select decision boundary"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ba81252dc8bb4d3e858b3627977b24ba",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = CatBoostClassifier(\n",
" random_seed=63,\n",
" iterations=200,\n",
" learning_rate=0.03,\n",
")\n",
"model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" verbose=False,\n",
" plot=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import sklearn\n",
"from sklearn import metrics\n",
"from catboost.utils import get_roc_curve"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"? get_roc_curve"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"eval_pool = Pool(X_validation, y_validation, cat_features=cat_features)\n",
"curve = get_roc_curve(model, eval_pool)\n",
"(fpr, tpr, thresholds) = curve\n",
"\n",
"plt.figure()\n",
"lw = 2\n",
"roc_auc = sklearn.metrics.auc(fpr, tpr)\n",
"plt.plot(fpr, tpr, color='darkorange',\n",
" lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)\n",
"\n",
"plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')\n",
"plt.xlim([0.0, 1.0])\n",
"plt.ylim([0.0, 1.05])\n",
"plt.xlabel('False Positive Rate')\n",
"plt.ylabel('True Positive Rate')\n",
"plt.title('Receiver operating characteristic')\n",
"plt.legend(loc=\"lower right\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"from catboost.utils import get_fpr_curve\n",
"from catboost.utils import get_fnr_curve"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"? get_fpr_curve"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure()\n",
"lw = 2\n",
"(thresholds, fpr) = get_fpr_curve(curve=curve)\n",
"(thresholds, fnr) = get_fnr_curve(curve=curve)\n",
"plt.plot(thresholds, fpr, color='blue', lw=lw, label='FPR')\n",
"plt.plot(thresholds, fnr, color='green', lw=lw, label='FNR')\n",
"\n",
"plt.xlim([0.0, 1.0])\n",
"plt.ylim([0.0, 1.05])\n",
"plt.xlabel('Threshold')\n",
"plt.ylabel('Error Rate')\n",
"plt.title('FPR-FNR curves')\n",
"plt.legend(loc=\"lower left\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4719014981770065\n",
"0.9876472944285266\n"
]
}
],
"source": [
"from catboost.utils import select_threshold\n",
"\n",
"print(select_threshold(model=model, data=eval_pool, FNR=0.01))\n",
"print(select_threshold(model=model, data=eval_pool, FPR=0.01))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Snapshotting"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"#!rm 'catboost_info/snapshot.bkp'\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0:\tlearn: 0.3019732\ttest: 0.3022828\tbest: 0.3022828 (0)\ttotal: 67ms\tremaining: 1.27s\n",
"1:\tlearn: 0.2238474\ttest: 0.2233617\tbest: 0.2233617 (1)\ttotal: 148ms\tremaining: 1.33s\n",
"2:\tlearn: 0.1904507\ttest: 0.1818139\tbest: 0.1818139 (2)\ttotal: 252ms\tremaining: 1.43s\n",
"3:\tlearn: 0.1802458\ttest: 0.1659696\tbest: 0.1659696 (3)\ttotal: 363ms\tremaining: 1.45s\n",
"4:\tlearn: 0.1754740\ttest: 0.1587777\tbest: 0.1587777 (4)\ttotal: 485ms\tremaining: 1.46s\n",
"5:\tlearn: 0.1703625\ttest: 0.1514527\tbest: 0.1514527 (5)\ttotal: 566ms\tremaining: 1.32s\n",
"6:\tlearn: 0.1684087\ttest: 0.1485265\tbest: 0.1485265 (6)\ttotal: 654ms\tremaining: 1.22s\n",
"7:\tlearn: 0.1672054\ttest: 0.1477918\tbest: 0.1477918 (7)\ttotal: 741ms\tremaining: 1.11s\n",
"8:\tlearn: 0.1662518\ttest: 0.1470028\tbest: 0.1470028 (8)\ttotal: 842ms\tremaining: 1.03s\n",
"9:\tlearn: 0.1656720\ttest: 0.1469340\tbest: 0.1469340 (9)\ttotal: 929ms\tremaining: 929ms\n",
"10:\tlearn: 0.1646958\ttest: 0.1458786\tbest: 0.1458786 (10)\ttotal: 1.02s\tremaining: 832ms\n",
"11:\tlearn: 0.1639292\ttest: 0.1456439\tbest: 0.1456439 (11)\ttotal: 1.1s\tremaining: 737ms\n",
"12:\tlearn: 0.1631213\ttest: 0.1453225\tbest: 0.1453225 (12)\ttotal: 1.18s\tremaining: 638ms\n",
"13:\tlearn: 0.1628037\ttest: 0.1449395\tbest: 0.1449395 (13)\ttotal: 1.25s\tremaining: 538ms\n",
"14:\tlearn: 0.1626140\ttest: 0.1450753\tbest: 0.1449395 (13)\ttotal: 1.32s\tremaining: 440ms\n",
"15:\tlearn: 0.1614957\ttest: 0.1448378\tbest: 0.1448378 (15)\ttotal: 1.42s\tremaining: 356ms\n",
"16:\tlearn: 0.1614173\ttest: 0.1448587\tbest: 0.1448378 (15)\ttotal: 1.51s\tremaining: 267ms\n",
"17:\tlearn: 0.1614154\ttest: 0.1448968\tbest: 0.1448378 (15)\ttotal: 1.61s\tremaining: 179ms\n",
"18:\tlearn: 0.1613764\ttest: 0.1448062\tbest: 0.1448062 (18)\ttotal: 1.65s\tremaining: 87ms\n",
"19:\tlearn: 0.1612640\ttest: 0.1445666\tbest: 0.1445666 (19)\ttotal: 1.72s\tremaining: 0us\n",
"\n",
"bestTest = 0.1445666012\n",
"bestIteration = 19\n",
"\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#!rm 'catboost_info/snapshot.bkp'\n",
"from catboost import CatBoostClassifier\n",
"model = CatBoostClassifier(\n",
" iterations=20,\n",
" save_snapshot=True,\n",
" snapshot_file='snapshot.bkp',\n",
" snapshot_interval=1,\n",
" random_seed=43\n",
")\n",
"model.fit(\n",
" X_train, y_train,\n",
" eval_set=(X_validation, y_validation),\n",
" cat_features=cat_features,\n",
" verbose=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "902ccb0a2fef4f25b41b4697f325192a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"20:\tlearn: 0.1607445\ttest: 0.1440626\tbest: 0.1440626 (20)\ttotal: 1.82s\tremaining: 17.1s\n",
"21:\tlearn: 0.1606770\ttest: 0.1439965\tbest: 0.1439965 (21)\ttotal: 1.9s\tremaining: 15.8s\n",
"22:\tlearn: 0.1606616\ttest: 0.1440083\tbest: 0.1439965 (21)\ttotal: 1.97s\tremaining: 14.8s\n",
"23:\tlearn: 0.1606540\ttest: 0.1440366\tbest: 0.1439965 (21)\ttotal: 2.04s\tremaining: 14s\n",
"24:\tlearn: 0.1606450\ttest: 0.1440838\tbest: 0.1439965 (21)\ttotal: 2.11s\tremaining: 13.6s\n",
"25:\tlearn: 0.1604540\ttest: 0.1439186\tbest: 0.1439186 (25)\ttotal: 2.22s\tremaining: 14.3s\n",
"26:\tlearn: 0.1600134\ttest: 0.1436727\tbest: 0.1436727 (26)\ttotal: 2.33s\tremaining: 15s\n",
"27:\tlearn: 0.1600132\ttest: 0.1436712\tbest: 0.1436712 (27)\ttotal: 2.4s\tremaining: 14.5s\n",
"28:\tlearn: 0.1597968\ttest: 0.1437510\tbest: 0.1436712 (27)\ttotal: 2.49s\tremaining: 14.6s\n",
"29:\tlearn: 0.1597306\ttest: 0.1437718\tbest: 0.1436712 (27)\ttotal: 2.62s\tremaining: 15.3s\n",
"30:\tlearn: 0.1595170\ttest: 0.1437298\tbest: 0.1436712 (27)\ttotal: 2.77s\tremaining: 16.1s\n",
"31:\tlearn: 0.1593775\ttest: 0.1439295\tbest: 0.1436712 (27)\ttotal: 2.88s\tremaining: 16.2s\n",
"32:\tlearn: 0.1589895\ttest: 0.1438260\tbest: 0.1436712 (27)\ttotal: 2.99s\tremaining: 16.2s\n",
"33:\tlearn: 0.1589845\ttest: 0.1438430\tbest: 0.1436712 (27)\ttotal: 3.04s\tremaining: 15.6s\n",
"34:\tlearn: 0.1589319\ttest: 0.1438151\tbest: 0.1436712 (27)\ttotal: 3.12s\tremaining: 15.4s\n",
"35:\tlearn: 0.1588918\ttest: 0.1438180\tbest: 0.1436712 (27)\ttotal: 3.19s\tremaining: 15s\n",
"36:\tlearn: 0.1588731\ttest: 0.1438125\tbest: 0.1436712 (27)\ttotal: 3.24s\tremaining: 14.6s\n",
"37:\tlearn: 0.1588159\ttest: 0.1438478\tbest: 0.1436712 (27)\ttotal: 3.32s\tremaining: 14.4s\n",
"38:\tlearn: 0.1588010\ttest: 0.1438314\tbest: 0.1436712 (27)\ttotal: 3.4s\tremaining: 14.2s\n",
"39:\tlearn: 0.1586945\ttest: 0.1436908\tbest: 0.1436712 (27)\ttotal: 3.49s\tremaining: 14.1s\n",
"40:\tlearn: 0.1586894\ttest: 0.1436654\tbest: 0.1436654 (40)\ttotal: 3.56s\tremaining: 13.9s\n",
"41:\tlearn: 0.1582296\ttest: 0.1432243\tbest: 0.1432243 (41)\ttotal: 3.66s\tremaining: 13.9s\n",
"42:\tlearn: 0.1582208\ttest: 0.1432282\tbest: 0.1432243 (41)\ttotal: 3.71s\tremaining: 13.6s\n",
"43:\tlearn: 0.1582196\ttest: 0.1432159\tbest: 0.1432159 (43)\ttotal: 3.78s\tremaining: 13.4s\n",
"44:\tlearn: 0.1579211\ttest: 0.1430249\tbest: 0.1430249 (44)\ttotal: 3.86s\tremaining: 13.3s\n",
"45:\tlearn: 0.1578847\ttest: 0.1431223\tbest: 0.1430249 (44)\ttotal: 3.96s\tremaining: 13.2s\n",
"46:\tlearn: 0.1578482\ttest: 0.1430689\tbest: 0.1430249 (44)\ttotal: 4.05s\tremaining: 13.2s\n",
"47:\tlearn: 0.1578470\ttest: 0.1430681\tbest: 0.1430249 (44)\ttotal: 4.11s\tremaining: 13s\n",
"48:\tlearn: 0.1578462\ttest: 0.1430580\tbest: 0.1430249 (44)\ttotal: 4.17s\tremaining: 12.7s\n",
"49:\tlearn: 0.1577480\ttest: 0.1432538\tbest: 0.1430249 (44)\ttotal: 4.25s\tremaining: 12.6s\n",
"50:\tlearn: 0.1575513\ttest: 0.1433252\tbest: 0.1430249 (44)\ttotal: 4.34s\tremaining: 12.6s\n",
"51:\tlearn: 0.1571386\ttest: 0.1428501\tbest: 0.1428501 (51)\ttotal: 4.43s\tremaining: 12.5s\n",
"52:\tlearn: 0.1571009\ttest: 0.1428871\tbest: 0.1428501 (51)\ttotal: 4.52s\tremaining: 12.5s\n",
"53:\tlearn: 0.1570543\ttest: 0.1428245\tbest: 0.1428245 (53)\ttotal: 4.6s\tremaining: 12.4s\n",
"54:\tlearn: 0.1569937\ttest: 0.1428190\tbest: 0.1428190 (54)\ttotal: 4.67s\tremaining: 12.2s\n",
"55:\tlearn: 0.1568604\ttest: 0.1428687\tbest: 0.1428190 (54)\ttotal: 4.79s\tremaining: 12.3s\n",
"56:\tlearn: 0.1568583\ttest: 0.1428933\tbest: 0.1428190 (54)\ttotal: 4.84s\tremaining: 12s\n",
"57:\tlearn: 0.1566636\ttest: 0.1425799\tbest: 0.1425799 (57)\ttotal: 4.92s\tremaining: 11.9s\n",
"58:\tlearn: 0.1566105\ttest: 0.1425226\tbest: 0.1425226 (58)\ttotal: 5s\tremaining: 11.9s\n",
"59:\tlearn: 0.1565786\ttest: 0.1425208\tbest: 0.1425208 (59)\ttotal: 5.1s\tremaining: 11.8s\n",
"60:\tlearn: 0.1565691\ttest: 0.1425103\tbest: 0.1425103 (60)\ttotal: 5.2s\tremaining: 11.8s\n",
"61:\tlearn: 0.1564572\ttest: 0.1425675\tbest: 0.1425103 (60)\ttotal: 5.3s\tremaining: 11.8s\n",
"62:\tlearn: 0.1564145\ttest: 0.1426340\tbest: 0.1425103 (60)\ttotal: 5.42s\tremaining: 11.8s\n",
"63:\tlearn: 0.1555005\ttest: 0.1421151\tbest: 0.1421151 (63)\ttotal: 5.52s\tremaining: 11.7s\n",
"64:\tlearn: 0.1548624\ttest: 0.1414761\tbest: 0.1414761 (64)\ttotal: 5.63s\tremaining: 11.7s\n",
"65:\tlearn: 0.1548316\ttest: 0.1415012\tbest: 0.1414761 (64)\ttotal: 5.73s\tremaining: 11.7s\n",
"66:\tlearn: 0.1546954\ttest: 0.1414039\tbest: 0.1414039 (66)\ttotal: 5.83s\tremaining: 11.6s\n",
"67:\tlearn: 0.1545918\ttest: 0.1412461\tbest: 0.1412461 (67)\ttotal: 5.94s\tremaining: 11.6s\n",
"68:\tlearn: 0.1545238\ttest: 0.1412214\tbest: 0.1412214 (68)\ttotal: 6.05s\tremaining: 11.6s\n",
"69:\tlearn: 0.1543865\ttest: 0.1412259\tbest: 0.1412214 (68)\ttotal: 6.17s\tremaining: 11.5s\n",
"70:\tlearn: 0.1541683\ttest: 0.1414024\tbest: 0.1412214 (68)\ttotal: 6.27s\tremaining: 11.5s\n",
"71:\tlearn: 0.1540775\ttest: 0.1414143\tbest: 0.1412214 (68)\ttotal: 6.37s\tremaining: 11.4s\n",
"72:\tlearn: 0.1540220\ttest: 0.1413343\tbest: 0.1412214 (68)\ttotal: 6.45s\tremaining: 11.3s\n",
"73:\tlearn: 0.1539815\ttest: 0.1413039\tbest: 0.1412214 (68)\ttotal: 6.52s\tremaining: 11.2s\n",
"74:\tlearn: 0.1539774\ttest: 0.1413123\tbest: 0.1412214 (68)\ttotal: 6.62s\tremaining: 11.1s\n",
"75:\tlearn: 0.1539583\ttest: 0.1413918\tbest: 0.1412214 (68)\ttotal: 6.71s\tremaining: 11s\n",
"76:\tlearn: 0.1538207\ttest: 0.1412925\tbest: 0.1412214 (68)\ttotal: 6.8s\tremaining: 10.9s\n",
"77:\tlearn: 0.1537301\ttest: 0.1412290\tbest: 0.1412214 (68)\ttotal: 6.89s\tremaining: 10.9s\n",
"78:\tlearn: 0.1536745\ttest: 0.1411666\tbest: 0.1411666 (78)\ttotal: 6.99s\tremaining: 10.8s\n",
"79:\tlearn: 0.1536233\ttest: 0.1412135\tbest: 0.1411666 (78)\ttotal: 7.07s\tremaining: 10.7s\n",
"80:\tlearn: 0.1533953\ttest: 0.1413119\tbest: 0.1411666 (78)\ttotal: 7.18s\tremaining: 10.6s\n",
"81:\tlearn: 0.1533845\ttest: 0.1412750\tbest: 0.1411666 (78)\ttotal: 7.27s\tremaining: 10.6s\n",
"82:\tlearn: 0.1532740\ttest: 0.1413232\tbest: 0.1411666 (78)\ttotal: 7.35s\tremaining: 10.4s\n",
"83:\tlearn: 0.1531701\ttest: 0.1416607\tbest: 0.1411666 (78)\ttotal: 7.45s\tremaining: 10.4s\n",
"84:\tlearn: 0.1530255\ttest: 0.1416930\tbest: 0.1411666 (78)\ttotal: 7.53s\tremaining: 10.3s\n",
"85:\tlearn: 0.1530003\ttest: 0.1416376\tbest: 0.1411666 (78)\ttotal: 7.6s\tremaining: 10.2s\n",
"86:\tlearn: 0.1529827\ttest: 0.1415299\tbest: 0.1411666 (78)\ttotal: 7.7s\tremaining: 10.1s\n",
"87:\tlearn: 0.1529783\ttest: 0.1415534\tbest: 0.1411666 (78)\ttotal: 7.78s\tremaining: 9.98s\n",
"88:\tlearn: 0.1529650\ttest: 0.1415712\tbest: 0.1411666 (78)\ttotal: 7.86s\tremaining: 9.87s\n",
"89:\tlearn: 0.1529603\ttest: 0.1415611\tbest: 0.1411666 (78)\ttotal: 7.95s\tremaining: 9.79s\n",
"90:\tlearn: 0.1526304\ttest: 0.1415594\tbest: 0.1411666 (78)\ttotal: 8.03s\tremaining: 9.69s\n",
"91:\tlearn: 0.1525361\ttest: 0.1417001\tbest: 0.1411666 (78)\ttotal: 8.14s\tremaining: 9.62s\n",
"92:\tlearn: 0.1525222\ttest: 0.1416753\tbest: 0.1411666 (78)\ttotal: 8.25s\tremaining: 9.56s\n",
"93:\tlearn: 0.1523982\ttest: 0.1417682\tbest: 0.1411666 (78)\ttotal: 8.34s\tremaining: 9.48s\n",
"94:\tlearn: 0.1520970\ttest: 0.1417407\tbest: 0.1411666 (78)\ttotal: 8.44s\tremaining: 9.4s\n",
"95:\tlearn: 0.1520886\ttest: 0.1416787\tbest: 0.1411666 (78)\ttotal: 8.53s\tremaining: 9.31s\n",
"96:\tlearn: 0.1519551\ttest: 0.1416528\tbest: 0.1411666 (78)\ttotal: 8.62s\tremaining: 9.22s\n",
"97:\tlearn: 0.1518608\ttest: 0.1417604\tbest: 0.1411666 (78)\ttotal: 8.72s\tremaining: 9.15s\n",
"98:\tlearn: 0.1515918\ttest: 0.1418670\tbest: 0.1411666 (78)\ttotal: 8.81s\tremaining: 9.06s\n",
"99:\tlearn: 0.1514836\ttest: 0.1421103\tbest: 0.1411666 (78)\ttotal: 8.88s\tremaining: 8.95s\n",
"100:\tlearn: 0.1513803\ttest: 0.1424265\tbest: 0.1411666 (78)\ttotal: 8.94s\tremaining: 8.82s\n",
"101:\tlearn: 0.1513654\ttest: 0.1424137\tbest: 0.1411666 (78)\ttotal: 9s\tremaining: 8.7s\n",
"102:\tlearn: 0.1513427\ttest: 0.1423719\tbest: 0.1411666 (78)\ttotal: 9.06s\tremaining: 8.57s\n",
"103:\tlearn: 0.1511778\ttest: 0.1422654\tbest: 0.1411666 (78)\ttotal: 9.14s\tremaining: 8.48s\n",
"104:\tlearn: 0.1511117\ttest: 0.1423467\tbest: 0.1411666 (78)\ttotal: 9.24s\tremaining: 8.41s\n",
"105:\tlearn: 0.1510436\ttest: 0.1423221\tbest: 0.1411666 (78)\ttotal: 9.35s\tremaining: 8.33s\n",
"106:\tlearn: 0.1508977\ttest: 0.1423137\tbest: 0.1411666 (78)\ttotal: 9.43s\tremaining: 8.24s\n",
"107:\tlearn: 0.1508959\ttest: 0.1423254\tbest: 0.1411666 (78)\ttotal: 9.5s\tremaining: 8.13s\n",
"108:\tlearn: 0.1508424\ttest: 0.1424438\tbest: 0.1411666 (78)\ttotal: 9.59s\tremaining: 8.04s\n",
"109:\tlearn: 0.1508101\ttest: 0.1424697\tbest: 0.1411666 (78)\ttotal: 9.67s\tremaining: 7.95s\n",
"110:\tlearn: 0.1507799\ttest: 0.1424306\tbest: 0.1411666 (78)\ttotal: 9.75s\tremaining: 7.85s\n",
"111:\tlearn: 0.1507635\ttest: 0.1424197\tbest: 0.1411666 (78)\ttotal: 9.84s\tremaining: 7.76s\n",
"112:\tlearn: 0.1507610\ttest: 0.1423976\tbest: 0.1411666 (78)\ttotal: 9.92s\tremaining: 7.67s\n",
"113:\tlearn: 0.1507207\ttest: 0.1423904\tbest: 0.1411666 (78)\ttotal: 10s\tremaining: 7.58s\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"114:\tlearn: 0.1506871\ttest: 0.1423714\tbest: 0.1411666 (78)\ttotal: 10.1s\tremaining: 7.52s\n",
"115:\tlearn: 0.1506365\ttest: 0.1423830\tbest: 0.1411666 (78)\ttotal: 10.2s\tremaining: 7.43s\n",
"116:\tlearn: 0.1505736\ttest: 0.1425609\tbest: 0.1411666 (78)\ttotal: 10.3s\tremaining: 7.35s\n",
"117:\tlearn: 0.1505321\ttest: 0.1425929\tbest: 0.1411666 (78)\ttotal: 10.4s\tremaining: 7.29s\n",
"118:\tlearn: 0.1505028\ttest: 0.1425452\tbest: 0.1411666 (78)\ttotal: 10.5s\tremaining: 7.21s\n",
"119:\tlearn: 0.1497712\ttest: 0.1424179\tbest: 0.1411666 (78)\ttotal: 10.7s\tremaining: 7.17s\n",
"120:\tlearn: 0.1497006\ttest: 0.1423512\tbest: 0.1411666 (78)\ttotal: 10.8s\tremaining: 7.11s\n",
"121:\tlearn: 0.1496638\ttest: 0.1423660\tbest: 0.1411666 (78)\ttotal: 11s\tremaining: 7.1s\n",
"122:\tlearn: 0.1496092\ttest: 0.1423799\tbest: 0.1411666 (78)\ttotal: 11.1s\tremaining: 7.04s\n",
"123:\tlearn: 0.1495761\ttest: 0.1422939\tbest: 0.1411666 (78)\ttotal: 11.2s\tremaining: 6.95s\n",
"124:\tlearn: 0.1495648\ttest: 0.1423047\tbest: 0.1411666 (78)\ttotal: 11.3s\tremaining: 6.86s\n",
"125:\tlearn: 0.1493551\ttest: 0.1421014\tbest: 0.1411666 (78)\ttotal: 11.5s\tremaining: 6.8s\n",
"126:\tlearn: 0.1492764\ttest: 0.1420590\tbest: 0.1411666 (78)\ttotal: 11.6s\tremaining: 6.72s\n",
"127:\tlearn: 0.1490840\ttest: 0.1417298\tbest: 0.1411666 (78)\ttotal: 11.7s\tremaining: 6.65s\n",
"128:\tlearn: 0.1490683\ttest: 0.1417569\tbest: 0.1411666 (78)\ttotal: 11.8s\tremaining: 6.58s\n",
"129:\tlearn: 0.1490611\ttest: 0.1417673\tbest: 0.1411666 (78)\ttotal: 12s\tremaining: 6.53s\n",
"130:\tlearn: 0.1490453\ttest: 0.1417709\tbest: 0.1411666 (78)\ttotal: 12.1s\tremaining: 6.45s\n",
"131:\tlearn: 0.1490231\ttest: 0.1417575\tbest: 0.1411666 (78)\ttotal: 12.3s\tremaining: 6.39s\n",
"132:\tlearn: 0.1489009\ttest: 0.1419888\tbest: 0.1411666 (78)\ttotal: 12.3s\tremaining: 6.29s\n",
"133:\tlearn: 0.1486744\ttest: 0.1420246\tbest: 0.1411666 (78)\ttotal: 12.4s\tremaining: 6.2s\n",
"134:\tlearn: 0.1485575\ttest: 0.1421380\tbest: 0.1411666 (78)\ttotal: 12.5s\tremaining: 6.11s\n",
"135:\tlearn: 0.1485306\ttest: 0.1421969\tbest: 0.1411666 (78)\ttotal: 12.6s\tremaining: 6.01s\n",
"136:\tlearn: 0.1483855\ttest: 0.1420489\tbest: 0.1411666 (78)\ttotal: 12.7s\tremaining: 5.91s\n",
"137:\tlearn: 0.1483428\ttest: 0.1421026\tbest: 0.1411666 (78)\ttotal: 12.8s\tremaining: 5.83s\n",
"138:\tlearn: 0.1483134\ttest: 0.1420898\tbest: 0.1411666 (78)\ttotal: 12.9s\tremaining: 5.72s\n",
"139:\tlearn: 0.1483052\ttest: 0.1420878\tbest: 0.1411666 (78)\ttotal: 13s\tremaining: 5.63s\n",
"140:\tlearn: 0.1482157\ttest: 0.1420241\tbest: 0.1411666 (78)\ttotal: 13.1s\tremaining: 5.53s\n",
"141:\tlearn: 0.1481112\ttest: 0.1421268\tbest: 0.1411666 (78)\ttotal: 13.2s\tremaining: 5.45s\n",
"142:\tlearn: 0.1479973\ttest: 0.1422019\tbest: 0.1411666 (78)\ttotal: 13.3s\tremaining: 5.36s\n",
"143:\tlearn: 0.1479948\ttest: 0.1422188\tbest: 0.1411666 (78)\ttotal: 13.4s\tremaining: 5.28s\n",
"144:\tlearn: 0.1479705\ttest: 0.1421976\tbest: 0.1411666 (78)\ttotal: 13.6s\tremaining: 5.22s\n",
"145:\tlearn: 0.1479037\ttest: 0.1423085\tbest: 0.1411666 (78)\ttotal: 13.8s\tremaining: 5.16s\n",
"146:\tlearn: 0.1478758\ttest: 0.1424303\tbest: 0.1411666 (78)\ttotal: 13.9s\tremaining: 5.08s\n",
"147:\tlearn: 0.1478342\ttest: 0.1424301\tbest: 0.1411666 (78)\ttotal: 14s\tremaining: 4.98s\n",
"148:\tlearn: 0.1478086\ttest: 0.1423590\tbest: 0.1411666 (78)\ttotal: 14.1s\tremaining: 4.89s\n",
"149:\tlearn: 0.1477729\ttest: 0.1422865\tbest: 0.1411666 (78)\ttotal: 14.2s\tremaining: 4.8s\n",
"150:\tlearn: 0.1474847\ttest: 0.1422271\tbest: 0.1411666 (78)\ttotal: 14.3s\tremaining: 4.71s\n",
"151:\tlearn: 0.1474169\ttest: 0.1423728\tbest: 0.1411666 (78)\ttotal: 14.4s\tremaining: 4.63s\n",
"152:\tlearn: 0.1472865\ttest: 0.1422631\tbest: 0.1411666 (78)\ttotal: 14.6s\tremaining: 4.54s\n",
"153:\tlearn: 0.1472423\ttest: 0.1422256\tbest: 0.1411666 (78)\ttotal: 14.7s\tremaining: 4.46s\n",
"154:\tlearn: 0.1471948\ttest: 0.1421481\tbest: 0.1411666 (78)\ttotal: 14.8s\tremaining: 4.37s\n",
"155:\tlearn: 0.1468854\ttest: 0.1419873\tbest: 0.1411666 (78)\ttotal: 14.9s\tremaining: 4.28s\n",
"156:\tlearn: 0.1468790\ttest: 0.1420101\tbest: 0.1411666 (78)\ttotal: 15s\tremaining: 4.18s\n",
"157:\tlearn: 0.1468687\ttest: 0.1419619\tbest: 0.1411666 (78)\ttotal: 15.1s\tremaining: 4.09s\n",
"158:\tlearn: 0.1468658\ttest: 0.1419551\tbest: 0.1411666 (78)\ttotal: 15.2s\tremaining: 3.99s\n",
"159:\tlearn: 0.1468288\ttest: 0.1420261\tbest: 0.1411666 (78)\ttotal: 15.3s\tremaining: 3.89s\n",
"160:\tlearn: 0.1467160\ttest: 0.1419321\tbest: 0.1411666 (78)\ttotal: 15.4s\tremaining: 3.8s\n",
"161:\tlearn: 0.1467084\ttest: 0.1419409\tbest: 0.1411666 (78)\ttotal: 15.5s\tremaining: 3.7s\n",
"162:\tlearn: 0.1464073\ttest: 0.1416654\tbest: 0.1411666 (78)\ttotal: 15.7s\tremaining: 3.61s\n",
"163:\tlearn: 0.1462285\ttest: 0.1416156\tbest: 0.1411666 (78)\ttotal: 15.8s\tremaining: 3.52s\n",
"164:\tlearn: 0.1460339\ttest: 0.1416086\tbest: 0.1411666 (78)\ttotal: 15.9s\tremaining: 3.42s\n",
"165:\tlearn: 0.1460218\ttest: 0.1416287\tbest: 0.1411666 (78)\ttotal: 16s\tremaining: 3.33s\n",
"166:\tlearn: 0.1460092\ttest: 0.1416437\tbest: 0.1411666 (78)\ttotal: 16.1s\tremaining: 3.23s\n",
"167:\tlearn: 0.1456841\ttest: 0.1418764\tbest: 0.1411666 (78)\ttotal: 16.2s\tremaining: 3.13s\n",
"168:\tlearn: 0.1456011\ttest: 0.1419543\tbest: 0.1411666 (78)\ttotal: 16.3s\tremaining: 3.04s\n",
"169:\tlearn: 0.1455905\ttest: 0.1420247\tbest: 0.1411666 (78)\ttotal: 16.4s\tremaining: 2.94s\n",
"170:\tlearn: 0.1455581\ttest: 0.1420612\tbest: 0.1411666 (78)\ttotal: 16.5s\tremaining: 2.85s\n",
"171:\tlearn: 0.1455441\ttest: 0.1420924\tbest: 0.1411666 (78)\ttotal: 16.6s\tremaining: 2.75s\n",
"172:\tlearn: 0.1454990\ttest: 0.1421376\tbest: 0.1411666 (78)\ttotal: 16.7s\tremaining: 2.65s\n",
"173:\tlearn: 0.1453813\ttest: 0.1420980\tbest: 0.1411666 (78)\ttotal: 16.9s\tremaining: 2.56s\n",
"174:\tlearn: 0.1451034\ttest: 0.1420903\tbest: 0.1411666 (78)\ttotal: 17s\tremaining: 2.46s\n",
"175:\tlearn: 0.1450984\ttest: 0.1421000\tbest: 0.1411666 (78)\ttotal: 17s\tremaining: 2.36s\n",
"176:\tlearn: 0.1449863\ttest: 0.1421289\tbest: 0.1411666 (78)\ttotal: 17.1s\tremaining: 2.26s\n",
"177:\tlearn: 0.1447876\ttest: 0.1420145\tbest: 0.1411666 (78)\ttotal: 17.2s\tremaining: 2.16s\n",
"178:\tlearn: 0.1447701\ttest: 0.1420666\tbest: 0.1411666 (78)\ttotal: 17.3s\tremaining: 2.06s\n",
"179:\tlearn: 0.1446955\ttest: 0.1420109\tbest: 0.1411666 (78)\ttotal: 17.4s\tremaining: 1.96s\n",
"180:\tlearn: 0.1446544\ttest: 0.1420329\tbest: 0.1411666 (78)\ttotal: 17.5s\tremaining: 1.86s\n",
"181:\tlearn: 0.1445431\ttest: 0.1421508\tbest: 0.1411666 (78)\ttotal: 17.6s\tremaining: 1.76s\n",
"182:\tlearn: 0.1442690\ttest: 0.1421298\tbest: 0.1411666 (78)\ttotal: 17.7s\tremaining: 1.67s\n",
"183:\tlearn: 0.1441832\ttest: 0.1421534\tbest: 0.1411666 (78)\ttotal: 17.8s\tremaining: 1.57s\n",
"184:\tlearn: 0.1440706\ttest: 0.1423811\tbest: 0.1411666 (78)\ttotal: 17.9s\tremaining: 1.47s\n",
"185:\tlearn: 0.1439856\ttest: 0.1425963\tbest: 0.1411666 (78)\ttotal: 18s\tremaining: 1.37s\n",
"186:\tlearn: 0.1439761\ttest: 0.1426506\tbest: 0.1411666 (78)\ttotal: 18.1s\tremaining: 1.28s\n",
"187:\tlearn: 0.1439714\ttest: 0.1426176\tbest: 0.1411666 (78)\ttotal: 18.2s\tremaining: 1.18s\n",
"188:\tlearn: 0.1439492\ttest: 0.1426109\tbest: 0.1411666 (78)\ttotal: 18.4s\tremaining: 1.08s\n",
"189:\tlearn: 0.1439454\ttest: 0.1426164\tbest: 0.1411666 (78)\ttotal: 18.4s\tremaining: 984ms\n",
"190:\tlearn: 0.1439413\ttest: 0.1426552\tbest: 0.1411666 (78)\ttotal: 18.5s\tremaining: 885ms\n",
"191:\tlearn: 0.1438359\ttest: 0.1427087\tbest: 0.1411666 (78)\ttotal: 18.6s\tremaining: 787ms\n",
"192:\tlearn: 0.1438287\ttest: 0.1427033\tbest: 0.1411666 (78)\ttotal: 18.7s\tremaining: 689ms\n",
"193:\tlearn: 0.1437879\ttest: 0.1426357\tbest: 0.1411666 (78)\ttotal: 18.8s\tremaining: 590ms\n",
"194:\tlearn: 0.1437860\ttest: 0.1426468\tbest: 0.1411666 (78)\ttotal: 18.9s\tremaining: 492ms\n",
"195:\tlearn: 0.1437606\ttest: 0.1426516\tbest: 0.1411666 (78)\ttotal: 19s\tremaining: 393ms\n",
"196:\tlearn: 0.1437591\ttest: 0.1426400\tbest: 0.1411666 (78)\ttotal: 19.1s\tremaining: 294ms\n",
"197:\tlearn: 0.1437577\ttest: 0.1426297\tbest: 0.1411666 (78)\ttotal: 19.2s\tremaining: 196ms\n",
"198:\tlearn: 0.1437447\ttest: 0.1426317\tbest: 0.1411666 (78)\ttotal: 19.3s\tremaining: 98ms\n",
"199:\tlearn: 0.1437230\ttest: 0.1426660\tbest: 0.1411666 (78)\ttotal: 19.4s\tremaining: 0us\n",
"\n",
"bestTest = 0.1411666436\n",
"bestIteration = 78\n",
"\n",
"Shrink model to first 79 iterations.\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = CatBoostClassifier(\n",
" iterations=200,\n",
" save_snapshot=True,\n",
" snapshot_file='snapshot.bkp',\n",
" snapshot_interval=1,\n",
" random_seed=43\n",
")\n",
"model.fit(\n",
" X_train, y_train,\n",
" eval_set=(X_validation, y_validation),\n",
" cat_features=cat_features,\n",
" verbose=True,\n",
" plot=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model predictions"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"? model.predict_proba"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.0688 0.9312]\n",
" [0.008 0.992 ]\n",
" [0.0068 0.9932]\n",
" ...\n",
" [0.0129 0.9871]\n",
" [0.0184 0.9816]\n",
" [0.0273 0.9727]]\n"
]
}
],
"source": [
"print(model.predict_proba(data=X_validation))"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1. 1. 1. ... 1. 1. 1.]\n"
]
}
],
"source": [
"print(model.predict(data=X_validation))"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2.6056 4.8177 4.9809 ... 4.3367 3.9791 3.5714]\n"
]
}
],
"source": [
"raw_pred = model.predict(data=X_validation, prediction_type='RawFormulaVal')\n",
"print(raw_pred)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.9312 0.992 0.9932 ... 0.9871 0.9816 0.9727]\n"
]
}
],
"source": [
"import math\n",
"def sigmoid(x):\n",
" return 1 / (1 + math.exp(-x))\n",
"probabilities = [sigmoid(x) for x in raw_pred]\n",
"print(np.array(probabilities))"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.0688 0.9312]\n",
" [0.008 0.992 ]\n",
" [0.0068 0.9932]\n",
" ...\n",
" [0.0129 0.9871]\n",
" [0.0184 0.9816]\n",
" [0.0273 0.9727]]\n"
]
}
],
"source": [
"X_prepared = X_validation.values.astype(str).astype(object)\n",
"# For FeaturesData class categorial features must have type str\n",
"\n",
"fast_predictions = model.predict_proba(data=FeaturesData(cat_feature_data=X_prepared, \n",
" cat_feature_names=list(X_validation)))\n",
"\n",
"print(fast_predictions)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Staged prediction"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration 0, predictions:\n",
"[[0.3065 0.6935]\n",
" [0.2317 0.7683]\n",
" [0.2317 0.7683]\n",
" ...\n",
" [0.2317 0.7683]\n",
" [0.2317 0.7683]\n",
" [0.2746 0.7254]]\n",
"Iteration 1, predictions:\n",
"[[0.3002 0.6998]\n",
" [0.1762 0.8238]\n",
" [0.1422 0.8578]\n",
" ...\n",
" [0.1422 0.8578]\n",
" [0.1916 0.8084]\n",
" [0.2116 0.7884]]\n",
"Iteration 2, predictions:\n",
"[[0.3307 0.6693]\n",
" [0.13 0.87 ]\n",
" [0.1038 0.8962]\n",
" ...\n",
" [0.1454 0.8546]\n",
" [0.1956 0.8044]\n",
" [0.1938 0.8062]]\n"
]
}
],
"source": [
"predictions_gen = model.staged_predict_proba(data=X_validation, ntree_start=2, ntree_end=8, eval_period=2)\n",
"for iteration, predictions in enumerate(predictions_gen):\n",
" print('Iteration ' + str(iteration) + ', predictions:')\n",
" print(predictions)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Solving MultiClassification problem"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1536e93770b14eb48d913a0209269c87",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from catboost import CatBoostClassifier\n",
"model = CatBoostClassifier(\n",
" iterations=150,\n",
" random_seed=43,\n",
" loss_function='MultiClass'\n",
" #loss_function='MultiClassOneVsAll'\n",
")\n",
"model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" eval_set=(X_validation, y_validation),\n",
" verbose=False,\n",
" plot=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Metric evaluation on a new dataset"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0:\tlearn: 0.6569860\ttotal: 53.2ms\tremaining: 10.6s\n",
"50:\tlearn: 0.1950260\ttotal: 3.33s\tremaining: 9.74s\n",
"100:\tlearn: 0.1700584\ttotal: 6.72s\tremaining: 6.58s\n",
"150:\tlearn: 0.1641016\ttotal: 10.6s\tremaining: 3.42s\n",
"199:\tlearn: 0.1604074\ttotal: 14.1s\tremaining: 0us\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = CatBoostClassifier(\n",
" random_seed=63,\n",
" iterations=200,\n",
" learning_rate=0.03,\n",
")\n",
"model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" verbose=50\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"metrics = model.eval_metrics(data=pool1, \n",
" metrics=['Logloss','AUC'],\n",
" ntree_start=0,\n",
" ntree_end=0, \n",
" eval_period=1,\n",
" plot=True)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AUC values:\n",
"[0.4998 0.538 0.5504 0.5888 0.6503 0.6487 0.6487 0.6601 0.6601 0.6612\n",
" 0.6614 0.67 0.6699 0.6697 0.6697 0.6698 0.6698 0.6698 0.6698 0.7265\n",
" 0.7338 0.7376 0.7478 0.7478 0.7506 0.7591 0.7642 0.7877 0.8148 0.8265\n",
" 0.8353 0.8459 0.8533 0.8564 0.8573 0.8748 0.8874 0.893 0.8948 0.898\n",
" 0.9003 0.9052 0.9122 0.9159 0.92 0.9204 0.9209 0.9246 0.9262 0.9266\n",
" 0.9279 0.9303 0.9311 0.9312 0.9327 0.9329 0.9337 0.9341 0.9341 0.9349\n",
" 0.9354 0.9365 0.9391 0.9411 0.9424 0.9435 0.9446 0.9458 0.9463 0.9471\n",
" 0.9478 0.9481 0.9482 0.9483 0.9486 0.9497 0.951 0.9514 0.9515 0.9519\n",
" 0.9522 0.9526 0.953 0.9537 0.9543 0.9544 0.9545 0.9548 0.9548 0.9548\n",
" 0.9548 0.9548 0.955 0.9551 0.9553 0.9554 0.9557 0.9557 0.9557 0.9557\n",
" 0.956 0.9566 0.9566 0.9569 0.9568 0.957 0.9569 0.9572 0.9573 0.9575\n",
" 0.9576 0.9576 0.9577 0.958 0.9587 0.9594 0.9595 0.9594 0.9601 0.9609\n",
" 0.9609 0.9608 0.9608 0.9612 0.9616 0.9618 0.9622 0.9623 0.9623 0.9625\n",
" 0.9625 0.9628 0.9629 0.9629 0.9629 0.963 0.9631 0.9632 0.9635 0.9636\n",
" 0.9637 0.9638 0.9638 0.964 0.9641 0.9643 0.9645 0.9645 0.9645 0.9647\n",
" 0.9647 0.9647 0.9647 0.9648 0.9648 0.9648 0.965 0.9652 0.9652 0.9653\n",
" 0.9653 0.9654 0.9655 0.9655 0.9655 0.9656 0.9656 0.9657 0.9657 0.9657\n",
" 0.9658 0.9658 0.9659 0.9659 0.966 0.966 0.966 0.966 0.966 0.966\n",
" 0.966 0.966 0.9661 0.9662 0.9663 0.9663 0.9663 0.9663 0.9663 0.9663\n",
" 0.9665 0.9664 0.9666 0.9666 0.9666 0.9666 0.9666 0.9667 0.9667 0.9666]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "79673ab90b80409f9098024a23d609ec",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print('AUC values:')\n",
"print(np.array(metrics['AUC']))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Saving the model"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"my_best_model = CatBoostClassifier(iterations=10)\n",
"my_best_model.fit(\n",
" X_train, y_train,\n",
" eval_set=(X_validation, y_validation),\n",
" cat_features=cat_features,\n",
" verbose=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"my_best_model.save_model('catboost_model.bin')"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'iterations': 10, 'loss_function': 'Logloss', 'logging_level': 'Silent'}\n",
"17379826207872\n"
]
}
],
"source": [
"my_best_model.load_model('catboost_model.bin')\n",
"print(my_best_model.get_params())\n",
"print(my_best_model.random_seed_)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"my_best_model.save_model('catboost_model.json', format='json')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Feature importances"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0:\tlearn: 0.6569860\ttotal: 56.9ms\tremaining: 11.3s\n",
"50:\tlearn: 0.1950260\ttotal: 4.06s\tremaining: 11.9s\n",
"100:\tlearn: 0.1700584\ttotal: 7.18s\tremaining: 7.04s\n",
"150:\tlearn: 0.1641016\ttotal: 10.8s\tremaining: 3.5s\n",
"199:\tlearn: 0.1604074\ttotal: 14.3s\tremaining: 0us\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = CatBoostClassifier(\n",
" random_seed=63,\n",
" iterations=200,\n",
" learning_rate=0.03)\n",
"\n",
"model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" verbose=50\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"fstrs = model.get_feature_importance(prettified=True)"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{b'MGR_ID': 18.398056223683028,\n",
" b'RESOURCE': 21.27625707750949,\n",
" b'ROLE_CODE': 11.943230282191124,\n",
" b'ROLE_DEPTNAME': 15.252806962601875,\n",
" b'ROLE_FAMILY': 2.4789173515872176,\n",
" b'ROLE_FAMILY_DESC': 9.984073192415533,\n",
" b'ROLE_ROLLUP_1': 2.6278918788673633,\n",
" b'ROLE_ROLLUP_2': 13.582536028250486,\n",
" b'ROLE_TITLE': 4.456231002893895}"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"{feature_name : value for feature_name, value in fstrs}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Shap values"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://github.com/slundberg/shap"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
"def object_predictions(model, obj):\n",
" print('Probability of class 1 = {:.4f}'.format(model.predict_proba([obj])[0][1]))\n",
" print('Formula raw prediction = {:.4f}'.format(model.predict([obj], prediction_type='RawFormulaVal')[0]))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The model has complex ctrs, so the SHAP values will be calculated approximately.\n"
]
}
],
"source": [
"import shap\n",
"explainer = shap.TreeExplainer(model)\n",
"shap_values = explainer.shap_values(Pool(X, y, cat_features=cat_features))\n"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3.345740996864643"
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"explainer.expected_value"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Probability of class 1 = 0.9798\n",
"Formula raw prediction = 3.8820\n"
]
}
],
"source": [
"object_predictions(model, X.iloc[3,:])"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" Visualization omitted, Javascript library not loaded! \n",
" Have you run `initjs()` in this notebook? If this notebook was from another\n",
" user you must also trust this notebook (File -> Trust notebook). If you are viewing\n",
" this notebook on github the Javascript has been stripped for security. If you are using\n",
" JupyterLab this error is because a JupyterLab extension has not yet been written.\n",
"
\n",
" "
],
"text/plain": [
""
]
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"shap.initjs()\n",
"shap.force_plot(explainer.expected_value, shap_values[3,:], X.iloc[3,:])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above explanation shows features each contributing to push the model output from the base value (the average model output over the training dataset we passed) to the model output. Features pushing the prediction higher are shown in red, those pushing the prediction lower are in blue"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Probability of class 1 = 0.6404\n",
"Formula raw prediction = 0.5772\n"
]
}
],
"source": [
"object_predictions(model, X.iloc[91,:])"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" Visualization omitted, Javascript library not loaded! \n",
" Have you run `initjs()` in this notebook? If this notebook was from another\n",
" user you must also trust this notebook (File -> Trust notebook). If you are viewing\n",
" this notebook on github the Javascript has been stripped for security. If you are using\n",
" JupyterLab this error is because a JupyterLab extension has not yet been written.\n",
"
\n",
" "
],
"text/plain": [
""
]
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import shap\n",
"shap.initjs()\n",
"shap.force_plot(explainer.expected_value, shap_values[91,:], X.iloc[91,:])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To get an overview of which features are most important for a model we can plot the SHAP values of every feature for every sample. The plot below sorts features by the sum of SHAP value magnitudes over all samples, and uses SHAP values to show the distribution of the impacts each feature has on the model output. The color represents the feature value (red high, blue low)."
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"shap.summary_plot(shap_values, X)"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"shap.summary_plot(shap_values, X, plot_type=\"bar\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hyperparameter tunning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training speed"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a92eadf3c1c044dea4c7e513a9cc0443",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from catboost import CatBoost\n",
"fast_model = CatBoostClassifier(\n",
" random_seed=63,\n",
" iterations=150,\n",
" learning_rate=0.01,\n",
" boosting_type='Plain',\n",
" bootstrap_type='Bernoulli',\n",
" subsample=0.5,\n",
" rsm=0.5,\n",
" one_hot_max_size=20,\n",
" leaf_estimation_iterations=2,\n",
" max_ctr_complexity=1)\n",
"\n",
"fast_model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" verbose=False,\n",
" plot=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Accuracy"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e1847556d61748cea1e37bb8c30ccb1a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tunned_model = CatBoostClassifier(\n",
" random_seed=63,\n",
" iterations=1000,\n",
" learning_rate=0.03,\n",
" l2_leaf_reg=3,\n",
" bagging_temperature=1,\n",
" random_strength=1,\n",
" one_hot_max_size=2,\n",
" leaf_estimation_method='Newton'\n",
")\n",
"tunned_model.fit(\n",
" X_train, y_train,\n",
" cat_features=cat_features,\n",
" verbose=False,\n",
" eval_set=(X_validation, y_validation),\n",
" plot=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training the model after parameter tunning"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0:\tlearn: 0.6431867\ttotal: 103ms\tremaining: 1m 51s\n",
"100:\tlearn: 0.1590533\ttotal: 8.58s\tremaining: 1m 23s\n",
"200:\tlearn: 0.1517316\ttotal: 16.1s\tremaining: 1m 10s\n",
"300:\tlearn: 0.1487795\ttotal: 23.5s\tremaining: 1m\n",
"400:\tlearn: 0.1468414\ttotal: 31.3s\tremaining: 53.2s\n",
"500:\tlearn: 0.1448880\ttotal: 39.4s\tremaining: 45.7s\n",
"600:\tlearn: 0.1433889\ttotal: 48.4s\tremaining: 38.8s\n",
"700:\tlearn: 0.1421200\ttotal: 58.1s\tremaining: 31.7s\n",
"800:\tlearn: 0.1411938\ttotal: 1m 6s\tremaining: 23.4s\n",
"900:\tlearn: 0.1401819\ttotal: 1m 15s\tremaining: 15.2s\n",
"1000:\tlearn: 0.1394250\ttotal: 1m 25s\tremaining: 6.99s\n",
"1082:\tlearn: 0.1388355\ttotal: 1m 32s\tremaining: 0us\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"best_model = CatBoostClassifier(\n",
" random_seed=63,\n",
" iterations=int(tunned_model.tree_count_ * 1.2),\n",
")\n",
"best_model.fit(\n",
" X, y,\n",
" cat_features=cat_features,\n",
" verbose=100\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Calculate predictions for the contest"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predictoins:\n",
"[[0.4024 0.5976]\n",
" [0.015 0.985 ]\n",
" [0.0101 0.9899]\n",
" ...\n",
" [0.0089 0.9911]\n",
" [0.0496 0.9504]\n",
" [0.0135 0.9865]]\n"
]
}
],
"source": [
"X_test = test_df.drop('id', axis=1)\n",
"test_pool = Pool(data=X_test, cat_features=cat_features)\n",
"contest_predictions = best_model.predict_proba(test_pool)\n",
"print('Predictoins:')\n",
"print(contest_predictions)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare the submission"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"f = open('submit.csv', 'w')\n",
"f.write('Id,Action\\n')\n",
"for idx in range(len(contest_predictions)):\n",
" line = str(test_df['id'][idx]) + ',' + str(contest_predictions[idx][1]) + '\\n'\n",
" f.write(line)\n",
"f.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Submit your solution [here](https://www.kaggle.com/c/amazon-employee-access-challenge/submit).\n",
"Good luck!!!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
},
"widgets": {
"state": {
"1057714ebc614324aa3ba2cf69408966": {
"views": [
{
"cell_index": 17
}
]
},
"8381e9eed05f4a03905ae8a56d7ab4ea": {
"views": [
{
"cell_index": 48
}
]
},
"f49684e8c5c44241bfe2c7f577f5cb41": {
"views": [
{
"cell_index": 53
}
]
}
},
"version": "1.2.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}