{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CatBoost and CoreML tutorial — Titanic dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"CatBoost does support model export to Apple's [CoreML](https://developer.apple.com/machine-learning/) format, which lets you to easily embed ML models into applications on Apple's platforms.\n",
"\n",
"Currently, export of models with only float and one-hot features supported.\n",
"\n",
"This tutorial demonstrates exporting of CatBoost model trained on [Titanic](https://www.kaggle.com/c/titanic/data) dataset to CoreML model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get titanic dataset:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"from catboost import Pool, CatBoost\n",
"from catboost.datasets import titanic"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_df = titanic()[0]\n",
"X, y = train_df.drop('Survived', axis=1), train_df.Survived"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" PassengerId | \n",
" Pclass | \n",
" Name | \n",
" Sex | \n",
" Age | \n",
" SibSp | \n",
" Parch | \n",
" Ticket | \n",
" Fare | \n",
" Cabin | \n",
" Embarked | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" 3 | \n",
" Braund, Mr. Owen Harris | \n",
" male | \n",
" 22.0 | \n",
" 1 | \n",
" 0 | \n",
" A/5 21171 | \n",
" 7.2500 | \n",
" NaN | \n",
" S | \n",
"
\n",
" \n",
" 1 | \n",
" 2 | \n",
" 1 | \n",
" Cumings, Mrs. John Bradley (Florence Briggs Th... | \n",
" female | \n",
" 38.0 | \n",
" 1 | \n",
" 0 | \n",
" PC 17599 | \n",
" 71.2833 | \n",
" C85 | \n",
" C | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" 3 | \n",
" Heikkinen, Miss. Laina | \n",
" female | \n",
" 26.0 | \n",
" 0 | \n",
" 0 | \n",
" STON/O2. 3101282 | \n",
" 7.9250 | \n",
" NaN | \n",
" S | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" 1 | \n",
" Futrelle, Mrs. Jacques Heath (Lily May Peel) | \n",
" female | \n",
" 35.0 | \n",
" 1 | \n",
" 0 | \n",
" 113803 | \n",
" 53.1000 | \n",
" C123 | \n",
" S | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" 3 | \n",
" Allen, Mr. William Henry | \n",
" male | \n",
" 35.0 | \n",
" 0 | \n",
" 0 | \n",
" 373450 | \n",
" 8.0500 | \n",
" NaN | \n",
" S | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" PassengerId Pclass Name \\\n",
"0 1 3 Braund, Mr. Owen Harris \n",
"1 2 1 Cumings, Mrs. John Bradley (Florence Briggs Th... \n",
"2 3 3 Heikkinen, Miss. Laina \n",
"3 4 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) \n",
"4 5 3 Allen, Mr. William Henry \n",
"\n",
" Sex Age SibSp Parch Ticket Fare Cabin Embarked \n",
"0 male 22.0 1 0 A/5 21171 7.2500 NaN S \n",
"1 female 38.0 1 0 PC 17599 71.2833 C85 C \n",
"2 female 26.0 0 0 STON/O2. 3101282 7.9250 NaN S \n",
"3 female 35.0 1 0 113803 53.1000 C123 S \n",
"4 male 35.0 0 0 373450 8.0500 NaN S "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us drop Name and Ticket features as it doesn't make sence to one-hot these features, since there's a single object with this value, so there will be overfitting."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"X.drop(['Name', 'Ticket'], axis=1, inplace=True)\n",
"categorical_features_indices = np.where(X.dtypes != np.float)[0]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"is_cat = (X.dtypes != float)\n",
"for feature, feat_is_cat in is_cat.to_dict().items():\n",
" if feat_is_cat:\n",
" X[feature].fillna(\"NAN\", inplace=True)\n",
"\n",
"cat_features_index = np.where(is_cat)[0]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" PassengerId | \n",
" Pclass | \n",
" Sex | \n",
" Age | \n",
" SibSp | \n",
" Parch | \n",
" Fare | \n",
" Cabin | \n",
" Embarked | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" 3 | \n",
" male | \n",
" 22.0 | \n",
" 1 | \n",
" 0 | \n",
" 7.2500 | \n",
" NAN | \n",
" S | \n",
"
\n",
" \n",
" 1 | \n",
" 2 | \n",
" 1 | \n",
" female | \n",
" 38.0 | \n",
" 1 | \n",
" 0 | \n",
" 71.2833 | \n",
" C85 | \n",
" C | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" 3 | \n",
" female | \n",
" 26.0 | \n",
" 0 | \n",
" 0 | \n",
" 7.9250 | \n",
" NAN | \n",
" S | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" 1 | \n",
" female | \n",
" 35.0 | \n",
" 1 | \n",
" 0 | \n",
" 53.1000 | \n",
" C123 | \n",
" S | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" 3 | \n",
" male | \n",
" 35.0 | \n",
" 0 | \n",
" 0 | \n",
" 8.0500 | \n",
" NAN | \n",
" S | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" PassengerId Pclass Sex Age SibSp Parch Fare Cabin Embarked\n",
"0 1 3 male 22.0 1 0 7.2500 NAN S\n",
"1 2 1 female 38.0 1 0 71.2833 C85 C\n",
"2 3 3 female 26.0 0 0 7.9250 NAN S\n",
"3 4 1 female 35.0 1 0 53.1000 C123 S\n",
"4 5 3 male 35.0 0 0 8.0500 NAN S"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"train_pool = Pool(data=X, label=y, cat_features=cat_features_index)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the model:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Learning rate set to 0.016216\n",
"0:\tlearn: 0.6862663\ttotal: 73.2ms\tremaining: 1m 13s\n",
"100:\tlearn: 0.4272007\ttotal: 2.6s\tremaining: 23.2s\n",
"200:\tlearn: 0.4044455\ttotal: 4.59s\tremaining: 18.3s\n",
"300:\tlearn: 0.3928060\ttotal: 6.28s\tremaining: 14.6s\n",
"400:\tlearn: 0.3852512\ttotal: 7.88s\tremaining: 11.8s\n",
"500:\tlearn: 0.3750366\ttotal: 9.58s\tremaining: 9.54s\n",
"600:\tlearn: 0.3624703\ttotal: 12.5s\tremaining: 8.28s\n",
"700:\tlearn: 0.3493490\ttotal: 15s\tremaining: 6.39s\n",
"800:\tlearn: 0.3390201\ttotal: 17.3s\tremaining: 4.3s\n",
"900:\tlearn: 0.3301084\ttotal: 19.5s\tremaining: 2.14s\n",
"999:\tlearn: 0.3224542\ttotal: 21.6s\tremaining: 0us\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = CatBoost(params={'loss_function': 'Logloss', 'one_hot_max_size': 255, 'verbose': 100})\n",
"model.fit(train_pool)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Predict probabilities:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"test_pool = Pool(data=X[0:1], cat_features=cat_features_index)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.88003974, 0.11996026]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.predict(test_pool, prediction_type=\"Probability\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Save model:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"model.save_model(\n",
" \"titanic.mlmodel\",\n",
" format=\"coreml\",\n",
" export_parameters={\n",
" 'prediction_type': 'probability'\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"All the features are named as \"feature_i\" where i is a feature number in the dataset starting from 0."
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"Now you can import saved model to XCode and use it directly from swift:\n",
"\n",
"```swift\n",
"import CoreML\n",
"\n",
"let model = titanic()\n",
"\n",
"let passengerId = \"1\"\n",
"let pclass = \"1\"\n",
"let sex = \"female\"\n",
"let age = 38.0\n",
"let sibsp = \"1\"\n",
"let parch = \"0\"\n",
"let fare = 71.2833\n",
"let cabin = \"C85\"\n",
"let embarked = \"C\"\n",
"\n",
"guard let titanicOutput = try? model.prediction(feature_0: passengerId, feature_1: pclass, feature_2: sex, feature_3: age, feature_4: sibsp, feature_5: parch, feature_6: fare, feature_7: cabin, feature_8: embarked) else {\n",
" fatalError(\"Unexpected runtime error.\")\n",
" }\n",
"\n",
"print(String(\n",
" format: \"Probability of survival: %1.5f\",\n",
" titanicOutput.prediction[0].doubleValue\n",
"))\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you want to practice, titanic model is easy to integrate into Apple's [MarsHabitatPricer](https://developer.apple.com/documentation/coreml/integrating_a_core_ml_model_into_your_app) example project:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}