from typing import Any, Generic, TypeVar from uuid import UUID from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.db.base import Base ModelType = TypeVar("ModelType", bound=Base) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): def __init__(self, model: type[ModelType]): self.model = model async def get(self, db: AsyncSession, *, id: UUID) -> ModelType | None: result = await db.execute(select(self.model).where(self.model.id == id)) return result.scalar_one_or_none() async def get_multi( self, db: AsyncSession, *, skip: int = 0, limit: int = 100 ) -> list[ModelType]: result = await db.execute(select(self.model).offset(skip).limit(limit)) return list(result.scalars().all()) async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType, **extra: Any) -> ModelType: data = obj_in.model_dump() data.update(extra) db_obj = self.model(**data) db.add(db_obj) await db.flush() await db.refresh(db_obj) return db_obj async def update( self, db: AsyncSession, *, db_obj: ModelType, obj_in: UpdateSchemaType | dict[str, Any] ) -> ModelType: if isinstance(obj_in, dict): update_data = obj_in else: update_data = obj_in.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(db_obj, field, value) db.add(db_obj) await db.flush() await db.refresh(db_obj) return db_obj async def remove(self, db: AsyncSession, *, id: UUID) -> ModelType | None: db_obj = await self.get(db, id=id) if db_obj: await db.delete(db_obj) await db.flush() return db_obj