diff --git a/app/apis/depends.py b/app/apis/depends.py index 285b5cfddc63e9ddba17db8be6ce9598313a3221..5fd12b0593ddb9f4271fa268a061fa58e5f33514 100644 --- a/app/apis/depends.py +++ b/app/apis/depends.py @@ -1,25 +1,25 @@ """Dependencies that are injected in the routes.""" -from collections.abc import Generator +from collections.abc import Iterator +from sqlalchemy.engine import Connection from sqlalchemy.orm import Session -from ..db import AdminSession +from ..db.dbadmin import begin_session +from ..db.dbprojects import begin_connection -def get_session() -> Generator[Session, None, None]: - """Returns admin database session. +def get_session() -> Iterator[Session]: + """Returns an admin database transaction session. - The transaction is automatically committed at the end. + At the end of the request, the transaction is automatically committed and the + session closed. When using this session, the method `session.flush` needs to be run + in order to get the server-side generated values of the added resources. """ - with AdminSession.begin() as session: + with begin_session() as session: yield session -def get_session_as_you_go() -> Generator[Session, None, None]: - """Returns an admin database session. - - When using this session, commits must be explicitly set. - SQLAlchemy refers to this style as 'commit as you go'. - """ - with AdminSession() as session: - yield session +def get_connection(project: str) -> Iterator[Connection]: + """Returns a transaction connection associated with a project.""" + with begin_connection(project) as connection: + yield connection diff --git a/app/apis/v0/datasets.py b/app/apis/v0/datasets.py index d52049abe8ecb11f5ab57233368368475a6dcea1..49ee73647d4c51c57a4e951dba5118a689ac5352 100644 --- a/app/apis/v0/datasets.py +++ b/app/apis/v0/datasets.py @@ -1,34 +1,34 @@ """Definitions of the endpoints related the datasets.""" import logging -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import text from sqlalchemy.engine import Connection -from app.schemas import Dataset - -from ...helpers import begin_transaction +from ...schemas import Dataset +from ..depends import get_connection logger = logging.getLogger(__name__) router = APIRouter() @router.get('/projects/{project}/datasets', summary='Lists the datasets of a project.') -def list_datasets(project: str) -> list[Dataset]: +def list_datasets( + project: str, connection: Connection = Depends(get_connection) +) -> list[Dataset]: """Lists the datasets belonging to a project.""" - with begin_transaction(project) as conn: - stmt = text('SELECT schema_name FROM information_schema.schemata') - result = conn.execute(stmt) - schemas = result.scalars().all() - - datasets = [ - Dataset( - name=f'{project}.{_}', - description=_get_dataset_description_postgresql(conn, _), - ) - for _ in schemas - if _filter_schema(_) - ] + stmt = text('SELECT schema_name FROM information_schema.schemata') + result = connection.execute(stmt) + schemas = result.scalars().all() + + datasets = [ + Dataset( + name=f'{project}.{_}', + description=_get_dataset_description_postgresql(connection, _), + ) + for _ in schemas + if _filter_schema(_) + ] return datasets @@ -42,13 +42,14 @@ def _filter_schema(schema: str) -> bool: @router.post('/projects/{project}/datasets', summary='Creates a dataset in a project.') -def post_dataset(project: str, dataset: Dataset) -> Dataset: +def post_dataset( + project: str, dataset: Dataset, connection: Connection = Depends(get_connection) +) -> Dataset: """Creates a dataset in a project.""" - with begin_transaction(project) as conn: - *_, schema = dataset.name.split('.', 1) - stmt = text(f'CREATE SCHEMA {schema}') - conn.execute(stmt) - _set_dataset_description_postgresql(conn, schema, dataset.description) + *_, schema = dataset.name.split('.', 1) + stmt = text(f'CREATE SCHEMA {schema}') + connection.execute(stmt) + _set_dataset_description_postgresql(connection, schema, dataset.description) return Dataset(name=f'{project}.{schema}', description=dataset.description) @@ -56,21 +57,22 @@ def post_dataset(project: str, dataset: Dataset) -> Dataset: @router.get( '/projects/{project}/datasets/{dataset}', summary='Gets a dataset of a project.' ) -def get_dataset(project: str, dataset: str) -> Dataset: +def get_dataset( + project: str, dataset: str, connection: Connection = Depends(get_connection) +) -> Dataset: """Lists the datasets belonging to a project.""" - with begin_transaction(project) as conn: - stmt = text( - f""" - SELECT - FROM information_schema.schemata - WHERE schema_name='{dataset}'""" - ) - result = conn.execute(stmt).first() - if result is None: - msg = f"The dataset '{dataset}' does not exist in the project '{project}'." - raise HTTPException(status.HTTP_404_NOT_FOUND, msg) + stmt = text( + f""" + SELECT + FROM information_schema.schemata + WHERE schema_name='{dataset}'""" + ) + result = connection.execute(stmt).first() + if result is None: + msg = f"The dataset '{dataset}' does not exist in the project '{project}'." + raise HTTPException(status.HTTP_404_NOT_FOUND, msg) - description = _get_dataset_description_postgresql(conn, dataset) + description = _get_dataset_description_postgresql(connection, dataset) return Dataset( name=f'{project}.{dataset}', @@ -78,13 +80,13 @@ def get_dataset(project: str, dataset: str) -> Dataset: ) -def _get_dataset_description_postgresql(conn: Connection, dataset: str) -> str: +def _get_dataset_description_postgresql(connection: Connection, dataset: str) -> str: stmt = text(f"SELECT obj_description(CAST('{dataset}' AS regnamespace))") - return conn.execute(stmt).scalar_one() or '' + return connection.execute(stmt).scalar_one() or '' def _set_dataset_description_postgresql( - conn: Connection, dataset: str, description: str + connection: Connection, dataset: str, description: str ) -> None: stmt = text(f"COMMENT ON SCHEMA {dataset} IS '{description}'") - conn.execute(stmt) + connection.execute(stmt) diff --git a/app/apis/v0/projects.py b/app/apis/v0/projects.py index d1eec0cf34f770af777a63b10d33791142da0292..00b29f1b0b85790201c7cad2edf102f8a1e9246c 100644 --- a/app/apis/v0/projects.py +++ b/app/apis/v0/projects.py @@ -2,16 +2,17 @@ import logging from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy import create_engine, text, update -from sqlalchemy.exc import DBAPIError +from pydantic import SecretStr +from sqlalchemy import update from sqlalchemy.future import select from sqlalchemy.orm import Session from ... import db from ...db.dbadmin import DBProjectServer +from ...db.dbprojects import create_database from ...helpers import fix_sqlalchemy2_stubs_non_nullable_column from ...schemas import Project, ProjectCreate -from ..depends import get_session, get_session_as_you_go +from ..depends import get_session logger = logging.getLogger(__name__) router = APIRouter() @@ -23,9 +24,9 @@ def list_projects(*, session: Session = Depends(get_session)) -> list[Project]: return db.project.list_(session) -@router.post('') +@router.post('', summary='Creates a project.') def create_project( - *, session: Session = Depends(get_session_as_you_go), project: ProjectCreate + *, session: Session = Depends(get_session), project: ProjectCreate ) -> Project: """Creates a new project.""" if db.project.get(session, project.name) is not None: @@ -34,36 +35,21 @@ def create_project( detail=f"The project '{project.name} already exists.", ) - if project.uri is None: + is_user_project = project.uri is None + + if is_user_project: server = _find_best_project_server(session, project.max_size) + project.project_server_id = server.id + project.uri = SecretStr(f'{server.uri}/{project.name}') - # preempt the required storage - _update_server_available_size(session, server, -project.max_size) - session.commit() - - try: - with create_engine( - fix_sqlalchemy2_stubs_non_nullable_column(server.uri), - pool_pre_ping=True, - future=True, - ).connect() as conn: - stmt = text(f'CREATE DATABASE "{project.name}"') - conn.execution_options(isolation_level='AUTOCOMMIT').execute(stmt) - except DBAPIError as exc: - logger.error(str(exc)) - # release the preempted storage - _update_server_available_size(session, server, +project.max_size) - session.commit() - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail='The project database could not be created: ' - f'{type(exc).__name__}: {exc}', - ) + output = db.project.create(session, project) - project.project_server_id = server.id - project.uri = f'{server.uri}/{project.name}' + if is_user_project: + _update_server_available_size(session, server, -project.max_size) + server_uri = fix_sqlalchemy2_stubs_non_nullable_column(server.uri) + create_database(server_uri, project.name) - return db.project.create(session, project) + return output def _find_best_project_server(session: Session, max_size: int) -> DBProjectServer: @@ -86,6 +72,7 @@ def _update_server_available_size( .values(available_size=DBProjectServer.available_size + size) ) session.execute(stmt) + session.flush() @router.get('/{project}', summary='Gets a project.') diff --git a/app/apis/v0/tables.py b/app/apis/v0/tables.py index 435fa37614bf2def0362d914f0d8dd1860a69ed2..34c9f8c75d133d530317ae46bdf31a8b54407616 100644 --- a/app/apis/v0/tables.py +++ b/app/apis/v0/tables.py @@ -4,19 +4,18 @@ import os from typing import Literal import requests -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status from pandas import DataFrame from sqlalchemy import text from sqlalchemy.engine import Connection -from app.schemas import ( +from ...db.dbprojects import project_engines, select_table_column_names +from ...schemas import ( BodyCreateTableFromESAPGatewayQuery, BodyCreateTableFromMapping, Table, ) - -from ...db import project_engines -from ...helpers import begin_transaction, table_column_names +from ..depends import get_connection logger = logging.getLogger(__name__) router = APIRouter() @@ -27,21 +26,22 @@ HEADERS_JSON = {'Accept': 'application/json'} '/projects/{project}/datasets/{dataset}/tables', summary='Lists the tables of a dataset.', ) -def list_tables(project: str, dataset: str) -> list[Table]: +def list_tables( + project: str, dataset: str, connection: Connection = Depends(get_connection) +) -> list[Table]: """Lists the table in a dataset.""" - with begin_transaction(project) as conn: - stmt = text( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{dataset}'" # noqa: E501 + stmt = text( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{dataset}'" # noqa: E501 + ) + result = connection.execute(stmt) + tables = result.scalars().all() + return [ + Table( + name=f'{project}.{dataset}.{_}', + description=_get_table_description_postgresql(connection, dataset, _), ) - result = conn.execute(stmt) - tables = result.scalars().all() - return [ - Table( - name=f'{project}.{dataset}.{_}', - description=_get_table_description_postgresql(conn, dataset, _), - ) - for _ in tables - ] + for _ in tables + ] @router.post( @@ -55,6 +55,7 @@ def create_table_from_mapping( dataset: str, body: BodyCreateTableFromMapping, if_exists: Literal['fail', 'replace', 'append'] = 'fail', + connection: Connection = Depends(get_connection), ) -> Table: """Creates a table from a columnar input.""" *_, table = body.name.split('.') @@ -76,8 +77,7 @@ def create_table_from_mapping( f"The table '{table}' already exists in dataset '{project}.{dataset}'.", ) - with begin_transaction(project) as conn: - _set_table_description_postgresql(conn, dataset, table, description) + _set_table_description_postgresql(connection, dataset, table, description) return Table(name=f'{project}.{dataset}.{table}', description=description) @@ -87,7 +87,9 @@ def create_table_from_mapping( summary='Performs an ESAP Gateway query and stores it in a table.', ) def create_esap_gateway_operation( - project: str, body: BodyCreateTableFromESAPGatewayQuery + project: str, + body: BodyCreateTableFromESAPGatewayQuery, + connection: Connection = Depends(get_connection), ) -> Table: """Creates an operation that queries the ESAP API Gateway.""" *_, dataset, table = body.name.split('.') @@ -105,8 +107,7 @@ def create_esap_gateway_operation( break page += 1 - with begin_transaction(project) as conn: - _set_table_description_postgresql(conn, dataset, table, body.description) + _set_table_description_postgresql(connection, dataset, table, body.description) return Table(name=f'{project}.{dataset}.{table}', description=body.description) @@ -149,33 +150,37 @@ def _create_esap_gateway_operation_paginated( '/projects/{project}/datasets/{dataset}/tables/{table}', summary='Gets a table from a dataset.', ) -def get_table(project: str, dataset: str, table: str) -> Table: +def get_table( + project: str, + dataset: str, + table: str, + connection: Connection = Depends(get_connection), +) -> Table: """Lists the datasets belonging to a project.""" - with begin_transaction(project) as conn: + stmt = text( + f""" + SELECT + FROM information_schema.tables + WHERE table_schema='{dataset}' AND table_name='{table}'""" + ) + result = connection.execute(stmt).first() + if result is None: stmt = text( f""" - SELECT - FROM information_schema.tables - WHERE table_schema='{dataset}' AND table_name='{table}'""" + SELECT + FROM information_schema.schemata + WHERE schema_name='{dataset}'""" ) - result = conn.execute(stmt).first() + result = connection.execute(stmt).first() if result is None: - stmt = text( - f""" - SELECT - FROM information_schema.schemata - WHERE schema_name='{dataset}'""" - ) - result = conn.execute(stmt).first() - if result is None: - msg = f"The dataset '{dataset}' does not exist in the project '{project}'." # noqa - raise HTTPException(status.HTTP_404_NOT_FOUND, msg) - - dataset_name = f'{project}.{dataset}' - msg = f"The table '{table}' does not exist in the dataset '{dataset_name}'." + msg = f"The dataset '{dataset}' does not exist in the project '{project}'." # noqa raise HTTPException(status.HTTP_404_NOT_FOUND, msg) - description = _get_table_description_postgresql(conn, dataset, table) + dataset_name = f'{project}.{dataset}' + msg = f"The table '{table}' does not exist in the dataset '{dataset_name}'." + raise HTTPException(status.HTTP_404_NOT_FOUND, msg) + + description = _get_table_description_postgresql(connection, dataset, table) return Table( name=f'{project}.{dataset}.{table}', @@ -201,12 +206,16 @@ def _set_table_description_postgresql( '/projects/{project}/datasets/{dataset}/tables/{table}/content', summary='Gets the content of a table.', ) -def get_table_content(project: str, dataset: str, table: str) -> list[dict]: +def get_table_content( + project: str, + dataset: str, + table: str, + connection: Connection = Depends(get_connection), +) -> list[dict]: """Returns the whole table as json.""" - column_names = get_table_column_names(project, dataset, table) - with begin_transaction(project) as conn: - stmt = text(f'SELECT * FROM {dataset}.{table}') - result = conn.execute(stmt) + column_names = select_table_column_names(connection, dataset, table) + stmt = text(f'SELECT * FROM {dataset}.{table}') + result = connection.execute(stmt) content = result.all() return [dict(zip(column_names, row)) for row in content] @@ -215,6 +224,8 @@ def get_table_content(project: str, dataset: str, table: str) -> list[dict]: '/projects/{project}/datasets/{dataset}/tables/{table}/column-names', summary='Gets the column names of a table.', ) -def get_table_column_names(project: str, dataset: str, table: str) -> list[str]: +def get_table_column_names( + dataset: str, table: str, connection: Connection = Depends(get_connection) +) -> list[str]: """Returns the column names of a table.""" - return table_column_names(project, dataset, table) + return select_table_column_names(connection, dataset, table) diff --git a/app/apis/v0/users.py b/app/apis/v0/users.py index 902f042a2d9f35a5b6455fbdbd80c421617f0164..51f310cdbe8707cc169ca1e8c93369fe52d80085 100644 --- a/app/apis/v0/users.py +++ b/app/apis/v0/users.py @@ -6,15 +6,14 @@ from sqlalchemy.orm import Session from app.schemas import User -from ...db.dbadmin import DBUser -from ..depends import get_session +from ...db.dbadmin import DBUser, begin_session logger = logging.getLogger(__name__) router = APIRouter() @router.post('/', response_model=User) -def create_user(*, session: Session = Depends(get_session), user: User) -> DBUser: +def create_user(*, session: Session = Depends(begin_session), user: User) -> DBUser: """Creates a user.""" db_user = DBUser( first_name=user.first_name, diff --git a/app/db/__init__.py b/app/db/__init__.py index ff509bee7359cdca1fcd1afa0b751dc99072da72..73eb57724b7d6e519bba1cbc4f6f5bada9509554 100644 --- a/app/db/__init__.py +++ b/app/db/__init__.py @@ -1,11 +1,5 @@ """Accessors to the admin or project databases.""" from .accessors import project -from .dbadmin import AdminSession -from .dbprojects.engines import project_engines -__all__ = ( - 'AdminSession', - 'project_engines', - 'project', -) +__all__ = ('project',) diff --git a/app/db/accessors/base.py b/app/db/accessors/base.py index 9c04727f16ce1d7d56d21ca7021b46c7e409e6ab..8abe84b2631eb10957015372cba4b991844b7962 100644 --- a/app/db/accessors/base.py +++ b/app/db/accessors/base.py @@ -2,7 +2,7 @@ from typing import Generic, List, Optional, Type, TypeVar from fastapi.encoders import jsonable_encoder -from pydantic import BaseModel +from pydantic import BaseModel, SecretStr from sqlalchemy import Column from sqlalchemy.future import select from sqlalchemy.orm import Session @@ -55,11 +55,12 @@ class AccessorBase(Generic[DBModelType, SchemaType, CreateSchemaType, Identifier def create(self, session: Session, resource: CreateSchemaType) -> SchemaType: """Creates a resource.""" - serialized_resource = jsonable_encoder(resource) + serialized_resource = jsonable_encoder( + resource, custom_encoder={SecretStr: lambda v: v.get_secret_value()} + ) db_resource = self.model(**serialized_resource) # type: ignore session.add(db_resource) - session.commit() - session.refresh(db_resource) + session.flush() return self.schema.from_orm(db_resource) def delete( @@ -71,5 +72,5 @@ class AccessorBase(Generic[DBModelType, SchemaType, CreateSchemaType, Identifier if db_resource is None: return db_resource session.delete(db_resource) - session.commit() + session.flush() return self.schema.from_orm(db_resource) diff --git a/app/db/dbadmin/__init__.py b/app/db/dbadmin/__init__.py index 03790d34cedebe807574ff697ede4e0665a99657..2cde7f6b4900642b0ac2516372a7e6af9185e7ae 100644 --- a/app/db/dbadmin/__init__.py +++ b/app/db/dbadmin/__init__.py @@ -1,10 +1,10 @@ """The ESAP-DB module handling the admin database.""" from .models import DBProject, DBProjectServer, DBUser -from .sessions import AdminSession +from .sessions import begin_session __all__ = ( - 'AdminSession', 'DBProject', 'DBProjectServer', 'DBUser', + 'begin_session', ) diff --git a/app/db/dbadmin/sessions.py b/app/db/dbadmin/sessions.py index 539239d95413ad34c29f38f54092396ed8ef59b2..a1a967c35816cca3d64647d758ecc30ada3026a2 100644 --- a/app/db/dbadmin/sessions.py +++ b/app/db/dbadmin/sessions.py @@ -2,8 +2,11 @@ An instance of this class can be used to connect the admin or a project database. """ +from contextlib import contextmanager +from typing import Iterator + from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from ...config import settings @@ -11,3 +14,15 @@ assert settings.SQLALCHEMY_DBADMIN_URI is not None engine = create_engine(settings.SQLALCHEMY_DBADMIN_URI, pool_pre_ping=True, future=True) AdminSession = sessionmaker(autocommit=False, autoflush=False, bind=engine, future=True) + + +@contextmanager +def begin_session() -> Iterator[Session]: + """Returns an admin database transaction session. + + On exit, the transaction is automatically committed and the session closed. + When using this session, the method `session.flush` needs to be run in order + to get the server-side generated values of the added resources. + """ + with AdminSession.begin() as session: + yield session diff --git a/app/db/dbprojects/__init__.py b/app/db/dbprojects/__init__.py index 9f541560e83a422ddd004ef20a38fa173c76debe..0b22f4a7fa9a7174ac246e45fd39741328cd68c7 100644 --- a/app/db/dbprojects/__init__.py +++ b/app/db/dbprojects/__init__.py @@ -1 +1,10 @@ """The ESAP-DB module handling the project databases.""" +from .engines import begin_connection, project_engines +from .operations import create_database, select_table_column_names + +__all__ = ( + 'begin_connection', + 'create_database', + 'project_engines', + 'select_table_column_names', +) diff --git a/app/db/dbprojects/engines.py b/app/db/dbprojects/engines.py index ab12e2a13b64019538a9a87bf424c3f4dbcb0476..8be7eb0c01c49606189f3aa42fd056e400a0d0aa 100644 --- a/app/db/dbprojects/engines.py +++ b/app/db/dbprojects/engines.py @@ -1,13 +1,14 @@ """Factory for the SQLAlchemy engines of the projects, with a caching mechanism.""" import collections +from contextlib import contextmanager from typing import Iterator from fastapi import HTTPException, status from sqlalchemy import create_engine -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Connection, Engine from sqlalchemy.future import select -from ..dbadmin import AdminSession, DBProject +from ..dbadmin import DBProject, begin_session class ProjectEngines(collections.abc.Mapping): @@ -37,14 +38,22 @@ class ProjectEngines(collections.abc.Mapping): def _create_project_engine(project_name: str) -> Engine: """Factory for the project engines.""" stmt = select(DBProject).where(DBProject.name == project_name) - project = AdminSession().execute(stmt).scalar_one_or_none() - if project is None: - raise HTTPException( - status.HTTP_404_NOT_FOUND, - f"The project '{project_name}' does not exist.", - ) - uri = project.uri - return create_engine(uri, pool_pre_ping=True, future=True) + with begin_session() as session: + db_project = session.execute(stmt).scalar_one_or_none() + if db_project is None: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + f"The project '{project_name}' does not exist.", + ) + return create_engine(db_project.uri, pool_pre_ping=True, future=True) project_engines = ProjectEngines() + + +@contextmanager +def begin_connection(project: str) -> Iterator[Connection]: + """Returns a transaction connection associated with the project database.""" + engine = project_engines[project] + with engine.begin() as connection: + yield connection diff --git a/app/db/dbprojects/operations.py b/app/db/dbprojects/operations.py new file mode 100644 index 0000000000000000000000000000000000000000..29b8db0aee03fa9821430f7cbdee7cd49dd47156 --- /dev/null +++ b/app/db/dbprojects/operations.py @@ -0,0 +1,45 @@ +"""PostgreSQL operations on the user project database servers.""" +import logging + +from fastapi import HTTPException, status +from sqlalchemy import create_engine, text +from sqlalchemy.engine import Connection +from sqlalchemy.exc import DBAPIError + +from ...helpers import fix_sqlalchemy2_stubs_non_nullable_column + +logger = logging.getLogger(__name__) + + +def create_database(server_uri: str, project_name: str) -> None: + """Creates a PostgreSQL database.""" + try: + with create_engine( + fix_sqlalchemy2_stubs_non_nullable_column(server_uri), + pool_pre_ping=True, + future=True, + ).connect() as conn: + stmt = text(f'CREATE DATABASE "{project_name}"') + conn.execution_options(isolation_level='AUTOCOMMIT').execute(stmt) + except DBAPIError as exc: + logger.error(str(exc)) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail='The project database could not be created: ' + f'{type(exc).__name__}: {exc}', + ) + + +def select_table_column_names( + connection: Connection, dataset: str, table: str +) -> list[str]: + """Returns the column names of a table.""" + stmt = text( + f""" + SELECT column_name + FROM information_schema.columns + WHERE table_schema = '{dataset}' AND table_name = '{table}' + """ + ) + result = connection.execute(stmt) + return result.scalars().all() diff --git a/app/helpers.py b/app/helpers.py index 0718578ac2c1ee542205557de114232f0efd0681..22dc951d4b7038c4f5d2b116bc12e6f4473217a0 100644 --- a/app/helpers.py +++ b/app/helpers.py @@ -1,37 +1,9 @@ """Some helper functions for the API controlers.""" -from contextlib import contextmanager -from typing import Iterator, Optional, TypeVar - -from sqlalchemy import text -from sqlalchemy.engine import Connection - -from .db import project_engines +from typing import Optional, TypeVar T = TypeVar('T') -@contextmanager -def begin_transaction(project: str) -> Iterator[Connection]: - """Returns a cursor associated with the project database.""" - engine = project_engines[project] - with engine.begin() as connection: - yield connection - - -def table_column_names(project: str, dataset: str, table: str) -> list[str]: - """Returns the column names of a table.""" - with begin_transaction(project) as connection: - stmt = text( - f""" - SELECT column_name - FROM information_schema.columns - WHERE table_schema = '{dataset}' AND table_name = '{table}' - """ - ) - result = connection.execute(stmt) - return result.scalars().all() - - def fix_sqlalchemy2_stubs_non_nullable_column(arg: Optional[T]) -> T: """Non-nullable model columns are of type Optional[...] for mypy. diff --git a/app/schemas/project.py b/app/schemas/project.py index 69d610f42d144a973791759ea3ccbf67c6222a8e..0abfabbaa7a1991cabe1c7fee6ff13b27f49d2fd 100644 --- a/app/schemas/project.py +++ b/app/schemas/project.py @@ -1,28 +1,41 @@ """The Pydantic classes to represent a project.""" from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, SecretStr class ProjectBase(BaseModel): """The Project base schema.""" - project_server_id: Optional[int] + project_server_id: Optional[int] = None name: str description: str = '' - uri: Optional[str] + uri: Optional[SecretStr] = None max_size: int = 10 * 2 ** 30 class ProjectCreate(ProjectBase): """Schema to create a project.""" + class Config: + schema_extra = { + 'example': { + 'name': 'my_project', + 'description': 'My first project 😀', + 'max_size': 10 * 2 ** 30, + } + } + class Project(ProjectBase): """The Project schema mapped to the database.""" id: int - uri: str + uri: SecretStr class Config: orm_mode = True + json_encoders = { + SecretStr: lambda v: 'postgresql://********:********@' + + v.get_secret_value().split('@', 1)[1], + } diff --git a/migrations/env.py b/migrations/env.py index 416072ab26bda5060d2f2c8b77053117dd05c575..37930dd2e3fe2619dc96ab68fe631e1ebd64480d 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -12,6 +12,7 @@ from sqlalchemy import engine_from_config, pool # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata +import app.db.dbprojects.engines from app.config import settings from app.db.dbadmin import models diff --git a/scripts/initialize_dbadmin.py b/scripts/initialize_dbadmin.py index 2222b9c133d24730e03d27a9fe17819d21961d93..237c4a17f4cba72cb3fd78c2ae0122a27c37fdf6 100644 --- a/scripts/initialize_dbadmin.py +++ b/scripts/initialize_dbadmin.py @@ -4,8 +4,7 @@ import logging from sqlalchemy.exc import IntegrityError from app.config import settings -from app.db import AdminSession -from app.db.dbadmin import DBProjectServer +from app.db.dbadmin import DBProjectServer, begin_session logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -18,7 +17,7 @@ def initialize_dbadmin() -> None: max_size = 2 ** 50 # 1 PiB server = DBProjectServer(uri=uri, max_size=max_size, available_size=max_size) try: - with AdminSession.begin() as session: + with begin_session() as session: session.add(server) logger.info(f'Added project server: {server_name}') except IntegrityError: diff --git a/scripts/wait_for_dbadmin.py b/scripts/wait_for_dbadmin.py index 4805972d9cffeb97257cbe11bf173836f9bfeef4..c3cd8e87d947f6bea9e503e72161eaa36c3a3375 100644 --- a/scripts/wait_for_dbadmin.py +++ b/scripts/wait_for_dbadmin.py @@ -4,7 +4,7 @@ import logging from sqlalchemy import select from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed -from app.db import AdminSession +from app.db.dbadmin import begin_session logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ wait_seconds = 1 ) def execute_command() -> None: """Attempts to execute a command in the admin database.""" - with AdminSession.begin() as session: + with begin_session() as session: try: # Try to create session to check if DB is awake session.execute(select(1)) diff --git a/scripts/wait_for_initialized_dbadmin.py b/scripts/wait_for_initialized_dbadmin.py index 36bc1c18a9574ae01c4ddcea73eaae366c5bd034..c3e89cbfc8f05f699d842fc4fa35c3bebcb332f0 100644 --- a/scripts/wait_for_initialized_dbadmin.py +++ b/scripts/wait_for_initialized_dbadmin.py @@ -4,8 +4,7 @@ import logging from sqlalchemy import select from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed -from app.db import AdminSession -from app.db.dbadmin import DBProjectServer +from app.db.dbadmin import DBProjectServer, begin_session logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -22,7 +21,7 @@ wait_seconds = 1 ) def execute_command() -> None: """Ensures the project database has at least one entry.""" - with AdminSession.begin() as session: + with begin_session() as session: stmt = select(DBProjectServer) try: result = session.execute(stmt).first() diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 3ff499e20e86086c71036f68ce2545cdc7ae278e..0000000000000000000000000000000000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Generator - -import pytest -from sqlalchemy.orm import Session - -from app.db import AdminSession - - -@pytest.fixture(scope='session') -def db() -> Generator[Session, None, None]: - yield AdminSession() diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 85de4b97dfa858bee31a5433a91ff27d1638704e..00874870b4ba27db1c9d971ee62ff80d4b88f783 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -1,4 +1,4 @@ -from typing import Generator +from typing import Iterator import pytest from fastapi.testclient import TestClient @@ -7,6 +7,6 @@ from app.main import app @pytest.fixture(scope='session') -def client() -> Generator[TestClient, None, None]: +def client() -> Iterator[TestClient]: with TestClient(app) as c: yield c diff --git a/tests/functional/v0/conftest.py b/tests/functional/v0/conftest.py index 85887f3a6a1424465e41af54ac1313482cb586c4..d4b3dfe28827bfe6ebfc63b402d7110047f3a6a2 100644 --- a/tests/functional/v0/conftest.py +++ b/tests/functional/v0/conftest.py @@ -1,4 +1,4 @@ -from typing import Generator +from typing import Iterator import pytest from fastapi.testclient import TestClient @@ -9,10 +9,10 @@ from .helpers import stage_dataset, stage_project @pytest.fixture(scope='package') -def project(client: TestClient) -> Generator[Project, None, None]: +def project(client: TestClient) -> Iterator[Project]: yield stage_project(client) @pytest.fixture(scope='package') -def dataset(client: TestClient, project: Project) -> Generator[Dataset, None, None]: +def dataset(client: TestClient, project: Project) -> Iterator[Dataset]: yield stage_dataset(client, project) diff --git a/tests/functional/v0/test_projects.py b/tests/functional/v0/test_projects.py index 1d63fcf3c925ad0d5c1933db708f632fce3d2707..f40ad2a9bfe4aec9cb57eec5669d117d87fbc355 100644 --- a/tests/functional/v0/test_projects.py +++ b/tests/functional/v0/test_projects.py @@ -16,7 +16,7 @@ def test_create_project(client: TestClient) -> None: project_post = response.json() assert project_post.pop('id') is not None assert project_post.pop('project_server_id') is not None - assert project_post.pop('uri').startswith('postgresql') + assert project_post.pop('uri').startswith('postgresql://********:********@') assert project == project_post diff --git a/tests/unit/db/conftest.py b/tests/unit/db/conftest.py index 5db581cb183412d7a1a9c27e5aaca54a62415cfe..f8205be36512e8995565f44e98d336af4d6e092b 100644 --- a/tests/unit/db/conftest.py +++ b/tests/unit/db/conftest.py @@ -1,11 +1,12 @@ -from typing import Generator +from typing import Iterator import pytest from sqlalchemy.orm import Session -from app.apis.depends import get_session_as_you_go +from app.db.dbadmin import begin_session @pytest.fixture -def session() -> Generator[Session, None, None]: - yield from get_session_as_you_go() +def session() -> Iterator[Session]: + with begin_session() as session: + yield session