サムネがコーヒーの記事は書きかけです。

FastAPIとSQLalchemyによるデータベースの操作【Python】【バックエンド】

仕事でFast APIを使うことになったので、データベース操作の基本をまとめておきます。

FastAPIのインストール

pip install fastapi

Uvicornのインストール

 pip install uvicorn

ローカルサーバーの起動

uvicornを使用してローカルサーバーを起動。ポート8000番代

uvicorn main:app --reload

パスパラメータ

from fastapi import FastAPI

app = FastAPI()

@app.get("/")
async def root():
    return {"message": "Hello World"}

@app.get("/items/{item_id}")
async def read_item(item_id: int):
    return {"item_id": item_id}

#Path parameter + Pydantic 
@app.get("/about/{id}")
def about(id:int):
    return {"data":id}

SQLAlchemyによるデータベース操作

簡単な例として、ユーザーから文章を受け取ってマルコフ連鎖の状態遷移図を返すアプリを作成していきます。

database

データベースを作成します。

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

SQLALCHEMY_DATABASE_URL = "sqlite:///./app.db"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()

schemas

ベースクラスを定義します。この時、Pydanticのベースモデルを継承して型安全性を保ちます。

from typing import List, Optional
from pydantic import BaseModel


class ItemBase(BaseModel):
    sentence: str


class ItemCreate(ItemBase):
    pass


class Item(ItemBase):
    id: int
    owner_id: int

    class Config:
        orm_mode = True


class UserBase(BaseModel):
    username: str


class UserCreate(UserBase):
    pass


class User(UserBase):
    id: int
    sentences: List[Item] = []
    class Config:
        orm_mode = True

models

ベースクラスを使用して、データベースモデルの定義を行います。

from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from sqlalchemy_serializer import SerializerMixin 
from database import Base

class User(Base):
    __tablename__ = "users"
    id = Column(Integer, primary_key=True, index=True)
    username = Column(String, unique=True, index=True)

    sentences = relationship("Item", back_populates="owner")



class Item(Base,SerializerMixin):
    __tablename__ = "sentences"

    id = Column(Integer, primary_key=True, index=True)
    sentence = Column(String, index=True)
    owner_id = Column(Integer, ForeignKey("users.id"))

    owner = relationship("User", back_populates="sentences")

crud

データベースの操作を行う関数を作成します。

from sqlalchemy.orm import Session
import models
import schemas

def get_user(db: Session, user_id: int):
    return db.query(models.User).filter(models.User.id == user_id).first()

def get_sentences(db: Session,user_id: int):
    sentences = [i.sentence for i in db.query(models.User).order_by(models.User.id)[user_id-1].sentences]
    return sentences

def get_user_by_username(db: Session, username: str):
    return db.query(models.User).filter(models.User.username == username).first()

def get_users(db: Session, skip: int = 0, limit: int = 100):
    return db.query(models.User).offset(skip).limit(limit).all()

def create_user(db: Session, user: schemas.UserCreate):
    db_user = models.User(username = user.username)
    db.add(db_user)
    db.commit()
    db.refresh(db_user)
    return db_user

def get_sentences_all(db: Session, skip: int = 0, limit: int = 100):
    return db.query(models.Item).offset(skip).limit(limit).all()


def create_user_sentence(db: Session, sentence: schemas.ItemCreate, user_id: int):
    db_sentence = models.Item(**sentence.dict(), owner_id=user_id)
    db.add(db_sentence)
    db.commit()
    db.refresh(db_sentence)
    return db_sentence

main.py

ここでURLの割り当てとデータベース操作の統括を行います。

from typing import List
from fastapi import Depends, FastAPI, HTTPException
from sqlalchemy.orm import Session
import crud
import models
import schemas
from database import SessionLocal, engine

models.Base.metadata.create_all(bind=engine)
app = FastAPI()

def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()


@app.post("/users/", response_model=schemas.User)
def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)):
    db_user = crud.get_user_by_username(db, username=user.username)
    if db_user:
        raise HTTPException(status_code=400, detail="Username already taken.")
    return crud.create_user(db=db, user=user)


@app.get("/users/", response_model=List[schemas.User])
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
    users = crud.get_users(db, skip=skip, limit=limit)
    return users


@app.get("/users/{user_id}", response_model=schemas.User)
def read_user(user_id: int, db: Session = Depends(get_db)):
    db_user = crud.get_user(db, user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found.")
    return db_user


@app.get("/users/sentences/{user_id}", response_model=schemas.User)
def read_sentences(user_id: int, db: Session = Depends(get_db)):
    sentences = crud.get_sentences(db=db, user_id=user_id)
    print(sentences)
    return crud.get_user(db, user_id=user_id)


@app.post("/users/{user_id}/items/", response_model=schemas.Item)
def create_sentence_for_user(
        user_id: int, sentence: schemas.ItemCreate, db: Session = Depends(get_db)
):
    return crud.create_user_sentence(db=db, sentence=sentence, user_id=user_id)

サーバーの起動

以下をmain.pyと同じディレクトリで実行します。

uvicorn main:app --reload
>>>
INFO:     Will watch for changes in these directories: ['/Users']
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO:     Started reloader process [48055] using StatReload
INFO:     Started server process [48057]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     127.0.0.1:57169 - "GET /docs HTTP/1.1" 200 OK
INFO:     127.0.0.1:57169 - "GET /openapi.json HTTP/1.1" 200 OK
INFO:     127.0.0.1:57169 - "GET /docs HTTP/1.1" 200 OK
INFO:     127.0.0.1:57170 - "GET /docs HTTP/1.1" 200 OK
INFO:     127.0.0.1:57170 - "GET /openapi.json HTTP/1.1" 200 OK
INFO:     127.0.0.1:57169 - "GET /openapi.json HTTP/1.1" 200 OK

Swagger UIへのアクセス

FastAPIの場合、自動でAPIドキュメントを生成してくれます。

http://127.0.0.1:8000/docs

実行結果

ここまでで、登録したユーザーネームに対する文章を取り出すことができるようになりました。

処理の記述

任意の処理、今回ならマルコフ連鎖の抽出をmain.pyメソッドに追加してみます。

以下のように書き換えました。

from typing import List
from fastapi import Depends, FastAPI, HTTPException
from sqlalchemy.orm import Session
import crud
import models
import schemas
from database import SessionLocal, engine
from  graphviz import Digraph
import json
from fastapi.responses import FileResponse

models.Base.metadata.create_all(bind=engine)
app = FastAPI()

def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

@app.post("/users/", response_model=schemas.User)
def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)):
    db_user = crud.get_user_by_username(db, username=user.username)
    if db_user:
        raise HTTPException(status_code=400, detail="Username already taken.")
    return crud.create_user(db=db, user=user)

@app.get("/users/", response_model=List[schemas.User])
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
    users = crud.get_users(db, skip=skip, limit=limit)
    return users

@app.get("/users/{user_id}", response_model=schemas.User)
def read_user(user_id: int, db: Session = Depends(get_db)):
    db_user = crud.get_user(db, user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found.")
    return db_user


@app.post("/users/{user_id}/items/", response_model=schemas.Item)
def create_sentence_for_user(
        user_id: int, sentence: schemas.ItemCreate, db: Session = Depends(get_db)
):
    return crud.create_user_sentence(db=db, sentence=sentence, user_id=user_id)


@app.get("/users/sentences/{user_id}", response_model=schemas.User)
def read_sentences(user_id: int, db: Session = Depends(get_db)):
    sentences = crud.get_sentences(db=db, user_id=user_id)
    print(sentences)
    return crud.get_user(db, user_id=user_id)

@app.get("/users/markovchain/{user_id}")
def generate_chain(user_id: int, db: Session = Depends(get_db)):
    sentences = crud.get_sentences(db=db, user_id=user_id)

    words = [i.lower() for i in '\n'.join(sentences).split()]
    chain = {i:[] for i in words}
    for i in [[words[i],words[i+1]] for i in range(len(words)-1)]:
        chain[i[0]].append(i[1])

    chain = {i:chain[i] for i in chain if i[-1] != "."  and  i[-1] != ","}
    edges = {(i,j):str(1/len(list(set(chain[i]))))[:4] if len(chain[i]) !=0 else "0" for i in chain for j in chain[i]}
   

    G = Digraph(format="svg")

    G.attr("node", shape="circle")
    for i,j in edges:
        G.edge(str(i), str(j),label=edges[(i,j)])
    
    text = "".join([str(i) for i in G.source])

    edges_json = [i.replace("\n",'').replace('}','') for i in text.split('\t')][2:]
    print(edges_json)
    G.render(f"markov{user_id}")

    return {str(user_id):json.dumps(edges_json)}

ファイルのレスポンス

FastAPIを使うと、簡単に任意のファイルをレスポンスすることができるようです。

@app.get("/image/{user_id}")
async def get_image(user_id:int):
    return FileResponse(f"markov{user_id}.svg")

この場合、サーバーのルートディレクトリで直接ファイルの受け渡しをしているので、もう少しいい方法がないか探します。

Swagger UIによる確認

上記のアプリの出力結果を見てみます。

文章の登録

ユーザ一id1に対応するカラムに以下の文章を登録します。

I like a dog.
I like dogs.
I like a cat.
I like cats.
I like the cat.
I like the dog.
I like the cats.
I like the dogs.

登録結果

{
  "username": "user1",
  "id": 1,
  "sentences": [
    {
      "sentence": "I like a dog.",
      "id": 1,
      "owner_id": 1
    },
    {
      "sentence": "I like dogs.",
      "id": 2,
      "owner_id": 1
    },
    {
      "sentence": "I like a cat.",
      "id": 3,
      "owner_id": 1
    },
    {
      "sentence": "I like cats.",
      "id": 4,
      "owner_id": 1
    },
    {
      "sentence": "I like the cat.",
      "id": 5,
      "owner_id": 1
    },
    {
      "sentence": "I like the cats.",
      "id": 6,
      "owner_id": 1
    },
    {
      "sentence": "I like the dog.",
      "id": 7,
      "owner_id": 1
    },
    {
      "sentence": "I like the dogs.",
      "id": 8,
      "owner_id": 1
    }
  ]
}

マルコフ連鎖の生成

ユーザーid1に対応するデータから文章を読み込んで処理を行います。

実行結果

{
  "1": "[\"i -> like [label=1.0]\", \"like -> a [label=0.25]\", \"like -> \\\"dogs.\\\" [label=0.25]\", \"like -> \\\"cats.\\\" [label=0.25]\", \"like -> the [label=0.25]\", \"a -> \\\"dog.\\\" [label=0.5]\", \"a -> \\\"cat.\\\" [label=0.5]\", \"the -> \\\"cat.\\\" [label=0.25]\", \"the -> \\\"cats.\\\" [label=0.25]\", \"the -> \\\"dog.\\\" [label=0.25]\", \"the -> \\\"dogs.\\\" [label=0.25]\"]"
}

ファイルのレスポンス

マルコフ連鎖を生成したときにレンダリングされたpngを取り出します。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です