트리 모델 계열의 근간인 결정 트리를 학습시켜 연봉을 예측합니다. 모델링에서 발생하는 중요한 이슈인 오버피팅 개념과 해결 방법을 알아봅니다. 1편은 문제정의, 라이브러리 및 데이터 불러오기, 데이터 확인하기입니다.
학습 순서
결정 트리 소개
결정 트리(Decision Tree)는 관측값과 목푯값을 연결시켜주는 예측 모델로서 나무 모양으로 데이터를 분류합니다. 수많은 트리 기반 모델의 기본 모델(based model)이 되는 중요 모델입니다. 트리 기반의 모델은 선형 모델과는 전혀 다른 특징을 가지는데, 선형 모델이 각 변수에 대한 기울기값들을 최적화하여 모델을 만들어나갔다면, 트리 모델에서는 각 변수의 특정 지점을 기준으로 데이터를 분류 해가며 예측 모델을 만듭니다. 예를 들어 남자/여자로 나눠서 각 목푯값 평균치를 나눈다거나, 나이를 30세 이상/미만인 두 부류로 나눠서 평균치를 계산하는 방식으로 데이터를 무수하게 쪼개어나가고, 각 그룹에 대한 예측치를 만들어냅니다.
TOP10 선정 이유
예측력과 성능으로만 따지면 결정 트리 모델을 사용할 일은 없습니다. 시각화가 매우 뛰어나다는 유일한 장점이 있을 뿐입니다. 하지만 앞으로 배울 다른 트리 기반 모델을 설명하려면 결정 트리를 알아야 합니다. 트리 기반 모델은 딥러닝을 제외하고는 현재 가장 유용하고 많이 쓰이는 트렌드이기 때문에 트리 모델을 필수로 알아둬야 합니다.
예측력과 설명력
예측력이란 모델 학습을 통해 얼마나 좋은 예측치를 보여주는가를 의미하며, 설명력은 학습된 모델을 얼마나 쉽게 해석할 수 있는지를 뜻합니다. 알고리즘의 복잡도가 증가할수록 예측력은 좋아지나 설명력은 다소 떨어지는 반비례 관계를 보여줍니다. 즉, 단순한 알고리즘일수록 예측력이 상대적으로 떨어질 수 있으나 해석에 용이하며, 복잡한 알고리즘은 예측력이 뛰어난만큼 해석은 어렵습니다.
결정트리와 회귀 분석은 상대적으로 해석이 쉬워 설명력이 높다고 할 수 있으며, 9장부터 배울 알고리즘들은 복잡도가 증가하여 예측력이 높지만 해석이 어렵습니다. 딥러닝 또한 매우 복잡한 알고리즘으로 해석이 어려워서 이를 블랙박스에 비유하기도 합니다.
예측력과 설명력 중 어느 쪽을 택해야 하는지는 상황에 따라 다릅니다. 예를 들어 의학 계열에서 특정 질병의 발병률에 대한 예측 모델을 만들 때는, 발병률을 높이거나 억제하는 중요한 요인을 밝히는 데는 설명력이 좋은 알고리즘이 적합할 수 있습니다. 다른 예로 사기거래를 예측하는 모델에서는 요인보다는 더 정확하게 사기거래를 잡아낼 수 있어야 하므로 예측력이 높은 알고리즘이 더 적합할 수 있습니다.
▼ 예시 그래프
▼ 장단점
장점
• 데이터에 대한 가정이 없는 모델입니다(Nonparametric Model). 예를 들어 선형 모델은 정규분포에 대한 가정이나 독립변수와 종속변수의 선형 관계 등을 가정으로 하는 모델인 반면, 결정 트리는 데이터에 대한 가정이 없으므로 어디에나 자유롭게 적용할 수 있습니다.
• 아웃라이어에 영향을 거의 받지 않습니다.
• 트리 그래프를 통해서 직관적으로 이해하고 설명할 수 있습니다. 즉 시각화에 굉장히 탁월합니다.
단점
• 트리가 무한정 깊어지면 오버피팅 문제를 야기할 수 있습니다.
• 앞으로 배울 발전된 트리 기반 모델들에 비하면 예측력이 상당히 떨어집니다.
▼ 유용한 곳
• 종속변수가 연속형 데이터와 범주형 데이터 모두에 사용할 수 있습니다.
• 모델링 결과를 시각화할 목적으로 가장 유용합니다.
• 아웃라이어가 문제될 정도로 많을 때 선형 모델보다 좋은 대안이 될 수 있습니다.
1. 문제 정의 : 한눈에 보는 예측 목표
문제 정의
데이터 분석가와 개발자 몸값이 하루가 다르게 뛰어오릅니다. 현재 내 연봉 수준은 적절한 것일까요? 나이, 교육 수준, 혼인 상태, 직업, 인종 성별 등 항목에 따른 연봉을 기재한 데이터셋이 있습니다. 미국은 다민족 국가라서 인종 정보까지 있는 점이 특이합니다. 결정 트리 알고리즘을 활용해서 연봉 등급을 나눠보겠습니다.
▼ 예측 목표
- 미션: 학력, 교육 연수, 혼인 상태, 직업 정보를 담은 연봉 데이터셋을 이용해 연봉을 예측하라.
- 난이도: ★☆☆
- 알고리즘: 결정 트리(Decision Tree)
- 데이터셋 파일명: salary.csv
- 종속변수: class(연봉 등급)
- 데이터셋 소개: 이번 장에서는 연봉 데이터를 사용합니다. 연봉이 $50,000 이상인지 이하인지를 예측하는 것이 목표이며, 종속변수는 class, 독립변수로는 학력, 교육 연수, 혼인 상태, 직업 등이 있습니다.
- 문제 유형: 분류
- 평가지표: 정확도
- 사용한 모델: DecisionTreeClassifier
- 사용 라이브러리:
- numpy (numpy==1.19.5) • pandas (pandas==1.3.5) • seaborn (seaborn==0.11.2) • matplotlib (matplotlib==3.2.2) • sklearn (scikit-learn==1.0.2)
- 예제 코드:
- 위치 : colab.research.google.com/github/musthave-ML10/notebooks/blob/main/
- 파일 : 08_Decision Tree.ipynb
2. 라이브러리 및 데이터 불러오기, 데이터 확인하기
4가지 필수 모듈과 데이터(salary.csv) 파일을 불러오겠습니다.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
file_url = 'https://media.githubusercontent.com/media/musthave-ML10/data_source/ main/salary.csv'
data = pd.read_csv(file_url, skipinitialspace = True) # ❶ 데이터셋 읽기
❶ csv 파일을 불러옵니다. 매개변수 skipinitialspace는 각 데이터의 첫 자리에 있는 공란을 자동 제거합니다. 예를 들어 ‘ Male’를 ‘Male’로 정리합니다. 이번에 사용할 csv 파일에 불필요한 공란이 많아서 적용했습니다.
그럼 head( ) 함수를 실행해 데이터의 전반적인 모습을 살펴봅시다.
data.head() # 상위 5행 출력
종속변수는 class입니다. unique( ) 함수로 몇 가지 값이 있나 확인하겠습니다.
data['class'].unique() # 고윳값 확인
array([‘<=50K’, ‘>50K’], dtype=object)
50K 이하와 초과 두 가지만 있습니다. 50K에서 K는 천 단위를 뜻합니다. 즉 $50,000을 기준으로 나뉘어진 데이터입니다.
다음은 info( ) 함수로 변수별 형태를 보겠습니다.
data.info( ) # 변수 특징 출력
❶ object형, 즉 텍스트로 구성된 범주형 변수가 많습니다. ❷ Non-Null Count를 보면 결측치가 있는 변수도 몇몇 보입니다. 이번 데이터에서는 이런 범주형 변수들과 결측치가 있는 변수들에 대한 처리를 해줄 게 많아 보입니다.
마지막으로 describe( ) 함수로 통계적 정보를 확인합니다.
data.describe() # 통계 정보 출력
object형의 변수가 많다 보니 변수가 5개만 보입니다. 기본적으로 describe( ) 는 object형의 데이터를 제거하고 통계적 수치를 보여주지만, 매개변수를 이용하여 object형의 데이터까지 보이게 하는 방법도 있습니다.
data.describe(include = 'all') # object형이 포함된 통계정보 출력
include 매개변수를 사용해 object형까지도 출력했습니다. ❶ unique, top, freq 행이 추가되었습니다. 새로 추가된 행들은 오로지 object형의 변수들만을 위한 것이고, 기존의 숫자형 변수들은 NaN으로 처리되어 있으며, 반대로 object형의 변수들은 mean, std와 같은 기존의 통계적 정보가 모두 NaN으로 처리되어 있습니다. unique는 각 변수에서 가지고 있는 고유한 value의 숫자입니다. nunique( ) 함수를 사용했을 때와 같은 수치를 보여줍니다. top은 각 변수별로 가장 많이 등장하는 value가 무엇인지를 보여주며, freq는 top에 나와있는 value가 해당 변수에서 총 몇 건인지를 보여줍니다. 예를 들어 workclass 변수는 해당 변수에는 고유한 value가 8종류이며, Private이 총 33,906번 등장합니다(object 변수들을 이런 식으로 살펴보는 방식은 분석에서 크게 유의미하지는 않으니 참고로만 알아두세요). 잠시 후에 각 범주형 변수들을 더욱 자세히 확인한 뒤에 전처리하겠습니다.
2편은 전처리 : 범주형 데이터, 전처리 : 결측치 처리 및 더미 변수 변환입니다.
삼성전자에 마케팅 직군으로 입사하여 앱스토어 결제 데이터를 운영 및 관리했습니다. 데이터에 관심이 생겨 미국으로 유학을 떠나 지금은 모바일 서비스 업체 IDT에서 데이터 사이언티스트로 일합니다. 문과 출신이 미국 현지 데이터 사이언티스트가 되기까지 파이썬과 머신러닝을 배우며 많은 시행착오를 겪었습니다. 제가 겪었던 시행착오를 덜어드리고, 머신러닝에 대한 재미를 전달하고자 유튜버로 활동하고 책을 집필합니다.
현) IDT Corporation (미국 모바일 서비스 업체) 데이터 사이언티스트
전) 콜롬비아 대학교, Machine Learning Tutor, 대학원생 대상
전) 콜롬비아 대학교, Big Data Immersion Program Teaching Assistant
전) 콜롬비아 대학교, M.S. in Applied Analytics
전) 삼성전자 무선사업부, 스마트폰 데이터 분석가
전) 삼성전자 무선사업부, 모바일앱 스토어 데이터 관리 및 운영
강의 : 패스트캠퍼스 〈파이썬을 활용한 이커머스 데이터 분석 입문〉
SNS : www.youtube.com/c/데싸노트