from pydantic import EmailStr
from sqlalchemy.orm import Session
from models import OTPLog
from schemas.auth import OTPLogSchema
from repositories.base_repository import BaseRepository
from typing import Optional
from datetime import datetime, timezone, timedelta
import random
import string


class OTPRepository(BaseRepository[OTPLogSchema, OTPLog]):
    def __init__(self, db: Session):
        super().__init__(db, OTPLog)

    def _to_entity(self, db_model: Optional[OTPLog]) -> Optional[OTPLogSchema]:
        return db_model

    def _to_db_model(self, entity: OTPLogSchema) -> OTPLog:
        return entity

    def generate_otp(self, length=6) -> str:
        return ''.join(random.choices(string.digits, k=length))

    def get_latest_otp(self, email: str, otp_type: int) -> Optional[OTPLogSchema]:
        return self.db.query(OTPLog) \
            .filter(OTPLog.email == email, OTPLog.type == otp_type) \
            .order_by(OTPLog.created_at.desc()) \
            .first()

    def is_otp_expired(self, expired_at:datetime):
        now = datetime.now(timezone.utc)
        return expired_at < now.replace(tzinfo=None)

    def mark_otp_used(self, otp_id: int) -> None:
        otp_log = self.db.query(OTPLog).filter(OTPLog.id == otp_id).first()
        if otp_log:
            otp_log.used_at = datetime.now(timezone.utc)
            self.db.commit()

    async def create_otp(self, email: EmailStr, otp_type: int) -> OTPLogSchema:
        otp = self.generate_otp()
        now = datetime.now(timezone.utc)
        otp_log = OTPLog(
            email=email,
            otp=otp,
            type=otp_type,
            created_at=now,
            expired_at=now + timedelta(minutes=15)
        )
        return await self.create(otp_log)

    async def get_latest_otp(self, email: str, otp_type: int) -> Optional[OTPLogSchema]:
        return self.db.query(OTPLog) \
            .filter(OTPLog.email == email) \
            .filter(OTPLog.type == otp_type) \
            .filter(OTPLog.used_at.is_(None)) \
            .order_by(OTPLog.created_at.desc()) \
            .first()

    async def mark_otp_used(self, otp_log: OTPLogSchema) -> None:
        otp_log.used_at = datetime.now(timezone.utc)
        await self.update(otp_log.id, otp_log)
