{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Solving classification problems with CatBoost" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catboost/tutorials/blob/master/events/pydata_nyc_oct_19_2018.ipynb)\n", "\n", "In this tutorial we will use dataset Amazon Employee Access Challenge from [Kaggle](https://www.kaggle.com) competition for our experiments. Data can be downloaded [here](https://www.kaggle.com/c/amazon-employee-access-challenge/data)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Libraries installation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "#!pip install --user --upgrade catboost\n", "#!pip install --user --upgrade ipywidgets\n", "#!pip install shap\n", "#!pip install sklearn\n", "#!pip install --upgrade numpy\n", "#!pip install --upgrade pandas\n", "#!jupyter nbextension enable --py widgetsnbextension" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.10.2\n", "Python 3.5.5 :: Anaconda custom (64-bit)\r\n" ] } ], "source": [ "import catboost\n", "print(catboost.__version__)\n", "!python --version" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reading the data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The most simple way — read everything in pandas data frame" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import os\n", "import numpy as np\n", "np.set_printoptions(precision=4)\n", "import catboost\n", "from catboost import *\n", "from catboost import datasets" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# import ssl\n", "# ssl._create_default_https_context = ssl._create_unverified_context" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "(train_df, test_df) = catboost.datasets.amazon()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ACTIONRESOURCEMGR_IDROLE_ROLLUP_1ROLE_ROLLUP_2ROLE_DEPTNAMEROLE_TITLEROLE_FAMILY_DESCROLE_FAMILYROLE_CODE
013935385475117961118300123472117905117906290919117908
11171831540117961118343123125118536118536308574118539
21367241445711821911822011788411787926795219721117880
31361355396117961118343119993118321240983290919118322
4142680590511792911793011956911932312393219793119325
\n", "
" ], "text/plain": [ " ACTION RESOURCE MGR_ID ROLE_ROLLUP_1 ROLE_ROLLUP_2 ROLE_DEPTNAME \\\n", "0 1 39353 85475 117961 118300 123472 \n", "1 1 17183 1540 117961 118343 123125 \n", "2 1 36724 14457 118219 118220 117884 \n", "3 1 36135 5396 117961 118343 119993 \n", "4 1 42680 5905 117929 117930 119569 \n", "\n", " ROLE_TITLE ROLE_FAMILY_DESC ROLE_FAMILY ROLE_CODE \n", "0 117905 117906 290919 117908 \n", "1 118536 118536 308574 118539 \n", "2 117879 267952 19721 117880 \n", "3 118321 240983 290919 118322 \n", "4 119323 123932 19793 119325 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preparing your data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Label values extraction" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "y = train_df.ACTION\n", "X = train_df.drop('ACTION', axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Categorical features declaration" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0, 1, 2, 3, 4, 5, 6, 7, 8]\n" ] } ], "source": [ "cat_features = list(range(0, X.shape[1]))\n", "print(cat_features)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Looking on label balance in dataset" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Labels: {0, 1}\n", "Zero count = 1897, One count = 30872\n" ] } ], "source": [ "print('Labels: ', set(y))\n", "print('Zero count = ' + str(len(y) - sum(y)) + ', One count = ' + str(sum(y)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To train model in CatBoost we need to create wrapper class for data: Pool.\n", "This class stores the data in CatBoost internal format.\n", "\n", "There exists several ways to create pool. \n", "The most simple one: create it from pandas dataframe or numpy array" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "pool1 = Pool(data=X, \n", " label=y,\n", " cat_features=cat_features) #Indicies of categorical columns in X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This approach is not the most efficient, especially for big dataset: we'll need to copy everything from pandas to our internal format.\n", "\n", "So CatBoost could create Pool direclty from file" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lets look how we could load Pool from file" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Firstly, lets save data frame to disk" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "dataset_dir = './amazon'\n", "if not os.path.exists(dataset_dir):\n", " os.makedirs(dataset_dir)\n", "\n", "# We will be able to work with files with/without header and with different separators.\n", "train_df.to_csv(os.path.join(dataset_dir, 'train.tsv'), index=False, sep='\\t', header=False)\n", "test_df.to_csv(os.path.join(dataset_dir, 'test.tsv'), index=False, sep='\\t', header=False)\n", "\n", "train_df.to_csv(os.path.join(dataset_dir, 'train.csv'), index=False, sep=',', header=True)\n", "test_df.to_csv(os.path.join(dataset_dir, 'test.csv'), index=False, sep=',', header=True)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\t39353\t85475\t117961\t118300\t123472\t117905\t117906\t290919\t117908\r\n", "1\t17183\t1540\t117961\t118343\t123125\t118536\t118536\t308574\t118539\r\n" ] } ], "source": [ "!head -n2 amazon/train.tsv" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ACTION,RESOURCE,MGR_ID,ROLE_ROLLUP_1,ROLE_ROLLUP_2,ROLE_DEPTNAME,ROLE_TITLE,ROLE_FAMILY_DESC,ROLE_FAMILY,ROLE_CODE\r\n", "1,39353,85475,117961,118300,123472,117905,117906,290919,117908\r\n", "1,17183,1540,117961,118343,123125,118536,118536,308574,118539\r\n" ] } ], "source": [ "!head -n3 amazon/train.csv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we have dataset in 2 different formats:\n", "\n", "1) tab-separated without header\n", "\n", "2) comma-separated with header\n", "\n", "\n", "CatBoost, like pandas, could load data from different formats, we just need to pass proper options\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before we load data, we need to set types of each column\n", "\n", "Also, we need to specify columns type. For this CatBoost uses special file, column description\n", "And we have helper-function to easily do this" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from catboost.utils import create_cd\n", "\n", "feature_names = dict()\n", "for column, name in enumerate(train_df):\n", " if column == 0:\n", " continue\n", " feature_names[column - 1] = name\n", " \n", "create_cd(\n", " label=0, \n", " cat_features=list(range(1, train_df.columns.shape[0])),\n", " feature_names=feature_names,\n", " output_path=os.path.join(dataset_dir, 'train.cd')\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\tLabel\t\r\n", "1\tCateg\tRESOURCE\r\n", "2\tCateg\tMGR_ID\r\n", "3\tCateg\tROLE_ROLLUP_1\r\n", "4\tCateg\tROLE_ROLLUP_2\r\n", "5\tCateg\tROLE_DEPTNAME\r\n", "6\tCateg\tROLE_TITLE\r\n", "7\tCateg\tROLE_FAMILY_DESC\r\n", "8\tCateg\tROLE_FAMILY\r\n", "9\tCateg\tROLE_CODE\r\n" ] } ], "source": [ "!cat amazon/train.cd" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "create_cd(\n", " label=0, \n", " cat_features=list(range(1, train_df.columns.shape[0])),\n", " # feature_names=feature_names,\n", " output_path=os.path.join(dataset_dir, 'train_without_names.cd')\n", ")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\tLabel\t\r\n", "1\tCateg\t\r\n", "2\tCateg\t\r\n", "3\tCateg\t\r\n", "4\tCateg\t\r\n", "5\tCateg\t\r\n", "6\tCateg\t\r\n", "7\tCateg\t\r\n", "8\tCateg\t\r\n", "9\tCateg\t\r\n" ] } ], "source": [ "!cat amazon/train_without_names.cd" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "create_cd(\n", " label=0, \n", " cat_features=list(range(2, train_df.columns.shape[0])),\n", " feature_names=feature_names,\n", " output_path=os.path.join(dataset_dir, 'train_with_num.cd')\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\tLabel\t\r\n", "1\tNum\tRESOURCE\r\n", "2\tCateg\tMGR_ID\r\n", "3\tCateg\tROLE_ROLLUP_1\r\n", "4\tCateg\tROLE_ROLLUP_2\r\n", "5\tCateg\tROLE_DEPTNAME\r\n", "6\tCateg\tROLE_TITLE\r\n", "7\tCateg\tROLE_FAMILY_DESC\r\n", "8\tCateg\tROLE_FAMILY\r\n", "9\tCateg\tROLE_CODE\r\n" ] } ], "source": [ "!cat amazon/train_with_num.cd" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "? create_cd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's load pool from file now:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "pool2 = Pool(\n", " data=os.path.join(dataset_dir, 'train.tsv'), \n", " #delimiter=',', \n", " column_description=os.path.join(dataset_dir, 'train.cd'),\n", " # has_header=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loading pool from file is the fastest way to build Pool if you don't have Pool in RAM yet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Exercices: load the same pools from csv file" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, if you want maximum performance and you data is already in RAM, than in some cases we could do better, than simply passing dataframe to Pool constructor\n", "\n", "We have class FeaturesData that is a fast way to pass data from numpy matrices to catboost" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "? FeaturesData" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# Fastest way to create a Pool is to create it from numpy matrix. This way should be used if you want fast predictions\n", "# or fastest way to load the data in python.\n", "\n", "X_prepared = X.values.astype(str).astype(object)\n", "# For FeaturesData class categorial features must have type str\n", "\n", "pool3 = Pool(\n", " data=FeaturesData(cat_feature_data=X_prepared, cat_feature_names=list(X)),\n", " label=y.values\n", ")\n" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset shape\n", "dataset 1:(32769, 9)\n", "dataset 2:(32769, 9)\n", "dataset 3:(32769, 9)\n", "\n", "\n", "Column names\n", "dataset 1:\n", "['RESOURCE', 'MGR_ID', 'ROLE_ROLLUP_1', 'ROLE_ROLLUP_2', 'ROLE_DEPTNAME', 'ROLE_TITLE', 'ROLE_FAMILY_DESC', 'ROLE_FAMILY', 'ROLE_CODE']\n", "\n", "dataset 2:\n", "['RESOURCE', 'MGR_ID', 'ROLE_ROLLUP_1', 'ROLE_ROLLUP_2', 'ROLE_DEPTNAME', 'ROLE_TITLE', 'ROLE_FAMILY_DESC', 'ROLE_FAMILY', 'ROLE_CODE']\n", "\n", "dataset 3:\n", "['RESOURCE', 'MGR_ID', 'ROLE_ROLLUP_1', 'ROLE_ROLLUP_2', 'ROLE_DEPTNAME', 'ROLE_TITLE', 'ROLE_FAMILY_DESC', 'ROLE_FAMILY', 'ROLE_CODE']\n" ] } ], "source": [ "print('Dataset shape')\n", "print('dataset 1:' + str(pool1.shape) + '\\ndataset 2:' + str(pool2.shape) + \n", " '\\ndataset 3:' + str(pool3.shape))\n", "\n", "print('\\n')\n", "print('Column names')\n", "print('dataset 1:')\n", "print(pool1.get_feature_names()) \n", "print('\\ndataset 2:')\n", "print(pool2.get_feature_names())\n", "print('\\ndataset 3:')\n", "print(pool3.get_feature_names())\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Split your data into train and validation" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/noxoomo/anaconda3/lib/python3.5/site-packages/sklearn/model_selection/_split.py:2069: FutureWarning: From version 0.21, test_size will always complement train_size unless both are specified.\n", " FutureWarning)\n" ] } ], "source": [ "from sklearn.model_selection import train_test_split\n", "X_train, X_validation, y_train, y_validation = train_test_split(X, y, train_size=0.8, random_state=1234)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RESOURCEMGR_IDROLE_ROLLUP_1ROLE_ROLLUP_2ROLE_DEPTNAMEROLE_TITLEROLE_FAMILY_DESCROLE_FAMILYROLE_CODE
1492674463105908117961118225129617118702132654118704118705
794017278120340120342120343119076118834311236118424118836
247687932517733117961118300119984118890125128118398118892
163351733075117961117962120677120357120678118424120359
2743156723745117961118300118360124435118362118363124436
\n", "
" ], "text/plain": [ " RESOURCE MGR_ID ROLE_ROLLUP_1 ROLE_ROLLUP_2 ROLE_DEPTNAME \\\n", "14926 74463 105908 117961 118225 129617 \n", "7940 17278 120340 120342 120343 119076 \n", "24768 79325 17733 117961 118300 119984 \n", "1633 5173 3075 117961 117962 120677 \n", "2743 15672 3745 117961 118300 118360 \n", "\n", " ROLE_TITLE ROLE_FAMILY_DESC ROLE_FAMILY ROLE_CODE \n", "14926 118702 132654 118704 118705 \n", "7940 118834 311236 118424 118836 \n", "24768 118890 125128 118398 118892 \n", "1633 120357 120678 118424 120359 \n", "2743 124435 118362 118363 124436 " ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.head()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(26215, 9)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.shape" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(6554, 9)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_validation.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Selecting the objective function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Possible options for binary classification:\n", "\n", "`Logloss`\n", "\n", "`CrossEntropy` for probabilities in target" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "\n", "model = CatBoostClassifier(\n", " iterations=5,\n", " learning_rate=0.1,\n", " #loss_function='Logloss',\n", " #loss_function='CrossEntropy'\n", ")\n", "\n", "train_pool = Pool(data=X_train, \n", " label=y_train, \n", " cat_features=cat_features)\n", "\n", "validation_pool = Pool(data=X_validation, \n", " label=y_validation, \n", " cat_features=cat_features)\n", "model.fit(\n", " train_pool,\n", " eval_set=validation_pool,\n", " verbose=False\n", ")" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model is fitted: True\n", "Model params:\n", "{'iterations': 5, 'learning_rate': 0.1, 'loss_function': 'Logloss'}\n" ] } ], "source": [ "print('Model is fitted: ' + str(model.is_fitted()))\n", "print('Model params:')\n", "print(model.get_params())" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = CatBoostClassifier(\n", " iterations=5,\n", " learning_rate=0.1,\n", " #loss_function='Logloss',\n", " #loss_function='CrossEntropy'\n", ")\n", "\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False\n", ")" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model is fitted: True\n", "Model params:\n", "{'iterations': 5, 'learning_rate': 0.1, 'loss_function': 'Logloss'}\n" ] } ], "source": [ "print('Model is fitted: ' + str(model.is_fitted()))\n", "print('Model params:')\n", "print(model.get_params())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Stdout of the training" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\tlearn: 0.2922585\ttest: 0.2894772\tbest: 0.2894772 (0)\ttotal: 50.6ms\tremaining: 708ms\n", "1:\tlearn: 0.2200332\ttest: 0.2193253\tbest: 0.2193253 (1)\ttotal: 103ms\tremaining: 667ms\n", "2:\tlearn: 0.1882770\ttest: 0.1794720\tbest: 0.1794720 (2)\ttotal: 169ms\tremaining: 676ms\n", "3:\tlearn: 0.1787690\ttest: 0.1655785\tbest: 0.1655785 (3)\ttotal: 237ms\tremaining: 652ms\n", "4:\tlearn: 0.1748384\ttest: 0.1586419\tbest: 0.1586419 (4)\ttotal: 309ms\tremaining: 619ms\n", "5:\tlearn: 0.1737776\ttest: 0.1564081\tbest: 0.1564081 (5)\ttotal: 420ms\tremaining: 630ms\n", "6:\tlearn: 0.1723479\ttest: 0.1542367\tbest: 0.1542367 (6)\ttotal: 508ms\tremaining: 581ms\n", "7:\tlearn: 0.1719611\ttest: 0.1540092\tbest: 0.1540092 (7)\ttotal: 593ms\tremaining: 518ms\n", "8:\tlearn: 0.1718945\ttest: 0.1536590\tbest: 0.1536590 (8)\ttotal: 676ms\tremaining: 450ms\n", "9:\tlearn: 0.1697911\ttest: 0.1524950\tbest: 0.1524950 (9)\ttotal: 738ms\tremaining: 369ms\n", "10:\tlearn: 0.1684481\ttest: 0.1509084\tbest: 0.1509084 (10)\ttotal: 832ms\tremaining: 303ms\n", "11:\tlearn: 0.1673863\ttest: 0.1495139\tbest: 0.1495139 (11)\ttotal: 920ms\tremaining: 230ms\n", "12:\tlearn: 0.1642212\ttest: 0.1473859\tbest: 0.1473859 (12)\ttotal: 1s\tremaining: 154ms\n", "13:\tlearn: 0.1631211\ttest: 0.1468169\tbest: 0.1468169 (13)\ttotal: 1.08s\tremaining: 77.5ms\n", "14:\tlearn: 0.1628934\ttest: 0.1465499\tbest: 0.1465499 (14)\ttotal: 1.15s\tremaining: 0us\n", "\n", "bestTest = 0.1465498889\n", "bestIteration = 14\n", "\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=15,\n", " #verbose=5,\n", " logging_level='Verbose'\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", ")" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "{ROLE_FAMILY} pr1 tb0 type0, border=1 score 71.20196708\n", " tensor 0 is redundant, remove it and stop\n", "0:\tlearn: 0.3020322\ttest: 0.3024549\tbest: 0.3024549 (0)\ttotal: 46.3ms\tremaining: 648ms\n", "\n", "{ROLE_FAMILY} pr0 tb0 type0, border=6 score 22.73090753\n", "{ROLE_FAMILY, ROLE_CODE} pr0 tb0 type1, border=1 score 22.85737971\n", "{ROLE_TITLE} pr0 tb0 type0, border=13 score 23.18484659\n", "{ROLE_DEPTNAME, ROLE_FAMILY, ROLE_CODE} pr2 tb0 type0, border=10 score 23.4271605\n", "{ROLE_ROLLUP_2} pr1 tb0 type0, border=13 score 25.42952513\n", "{ROLE_TITLE, ROLE_FAMILY} pr2 tb0 type0, border=7 score 25.91828995\n", "1:\tlearn: 0.2126748\ttest: 0.2068150\tbest: 0.2068150 (1)\ttotal: 137ms\tremaining: 890ms\n", "\n", "{MGR_ID} pr2 tb0 type0, border=11 score 9.55130831\n", "{MGR_ID, ROLE_CODE} pr2 tb0 type0, border=14 score 12.59291657\n", "{ROLE_FAMILY} pr0 tb0 type1, border=9 score 14.15678078\n", "{ROLE_TITLE, ROLE_FAMILY} pr1 tb0 type0, border=6 score 14.34218082\n", "{ROLE_ROLLUP_1, ROLE_FAMILY} pr2 tb0 type0, border=3 score 14.2932558\n", "{ROLE_FAMILY, ROLE_CODE} pr0 tb0 type1, border=3 score 14.4073355\n", "2:\tlearn: 0.1856771\ttest: 0.1759821\tbest: 0.1759821 (2)\ttotal: 282ms\tremaining: 1.13s\n", "\n", "{MGR_ID} pr2 tb0 type0, border=10 score 6.594784046\n", "{MGR_ID, ROLE_CODE} pr2 tb0 type0, border=12 score 8.339746264\n", "{MGR_ID, ROLE_FAMILY_DESC} pr2 tb0 type0, border=11 score 9.115213253\n", "{ROLE_ROLLUP_2} pr0 tb0 type0, border=12 score 10.5940908\n", "{MGR_ID, ROLE_FAMILY_DESC} pr0 tb0 type0, border=5 score 10.45322949\n", "{ROLE_TITLE} pr2 tb0 type0, border=6 score 10.66637876\n", "3:\tlearn: 0.1756533\ttest: 0.1621181\tbest: 0.1621181 (3)\ttotal: 397ms\tremaining: 1.09s\n", "\n", "{MGR_ID} pr2 tb0 type0, border=10 score 4.816932629\n", "{MGR_ID, ROLE_FAMILY_DESC} pr2 tb0 type0, border=12 score 5.826899508\n", "{ROLE_ROLLUP_2} pr2 tb0 type0, border=13 score 6.455962742\n", "{ROLE_FAMILY} pr2 tb0 type0, border=10 score 6.668192917\n", "{ROLE_TITLE} pr2 tb0 type0, border=8 score 6.787022309\n", "{ROLE_FAMILY, ROLE_CODE} pr2 tb0 type0, border=1 score 6.784150409\n", " tensor 5 is redundant, remove it and stop\n", "4:\tlearn: 0.1732514\ttest: 0.1587816\tbest: 0.1587816 (4)\ttotal: 526ms\tremaining: 1.05s\n", "\n", "{MGR_ID} pr2 tb0 type0, border=7 score 2.297190459\n", "{RESOURCE} pr2 tb0 type0, border=11 score 3.692999825\n", "{MGR_ID, ROLE_FAMILY_DESC} pr2 tb0 type0, border=9 score 4.566200375\n", "{RESOURCE} pr0 tb0 type0, border=13 score 5.195022352\n", "{MGR_ID, ROLE_FAMILY_DESC} pr0 tb0 type0, border=10 score 5.882398347\n", "{MGR_ID, ROLE_FAMILY_DESC} pr1 tb0 type0, border=6 score 5.715399768\n", "5:\tlearn: 0.1693831\ttest: 0.1534864\tbest: 0.1534864 (5)\ttotal: 633ms\tremaining: 950ms\n", "\n", "{MGR_ID} pr2 tb0 type0, border=9 score 2.56631672\n", "{MGR_ID, ROLE_FAMILY_DESC} pr2 tb0 type0, border=12 score 3.053372671\n", "{ROLE_FAMILY} pr0 tb0 type1, border=14 score 3.662099526\n", " tensor 2 is redundant, remove it and stop\n", "6:\tlearn: 0.1692429\ttest: 0.1530965\tbest: 0.1530965 (6)\ttotal: 688ms\tremaining: 786ms\n", "\n", "{MGR_ID} pr2 tb0 type0, border=9 score 2.662106324\n", "{MGR_ID, ROLE_CODE} pr0 tb0 type0, border=14 score 3.271764978\n", " tensor 1 is redundant, remove it and stop\n", "7:\tlearn: 0.1692302\ttest: 0.1530455\tbest: 0.1530455 (7)\ttotal: 740ms\tremaining: 647ms\n", "\n", "{RESOURCE} pr2 tb0 type0, border=9 score 2.704934036\n", "{ROLE_ROLLUP_2} pr1 tb0 type0, border=9 score 3.380733801\n", "{ROLE_FAMILY_DESC} pr0 tb0 type1, border=8 score 3.485919025\n", "{ROLE_ROLLUP_2, ROLE_DEPTNAME} pr2 tb0 type0, border=8 score 4.040211274\n", "{ROLE_DEPTNAME} pr2 tb0 type0, border=4 score 4.662311897\n", "{ROLE_ROLLUP_1} pr2 tb0 type0, border=10 score 4.587959116\n", "8:\tlearn: 0.1678218\ttest: 0.1518261\tbest: 0.1518261 (8)\ttotal: 843ms\tremaining: 562ms\n", "\n", "{RESOURCE} pr2 tb0 type0, border=6 score 3.22558481\n", "{MGR_ID} pr0 tb0 type0, border=6 score 3.598670764\n", "{MGR_ID} pr2 tb0 type0, border=6 score 4.442316954\n", "{ROLE_ROLLUP_1} pr1 tb0 type0, border=14 score 4.534285954\n", " tensor 3 is redundant, remove it and stop\n", "9:\tlearn: 0.1675718\ttest: 0.1514379\tbest: 0.1514379 (9)\ttotal: 900ms\tremaining: 450ms\n", "\n", "{RESOURCE} pr1 tb0 type0, border=8 score 1.706869104\n", "{RESOURCE} pr2 tb0 type0, border=12 score 2.541399767\n", "{MGR_ID} pr0 tb0 type0, border=12 score 3.042148696\n", "{MGR_ID, ROLE_DEPTNAME} pr1 tb0 type0, border=6 score 3.056465669\n", "{ROLE_DEPTNAME} pr0 tb0 type1, border=14 score 3.417063506\n", " tensor 4 is redundant, remove it and stop\n", "10:\tlearn: 0.1662480\ttest: 0.1499766\tbest: 0.1499766 (10)\ttotal: 974ms\tremaining: 354ms\n", "\n", "{ROLE_FAMILY_DESC} pr1 tb0 type0, border=13 score 1.188282501\n", "{MGR_ID, ROLE_FAMILY_DESC} pr2 tb0 type0, border=6 score 2.265758008\n", "{RESOURCE} pr2 tb0 type0, border=12 score 2.620514967\n", "{ROLE_CODE} pr0 tb0 type0, border=14 score 3.027549244\n", " tensor 3 is redundant, remove it and stop\n", "11:\tlearn: 0.1660153\ttest: 0.1497512\tbest: 0.1497512 (11)\ttotal: 1.04s\tremaining: 260ms\n", "\n", "{RESOURCE} pr2 tb0 type0, border=3 score 1.942379095\n", "{ROLE_FAMILY_DESC} pr2 tb0 type0, border=12 score 2.328950537\n", "{ROLE_ROLLUP_2, ROLE_FAMILY_DESC} pr2 tb0 type0, border=13 score 2.976624744\n", "{RESOURCE, ROLE_ROLLUP_1} pr2 tb0 type0, border=11 score 3.247835595\n", "{RESOURCE, ROLE_ROLLUP_1} pr2 tb0 type0, border=4 score 3.98226103\n", "{ROLE_ROLLUP_2, ROLE_FAMILY_DESC} pr0 tb0 type0, border=14 score 3.884491014\n", " tensor 5 is redundant, remove it and stop\n", "12:\tlearn: 0.1655821\ttest: 0.1492094\tbest: 0.1492094 (12)\ttotal: 1.14s\tremaining: 175ms\n", "\n", "{RESOURCE} pr2 tb0 type0, border=6 score 1.401892445\n", "{ROLE_DEPTNAME} pr2 tb0 type0, border=5 score 2.17295077\n", "{RESOURCE, ROLE_ROLLUP_1} pr1 tb0 type0, border=6 score 2.344410053\n", "{RESOURCE, ROLE_ROLLUP_1, ROLE_ROLLUP_2} pr0 tb0 type0, border=6 score 2.784440779\n", "{ROLE_CODE} pr2 tb0 type0, border=8 score 3.478854515\n", "{ROLE_DEPTNAME, ROLE_CODE} pr2 tb0 type0, border=7 score 3.183561803\n", "13:\tlearn: 0.1637591\ttest: 0.1478251\tbest: 0.1478251 (13)\ttotal: 1.27s\tremaining: 91.1ms\n", "\n", "{MGR_ID} pr2 tb0 type0, border=5 score 1.678525175\n", "{RESOURCE} pr2 tb0 type0, border=2 score 1.913548971\n", "{ROLE_ROLLUP_2} pr1 tb0 type0, border=4 score 2.09899603\n", "{ROLE_ROLLUP_2, ROLE_FAMILY_DESC} pr0 tb0 type0, border=14 score 2.529101972\n", " tensor 3 is redundant, remove it and stop\n", "14:\tlearn: 0.1637261\ttest: 0.1477449\tbest: 0.1477449 (14)\ttotal: 1.34s\tremaining: 0us\n", "\n", "bestTest = 0.1477448839\n", "bestIteration = 14\n", "\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=15,\n", " #verbose=5,\n", " logging_level='Info'\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Metrics calculation and graph plotting" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b78812f3160e458dbd4fdd81998054e0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=500,\n", " random_seed=63,\n", " learning_rate=0.5\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Eval metric, custom metrics and best trees count" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a99d9adb234d42a9a82ff7812ebd0718", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=50,\n", " random_seed=63,\n", " learning_rate=0.5,\n", " eval_metric=\"Accuracy\",\n", " use_best_model=False\n", ")\n", "\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "50" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.tree_count_" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3237fc1e455c42079880250710ad16a5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=50,\n", " random_seed=63,\n", " learning_rate=0.5,\n", " custom_loss=['AUC', 'Accuracy']\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "28" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model._tree_count" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Metric hints" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "987ef484e520494a820090a2bb88ac50", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "\n", "model = CatBoostClassifier(\n", " iterations=50,\n", " random_seed=63,\n", " learning_rate=0.5,\n", " eval_metric='AUC:hints=skip_train~false' #default\n", ")\n", "\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "328bd48940a24bdc807e5bc149b56072", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=50,\n", " random_seed=63,\n", " learning_rate=0.5,\n", " eval_metric='AUC:hints=skip_train~false', #default\n", " metric_period=10\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model comparison" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model1 = CatBoostClassifier(\n", " learning_rate=0.5,\n", " iterations=100,\n", " train_dir='learing_rate_0.5'\n", ")\n", "\n", "model2 = CatBoostClassifier(\n", " learning_rate=0.01,\n", " iterations=100,\n", " train_dir='learing_rate_0.01'\n", ")\n", "\n", "model1.fit(\n", " X_train, y_train,\n", " eval_set=(X_validation, y_validation),\n", " cat_features=cat_features,\n", " verbose=False\n", ")\n", "model2.fit(\n", " X_train, y_train,\n", " eval_set=(X_validation, y_validation),\n", " cat_features=cat_features,\n", " verbose=False\n", ")" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d4ec633ced854e4f99dd4731264681ec", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from catboost import MetricVisualizer\n", "MetricVisualizer(['learing_rate_0.01', 'learing_rate_0.5']).start()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overfitting detector" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "80e70b13a7564741b84cd3c7328d04e0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_with_early_stop = CatBoostClassifier(\n", " iterations=200,\n", " random_seed=63,\n", " learning_rate=0.5,\n", " early_stopping_rounds=20\n", ")\n", "model_with_early_stop.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "28\n" ] } ], "source": [ "print(model_with_early_stop.tree_count_)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "32fc59c9d0204b41b75e2bb13327d8be", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_with_early_stop = CatBoostClassifier(\n", " eval_metric='AUC',\n", " iterations=200,\n", " random_seed=63,\n", " learning_rate=0.5,\n", " early_stopping_rounds=20\n", ")\n", "model_with_early_stop.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cross-validation" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "447c85da2d564f4890220e78ff7eb230", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from catboost import cv\n", "\n", "params = {}\n", "params['loss_function'] = 'Logloss'\n", "params['iterations'] = 80\n", "params['custom_loss'] = 'AUC'\n", "params['random_seed'] = 63\n", "params['learning_rate'] = 0.5\n", "\n", "cv_data = cv(\n", " params = params,\n", " pool = Pool(X, label=y, cat_features=cat_features),\n", " fold_count=5,\n", " type = 'Classical',\n", " shuffle=True,\n", " partition_random_seed=0,\n", " plot=True,\n", " stratified=False,\n", " verbose=False\n", ")" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
test-AUC-meantest-AUC-stdtest-Logloss-meantest-Logloss-stdtrain-AUC-meantrain-AUC-stdtrain-Logloss-meantrain-Logloss-std
00.5000000.0000000.3022170.0038710.4999960.0000090.3021890.002418
10.6252870.0984860.2255020.0143720.6189810.0883960.2277760.009594
20.7973810.0318020.1820740.0101770.7585750.0369030.1908740.004790
30.8185420.0190810.1686300.0068340.7784740.0125470.1806570.004351
40.8238750.0132420.1635910.0071060.7846640.0076620.1767910.003464
50.8420550.0068450.1594240.0062780.8014860.0082650.1735450.003944
60.8525050.0115920.1558350.0072500.8120290.0048140.1708150.003318
70.8592900.0130550.1526860.0073740.8215730.0041500.1681680.002686
80.8624000.0129300.1514120.0077500.8255370.0021660.1668540.002284
90.8643130.0121270.1503880.0073120.8276990.0043550.1659800.002763
100.8646210.0122510.1499810.0072680.8282380.0049670.1657100.002916
110.8699110.0121210.1481500.0080580.8323910.0040980.1642500.002763
120.8711780.0114280.1474790.0081750.8342700.0054160.1633280.003049
130.8716630.0112810.1473030.0082300.8348860.0055300.1630570.003074
140.8722520.0103880.1472000.0082470.8355800.0060420.1627110.003282
150.8725060.0105180.1470030.0084780.8361770.0065050.1623500.003406
160.8737820.0095030.1463930.0082120.8375870.0066700.1617290.003425
170.8757300.0084380.1457390.0079610.8404220.0083380.1606300.003884
180.8766550.0092060.1454210.0080270.8425640.0054160.1599100.003345
190.8769610.0095290.1451510.0083680.8437620.0042230.1594560.002908
200.8771310.0096570.1450950.0083610.8438620.0041370.1593370.002843
210.8777970.0096270.1446910.0083320.8457420.0049770.1585030.002997
220.8784560.0098850.1444740.0083870.8470260.0043510.1579640.002731
230.8780980.0104340.1444970.0084840.8480290.0050310.1575160.002853
240.8785180.0104050.1443710.0086340.8485540.0051700.1572130.002953
250.8791140.0107060.1441430.0088710.8497410.0050650.1566560.002807
260.8790010.0109420.1440830.0089580.8501910.0049550.1562960.002907
270.8793220.0110280.1439800.0087610.8509110.0054070.1558610.002955
280.8796940.0113260.1436920.0091070.8514710.0049560.1555990.002789
290.8797810.0110450.1436150.0091350.8516300.0050690.1554550.002823
...........................
500.8825120.0125000.1426040.0104090.8602300.0053780.1501630.003652
510.8824530.0123850.1426210.0104000.8603100.0053310.1500980.003616
520.8823850.0124150.1426370.0104030.8604540.0054180.1499820.003705
530.8824430.0120710.1425240.0103840.8610730.0056050.1496850.003798
540.8822970.0121770.1425170.0104840.8612480.0054200.1495610.003734
550.8821990.0122240.1425560.0105130.8613190.0054840.1494840.003754
560.8821940.0122050.1425740.0104950.8613690.0054920.1494240.003744
570.8822770.0121280.1424930.0105130.8616460.0054240.1492690.003714
580.8823210.0122030.1424890.0105290.8617080.0054150.1492290.003720
590.8821940.0121650.1425550.0105000.8617390.0053970.1491710.003693
600.8821520.0121890.1425790.0104940.8617420.0053930.1491480.003703
610.8824420.0122750.1425140.0104800.8620270.0051360.1489430.003607
620.8825310.0121480.1424920.0104390.8621210.0051560.1488570.003626
630.8823840.0123050.1424500.0104020.8622490.0052530.1487370.003715
640.8823270.0123620.1426040.0103500.8624040.0053270.1486280.003732
650.8823940.0124760.1425540.0104030.8626640.0053930.1484640.003785
660.8827180.0120370.1424750.0102940.8629640.0057430.1482850.003972
670.8824410.0125220.1425240.0102630.8633910.0061730.1480760.004102
680.8824100.0125630.1425350.0102840.8634870.0062850.1479850.004199
690.8826880.0128170.1424290.0103690.8639470.0061930.1477170.004175
700.8828250.0127290.1423430.0103210.8640800.0062830.1475740.004268
710.8829480.0128390.1422890.0104180.8642430.0063870.1474330.004301
720.8830910.0128950.1421790.0104930.8647030.0060830.1471740.004146
730.8830800.0129290.1422130.0105550.8647840.0060460.1471110.004117
740.8829210.0129680.1422230.0105040.8650530.0060980.1469300.004142
750.8829270.0129930.1421920.0105950.8651550.0060750.1468420.004101
760.8829420.0129870.1422210.0105740.8653040.0059870.1466950.004043
770.8833050.0129680.1421000.0107140.8655120.0060000.1465370.003998
780.8835230.0131460.1421310.0106810.8656170.0059380.1464480.003941
790.8836610.0130470.1420580.0106380.8657270.0059610.1463410.003937
\n", "

80 rows × 8 columns

\n", "
" ], "text/plain": [ " test-AUC-mean test-AUC-std test-Logloss-mean test-Logloss-std \\\n", "0 0.500000 0.000000 0.302217 0.003871 \n", "1 0.625287 0.098486 0.225502 0.014372 \n", "2 0.797381 0.031802 0.182074 0.010177 \n", "3 0.818542 0.019081 0.168630 0.006834 \n", "4 0.823875 0.013242 0.163591 0.007106 \n", "5 0.842055 0.006845 0.159424 0.006278 \n", "6 0.852505 0.011592 0.155835 0.007250 \n", "7 0.859290 0.013055 0.152686 0.007374 \n", "8 0.862400 0.012930 0.151412 0.007750 \n", "9 0.864313 0.012127 0.150388 0.007312 \n", "10 0.864621 0.012251 0.149981 0.007268 \n", "11 0.869911 0.012121 0.148150 0.008058 \n", "12 0.871178 0.011428 0.147479 0.008175 \n", "13 0.871663 0.011281 0.147303 0.008230 \n", "14 0.872252 0.010388 0.147200 0.008247 \n", "15 0.872506 0.010518 0.147003 0.008478 \n", "16 0.873782 0.009503 0.146393 0.008212 \n", "17 0.875730 0.008438 0.145739 0.007961 \n", "18 0.876655 0.009206 0.145421 0.008027 \n", "19 0.876961 0.009529 0.145151 0.008368 \n", "20 0.877131 0.009657 0.145095 0.008361 \n", "21 0.877797 0.009627 0.144691 0.008332 \n", "22 0.878456 0.009885 0.144474 0.008387 \n", "23 0.878098 0.010434 0.144497 0.008484 \n", "24 0.878518 0.010405 0.144371 0.008634 \n", "25 0.879114 0.010706 0.144143 0.008871 \n", "26 0.879001 0.010942 0.144083 0.008958 \n", "27 0.879322 0.011028 0.143980 0.008761 \n", "28 0.879694 0.011326 0.143692 0.009107 \n", "29 0.879781 0.011045 0.143615 0.009135 \n", ".. ... ... ... ... \n", "50 0.882512 0.012500 0.142604 0.010409 \n", "51 0.882453 0.012385 0.142621 0.010400 \n", "52 0.882385 0.012415 0.142637 0.010403 \n", "53 0.882443 0.012071 0.142524 0.010384 \n", "54 0.882297 0.012177 0.142517 0.010484 \n", "55 0.882199 0.012224 0.142556 0.010513 \n", "56 0.882194 0.012205 0.142574 0.010495 \n", "57 0.882277 0.012128 0.142493 0.010513 \n", "58 0.882321 0.012203 0.142489 0.010529 \n", "59 0.882194 0.012165 0.142555 0.010500 \n", "60 0.882152 0.012189 0.142579 0.010494 \n", "61 0.882442 0.012275 0.142514 0.010480 \n", "62 0.882531 0.012148 0.142492 0.010439 \n", "63 0.882384 0.012305 0.142450 0.010402 \n", "64 0.882327 0.012362 0.142604 0.010350 \n", "65 0.882394 0.012476 0.142554 0.010403 \n", "66 0.882718 0.012037 0.142475 0.010294 \n", "67 0.882441 0.012522 0.142524 0.010263 \n", "68 0.882410 0.012563 0.142535 0.010284 \n", "69 0.882688 0.012817 0.142429 0.010369 \n", "70 0.882825 0.012729 0.142343 0.010321 \n", "71 0.882948 0.012839 0.142289 0.010418 \n", "72 0.883091 0.012895 0.142179 0.010493 \n", "73 0.883080 0.012929 0.142213 0.010555 \n", "74 0.882921 0.012968 0.142223 0.010504 \n", "75 0.882927 0.012993 0.142192 0.010595 \n", "76 0.882942 0.012987 0.142221 0.010574 \n", "77 0.883305 0.012968 0.142100 0.010714 \n", "78 0.883523 0.013146 0.142131 0.010681 \n", "79 0.883661 0.013047 0.142058 0.010638 \n", "\n", " train-AUC-mean train-AUC-std train-Logloss-mean train-Logloss-std \n", "0 0.499996 0.000009 0.302189 0.002418 \n", "1 0.618981 0.088396 0.227776 0.009594 \n", "2 0.758575 0.036903 0.190874 0.004790 \n", "3 0.778474 0.012547 0.180657 0.004351 \n", "4 0.784664 0.007662 0.176791 0.003464 \n", "5 0.801486 0.008265 0.173545 0.003944 \n", "6 0.812029 0.004814 0.170815 0.003318 \n", "7 0.821573 0.004150 0.168168 0.002686 \n", "8 0.825537 0.002166 0.166854 0.002284 \n", "9 0.827699 0.004355 0.165980 0.002763 \n", "10 0.828238 0.004967 0.165710 0.002916 \n", "11 0.832391 0.004098 0.164250 0.002763 \n", "12 0.834270 0.005416 0.163328 0.003049 \n", "13 0.834886 0.005530 0.163057 0.003074 \n", "14 0.835580 0.006042 0.162711 0.003282 \n", "15 0.836177 0.006505 0.162350 0.003406 \n", "16 0.837587 0.006670 0.161729 0.003425 \n", "17 0.840422 0.008338 0.160630 0.003884 \n", "18 0.842564 0.005416 0.159910 0.003345 \n", "19 0.843762 0.004223 0.159456 0.002908 \n", "20 0.843862 0.004137 0.159337 0.002843 \n", "21 0.845742 0.004977 0.158503 0.002997 \n", "22 0.847026 0.004351 0.157964 0.002731 \n", "23 0.848029 0.005031 0.157516 0.002853 \n", "24 0.848554 0.005170 0.157213 0.002953 \n", "25 0.849741 0.005065 0.156656 0.002807 \n", "26 0.850191 0.004955 0.156296 0.002907 \n", "27 0.850911 0.005407 0.155861 0.002955 \n", "28 0.851471 0.004956 0.155599 0.002789 \n", "29 0.851630 0.005069 0.155455 0.002823 \n", ".. ... ... ... ... \n", "50 0.860230 0.005378 0.150163 0.003652 \n", "51 0.860310 0.005331 0.150098 0.003616 \n", "52 0.860454 0.005418 0.149982 0.003705 \n", "53 0.861073 0.005605 0.149685 0.003798 \n", "54 0.861248 0.005420 0.149561 0.003734 \n", "55 0.861319 0.005484 0.149484 0.003754 \n", "56 0.861369 0.005492 0.149424 0.003744 \n", "57 0.861646 0.005424 0.149269 0.003714 \n", "58 0.861708 0.005415 0.149229 0.003720 \n", "59 0.861739 0.005397 0.149171 0.003693 \n", "60 0.861742 0.005393 0.149148 0.003703 \n", "61 0.862027 0.005136 0.148943 0.003607 \n", "62 0.862121 0.005156 0.148857 0.003626 \n", "63 0.862249 0.005253 0.148737 0.003715 \n", "64 0.862404 0.005327 0.148628 0.003732 \n", "65 0.862664 0.005393 0.148464 0.003785 \n", "66 0.862964 0.005743 0.148285 0.003972 \n", "67 0.863391 0.006173 0.148076 0.004102 \n", "68 0.863487 0.006285 0.147985 0.004199 \n", "69 0.863947 0.006193 0.147717 0.004175 \n", "70 0.864080 0.006283 0.147574 0.004268 \n", "71 0.864243 0.006387 0.147433 0.004301 \n", "72 0.864703 0.006083 0.147174 0.004146 \n", "73 0.864784 0.006046 0.147111 0.004117 \n", "74 0.865053 0.006098 0.146930 0.004142 \n", "75 0.865155 0.006075 0.146842 0.004101 \n", "76 0.865304 0.005987 0.146695 0.004043 \n", "77 0.865512 0.006000 0.146537 0.003998 \n", "78 0.865617 0.005938 0.146448 0.003941 \n", "79 0.865727 0.005961 0.146341 0.003937 \n", "\n", "[80 rows x 8 columns]" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cv_data" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best validation Logloss score, not stratified: 0.1421±0.0106 on step 79\n" ] } ], "source": [ "best_value = np.min(cv_data['test-Logloss-mean'])\n", "best_iter = cv_data['test-Logloss-mean'].idxmin()\n", "\n", "print('Best validation Logloss score, not stratified: {:.4f}±{:.4f} on step {}'.format(\n", " best_value,\n", " cv_data['test-Logloss-std'][best_iter],\n", " best_iter)\n", ")" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f52baa02b8ad418ea07b4b50be5def76", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Best validation Logloss score, not stratified: 0.1412±0.0055 on step 67\n" ] } ], "source": [ "cv_data = cv(\n", " params = params,\n", " pool = Pool(X, label=y, cat_features=cat_features),\n", " fold_count=5,\n", " type = 'Classical',\n", " shuffle=True,\n", " partition_random_seed=0,\n", " plot=True,\n", " stratified=True,\n", " verbose=False\n", ")\n", "\n", "best_value = np.min(cv_data['test-Logloss-mean'])\n", "best_iter = cv_data['test-Logloss-mean'].idxmin()\n", "\n", "print('Best validation Logloss score, not stratified: {:.4f}±{:.4f} on step {}'.format(\n", " best_value,\n", " cv_data['test-Logloss-std'][best_iter],\n", " best_iter)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Select decision boundary" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ba81252dc8bb4d3e858b3627977b24ba", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = CatBoostClassifier(\n", " random_seed=63,\n", " iterations=200,\n", " learning_rate=0.03,\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import sklearn\n", "from sklearn import metrics\n", "from catboost.utils import get_roc_curve" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "? get_roc_curve" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "eval_pool = Pool(X_validation, y_validation, cat_features=cat_features)\n", "curve = get_roc_curve(model, eval_pool)\n", "(fpr, tpr, thresholds) = curve\n", "\n", "plt.figure()\n", "lw = 2\n", "roc_auc = sklearn.metrics.auc(fpr, tpr)\n", "plt.plot(fpr, tpr, color='darkorange',\n", " lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)\n", "\n", "plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')\n", "plt.xlim([0.0, 1.0])\n", "plt.ylim([0.0, 1.05])\n", "plt.xlabel('False Positive Rate')\n", "plt.ylabel('True Positive Rate')\n", "plt.title('Receiver operating characteristic')\n", "plt.legend(loc=\"lower right\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "from catboost.utils import get_fpr_curve\n", "from catboost.utils import get_fnr_curve" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "? get_fpr_curve" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "lw = 2\n", "(thresholds, fpr) = get_fpr_curve(curve=curve)\n", "(thresholds, fnr) = get_fnr_curve(curve=curve)\n", "plt.plot(thresholds, fpr, color='blue', lw=lw, label='FPR')\n", "plt.plot(thresholds, fnr, color='green', lw=lw, label='FNR')\n", "\n", "plt.xlim([0.0, 1.0])\n", "plt.ylim([0.0, 1.05])\n", "plt.xlabel('Threshold')\n", "plt.ylabel('Error Rate')\n", "plt.title('FPR-FNR curves')\n", "plt.legend(loc=\"lower left\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4719014981770065\n", "0.9876472944285266\n" ] } ], "source": [ "from catboost.utils import select_threshold\n", "\n", "print(select_threshold(model=model, data=eval_pool, FNR=0.01))\n", "print(select_threshold(model=model, data=eval_pool, FPR=0.01))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Snapshotting" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "#!rm 'catboost_info/snapshot.bkp'\n", "\n" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\tlearn: 0.3019732\ttest: 0.3022828\tbest: 0.3022828 (0)\ttotal: 67ms\tremaining: 1.27s\n", "1:\tlearn: 0.2238474\ttest: 0.2233617\tbest: 0.2233617 (1)\ttotal: 148ms\tremaining: 1.33s\n", "2:\tlearn: 0.1904507\ttest: 0.1818139\tbest: 0.1818139 (2)\ttotal: 252ms\tremaining: 1.43s\n", "3:\tlearn: 0.1802458\ttest: 0.1659696\tbest: 0.1659696 (3)\ttotal: 363ms\tremaining: 1.45s\n", "4:\tlearn: 0.1754740\ttest: 0.1587777\tbest: 0.1587777 (4)\ttotal: 485ms\tremaining: 1.46s\n", "5:\tlearn: 0.1703625\ttest: 0.1514527\tbest: 0.1514527 (5)\ttotal: 566ms\tremaining: 1.32s\n", "6:\tlearn: 0.1684087\ttest: 0.1485265\tbest: 0.1485265 (6)\ttotal: 654ms\tremaining: 1.22s\n", "7:\tlearn: 0.1672054\ttest: 0.1477918\tbest: 0.1477918 (7)\ttotal: 741ms\tremaining: 1.11s\n", "8:\tlearn: 0.1662518\ttest: 0.1470028\tbest: 0.1470028 (8)\ttotal: 842ms\tremaining: 1.03s\n", "9:\tlearn: 0.1656720\ttest: 0.1469340\tbest: 0.1469340 (9)\ttotal: 929ms\tremaining: 929ms\n", "10:\tlearn: 0.1646958\ttest: 0.1458786\tbest: 0.1458786 (10)\ttotal: 1.02s\tremaining: 832ms\n", "11:\tlearn: 0.1639292\ttest: 0.1456439\tbest: 0.1456439 (11)\ttotal: 1.1s\tremaining: 737ms\n", "12:\tlearn: 0.1631213\ttest: 0.1453225\tbest: 0.1453225 (12)\ttotal: 1.18s\tremaining: 638ms\n", "13:\tlearn: 0.1628037\ttest: 0.1449395\tbest: 0.1449395 (13)\ttotal: 1.25s\tremaining: 538ms\n", "14:\tlearn: 0.1626140\ttest: 0.1450753\tbest: 0.1449395 (13)\ttotal: 1.32s\tremaining: 440ms\n", "15:\tlearn: 0.1614957\ttest: 0.1448378\tbest: 0.1448378 (15)\ttotal: 1.42s\tremaining: 356ms\n", "16:\tlearn: 0.1614173\ttest: 0.1448587\tbest: 0.1448378 (15)\ttotal: 1.51s\tremaining: 267ms\n", "17:\tlearn: 0.1614154\ttest: 0.1448968\tbest: 0.1448378 (15)\ttotal: 1.61s\tremaining: 179ms\n", "18:\tlearn: 0.1613764\ttest: 0.1448062\tbest: 0.1448062 (18)\ttotal: 1.65s\tremaining: 87ms\n", "19:\tlearn: 0.1612640\ttest: 0.1445666\tbest: 0.1445666 (19)\ttotal: 1.72s\tremaining: 0us\n", "\n", "bestTest = 0.1445666012\n", "bestIteration = 19\n", "\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#!rm 'catboost_info/snapshot.bkp'\n", "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=20,\n", " save_snapshot=True,\n", " snapshot_file='snapshot.bkp',\n", " snapshot_interval=1,\n", " random_seed=43\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " eval_set=(X_validation, y_validation),\n", " cat_features=cat_features,\n", " verbose=True\n", ")" ] }, { "cell_type": "code", "execution_count": 61, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "902ccb0a2fef4f25b41b4697f325192a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "20:\tlearn: 0.1607445\ttest: 0.1440626\tbest: 0.1440626 (20)\ttotal: 1.82s\tremaining: 17.1s\n", "21:\tlearn: 0.1606770\ttest: 0.1439965\tbest: 0.1439965 (21)\ttotal: 1.9s\tremaining: 15.8s\n", "22:\tlearn: 0.1606616\ttest: 0.1440083\tbest: 0.1439965 (21)\ttotal: 1.97s\tremaining: 14.8s\n", "23:\tlearn: 0.1606540\ttest: 0.1440366\tbest: 0.1439965 (21)\ttotal: 2.04s\tremaining: 14s\n", "24:\tlearn: 0.1606450\ttest: 0.1440838\tbest: 0.1439965 (21)\ttotal: 2.11s\tremaining: 13.6s\n", "25:\tlearn: 0.1604540\ttest: 0.1439186\tbest: 0.1439186 (25)\ttotal: 2.22s\tremaining: 14.3s\n", "26:\tlearn: 0.1600134\ttest: 0.1436727\tbest: 0.1436727 (26)\ttotal: 2.33s\tremaining: 15s\n", "27:\tlearn: 0.1600132\ttest: 0.1436712\tbest: 0.1436712 (27)\ttotal: 2.4s\tremaining: 14.5s\n", "28:\tlearn: 0.1597968\ttest: 0.1437510\tbest: 0.1436712 (27)\ttotal: 2.49s\tremaining: 14.6s\n", "29:\tlearn: 0.1597306\ttest: 0.1437718\tbest: 0.1436712 (27)\ttotal: 2.62s\tremaining: 15.3s\n", "30:\tlearn: 0.1595170\ttest: 0.1437298\tbest: 0.1436712 (27)\ttotal: 2.77s\tremaining: 16.1s\n", "31:\tlearn: 0.1593775\ttest: 0.1439295\tbest: 0.1436712 (27)\ttotal: 2.88s\tremaining: 16.2s\n", "32:\tlearn: 0.1589895\ttest: 0.1438260\tbest: 0.1436712 (27)\ttotal: 2.99s\tremaining: 16.2s\n", "33:\tlearn: 0.1589845\ttest: 0.1438430\tbest: 0.1436712 (27)\ttotal: 3.04s\tremaining: 15.6s\n", "34:\tlearn: 0.1589319\ttest: 0.1438151\tbest: 0.1436712 (27)\ttotal: 3.12s\tremaining: 15.4s\n", "35:\tlearn: 0.1588918\ttest: 0.1438180\tbest: 0.1436712 (27)\ttotal: 3.19s\tremaining: 15s\n", "36:\tlearn: 0.1588731\ttest: 0.1438125\tbest: 0.1436712 (27)\ttotal: 3.24s\tremaining: 14.6s\n", "37:\tlearn: 0.1588159\ttest: 0.1438478\tbest: 0.1436712 (27)\ttotal: 3.32s\tremaining: 14.4s\n", "38:\tlearn: 0.1588010\ttest: 0.1438314\tbest: 0.1436712 (27)\ttotal: 3.4s\tremaining: 14.2s\n", "39:\tlearn: 0.1586945\ttest: 0.1436908\tbest: 0.1436712 (27)\ttotal: 3.49s\tremaining: 14.1s\n", "40:\tlearn: 0.1586894\ttest: 0.1436654\tbest: 0.1436654 (40)\ttotal: 3.56s\tremaining: 13.9s\n", "41:\tlearn: 0.1582296\ttest: 0.1432243\tbest: 0.1432243 (41)\ttotal: 3.66s\tremaining: 13.9s\n", "42:\tlearn: 0.1582208\ttest: 0.1432282\tbest: 0.1432243 (41)\ttotal: 3.71s\tremaining: 13.6s\n", "43:\tlearn: 0.1582196\ttest: 0.1432159\tbest: 0.1432159 (43)\ttotal: 3.78s\tremaining: 13.4s\n", "44:\tlearn: 0.1579211\ttest: 0.1430249\tbest: 0.1430249 (44)\ttotal: 3.86s\tremaining: 13.3s\n", "45:\tlearn: 0.1578847\ttest: 0.1431223\tbest: 0.1430249 (44)\ttotal: 3.96s\tremaining: 13.2s\n", "46:\tlearn: 0.1578482\ttest: 0.1430689\tbest: 0.1430249 (44)\ttotal: 4.05s\tremaining: 13.2s\n", "47:\tlearn: 0.1578470\ttest: 0.1430681\tbest: 0.1430249 (44)\ttotal: 4.11s\tremaining: 13s\n", "48:\tlearn: 0.1578462\ttest: 0.1430580\tbest: 0.1430249 (44)\ttotal: 4.17s\tremaining: 12.7s\n", "49:\tlearn: 0.1577480\ttest: 0.1432538\tbest: 0.1430249 (44)\ttotal: 4.25s\tremaining: 12.6s\n", "50:\tlearn: 0.1575513\ttest: 0.1433252\tbest: 0.1430249 (44)\ttotal: 4.34s\tremaining: 12.6s\n", "51:\tlearn: 0.1571386\ttest: 0.1428501\tbest: 0.1428501 (51)\ttotal: 4.43s\tremaining: 12.5s\n", "52:\tlearn: 0.1571009\ttest: 0.1428871\tbest: 0.1428501 (51)\ttotal: 4.52s\tremaining: 12.5s\n", "53:\tlearn: 0.1570543\ttest: 0.1428245\tbest: 0.1428245 (53)\ttotal: 4.6s\tremaining: 12.4s\n", "54:\tlearn: 0.1569937\ttest: 0.1428190\tbest: 0.1428190 (54)\ttotal: 4.67s\tremaining: 12.2s\n", "55:\tlearn: 0.1568604\ttest: 0.1428687\tbest: 0.1428190 (54)\ttotal: 4.79s\tremaining: 12.3s\n", "56:\tlearn: 0.1568583\ttest: 0.1428933\tbest: 0.1428190 (54)\ttotal: 4.84s\tremaining: 12s\n", "57:\tlearn: 0.1566636\ttest: 0.1425799\tbest: 0.1425799 (57)\ttotal: 4.92s\tremaining: 11.9s\n", "58:\tlearn: 0.1566105\ttest: 0.1425226\tbest: 0.1425226 (58)\ttotal: 5s\tremaining: 11.9s\n", "59:\tlearn: 0.1565786\ttest: 0.1425208\tbest: 0.1425208 (59)\ttotal: 5.1s\tremaining: 11.8s\n", "60:\tlearn: 0.1565691\ttest: 0.1425103\tbest: 0.1425103 (60)\ttotal: 5.2s\tremaining: 11.8s\n", "61:\tlearn: 0.1564572\ttest: 0.1425675\tbest: 0.1425103 (60)\ttotal: 5.3s\tremaining: 11.8s\n", "62:\tlearn: 0.1564145\ttest: 0.1426340\tbest: 0.1425103 (60)\ttotal: 5.42s\tremaining: 11.8s\n", "63:\tlearn: 0.1555005\ttest: 0.1421151\tbest: 0.1421151 (63)\ttotal: 5.52s\tremaining: 11.7s\n", "64:\tlearn: 0.1548624\ttest: 0.1414761\tbest: 0.1414761 (64)\ttotal: 5.63s\tremaining: 11.7s\n", "65:\tlearn: 0.1548316\ttest: 0.1415012\tbest: 0.1414761 (64)\ttotal: 5.73s\tremaining: 11.7s\n", "66:\tlearn: 0.1546954\ttest: 0.1414039\tbest: 0.1414039 (66)\ttotal: 5.83s\tremaining: 11.6s\n", "67:\tlearn: 0.1545918\ttest: 0.1412461\tbest: 0.1412461 (67)\ttotal: 5.94s\tremaining: 11.6s\n", "68:\tlearn: 0.1545238\ttest: 0.1412214\tbest: 0.1412214 (68)\ttotal: 6.05s\tremaining: 11.6s\n", "69:\tlearn: 0.1543865\ttest: 0.1412259\tbest: 0.1412214 (68)\ttotal: 6.17s\tremaining: 11.5s\n", "70:\tlearn: 0.1541683\ttest: 0.1414024\tbest: 0.1412214 (68)\ttotal: 6.27s\tremaining: 11.5s\n", "71:\tlearn: 0.1540775\ttest: 0.1414143\tbest: 0.1412214 (68)\ttotal: 6.37s\tremaining: 11.4s\n", "72:\tlearn: 0.1540220\ttest: 0.1413343\tbest: 0.1412214 (68)\ttotal: 6.45s\tremaining: 11.3s\n", "73:\tlearn: 0.1539815\ttest: 0.1413039\tbest: 0.1412214 (68)\ttotal: 6.52s\tremaining: 11.2s\n", "74:\tlearn: 0.1539774\ttest: 0.1413123\tbest: 0.1412214 (68)\ttotal: 6.62s\tremaining: 11.1s\n", "75:\tlearn: 0.1539583\ttest: 0.1413918\tbest: 0.1412214 (68)\ttotal: 6.71s\tremaining: 11s\n", "76:\tlearn: 0.1538207\ttest: 0.1412925\tbest: 0.1412214 (68)\ttotal: 6.8s\tremaining: 10.9s\n", "77:\tlearn: 0.1537301\ttest: 0.1412290\tbest: 0.1412214 (68)\ttotal: 6.89s\tremaining: 10.9s\n", "78:\tlearn: 0.1536745\ttest: 0.1411666\tbest: 0.1411666 (78)\ttotal: 6.99s\tremaining: 10.8s\n", "79:\tlearn: 0.1536233\ttest: 0.1412135\tbest: 0.1411666 (78)\ttotal: 7.07s\tremaining: 10.7s\n", "80:\tlearn: 0.1533953\ttest: 0.1413119\tbest: 0.1411666 (78)\ttotal: 7.18s\tremaining: 10.6s\n", "81:\tlearn: 0.1533845\ttest: 0.1412750\tbest: 0.1411666 (78)\ttotal: 7.27s\tremaining: 10.6s\n", "82:\tlearn: 0.1532740\ttest: 0.1413232\tbest: 0.1411666 (78)\ttotal: 7.35s\tremaining: 10.4s\n", "83:\tlearn: 0.1531701\ttest: 0.1416607\tbest: 0.1411666 (78)\ttotal: 7.45s\tremaining: 10.4s\n", "84:\tlearn: 0.1530255\ttest: 0.1416930\tbest: 0.1411666 (78)\ttotal: 7.53s\tremaining: 10.3s\n", "85:\tlearn: 0.1530003\ttest: 0.1416376\tbest: 0.1411666 (78)\ttotal: 7.6s\tremaining: 10.2s\n", "86:\tlearn: 0.1529827\ttest: 0.1415299\tbest: 0.1411666 (78)\ttotal: 7.7s\tremaining: 10.1s\n", "87:\tlearn: 0.1529783\ttest: 0.1415534\tbest: 0.1411666 (78)\ttotal: 7.78s\tremaining: 9.98s\n", "88:\tlearn: 0.1529650\ttest: 0.1415712\tbest: 0.1411666 (78)\ttotal: 7.86s\tremaining: 9.87s\n", "89:\tlearn: 0.1529603\ttest: 0.1415611\tbest: 0.1411666 (78)\ttotal: 7.95s\tremaining: 9.79s\n", "90:\tlearn: 0.1526304\ttest: 0.1415594\tbest: 0.1411666 (78)\ttotal: 8.03s\tremaining: 9.69s\n", "91:\tlearn: 0.1525361\ttest: 0.1417001\tbest: 0.1411666 (78)\ttotal: 8.14s\tremaining: 9.62s\n", "92:\tlearn: 0.1525222\ttest: 0.1416753\tbest: 0.1411666 (78)\ttotal: 8.25s\tremaining: 9.56s\n", "93:\tlearn: 0.1523982\ttest: 0.1417682\tbest: 0.1411666 (78)\ttotal: 8.34s\tremaining: 9.48s\n", "94:\tlearn: 0.1520970\ttest: 0.1417407\tbest: 0.1411666 (78)\ttotal: 8.44s\tremaining: 9.4s\n", "95:\tlearn: 0.1520886\ttest: 0.1416787\tbest: 0.1411666 (78)\ttotal: 8.53s\tremaining: 9.31s\n", "96:\tlearn: 0.1519551\ttest: 0.1416528\tbest: 0.1411666 (78)\ttotal: 8.62s\tremaining: 9.22s\n", "97:\tlearn: 0.1518608\ttest: 0.1417604\tbest: 0.1411666 (78)\ttotal: 8.72s\tremaining: 9.15s\n", "98:\tlearn: 0.1515918\ttest: 0.1418670\tbest: 0.1411666 (78)\ttotal: 8.81s\tremaining: 9.06s\n", "99:\tlearn: 0.1514836\ttest: 0.1421103\tbest: 0.1411666 (78)\ttotal: 8.88s\tremaining: 8.95s\n", "100:\tlearn: 0.1513803\ttest: 0.1424265\tbest: 0.1411666 (78)\ttotal: 8.94s\tremaining: 8.82s\n", "101:\tlearn: 0.1513654\ttest: 0.1424137\tbest: 0.1411666 (78)\ttotal: 9s\tremaining: 8.7s\n", "102:\tlearn: 0.1513427\ttest: 0.1423719\tbest: 0.1411666 (78)\ttotal: 9.06s\tremaining: 8.57s\n", "103:\tlearn: 0.1511778\ttest: 0.1422654\tbest: 0.1411666 (78)\ttotal: 9.14s\tremaining: 8.48s\n", "104:\tlearn: 0.1511117\ttest: 0.1423467\tbest: 0.1411666 (78)\ttotal: 9.24s\tremaining: 8.41s\n", "105:\tlearn: 0.1510436\ttest: 0.1423221\tbest: 0.1411666 (78)\ttotal: 9.35s\tremaining: 8.33s\n", "106:\tlearn: 0.1508977\ttest: 0.1423137\tbest: 0.1411666 (78)\ttotal: 9.43s\tremaining: 8.24s\n", "107:\tlearn: 0.1508959\ttest: 0.1423254\tbest: 0.1411666 (78)\ttotal: 9.5s\tremaining: 8.13s\n", "108:\tlearn: 0.1508424\ttest: 0.1424438\tbest: 0.1411666 (78)\ttotal: 9.59s\tremaining: 8.04s\n", "109:\tlearn: 0.1508101\ttest: 0.1424697\tbest: 0.1411666 (78)\ttotal: 9.67s\tremaining: 7.95s\n", "110:\tlearn: 0.1507799\ttest: 0.1424306\tbest: 0.1411666 (78)\ttotal: 9.75s\tremaining: 7.85s\n", "111:\tlearn: 0.1507635\ttest: 0.1424197\tbest: 0.1411666 (78)\ttotal: 9.84s\tremaining: 7.76s\n", "112:\tlearn: 0.1507610\ttest: 0.1423976\tbest: 0.1411666 (78)\ttotal: 9.92s\tremaining: 7.67s\n", "113:\tlearn: 0.1507207\ttest: 0.1423904\tbest: 0.1411666 (78)\ttotal: 10s\tremaining: 7.58s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "114:\tlearn: 0.1506871\ttest: 0.1423714\tbest: 0.1411666 (78)\ttotal: 10.1s\tremaining: 7.52s\n", "115:\tlearn: 0.1506365\ttest: 0.1423830\tbest: 0.1411666 (78)\ttotal: 10.2s\tremaining: 7.43s\n", "116:\tlearn: 0.1505736\ttest: 0.1425609\tbest: 0.1411666 (78)\ttotal: 10.3s\tremaining: 7.35s\n", "117:\tlearn: 0.1505321\ttest: 0.1425929\tbest: 0.1411666 (78)\ttotal: 10.4s\tremaining: 7.29s\n", "118:\tlearn: 0.1505028\ttest: 0.1425452\tbest: 0.1411666 (78)\ttotal: 10.5s\tremaining: 7.21s\n", "119:\tlearn: 0.1497712\ttest: 0.1424179\tbest: 0.1411666 (78)\ttotal: 10.7s\tremaining: 7.17s\n", "120:\tlearn: 0.1497006\ttest: 0.1423512\tbest: 0.1411666 (78)\ttotal: 10.8s\tremaining: 7.11s\n", "121:\tlearn: 0.1496638\ttest: 0.1423660\tbest: 0.1411666 (78)\ttotal: 11s\tremaining: 7.1s\n", "122:\tlearn: 0.1496092\ttest: 0.1423799\tbest: 0.1411666 (78)\ttotal: 11.1s\tremaining: 7.04s\n", "123:\tlearn: 0.1495761\ttest: 0.1422939\tbest: 0.1411666 (78)\ttotal: 11.2s\tremaining: 6.95s\n", "124:\tlearn: 0.1495648\ttest: 0.1423047\tbest: 0.1411666 (78)\ttotal: 11.3s\tremaining: 6.86s\n", "125:\tlearn: 0.1493551\ttest: 0.1421014\tbest: 0.1411666 (78)\ttotal: 11.5s\tremaining: 6.8s\n", "126:\tlearn: 0.1492764\ttest: 0.1420590\tbest: 0.1411666 (78)\ttotal: 11.6s\tremaining: 6.72s\n", "127:\tlearn: 0.1490840\ttest: 0.1417298\tbest: 0.1411666 (78)\ttotal: 11.7s\tremaining: 6.65s\n", "128:\tlearn: 0.1490683\ttest: 0.1417569\tbest: 0.1411666 (78)\ttotal: 11.8s\tremaining: 6.58s\n", "129:\tlearn: 0.1490611\ttest: 0.1417673\tbest: 0.1411666 (78)\ttotal: 12s\tremaining: 6.53s\n", "130:\tlearn: 0.1490453\ttest: 0.1417709\tbest: 0.1411666 (78)\ttotal: 12.1s\tremaining: 6.45s\n", "131:\tlearn: 0.1490231\ttest: 0.1417575\tbest: 0.1411666 (78)\ttotal: 12.3s\tremaining: 6.39s\n", "132:\tlearn: 0.1489009\ttest: 0.1419888\tbest: 0.1411666 (78)\ttotal: 12.3s\tremaining: 6.29s\n", "133:\tlearn: 0.1486744\ttest: 0.1420246\tbest: 0.1411666 (78)\ttotal: 12.4s\tremaining: 6.2s\n", "134:\tlearn: 0.1485575\ttest: 0.1421380\tbest: 0.1411666 (78)\ttotal: 12.5s\tremaining: 6.11s\n", "135:\tlearn: 0.1485306\ttest: 0.1421969\tbest: 0.1411666 (78)\ttotal: 12.6s\tremaining: 6.01s\n", "136:\tlearn: 0.1483855\ttest: 0.1420489\tbest: 0.1411666 (78)\ttotal: 12.7s\tremaining: 5.91s\n", "137:\tlearn: 0.1483428\ttest: 0.1421026\tbest: 0.1411666 (78)\ttotal: 12.8s\tremaining: 5.83s\n", "138:\tlearn: 0.1483134\ttest: 0.1420898\tbest: 0.1411666 (78)\ttotal: 12.9s\tremaining: 5.72s\n", "139:\tlearn: 0.1483052\ttest: 0.1420878\tbest: 0.1411666 (78)\ttotal: 13s\tremaining: 5.63s\n", "140:\tlearn: 0.1482157\ttest: 0.1420241\tbest: 0.1411666 (78)\ttotal: 13.1s\tremaining: 5.53s\n", "141:\tlearn: 0.1481112\ttest: 0.1421268\tbest: 0.1411666 (78)\ttotal: 13.2s\tremaining: 5.45s\n", "142:\tlearn: 0.1479973\ttest: 0.1422019\tbest: 0.1411666 (78)\ttotal: 13.3s\tremaining: 5.36s\n", "143:\tlearn: 0.1479948\ttest: 0.1422188\tbest: 0.1411666 (78)\ttotal: 13.4s\tremaining: 5.28s\n", "144:\tlearn: 0.1479705\ttest: 0.1421976\tbest: 0.1411666 (78)\ttotal: 13.6s\tremaining: 5.22s\n", "145:\tlearn: 0.1479037\ttest: 0.1423085\tbest: 0.1411666 (78)\ttotal: 13.8s\tremaining: 5.16s\n", "146:\tlearn: 0.1478758\ttest: 0.1424303\tbest: 0.1411666 (78)\ttotal: 13.9s\tremaining: 5.08s\n", "147:\tlearn: 0.1478342\ttest: 0.1424301\tbest: 0.1411666 (78)\ttotal: 14s\tremaining: 4.98s\n", "148:\tlearn: 0.1478086\ttest: 0.1423590\tbest: 0.1411666 (78)\ttotal: 14.1s\tremaining: 4.89s\n", "149:\tlearn: 0.1477729\ttest: 0.1422865\tbest: 0.1411666 (78)\ttotal: 14.2s\tremaining: 4.8s\n", "150:\tlearn: 0.1474847\ttest: 0.1422271\tbest: 0.1411666 (78)\ttotal: 14.3s\tremaining: 4.71s\n", "151:\tlearn: 0.1474169\ttest: 0.1423728\tbest: 0.1411666 (78)\ttotal: 14.4s\tremaining: 4.63s\n", "152:\tlearn: 0.1472865\ttest: 0.1422631\tbest: 0.1411666 (78)\ttotal: 14.6s\tremaining: 4.54s\n", "153:\tlearn: 0.1472423\ttest: 0.1422256\tbest: 0.1411666 (78)\ttotal: 14.7s\tremaining: 4.46s\n", "154:\tlearn: 0.1471948\ttest: 0.1421481\tbest: 0.1411666 (78)\ttotal: 14.8s\tremaining: 4.37s\n", "155:\tlearn: 0.1468854\ttest: 0.1419873\tbest: 0.1411666 (78)\ttotal: 14.9s\tremaining: 4.28s\n", "156:\tlearn: 0.1468790\ttest: 0.1420101\tbest: 0.1411666 (78)\ttotal: 15s\tremaining: 4.18s\n", "157:\tlearn: 0.1468687\ttest: 0.1419619\tbest: 0.1411666 (78)\ttotal: 15.1s\tremaining: 4.09s\n", "158:\tlearn: 0.1468658\ttest: 0.1419551\tbest: 0.1411666 (78)\ttotal: 15.2s\tremaining: 3.99s\n", "159:\tlearn: 0.1468288\ttest: 0.1420261\tbest: 0.1411666 (78)\ttotal: 15.3s\tremaining: 3.89s\n", "160:\tlearn: 0.1467160\ttest: 0.1419321\tbest: 0.1411666 (78)\ttotal: 15.4s\tremaining: 3.8s\n", "161:\tlearn: 0.1467084\ttest: 0.1419409\tbest: 0.1411666 (78)\ttotal: 15.5s\tremaining: 3.7s\n", "162:\tlearn: 0.1464073\ttest: 0.1416654\tbest: 0.1411666 (78)\ttotal: 15.7s\tremaining: 3.61s\n", "163:\tlearn: 0.1462285\ttest: 0.1416156\tbest: 0.1411666 (78)\ttotal: 15.8s\tremaining: 3.52s\n", "164:\tlearn: 0.1460339\ttest: 0.1416086\tbest: 0.1411666 (78)\ttotal: 15.9s\tremaining: 3.42s\n", "165:\tlearn: 0.1460218\ttest: 0.1416287\tbest: 0.1411666 (78)\ttotal: 16s\tremaining: 3.33s\n", "166:\tlearn: 0.1460092\ttest: 0.1416437\tbest: 0.1411666 (78)\ttotal: 16.1s\tremaining: 3.23s\n", "167:\tlearn: 0.1456841\ttest: 0.1418764\tbest: 0.1411666 (78)\ttotal: 16.2s\tremaining: 3.13s\n", "168:\tlearn: 0.1456011\ttest: 0.1419543\tbest: 0.1411666 (78)\ttotal: 16.3s\tremaining: 3.04s\n", "169:\tlearn: 0.1455905\ttest: 0.1420247\tbest: 0.1411666 (78)\ttotal: 16.4s\tremaining: 2.94s\n", "170:\tlearn: 0.1455581\ttest: 0.1420612\tbest: 0.1411666 (78)\ttotal: 16.5s\tremaining: 2.85s\n", "171:\tlearn: 0.1455441\ttest: 0.1420924\tbest: 0.1411666 (78)\ttotal: 16.6s\tremaining: 2.75s\n", "172:\tlearn: 0.1454990\ttest: 0.1421376\tbest: 0.1411666 (78)\ttotal: 16.7s\tremaining: 2.65s\n", "173:\tlearn: 0.1453813\ttest: 0.1420980\tbest: 0.1411666 (78)\ttotal: 16.9s\tremaining: 2.56s\n", "174:\tlearn: 0.1451034\ttest: 0.1420903\tbest: 0.1411666 (78)\ttotal: 17s\tremaining: 2.46s\n", "175:\tlearn: 0.1450984\ttest: 0.1421000\tbest: 0.1411666 (78)\ttotal: 17s\tremaining: 2.36s\n", "176:\tlearn: 0.1449863\ttest: 0.1421289\tbest: 0.1411666 (78)\ttotal: 17.1s\tremaining: 2.26s\n", "177:\tlearn: 0.1447876\ttest: 0.1420145\tbest: 0.1411666 (78)\ttotal: 17.2s\tremaining: 2.16s\n", "178:\tlearn: 0.1447701\ttest: 0.1420666\tbest: 0.1411666 (78)\ttotal: 17.3s\tremaining: 2.06s\n", "179:\tlearn: 0.1446955\ttest: 0.1420109\tbest: 0.1411666 (78)\ttotal: 17.4s\tremaining: 1.96s\n", "180:\tlearn: 0.1446544\ttest: 0.1420329\tbest: 0.1411666 (78)\ttotal: 17.5s\tremaining: 1.86s\n", "181:\tlearn: 0.1445431\ttest: 0.1421508\tbest: 0.1411666 (78)\ttotal: 17.6s\tremaining: 1.76s\n", "182:\tlearn: 0.1442690\ttest: 0.1421298\tbest: 0.1411666 (78)\ttotal: 17.7s\tremaining: 1.67s\n", "183:\tlearn: 0.1441832\ttest: 0.1421534\tbest: 0.1411666 (78)\ttotal: 17.8s\tremaining: 1.57s\n", "184:\tlearn: 0.1440706\ttest: 0.1423811\tbest: 0.1411666 (78)\ttotal: 17.9s\tremaining: 1.47s\n", "185:\tlearn: 0.1439856\ttest: 0.1425963\tbest: 0.1411666 (78)\ttotal: 18s\tremaining: 1.37s\n", "186:\tlearn: 0.1439761\ttest: 0.1426506\tbest: 0.1411666 (78)\ttotal: 18.1s\tremaining: 1.28s\n", "187:\tlearn: 0.1439714\ttest: 0.1426176\tbest: 0.1411666 (78)\ttotal: 18.2s\tremaining: 1.18s\n", "188:\tlearn: 0.1439492\ttest: 0.1426109\tbest: 0.1411666 (78)\ttotal: 18.4s\tremaining: 1.08s\n", "189:\tlearn: 0.1439454\ttest: 0.1426164\tbest: 0.1411666 (78)\ttotal: 18.4s\tremaining: 984ms\n", "190:\tlearn: 0.1439413\ttest: 0.1426552\tbest: 0.1411666 (78)\ttotal: 18.5s\tremaining: 885ms\n", "191:\tlearn: 0.1438359\ttest: 0.1427087\tbest: 0.1411666 (78)\ttotal: 18.6s\tremaining: 787ms\n", "192:\tlearn: 0.1438287\ttest: 0.1427033\tbest: 0.1411666 (78)\ttotal: 18.7s\tremaining: 689ms\n", "193:\tlearn: 0.1437879\ttest: 0.1426357\tbest: 0.1411666 (78)\ttotal: 18.8s\tremaining: 590ms\n", "194:\tlearn: 0.1437860\ttest: 0.1426468\tbest: 0.1411666 (78)\ttotal: 18.9s\tremaining: 492ms\n", "195:\tlearn: 0.1437606\ttest: 0.1426516\tbest: 0.1411666 (78)\ttotal: 19s\tremaining: 393ms\n", "196:\tlearn: 0.1437591\ttest: 0.1426400\tbest: 0.1411666 (78)\ttotal: 19.1s\tremaining: 294ms\n", "197:\tlearn: 0.1437577\ttest: 0.1426297\tbest: 0.1411666 (78)\ttotal: 19.2s\tremaining: 196ms\n", "198:\tlearn: 0.1437447\ttest: 0.1426317\tbest: 0.1411666 (78)\ttotal: 19.3s\tremaining: 98ms\n", "199:\tlearn: 0.1437230\ttest: 0.1426660\tbest: 0.1411666 (78)\ttotal: 19.4s\tremaining: 0us\n", "\n", "bestTest = 0.1411666436\n", "bestIteration = 78\n", "\n", "Shrink model to first 79 iterations.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = CatBoostClassifier(\n", " iterations=200,\n", " save_snapshot=True,\n", " snapshot_file='snapshot.bkp',\n", " snapshot_interval=1,\n", " random_seed=43\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " eval_set=(X_validation, y_validation),\n", " cat_features=cat_features,\n", " verbose=True,\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model predictions" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "? model.predict_proba" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0.0688 0.9312]\n", " [0.008 0.992 ]\n", " [0.0068 0.9932]\n", " ...\n", " [0.0129 0.9871]\n", " [0.0184 0.9816]\n", " [0.0273 0.9727]]\n" ] } ], "source": [ "print(model.predict_proba(data=X_validation))" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1. 1. 1. ... 1. 1. 1.]\n" ] } ], "source": [ "print(model.predict(data=X_validation))" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2.6056 4.8177 4.9809 ... 4.3367 3.9791 3.5714]\n" ] } ], "source": [ "raw_pred = model.predict(data=X_validation, prediction_type='RawFormulaVal')\n", "print(raw_pred)" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.9312 0.992 0.9932 ... 0.9871 0.9816 0.9727]\n" ] } ], "source": [ "import math\n", "def sigmoid(x):\n", " return 1 / (1 + math.exp(-x))\n", "probabilities = [sigmoid(x) for x in raw_pred]\n", "print(np.array(probabilities))" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0.0688 0.9312]\n", " [0.008 0.992 ]\n", " [0.0068 0.9932]\n", " ...\n", " [0.0129 0.9871]\n", " [0.0184 0.9816]\n", " [0.0273 0.9727]]\n" ] } ], "source": [ "X_prepared = X_validation.values.astype(str).astype(object)\n", "# For FeaturesData class categorial features must have type str\n", "\n", "fast_predictions = model.predict_proba(data=FeaturesData(cat_feature_data=X_prepared, \n", " cat_feature_names=list(X_validation)))\n", "\n", "print(fast_predictions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Staged prediction" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration 0, predictions:\n", "[[0.3065 0.6935]\n", " [0.2317 0.7683]\n", " [0.2317 0.7683]\n", " ...\n", " [0.2317 0.7683]\n", " [0.2317 0.7683]\n", " [0.2746 0.7254]]\n", "Iteration 1, predictions:\n", "[[0.3002 0.6998]\n", " [0.1762 0.8238]\n", " [0.1422 0.8578]\n", " ...\n", " [0.1422 0.8578]\n", " [0.1916 0.8084]\n", " [0.2116 0.7884]]\n", "Iteration 2, predictions:\n", "[[0.3307 0.6693]\n", " [0.13 0.87 ]\n", " [0.1038 0.8962]\n", " ...\n", " [0.1454 0.8546]\n", " [0.1956 0.8044]\n", " [0.1938 0.8062]]\n" ] } ], "source": [ "predictions_gen = model.staged_predict_proba(data=X_validation, ntree_start=2, ntree_end=8, eval_period=2)\n", "for iteration, predictions in enumerate(predictions_gen):\n", " print('Iteration ' + str(iteration) + ', predictions:')\n", " print(predictions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Solving MultiClassification problem" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1536e93770b14eb48d913a0209269c87", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=150,\n", " random_seed=43,\n", " loss_function='MultiClass'\n", " #loss_function='MultiClassOneVsAll'\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Metric evaluation on a new dataset" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\tlearn: 0.6569860\ttotal: 53.2ms\tremaining: 10.6s\n", "50:\tlearn: 0.1950260\ttotal: 3.33s\tremaining: 9.74s\n", "100:\tlearn: 0.1700584\ttotal: 6.72s\tremaining: 6.58s\n", "150:\tlearn: 0.1641016\ttotal: 10.6s\tremaining: 3.42s\n", "199:\tlearn: 0.1604074\ttotal: 14.1s\tremaining: 0us\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = CatBoostClassifier(\n", " random_seed=63,\n", " iterations=200,\n", " learning_rate=0.03,\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " verbose=50\n", ")" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "metrics = model.eval_metrics(data=pool1, \n", " metrics=['Logloss','AUC'],\n", " ntree_start=0,\n", " ntree_end=0, \n", " eval_period=1,\n", " plot=True)" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AUC values:\n", "[0.4998 0.538 0.5504 0.5888 0.6503 0.6487 0.6487 0.6601 0.6601 0.6612\n", " 0.6614 0.67 0.6699 0.6697 0.6697 0.6698 0.6698 0.6698 0.6698 0.7265\n", " 0.7338 0.7376 0.7478 0.7478 0.7506 0.7591 0.7642 0.7877 0.8148 0.8265\n", " 0.8353 0.8459 0.8533 0.8564 0.8573 0.8748 0.8874 0.893 0.8948 0.898\n", " 0.9003 0.9052 0.9122 0.9159 0.92 0.9204 0.9209 0.9246 0.9262 0.9266\n", " 0.9279 0.9303 0.9311 0.9312 0.9327 0.9329 0.9337 0.9341 0.9341 0.9349\n", " 0.9354 0.9365 0.9391 0.9411 0.9424 0.9435 0.9446 0.9458 0.9463 0.9471\n", " 0.9478 0.9481 0.9482 0.9483 0.9486 0.9497 0.951 0.9514 0.9515 0.9519\n", " 0.9522 0.9526 0.953 0.9537 0.9543 0.9544 0.9545 0.9548 0.9548 0.9548\n", " 0.9548 0.9548 0.955 0.9551 0.9553 0.9554 0.9557 0.9557 0.9557 0.9557\n", " 0.956 0.9566 0.9566 0.9569 0.9568 0.957 0.9569 0.9572 0.9573 0.9575\n", " 0.9576 0.9576 0.9577 0.958 0.9587 0.9594 0.9595 0.9594 0.9601 0.9609\n", " 0.9609 0.9608 0.9608 0.9612 0.9616 0.9618 0.9622 0.9623 0.9623 0.9625\n", " 0.9625 0.9628 0.9629 0.9629 0.9629 0.963 0.9631 0.9632 0.9635 0.9636\n", " 0.9637 0.9638 0.9638 0.964 0.9641 0.9643 0.9645 0.9645 0.9645 0.9647\n", " 0.9647 0.9647 0.9647 0.9648 0.9648 0.9648 0.965 0.9652 0.9652 0.9653\n", " 0.9653 0.9654 0.9655 0.9655 0.9655 0.9656 0.9656 0.9657 0.9657 0.9657\n", " 0.9658 0.9658 0.9659 0.9659 0.966 0.966 0.966 0.966 0.966 0.966\n", " 0.966 0.966 0.9661 0.9662 0.9663 0.9663 0.9663 0.9663 0.9663 0.9663\n", " 0.9665 0.9664 0.9666 0.9666 0.9666 0.9666 0.9666 0.9667 0.9667 0.9666]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "79673ab90b80409f9098024a23d609ec", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print('AUC values:')\n", "print(np.array(metrics['AUC']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Saving the model" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "my_best_model = CatBoostClassifier(iterations=10)\n", "my_best_model.fit(\n", " X_train, y_train,\n", " eval_set=(X_validation, y_validation),\n", " cat_features=cat_features,\n", " verbose=False\n", ")" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [], "source": [ "my_best_model.save_model('catboost_model.bin')" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'iterations': 10, 'loss_function': 'Logloss', 'logging_level': 'Silent'}\n", "17379826207872\n" ] } ], "source": [ "my_best_model.load_model('catboost_model.bin')\n", "print(my_best_model.get_params())\n", "print(my_best_model.random_seed_)" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [], "source": [ "my_best_model.save_model('catboost_model.json', format='json')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature importances" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\tlearn: 0.6569860\ttotal: 56.9ms\tremaining: 11.3s\n", "50:\tlearn: 0.1950260\ttotal: 4.06s\tremaining: 11.9s\n", "100:\tlearn: 0.1700584\ttotal: 7.18s\tremaining: 7.04s\n", "150:\tlearn: 0.1641016\ttotal: 10.8s\tremaining: 3.5s\n", "199:\tlearn: 0.1604074\ttotal: 14.3s\tremaining: 0us\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = CatBoostClassifier(\n", " random_seed=63,\n", " iterations=200,\n", " learning_rate=0.03)\n", "\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " verbose=50\n", ")" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [], "source": [ "fstrs = model.get_feature_importance(prettified=True)" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{b'MGR_ID': 18.398056223683028,\n", " b'RESOURCE': 21.27625707750949,\n", " b'ROLE_CODE': 11.943230282191124,\n", " b'ROLE_DEPTNAME': 15.252806962601875,\n", " b'ROLE_FAMILY': 2.4789173515872176,\n", " b'ROLE_FAMILY_DESC': 9.984073192415533,\n", " b'ROLE_ROLLUP_1': 2.6278918788673633,\n", " b'ROLE_ROLLUP_2': 13.582536028250486,\n", " b'ROLE_TITLE': 4.456231002893895}" ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ "{feature_name : value for feature_name, value in fstrs}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Shap values" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "https://github.com/slundberg/shap" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [], "source": [ "def object_predictions(model, obj):\n", " print('Probability of class 1 = {:.4f}'.format(model.predict_proba([obj])[0][1]))\n", " print('Formula raw prediction = {:.4f}'.format(model.predict([obj], prediction_type='RawFormulaVal')[0]))\n", " " ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The model has complex ctrs, so the SHAP values will be calculated approximately.\n" ] } ], "source": [ "import shap\n", "explainer = shap.TreeExplainer(model)\n", "shap_values = explainer.shap_values(Pool(X, y, cat_features=cat_features))\n" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3.345740996864643" ] }, "execution_count": 82, "metadata": {}, "output_type": "execute_result" } ], "source": [ "explainer.expected_value" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Probability of class 1 = 0.9798\n", "Formula raw prediction = 3.8820\n" ] } ], "source": [ "object_predictions(model, X.iloc[3,:])" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "shap.initjs()\n", "shap.force_plot(explainer.expected_value, shap_values[3,:], X.iloc[3,:])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above explanation shows features each contributing to push the model output from the base value (the average model output over the training dataset we passed) to the model output. Features pushing the prediction higher are shown in red, those pushing the prediction lower are in blue" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Probability of class 1 = 0.6404\n", "Formula raw prediction = 0.5772\n" ] } ], "source": [ "object_predictions(model, X.iloc[91,:])" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import shap\n", "shap.initjs()\n", "shap.force_plot(explainer.expected_value, shap_values[91,:], X.iloc[91,:])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To get an overview of which features are most important for a model we can plot the SHAP values of every feature for every sample. The plot below sorts features by the sum of SHAP value magnitudes over all samples, and uses SHAP values to show the distribution of the impacts each feature has on the model output. The color represents the feature value (red high, blue low)." ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "shap.summary_plot(shap_values, X)" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "shap.summary_plot(shap_values, X, plot_type=\"bar\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameter tunning" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training speed" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a92eadf3c1c044dea4c7e513a9cc0443", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoost\n", "fast_model = CatBoostClassifier(\n", " random_seed=63,\n", " iterations=150,\n", " learning_rate=0.01,\n", " boosting_type='Plain',\n", " bootstrap_type='Bernoulli',\n", " subsample=0.5,\n", " rsm=0.5,\n", " one_hot_max_size=20,\n", " leaf_estimation_iterations=2,\n", " max_ctr_complexity=1)\n", "\n", "fast_model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Accuracy" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e1847556d61748cea1e37bb8c30ccb1a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tunned_model = CatBoostClassifier(\n", " random_seed=63,\n", " iterations=1000,\n", " learning_rate=0.03,\n", " l2_leaf_reg=3,\n", " bagging_temperature=1,\n", " random_strength=1,\n", " one_hot_max_size=2,\n", " leaf_estimation_method='Newton'\n", ")\n", "tunned_model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " verbose=False,\n", " eval_set=(X_validation, y_validation),\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training the model after parameter tunning" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\tlearn: 0.6431867\ttotal: 103ms\tremaining: 1m 51s\n", "100:\tlearn: 0.1590533\ttotal: 8.58s\tremaining: 1m 23s\n", "200:\tlearn: 0.1517316\ttotal: 16.1s\tremaining: 1m 10s\n", "300:\tlearn: 0.1487795\ttotal: 23.5s\tremaining: 1m\n", "400:\tlearn: 0.1468414\ttotal: 31.3s\tremaining: 53.2s\n", "500:\tlearn: 0.1448880\ttotal: 39.4s\tremaining: 45.7s\n", "600:\tlearn: 0.1433889\ttotal: 48.4s\tremaining: 38.8s\n", "700:\tlearn: 0.1421200\ttotal: 58.1s\tremaining: 31.7s\n", "800:\tlearn: 0.1411938\ttotal: 1m 6s\tremaining: 23.4s\n", "900:\tlearn: 0.1401819\ttotal: 1m 15s\tremaining: 15.2s\n", "1000:\tlearn: 0.1394250\ttotal: 1m 25s\tremaining: 6.99s\n", "1082:\tlearn: 0.1388355\ttotal: 1m 32s\tremaining: 0us\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_model = CatBoostClassifier(\n", " random_seed=63,\n", " iterations=int(tunned_model.tree_count_ * 1.2),\n", ")\n", "best_model.fit(\n", " X, y,\n", " cat_features=cat_features,\n", " verbose=100\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Calculate predictions for the contest" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predictoins:\n", "[[0.4024 0.5976]\n", " [0.015 0.985 ]\n", " [0.0101 0.9899]\n", " ...\n", " [0.0089 0.9911]\n", " [0.0496 0.9504]\n", " [0.0135 0.9865]]\n" ] } ], "source": [ "X_test = test_df.drop('id', axis=1)\n", "test_pool = Pool(data=X_test, cat_features=cat_features)\n", "contest_predictions = best_model.predict_proba(test_pool)\n", "print('Predictoins:')\n", "print(contest_predictions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare the submission" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [], "source": [ "f = open('submit.csv', 'w')\n", "f.write('Id,Action\\n')\n", "for idx in range(len(contest_predictions)):\n", " line = str(test_df['id'][idx]) + ',' + str(contest_predictions[idx][1]) + '\\n'\n", " f.write(line)\n", "f.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Submit your solution [here](https://www.kaggle.com/c/amazon-employee-access-challenge/submit).\n", "Good luck!!!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" }, "widgets": { "state": { "1057714ebc614324aa3ba2cf69408966": { "views": [ { "cell_index": 17 } ] }, "8381e9eed05f4a03905ae8a56d7ab4ea": { "views": [ { "cell_index": 48 } ] }, "f49684e8c5c44241bfe2c7f577f5cb41": { "views": [ { "cell_index": 53 } ] } }, "version": "1.2.0" } }, "nbformat": 4, "nbformat_minor": 2 }