diff --git a/atdb/communication.py b/atdb/communication.py index 427d5859d004549c2360d55562856d8f79f68812..86625ae1ab905c76e9a1a445b41c4d3133715d12 100644 --- a/atdb/communication.py +++ b/atdb/communication.py @@ -1,7 +1,17 @@ +""" +This module is responsible for the communication to and from ATDB +""" +from typing import List, Generator +from argparse import Namespace import requests class APIError(Exception): + """ + The APIError is an exception which is raised when the communication with the ATDB API + fails or the requested operation fails + """ + def __init__(self, reply: requests.Response): status_code = reply.status_code reason = reply.reason @@ -14,47 +24,76 @@ class APIError(Exception): class DRFReply: + """ + A class to represent the DJANGO REST framework reply + """ + def __init__(self, response): self._content = response.json() @property - def n_items(self): + def n_items(self) -> int: + """ + Returns the number of items in the DRF reply + """ return self._content["count"] @property - def results(self): + def results(self) -> List[dict]: + """ + Access to the results list + """ return self._content["results"] @property def next_page_url(self): + """ + Access to the next page if the results are paginated + """ return self._content["next"] @property def previous_page_url(self): + """ + Access to the previous page of results if results are paginated + """ return self._content["previous"] class APIConnector: + """ + A class to represent the connection to the API + """ + def __init__(self, url, token): self._url = url.rstrip("/") self._token = token self._session = None @staticmethod - def from_args(args): + def from_args(args: Namespace): + """ + Creates API connector from command line arguments + """ return APIConnector(args.atdb_site, args.token) - def session(self): + def session(self) -> requests.Session: + """ + Returns a http session object and creates if it is not initialized + """ if self._session is None: self._session = self.start_session() return self._session - def start_session(self): - s = requests.Session() - s.headers["Authorization"] = f"Token {self._token}" - s.headers["content-type"] = "application/json" - s.headers["cache-control"] = "no-cache" - return s + def start_session(self) -> requests.Session: + """ + Start a session + """ + session_instance = requests.Session() + session_instance.headers["Authorization"] = f"Token {self._token}" + session_instance.headers["content-type"] = "application/json" + session_instance.headers["cache-control"] = "no-cache" + return session_instance def _request_url(self, method, url, query=None, content=None): url = url.replace("http://", "https://") @@ -67,17 +106,24 @@ class APIConnector: url = "/".join((self._url, item.lstrip("/"), "")) return self._request_url(method, url, query=query, content=content) - def list_iter(self, item, query=None): - response = self._request_path("get", item, query=query) + def list_iter(self, object_type, query=None) -> Generator[dict]: + """ + Returns a list iterator to a specific object_type in the REST API + + """ + response = self._request_path("get", object_type, query=query) drf_reply = DRFReply(response) - for r in drf_reply.results: - yield r + for item in drf_reply.results: + yield item while drf_reply.next_page_url is not None: drf_reply = DRFReply(self._request_url("get", drf_reply.next_page_url)) - for r in drf_reply.results: - yield r + for item in drf_reply.results: + yield item - def change_task_status(self, task_id, status): + def change_task_status(self, task_id, status) -> None: + """ + Change the status of a task + """ self._request_path("PUT", f"tasks/{task_id}", content={"new_status": status}) diff --git a/atdb/main.py b/atdb/main.py index 27daf908e1e4dfc052724e3f45df93a106384fa2..cf3734acab0b2786f04e27198b18065e1fae9316 100644 --- a/atdb/main.py +++ b/atdb/main.py @@ -1,10 +1,15 @@ +""" +Main entry point for command line script +""" + import logging import os -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from configparser import ConfigParser from atdb.prune import prune +DEFAULT_PATH = os.path.expanduser("~/.config/ldv/services.cfg") logging.basicConfig( level=logging.DEBUG, format="%(asctime)s %(levelname)s : %(message)s" ) @@ -12,8 +17,10 @@ logging.basicConfig( logger = logging.getLogger("atdb_mngr") -def read_conf_file(args, additional_location=None): - DEFAULT_PATH = os.path.expanduser("~/.config/ldv/services.cfg") +def read_conf_file(args: Namespace, additional_location=None): + """ + Reads configuration files and append results to args namespace + """ parser = ConfigParser() if additional_location is not None: parser.read(additional_location) @@ -28,7 +35,10 @@ def read_conf_file(args, additional_location=None): return args -def parse_args(): +def parse_args() -> Namespace: + """ + Parse command line arguments + """ parser = ArgumentParser(description="ATDB management tool") parser.add_argument( "--atdb_site", help="ATDB url", default="https://sdc-dev.astron.nl:5554/atdb" @@ -53,6 +63,9 @@ def parse_args(): def main(): + """ + Main entry point + """ args, parser = parse_args() args = read_conf_file(args, args.config) if args.v: diff --git a/atdb/prune.py b/atdb/prune.py index 2134c5b2142bfc47539c0e4bcd09134cb354128f..847debf1bdd7f049f50ac3a335906aaed33f7334 100644 --- a/atdb/prune.py +++ b/atdb/prune.py @@ -1,4 +1,8 @@ +""" +Prune command module +""" import logging +from typing import Union, List import gfal2 @@ -7,7 +11,11 @@ from atdb.communication import APIConnector logger = logging.getLogger("prune") -def extract_surls_from_obj(obj, partial=None): +def extract_surls_from_obj(obj: Union[dict, list], partial: List[Union[dict, list]] = None) -> List[ + Union[dict, list]]: + """ + Iterate over a nested object to extract surl values + """ if partial is None: partial = [] @@ -15,23 +23,27 @@ def extract_surls_from_obj(obj, partial=None): if isinstance(obj, dict) and "surl" in obj: partial.append(obj["surl"]) elif isinstance(obj, dict): - for key, value in obj.items(): + for value in obj.values(): extract_surls_from_obj(value, partial=partial) - elif isinstance(obj, list) or isinstance(obj, tuple): + elif isinstance(obj, (list, tuple)): for value in obj: extract_surls_from_obj(value, partial=partial) - except Exception as e: - logging.exception(e) - print(obj, partial) - raise SystemExit(1) + except KeyError as exception: + logging.exception(exception) return partial -def extract_task_surls_from_field(item, field_name): +def extract_task_surls_from_field(item: dict, field_name: str) -> List[Union[dict, list]]: + """ + Extract from task object field the surl + """ return extract_surls_from_obj(item[field_name]) -def remove_surl_locations(surls, dry_run=False): +def remove_surl_locations(surls: List[str], dry_run=False) -> None: + """ + Removes SURL location if dry_run is specified it only tests + """ context = gfal2.creat_context() for surl in surls: @@ -43,6 +55,9 @@ def remove_surl_locations(surls, dry_run=False): def prune(args): + """ + Prune command entry point + """ connector = APIConnector.from_args(args) workflow_id = args.workflow_id status = args.status diff --git a/tests/test_prune.py b/tests/test_prune.py index 06f7b90df8861ad645af21bd68cba1d98e3edd58..bd0f9eae52fb7d596b8f66c24c03264998fd0577 100644 --- a/tests/test_prune.py +++ b/tests/test_prune.py @@ -1,3 +1,6 @@ +""" +Test prune command +""" from unittest import TestCase from atdb.prune import extract_surls_from_obj @@ -6,6 +9,9 @@ class TestPruneUtils(TestCase): """Test Case of the prune utility functions""" def test_surl_filtering(self): + """ + Test surl filtering utility function + """ test_data = { "item": [{"surl": "onesurl"}, 1, ["ciao", {"surl": "another_surl"}]], "item2": {"surl": "third_surl"}, diff --git a/tox.ini b/tox.ini index c0530645e8f647ea7d94921271c1227dcd3522b0..49c39c0fb289bfa445fb8f21b1908275018fe1fc 100644 --- a/tox.ini +++ b/tox.ini @@ -21,7 +21,7 @@ commands = [testenv:coverage] commands = {envpython} --version - {envpython} -m pytest --cov-report xml --cov-report html --cov=map + {envpython} -m pytest --cov-report xml --cov-report html --cov=atdb # Use generative name and command prefixes to reuse the same virtualenv # for all linting jobs. @@ -34,8 +34,8 @@ commands = black: {envpython} -m black --version black: {envpython} -m black --check --diff . pylint: {envpython} -m pylint --version - pylint: {envpython} -m pylint map tests - format: {envpython} -m autopep8 -v -aa --in-place --recursive map + pylint: {envpython} -m pylint atdb tests --extension-pkg-allow-list=gfal2 + format: {envpython} -m autopep8 -v -aa --in-place --recursive atdb format: {envpython} -m autopep8 -v -aa --in-place --recursive tests format: {envpython} -m black -v .