Commit 214263a8 authored by Pierre Chanial's avatar Pierre Chanial
Browse files

Create table from SQL query involving tables in the same project.

parent 9ef115c4
Pipeline #15995 canceled with stages
in 10 minutes and 4 seconds
......@@ -6,7 +6,7 @@ import pickle
import re
from io import StringIO
from pathlib import Path
from typing import Literal, Optional
from typing import Callable, Literal, Optional, Union
import pandas as pd
import requests
......@@ -21,6 +21,7 @@ from ...db.dbadmin import begin_session
from ...db.dbprojects import project_engines
from ...db.dbprojects.operations import drop_table, select_table_column_names
from ...exceptions import ESAPDBValidationError
from ...helpers import uid
from ...schemas import (
BodyCreateTableFrom,
BodyCreateTableFromESAPGatewayQuery,
......@@ -36,6 +37,7 @@ HEADERS_JSON = {'Accept': 'application/json'}
REGEX_CONTENT_DISPOSITION_NAME = re.compile(r'\bname\s*=\s*"(.*)"')
REGEX_CONTENT_DISPOSITION_FILENAME = re.compile(r'\bfilename\s*=\s*"(.*)"')
REGEX_VALID_FORMAT = re.compile('[a-z]')
REGEX_SELECT_CLAUSE = re.compile(r'[\s;]*(SELECT .*)[\s;]*')
@router.get(
......@@ -60,44 +62,58 @@ def create_table_from(
dataset: str,
body: BodyCreateTableFrom,
if_exists: Literal['fail', 'replace', 'append'] = 'fail',
connection: Connection = Depends(get_connection),
session: Session = Depends(get_session),
) -> Table:
"""Creates a table from a columnar input."""
if isinstance(body.content, str):
# FIXME arbitrary code execution, data cannot be trusted
dataframes = pickle.loads(base64.b64decode(body.content))
name = body.name
elif body.content is not None:
dataframes = pd.DataFrame(body.content)
name = body.name
"""Creates a table from a data source."""
if body.name is not None:
parts = body.name.split('.')
if len(parts) == 3 and parts[0] != project:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
f'The project specified in the path ({project}) differs '
f'from that specified in the request body ({parts[0]}).',
)
if len(parts) >= 2 and parts[-2] != dataset:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
f'The dataset specified in the path ({dataset}) differs'
f' from that specified in the request body ({parts[-2]}).',
)
dataframe_getter: Optional[Callable[[], pd.DataFrame]]
if body.content is not None:
assert body.name is not None
*_, table = body.name.split('.')
def dataframe_getter() -> pd.DataFrame:
assert body.content is not None
return _get_dataframe_content(body.content)
elif body.path is not None:
if body.name:
# the request can be deferred after the Table creation
*_, table = body.name.split('.')
def dataframe_getter() -> pd.DataFrame:
return _get_dataframe_path(body)[1]
else:
# the request cannot be deferred: it is needed for the table name
table, df = _get_dataframe_path(body)
def dataframe_getter() -> pd.DataFrame:
return df
else:
assert body.path is not None
response = requests.get(body.path)
response.raise_for_status()
name = normalize_identifier(_infer_name_from_response(body.name, response))
format = _infer_format_from_response(body.params.pop('format', None), response)
if format == 'csv':
dataframes = pd.read_csv(StringIO(response.text), **body.params)
assert body.query is not None
if body.name is not None:
*_, table = body.name.split('.')
else:
raise ESAPDBValidationError(
f"Cannot handle format '{format}'. Valid formats are: csv."
)
table = uid('anonymous')
dataframe_getter = None
assert name is not None
*_, table = name.split('.')
if len(_) == 2 and _[0] != project:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
f'The project specified in the path ({project}) differs '
f'from that specified in the request body ({_[0]}).',
)
if len(_) > 1 and _[-1] != dataset:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
f'The dataset specified in the path ({dataset}) differs'
f' from that specified in the request body ({_[-1]}).',
)
table_name = f'{project}.{dataset}.{table}'
if if_exists == 'fail':
......@@ -106,21 +122,47 @@ def create_table_from(
else:
table_out = db.table.get_by_name(session, table_name)
dataframes.to_sql(
table,
project_engines[project],
schema=dataset,
if_exists=if_exists,
index=False,
)
if dataframe_getter is not None:
df = dataframe_getter()
_write_dataframe(df, project, dataset, table, if_exists)
else:
assert body.query is not None
stmt = text(_get_create_table_as_statement(body.query, f'{dataset}.{table}'))
connection.execute(stmt)
return table_out
def _get_dataframe_content(content: Union[str, dict[str, list]]) -> pd.DataFrame:
if isinstance(content, str):
# FIXME arbitrary code execution, data cannot be trusted
return pickle.loads(base64.b64decode(content))
return pd.DataFrame(content)
def _get_dataframe_path(body: BodyCreateTableFrom) -> tuple[str, pd.DataFrame]:
assert body.path is not None
response = requests.get(body.path)
response.raise_for_status()
table = normalize_identifier(_infer_name_from_response(body.name, response))
format = _infer_format_from_response(body.params.pop('format', None), response)
if format == 'csv':
df = pd.read_csv(StringIO(response.text), **body.params)
else:
raise ESAPDBValidationError(
f"Cannot handle format '{format}'. Valid formats are: csv."
)
return table, df
def _infer_name_from_response(name: Optional[str], response: requests.Response) -> str:
name = name or None
if name:
return name
return name.split('.')[-1]
url = Path(response.url)
if url.suffix[1:] in {'csv'}:
......@@ -169,11 +211,39 @@ def _infer_format_from_response(
)
def _write_dataframe(
dataframes: pd.DataFrame, project: str, dataset: str, table: str, if_exists: str
) -> None:
dataframes.to_sql(
table,
project_engines[project],
schema=dataset,
if_exists=if_exists,
index=False,
)
def _valid_format(value: str) -> bool:
"""Some heuristics to check if a suffix is valid."""
return len(value) <= 8 and REGEX_VALID_FORMAT.search(value) is not None
def _get_create_table_as_statement(query: str, table_name: str) -> str:
match = REGEX_SELECT_CLAUSE.match(query)
if match is None:
raise ESAPDBValidationError(
'The SQL query statement does not start with a SELECT clause.'
)
query = match[1]
if ';' in query:
raise ESAPDBValidationError(
"The SQL query must not contain the character ';' to avoid SQL "
'injections. A better approach would consist in parsing the SQL '
'statement to avoid false-positives in literals.'
)
return f'CREATE TABLE {table_name} AS {query}'
@router.post(
'/projects/{project}/esap-gateway-operations',
summary='Performs an ESAP Gateway query and stores it in a table.',
......
"""Some helper functions for the API controlers."""
from typing import Optional, TypeVar
from uuid import uuid4
T = TypeVar('T')
......@@ -12,3 +13,10 @@ def fix_sqlalchemy2_stubs_non_nullable_column(arg: Optional[T]) -> T:
"""
assert arg is not None
return arg
def uid(arg: str = '') -> str:
"""Random project, dataset or table identifier."""
if arg:
arg += '_'
return arg + str(uuid4()).replace('-', '_')
......@@ -12,26 +12,29 @@ from pydantic import BaseModel, Field, root_validator
class BodyCreateTableFrom(BaseModel):
"""The body schema to create a table.
Current available inputs are:
Current available input sources are:
* a mapping column name / list of values
* a base64 encoded pickled Pandas DataFrame
* a SELECT query
"""
name: Optional[str] = None
description: str = ''
content: Union[None, str, dict[str, list]] = None
path: Optional[str] = None
query: Optional[str] = None
params: dict[str, Any] = Field(default_factory=dict)
@root_validator()
def check_data_source(cls, values: dict) -> dict:
"""Ensures there is only one data source."""
if values['content'] is None and values['path'] is None:
nsource = sum(values[_] is not None for _ in ['content', 'path', 'query'])
if nsource == 0:
raise ValueError('No data source is specified in the request body.')
if values['content'] is not None and values['path'] is not None:
if nsource > 1:
raise ValueError(
'Ambiguous data source: both the content and path property are '
'specified in the request body.'
'Ambiguous data source: only one of the following sources must be '
"set in the request body: 'content', 'path', 'query'."
)
return values
......
......@@ -5,10 +5,9 @@ from typing import Iterator, Optional
from fastapi.testclient import TestClient
from app.config import settings
from app.helpers import uid
from app.schemas import Dataset, Project, Table
from ...helpers import uid
DATA_PATH = Path(__file__).parents[1] / 'data'
......
from fastapi.testclient import TestClient
from app.config import settings
from app.helpers import uid
from app.schemas import Dataset, Project
from ...helpers import uid
from .helpers import staged_dataset, staged_table
......
from fastapi.testclient import TestClient
from app.config import settings
from app.helpers import uid
from app.schemas import Project
from ...helpers import uid
from .helpers import staged_project
......
......@@ -6,10 +6,10 @@ import responses
from fastapi.testclient import TestClient
from app.config import settings
from app.helpers import uid
from app.schemas import Dataset, Project, Table
from ...helpers import uid
from .helpers import DATA_PATH, staged_table
from .helpers import DATA_PATH, staged_dataset, staged_table
def test_create_table_mapping_success(client: TestClient, dataset: Dataset) -> None:
......@@ -71,6 +71,46 @@ def test_create_table_csv_success(
assert table.description == payload['description']
def test_create_table_query_success(
client: TestClient, project: Project, dataset: Dataset
) -> None:
payload1 = {
'name': uid('table1'),
'content': {'x': 5 * ['vegetable'], 'y': list('🥑🌽🥒🍆🥦')},
}
payload2 = {
'name': uid('table2'),
'description': 'Union of two tables',
'content': {'x': 6 * ['fruit'], 'y': list('🍓🥝🍇🍐🍏🍍')},
}
with staged_dataset(client, project) as dataset1, staged_table(
client, dataset1, payload1
) as table1, staged_dataset(client, project) as dataset2, staged_table(
client, dataset2, payload2
) as table2:
_, dataset1_name, table1_name = table1.name.split('.')
_, dataset2_name, table2_name = table2.name.split('.')
payload3 = {
'name': f'{dataset.name}.{uid("table3")}',
'description': 'union table',
'query': f'SELECT * FROM {dataset1_name}.{table1_name} UNION SELECT * FROM {dataset2_name}.{table2_name} ORDER BY x, y',
}
with staged_table(client, dataset, payload3) as table3:
project_name, dataset_name, table_name = table3.name.split('.')
api = f'{settings.API_V0_STR}/projects/{project_name}/datasets/{dataset_name}/tables/{table_name}/content'
response = client.get(api)
assert response.status_code == 200
actual = pd.DataFrame(response.json())
# FIXME the order is wrong, should investigate collation
actual = actual.sort_values(['x', 'y']).reset_index(drop=True)
df1 = pd.DataFrame(payload1['content'])
df2 = pd.DataFrame(payload2['content'])
expected = pd.concat([df1, df2]).sort_values(['x', 'y']).reset_index(drop=True)
assert actual.equals(expected)
assert table3.name == payload3['name']
assert table3.description == payload3['description']
def test_create_table_not_found(client: TestClient, project: Project) -> None:
payload = {
'name': uid('table'),
......
from uuid import uuid4
def uid(arg: str = '') -> str:
if arg:
arg += '_'
return arg + str(uuid4()).replace('-', '_')
from uuid import uuid4
from app.helpers import uid
from app.schemas import DatasetCreate, ProjectCreate, ProjectServerCreate, TableCreate
from ....helpers import uid
def fake_project_server() -> ProjectServerCreate:
return ProjectServerCreate(
......
......@@ -6,9 +6,9 @@ from sqlalchemy.orm import Session
from app import db
from app.exceptions import ESAPDBResourceExistsError, ESAPDBResourceNotFoundError
from app.helpers import uid
from app.schemas import Project
from ....helpers import uid
from .helpers import fake_project
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment