Introduction

Infant mortality, death of infants less than 1 year, is the most important indicator for child health development in a country [1]. The reduction in child mortality particularly infants death is one of the most important Millennium Development Goals (MDGs) [2]. Globally, the infant mortality rate (per 1000 live births) decreased to 29 from 65 deaths in 2018 from 1990 [3]. In Bangladesh, this rate also declined to 38 in 2014 from 87 deaths in 1993 [4]. To meet the MDG-4 and sustainable development goals (SDGs), the reduction of infant mortality can play an important role in improving child health [5]. Bangladesh is still far away to achieve the target of SDGs for the infant mortality rate (5 deaths per 1000 live births) [4, 6].

Several studies measured the consequences of infant mortality and also investigated the significant factors affecting infant mortality in different countries. For example, maternal education and antenatal care were potential predictors of infant mortality [7, 8]. The birth interval (duration of two subsequent pregnancies) was identified as a significant factor for infant mortality [9, 10]. Moreover, multiple and preterm births, mother’s age, domestic abuse, place of residence, wealth index and metabolic abnormalities were statistically significant determinants of infant mortality [11, 12]. Higher educational attainment of mother’s, low birth weight and birth size were the most important factors of infant mortality [13, 14]. In Bangladesh, antenatal care during the pregnancy period, place of delivery, wealth status, birth size, and sex of child were identified as significant predictors of infant mortality [4, 15].

Most of the mortality-associated previous studies analysed the data using logistic regression (LR), notably for the dichotomous outcome variable. However, the LR approach is often challenging and not suitable for the estimation of model parameters and predicting mortality. Recently, machine learning (ML) techniques have been developed as further advancement of modelling health data incorporating artificial intelligence and exploring the unknown relationships or patterns from a huge volume of data [16]. In health sciences and medical research, the ML approach is widely used for predicting numerous clinical responses based on multiple covariates [17, 18] and for others: fever detection from twitter [19], identification of diabetic eye diseases [20], epilepsy detection [21], cardiac arrhythmia detection [22] etc. Therefore, various well-known ML techniques: decision tree (DT), random forest (RF), support vector machine (SVM) and LR were assessed via simulations for the classification of infant mortality and identification of its predicting factors in Bangladesh. The systematic assessment of these ML techniques was done through simulation studies by comparing their accuracy, sensitivity, specificity, precision, F1-score, receiver operating characteristic (ROC) and k-fold cross-validation.

Materials and methods

Data and variables

This study used representative cross-sectional survey data of infant mortality extracted from the latest Bangladesh Demographic and Health Survey (BDHS), 2017–2018 [23]. A two-stage stratified random sampling design was used to collect data. In the first stage, enumeration areas (EAs) were selected using probability proportional to size (PPS) of EAs. A complete list of households was then carried out in all selected EAs to provide a sampling frame for the second-stage household selection process. This stage of sampling was then performed using a systematic sampling design which selected 30 households per EAs on average to provide statistically reliable estimates of key demographic and health variables for the country. Finally, the survey selected 20250 households in total and 20,100 women of reproductive age group 15–49 years were interviewed. The detailed information can be accessed at https://dhsprogram.com/data/available-datasets.cfm. In total, 26145 infant records were used after filtering out all missing cases in this study.

The binary indicator (yes, no) of infant mortality was considered as the outcome variable of interest. Based on literature various maternal, socio-economic, demographic and environmental factors: mother’s age at first marriage, age at first birth, body mass index (BMI) of mother’s, birth interval, antenatal care visits, tetanus toxoid (TT) injection, administrative division, place of residence, religion, parents education, mother’s occupation and empowerment, media-exposure, children ever born, gender of child, birth order, toilet facility, drinking water and type of cooking fuel were considered as covariates of infant mortality.

Machine learning models

We adapted four different machine learning techniques: decision tree (DT), random forest (RF), support vector machine (SVM), and logistic regression (LR). The DT, commonly used machine learning technique, that develops prediction algorithm for the target variable. The total population is divided into branch-like sectors forming an inverted tree with roots, inner and leaf nodes in this approach. The algorithm can handle a large amount of data without imposing a complicated parametric structure [24, 25]. The RF is a tree-structured classifier technique that depends on a collection of random variables. The number of trees and the highest depth of each tree are identified using hyper-parameters in the RF technique [26]. The RF is a group-learning method for the classification of outcome variables utilizing a large number of decorrelated decision trees [27]. The SVM is a supervised machine learning algorithm that can be used to analyze the data for classification and regression problems [28, 29]. The SVM training process creates a model or classification function that assigns new observations to one group on any side of the hyper-plane, generating a non-probabilistic dichotomous linear classifier for a two-groups learning problem. In the SVM approach, the kernel trick is employed for mapping the data in a multi-dimensional space before addressing the machine learning job treated as an optimization problem [30]. The best performed sigmoid kernel was used in this study. The LR is a commonly used statistical technique for classification problems and predicting the probability of occurrence of an event (infant mortality). The association between infant mortality (binary response) and a set of covariates is modelled in the LR approach [27].

Prediction performance parameters

We considered the prediction performance parameters of the confusion matrix, a graphical technique of true against predicted classifications in the form of true positive \((T^{+ve})\), false positive \((F^{+ve})\), true negative \((T^{-ve})\) and false negative \((F^{-ve})\). The different performance measures: accuracy (total data values correctly classified into the true groups), sensitivity (correctly classifying data values in a positive group), specificity (accurately classifying data points in a negative group) and precision (number of data values accurately classified from the positive group) of a confusion matrix are calculated as: accuracy = \((T^{+ve}+T^{-ve})/(T^{+ve}+T^{-ve}+F^{+ve}+F^{-ve})\), sensitivity = \(T^{+ve}/(T^{+ve}+F^{-ve})\), specificity = \(T^{-ve}/(T^{-ve}+F^{+ve})\) and precision = \(T^{+ve}/(T^{+ve}+F^{+ve})\) [27]. The receiver operating characteristic (ROC) curve is a two-dimensional diagram used to visualize, organise, and select classifiers depending on machine learning model performances [31]. The true-positive data points are plotted against the false-positive values to reveal the sensitivity of the classifier. The area under the receiver operating characteristic curve (AUC) measures how well a performance parameter classifies data values in two diagnostic groups. The true positive rate increases and the AUC value is approximately 1 for the outstanding classifier [24].

k-Fold cross-validation

The cross-validation (verification approach) is used for evaluation of the generalization capability of a model to an independent data set [27]. The training data set is randomly subdivided into k sub-samples (mutually exclusive folds) of equal sizes in the k-fold cross-validation technique. The model is trained k times (or folds) iteratively, where k sub-samples are used for testing (cross-validation) in every iteration and the remaining \((k-1)\) sub-samples are then used for training the model. An average of k cross-validation results is used for estimating the accuracy of ML techniques [27]. The 3, 5, 10 and 20-folds cross-validation approaches were used to evaluate the predictive performance of classifiers for a large sample size (n = 26145) considered in this study.

Results

Features selection of infant mortality

We first conducted chi-square (\(\chi ^2\)) tests for the feature selection of infant mortality [32]. Based on literature, initially twenty one exposure variables were considered to examine the association between these variables and infant mortality. The statistically significant associations were observed between infant mortality and covariates: mother’s BMI (\(\chi ^2=52.295\), \(p<0.001\)), birth interval (\(\chi ^2=957.434\), \(p<0.001\)), region (\(\chi ^2=27.940\), \(p<0.05\)), religion (\(\chi ^2=7.723\), \(p<0.05\)), maternal education (\(\chi ^2=82.227\), \(p<0.001\)), father’s education (\(\chi ^2=63.747\), \(p<0.001\)) mother’s occupation (\(\chi ^2=10.889\), \(p<0.001\)), mass-media exposure (\(\chi ^2=9.662\), \(p<0.05\)), wealth index (\(\chi ^2=25.853\), \(p<0.001\)), gender of child (\(\chi ^2=46.344\), \(p<0.001\)), birth order (\(\chi ^2=12.302\), \(p<0.05\)), children ever born (\(\chi ^2=252.454\), \(p<0.001\)), households toilet facility (\(\chi ^2=14.479\), \(p<0.001\)), and households cooking fuel (\(\chi ^2=8.757\), \(p<0.001\)).

In addition, the Boruta algorithm, implemented in R-package: Boruta [33], was used to find out potential predictors of infant mortality based on the mean decreased accuracy [34]. The output is shown in Fig. 1. The selected potential predictors with the green box, and the tentative and rejected ones (which hold less mean decrease accuracy) with the yellow and red boxes, respectively. Hereafter, the Boruta algorithm identified a total of seventeen features including three additional predictors (mother’s age at first marriage, age at first birth and place of residence) with fourteen features extracted from the \(\chi ^2\) test.

Fig. 1
figure 1

Features selection using the Boruta algorithm

Evaluation of models

The selected important possible predictors (using the Boruta algorithm and chi-square test) of infant mortality were considered to assess the performances of various ML techniques using the parameters of confusion matrix and area under the ROC curve (Fig. 2). The performance parameters: accuracy scores, sensitivity, specificity, precision, and F1-score calculated from the confusion matrix are summarized in Table 1. It was observed that the LR approach performed better compared to other machine learning techniques based on its higher accuracy and specificity scores for the extracted predictors using both of the methods (Boruta and \(\chi ^2\) test). Although, the LR model produces the highest accuracy (accuracy = 0.903) and specificity score (specificity = 1.00) in both scenarios, but it completely failed to assess the precision of the test. The sensitivity and F1-score were also zero for the LR model in both cases. However, the RF model performed better for the selected important predictors associated with infant mortality using the Boruta algorithm and chi-square test. For example, when the important determinants were extracted using the Boruta algorithm, the RF technique showed 89.3% of precise predictions (accuracy = 0.893), 33.9% of true positive cases (sensitivity=0.339), 97.9% of true negative cases (specificity=0.979), 71.5% of correct positive predictions (precision = 0.715), and 46.0% of reasonable precision and recall (F1-score = 0.460). The RF model also performed better when the potential predictors associated with infant mortality were identified using the \(\chi ^2\) test.

Table 1 Accuracy, sensitivity, specificity, precision and F1-score of different machine learning models
Fig. 2
figure 2

ROC curves to predict infant mortality of Bangladesh using DT, RF, SVM and LR based on the selected potential factors using Boruta algorithm in (a) and \(\chi ^2\) test in (b)

Figure 2 shows the predicted AUC of different machine learning techniques: DT, RF, SVM, and LR, which were performed in Python (version 3.7.3) using the random seed 1119 for the selected risk factors (Boruta algorithm and \(\chi ^2\) test). It was observed that the LR model showed the maximum AUC compared to other ML techniques, but this model failed for the classification of true positive cases of infant mortality. Thus, the performance of the RF model was relatively higher in all scenarios. However, these performances were evaluated based on a single realization (run with the random seed 1119). Hence, a simulation study is more desirable to explore the overall performances of the ML models.

Simulation study

A simulation study and the k-fold cross-validation were conducted to assess the overall performances of different ML techniques for the prediction of infant mortality in Bangladesh. In each run, we simulated 1000 times with the random seed 1111 to 2111 using 70% training data and 30% test data. The five performance measures of the confusion matrix with their uncertainty estimates based on 1000 simulations are organized in Table 2. Based on the simulation results, the RF technique performed better when the associated factors were identified using the Boruta algorithm as well as the chi-square test.

Table 2 The mean estimated value of performance parameters with their uncertainty estimates of different ML models based on 1000 simulations

The summary results of k-fold cross-validation for various ML techniques with 10, 20 and 30-folds repetitions with random seed number 1 and shuffle argument True are given in Table 3. It was observed that the mean accuracy values from each fold were higher and standard errors of accuracy scores were also minimum for the LR model than other ML techniques. But, the LR model was unable to adequately predict the true positive cases of infant mortality, as the sensitivity score was 0 and the precision could not be calculated. Thus, to predict infant mortality, the RF (avoiding the LR model performances) was found to be the most suitable modelling technique compared to other machine learning approaches considered in this study.

Table 3 Results of k-fold cross-validation for different machine learning models

Predicting infant mortality using the random forest (RF)

For the full data set, the best performing ML algorithm, the RF technique, was fitted to predict infant mortality using the extracted potential factors based on the Boruta algorithm and the \(\chi ^2\) test. Model interpretation is an important part of the evaluation process [35]. The explanation of the association between model variables and output is quite challenging for “black box” models like RF, since the data is hidden within the model structure. Feature contributions were calculated independently for each selected factor, providing precise information about the association between variables and the predicted value of infant mortality. These results are organized in Figs. 3 and 4 , where the selected factors were identified using the Boruta algorithm and the \(\chi ^2\) test, respectively.

Fig. 3
figure 3

Feature contributions for the selected features based on the Boruta algorithm using the fitted RF model to predict infant mortality in Bangladesh

Figure 3 illustrates that among the seventeen selected factors using the Boruta algorithm for the prediction of infant mortality, feature contributions were higher for administrative division (20.0%), father’s education (9.5%), mother’s education (8.3%), birth interval (8.2%), mother’s BMI (7.2%), and wealth index (7.0%). Conversely, feature contributions were lower for mother’s age at first marriage (3.1%), age at first birth (2.6%), religion (2.1%), cooking fuel (2.0%), and total children ever born (1.6%).

Fig. 4
figure 4

Feature contributions for the selected features based on the \(\chi ^2\)-Test using the fitted RF model to predict infant mortality

Figure 4 explains that among the fourteen selected variables using the \(\chi ^2\) test, feature contributions were higher for administrative division (24.4%), father’s education (10.8%), mother’s education (9.3%), birth order (8.4%), mother’s BMI (7.7%), and wealth index (7.6%). Whereas, feature contributions were lower for households toilet facility (5.1%), child’s gender (3.9%), religion (2.4%), cooking fuel (2.3%), and total children ever born (1.9%) to predict infant mortality in Bangladesh.

Discussion and conclusion

Globally, infant mortality is a major public health concern. In our study, four different well-known machine learning techniques: DT, RF, SVM and LR were adapted to assess the potential predictors of infant mortality in Bangladesh. These techniques were evaluated by using 70% values as the training data and remaining 30% observations as test data. All the machine learning techniques were evaluated based on the selected important features extracted by using both the Boruta algorithm and chi-square test. The predictive performances of all models were compared using the different performance parameters (accuracy, sensitivity, specificity, precision, and F1-score) of the confusion matrix, the AUC, and the k-fold cross-validation techniques via simulations.

The traditional \(\chi ^2\) test identified fourteen factors: administrative division, birth interval [9, 10], religion, body mass index (BMI), education of parents [7, 13], occupation of mother, children ever born, gender of child [4, 15], exposure of media, birth order [36], wealth index, toilet facility and cooking fuel as significant predictors of infant mortality in Bangladesh, while the Boruta algorithm also suggested three more additional predictors: age at first marriage, age at first birth and place of residence [37]. Therefore, infant mortality was significantly higher among mothers who got married in their teens (before 18 years) and gave their first birth before 20 years [38, 39]. Interval of pregnancy was significantly associated with infant mortality and hence the short intervals between two pregnancies (\(\le \) 2 years) can increase the risk of infant mortality [9, 10, 36]. Along with birth interval, infant mortality was higher among women from rural than urban areas [9]. Educational attainment of both mothers and fathers was also found to be a significant factor for infant mortality. Similar to other studies, a lower risk of infant death has been observed among higher educated mothers and fathers compared with their counterparts [7, 8, 13]. The prevalence of infant mortality decreased for mothers who had mass-media exposure, higher birth order and better sanitation facility [40, 41]. The poor economic status of the family also leads to high infant mortality [42].

Though accuracy and specificity scores were found to be higher for the LR model (Table 1), but this model failed to accurately classify true positive cases of infant mortality. Therefore, the LR was not capable to compute the precision and hence provided 0 (zero) for both the sensitivity and F1-score. On the other hand, the random forest technique predicted the infant mortality better compared to other machine learning approaches. Based on simulation results, the RF technique was found to be superior among all different machine learning models considered in this study with an accuracy of 88.90%, a sensitivity of 4.80%, a specificity of 97.89%, a precision of 19.6%, and a F1-score of 7.71% when the important features were extracted using the Boruta algorithm. Moreover, the highest mean of accuracy scores for 10-fold (MAcc = 88.78%), 20-fold (MAcc = 88.82%) and 30-fold (MAcc = 88.77%) repetitions appeared in the RF classification technique.