Skip to content
Snippets Groups Projects
dataset.py 2.24 KiB
Newer Older
Pierre Chanial's avatar
Pierre Chanial committed
"""Provides the dataset accessor object."""
from typing import Optional

from fastapi import HTTPException, status
from sqlalchemy.future import select
from sqlalchemy.orm import Session

from ...schemas.dataset import Dataset, DatasetCreate
from ...schemas.project import Project
from ..dbadmin import DBDataset
from .base import AccessorBase
from .project import project_accessor


class AccessorDataset(AccessorBase[DBDataset, Dataset, DatasetCreate, str]):
    """The Dataset accessor class."""

    def get(self, session: Session, identifier: str) -> Optional[Dataset]:
        """Retrieves a dataset, as specified by its name."""
        dataset = super().get(session, identifier)
        if dataset is not None:
            return dataset
        self._get_project(session, identifier)
        return None

    def list_by_project(self, session: Session, project_name: str) -> list[Dataset]:
        """Lists resources, according to pagination parameters."""
        project = self._get_project(session, project_name + '.')
        stmt = select(self.model).where(self.model.project_id == project.id)
        resources = session.execute(stmt).scalars()
        return [self.schema.from_orm(_) for _ in resources]  # type: ignore

    def serialize_create_schema(
        self, session: Session, resource: DatasetCreate
    ) -> dict:
        """Serializes a DatasetCreate instance into a dict."""
        serialized_resource = super().serialize_create_schema(session, resource)
        project = self._get_project(session, resource.name)
        serialized_resource['project_id'] = project.id
        return serialized_resource

    @staticmethod
    def _get_project(session: Session, identifier: str) -> Project:
        parts = identifier.split('.')
        if len(parts) != 2:
            raise ValueError(
                f"The dataset name '{identifier}' is not of the form 'project.dataset'."
            )
        project_name = parts[0]
        project = project_accessor.get(session, project_name)
        if project is None:
            raise HTTPException(
                status.HTTP_404_NOT_FOUND,
                f"The project '{project_name}' does not exist.",
            )
        return project


dataset_accessor = AccessorDataset(DBDataset, Dataset, DBDataset.name)