from typing import Optional, List, Dict, Any
from sqlalchemy.orm import Session
from models import SpecialHoliday
from datetime import date
from repositories.base_repository import BaseRepository
from schemas.holiday import (
    SpecialHolidaySchema,
    CreateSpecialHolidaySchema,
    SpecialHolidayResponseSchema
)


class SpecialHolidayRepository(BaseRepository[SpecialHolidaySchema, SpecialHoliday]):
    def __init__(self, db: Session):
        super().__init__(db, SpecialHoliday)

    def _to_entity(self, db_model: SpecialHoliday) -> Optional[SpecialHolidaySchema]:
        return SpecialHolidaySchema(
            id=db_model.id,
            company_id=db_model.company_id,
            date=db_model.date,
            name=db_model.name,
            created_at=db_model.created_at,
            updated_at=db_model.updated_at
        )

    def _to_response(self, db_model: SpecialHoliday) -> Optional[SpecialHolidayResponseSchema]:
        return SpecialHolidayResponseSchema(
            id=db_model.id,
            date=db_model.date,
            name=db_model.name
        )

    def _to_db_model(self, entity: SpecialHolidaySchema) -> SpecialHoliday:
        return SpecialHoliday(
            company_id=entity.company_id,
            date=entity.date,
            name=entity.name
        )

    async def create_holiday(self, company_id: int, holiday_date: date, name: str) -> SpecialHolidayResponseSchema:
        holiday = CreateSpecialHolidaySchema(
            company_id=company_id,
            date=holiday_date,
            name=name
        )
        db_model = SpecialHoliday(
            company_id=holiday.company_id,
            date=holiday.date,
            name=holiday.name
        )
        self.db.add(db_model)
        self.db.commit()
        self.db.refresh(db_model)
        return self._to_response(db_model)

    async def get_by_company(self, company_id: int, start: int = 1, limit: int = 10) -> Dict[str, Any]:
        total = self.db.query(SpecialHoliday).filter(SpecialHoliday.company_id == company_id).count()
        
        holidays = self.db.query(SpecialHoliday)\
            .filter(SpecialHoliday.company_id == company_id)\
            .offset(start)\
            .limit(limit)\
            .all()
        
        # Format holidays for response
        formatted_holidays = [
            {
                'id': holiday.id,
                'date': holiday.date,
                'name': holiday.name
            }
            for holiday in holidays
        ]
        
        return {
            "data": formatted_holidays,
            "total_count": total
        }

    async def get_by_company_and_date(self, company_id: int, date: str) -> Optional[SpecialHolidayResponseSchema]:
        holiday = self.db.query(SpecialHoliday).filter(
            SpecialHoliday.company_id == company_id,
            SpecialHoliday.date == date
        ).first()
        return self._to_response(holiday) if holiday else None

    async def delete_holiday(self, holiday_id: int, company_id: int) -> bool:
        holiday = self.db.query(SpecialHoliday).filter(
            SpecialHoliday.id == holiday_id,
            SpecialHoliday.company_id == company_id
        ).first()
        if not holiday:
            return False
        return await self.delete(holiday_id)