Tree Models

반응형
Tree Models

1. 의사결정나무(Decision tree)

  • 의사결정나무는 지니 불순도(Gini Impurity) 등의 기준을 사용하여 노드(node)를 재귀적으로 분할하면서 tree 모형을 만드는 방법입니다.
  • if ~ then, else 의 조건문과 같은 형식으로 구성되어 있어서 이해하기 쉽고 처리속도가 비교적 빠르며, 여러 가지 feature들간의 상호 작용을 잘 표현해주고 다양한 데이터에 적용시킬 수 있다는 장점이 있습니다.

1. 1. 노드를 나누는 기준

  • 노드에는 분류의 시작점에 해당하는 Root Node, 그리고 제일 하단의 잎사귀 노드 Leaf Node가 존재합니다.
  • 데이터가 얼마나 잘 분리되어 있는지 불순도를 기준으로 평가하며 불순도 함수 \(f\)가 존재할 때 임의의 노드 \(A\)의 불순도 \(I(A)\)는 다음과 같이 정의합니다. 여기서 \(C\)는 분류의 개수이며 \(p_{iA}\)는 노드 \(A\)에 속한 분류 \(i\)인 표본의 비율입니다.

    \[I(A) = \sum^{C}_{i=1} f(p_{iA})\]

  • 가장 기본적으로 이용되는 불순도 함수는 지니 불순도로 다음과 같이 정의됩니다.

    \[f(p)=p(1-p)\]

2. 적용

2. 1. rpart()

  • 의사 결정 나무를 만드는 다양한 패키기 중 rpart 패키지 안에 내장되어 있는 rpart() 함수를 이용해보겠습니다.

    library(rpart)
    • rpart 패키지는 CART(Classification and Regression Trees)의 아이디어를 구현한 패키지 입니다.
    rpart(formula,
          data)
    • formula : 모형식, 반응변수 ~ 설면변수들
    • data : 모형식을 적용할 데이터프레임
  • 기본적으로 iris 데이터에 대해 rpart() 함수를 적용한 결과 입니다.

    m1 <- rpart(Species ~ ., 
                data = iris)
    m1
    ## n= 150 
    ## 
    ## node), split, n, loss, yval, (yprob)
    ##       * denotes terminal node
    ## 
    ## 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  
    ##   2) Petal.Length< 2.45 50   0 setosa (1.00000000 0.00000000 0.00000000) *
    ##   3) Petal.Length>=2.45 100  50 versicolor (0.00000000 0.50000000 0.50000000)  
    ##     6) Petal.Width< 1.75 54   5 versicolor (0.00000000 0.90740741 0.09259259) *
    ##     7) Petal.Width>=1.75 46   1 virginica (0.00000000 0.02173913 0.97826087) *
    • \(n=150\)은 150개의 데이터가 input 되어있음을 의미하고 밑으로 각 노드들이 표시되어 있습니다.
    • 들여쓰기 형식으로 output이 출력됨을 볼 수 있는데 이는 가지가 갈라지는 모양을 의미합니다. ’*’ 모양은 잎사귀 노드를 의미합니다.
  • 이 또한 predict() 함수를 적용하여 예측을 수행할 수 있습니다.

    head( predict(m1, newdata = iris, type = "class") )
    ##      1      2      3      4      5      6 
    ## setosa setosa setosa setosa setosa setosa 
    ## Levels: setosa versicolor virginica
  • 이 밖에 과적합(over-fitting)을 피하기 위해 가지치기(pruning) prune.rpart() 함수와 rpart.control() 함수 등 다양한 성능 튜닝 함수가 존재합니다.

2. 2. ctree()

  • CART의 과적합 등의 문제를 조금 개선한 방법으로 조건부 추론 나무(Conditional Inference Tree) 모형을 언급할 수 있습니다.
  • 조건부 추론 나무는 조건부 분포에 따라 설명변수와 반응변수(분류) 사이의 연관 관계를 측정하여 노드 분할에 사용할 변수를 선택할 수 있고 노드를 반복 분할하면서 적절한 검정 절차를 적용해 적절한 시점에 노드의 분할을 중단할 수 있습니다.
  • 이는 party 패키지 안에 내장되어 있는 ctree() 함수를 이용하여 적용할 수 있습니다.

    library(party)
    ctree(formula, 
          data)
    • rpart() 함수와 인자가 똑같으며 반환되는 객체는 BinaryTree 객체 입니다.
    m2 <- ctree(Species ~ ., data = iris)
    m2
    ## 
    ##   Conditional inference tree with 4 terminal nodes
    ## 
    ## Response:  Species 
    ## Inputs:  Sepal.Length, Sepal.Width, Petal.Length, Petal.Width 
    ## Number of observations:  150 
    ## 
    ## 1) Petal.Length <= 1.9; criterion = 1, statistic = 140.264
    ##   2)*  weights = 50 
    ## 1) Petal.Length > 1.9
    ##   3) Petal.Width <= 1.7; criterion = 1, statistic = 67.894
    ##     4) Petal.Length <= 4.8; criterion = 0.999, statistic = 13.865
    ##       5)*  weights = 46 
    ##     4) Petal.Length > 4.8
    ##       6)*  weights = 8 
    ##   3) Petal.Width > 1.7
    ##     7)*  weights = 46
  • plot() 함수를 이용하여 그림으로도 표현할 수 있습니다.

    plot(m2)

  • 마찬가지로 ctree() 함수를 이용하여 반환된 BinaryTree 객체는 predict() 함수를 사용하여 예측을 수행할 수 있습니다.

    head( predict(m2, newdata = iris, type = "response") )
    ## [1] setosa setosa setosa setosa setosa setosa
    ## Levels: setosa versicolor virginica

3. 랜덤 포레스트(Random Forest)

  • 랜덤 포레스트는 앙상블(Ensemble) 학습 기법을 사용한 모형으로 주어진 데이터로부터 여러 개의 모형을 학습시킨 다음 예측 시 여러 모형의 예측 결과들을 종합해 사용하여 정확도를 기법입니다.
  • 랜덤 포레스트는 의사 결정 나무를 만들 때 데이터의 일부를 복원 추출(sampling with replacement)을 통하여 해당 데이터에 대해서만 의사 결정 나무를 만들고, 노드 내 데이터를 나누는 기준을 정할 때 전체 변수가 아니라 일부 변수만을 대상으로 하여 가지를 나눕니다.
    • 새로운 데이터에 대한 예측을 수행할 때는 여러 가지의 의사 결정 나무가 내놓은 예측 결과를 투표(voting) 방식으로 결정하여 최종 결과를 내리는 형식입니다. 예를 들어, 총 10개의 의사 결정 나무중 과반 이상이 반응변수를 Y로 예측했다면 최종 결과는 Y로 예측되게 됩니다.
  • 랜덤 포레스트는 일반적으로 성능이 뛰어나고 의사 결정 나무 하나가 아니라 여러 개를 사용해 과적합 문제를 피하게 됩니다.

3. 1. randomForest()

  • R에서 randomForest 패키지는 랜덤 포레스트를 구현한 패키지로 함수 역시 randomForest() 입니다.

    library(randomForest)
    randomForest(formula, 
                 data,
                 ntree = 500,
                 mtry,
                 importance = FALSE)
    • formula : 반응변수 ~ 설명변수 모형식
    • data : formula를 적용할 data.frame
    • ntree = 500 : 생성할 나무의 갯수로 기본 값은 500
    • mtry : 노드를 나눌 기준을 정할 때 고려할 변수의 갯수
    • importance = FALSE : 변수의 중요도 평가 여부
  • iris 데이터에 대한 랜덤 포레스트 모형은 다음과 같이 만들 수 있습니다.

    m3 <- randomForest(Species ~ ., data = iris)
    m3
    ## 
    ## Call:
    ##  randomForest(formula = Species ~ ., data = iris) 
    ##                Type of random forest: classification
    ##                      Number of trees: 500
    ## No. of variables tried at each split: 2
    ## 
    ##         OOB estimate of  error rate: 4%
    ## Confusion matrix:
    ##            setosa versicolor virginica class.error
    ## setosa         50          0         0        0.00
    ## versicolor      0         47         3        0.06
    ## virginica       0          3        47        0.06
    • OOB(Out of Bag) estimate of error rate는 모형을 적합시킬 때 사용되지 않은 데이터를 사용한 에러 추정치를 의미합니다.
    • comfusion matrix를 통해 오분류율을 확인할 수 있습니다.
  • importance = TRUE argument를 이용하여 변수의 중요도를 평가하고 모델링에 사용할 변수를 선택하는데 이용할 수 있습니다. output을 출력할 때는 importance(), varImpPlot() 함수를 이용합니다.

    m3 <- randomForest(Species ~ ., data = iris,
                       importance = TRUE)
    importance(m3)
    ##                 setosa versicolor virginica MeanDecreaseAccuracy
    ## Sepal.Length  6.782548   7.681925  7.666529            10.222971
    ## Sepal.Width   4.719956   1.281748  5.697233             5.597411
    ## Petal.Length 23.938576  33.960217 29.349978            35.147853
    ## Petal.Width  21.357472  31.017910 31.744591            33.167970
    ##              MeanDecreaseGini
    ## Sepal.Length         9.656970
    ## Sepal.Width          2.265402
    ## Petal.Length        44.044858
    ## Petal.Width         43.301342
    • MeanDecreaseAccuracy(정확도) 부분에서는 Petal.Length 변수가 가장 중요함을 알 수 있고, MeanDecreaseGini(불순도 개선) 부분에서도 중요한 부분을 알 수 있습니다.
    varImpPlot(m3, main = "Importance of variables")

반응형

'Statistical Modeling & ML > Classification' 카테고리의 다른 글

Logistic Regression Model  (0) 2017.10.06
TAGS.

Comments