今日は、今日は、PythonとPyTorchを使って、あなたの写真の中にある物を識別する方法を学びます。Faster R-CNNという強力なモデルを使って、物体検出を行う方法を一緒に見ていきましょう。
前回はこちらの記事で環境構築まで紹介しました。
今回は実際にプログラムを作成して実践してみましょう。
・物体検出のプログラムがわからない。
・Pytorchを使用して、画像内の物体を検出したい。
こんな疑問に答えます。
・物体検出のプログラムの中身がわかるようになる。
・実際に画像に対して、物体検出を行うことができるようになる。
・バウンディングボックスで画像に記載する方法がわかる。
では実際に作成していきましょう。
Pytorchの物体検出プログラム
初めにプログラムの全文を記載し、この内容について説明していきます。
import torch
from torchvision import models, transforms
import matplotlib.pyplot as plt
from PIL import Image
# モデルをロード
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
# クラス名のリストをロード
with open('ms_coco_classnames.txt', 'r') as f:
classes = [line.strip() for line in f.readlines()]
# 画像の変換処理を定義
transform = transforms.Compose([
transforms.ToTensor(),
])
# アップロードされた画像を読み込み
img_path = 'test2.jpg' # ここはアップロードされた画像のパス
img = Image.open(img_path)
# 変換処理を適用
img_tensor = transform(img)
# 画像をモデルに入力し、予測を取得
with torch.no_grad():
prediction = model([img_tensor])[0]
# 画像を表示
plt.figure(figsize=(12, 8))
plt.imshow(img)
# バウンディングボックスとクラス名を描画
for i, (box, label_index, score) in enumerate(zip(prediction['boxes'], prediction['labels'], prediction['scores'])):
# スコアが0.5以上のバウンディングボックスのみを描画
if score > 0.5:
box = box.numpy()
label = classes[label_index] if label_index < len(classes) else 'Unknown'
label=label[4:]
plt.gca().add_patch(plt.Rectangle((box[0], box), box - box[0], box - box, fill=False, edgecolor='red', linewidth=2))
plt.text(box[0], box, f'{label}: {score:.2f}', color='white', fontsize=12, bbox=dict(facecolor='red', alpha=0.5))
plt.axis('off') # 軸を非表示に
plt.show()
実行するとこんな感じになります。
このようにクラス名称と確率が表示されます。ではプログラムの内容を紹介していきます。
プログラムの内容説明
ステップ1: 必要なライブラリをインポートする
まず最初に、物体検出を行うために必要なライブラリをインポートします。
import torch
from torchvision import models, transforms
import matplotlib.pyplot as plt
from PIL import Image
- torch: PyTorchライブラリで、ディープラーニングのモデルを扱います。
- models: PyTorchの事前訓練済みモデルを提供します。
- transforms: 画像をモデルが処理できる形式に変換するためのツールです。
- matplotlib.pyplot: 画像を表示するためのライブラリです。
- Image: PIL(Python Imaging Library)を使って画像ファイルを操作します。
すべてが完了すると下記のようにチェックマークになります。ライブラリをインストールしていない場合はインストールしてください。
ステップ2: モデルをロードする
物体検出のために、事前に訓練されたFaster R-CNNモデルをロードします。
# 事前に訓練されたFaster R-CNNモデルをロードします。
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval() # モデルを評価モードに設定
下記のようにモデルがダウンロードされます。
- models.detection.fasterrcnn_resnet50_fpn(pretrained=True): これにより、事前に訓練されたFaster R-CNNモデルがロードされます。
pretrained=True
は、既に訓練されたデータを使ってモデルをロードすることを意味します。 - model.eval(): モデルを評価モードに設定します。これは、モデルを予測に使用するときに行います。
ステップ3: クラス名のリストをロードする
検出された物体が何であるかを知るためには、クラス名のリストが必要です。
事前にms_coco_classnames.txtを下記URLよりインストールして、プログラムと同じディレクトリに保存します。
ダウンロードしたらZIPファイルから取り出し、同じディレクトリに配置。
下記プログラウでえファイルを読み込みます。
with open('ms_coco_classnames.txt', 'r') as f:
classes = [line.strip() for line in f.readlines()]
ms_coco_classnames.txt: これは、COCOデータセットで使用されるクラス名のリストを含むテキストファイルです。ファイルの各行には、一つのクラス名が含まれています。
これでクラス名称とクラスのIDが把握できるようになります。
ステップ4: 画像の前処理を行う
画像をモデルに入力する前に、適切な形式に変換する必要があります。
transform = transforms.Compose([
transforms.ToTensor(),
])
transforms.ToTensor(): この変換によって、PIL画像またはNumPy ndarray
がFloatTensorに変換され、値が[0., 1.]の範囲にスケールされます。
ステップ5: 画像をモデルに入力し、予測を取得する
画像のパスを記載します。画像も同じディレクトリに配置します。
今回はtest2.jpgで実行します。
img_path = 'test2.jpg' # ここはアップロードされた画像のパス
img = Image.open(img_path)
img_tensor = transform(img)
with torch.no_grad():
prediction = model([img_tensor])[0]
- Image.open(img_path): PILを使用して画像を開きます。
- transform(img): 前処理を適用して画像をテンソルに変換します。
- torch.no_grad(): 予測中に勾配計算を無効にするために使用します。
- model([img_tensor])[0]: モデルに画像テンソルを入力して予測を行います。
これでpredictionまでは完了したので、この後画像に四角の枠で囲う描画の処理を行います。
ステップ6: 画像にバウンディングボックスとクラス名を描画する
figsizeを変更すると、画像のサイズを変更したり、score>0.5 から0.9に値を変更すると、確率が高いもののみ表示するプログラムになります。
plt.figure(figsize=(12, 8))
plt.imshow(img)
for i, (box, label_index, score) in enumerate(zip(prediction['boxes'], prediction['labels'], prediction['scores'])):
if score > 0.5:
box = box.numpy()
label = classes[label_index] if label_index < len(classes) else 'Unknown'
label = label[4:]
plt.gca().add_patch(plt.Rectangle((box[0], box), box - box[0], box - box, fill=False, edgecolor='red', linewidth=2))
plt.text(box[0], box, f'{label}: {score:.2f}', color='white', fontsize=12, bbox=dict(facecolor='red', alpha=0.5))
plt.axis('off')
plt.show()
いろいろな画像でやると面白いのでやってみてください。
- plt.figure(figsize=(12, 8)): 描画する図のサイズを設定します。
- plt.imshow(img): 画像を表示します。
- for i, (box, label_index, score) in enumerate(…): 予測されたバウンディングボックス、ラベルのインデックス、スコアに対してループ処理を行います。
- plt.gca().add_patch(…): バウンディングボックスを描画します。
- plt.text(…): バウンディングボックスの上にクラス名とスコアをテキストとして描画します。
- plt.axis(‘off’): 軸をオフにして表示をクリーンにします。
- plt.show(): 図を表示します。
これで、画像に対して物体検出を行い、結果を視覚的に確認することができます。物体検出はAIの魅力的な応用の一つです。このステップバイステップガイドがあなたの学習に役立つことを願っています!