Issue
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.impute import SimpleImputer
from sklearn.inspection import permutation_importance
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
result = permutation_importance(rf,
X_test,
y_test,
n_repeats=10,
random_state=42,
n_jobs=2)
sorted_idx = result.importances_mean.argsort()
fig, ax = plt.subplots()
ax.boxplot(result.importances[sorted_idx].T,
vert=False,
labels=X_test.columns[sorted_idx])
ax.set_title("Permutation Importances (test set)")
fig.tight_layout()
plt.show()
In the code above, taken from this example in the documentation, is there a way to plot the top 3 features only instead of all the features?
Solution
argsort
"returns the indices that would sort an array," so here sorted_idx
contains the feature indices in order of least to most important. Since you just want the 3 most important features, take only the last 3 indices:
sorted_idx = result.importances_mean.argsort()[-3:]
# array([4, 0, 1])
Then the plotting code can remain as is, but now it will only plot the top 3 features:
# unchanged
fig, ax = plt.subplots(figsize=(6, 3))
ax.boxplot(result.importances[sorted_idx].T,
vert=False, labels=X_test.columns[sorted_idx])
ax.set_title("Permutation Importances (test set)")
fig.tight_layout()
plt.show()
Note that if you prefer to leave sorted_idx
untouched (e.g., to use the full indices elsewhere in the code),
either change
sorted_idx
tosorted_idx[-3:]
inline:sorted_idx = result.importances_mean.argsort() # unchanged ax.boxplot(result.importances[sorted_idx[-3:]].T, # replace sorted_idx with sorted_idx[-3:] vert=False, labels=X_test.columns[sorted_idx[-3:]]) # replace sorted_idx with sorted_idx[-3:]
or store the filtered indices in a separate variable:
sorted_idx = result.importances_mean.argsort() # unchanged top3_idx = sorted_idx[-3:] # store top 3 indices ax.boxplot(result.importances[top3_idx].T, # replace sorted_idx with top3_idx vert=False, labels=X_test.columns[top3_idx]) # replace sorted_idx with top3_idx
Answered By - tdy
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.