Customer Churn Prediction using Decision Trees

The subscription market has become more challenging with the increment in available options. For example, Netflix Prime, Hotstar, or Zee5 all these OTT (Over-the-top) platforms are competing to cover most market share. This can be ensured by retaining existing customers and attracting more new customers. With the advancement in Data Science and Machine learning, companies can predict the behaviour of potential customers who can cancel their subscriptions.

Companies spend more money on retaining customers than focusing on new customers. Hence, they invest heavily in knowing which customers might stop using the company's products or services, and the term used for this is Customer churn. Besides running customer retention campaigns and product publicity programs, churn prediction is a common practice followed by companies working in competitive markets.

In this article, we will be building a decision tree-based Churn rate prediction model for an investment bank so that they can retain more customers.

Key Points

  1. Problem Statement
  2. Understanding Churn prediction Dataset
  3. Data preprocessing Steps
  4. Exploratory Data Analysis
  5. Model Building
  6. Model Evaluation

First, let's understand how companies lose customers.

Why do Industries use customer churn prediction?

Let Jack be a customer of X bank for more than ten years. Now, with the fast-growing inflation rate and the slow-growing interest rate provided by the bank, he is concerned about his savings. Plus, he is disappointed by the outdated product designs and the tedious application process for new services. So Jack decides to transfer his account to another bank. 

A report by Sells Agency states that 20% of customers leave their banks due to poor in-bank experience and tiring procedures, and 3.6% leave their banks for new products and services. 

Industries like airline ticket services, telecommunication, OTT platforms, banking, and finance need to retain customers. These all are competitive market industries and constantly hustle to grab the majority of market share. To accomplish this, they continuously collect customer data to build ML models and learn customer behaviour. One such model is the customer churn prediction which we are building in this blog.

Problem statement for Churn Prediction

Understanding the problem statement is the first step of model building. Here we are given a dataset of bank customers, and we want to build a churn prediction model that can classify customers into two classes: potential churners and non-churners. This classification problem is solved via a supervised learning approach, meaning we will have labelled Input and Output data samples. This model should be precise as it might affect the future investments of the bank. Now, let's study the Dataset in detail.

Understanding the Dataset for churn prediction

The Dataset we used to build the model is available on Kaggle and can be downloaded from "" There are 14 columns/features and 10k rows/samples.

Let's see the effect of these features on the model's prediction.

  • CreditScore — can affect customer churn since a customer with a higher credit score is less likely to leave the bank.
  • Geography — a customer's location can affect their decision to leave the bank.
  • Tenure — refers to the number of years the customer has been a bank's client. Usually, older clients are more loyal and less likely to leave a bank.
  • Balance — is also an excellent indicator of customer churn, as people with a higher balance in their accounts are less likely to leave the bank than those with lower balances.
  • NumOfProducts — refers to the number of products a customer has purchased through the bank.
  • HasCrCard — denotes whether or not a customer has a credit card. This column is also relevant since people with credit cards are less likely to leave the bank.
  • IsActiveMember — active customers are less likely to leave the bank.
  • EstimatedSalary — as with balance, people with lower salaries are more likely to leave the bank than those with higher wages.
  • Exited — The target column tells whether the customer is a potential churner.

Other columns are RowNumber, CustomerId, Surname, Gender, and Age, having their usual meaning.

Let's look at some rows of the Dataset and the different datatypes used to store them.

df=pd.read_csv("N:\Machine learning\Algorithms\churn.csv")

Customer churn prediction data initial samples

Here we can infer that columns Rownumber, CustomerId, and Surname only describe customers uniquely and do not affect the target variable, so we can drop these columns. The other observation is we have two more categorical columns, geography, and gender, which we will encode using labelencoder() provided by the sklearn library of python.

Now let's jump to the data preprocessing steps and clean the data.

Data Preprocessing steps

Data preprocessing is cleaning the Dataset before feeding it to the model. There are some standard preprocessing steps we must follow.

1. Checking the presence of null values and duplicates and dropping them.


RowNumber        0
CustomerId       0
Surname          0
CreditScore      0
Geography        0
Gender           0
Age              0
Tenure           0
Balance          0
NumOfProducts    0
HasCrCard        0
IsActiveMember   0
EstimatedSalary  0


There is no missing or duplicate value in any column. Hence we do not need to drop anything. Otherwise, we can use df.dropna() to drop them.

2. Removing outliers.

Outliers are those data samples that are present far from the other data samples. They drastically affect the learning of the model and manipulate the predictions towards them. There are two standard methods to remove outliers: 1. Inter Quartile Range (IQR) or 2. Standard deviation. We will be using the IQR method to detect the presence of outliers:

for i in numcols:
    q75, q25 = np.percentile(df[i], [75 ,25])
    iqr = q75 - q25
    min_val = q25 - (iqr*1.5)
    max_val = q75 + (iqr*1.5)

We have removed outliers from the customer dataset and can validate from the boxplot.

fig, ax = plt.subplots(2,2, figsize = (15,15))

for i, subplot in zip(numcols, ax.flatten()):
    sns.boxplot(x = 'Exited', y = i , data = df, ax = subplot, palette=color_0_1)

Box plot after removing the outliers from the churn rate prediction data

The age column still has some outliers but fewer than before data preprocessing. The plots show that the data is balanced around the box's centre line except for the Balance column. Boxplot also helps in finding out the distribution of data around the quartiles.

3. Feature scaling

Feature scaling is generally referred to as scaling and normalizing the data. We have discussed in this blog that the updation of parameters is affected by the magnitude of features, and that's why they must be scaled in the same range. But tree-based algorithms are immune to feature scaling because different feature magnitudes and variances do not affect them. Therefore, it's unnecessary to scale data for our model.

Now, we will perform exploratory data analysis to understand the trend followed by the data.

Exploratory Data Analysis of Bank Customers Dataset

This section will present some data visualizations of the bank customer dataset. 

1. Heatmap showing the correlation of features.

Heatmap is the tabular representation of the correlation of features. We study heatmaps to determine features' dependency on other features and the target variable. Observing the heatmap below, the features are independent because their correlation values are near zero.

import matplotlib.pyplot as plt
import seaborn as sns

sns.heatmap(df.corr(), cmap='Blues', annot=True)

Heatmap analysis to select the final set of features for customer churn rate prediction

The other use of heatmap is in feature selection. We generally drop those highly correlated features (positively and negatively) and keep only one. Read more about why we need feature selection here

2. Counterplots to show the trend of features w.r.t target

Counterplots are best for observing feature values with the target variable. We can see that females have more churners, and people with credit cards are less likely to turn their back on the banks. Similarly, active customers with their transactions or bank visits are less likely to leave or change the bank.

fig, ax = plt.subplots(1,3, figsize = (15,15))

for i, subplot in zip(categorical_features, ax.flatten()):
    sns.countplot(x = i, hue="Exited", data = df, ax = subplot, palette=color_0_1)

Counterplot analysis to see which features are truly affecting the prediction

These observations suggest that banks should spend more money on maintaining the quality of service and retaining their customers. Products like credit cards, loans, tax rebates, and fixed deposits are essential to hook customers and ensure their engagement.

Model building

We are ready to build our Customer churn prediction model! But first, we must select which machine learning algorithm is best for churn prediction.

We have a small dataset of 10000 rows. Thus, we need an algorithm that learns well even with a small training data size. A decision tree is the best option because tree-based algorithms are easy to implement, have great explainability, work well with a small dataset, and require minimal data preprocessing.

from sklearn.tree import DecisionTreeClassifier

dtc = DecisionTreeClassifier(),Y_train)
############### Tuning the parameters  ##############
dtc_new = DecisionTreeClassifier(criterion = 'entropy', min_samples_split = 10, min_samples_leaf = 6 , max_features = 'sqrt', random_state = 1),Y_train)


Hyperparameters to train in the decision tree

We can import the Decision Tree Classifier from the sklearn library and change the default parameter values. 

  1. Criterion specifies the criteria on which we want to select the next node. The default criterion is the Gini index. It measures the probability of an instance being misclassified when chosen randomly. In contrast, entropy measures impurity or randomness in the data points. Please refer here for the mathematical explanation of entropy and the Gini index.
  2. minsamplessplit tells the model a minimum number of samples that can be split into the next level. The ideal value is between 8–10.
  3.  minsamplesleaf tells the model the minimum samples the next level should have. The ideal value is between 2–6. 

Other parameters, like maxdepth, maxfeatures, and class_weight, can also be tuned. Explore more about them here.

Let's evaluate the model's performance before and after tuning the parameters.

Model Evaluation

Evaluation metric measures how well the model performs on train and test data. Evaluation metrics are decided based on problem type. For a regression problem, mean square error, root mean square error, and R squared are the most commonly used metrics. Confusion matrix, accuracy, F1 score, and area under the curve (AUC) are the most common metrics for a classification problem. For a detailed explanation of evaluation metrics, hit this link.

Confusion Matrix

To see the actual predictions of the model confusion matrix is best. It divides predictions into four classes true positive, true negative, false positive, and false negative. As we can see from the below matrix, a higher value in the true negative column means that our model is good at predicting non-churners but needs to be tuned more for better performance.

    sns.heatmap(cm, annot=True, 
    linewidths=1, square = True, cmap = 'Blues_r', fmt='0.4g')
    plt.ylabel('Actual label')
    plt.xlabel('Predicted label')
        'model': model_name,
        'best_score': clf.best_score_,
        'best_params': clf.best_params_

Confusion matrix for Decision tree to find the accuracy for churn rate prediction

ROC: Receiver Operating Curve

We try to optimize our model performance by changing the threshold value and deciding the best fit for the model. Imagine plotting a confusion matrix each time we vary the threshold. Yes, the confusion matrix will result in a lot of confusion. So we use ROC in its place. Considering all threshold values, it plots a graph between the True positive rate and the False positive rate. The optimum threshold value will be the point with the highest true positive rate and the lowest false positive rate. For more details, check out this blog.

ROC curve to evaluate the decision tree model for customer churn prediction

Other evaluation metrics results are also listed below.

def eval_classification(model):
    y_pred = model.predict(X_test)
    y_pred_train = model.predict(X_train)
    y_pred_prob = model.predict_proba(X_test)
    y_pred_prob_train = model.predict_proba(X_train)
    print("Accuracy (Test Set): %.2f" % accuracy_score(Y_test, y_pred))
    print("F1-Score (Test Set): %.2f" % f1_score(Y_test, y_pred))

Model performance before tuning parameters

Accuracy (Test Set): 0.78F1-Score (Test Set): 0.47

Model performance after tuning parameters

Accuracy (Test Set): 0.83F1-Score (Test Set): 0.52

These results could be better but can be a benchmark to start with. The path from good to best is complicated and marked, with many research papers claiming to increase accuracy by 1% or more. Let's look at one of the research papers using an advanced evaluation metric for profit-driven decision trees.

An advanced evaluation metric for churn prediction

A churn prediction model should follow a profit-driven approach. This means that cost added due to wrong predictions and the cost of retaining the customer should be considered while predicting the output. The aim of using churn models is not only to tell potential churners but also to predict those customers who are most profitable to the company and thus worth retaining. ProfTree is one such decision tree-based machine learning algorithm that uses an evolutionary algorithm to learn profit-driven decision trees.

Standard classification metrics like AUC or ROC are not recommended to evaluate the performance of profit-driven decision trees because they treat the cost due to misclassifications as the same. So a new evaluation metric is proposed, "expected maximum profit measure for customer churn," which helps identify the most profitable model. This paper presents a new metric based on customer lifetime value (CLV), according to which a churner is defined as a customer whose CLV is decreasing. CLV is the discounted value of future marginal earnings related to the customer's activity and can be calculated by the following equation.

CLTV = (Customer Value / Churn Rate) x Profit Margin

Customer Value = Average Order Value * Purchase Frequency

For reference, read this paper

Case study


Netflix is unbeatable, with the OTT industry's lowest churn rate of 2.3%. The reason for this is Netflix's recommendation system and data analysis. Netflix analyses everything, which movie is watched for how long, what is the most liked genre, and which web series is trending, and updates its system every 24 hours to catch customers' attention by providing them with something new every time they visit. Over 80% of the content a user watches comes from Netflix recommendations. Netflix constantly monitors its potential churners and keeps on increasing its user experience.

Tata Sky

DTH(Direct-to-home) is one of the fastest-growing industries where customer retention is the key to success. Customers have a plethora of options and will be influenced by offers, service, network quality, and publicity. Tata Sky's dedicated team studies customer recharge patterns and identifies customer cohorts needing focused retention efforts. This helped them immensely during the Covid lockdown period to design packages satisfying customer needs.


Customer churn models will become an integral part of every industry in the future. It is crucial to understand how they work and what are the possible advancements. We discussed the complete process of model development in this blog. From understanding the problem statement, data preprocessing, and data visualization to model evaluation. Then we discussed a research paper presenting an advanced decision tree approach and a new evaluation metric to measure the performance of profit-driven decision trees. We hope you enjoyed the article.

If you have any queries/doubts/feedback, please write us at Enjoy learning, Enjoy algorithms!

More from EnjoyAlgorithms

Self-paced Courses and Blogs