diff --git a/atdb/communication.py b/atdb/communication.py index a7cb8e86070569df1ffd6b2d3733b4f0e244a1b2..68f8d8b02978622f40d2874a7e427b9f87b2a081 100644 --- a/atdb/communication.py +++ b/atdb/communication.py @@ -1,8 +1,9 @@ """ This module is responsible for the communication to and from ATDB """ -from typing import List, Generator from argparse import Namespace +from typing import List, Generator + import requests @@ -122,6 +123,14 @@ class APIConnector: for item in drf_reply.results: yield item + def update_task_processed_size(self, task_id, processed_size): + """ + Change the whole task content + """ + return self._request_path( + "PUT", f"tasks/{task_id}", content={"size_processed": processed_size} + ) + def change_task_status(self, task_id, status) -> None: """ Change the status of a task diff --git a/atdb/fix.py b/atdb/fix.py new file mode 100644 index 0000000000000000000000000000000000000000..efff84aa5293d36fbe0ab22df950e892b2e6416f --- /dev/null +++ b/atdb/fix.py @@ -0,0 +1,71 @@ +""" +Fix command module +""" +import logging +import atdb.communication as com + + +logger = logging.getLogger("fix") + + +def aggregate_on_tree(tree, field): + """ + Aggregated values with a given field name from a dict tree + """ + if isinstance(tree, dict) and field in tree: + return tree[field] + if isinstance(tree, dict): + total = 0 + for value in tree.values(): + total += aggregate_on_tree(value, field) + return total + if isinstance(tree, list): + total = 0 + for item in tree: + total += aggregate_on_tree(item, field) + return total + + return 0 + + +def compute_output_sizes(outputs): + """ + Computes the size of the output files + """ + if outputs is not None: + return aggregate_on_tree( + {key: value for key, value in outputs.items() if key != "ingest"}, "size" + ) + return 0 + + +def fix_computed_sizes(connector, dry_run=True): + """ + Fix the size of the computed task + """ + for task in connector.list_iter("tasks"): + task_id = task["id"] + size_before = task["size_processed"] + total_output_size = compute_output_sizes(task["outputs"]) + task["size_processed"] = total_output_size + if not dry_run: + if size_before != total_output_size: + connector.update_task_processed_size(task_id, total_output_size) + else: + if size_before != total_output_size: + logger.info( + "Dry run: Size updated for %s from %s to %s", + task_id, + size_before, + total_output_size, + ) + + +def fix(args): + """ + Fix command + + Changes task fields to be consistent with each others + """ + connector = com.APIConnector.from_args(args) + fix_computed_sizes(connector, dry_run=args.dry_run) diff --git a/atdb/main.py b/atdb/main.py index 2268a3a9ed2a2457cc5aae4d26636e175ac674c0..1637e3dd66a06cb4e7f9e5e832c091042f9a40f5 100644 --- a/atdb/main.py +++ b/atdb/main.py @@ -8,6 +8,7 @@ from argparse import ArgumentParser, Namespace from configparser import ConfigParser from atdb.prune import prune +from atdb.fix import fix DEFAULT_PATH = os.path.expanduser("~/.config/ldv/services.cfg") logging.basicConfig( @@ -61,6 +62,9 @@ def parse_args() -> (Namespace, ArgumentParser): prune_parser = subparser.add_parser("prune") prune_parser.add_argument("--workflow_id", help="Filters by workflow id") prune_parser.add_argument("--status", help="Filter by status") + + _ = subparser.add_parser("fix") + return parser.parse_args(), parser @@ -75,6 +79,8 @@ def main(): if args.operation == "prune": prune(args) + elif args.operation == "fix": + fix(args) else: parser.print_help()