Newer
Older
from urllib.parse import urlparse
import requests
from lofardata.models import (
SUBMISSION_STATUS,
ATDBProcessingSite,
DataProduct,
WorkSpecification,
)
from requests.auth import AuthBase
from requests.exceptions import RequestException
from ldvspec.celery import app
logger = get_task_logger(__name__)
class RequestNotOk(Exception):
pass
class WorkSpecificationNoSite(Exception):
pass
class SessionStore:
"""requests.Session Singleton"""
_session = None
@classmethod
def get_session(cls) -> requests.Session:
if cls._session == None:
cls._session = requests.Session()
return cls._session
@app.task
def define_work_specification(workspecification_id):
specification = WorkSpecification.objects.get(pk=workspecification_id)
filters = specification.filters
dataproducts = DataProduct.objects.filter(**filters).order_by("surl")
inputs = {
"surls": [
{"surl": dataproduct.surl, "size": dataproduct.filesize}
for dataproduct in dataproducts
]
}
if specification.inputs is None:
specification.inputs = inputs
else:
specification.inputs.update(inputs)
def _parse_surl(surl: str) -> dict:
parsed = urlparse(surl)
host = parsed.hostname
path = parsed.path
pattern = r"^.*/projects\/(?P<project>.*_\d*)\/(?P<sas_id>\w\d*)\/"
data = re.match(pattern, path).groupdict()
data["location"] = host
return data
def _prepare_request_payload(
entries: List[dict],
filter_id: str,
workflow_url: str,
purge_policy: str = "no",
predecessor: int = None,
optional_parameters: Dict[str, Any] = None,
):
# Parse a single surl for info:
# project, sas_id & location
# This does assume that a task consists of at most 1 project, sas_id and location!
parsed = _parse_surl(entries[0]["surl"])
project_id = parsed["project"]
sas_id = parsed["sas_id"]
inputs = [
{
"size": e["size"],
"surl": e["surl"],
"type": "File",
"location": parsed["location"],
}
for e in entries
]
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
data = {
"project": project_id,
"sas_id": sas_id,
"task_type": "regular",
"filter": filter_id,
"purge_policy": purge_policy,
"new_status": "defining",
"new_workflow_uri": workflow_url,
"size_to_process": sum([e["size"] for e in entries]),
"inputs": inputs,
}
if predecessor:
data["predecessor"] = predecessor
return data
class TokenAuth(AuthBase):
"""Basic Token Auth
Adds a: `Authorization: Token <token>` header"""
def __init__(self, token: str):
self._token = token
def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
r.headers["Authorization"] = f"Token {self._token}"
return r
def split_entries_to_batches(entries: List[Any], batch_size: int) -> List[List[Any]]:
"""Split the list of entries into batches of at most `batch_size` entries"""
n_entries = len(entries)
# NOTE: Think about using file size instead of amount of files
batches: List[List[Any]] = []
if batch_size == 0:
batches.append(entries)
elif batch_size > 0:
# Calculate amount of required batches
num_batches = n_entries // batch_size
num_batches += 1 if n_entries % batch_size else 0
for n in range(num_batches):
batches.append(entries[n * batch_size : (n + 1) * batch_size])
return batches
@app.task
def insert_task_into_atdb(workspecification_id: int):
"""This creates the task in ATDB and set's it to defining"""
sess = SessionStore.get_session()
work_spec: WorkSpecification = WorkSpecification.objects.get(
pk=workspecification_id
)
inputs: Dict[str, Any] = work_spec.inputs.copy()
batches = split_entries_to_batches(entries, work_spec.batch_size)
site: Optional[ATDBProcessingSite] = work_spec.processing_site
if site is None:
raise WorkSpecificationNoSite()
url = site.url + "tasks/"
# Task ID of the predecessor
atdb_predecessor_task_id: int | None = None
if work_spec.predecessor_specification is not None:
predecessor: WorkSpecification = work_spec.predecessor_specification
if len(predecessor.related_tasks) != 1:
logger.error("Workspecification {} has no valid predecessor".format(work_spec.pk))
raise InvalidPredecessor()
# Should only be 1 entry
atdb_predecessor_task_id = predecessor.related_tasks[0]
try:
for batch in batches:
payload = _prepare_request_payload(
entries=batch,
filter_id=f"ldv-spec:{work_spec.pk}",
workflow_url=work_spec.selected_workflow,
purge_policy=work_spec.purge_policy,
)
res = sess.post(url, json=payload, auth=TokenAuth(site.access_token))
if not res.ok:
raise RequestNotOk()
# Store ATDB Task ID in related_tasks
if work_spec.related_tasks is None:
work_spec.related_tasks = []
data = res.json()
work_spec.related_tasks.append(data["id"])
# All went well
work_spec.submission_status = SUBMISSION_STATUS.DEFINING
if work_spec.is_auto_submit:
set_tasks_defined.delay(workspecification_id)
except (RequestException, RequestNotOk):
work_spec.submission_status = SUBMISSION_STATUS.ERROR
finally:
work_spec.save()
def update_related_tasks(
work_spec: WorkSpecification,
delete: bool,
data: Optional[dict],
on_success_status: SUBMISSION_STATUS,
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
):
sess = SessionStore.get_session()
site: Optional[ATDBProcessingSite] = work_spec.processing_site
if site is None:
raise WorkSpecificationNoSite()
url = site.url + "tasks/"
task_ids: List[int] = work_spec.related_tasks
try:
for task_id in task_ids:
if delete:
res = sess.delete(
url + str(task_id) + "/", auth=TokenAuth(site.access_token)
)
else:
res = sess.put(
url + str(task_id) + "/",
json=data,
auth=TokenAuth(site.access_token),
)
if not res.ok:
raise RequestNotOk()
# All went well
work_spec.submission_status = on_success_status
if delete:
work_spec.related_tasks = None
except (RequestException, RequestNotOk):
work_spec.submission_status = SUBMISSION_STATUS.ERROR
finally:
work_spec.save()
@app.task
def set_tasks_defined(workspecification_id: int):
"""This sets tasks to defined so they can be picked up by ATDB services"""
work_spec: WorkSpecification = WorkSpecification.objects.get(
pk=workspecification_id
)
if work_spec.submission_status != SUBMISSION_STATUS.DEFINING:
raise ValueError("Invalid WorkSpecification state")
update_related_tasks(
work_spec, False, {"new_status": "defined"}, SUBMISSION_STATUS.SUBMITTED
)
@app.task
def delete_tasks_from_atdb(workspecification_id: int):
"""Removes related tasks from ATDB (for retrying)"""
work_spec: WorkSpecification = WorkSpecification.objects.get(
pk=workspecification_id
)
update_related_tasks(work_spec, True, None, SUBMISSION_STATUS.NOT_SUBMITTED)