{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# catboost for rust tutorial" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "!pip install -q numpy pandas catboost" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from __future__ import absolute_import, division, print_function, unicode_literals" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CatBoost version 0.14.2\n", "NumPy version 1.16.3\n", "Pandas version 0.24.2\n" ] } ], "source": [ "import catboost as cb\n", "import catboost.datasets as cbd\n", "import numpy as np\n", "import pandas as pd\n", "\n", "# print module versions for reproducibility\n", "print('CatBoost version {}'.format(cb.__version__))\n", "print('NumPy version {}'.format(np.__version__))\n", "print('Pandas version {}'.format(pd.__version__))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Download \"Adult Data Set\" [1] from UCI Machine Learning Repository.\n", "\n", " Will return two pandas.DataFrame-s, first with train part (adult.data) and second with test part\n", " (adult.test) of the dataset.\n", "\n", " [1]: https://archive.ics.uci.edu/ml/datasets/Adult\n", " \n" ] } ], "source": [ "# We are going to use UCI Adult Data Set because it has both numerical and categorical \n", "# features and also has missing features.\n", "print(cbd.adult.__doc__)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def get_fixed_adult():\n", " train, test = cbd.adult()\n", " \n", " # CatBoost doesn't support pandas.DataFrame missing values for categorical features out \n", " # of the box (seed issue #571 on GitHub or issue MLTOOLS-2785 in internal tracker). So \n", " # we have to replace them with some designated string manually. \n", " for dataset in (train, test, ):\n", " for name in (name for name, dtype in dict(dataset.dtypes).items() if dtype == np.object):\n", " dataset[name].fillna('nan', inplace=True)\n", " \n", " X_train, y_train = train.drop('income', axis=1), train.income\n", " X_test, y_test = test.drop('income', axis=1), test.income\n", " return X_train, y_train, X_test, y_test" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "X_train, y_train, _, _ = get_fixed_adult()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | age | \n", "workclass | \n", "fnlwgt | \n", "education | \n", "education-num | \n", "marital-status | \n", "occupation | \n", "relationship | \n", "race | \n", "sex | \n", "capital-gain | \n", "capital-loss | \n", "hours-per-week | \n", "native-country | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "39.0 | \n", "State-gov | \n", "77516.0 | \n", "Bachelors | \n", "13.0 | \n", "Never-married | \n", "Adm-clerical | \n", "Not-in-family | \n", "White | \n", "Male | \n", "2174.0 | \n", "0.0 | \n", "40.0 | \n", "United-States | \n", "
1 | \n", "50.0 | \n", "Self-emp-not-inc | \n", "83311.0 | \n", "Bachelors | \n", "13.0 | \n", "Married-civ-spouse | \n", "Exec-managerial | \n", "Husband | \n", "White | \n", "Male | \n", "0.0 | \n", "0.0 | \n", "13.0 | \n", "United-States | \n", "
2 | \n", "38.0 | \n", "Private | \n", "215646.0 | \n", "HS-grad | \n", "9.0 | \n", "Divorced | \n", "Handlers-cleaners | \n", "Not-in-family | \n", "White | \n", "Male | \n", "0.0 | \n", "0.0 | \n", "40.0 | \n", "United-States | \n", "
3 | \n", "53.0 | \n", "Private | \n", "234721.0 | \n", "11th | \n", "7.0 | \n", "Married-civ-spouse | \n", "Handlers-cleaners | \n", "Husband | \n", "Black | \n", "Male | \n", "0.0 | \n", "0.0 | \n", "40.0 | \n", "United-States | \n", "
4 | \n", "28.0 | \n", "Private | \n", "338409.0 | \n", "Bachelors | \n", "13.0 | \n", "Married-civ-spouse | \n", "Prof-specialty | \n", "Wife | \n", "Black | \n", "Female | \n", "0.0 | \n", "0.0 | \n", "40.0 | \n", "Cuba | \n", "