58 lines
1.9 KiB
Python
58 lines
1.9 KiB
Python
|
|
from typing import Any, Generic, TypeVar
|
||
|
|
from uuid import UUID
|
||
|
|
|
||
|
|
from pydantic import BaseModel
|
||
|
|
from sqlalchemy import select
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
|
||
|
|
from app.db.base import Base
|
||
|
|
|
||
|
|
ModelType = TypeVar("ModelType", bound=Base)
|
||
|
|
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||
|
|
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||
|
|
|
||
|
|
|
||
|
|
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||
|
|
def __init__(self, model: type[ModelType]):
|
||
|
|
self.model = model
|
||
|
|
|
||
|
|
async def get(self, db: AsyncSession, *, id: UUID) -> ModelType | None:
|
||
|
|
result = await db.execute(select(self.model).where(self.model.id == id))
|
||
|
|
return result.scalar_one_or_none()
|
||
|
|
|
||
|
|
async def get_multi(
|
||
|
|
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
||
|
|
) -> list[ModelType]:
|
||
|
|
result = await db.execute(select(self.model).offset(skip).limit(limit))
|
||
|
|
return list(result.scalars().all())
|
||
|
|
|
||
|
|
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType, **extra: Any) -> ModelType:
|
||
|
|
data = obj_in.model_dump()
|
||
|
|
data.update(extra)
|
||
|
|
db_obj = self.model(**data)
|
||
|
|
db.add(db_obj)
|
||
|
|
await db.flush()
|
||
|
|
await db.refresh(db_obj)
|
||
|
|
return db_obj
|
||
|
|
|
||
|
|
async def update(
|
||
|
|
self, db: AsyncSession, *, db_obj: ModelType, obj_in: UpdateSchemaType | dict[str, Any]
|
||
|
|
) -> ModelType:
|
||
|
|
if isinstance(obj_in, dict):
|
||
|
|
update_data = obj_in
|
||
|
|
else:
|
||
|
|
update_data = obj_in.model_dump(exclude_unset=True)
|
||
|
|
for field, value in update_data.items():
|
||
|
|
setattr(db_obj, field, value)
|
||
|
|
db.add(db_obj)
|
||
|
|
await db.flush()
|
||
|
|
await db.refresh(db_obj)
|
||
|
|
return db_obj
|
||
|
|
|
||
|
|
async def remove(self, db: AsyncSession, *, id: UUID) -> ModelType | None:
|
||
|
|
db_obj = await self.get(db, id=id)
|
||
|
|
if db_obj:
|
||
|
|
await db.delete(db_obj)
|
||
|
|
await db.flush()
|
||
|
|
return db_obj
|