{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CatBoost and CoreML tutorial — Titanic dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "CatBoost does support model export to Apple's [CoreML](https://developer.apple.com/machine-learning/) format, which lets you to easily embed ML models into applications on Apple's platforms.\n", "\n", "Currently, export of models with only float and one-hot features supported.\n", "\n", "This tutorial demonstrates exporting of CatBoost model trained on [Titanic](https://www.kaggle.com/c/titanic/data) dataset to CoreML model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get titanic dataset:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from catboost import Pool, CatBoost\n", "from catboost.datasets import titanic" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "train_df = titanic()[0]\n", "X, y = train_df.drop('Survived', axis=1), train_df.Survived" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
013Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
121Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
233Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
341Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
453Allen, Mr. William Henrymale35.0003734508.0500NaNS
\n", "
" ], "text/plain": [ " PassengerId Pclass Name \\\n", "0 1 3 Braund, Mr. Owen Harris \n", "1 2 1 Cumings, Mrs. John Bradley (Florence Briggs Th... \n", "2 3 3 Heikkinen, Miss. Laina \n", "3 4 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) \n", "4 5 3 Allen, Mr. William Henry \n", "\n", " Sex Age SibSp Parch Ticket Fare Cabin Embarked \n", "0 male 22.0 1 0 A/5 21171 7.2500 NaN S \n", "1 female 38.0 1 0 PC 17599 71.2833 C85 C \n", "2 female 26.0 0 0 STON/O2. 3101282 7.9250 NaN S \n", "3 female 35.0 1 0 113803 53.1000 C123 S \n", "4 male 35.0 0 0 373450 8.0500 NaN S " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us drop Name and Ticket features as it doesn't make sence to one-hot these features, since there's a single object with this value, so there will be overfitting." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "X.drop(['Name', 'Ticket'], axis=1, inplace=True)\n", "categorical_features_indices = np.where(X.dtypes != np.float)[0]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "is_cat = (X.dtypes != float)\n", "for feature, feat_is_cat in is_cat.to_dict().items():\n", " if feat_is_cat:\n", " X[feature].fillna(\"NAN\", inplace=True)\n", "\n", "cat_features_index = np.where(is_cat)[0]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdPclassSexAgeSibSpParchFareCabinEmbarked
013male22.0107.2500NANS
121female38.01071.2833C85C
233female26.0007.9250NANS
341female35.01053.1000C123S
453male35.0008.0500NANS
\n", "
" ], "text/plain": [ " PassengerId Pclass Sex Age SibSp Parch Fare Cabin Embarked\n", "0 1 3 male 22.0 1 0 7.2500 NAN S\n", "1 2 1 female 38.0 1 0 71.2833 C85 C\n", "2 3 3 female 26.0 0 0 7.9250 NAN S\n", "3 4 1 female 35.0 1 0 53.1000 C123 S\n", "4 5 3 male 35.0 0 0 8.0500 NAN S" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X.head()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "train_pool = Pool(data=X, label=y, cat_features=cat_features_index)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train the model:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Learning rate set to 0.016216\n", "0:\tlearn: 0.6862663\ttotal: 73.2ms\tremaining: 1m 13s\n", "100:\tlearn: 0.4272007\ttotal: 2.6s\tremaining: 23.2s\n", "200:\tlearn: 0.4044455\ttotal: 4.59s\tremaining: 18.3s\n", "300:\tlearn: 0.3928060\ttotal: 6.28s\tremaining: 14.6s\n", "400:\tlearn: 0.3852512\ttotal: 7.88s\tremaining: 11.8s\n", "500:\tlearn: 0.3750366\ttotal: 9.58s\tremaining: 9.54s\n", "600:\tlearn: 0.3624703\ttotal: 12.5s\tremaining: 8.28s\n", "700:\tlearn: 0.3493490\ttotal: 15s\tremaining: 6.39s\n", "800:\tlearn: 0.3390201\ttotal: 17.3s\tremaining: 4.3s\n", "900:\tlearn: 0.3301084\ttotal: 19.5s\tremaining: 2.14s\n", "999:\tlearn: 0.3224542\ttotal: 21.6s\tremaining: 0us\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = CatBoost(params={'loss_function': 'Logloss', 'one_hot_max_size': 255, 'verbose': 100})\n", "model.fit(train_pool)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predict probabilities:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "test_pool = Pool(data=X[0:1], cat_features=cat_features_index)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.88003974, 0.11996026]])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict(test_pool, prediction_type=\"Probability\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Save model:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "model.save_model(\n", " \"titanic.mlmodel\",\n", " format=\"coreml\",\n", " export_parameters={\n", " 'prediction_type': 'probability'\n", " }\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All the features are named as \"feature_i\" where i is a feature number in the dataset starting from 0." ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "Now you can import saved model to XCode and use it directly from swift:\n", "\n", "```swift\n", "import CoreML\n", "\n", "let model = titanic()\n", "\n", "let passengerId = \"1\"\n", "let pclass = \"1\"\n", "let sex = \"female\"\n", "let age = 38.0\n", "let sibsp = \"1\"\n", "let parch = \"0\"\n", "let fare = 71.2833\n", "let cabin = \"C85\"\n", "let embarked = \"C\"\n", "\n", "guard let titanicOutput = try? model.prediction(feature_0: passengerId, feature_1: pclass, feature_2: sex, feature_3: age, feature_4: sibsp, feature_5: parch, feature_6: fare, feature_7: cabin, feature_8: embarked) else {\n", " fatalError(\"Unexpected runtime error.\")\n", " }\n", "\n", "print(String(\n", " format: \"Probability of survival: %1.5f\",\n", " titanicOutput.prediction[0].doubleValue\n", "))\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you want to practice, titanic model is easy to integrate into Apple's [MarsHabitatPricer](https://developer.apple.com/documentation/coreml/integrating_a_core_ml_model_into_your_app) example project:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }