import logging
import traceback
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, List, Optional, Type, Union, Any
from sqlalchemy.orm import Session
from sqlalchemy import select, delete, func

from utils.exceptions import DatabaseError, NotFoundException

T = TypeVar('T')  # Pydantic entity type
DBModel = TypeVar('DBModel')  # SQLAlchemy model type


class BaseRepository(ABC, Generic[T, DBModel]):
    """
    Abstract base repository for handling database operations using SQLAlchemy.

    Provides generic methods for CRUD operations while ensuring proper error handling.
    """

    def __init__(self, db_session: Session, model_class: Type[DBModel]):
        """
        Initializes the repository with a database session and model class.

        :param db_session: SQLAlchemy database session.
        :param model_class: The SQLAlchemy model class associated with the repository.
        """
        self.db = db_session
        self.model_class = model_class

    async def get_all(self, start: Optional[int] = None, limit: Optional[int] = None) -> List[T]:
        """
        Retrieves all records of the given model with optional pagination.

        :param start: Offset for pagination.
        :param limit: Maximum number of records to retrieve.
        :return: List of Pydantic entities representing the database records.
        :raises DatabaseError: If a database operation fails.
        """
        try:
            query = select(self.model_class)
            if start is not None:
                query = query.offset(start)

            if limit is not None and limit > 0:
                query = query.limit(limit)
            result = self.db.execute(query)
            db_models = result.scalars().all()
            return [self._to_entity(model) for model in db_models]
        except Exception:
            logging.error(traceback.format_exc())
            raise DatabaseError()

    async def get_all_count(self) -> int:
        """
        Retrieves the total count of records in the table.

        :return: Total number of records.
        :raises DatabaseError: If a database operation fails.
        """
        try:
            query = select(func.count()).select_from(self.model_class)
            result = self.db.execute(query)
            return result.scalar_one()
        except Exception as e:
            raise DatabaseError() from e

    async def get_by_id(self, id: Union[int, str]) -> Optional[T]:
        """
        Retrieves a record by its ID.

        :param id: The unique identifier of the record.
        :return: The corresponding Pydantic entity, or None if not found.
        :raises DatabaseError: If a database operation fails.
        """
        try:
            query = select(self.model_class).where(self.model_class.id == id)
            result = self.db.execute(query)
            db_model = result.scalar_one_or_none()
            return self._to_entity(db_model) if db_model else None
        except Exception:
            raise DatabaseError()

    async def get_by_id_raw(self, id: Union[int, str]) -> Optional[DBModel]:
        """
        Retrieves a raw SQLAlchemy model by its ID without conversion.

        :param id: The unique identifier of the record.
        :return: The SQLAlchemy model instance, or None if not found.
        :raises DatabaseError: If a database operation fails.
        """
        try:
            query = select(self.model_class).where(self.model_class.id == id)
            result = self.db.execute(query)
            return result.scalar_one_or_none()
        except Exception as e:
            raise DatabaseError()

    async def get_by_field(self, field_name: str, value: Any, convert_to_entity=True) -> Optional[T] | Optional[DBModel]:
        """
        Retrieves a record by a specific field.

        :param field_name: The field to filter by.
        :param value: The value to match.
        :param convert_to_entity: True/False weather to convert db model to entity or not.
        :return: The corresponding Pydantic entity, or None if not found.
        :raises DatabaseError: If a database operation fails.
        """
        try:
            column = getattr(self.model_class, field_name)
            query = select(self.model_class).where(column == value)
            result = self.db.execute(query)
            db_model = result.scalar_one_or_none()
            if convert_to_entity:
                return self._to_entity(db_model) if db_model else None
            return db_model if db_model else None
        except Exception:
            logging.error(traceback.format_exc())
            raise DatabaseError()

    async def create(self, entity: T) -> T:
        """
        Creates a new record in the database.

        :param entity: The Pydantic entity to be saved.
        :return: The saved Pydantic entity.
        :raises DatabaseError: If a database operation fails.
        """
        try:
            entity = self.before_save(entity)
            db_model = self._to_db_model(entity)
            self.db.add(db_model)
            self.db.commit()
            self.db.refresh(db_model)
            return self.after_save(entity, db_model)
        except Exception as e:
            self.db.rollback()
            logging.error(f"Error creating entity: {str(e)}")
            logging.error(traceback.format_exc())
            raise DatabaseError()

    async def update(self, id: Union[int, str], entity: T) -> T:
        """
        Updates an existing record in the database.

        :param id: The unique identifier of the record to update.
        :param entity: The Pydantic entity with updated values.
        :return: The updated Pydantic entity.
        :raises DatabaseError: If a database operation fails.
        """
        try:
            query = select(self.model_class).where(self.model_class.id == id)
            result = self.db.execute(query)
            db_model = result.scalar_one_or_none()

            if not db_model:
                raise DatabaseError()

            updated_fields = self._get_clean_attributes(self._to_db_model(entity))
            for key, value in updated_fields.items():
                setattr(db_model, key, value)

            self.db.commit()
            self.db.refresh(db_model)
            return self._to_entity(db_model)
        except Exception as e:
            self.db.rollback()
            logging.error(f"Error updating entity: {str(e)}")
            logging.error(traceback.format_exc())
            raise DatabaseError()

    async def delete(self, id: Union[int, str]) -> bool:
        """
        Deletes a record by its ID.

        :param id: The unique identifier of the record to delete.
        :raises DatabaseError: If a database operation fails.
        :raises NotFoundException: If a model class record not found against that id.
        """
        model_class = self.get_by_id(id)
        if model_class is None:
            raise NotFoundException

        try:
            query = delete(self.model_class).where(self.model_class.id == id)
            self.db.execute(query)
            self.db.commit()
            return True
        except Exception as e:
            self.db.rollback()
            raise DatabaseError()

    def _get_clean_attributes(self, db_model: DBModel) -> dict:
        """
        Removes private and SQLAlchemy-specific fields from a model before an update.

        :param db_model: The SQLAlchemy model instance.
        :return: A dictionary of cleaned attributes.
        """
        return {
            key: value for key, value in db_model.__dict__.items()
            if not key.startswith("_")
        }

    def before_save(self, entity: T) -> T:
        """
        Hook method for preprocessing an entity before saving.

        Override in subclasses to implement custom logic.

        :param entity: The Pydantic entity before saving.
        :return: The processed entity.
        """
        return entity

    def after_save(self, entity: T, db_model: DBModel) -> T:
        """
        Hook method for post-processing an entity after saving.

        Override in subclasses to implement custom logic.

        :param entity: The Pydantic entity before saving.
        :param db_model: The saved SQLAlchemy model.
        :return: The final processed entity.
        """
        return self._to_entity(db_model)

    @abstractmethod
    def _to_entity(self, db_model: Optional[DBModel]) -> Optional[T]:
        """
        Converts a database model instance into a Pydantic entity.

        :param db_model: The SQLAlchemy model instance.
        :return: The corresponding Pydantic entity.
        """
        pass

    @abstractmethod
    def _to_db_model(self, entity: T) -> DBModel:
        """
        Converts a Pydantic entity into a database model.

        :param entity: The Pydantic entity.
        :return: The corresponding SQLAlchemy model instance.
        """
        pass
