from sqlalchemy.orm import Session
from models import Company
from schemas.company import CompanySchema
from repositories.base_repository import BaseRepository
from typing import Optional

class CompanyRepository(BaseRepository[CompanySchema, Company]):
    def __init__(self, db: Session):
        super().__init__(db, Company)

    def _to_entity(self, db_model: Company) -> Optional[CompanySchema]:
        return CompanySchema.model_validate(db_model) if db_model else None

    def _to_db_model(self, entity: CompanySchema) -> Company:
        return Company(**entity.model_dump())

    async def create_company(self, company_data: CompanySchema) -> CompanySchema:
        return await self.create(company_data)

    async def get_by_user_id(self, user_id: int) -> Optional[CompanySchema]:
        return self.db.query(Company).filter(Company.user_id == user_id).first() 
