サムネがコーヒーの記事は書きかけです。

FastAPIによるMariaDBのcrud操作と非同期処理【バックエンド】

これまではFastAPIを使用してデータをSQLite形式で取り扱っていましたが、今回は本格的にサーバーに展開できるようにMaria DBに切り替えていきます。また、データベースのcrud操作を非同期処理に切り替えます。

具体的には、以下の記事で作成したユーザーの登録センテンスに対するマルコフ連鎖を抽出するアプリケーションをMariaDB+非同期処理に書き換えていきます。

データベースへの接続(database.py)

import os
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from typing import cast, Callable, AsyncContextManager

mariadb_user: str = os.environ["MARIADB_USER"]
mariadb_password: str = os.environ["MARIADB_PASSWORD"]
mariadb_database: str = os.environ["MARIADB_DATABASE"]

SQLALCHEMY_DATABASE_URL = f"mariadb+asyncmy://{mariadb_user}:{mariadb_password}@db/{mariadb_database}?charset=utf8mb4"

engine = create_async_engine(SQLALCHEMY_DATABASE_URL, pool_recycle=3600)
SessionLocal = cast(
    Callable[[], AsyncContextManager[AsyncSession]],
    sessionmaker(
        engine, future=True, autocommit=False, autoflush=False, class_=AsyncSession
    )
)

Base = declarative_base()

コンテナでのapt-get

コンテナ上に直接モジュールをインストールします。

今回はgraphvizというグラフ描画モジュールをインストールする必要がありました。

dockerfileに以下の文を足します。

RUN apt-get update && export DEBIAN_FRONTEND=noninteractive && apt-get -y install --no-install-recommends <module name>

データベースのベースモデル(models.py)

モデルはSQLite3の時と全く同じにしました。

from sqlalchemy import Column, ForeignKey, Integer, String, Text
from sqlalchemy.orm import relationship
from .database import Base


class User(Base):
    __tablename__ = "users"
    id = Column(Integer, primary_key=True, index=True, autoincrement=True)
    username = Column(String(100), unique=True, index=True)
    sentences = relationship("Item", back_populates="owner", lazy="selectin")


class Item(Base):
    __tablename__ = "sentences"

    id = Column(Integer, primary_key=True, index=True, autoincrement=True)
    sentence = Column(Text, index=True)
    owner_id = Column(Integer, ForeignKey("users.id"))

    owner = relationship("User", back_populates="sentences", lazy="joined")

図にするとこのような感じ。

schemas.pyの定義

こちらも、以前の記事から変更はありません。

from typing import List, Optional
from pydantic import BaseModel


class ItemBase(BaseModel):
    sentence: str


class ItemCreate(ItemBase):
    pass


class Item(ItemBase):
    id: int
    owner_id: int

    class Config:
        orm_mode = True


class UserBase(BaseModel):
    username: str


class UserCreate(UserBase):
    pass


class User(UserBase):
    id: int
    sentences: List[Item] = []
    class Config:
        orm_mode = True

crud処理の非同期化

関数名の前にasyncを付けます。また、データの読み込みなど非同期処理ができそうな箇所にawaitを付けます。

from sqlalchemy.orm import Session
from . import models
from . import schemas
from sqlalchemy.future import select


async def get_user(db: Session, user_id: int):
    return (await db.scalars(select(models.User).where(models.User.id == user_id))).first()

async def get_sentences(db: Session,user_id: int) -> list[models.Item]:
    stmt = select(models.Item).order_by(models.Item.id).where(models.Item.owner_id == user_id)
    items: list[models.Item] = (await db.scalars(stmt)).all()
    return items

async def get_user_by_username(db: Session, username: str):
    return (await db.scalars(select(models.User).where(models.User.username == username))).first()


async def get_users(db: Session, skip: int = 0, limit: int = 100):
    return (await db.scalars(select(models.User).offset(skip).limit(limit))).all()

async def create_user(db: Session, user: schemas.UserCreate):
    db_user = models.User(username = user.username)
    db.add(db_user)
    await db.commit()
    await db.refresh(db_user)
    return db_user

async def get_sentences_all(db: Session, skip: int = 0, limit: int = 100):
    return (await db.scalars(select(models.Item).offset(skip).limit(limit))).all()

async def create_user_sentence(db: Session, sentence: schemas.ItemCreate, user_id: int):
    db_sentence = models.Item(**sentence.dict(), owner_id=user_id)
    db.add(db_sentence)
    await db.commit()
    await db.refresh(db_sentence)
    return db_sentence

非同期の場合は処理内容を逐一実行しているとおかしなことになるので、大きな処理(今回はマルコフ連鎖の抽出)は関数にまとめてawaitを付け加えて実行する形にします。

from graphviz import Digraph
def temp(sentences,user_id):
    sentences = [i.sentence for i in sentences]
    words = [i.lower() for i in '\n'.join(sentences).split()]
    chain = {i: [] for i in words}
    for i in [[words[i], words[i+1]] for i in range(len(words)-1)]:
        chain[i[0]].append(i[1])
    chain = {i: chain[i] for i in chain if i[-1] != "." and i[-1] != ","}
    edges = {(i, j): str(1/len(list(set(chain[i]))))[:4] if len(
        chain[i]) != 0 else "0" for i in chain for j in chain[i]}
    G = Digraph(format="png")
    G.attr("node", shape="circle")
    for i, j in edges:
        G.edge(str(i), str(j), label=edges[(i, j)])
    text = "".join([str(i) for i in G.source])
    edges_json = [i.replace("\n", '').replace('}', '')
                  for i in text.split('\t')][2:]
    G.render(f"markov{user_id}")
    return edges_json

main.py

クラッド処理を統合します。

import asyncio
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
# from .utils.spa_staticfiles import SPAStaticFiles # For StaticFiles for SPA (React, Vue, etc.)
from .schemas import ItemCreate

from typing import List
from fastapi import Depends, FastAPI, HTTPException
from sqlalchemy.orm import Session
from . import crud
from .import models
from .import schemas
from .database import SessionLocal, engine
from graphviz import Digraph
import json
from fastapi.responses import FileResponse
from .functioins import temp
app = FastAPI()


async def get_db():
    async with SessionLocal() as db:
        yield db


@app.on_event("startup")
async def startup_event():
    async with engine.begin() as conn:
        # Create DB tables.
        await conn.run_sync(models.Base.metadata.create_all)


@app.post("/users/", response_model=schemas.User)
async def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)):
    db_user = await crud.get_user_by_username(db, username=user.username)
    if db_user:
        raise HTTPException(status_code=400, detail="Username already taken.")
    return await crud.create_user(db=db, user=user)


@app.get("/users/", response_model=List[schemas.User])
async def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
    users = await crud.get_users(db, skip=skip, limit=limit)
    return users


@app.get("/users/{user_id}", response_model=schemas.User)
async def read_user(user_id: int, db: Session = Depends(get_db)):
    db_user = await crud.get_user(db, user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found.")
    return db_user


@app.post("/users/{user_id}/items/", response_model=schemas.Item)
async def create_sentence_for_user(
        user_id: int, sentence: schemas.ItemCreate, db: Session = Depends(get_db)
):
    return await crud.create_user_sentence(db=db, sentence=sentence, user_id=user_id)


@app.get("/users/sentences/{user_id}", response_model=schemas.User)
async def read_sentences(user_id: int, db: Session = Depends(get_db)):
    sentences = await crud.get_sentences(db=db, user_id=user_id)
    print(sentences)
    return await crud.get_user(db, user_id=user_id)


@app.get("/users/markovchain/{user_id}")
async def generate_chain(user_id: int, db: Session = Depends(get_db)):
    sentences = await crud.get_sentences(db=db, user_id=user_id)
    edges_json = await asyncio.to_thread(temp, sentences, user_id)
    return {str(user_id): json.dumps(edges_json)}


@app.get("/image/{user_id}")
async def get_image(user_id: int):
    return FileResponse(f"markov{user_id}.png")

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です