サムネがコーヒーの記事は書きかけです。

抽象化した公開鍵取得クラスの作成 Python

Auth0とLINEの両方でログインできる機能を実装した際に、公開鍵の取得(リロード)部分を自力で実装する必要があったので、忘れないうちにまとめておきます。

ポイント

公開鍵は更新される可能性がある

そもそも公開鍵をハードコードして使うことも可能ですが、外部サーバーの仕様で変更になった場合に対応できないため、直接問い合わせて更新したほうが良いとの考えで今回のクラスを実装します。

外部サーバーと通信するためレート制限について考慮する

公開鍵をリロードするときは外部のサーバーと通信をすることになります。そのため、依存性注入などでエンドポイントを保護している場合、APIが呼び出されるたびに外部サーバーへリクエストが発生してしまうことになります。

外部サーバーとの非同期通信時に排他制御をかける

外部サーバーに問い合わせる際、非同期関数の場合エンドポイントへの同時アクセス時にリロードが重複してしまう問題が出てきます。この問題を解消するために、asyncio lock()によりリロード処理を排他制御する必要があります。

署名検証は必ず行う

JWTの仕様として署名検証は必ず行う必要があるため、送られてくるトークンは検証が完了するまで得体の知れないものとして扱う必要があります。

JWKのデータ構造をTypeAdapterで保証する

PydanticのTypeAdapterを使用して署名検証に必要なプロパティを内包したデータを必ず受け取れるように、データ構造を固定します。
https://docs.pydantic.dev/latest/api/type_adapter/

データベースではなくメモリ上で公開鍵を保持する

わざわざデータベースを介さずに、インスタンスを生成してメモリ上で公開鍵取得のための情報を保持しておく方が効率的になります。

署名検証アルゴリズム別に型を定義する

例えば下記のように署名検証アルゴリズムであるRS256とES256用にそれぞれのTypeDictを作成し、descriminated unionを使用してJwkDictという型を生成しておくことで複数の検証方式を共通化することができます。

https://docs.pydantic.dev/latest/concepts/unions

from typing import TypedDict, Literal, Annotated

class RSAJwk(TypedDict):
    kty: str
    alg: Literal["RS256"]
    use: str
    kid: str
    n: str
    e: str
    x5t: str
    x5c: list[str]


class ECDSAJwk(TypedDict):
    kty: str
    alg: Literal["ES256"]
    use: str
    kid: str
    crv: str
    x: str
    y: str
    xvalue: str
    yvalue: str


JwkDict = Annotated[RSAJwk | ECDSAJwk, Field(discriminator="alg")]


class JwksResponse(TypedDict):
    keys: list[JwkDict]

公開鍵Fetcher

引数のkidに対応する公開鍵情報がすでに存在すればそのまま返す、そうでなければ前回の取得を試みてから指定時間(min_reload_interval,秒)以上経過している場合にのみ実際に公開鍵を取りに行くというような処理を行うメソッドを内包した公開鍵Fetcherクラスを実装します。

class PublicKeyFetcher:
    def __init__(self, jwk_url: str, min_reload_interval: float) -> None:
        self._jwk_url: str = jwk_url
        self._min_reload_interval: float = min_reload_interval
        self._jwk_response_adapter = TypeAdapter(JwksResponse)
        self._keys: dict[str, JwkDict] = {}
        self._last_updated: datetime | None = None
        self._lock = asyncio.Lock()

    async def get_key(self, kid: str) -> JwkDict:
        async with self._lock:
            existing_key = self._keys.get(kid)
            if existing_key:
                return existing_key
            if self._last_updated is None or (
                (datetime.now() - self._last_updated).seconds
                > self._min_reload_interval
            ):
                self._last_updated = datetime.now()
                async with httpx.AsyncClient() as client:
                    try:
                        jwk_response = await client.get(self._jwk_url)
                        jwk_response.raise_for_status()
                        validated_jwk_response = (
                            self._jwk_response_adapter.validate_python(
                                jwk_response.json()
                            )
                        )
                        self._keys = {
                            item["kid"]: item for item in validated_jwk_response["keys"]
                        }

                    except (
                        httpx.HTTPError,
                        ValidationError,
                        json.JSONDecodeError,
                    ) as err:
                        print(err)
                        raise FailedToFetchJWKs
            key = self._keys.get(kid)
            if key is None:
                raise InvalidKID
            return key

上記のような公開鍵取得クラスは下記のように使用できます。

line_public_key_fetcher: PublicKeyFetcher = PublicKeyFetcher(
    jwk_url="https://api.line.me/oauth2/v2.1/certs",
    min_reload_interval=settings.line_public_keys_min_reload_interval,
)

auth0_public_key_fetcher: PublicKeyFetcher = PublicKeyFetcher(
    jwk_url=f"https://{settings.auth0_domain}/.well-known/jwks.json",
    min_reload_interval=settings.auth0_public_keys_min_reload_interval,
)

async def parse_line_id_token(token: str) -> LineIdToken:
    try:
        header = JWTHeader(**jwt.get_unverified_header(token))
        kid = header.kid
        public_key = await line_public_key_fetcher.get_key(kid)

        payload = jwt.decode(
            token,
            public_key,
            issuer="https://access.line.me",
            algorithms=["ES256"],
            audience=settings.line_login_channel_id,
        )
...

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です