5-1 결정 트리PYTHON/데이터분석2023. 9. 15. 17:21
Table of Contents
#로지스틱 회귀로 와인 분류하기
캔에 인쇄된 알콜 도수, 당도, PH 값으로 와인종류를 구별할 수 있는 방법이 있을까?
6497개의 와인 샘플 데이터가 있다.
import pandas as pd
wine = pd.read_csv('https://bit.ly/wine_csv_data')
처음 5개 샘플 확인
wine.head()
class는 타깃값으로 0이면 레드, 1이면 화이트와인
레드와인과 화이트와인을 구분하는 이진분류 문제다.
화이트와인이 양성클래스이다.
즉 전체 와인 데이터에서 화이트 와인을 골라내는 문제다.
info()로 데이터프레임의 각 열의 데이터 타입과 누락된 데이터 있는지 확인
describe()로 열에 대한 간략한 통계 출력
차례로 평균, 표준편차, 최소, 1,2,3사분위수(데이터를 순서대로 4등분), 최대
알콜도수, 당도, pH값의 스케일이 다르다. 표준화해야한다.
판다스 데이터 프레임을 넘파이 배열로 바꾸고 훈련세트, 테스트세트로 나눈다.
data = wine[['alcohol', 'sugar', 'pH']].to_numpy()
target = wine['class'].to_numpy()
from sklearn.model_selection import train_test_split
train_input, test_input, train_target, test_target
= train_test_split(data, target, test_size=0.2, random_state=42)
훈련세트 : 5197개 테스트세트 : 1300개
print(train_input.shape, test_input.shape)
#(5197, 3) (1300, 3)
StandardScaler()로 훈련세트를 전처리한 다음(표준점수로 변환)
같은 객체를 그대로 사용해 테스트 세트를 변환
from sklearn.preprocessing import StandardScaler
ss = StandardScaler()
ss.fit(train_input)
train_scaled = ss.transform(train_input)
test_scaled = ss.transform(test_input)
로지스틱회귀모델 훈련
점수가 둘다 낮다 과소적합이다.
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression()
lr.fit(train_scaled, train_target)
print(lr.score(train_scaled, train_target))
print(lr.score(test_scaled, test_target
"""
0.7808350971714451
0.7776923076923077
"""
#설명하기 쉬운 모델과 어려운 모델
로지스틱 회귀가 학습한 계수와 절편을 출력
-> 무슨 의미인지 이해하기 어려움
print(lr.coef_, lr.intercept_)
#[[ 0.51270274 1.6733911 -0.68767781]] [1.81777902]
#결정트리
트리처럼 이동..
훈련세트에 대한 점수가 엄청높다.
from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier(random_state=42)
dt.fit(train_scaled, train_target)
print(dt.score(train_scaled, train_target))
print(dt.score(test_scaled, test_target))
"""
0.996921300750433
0.8592307692307692
"""
출력
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
plt.figure(figsize=(10,7))
plot_tree(dt)
plt.show()
길이 제한해서 출력
왼 : yes , 오 : no
사각형 안 : 테스트 조건, 불순도, 총샘플수, 클래스별 샘플 수(음성,양성)
plt.figure(figsize=(10,7))
plot_tree(dt, max_depth=1, filled=True, feature_names=['alcohol', 'sugar', 'pH'])
plt.show()
'PYTHON > 데이터분석' 카테고리의 다른 글
4-2 확률적 경사 하강법 (0) | 2023.09.15 |
---|---|
4-1 로지스틱 회귀 (0) | 2023.09.15 |
3-3 특성 공학과 규제 (0) | 2023.09.15 |
3-2 선형회귀 (0) | 2023.09.15 |
3-1 KNN 회귀 (0) | 2023.09.15 |