# base_resource_estimator.py
#
# Copyright (C) 2016, 2017
# ASTRON (Netherlands Institute for Radio Astronomy)
# P.O.Box 2, 7990 AA Dwingeloo, The Netherlands
#
# This file is part of the LOFAR software suite.
# The LOFAR software suite is free software: you can redistribute it
# and/or modify it under the terms of the GNU General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# The LOFAR software suite is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with the LOFAR software suite. If not, see <http://www.gnu.org/licenses/>.
#
# $Id: base_resource_estimator.py 33534 2016-02-08 14:28:26Z schaap $

import logging
from math import ceil
from base_pipeline_estimator import BasePipelineResourceEstimator
from lofar.parameterset import parameterset

logger = logging.getLogger(__name__)

DATAPRODUCTS = "Observation.DataProducts."
PIPELINE = "Observation.ObservationControl.PythonControl."

#Observation.DataProducts.Output_Correlated.storageClusterName=
#Observation.ObservationControl.PythonControl.AWimager

class ImagePipelineResourceEstimator(BasePipelineResourceEstimator):
    """ ResourceEstimator for Imaging Pipelines
    """
    def __init__(self):
        logger.info("init ImagePipelineResourceEstimator")
        BasePipelineResourceEstimator.__init__(self, name='pipeline') #FIXME name='imaging_pipeline'
        self.required_keys = ('Observation.startTime',
                              'Observation.stopTime',
                              DATAPRODUCTS + 'Input_Correlated.enabled',
                              DATAPRODUCTS + 'Input_Correlated.identifications',
                              #DATAPRODUCTS + 'Input_Correlated.storageClusterName',  # TODO: also add input estimates
                              DATAPRODUCTS + 'Output_SkyImage.enabled',
                              DATAPRODUCTS + 'Output_SkyImage.identifications',
                              DATAPRODUCTS + 'Output_SkyImage.storageClusterName',
                              PIPELINE + 'Imaging.slices_per_image',
                              PIPELINE + 'Imaging.subbands_per_image')

    def _calculate(self, parset, input_files):
        """ Estimate for Imaging Pipeline. Also gets used for MSSS Imaging Pipeline
        calculates: datasize (number of files, file size), bandwidth
        input_files should look something like:
TODO
        'input_files': 
        {'uv': {'nr_of_uv_files': 481, 'uv_file_size': 1482951104}, ...}
        
        reply is something along the lines of:
        {'bandwidth': {'total_size': 19021319494},
        'storage': {'total_size': 713299481024,
        'output_files': 
          {'img': {'nr_of_img_files': 481, 'img_file_size': 148295}
        }}
        """
        logger.debug("start estimate '{}'".format(self.name))
        logger.info('parset: %s ' % parset)
        result = {'errors': [], 'estimates': [{}]}  # can all be described in 1 estimate here
#TODO: really? What if input_files from output_files which has a list of len > 1?
#parset.getString(DATAPRODUCTS + 'Input_Correlated.storageClusterName')
        identifications = parset.getStringVector(DATAPRODUCTS + 'Input_Correlated.identifications')
        input_files = self._filterInputs(input_files, identifications)
        result['estimates'][0]['input_files'] = input_files

        duration = self._getDuration(parset.getString('Observation.startTime'),
                                     parset.getString('Observation.stopTime'))
        slices_per_image = parset.getInt(PIPELINE + 'Imaging.slices_per_image', 0) #TODO, should these have defaults?
        subbands_per_image = parset.getInt(PIPELINE + 'Imaging.subbands_per_image', 0)

        if not parset.getBool(DATAPRODUCTS + 'Output_SkyImage.enabled'):
            logger.error('Output_SkyImage is not enabled')
            result['errors'].append('Output_SkyImage is not enabled')
        if not 'uv' in input_files:
            logger.error('Missing UV Dataproducts in input_files')
            result['errors'].append('Missing UV Dataproducts in input_files')
        else:
            nr_input_subbands = input_files['uv']['nr_of_uv_files']
        if not slices_per_image or not subbands_per_image:
            logger.error('slices_per_image or subbands_per_image are not valid')
            result['errors'].append('Missing UV Dataproducts in input_files')
        if nr_input_subbands % (subbands_per_image * slices_per_image) > 0:
            logger.error('slices_per_image and subbands_per_image not a multiple of number of inputs')
            result['errors'].append('slices_per_image and subbands_per_image not a multiple of number of inputs')
        if result['errors']:
            return result

        logger.debug("calculate sky image data size")
        result['estimates'][0]['output_files'] = {}
        nr_images = nr_input_subbands / (subbands_per_image * slices_per_image)
        result['estimates'][0]['output_files']['img'] = {'nr_of_img_files': nr_images,
                                                         'img_file_size': 1000,  # 1 kB was hardcoded in the Scheduler
                                                         'identifications': parset.getStringVector(DATAPRODUCTS + 'Output_SkyImage.identifications')}
        logger.info("sky_images: {} files {} bytes each".format(result['estimates'][0]['output_files']['img']['nr_of_img_files'],
                                                                result['estimates'][0]['output_files']['img']['img_file_size']))

        # count total data size
        total_data_size = result['estimates'][0]['output_files']['img']['nr_of_img_files'] * \
                          result['estimates'][0]['output_files']['img']['img_file_size'] # bytes
        total_bandwidth = int(ceil(total_data_size * 8 / duration))  # bits/second
        if total_data_size:
            needed_resource_types = {'bandwidth': total_bandwidth, 'storage': total_data_size}
            result['resource_types'] = needed_resource_types
            result['resource_count'] = 1
            result['root_resource_group'] = parset.getString(DATAPRODUCTS + 'Output_SkyImage.storageClusterName')
        else:
            result['errors'].append('Total data size is zero!')
            logger.error('ERROR: A datasize of zero was calculated!')
        return result