#!/usr/bin/env python

# mom.py
#
# Copyright (C) 2015
# ASTRON (Netherlands Institute for Radio Astronomy)
# P.O.Box 2, 7990 AA Dwingeloo, The Netherlands
#
# This file is part of the LOFAR software suite.
# The LOFAR software suite is free software: you can redistribute it
# and/or modify it under the terms of the GNU General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# The LOFAR software suite is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with the LOFAR software suite. If not, see <http://www.gnu.org/licenses/>.
#
# $Id: mom.py 1580 2015-09-30 14:18:57Z loose $

"""
TODO: documentation
"""

import logging

logger = logging.getLogger(__name__)

def updateTaskStorageDetails(task, sqrpc):
    def applyDefaults(t):
        '''apply sane default values for a task'''
        t['disk_usage'] = None
        t['disk_usage_readable'] = None

    tasklist = task if isinstance(task, list) else [task]

    for t in tasklist:
        applyDefaults(t)

    statuses = set(['finished', 'completing', 'aborted'])
    tasklist = [t for t in tasklist if t['cluster'] == 'CEP4' and t['status'] in statuses]

    if len(tasklist) == 0:
        return

    if not sqrpc:
        return

    try:
        otdb_ids = [t['otdb_id'] for t in tasklist]
        usages = sqrpc.getDiskUsageForTasks(otdb_ids=otdb_ids, include_scratch_paths=False).get('otdb_ids')

        if not usages:
            return

        for task in tasklist:
            otdb_id = str(task['otdb_id'])
            if otdb_id in usages:
                usage = usages[otdb_id]
                t['disk_usage'] = usage['disk_usage']
                t['disk_usage_readable'] = usage['disk_usage_readable']

    except Exception as e:
        logger.error(str(e))