今回は、シンプルながら非常に強力な分類アルゴリズムであるロジスティック回帰について学んでいきましょう。ロジスティック回帰を使って、アヤメの花を分類する方法を見ていきます。
環境構築については別の記事で紹介しているので、参考にしてください。
動画でも紹介しているのでご覧ください。
データセットの準備
まずは、アヤメのデータセットをロードしましょう。Scikit-learnライブラリは、機械学習の初学者がよく使うIrisデータセットを含んでいます
。
from sklearn.datasets import load_iris
import pandas as pd
import numpy as np
# Irisデータセットのロード
iris = load_iris()
iris_df = pd.DataFrame(data= np.c_[iris['data'], iris['target']],
columns= iris['feature_names'] + ['target'])
データの視覚化
視覚化はデータを理解する上で非常に重要です。Seabornライブラリを使って、アヤメのデータを視覚化してみましょう。
import seaborn as sns
import matplotlib.pyplot as plt
# ターゲット列が整数型であることを確認し、整数型に変換
iris_df['target'] = iris_df['target'].astype(int)
# ターゲットの数値を種類名にマッピング
iris_df['species'] = iris_df['target'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})
# データの視覚化
sns.pairplot(iris_df.drop('target', axis=1), hue='species')
plt.show()
実際に実行するとこのように各種変数がグラフ化されて表示されます。
このpairplotというseabornのメソッドを使用すると、簡単にグラフ化ができます。
相関を確認
次に、ヒートマップを使って、特徴量間の相関を確認します。相関が高い特徴量は、互いに関連していることを意味しています。
# 相関行列を計算
corr = iris_df.drop('target', axis=1).corr()
# ヒートマップを作成
sns.heatmap(corr, annot=True)
plt.show()
実行すると、ヒートMAPが表示されます。
この結果より、petal width/lengthの相関が強いことがわかります。
ロジスティック回帰モデルの訓練
ロジスティック回帰は、特定の特徴量に基づいて、アヤメがどの種類に属するかを予測します。
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
# データとラベルを準備
X = iris.data
y = iris.target
# 訓練セットとテストセットに分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# モデルの訓練
model_lr = LogisticRegression(max_iter=200)
model_lr.fit(X_train, y_train)
# テストデータの予測
predictions_lr = model_lr.predict(X_test)
# 正確度の計算
accuracy_lr = accuracy_score(y_test, predictions_lr)
print(f"Logistic Regression Accuracy: {accuracy_lr}")
実行すると、簡単にテストデータの予測を行い、正確度を表示されます。今回の結果は1.0となり100%で正解していることがわかります。
また、境界を確認するために、グラフ化した結果がこちらです。
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
import numpy as np
import matplotlib.pyplot as plt
# Load the iris dataset
iris = load_iris()
# Prepare subplots
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 6))
# First subplot for sepal features (sepal length and sepal width)
X_sepal = iris.data[:, :2] # Selecting sepal length and sepal width
logreg_sepal = LogisticRegression(C=1e5, solver='lbfgs', multi_class='multinomial')
logreg_sepal.fit(X_sepal, iris.target)
x_min, x_max = X_sepal[:, 0].min() - .5, X_sepal[:, 0].max() + .5
y_min, y_max = X_sepal[:, 1].min() - .5, X_sepal[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, .02), np.arange(y_min, y_max, .02))
Z_sepal = logreg_sepal.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
axes[0].contourf(xx, yy, Z_sepal, cmap=plt.cm.Paired, alpha=0.8)
axes[0].scatter(X_sepal[:, 0], X_sepal[:, 1], c=iris.target, edgecolors='k', cmap=plt.cm.Paired)
axes[0].set_xlabel('Sepal length')
axes[0].set_ylabel('Sepal width')
axes[0].set_title('Logistic Regression on Sepal Features')
# Second subplot for petal features (petal length and petal width)
X_petal = iris.data[:, 2:4] # Selecting petal length and petal width
logreg_petal = LogisticRegression(C=1e5, solver='lbfgs', multi_class='multinomial')
logreg_petal.fit(X_petal, iris.target)
x_min, x_max = X_petal[:, 0].min() - .5, X_petal[:, 0].max() + .5
y_min, y_max = X_petal[:, 1].min() - .5, X_petal[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, .02), np.arange(y_min, y_max, .02))
Z_petal = logreg_petal.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
axes.contourf(xx, yy, Z_petal, cmap=plt.cm.Paired, alpha=0.8)
axes.scatter(X_petal[:, 0], X_petal[:, 1], c=iris.target, edgecolors='k', cmap=plt.cm.Paired)
axes.set_xlabel('Petal length')
axes.set_ylabel('Petal width')
axes.set_title('Logistic Regression on Petal Features')
# Adjust the layout
plt.tight_layout()
plt.show()
左のグラフがsepal length/width, 右のグラフがpetal/length/widthの結果になっています。このように3種類のグラフ(setosa,virgicolor, verginica)の3つのグラフにこのように分類されます。
この結果より、未知のデータより、アヤメの種類を分類することが可能になります。人間の目でも、このグラフ化されると、どの花の種類に分類されるかわかりますね。
最後にこのモデルで未知の値から花の種類を分類するプログラムがこちらです。
未知のデータを入れて、花の分類プログラム
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
# Load the iris dataset
iris = datasets.load_iris()
X = iris.data
y = iris.target
# Train the LogisticRegression model
model = LogisticRegression(max_iter=200)
model.fit(X, y)
# Define a function to predict the class and probability of iris species
def predict_iris_species(features):
species = ['setosa', 'versicolor', 'virginica']
prediction = model.predict([features])
probabilities = model.predict_proba([features])
print(f"Predicted species: {species[prediction[0]]}")
print("Probabilities for each class:")
for specie, probability in zip(species, probabilities[0]):
print(f" {specie}: {probability:.4f}")
# Example: Predict the species and probabilities for a new iris with features
new_iris_features = [6.3, 2.5, 5, 1.5]
predict_iris_species(new_iris_features)
このプログラムでは、new_iris_featuresに適当な値を入れると、アヤメの種類を表示するプログラムになっています。実際にやった結果がこちらです。
境界付近のデータを入れているため、確率は低いですが、versicolorと分類されました。
このように簡単に機械学習できるのでscikit-learnで機械学習にトライしてみてください。