- breast_cancer 데이터를 이용하여 적합한 randomforest의 트리 개수와, 특정값을 구해보자
# 1 데이터 호출
from sklearn.datasets import load_breast_cancer
breast_cancer_dataset = load_breast_cancer()
# 1-1 데이터 분류
X_train,X_test,y_train,y_test = train_test_split(X,y, test_size = 0.3, random_state = 10)
# 2 n_estimators
depths=range(5,201,5)
forest_train,forest_test = [],[]
for i in depths:
forest = RandomForestClassifier(n_estimators=i, random_state=i)
forest.fit(X_train,y_train)
forest_train.append(forest.score(X_train,y_train))
forest_test.append(forest.score(X_test,y_test))
plt.plot(depths,forest_train,label='훈련 정확도',marker='o')
plt.plot(depths,forest_test,label='테스트 정확도',marker='o')
plt.ylabel('accuracy')
plt.xlabel('number of n_estimators')
plt.legend()
2. n_n_estimators 는 randomforest에서 의사결정나무의 개수를 정해준다고 했습니다.
그럼 적합한 나무의 개수가 무엇인지 확인하기 위해 range를 통해 다양한 숫자를 넣어 그래프로 train,test의 정확도를 확인해보았습니다. (random_state도 같이 증가시켜 확인해보았습니다)
위의 그래프를 보면 test 정확도가 25일때 가장 최대이며, 더 이상 높아지지 않다는 것을 추정하면 n_estimators = 25로 지정할 수 있습니다.
# 3 max_features
depths=range(1,31)
tree_train,tree_test = [],[]
for i in depths:
forest = RandomForestClassifier(n_estimators=25, random_state=25, max_features=i)
forest.fit(X_train,y_train)
tree_train.append(forest.score(X_train,y_train))
tree_test.append(forest.score(X_test,y_test))
plt.plot(depths,tree_train,label='훈련 정확도',marker='o')
plt.plot(depths,tree_test,label='테스트 정확도',marker='o')
plt.ylabel('accuracy')
plt.xlabel('number of max_features')
plt.legend()
3 max_features 에서는 적합한 features의 수를 구하기 위한 작업입니다.
위에서 높게 나온 n_estimators와 random_state를 25로 지정해높고 1~ 31까지 반복문을 돌려줍니다.
위 그래프에서 test가 제일 높을 때 max_features의 값은 10인 것을 확인할 수 있습니다.
# 최종 확인
forest = RandomForestClassifier(n_estimators=25, random_state=25, max_features=10)
forest.fit(X_train,y_train)
print("train accuracy {:.2f}".format(forest.score(X_train,y_train)))
print("test accuracy {:.2f}".format(forest.score(X_test, y_test)))
마지막 최종확인으로 가장 적합한 값을 넣어서 최종 train, test 정확도를 확인해보았습니다.
train accuracy 1.00
test accuracy 0.99
이상으로 randomforest 모델을 이용하는 것을 공부했습니다.
'머신러닝 in Python' 카테고리의 다른 글
[Python] LinearRegression(수치예측) (0) | 2019.08.28 |
---|---|
[Python] Gradient Boosting (0) | 2019.08.28 |
[Python] RandomForest (0) | 2019.08.27 |
[Python] DecisionTree (0) | 2019.08.27 |
[Python] k-nearest neighbor 예제 (0) | 2019.08.26 |