오답노트

[ML] 모델에 대한 설명 - Partial Dependence Plots 본문

Python/ML

[ML] 모델에 대한 설명 - Partial Dependence Plots

권멋져 2022. 9. 5. 19:12

Partial Dependence Plots

관심 Feature의 값이 변할때, 모델에 미치는 영향을 시각화한다.

데이터 셋의  모든 row에 관심 Feature 모든 값을 돌면서 하나로 통일하고 예측한다.

그 값들의 평균을 계산한다.

 

from sklearn.inspection import plot_partial_dependence

var = 'rm'
plt.rcParams['figure.figsize'] = 12, 8
plot_partial_dependence(model, 
                        features = [var], 
                        X = x_train, 
                        kind = 'both')
plt.grid()
plt.show()

 

예측할 모델과 관심 Feature 그리고 데이터 셋을 입력해 Partial Dependence Plots를 출력한다.

kind 옵션은 both일 경우 average와 개별 instance와 함께 값을 그린다.

 

plot_partial_dependence 결과

위 그래프를 분석해보면, rm이 6.7일 때와 7.4 정도 일 때 급격하게 y에 대한 예측값이 높아지는 것을 알 수 있다.

 

plot_partial_dependence(model, features = ['rm','lstat'], X = x_train)
plt.show()

 

feature에 리스트로 복수개의 변수를 넣으면 해당 변수마다 Partial Dependence Plots을 그린다.

 

위 그래프를 보면 rm이 증가할 수록 예측값은 상승하고, lstat가 증가할 수록 예측값은 감소한다.

 

plot_partial_dependence(model, features = [('rm','lstat')], X = x_train)
plt.show()

feature에 튜플로 변수를 입력하게 되면 한 그래프에서 예측값에 대한 변화를 확인 할 수 있다.

위 그래프를 보면 rm이 6.7일 때와 7.4 정도 일 때 lstat 값의 거의 상관없이 예측값이 증가한다. 또, rm이 7.5 부근을 넘어가면 lstat 값의 상관없이 예측값이 증가한다.