Issue
I would like to instantiate a sklearn estimator thanks to its name. For example :
name = 'RandomForestClassifier'
clf = BaseEstimator(name)
print(clf) # return something like RandomForestClassifier(...)
What I have tried is clf = eval(name)
but I can't do clf.fit(X, y)
because clf is <class 'sklearn.ensemble._forest.RandomForestClassifier'>
instead of something like RandomForestClassifier()
.
I can't find how to create a sklearn estimator thanks to its name.
Solution
As far as I understand, what you want to achieve is to dynamically instantiate existing classifiers based on some user input given in string. BaseEstimator
is the base class for all estimators so I think you should use it only in case you are planning to build your own Estimator.
Now I don't know how did you manage to run the above code as BaseEstimator()
doesn't take any arguments:
clf = BaseEstimator('RandomForestClassifier')
Traceback (most recent call last):
File "/usr/local/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-6-59dd1a62618a>", line 1, in <module>
clf = BaseEstimator('RandomForestClassifier')
TypeError: BaseEstimator() takes no arguments
In any case, the closer you can get to a dynamic instantiation of classifiers is the following:
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
def get_clf(name):
if name not in ('RandomForestClassifier', 'DecisionTreeClassifier'):
raise ValueError(f"{name} is not a recognised option")
classifiers = {
"DecisionTreeClassifier": DecisionTreeClassifier,
"RandomForestClassifier": RandomForestClassifier
}
classifier = classifiers[name]
return classifier()
Alternatively, you can go with a more straight forward solution:
def get_clf(name):
if name not in ('RandomForestClassifier', 'DecisionTreeClassifier'):
raise ValueError(f"{name} is not a recognised option")
if name == 'RandomForestClassifier':
return RandomForestClassifier()
elif name == 'DecisionTreeClassifier':
return DecisionTreeClassifier()
else:
raise RuntimeError()
Answered By - Giorgos Myrianthous
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.