Skip to content
Snippets Groups Projects
Select Git revision
  • 4bc9727d002ab5c96d4eaba58b8181d90822f0d5
  • master default protected
  • image_support_for_boolean
  • image_support_lofar_fixes
  • image_support
  • moved-to-gitlab
  • remove-libpqxx-submodule
  • v0.11.2
  • v0.11.1
  • v0.11.0
  • v0.10.0
  • v0.9.1
  • v0.9.0
13 results

cluster.sql

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    averaging.py 15.20 KiB
    import h5py
    from argparse import ArgumentParser
    import os
    from typing import Dict, Optional, Iterable, Any, Union, ByteString, AnyStr
    from datetime import datetime, timedelta
    import json
    import numpy
    import logging
    from astropy.coordinates import EarthLocation, SkyCoord, AltAz
    from astropy.time import Time
    import astropy.io.fits as fits
    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    
    logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s', level=logging.INFO)
    
    
    def decode_str(str_or_byteslike: Union[ByteString, AnyStr]):
        try:
            return str_or_byteslike.decode()
        except (UnicodeDecodeError, AttributeError):
            return str_or_byteslike
    
    
    _ROOT_SELECTED_FIELDS = (
        "ANTENNA_SET",
        "CHANNELS_PER_SUBANDS",
        "CHANNEL_WIDTH",
        "CHANNEL_WIDTH_UNIT",
        "FILEDATE",
        "FILENAME",
        "FILETYPE",
        "FILTER_SELECTION",
        "NOF_SAMPLES",
        "NOTES",
        "OBSERVATION_END_MJD",
        "OBSERVATION_END_UTC",
        "OBSERVATION_ID",
        "OBSERVATION_NOF_BITS_PER_SAMPLE",
        "OBSERVATION_START_MJD",
        "OBSERVATION_START_UTC",
        "POINT_DEC",
        "POINT_RA",
        "PRIMARY_POINTING_DIAMETER",
        "PROJECT_ID",
        "SAMPLING_RATE",
        "SAMPLING_RATE_UNIT",
        "SAMPLING_TIME",
        "SAMPLING_TIME_UNIT",
        "SUBBAND_WIDTH",
        "SUBBAND_WIDTH_UNIT",
        "TARGET",
        "TELESCOPE",
        "TOTAL_BAND_WIDTH",
        "TOTAL_INTEGRATION_TIME",
        "TOTAL_INTEGRATION_TIME_UNIT",
    
    )
    
    _BEAM_FIELDS = (
        "BEAM_FREQUENCY_CENTER",
        "BEAM_FREQUENCY_MAX",
        "BEAM_FREQUENCY_MIN",
        "BEAM_FREQUENCY_UNIT",
        "BEAM_STATIONS_LIST",
        "DYNSPEC_BANDWIDTH",
        "POINT_DEC",
        "POINT_RA",
        "point_start_azimuth",
        "point_end_azimuth",
        "point_start_elevation",
        "point_end_elevation",
        "SIGNAL_SUM",
        "STOCKES_COMPONENT",
        "TARGET",
        "TRACKING",
        "REF_LOCATION_FRAME",
        "REF_LOCATION_UNIT",
        "REF_LOCATION_VALUE",
        "REF_TIME_FRAME",
        "REF_TIME_UNIT",
        "REF_TIME_VALUE"
    )
    
    
    def mjd_to_datetime(mjd) -> datetime:
        return Time(mjd, format='mjd').to_datetime()
    
    
    class SmartJsonEncoder(json.JSONEncoder):
        def default(self, o: Any) -> Any:
            try:
                if isinstance(o, numpy.int32):
                    return int(o)
                elif isinstance(o, numpy.ndarray):
                    return o.tolist()
                elif isinstance(o, numpy.uint64):
                    return int(o)
                elif isinstance(o, datetime):
                    return o.isoformat()
                else:
                    return super().default(o)
            except TypeError:
                print(o)
                raise Exception('Cannot convert ' + str(type(o)))
    
    
    def parse_args():
        parser = ArgumentParser(description='Scintillation averaging script')
        parser.add_argument('scintillation_dataset', help='Scintillation dataset [e.g. Dynspec_rebinned_L271905_SAP000.h5]')
        parser.add_argument('output_directory', help='Output directory')
    
        return parser.parse_args()
    
    
    def open_dataset(path):
        if not os.path.exists(path):
            raise FileNotFoundError(f'Cannot find file at {path}')
        return h5py.File(path, mode='r')
    
    
    def copy_attrs_to_dict(h5_leaf, dict_container: Optional[Dict] = None,
                           exclude_fields: Optional[Iterable] = None,
                           include_fields: Optional[Iterable] = None):
        exclude_fields_set = set(exclude_fields) if exclude_fields else None
        if dict_container is None:
            dict_container = {}
    
        for key, value in h5_leaf.attrs.items():
            if include_fields is not None and key not in include_fields:
                continue
            if exclude_fields_set and key in exclude_fields_set:
                continue
    
            if isinstance(value, datetime):
                dict_container[key] = value.isoformat()
            elif isinstance(value, list) or isinstance(value, tuple) or isinstance(value, numpy.ndarray):
                dict_container[key] = list(map(decode_str, value))
            else:
                dict_container[key] = decode_str(value)
        return dict_container
    
    
    def parse_datetime_str(datetime_str):
        return Time(datetime_str.split(' ')[0], format='isot', scale='utc').to_datetime()
    
    
    def extract_root_metadata(dataset):
        metadata = dict()
        copy_attrs_to_dict(dataset['/'], metadata, include_fields=_ROOT_SELECTED_FIELDS)
        metadata['OBSERVATION_START_UTC'] = metadata['OBSERVATION_START_UTC'].split(' ')[0]
        metadata['OBSERVATION_END_UTC'] = metadata['OBSERVATION_END_UTC'].split(' ')[0]
        return metadata
    
    
    def extract_coordinates_metadata(dataset):
        coordinates = copy_attrs_to_dict(dataset)
        for item in dataset:
            copy_attrs_to_dict(dataset[item], coordinates)
        return coordinates
    
    
    def extract_dynspec(dataset):
        dynspec_root = {}
        copy_attrs_to_dict(dataset, dynspec_root)
        for item in dataset:
            if item == 'COORDINATES':
                coordinates = extract_coordinates_metadata(dataset[item])
                dynspec_root.update(coordinates)
            else:
                dynspec_root.update(copy_attrs_to_dict(dataset[item]))
    
        dynspec_root = {key: value for key, value in dynspec_root.items() if key in _BEAM_FIELDS}
        return dynspec_root
    
    
    def compute_start_end_azimuth_elevation(metadata):
        x, y, z = metadata['REF_LOCATION_VALUE']
        unit, *_ = metadata['REF_LOCATION_UNIT']
        ra, dec = metadata['POINT_RA'], metadata['POINT_DEC']
        start_time = Time(metadata['OBSERVATION_START_UTC'].split(' ')[0], format='isot', scale='utc')
        end_time = Time(metadata['OBSERVATION_END_UTC'].split(' ')[0], format='isot', scale='utc')
    
        location = EarthLocation(x=x, y=y, z=z, unit=unit)
        start_altaz, end_altaz = SkyCoord(ra=ra, dec=dec, unit='deg').transform_to(AltAz(obstime=[start_time, end_time],
                                                                                         location=location))
        metadata['point_start_azimuth'] = start_altaz.az.deg
        metadata['point_start_elevation'] = start_altaz.alt.deg
    
        metadata['point_end_azimuth'] = end_altaz.az.deg
        metadata['point_end_elevation'] = end_altaz.alt.deg
        return metadata
    
    
    def extract_metadata(dataset):
        root_metadata = extract_root_metadata(dataset)
        metadata_per_dynspec = {}
        for dynspec in dataset['/']:
            if not dynspec.startswith('DYNSPEC'):
                continue
            metadata_per_dynspec[dynspec] = extract_dynspec(dataset[dynspec])
            metadata_per_dynspec[dynspec].update(root_metadata)
            compute_start_end_azimuth_elevation(metadata_per_dynspec[dynspec])
    
        return metadata_per_dynspec
    
    
    def create_fits_from_dataset(sample_info, data_array, output_path):
        start_datetime = sample_info['sample_start_datetime']
        end_datetime = sample_info['sample_end_datetime']
        delta_seconds = (end_datetime - start_datetime).seconds / data_array.shape[0]
    
        start_freq = sample_info['sample_start_frequency']
        end_freq = sample_info['sample_end_frequency']
        delta_freq = (end_freq - start_freq) / data_array.shape[1]
    
        hdu_lofar = fits.PrimaryHDU()
        hdu_lofar.data = data_array[:, :, 0].T
        hdu_lofar.header['SIMPLE'] = True
        hdu_lofar.header['BITPIX'] = 8
        hdu_lofar.header['NAXIS '] = 2
        hdu_lofar.header['NAXIS1'] = data_array.shape[0]
        hdu_lofar.header['NAXIS2'] = data_array.shape[1]
        hdu_lofar.header['EXTEND'] = True
        hdu_lofar.header['DATE'] = start_datetime.strftime("%Y-%m-%d")
        hdu_lofar.header['CONTENT'] = start_datetime.strftime("%Y/%m/%d") + ' Radio Flux Intensity LOFAR ' + \
                                      sample_info['ANTENNA_SET']
        hdu_lofar.header['ORIGIN'] = 'ASTRON Netherlands'
        hdu_lofar.header['TELESCOP'] = sample_info['TELESCOPE']
        hdu_lofar.header['INSTRUME'] = sample_info['ANTENNA_SET']
        hdu_lofar.header['OBJECT'] = sample_info['TARGET'][0]
        hdu_lofar.header['DATE-OBS'] = start_datetime.strftime("%Y/%m/%d")
        hdu_lofar.header['TIME-OBS'] = start_datetime.strftime("%H:%M:%S.%f")
        hdu_lofar.header['DATE-END'] = end_datetime.strftime("%Y/%m/%d")
        hdu_lofar.header['TIME-END'] = end_datetime.strftime("%H:%M:%S.%f")
    
        hdu_lofar.header['BZERO'] = 0.
        hdu_lofar.header['BSCALE'] = 1.
        hdu_lofar.header['BUNIT'] = 'digits  '
        hdu_lofar.header['DATAMIN'] = numpy.min(data_array)
        hdu_lofar.header['DATAMAX'] = numpy.max(data_array)
    
        hdu_lofar.header['CRVAL1'] = start_datetime.timestamp()
        hdu_lofar.header['CRPIX1'] = 0
        hdu_lofar.header['CTYPE1'] = 'Time [UT]'
        hdu_lofar.header['CDELT1'] = delta_seconds
        hdu_lofar.header['CRVAL2'] = start_freq
        hdu_lofar.header['CRPIX2'] = 0
        hdu_lofar.header['CTYPE2'] = 'Frequency [MHz]'
        hdu_lofar.header['CDELT2'] = delta_freq
        hdu_lofar.header['RA'] = sample_info['POINT_RA']
        hdu_lofar.header['DEC'] = sample_info['POINT_DEC']
    
        hdu_lofar.header['HISTORY'] = '        '
    
        full_hdu = fits.HDUList([hdu_lofar])
        full_hdu.writeto(output_path, overwrite=True)
    
    
    def make_plot(data_array, time_axis, frequency_axis, station_name, plot_full_path):
        fig = plt.figure(figsize=(6, 4), dpi=120)
        ax = plt.gca()
        start_time = datetime.fromtimestamp(time_axis[0]).strftime("%Y/%m/%d %H:%M:%S")
        datetime_axis = [datetime.fromtimestamp(time_s) for time_s in time_axis]
        high_freq = frequency_axis > 30
    
        times = mdates.date2num(datetime_axis)
        title = f'Dynspec {station_name} - {start_time}'
        data_fits_new = data_array - numpy.nanmean(
            numpy.sort(data_array, 0)[
            int(data_array.shape[0] * 0.1):int(data_array.shape[0] * 0.3), :], 0)
    
        high_freq_mean = data_fits_new[:, high_freq, 0].mean()
        high_freq_std = data_fits_new[:, high_freq, 0].std()
        vmin = (high_freq_mean - 2 * high_freq_std)
        vmax = (high_freq_mean + 3 * high_freq_std)
    
        ax.imshow(data_fits_new[:, :, 0].T, origin='lower', aspect='auto',
                  vmin=vmin,
                  vmax=vmax,
                  extent=[times[0], times[-1], frequency_axis[0], frequency_axis[-1]],
                  cmap='inferno')
        plt.suptitle(title)
        ax.xaxis_date()
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))
        ax.set_xlim(
            [datetime(year=datetime_axis[0].year, month=datetime_axis[0].month, day=datetime_axis[0].day,
                      hour=datetime_axis[0].hour),
             datetime(year=datetime_axis[-1].year, month=datetime_axis[-1].month, day=(datetime_axis[-1].day + (datetime_axis[-1].hour + 1) // 24),
                      hour=(datetime_axis[-1].hour + 1) % 24 )]
        )
        ax.set_xlabel('Time (UT)')
        ax.set_ylabel('Frequency (MHz)')
        plt.savefig(plot_full_path)
        plt.close(fig)
    
    
    def create_averaged_dataset(sample_info, start_index, data_array):
        average_window = sample_info['average_window_samples']
        start_datetime, end_datetime = sample_info['sample_start_datetime'], sample_info['sample_end_datetime']
        start_freq, end_freq = sample_info['sample_start_frequency'], sample_info['sample_end_frequency']
        output_samples = sample_info['sample_time_samples']
    
        tmp_array = numpy.zeros([output_samples, *data_array.shape[1:]], dtype=numpy.float64)
    
        time_axis = numpy.linspace(start_datetime.timestamp(), end_datetime.timestamp(), output_samples)
        frequency_axis = numpy.linspace(start_freq, end_freq, data_array.shape[1])
    
        for i in range(output_samples):
            index = i * average_window + start_index
    
            tmp_array[i: i + 1, :, :] = numpy.median(data_array[index:index + average_window, :, :], axis=0)
    
        tmp_array[tmp_array > 0] = 10.0 * numpy.log10(tmp_array[tmp_array > 0])
    
        return numpy.array(tmp_array, dtype=numpy.float32), time_axis, frequency_axis
    
    
    def round_up_datetime(datet, interval):
        return datetime.fromtimestamp(numpy.ceil(datet.timestamp() / interval) * interval)
    
    
    def round_down_datetime(datet, interval):
        return datetime.fromtimestamp(numpy.floor(datet.timestamp() / interval) * interval)
    
    
    def split_samples(dynspec_name,
                      metadata,
                      dataset: h5py.File,
                      sample_window,
                      averaging_window,
                      out_directory):
        """
    
        :param dynspec_name: Dynspec tab naem
        :param dataset: dynspec dataset
        :param sample_window: sample window in seconds
        :param averaging_window: averaging window in seconds
    
        :return:
        """
        time_delta, *_ = decode_str(dataset[dynspec_name]['COORDINATES']['TIME'].attrs['INCREMENT'])
        obs_start_time = parse_datetime_str(decode_str(dataset[dynspec_name].attrs['DYNSPEC_START_UTC']))
        obs_end_time = parse_datetime_str(decode_str(dataset[dynspec_name].attrs['DYNSPEC_STOP_UTC']))
    
        frequency = dataset[dynspec_name]['COORDINATES']['SPECTRAL'].attrs['AXIS_VALUE_WORLD']
        antenna_set = metadata['ANTENNA_SET']
    
        start_frequency, end_frequency = frequency[0] / 1.e6, frequency[-1] / 1.e6
    
        station_name, *_ = metadata['BEAM_STATIONS_LIST']
        station_name = decode_str(station_name)
        averaging_window_in_samples = int(numpy.ceil(averaging_window / time_delta))
        averaging_window_in_seconds = averaging_window_in_samples * time_delta
    
        data_array = dataset[dynspec_name]['DATA']
        total_time_samples = data_array.shape[0]
    
        start_obs_datetime = round_down_datetime(obs_start_time, sample_window)
        end_obs_datetime = round_up_datetime(obs_end_time, sample_window)
        time_obs = numpy.linspace(obs_start_time.timestamp(), obs_end_time.timestamp(), total_time_samples)
        n_samples = int((end_obs_datetime - start_obs_datetime).seconds // sample_window)
    
        for i in range(n_samples):
            start_sample_datetime = round_down_datetime(start_obs_datetime + timedelta(seconds=sample_window * i),
                                                        averaging_window)
            end_sample_datetime = round_up_datetime(start_obs_datetime + timedelta(seconds=sample_window * (i + 1)),
    
                                                    averaging_window)
    
            indexs = numpy.where(numpy.logical_and(time_obs > start_sample_datetime.timestamp(),
                                                   time_obs <= end_sample_datetime.timestamp()))[0]
            start_index, end_index = indexs[0], indexs[-1]
    
            fname = start_sample_datetime.strftime(
                "LOFAR_%Y%m%d_%H%M%S_") + station_name + '_' + antenna_set
    
            full_path = os.path.join(out_directory, fname)
            logging.info('--processing sample number %s for station %s--', i, station_name)
            output_time_samples = int(numpy.ceil(len(indexs) / averaging_window_in_samples))
            sample_info = {
                'average_window_samples': averaging_window_in_samples,
                'average_window_seconds': averaging_window_in_seconds,
                'sample_time_samples': output_time_samples,
                'sample_start_datetime': datetime.fromtimestamp(time_obs[start_index]),
                'sample_end_datetime': datetime.fromtimestamp(time_obs[end_index]),
                'n_time_samples': len(indexs),
                'sample_start_frequency': start_frequency,
                'sample_end_frequency': end_frequency,
                **metadata
            }
    
            averaged_data_array, time_axis, frequency_axis = create_averaged_dataset(sample_info, start_index,
                                                                                     data_array)
            make_plot(averaged_data_array, time_axis, frequency_axis, station_name, full_path + '.png')
            create_fits_from_dataset(sample_info, averaged_data_array, full_path + '.fits')
            store_metadata(sample_info, full_path + '.json')
    
    
    def store_metadata(metadata, path):
        with open(path, 'w') as fout:
            json.dump(metadata, fout, cls=SmartJsonEncoder, indent=True)