Skip to content
Snippets Groups Projects
tasks.py 7.84 KiB
Newer Older
Fanna Lautenbach's avatar
Fanna Lautenbach committed
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse

import requests
Fanna Lautenbach's avatar
Fanna Lautenbach committed
from celery.utils.log import get_task_logger
from lofardata.models import (
    SUBMISSION_STATUS,
    ATDBProcessingSite,
    DataProduct,
    WorkSpecification,
)
Fanna Lautenbach's avatar
Fanna Lautenbach committed
from requests.auth import AuthBase
from requests.exceptions import RequestException

from ldvspec.celery import app

logger = get_task_logger(__name__)


class RequestNotOk(Exception):
    pass


class WorkSpecificationNoSite(Exception):
    pass

Fanna Lautenbach's avatar
Fanna Lautenbach committed
class InvalidPredecessor(ValueError):
    pass

class SessionStore:
    """requests.Session Singleton"""

    _session = None

    @classmethod
    def get_session(cls) -> requests.Session:
        if cls._session == None:
            cls._session = requests.Session()
        return cls._session
Mattia Mancini's avatar
Mattia Mancini committed


@app.task
def define_work_specification(workspecification_id):
    specification = WorkSpecification.objects.get(pk=workspecification_id)
    filters = specification.filters

Fanna Lautenbach's avatar
Fanna Lautenbach committed
    dataproducts = DataProduct.objects.filter(**filters).order_by("surl")
    inputs = {
        "surls": [
            {"surl": dataproduct.surl, "size": dataproduct.filesize}
            for dataproduct in dataproducts
        ]
    }
Nico Vermaas's avatar
Nico Vermaas committed
    if specification.inputs is None:
        specification.inputs = inputs
    else:
        specification.inputs.update(inputs)
Mattia Mancini's avatar
Mattia Mancini committed
    specification.is_ready = True
    specification.save()


def _parse_surl(surl: str) -> dict:
    parsed = urlparse(surl)
    host = parsed.hostname
    path = parsed.path
    pattern = r"^.*/projects\/(?P<project>.*_\d*)\/(?P<sas_id>\w\d*)\/"
    data = re.match(pattern, path).groupdict()
    data["location"] = host
    return data


def _prepare_request_payload(
Fanna Lautenbach's avatar
Fanna Lautenbach committed
    entries: List[dict],
    filter_id: str,
    workflow_url: str,
    purge_policy: str = "no",
    predecessor: int = None,
    optional_parameters: Dict[str, Any] = None,
):
    # Parse a single surl for info:
    # project, sas_id & location
    # This does assume that a task consists of at most 1 project, sas_id and location!
    parsed = _parse_surl(entries[0]["surl"])
    project_id = parsed["project"]
    sas_id = parsed["sas_id"]

    inputs = [
        {
            "size": e["size"],
            "surl": e["surl"],
            "type": "File",
            "location": parsed["location"],
        }
        for e in entries
    ]
    if optional_parameters:
Fanna Lautenbach's avatar
Fanna Lautenbach committed
        inputs = {**optional_parameters, "surls": inputs}

    data = {
        "project": project_id,
        "sas_id": sas_id,
        "task_type": "regular",
        "filter": filter_id,
        "purge_policy": purge_policy,
        "new_status": "defining",
        "new_workflow_uri": workflow_url,
        "size_to_process": sum([e["size"] for e in entries]),
        "inputs": inputs,
    }

    if predecessor:
        data["predecessor"] = predecessor

    return data


class TokenAuth(AuthBase):
    """Basic Token Auth

    Adds a: `Authorization: Token <token>` header"""

    def __init__(self, token: str):
        self._token = token

    def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
        r.headers["Authorization"] = f"Token {self._token}"
        return r


def split_entries_to_batches(entries: List[Any], batch_size: int) -> List[List[Any]]:
    """Split the list of entries into batches of at most `batch_size` entries"""
    n_entries = len(entries)

    # NOTE: Think about using file size instead of amount of files
    batches: List[List[Any]] = []
    if batch_size == 0:
        batches.append(entries)
    elif batch_size > 0:
        # Calculate amount of required batches
        num_batches = n_entries // batch_size
        num_batches += 1 if n_entries % batch_size else 0

        for n in range(num_batches):
Fanna Lautenbach's avatar
Fanna Lautenbach committed
            batches.append(entries[n * batch_size : (n + 1) * batch_size])

    return batches


@app.task
def insert_task_into_atdb(workspecification_id: int):
    """This creates the task in ATDB and set's it to defining"""
    sess = SessionStore.get_session()

    work_spec: WorkSpecification = WorkSpecification.objects.get(
        pk=workspecification_id
    )
    inputs: Dict[str, Any] = work_spec.inputs.copy()
Fanna Lautenbach's avatar
Fanna Lautenbach committed
    entries: List[dict] = inputs.pop("surls")

    batches = split_entries_to_batches(entries, work_spec.batch_size)

    site: Optional[ATDBProcessingSite] = work_spec.processing_site
    if site is None:
        raise WorkSpecificationNoSite()
    url = site.url + "tasks/"

Fanna Lautenbach's avatar
Fanna Lautenbach committed
    # Task ID of the predecessor
    atdb_predecessor_task_id: int | None = None
    if work_spec.predecessor_specification is not None:
        predecessor: WorkSpecification = work_spec.predecessor_specification
        if len(predecessor.related_tasks) != 1:
Fanna Lautenbach's avatar
Fanna Lautenbach committed
            logger.error("Workspecification {} has no valid predecessor".format(work_spec.pk))
            raise InvalidPredecessor()
        # Should only be 1 entry
        atdb_predecessor_task_id = predecessor.related_tasks[0]

    try:
        for batch in batches:
            payload = _prepare_request_payload(
                entries=batch,
                optional_parameters=inputs,
                filter_id=f"ldv-spec:{work_spec.pk}",
                workflow_url=work_spec.selected_workflow,
                purge_policy=work_spec.purge_policy,
Fanna Lautenbach's avatar
Fanna Lautenbach committed
                predecessor=atdb_predecessor_task_id,
            )

            res = sess.post(url, json=payload, auth=TokenAuth(site.access_token))

            if not res.ok:
                raise RequestNotOk()

            # Store ATDB Task ID in related_tasks
            if work_spec.related_tasks is None:
                work_spec.related_tasks = []

            data = res.json()
            work_spec.related_tasks.append(data["id"])

        # All went well
        work_spec.submission_status = SUBMISSION_STATUS.DEFINING
        if work_spec.is_auto_submit:
            set_tasks_defined.delay(workspecification_id)
    except (RequestException, RequestNotOk):
        work_spec.submission_status = SUBMISSION_STATUS.ERROR
    finally:
        work_spec.save()


def update_related_tasks(
Fanna Lautenbach's avatar
Fanna Lautenbach committed
    work_spec: WorkSpecification,
    delete: bool,
    data: Optional[dict],
    on_success_status: SUBMISSION_STATUS,
):
    sess = SessionStore.get_session()

    site: Optional[ATDBProcessingSite] = work_spec.processing_site
    if site is None:
        raise WorkSpecificationNoSite()
    url = site.url + "tasks/"

    task_ids: List[int] = work_spec.related_tasks
    try:
        for task_id in task_ids:
            if delete:
                res = sess.delete(
                    url + str(task_id) + "/", auth=TokenAuth(site.access_token)
                )
            else:
                res = sess.put(
                    url + str(task_id) + "/",
                    json=data,
                    auth=TokenAuth(site.access_token),
                )

            if not res.ok:
                raise RequestNotOk()

        # All went well
        work_spec.submission_status = on_success_status
        if delete:
            work_spec.related_tasks = None
    except (RequestException, RequestNotOk):
        work_spec.submission_status = SUBMISSION_STATUS.ERROR
    finally:
        work_spec.save()


@app.task
def set_tasks_defined(workspecification_id: int):
    """This sets tasks to defined so they can be picked up by ATDB services"""

    work_spec: WorkSpecification = WorkSpecification.objects.get(
        pk=workspecification_id
    )

    if work_spec.submission_status != SUBMISSION_STATUS.DEFINING:
        raise ValueError("Invalid WorkSpecification state")

    update_related_tasks(
        work_spec, False, {"new_status": "defined"}, SUBMISSION_STATUS.SUBMITTED
    )


@app.task
def delete_tasks_from_atdb(workspecification_id: int):
    """Removes related tasks from ATDB (for retrying)"""

    work_spec: WorkSpecification = WorkSpecification.objects.get(
        pk=workspecification_id
    )

    update_related_tasks(work_spec, True, None, SUBMISSION_STATUS.NOT_SUBMITTED)