ウェブサイト検索

FastAPI を使用した ML を活用した Web アプリの構築


FastAPI と Jinja2 テンプレートを使用して、機械学習モデル推論用のシンプルな Web アプリケーションを構築するための初心者向けチュートリアル。

このチュートリアルでは、FastAPI について少し学び、それを使用して機械学習 (ML) モデル推論用の API を構築します。次に、Jinja2 テンプレートを使用して、適切な Web インターフェイスを作成します。これは、API と Web 開発に関する限られた知識でも自分で構築できる、短いですが楽しいプロジェクトです。

FastAPIとは何ですか?

FastAPI は、Python で API を構築するために使用される人気のある最新の Web フレームワークです。高速かつ効率的に設計されており、Python の標準型ヒントを活用して最高の開発エクスペリエンスを提供します。習得が簡単で、数行のコードのみで高パフォーマンスの API を開発できます。 FastAPI は、Uber、Netflix、Microsoft などの企業によって API やアプリケーションを構築するために広く使用されています。その設計は、機械学習モデルの推論とテスト用の API エンドポイントの作成に特に適しています。 Jinja2 テンプレートを統合することで、適切な Web アプリケーションを構築することもできます。

モデルのトレーニング

最も人気のある Iris データセットでランダム フォレスト分類器をトレーニングします。トレーニングが完了したら、モデルの評価メトリクスを表示し、モデルを pickle 形式で保存します。

train_model.py:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
import joblib

# Load the iris dataset
iris = load_iris()
X, y = iris.data, iris.target

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Train a RandomForest classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# Evaluate the model
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
report = classification_report(y_test, y_pred, target_names=iris.target_names)

print(f"Model Accuracy: {accuracy}")
print("Classification Report:")
print(report)

# Save the trained model to a file
joblib.dump(clf, "iris_model.pkl")
$ python train_model.py
Model Accuracy: 1.0
Classification Report:
              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        10
  versicolor       1.00      1.00      1.00         9
   virginica       1.00      1.00      1.00        11

    accuracy                           1.00        30
   macro avg       1.00      1.00      1.00        30
weighted avg       1.00      1.00      1.00        30

FastAPI を使用した ML API の構築

次に、モデル推論 API の構築に使用する FastAPI と Unicorn ライブラリをインストールします。 

$ pip install fastapi uvicorn

`app.py` ファイルでは次のことを行います。

  1. 前のステップで保存したモデルをロードします。

  2. 入力と予測用の Python クラスを作成します。必ず dtype を指定してください。 
  3. 次に、predict 関数を作成し、`@app.post` デコレータを使用します。デコレーターは、URL パス `/predict` で POST エンドポイントを定義します。この関数は、クライアントがこのエンドポイントに POST リクエストを送信すると実行されます。
  4. predict 関数は、`IrisInput` クラスから値を取得し、それらを `IrisPrediction` クラスとして返します。
  5. 「uvicorn.run」関数を使用してアプリを実行し、以下に示すようにホスト IP とポート番号を指定します。 

app.py:

from fastapi import FastAPI
from pydantic import BaseModel
import joblib
import numpy as np
from sklearn.datasets import load_iris

# Load the trained model
model = joblib.load("iris_model.pkl")

app = FastAPI()


class IrisInput(BaseModel):
    sepal_length: float
    sepal_width: float
    petal_length: float
    petal_width: float


class IrisPrediction(BaseModel):
    predicted_class: int
    predicted_class_name: str


@app.post("/predict", response_model=IrisPrediction)
def predict(data: IrisInput):
    # Convert the input data to a numpy array
    input_data = np.array(
        [[data.sepal_length, data.sepal_width, data.petal_length, data.petal_width]]
    )

    # Make a prediction
    predicted_class = model.predict(input_data)[0]
    predicted_class_name = load_iris().target_names[predicted_class]

    return IrisPrediction(
        predicted_class=predicted_class, predicted_class_name=predicted_class_name
    )


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="127.0.0.1", port=8000)

Python ファイルを実行します。 

$ python app.py

FastAPI サーバーが実行されており、リンクをクリックするとアクセスできます。 

INFO:     Started server process [33828]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)

ブラウザーのインデックス ページに移動します。インデックス ページには何もなく、`/predict` POST リクエストだけがあります。そのため、何も表示されません。 

SwaggerUI インターフェイスを使用して API をテストできます。リンクの後に「/docs」を追加するとアクセスできます。

「/predict」オプションをクリックして値を編集し、予測を実行します。最終的に、応答本文セクションで応答を取得します。ご覧のとおり、結果として「Virginica」が得られました。 SwaggerUI 内で直接値を使用してモデルをテストし、本番環境にデプロイする前にモデルが適切に動作していることを確認できます。 

Web アプリケーションの UI を構築する

Swagger UI を使用する代わりに、シンプルで他の Web アプリケーションと同様に結果を表示する独自のユーザー インターフェイスを作成します。これを実現するには、アプリ内に Jinja2Templates を統合する必要があります。 Jinja2Templates を使用すると、HTML ファイルを使用して適切な Web インターフェイスを構築でき、Web ページのさまざまなコンポーネントをカスタマイズできるようになります。

  1. Jinja2Templates に HTML ファイルが置かれるディレクトリを指定して、Jinja2Templates を開始します。 

  2. 「index.html」テンプレートをルート URL (「/」) の HTML 応答として提供する非同期ルートを定義します。
  3. リクエストとフォームを使用して、「predict」関数の入力引数を変更します。 
  4. アヤメの花の測定用のフォーム データを受け取り、機械学習モデルを使用してアヤメの種を予測し、TemplateResponse を使用して「result.html」にレンダリングされた予測結果を返す、非同期 POST エンドポイント「/predict」を定義します。
  5. コードの残りの部分は同様です。

from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
import joblib
import numpy as np
from sklearn.datasets import load_iris

# Load the trained model
model = joblib.load("iris_model.pkl")

# Initialize FastAPI
app = FastAPI()

# Set up templates
templates = Jinja2Templates(directory="templates")


# Pydantic models for input and output data
class IrisInput(BaseModel):
    sepal_length: float
    sepal_width: float
    petal_length: float
    petal_width: float


class IrisPrediction(BaseModel):
    predicted_class: int
    predicted_class_name: str


@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})


@app.post("/predict", response_model=IrisPrediction)
async def predict(
    request: Request,
    sepal_length: float = Form(...),
    sepal_width: float = Form(...),
    petal_length: float = Form(...),
    petal_width: float = Form(...),
):
    # Convert the input data to a numpy array
    input_data = np.array([[sepal_length, sepal_width, petal_length, petal_width]])

    # Make a prediction
    predicted_class = model.predict(input_data)[0]
    predicted_class_name = load_iris().target_names[predicted_class]

    return templates.TemplateResponse(
        "result.html",
        {
            "request": request,
            "predicted_class": predicted_class,
            "predicted_class_name": predicted_class_name,
            "sepal_length": sepal_length,
            "sepal_width": sepal_width,
            "petal_length": petal_length,
            "petal_width": petal_width,
        },
    )


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="127.0.0.1", port=8000)

次に、「app.py」と同じディレクトリに「templates」という名前のディレクトリを作成します。 「templates」ディレクトリ内に、「index.html」と「result.html」という 2 つの HTML ファイルを作成します。

Web 開発者であれば、HTML コードを簡単に理解できるでしょう。初心者向けに何が起こっているのかを説明します。この HTML コードは、アヤメの花の種類を予測するためのフォームを備えた Web ページを作成します。これにより、ユーザーは「がく片」と「花びら」の測定値を入力し、POST リクエストを介して「/predict」エンドポイントに送信できます。

index.html:

<!DOCTYPE html>

<html>

<head>

<title>Iris Flower Prediction</title>

</head>

<body>

<h1>Predict Iris Flower Species</h1>

<form action="/predict" method="post">

<label for="sepal_length">Sepal Length:</label>

<input type="number" step="any" id="sepal_length" name="sepal_length" required><br>

<label for="sepal_width">Sepal Width:</label>

<input type="number" step="any" id="sepal_width" name="sepal_width" required><br>

<label for="petal_length">Petal Length:</label>

<input type="number" step="any" id="petal_length" name="petal_length" required><br>

<label for="petal_width">Petal Width:</label>

<input type="number" step="any" id="petal_width" name="petal_width" required><br>

<button type="submit">Predict</button>

</form>

</body>

</html>

「result.html」コードは、入力されたがく片と花弁の測定値と予測されたアイリス種を示す予測結果を表示する Web ページを定義します。また、予測クラス名とクラス ID が表示され、インデックス ページに移動するボタンもあります。 

結果.html:

<!DOCTYPE html>

<html>

<head>

<title>Prediction Result</title>

</head>

<body>

<h1>Prediction Result</h1>

<p>Sepal Length: {{ sepal_length }}</p>

<p>Sepal Width: {{ sepal_width }}</p>

<p>Petal Length: {{ petal_length }}</p>

<p>Petal Width: {{ petal_width }}</p>

<h2>Predicted Class: {{ predicted_class_name }} (Class ID: {{ predicted_class }})</h2>

<a href="/">Predict Again</a>

</body>

</html>

Python アプリ ファイルを再度実行します。 

$ python app.py 
INFO:     Started server process [2932]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO:     127.0.0.1:63153 - "GET / HTTP/1.1" 200 OK

リンクをクリックしても、空の画面は表示されません。代わりに、「がく片」と「花びら」の長さと幅を入力できるユーザー インターフェイスが表示されます。 

「予測」ボタンをクリックすると、次のページに進み、結果が表示されます。 [再度予測] ボタンをクリックすると、別の値でモデルをテストできます。

すべてのソース コード、データ、モデル、および情報は、kingabzpro/FastAPI-for-ML GitHub リポジトリで入手できます。 ⭐にスターを付けることを忘れないでください。

結論

多くの大企業は現在、FastAPI を使用してモデルのエンドポイントを作成し、これらのモデルをシステム全体にシームレスに展開して統合できるようにしています。 FastAPI は高速でコーディングが簡単で、最新のデータ スタックの要求を満たすさまざまな機能が付属しています。この分野で仕事を獲得するための鍵は、できるだけ多くのプロジェクトを構築して文書化することです。これにより、一次審査に必要な経験と知識を得ることができます。採用担当者はあなたのプロフィールとポートフォリオを評価して、あなたが彼らのチームに適しているかどうかを判断します。それでは、今すぐ FastAPI を使用してプロジェクトの構築を始めてみてはいかがでしょうか?

関連記事