Skip to content
Snippets Groups Projects
Commit ec850fda authored by Pierre Chanial's avatar Pierre Chanial
Browse files

Table accessors.

parent 0d742207
No related branches found
No related tags found
1 merge request!19Resolve "Table DB model"
Pipeline #15939 passed
Showing
with 445 additions and 163 deletions
......@@ -32,11 +32,11 @@ def create_dataset(
"""Creates a dataset in a project."""
if '.' not in dataset_in.name:
dataset_in.name = f'{project}.{dataset_in.name}'
if db.dataset.get(session, dataset_in.name) is not None:
schema = dataset_in.name.split('.')[1]
elif (project_in := dataset_in.name.split('.')[0]) != project:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"The dataset '{schema}' already exists in project '{project}'.",
status.HTTP_400_BAD_REQUEST,
f'The project specified in the path ({project}) differs '
f'from that specified in the request body ({project_in}).',
)
dataset_out = db.dataset.create(session, dataset_in)
......@@ -52,10 +52,4 @@ def get_dataset(
) -> Dataset:
"""Gets a dataset."""
name = f'{project}.{dataset}'
dataset_out = db.dataset.get(session, name)
if dataset_out is None:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
f"The dataset '{dataset}' does not exist in the project '{project}'.",
)
return dataset_out
return db.dataset.get_by_name(session, name)
......@@ -29,12 +29,6 @@ def create_project(
*, session: Session = Depends(get_session), project_in: ProjectCreate
) -> Project:
"""Creates a new project."""
if db.project.get(session, project_in.name) is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"The project '{project_in.name} already exists.",
)
is_user_project = project_in.uri is None
if is_user_project:
......@@ -78,8 +72,4 @@ def _update_server_available_size(
@router.get('/{project}', summary='Gets a project.')
def get_project(project: str, *, session: Session = Depends(get_session)) -> Project:
"""Gets a project visible to a user."""
project_ = db.project.get(session, project)
if project_ is None:
msg = f"The project '{project}' does not exist."
raise HTTPException(status.HTTP_404_NOT_FOUND, msg)
return project_
return db.project.get_by_name(session, project)
......@@ -8,14 +8,19 @@ from fastapi import APIRouter, Depends, HTTPException, status
from pandas import DataFrame
from sqlalchemy import text
from sqlalchemy.engine import Connection
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from ... import db
from ...db.dbadmin import begin_session
from ...db.dbprojects import project_engines, select_table_column_names
from ...schemas import (
BodyCreateTableFromESAPGatewayQuery,
BodyCreateTableFromMapping,
Table,
TableCreate,
)
from ..depends import get_connection
from ..depends import get_connection, get_session
logger = logging.getLogger(__name__)
router = APIRouter()
......@@ -27,21 +32,10 @@ HEADERS_JSON = {'Accept': 'application/json'}
summary='Lists the tables of a dataset.',
)
def list_tables(
project: str, dataset: str, connection: Connection = Depends(get_connection)
project: str, dataset: str, session: Session = Depends(get_session)
) -> list[Table]:
"""Lists the table in a dataset."""
stmt = text(
f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{dataset}'" # noqa: E501
)
result = connection.execute(stmt)
tables = result.scalars().all()
return [
Table(
name=f'{project}.{dataset}.{_}',
description=_get_table_description_postgresql(connection, dataset, _),
)
for _ in tables
]
"""Lists the tables in a dataset."""
return db.table.list_by_dataset(session, f'{project}.{dataset}')
@router.post(
......@@ -55,31 +49,40 @@ def create_table_from_mapping(
dataset: str,
body: BodyCreateTableFromMapping,
if_exists: Literal['fail', 'replace', 'append'] = 'fail',
connection: Connection = Depends(get_connection),
session: Session = Depends(get_session),
) -> Table:
"""Creates a table from a columnar input."""
*_, table = body.name.split('.')
description = body.description
dataframes = DataFrame(body.content)
try:
dataframes.to_sql(
table,
project_engines[project],
schema=dataset,
if_exists=if_exists,
index=False,
if len(_) == 2 and _[0] != project:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
f'The project specified in the path ({project}) differs '
f'from that specified in the request body ({_[0]}).',
)
except ValueError:
if if_exists != 'fail':
raise
if len(_) > 1 and _[-1] != dataset:
raise HTTPException(
status.HTTP_409_CONFLICT,
f"The table '{table}' already exists in dataset '{project}.{dataset}'.",
status.HTTP_400_BAD_REQUEST,
f'The dataset specified in the path ({dataset}) differs'
f' from that specified in the request body ({_[-1]}).',
)
table_name = f'{project}.{dataset}.{table}'
if if_exists == 'fail':
table_create = TableCreate(name=table_name, description=body.description)
table_out = db.table.create(session, table_create)
else:
table_out = db.table.get_by_name(session, table_name)
_set_table_description_postgresql(connection, dataset, table, description)
dataframes = DataFrame(body.content)
dataframes.to_sql(
table,
project_engines[project],
schema=dataset,
if_exists=if_exists,
index=False,
)
return Table(name=f'{project}.{dataset}.{table}', description=description)
return table_out
@router.post(
......@@ -89,27 +92,33 @@ def create_table_from_mapping(
def create_esap_gateway_operation(
project: str,
body: BodyCreateTableFromESAPGatewayQuery,
connection: Connection = Depends(get_connection),
) -> Table:
"""Creates an operation that queries the ESAP API Gateway."""
*_, dataset, table = body.name.split('.')
session = requests.Session()
session.headers.update(HEADERS_JSON)
session.trust_env = False
requests_session = requests.Session()
requests_session.headers.update(HEADERS_JSON)
requests_session.trust_env = False
page = 1
while True:
try:
_create_esap_gateway_operation_paginated(
project, dataset, table, session, body.query, page
project, dataset, table, requests_session, body.query, page
)
except StopIteration:
break
page += 1
_set_table_description_postgresql(connection, dataset, table, body.description)
table_create = TableCreate(
name=f'{project}.{dataset}.{table}', description=body.description
)
return Table(name=f'{project}.{dataset}.{table}', description=body.description)
try:
with begin_session() as session:
return db.table.create(session, table_create)
except SQLAlchemyError:
# XXX drop table in project database
raise
def _create_esap_gateway_operation_paginated(
......@@ -154,52 +163,10 @@ def get_table(
project: str,
dataset: str,
table: str,
connection: Connection = Depends(get_connection),
session: Session = Depends(get_session),
) -> Table:
"""Lists the datasets belonging to a project."""
stmt = text(
f"""
SELECT
FROM information_schema.tables
WHERE table_schema='{dataset}' AND table_name='{table}'"""
)
result = connection.execute(stmt).first()
if result is None:
stmt = text(
f"""
SELECT
FROM information_schema.schemata
WHERE schema_name='{dataset}'"""
)
result = connection.execute(stmt).first()
if result is None:
msg = f"The dataset '{dataset}' does not exist in the project '{project}'." # noqa
raise HTTPException(status.HTTP_404_NOT_FOUND, msg)
dataset_name = f'{project}.{dataset}'
msg = f"The table '{table}' does not exist in the dataset '{dataset_name}'."
raise HTTPException(status.HTTP_404_NOT_FOUND, msg)
description = _get_table_description_postgresql(connection, dataset, table)
return Table(
name=f'{project}.{dataset}.{table}',
description=description,
)
def _get_table_description_postgresql(
conn: Connection, dataset: str, table: str
) -> str:
stmt = text(f"SELECT obj_description(CAST('{dataset}.{table}' AS regclass))")
return conn.execute(stmt).scalar_one() or ''
def _set_table_description_postgresql(
conn: Connection, dataset: str, table: str, description: str
) -> None:
stmt = text(f"COMMENT ON TABLE {dataset}.{table} IS '{description}'")
conn.execute(stmt)
"""Gets a table from a dataset."""
return db.table.get_by_name(session, f'{project}.{dataset}.{table}')
@router.get(
......
"""Accessors to the admin or project databases."""
from .accessors import dataset, project
from .accessors import dataset, project, table
__all__ = (
'project',
'dataset',
'table',
)
"""SQLAlchemy accessors for the resources managed by the admin database."""
from .dataset import dataset_accessor as dataset
from .project import project_accessor as project
from .table import table_accessor as table
__all__ = (
'project',
'dataset',
'table',
)
"""Provides the base class for accessor objects."""
from typing import Generic, Optional, Type, TypeVar
from typing import Any, Generic, Optional, Type, TypeVar
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, SecretStr
from sqlalchemy import Column
from sqlalchemy.exc import IntegrityError
from sqlalchemy.future import select
from sqlalchemy.orm import Session
from ...exceptions import ESAPDBResourceExistsError, ESAPDBResourceNotFoundError
from ..dbadmin.models import Base
DBModelType = TypeVar('DBModelType', bound=Base)
SchemaType = TypeVar('SchemaType', bound=BaseModel)
CreateSchemaType = TypeVar('CreateSchemaType', bound=BaseModel)
IdentifierType = TypeVar('IdentifierType', int, str)
class AccessorBase(Generic[DBModelType, SchemaType, CreateSchemaType, IdentifierType]):
class AccessorBase(Generic[DBModelType, SchemaType, CreateSchemaType]):
"""The accessor base class."""
resource_type = ''
def __init__(
self,
model: Type[DBModelType],
schema: Type[SchemaType],
model_primary_key: Optional[Column],
):
"""Accessor object with default methods to Create, Retrieve, Update, Delete (CRUD).
Parameters:
`model`: A SQLAlchemy model class.
`schema`: A Pydantic model class.
`model_primary_key`: The column to by used to select a resource.
"""
self.model = model
self.schema = schema
if model_primary_key is None:
model_primary_key = self.model.id # type: ignore
self.model_primary_key = model_primary_key
def get(self, session: Session, identifier: IdentifierType) -> Optional[SchemaType]:
def get_by_id(self, session: Session, id: Any) -> SchemaType:
"""Retrieves a resource, as specified by the model primary key."""
stmt = select(self.model).where(self.model_primary_key == identifier)
stmt = select(self.model).where(self.model.id == id) # type: ignore
resource = session.execute(stmt).scalars().first()
if resource is None:
return None
raise ESAPDBResourceNotFoundError(
f"The {self.resource_type} '{id}' does not exist."
)
return self.schema.from_orm(resource)
def list_(
......@@ -61,17 +61,14 @@ class AccessorBase(Generic[DBModelType, SchemaType, CreateSchemaType, Identifier
session.flush()
return self.schema.from_orm(db_resource)
def delete(
self, session: Session, identifier: IdentifierType
) -> Optional[SchemaType]:
def delete_by_id(self, session: Session, id: Any) -> None:
"""Deletes a resource, as specified by the model primary key."""
stmt = select(self.model).where(self.model_primary_key == identifier)
stmt = select(self.model).where(self.model.id == id) # type: ignore
db_resource = session.execute(stmt).scalars().first()
if db_resource is None:
return db_resource
return
session.delete(db_resource)
session.flush()
return self.schema.from_orm(db_resource)
def serialize_create_schema(
self, session: Session, resource: CreateSchemaType
......@@ -80,3 +77,53 @@ class AccessorBase(Generic[DBModelType, SchemaType, CreateSchemaType, Identifier
return jsonable_encoder(
resource, custom_encoder={SecretStr: lambda v: v.get_secret_value()}
)
class AccessorWithNameBase(AccessorBase[DBModelType, SchemaType, CreateSchemaType]):
"""The accessor base class, for resources that can also be identified by name."""
def get_by_name(self, session: Session, name: str) -> SchemaType:
"""Retrieves a resource, as specified by their name."""
self.validate_full_name(name)
stmt = select(self.model).where(self.model.name == name) # type: ignore
resource = session.execute(stmt).scalars().first()
if resource is not None:
return self.schema.from_orm(resource)
msg = f"The {self.resource_type} '{name.split('.')[-1]}' does not exist"
parent = self.get_parent_resource(session, name)
if parent is None:
raise ESAPDBResourceNotFoundError(f'{msg}.')
parent_type = type(parent).__name__.lower()
raise ESAPDBResourceNotFoundError(
f"{msg} in the {parent_type} '{parent.name}'." # type: ignore
)
def create(self, session: Session, resource: CreateSchemaType) -> SchemaType:
"""Creates a resource."""
try:
return super().create(session, resource)
except IntegrityError:
msg = f"The {self.resource_type} '{resource.name}' already exists." # type: ignore # noqa
raise ESAPDBResourceExistsError(msg)
def delete_by_name(self, session: Session, name: str) -> None:
"""Deletes a resource, as specified by their name."""
self.validate_full_name(name)
stmt = select(self.model).where(self.model.name == name) # type: ignore
db_resource = session.execute(stmt).scalars().first()
if db_resource is None:
self.get_parent_resource(session, name)
return None
session.delete(db_resource)
session.flush()
def validate_full_name(self, name: str) -> None:
"""Validates that the name is correct."""
def get_parent_resource(self, session: Session, name: str) -> Optional[BaseModel]:
"""Returns the containing resource.
I.e a project for a dataset and a dataset for a table.
"""
return None
"""Provides the dataset accessor object."""
from typing import Optional
from fastapi import HTTPException, status
from sqlalchemy.future import select
from sqlalchemy.orm import Session
from ...schemas.dataset import Dataset, DatasetCreate
from ...schemas.project import Project
from ..dbadmin import DBDataset
from .base import AccessorBase
from .base import AccessorWithNameBase
from .project import project_accessor
class AccessorDataset(AccessorBase[DBDataset, Dataset, DatasetCreate, str]):
class AccessorDataset(AccessorWithNameBase[DBDataset, Dataset, DatasetCreate]):
"""The Dataset accessor class."""
def get(self, session: Session, identifier: str) -> Optional[Dataset]:
"""Retrieves a dataset, as specified by its name."""
dataset = super().get(session, identifier)
if dataset is not None:
return dataset
self._get_project(session, identifier)
return None
resource_type = 'dataset'
def list_by_project(self, session: Session, project_name: str) -> list[Dataset]:
"""Lists resources, according to pagination parameters."""
project = self._get_project(session, project_name + '.')
project = project_accessor.get_by_name(session, project_name)
stmt = select(self.model).where(self.model.project_id == project.id)
resources = session.execute(stmt).scalars()
return [self.schema.from_orm(_) for _ in resources] # type: ignore
......@@ -34,26 +25,24 @@ class AccessorDataset(AccessorBase[DBDataset, Dataset, DatasetCreate, str]):
self, session: Session, resource: DatasetCreate
) -> dict:
"""Serializes a DatasetCreate instance into a dict."""
self.validate_full_name(resource.name)
serialized_resource = super().serialize_create_schema(session, resource)
project = self._get_project(session, resource.name)
project = self.get_parent_resource(session, resource.name)
serialized_resource['project_id'] = project.id
return serialized_resource
@staticmethod
def _get_project(session: Session, identifier: str) -> Project:
parts = identifier.split('.')
def validate_full_name(self, name: str) -> None:
"""Validates that the name is of form 'project.dataset'."""
parts = name.split('.')
if len(parts) != 2:
raise ValueError(
f"The dataset name '{identifier}' is not of the form 'project.dataset'."
)
project_name = parts[0]
project = project_accessor.get(session, project_name)
if project is None:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
f"The project '{project_name}' does not exist.",
f"The dataset name '{name}' is not of the form 'project.dataset'."
)
return project
def get_parent_resource(self, session: Session, name: str) -> Project:
"""Returns the `Project` instance associated with the specified dataset name."""
project_name = name.split('.')[0]
return project_accessor.get_by_name(session, project_name)
dataset_accessor = AccessorDataset(DBDataset, Dataset, DBDataset.name)
dataset_accessor = AccessorDataset(DBDataset, Dataset)
"""Provides the project accessor object."""
from ...schemas.project import Project, ProjectCreate
from ..dbadmin import DBProject
from .base import AccessorBase
from .base import AccessorWithNameBase
class AccessorProject(AccessorBase[DBProject, Project, ProjectCreate, str]):
class AccessorProject(AccessorWithNameBase[DBProject, Project, ProjectCreate]):
"""The Project accessor class."""
resource_type = 'project'
project_accessor = AccessorProject(DBProject, Project, DBProject.name)
project_accessor = AccessorProject(DBProject, Project)
"""Provides the dataset accessor object."""
from sqlalchemy.future import select
from sqlalchemy.orm import Session
from ...schemas.dataset import Dataset
from ...schemas.table import Table, TableCreate
from ..dbadmin import DBTable
from .base import AccessorWithNameBase
from .dataset import dataset_accessor
class AccessorTable(AccessorWithNameBase[DBTable, Table, TableCreate]):
"""The Dataset accessor class."""
resource_type = 'table'
def list_by_dataset(self, session: Session, dataset_name: str) -> list[Table]:
"""Lists resources, according to pagination parameters."""
dataset = dataset_accessor.get_by_name(session, dataset_name)
stmt = select(self.model).where(self.model.dataset_id == dataset.id)
resources = session.execute(stmt).scalars()
return [self.schema.from_orm(_) for _ in resources] # type: ignore
def serialize_create_schema(self, session: Session, resource: TableCreate) -> dict:
"""Serializes a TableCreate instance into a dict."""
self.validate_full_name(resource.name)
serialized_resource = super().serialize_create_schema(session, resource)
dataset = self.get_parent_resource(session, resource.name)
serialized_resource['project_id'] = dataset.project_id
serialized_resource['dataset_id'] = dataset.id
return serialized_resource
def validate_full_name(self, name: str) -> None:
"""Validates that the name is of form 'project.dataset.table'."""
parts = name.split('.')
if len(parts) != 3:
raise ValueError(
f"The table name '{name}' is not of the form 'project.dataset.table'."
)
def get_parent_resource(self, session: Session, name: str) -> Dataset:
"""Returns the `Dataset` instance associated with the specified table name."""
dataset_name = '.'.join(name.split('.')[:2])
return dataset_accessor.get_by_name(session, dataset_name)
table_accessor = AccessorTable(DBTable, Table)
"""The ESAP-DB module handling the admin database."""
from .models import DBDataset, DBProject, DBProjectServer, DBUser
from .models import DBDataset, DBProject, DBProjectServer, DBTable, DBUser
from .sessions import begin_session
__all__ = (
'DBDataset',
'DBProject',
'DBProjectServer',
'DBTable',
'DBUser',
'begin_session',
)
......@@ -37,6 +37,17 @@ class DBDataset(Base):
description = Column(String, nullable=False)
class DBTable(Base):
"""The model representing an ESAP-DB table."""
__tablename__ = 'tables'
id = Column(Integer, primary_key=True, index=True)
project_id = Column(Integer, ForeignKey('projects.id'), nullable=False)
dataset_id = Column(Integer, ForeignKey('datasets.id'), nullable=False)
name = Column(String(513), unique=True, index=True, nullable=False)
description = Column(String, nullable=False)
class DBUser(Base):
"""The model representing an ESAP-DB user."""
......
"""This modules defines the exceptions used by ESAP-DB.
These exceptions, if not caught, are handled by the FastAPI error handlers
and translated into HTTP errors.
"""
from __future__ import annotations
from fastapi import Request
from fastapi.responses import JSONResponse
class ESAPDBError(Exception):
"""The base class for ESAP-DB exceptions."""
status_code = 500
@classmethod
async def error_handler(cls, request: Request, exc: ESAPDBError) -> JSONResponse:
"""The FastAPI error handler associated with this exception."""
return JSONResponse(
status_code=cls.status_code,
content={'status': cls.status_code, 'detail': str(exc)},
)
class ESAPDBValidationError(ESAPDBError, ValueError):
"""When a validation specific to ESAP-DB fails."""
status_code = 400
class ESAPDBResourceNotFoundError(ESAPDBError):
"""When a resource is expected to exist and does not."""
status_code = 404
class ESAPDBResourceExistsError(ESAPDBError):
"""When a resource is not expected to already exists."""
status_code = 409
......@@ -8,6 +8,11 @@ from starlette.middleware.cors import CORSMiddleware
from .apis.v0 import api_v0_router
from .config import settings
from .exceptions import (
ESAPDBResourceExistsError,
ESAPDBResourceNotFoundError,
ESAPDBValidationError,
)
logger = logging.getLogger(__name__)
......@@ -30,3 +35,10 @@ if settings.BACKEND_CORS_ORIGINS:
)
app.include_router(api_v0_router, prefix=settings.API_V0_STR)
for _exc in (
ESAPDBResourceExistsError,
ESAPDBResourceNotFoundError,
ESAPDBValidationError,
):
app.exception_handler(_exc)(_exc.error_handler)
......@@ -7,7 +7,7 @@ client.
from .dataset import Dataset, DatasetCreate
from .project import Project, ProjectCreate
from .schemas import BodyCreateTableFromESAPGatewayQuery, BodyCreateTableFromMapping
from .table import Table
from .table import Table, TableCreate
from .user import User
__all__ = (
......@@ -18,5 +18,6 @@ __all__ = (
'Project',
'ProjectCreate',
'Table',
'TableCreate',
'User',
)
"""The Pydantic classes to represent a table."""
import re
from pydantic import BaseModel
from pydantic import BaseModel, validator
from .helpers import IDENTIFIER_REGEX_STR
from .helpers import IDENTIFIER_MAX_LENGTH, IDENTIFIER_REGEX_STR
TABLE_REGEX = re.compile(rf'^({IDENTIFIER_REGEX_STR}\.){{0,2}}{IDENTIFIER_REGEX_STR}$')
class Table(BaseModel):
class TableBase(BaseModel):
"""The Table schema."""
name: str
description: str = ''
class TableCreate(TableBase):
"""Schema to create a table."""
@validator('name')
def check_name(cls, v: str) -> str:
"""Table name validator."""
*_, table_name = v.split('.')
if len(table_name) > IDENTIFIER_MAX_LENGTH:
raise ValueError(
f'The table name length is greater than {IDENTIFIER_MAX_LENGTH}: '
f'{table_name}'
)
if _ and v.startswith('_'):
raise ValueError(f'A project cannot start with an underscore: {v}')
if TABLE_REGEX.match(v) is None:
raise ValueError(
f'Invalid table name: {v}.'
"Valid characters are alphanumerical or '_#@$'."
)
return v
class Config:
schema_extra = {
'example': {
'name': 'my_project.my_dataset.my_table',
'description': 'My first table 😁',
}
}
class Table(TableBase):
"""The Table schema mapped to the database."""
id: int
project_id: int
dataset_id: int
class Config:
orm_mode = True
"""tables table
Revision ID: 0493dc8f99ac
Revises: 0d789bef6f2d
Create Date: 2021-07-30 01:21:25.322482
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = '0493dc8f99ac'
down_revision = '0d789bef6f2d'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
'tables',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('project_id', sa.Integer(), nullable=False),
sa.Column('dataset_id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=513), nullable=False),
sa.Column('description', sa.String(), nullable=False),
sa.ForeignKeyConstraint(
['dataset_id'],
['datasets.id'],
),
sa.ForeignKeyConstraint(
['project_id'],
['projects.id'],
),
sa.PrimaryKeyConstraint('id'),
)
op.create_index(op.f('ix_tables_id'), 'tables', ['id'], unique=False)
op.create_index(op.f('ix_tables_name'), 'tables', ['name'], unique=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_tables_name'), table_name='tables')
op.drop_index(op.f('ix_tables_id'), table_name='tables')
op.drop_table('tables')
# ### end Alembic commands ###
......@@ -4,4 +4,4 @@ set -ex
# Let the DB start
python /code/scripts/wait_for_initialized_dbadmin.py
pytest --cov=app --cov-report=term-missing tests "${@}"
pytest --cov=app --cov-report=term-missing "${@}"
......@@ -40,6 +40,18 @@ def test_create_dataset_duplicate_failure(client: TestClient, project: Project)
assert response.status_code == 409
def test_create_dataset_invalid_path(client: TestClient, project: Project) -> None:
dataset = {'name': 'project.dataset'}
api = f'{settings.API_V0_STR}/projects/{project.name}/datasets'
response = client.post(api, json=dataset)
assert response.status_code == 400
msg = response.json()['detail']
assert (
msg == f'The project specified in the path ({project.name}) differs '
'from that specified in the request body (project).'
)
def test_get_dataset_success(client: TestClient, project: Project) -> None:
dataset = stage_dataset(client, project)
dataset_name = dataset.name.split('.')[1]
......
import pytest
from fastapi.testclient import TestClient
from app.config import settings
......@@ -39,7 +38,6 @@ def test_create_table_not_found(client: TestClient, project: Project) -> None:
detail = response.json()['detail']
assert detail == "The project 'UNKNOWN_PROJECT' does not exist."
pytest.xfail('failure for unknown dataset. some refactoring is needed')
response = client.post(
f'{settings.API_V0_STR}/projects/{project.name}/datasets/UNKNOWN_DATASET/tables', # noqa
json=payload,
......@@ -67,6 +65,42 @@ def test_create_table_conflict(client: TestClient, dataset: Dataset) -> None:
assert response.status_code == 409
def test_create_table_invalid_path1(client: TestClient, dataset: Dataset) -> None:
project_name, _ = dataset.name.split('.')
payload = {
'name': 'project.dataset.table',
'content': {'x': [1], 'y': ['a']},
}
response = client.post(
f'{settings.API_V0_STR}/projects/{project_name}/datasets/dataset/tables',
json=payload,
)
assert response.status_code == 400
msg = response.json()['detail']
assert (
msg == f'The project specified in the path ({project_name}) differs '
'from that specified in the request body (project).'
)
def test_create_table_invalid_path2(client: TestClient, dataset: Dataset) -> None:
project_name, _ = dataset.name.split('.')
payload = {
'name': f'{project_name}.dataset1.table',
'content': {'x': [1], 'y': ['a']},
}
response = client.post(
f'{settings.API_V0_STR}/projects/{project_name}/datasets/dataset2/tables',
json=payload,
)
assert response.status_code == 400
msg = response.json()['detail']
assert (
msg == 'The dataset specified in the path (dataset2) differs '
'from that specified in the request body (dataset1).'
)
def test_get_table_success(client: TestClient, dataset: Dataset) -> None:
data = {'x': [1, 2, 3, 4], 'y': ['a', 'b', 'c', 'd']}
table = stage_table(client, dataset, data)
......
from typing import Iterator
import pytest
from sqlalchemy.orm import Session
from app import db
from app.db.dbadmin import begin_session
from app.exceptions import ESAPDBResourceNotFoundError
from app.schemas import Dataset, Project
from .helpers import fake_dataset, fake_project
@pytest.fixture
def project(session: Session) -> Iterator[Project]:
project_create = fake_project()
yield db.project.create(session, project_create)
# Some errors are caught with pytest.raises and the session may be invalid.
with begin_session() as session:
try:
datasets = db.dataset.list_by_project(session, project_create.name)
except ESAPDBResourceNotFoundError:
return
for dataset in datasets:
db.dataset.delete_by_name(session, dataset.name)
db.project.delete_by_name(session, project_create.name)
@pytest.fixture
def dataset(session: Session, project: Project) -> Iterator[Dataset]:
dataset_create = fake_dataset(project.name)
yield db.dataset.create(session, dataset_create)
# Some errors are caught with pytest.raises and the session may be invalid.
with begin_session() as session:
try:
tables = db.table.list_by_dataset(session, dataset_create.name)
except ESAPDBResourceNotFoundError:
return
for table in tables:
db.table.delete_by_name(session, table.name)
db.dataset.delete_by_name(session, dataset_create.name)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment