Select Git revision
averaging.py
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
averaging.py 15.20 KiB
import h5py
from argparse import ArgumentParser
import os
from typing import Dict, Optional, Iterable, Any, Union, ByteString, AnyStr
from datetime import datetime, timedelta
import json
import numpy
import logging
from astropy.coordinates import EarthLocation, SkyCoord, AltAz
from astropy.time import Time
import astropy.io.fits as fits
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s', level=logging.INFO)
def decode_str(str_or_byteslike: Union[ByteString, AnyStr]):
try:
return str_or_byteslike.decode()
except (UnicodeDecodeError, AttributeError):
return str_or_byteslike
_ROOT_SELECTED_FIELDS = (
"ANTENNA_SET",
"CHANNELS_PER_SUBANDS",
"CHANNEL_WIDTH",
"CHANNEL_WIDTH_UNIT",
"FILEDATE",
"FILENAME",
"FILETYPE",
"FILTER_SELECTION",
"NOF_SAMPLES",
"NOTES",
"OBSERVATION_END_MJD",
"OBSERVATION_END_UTC",
"OBSERVATION_ID",
"OBSERVATION_NOF_BITS_PER_SAMPLE",
"OBSERVATION_START_MJD",
"OBSERVATION_START_UTC",
"POINT_DEC",
"POINT_RA",
"PRIMARY_POINTING_DIAMETER",
"PROJECT_ID",
"SAMPLING_RATE",
"SAMPLING_RATE_UNIT",
"SAMPLING_TIME",
"SAMPLING_TIME_UNIT",
"SUBBAND_WIDTH",
"SUBBAND_WIDTH_UNIT",
"TARGET",
"TELESCOPE",
"TOTAL_BAND_WIDTH",
"TOTAL_INTEGRATION_TIME",
"TOTAL_INTEGRATION_TIME_UNIT",
)
_BEAM_FIELDS = (
"BEAM_FREQUENCY_CENTER",
"BEAM_FREQUENCY_MAX",
"BEAM_FREQUENCY_MIN",
"BEAM_FREQUENCY_UNIT",
"BEAM_STATIONS_LIST",
"DYNSPEC_BANDWIDTH",
"POINT_DEC",
"POINT_RA",
"point_start_azimuth",
"point_end_azimuth",
"point_start_elevation",
"point_end_elevation",
"SIGNAL_SUM",
"STOCKES_COMPONENT",
"TARGET",
"TRACKING",
"REF_LOCATION_FRAME",
"REF_LOCATION_UNIT",
"REF_LOCATION_VALUE",
"REF_TIME_FRAME",
"REF_TIME_UNIT",
"REF_TIME_VALUE"
)
def mjd_to_datetime(mjd) -> datetime:
return Time(mjd, format='mjd').to_datetime()
class SmartJsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> Any:
try:
if isinstance(o, numpy.int32):
return int(o)
elif isinstance(o, numpy.ndarray):
return o.tolist()
elif isinstance(o, numpy.uint64):
return int(o)
elif isinstance(o, datetime):
return o.isoformat()
else:
return super().default(o)
except TypeError:
print(o)
raise Exception('Cannot convert ' + str(type(o)))
def parse_args():
parser = ArgumentParser(description='Scintillation averaging script')
parser.add_argument('scintillation_dataset', help='Scintillation dataset [e.g. Dynspec_rebinned_L271905_SAP000.h5]')
parser.add_argument('output_directory', help='Output directory')
return parser.parse_args()
def open_dataset(path):
if not os.path.exists(path):
raise FileNotFoundError(f'Cannot find file at {path}')
return h5py.File(path, mode='r')
def copy_attrs_to_dict(h5_leaf, dict_container: Optional[Dict] = None,
exclude_fields: Optional[Iterable] = None,
include_fields: Optional[Iterable] = None):
exclude_fields_set = set(exclude_fields) if exclude_fields else None
if dict_container is None:
dict_container = {}
for key, value in h5_leaf.attrs.items():
if include_fields is not None and key not in include_fields:
continue
if exclude_fields_set and key in exclude_fields_set:
continue
if isinstance(value, datetime):
dict_container[key] = value.isoformat()
elif isinstance(value, list) or isinstance(value, tuple) or isinstance(value, numpy.ndarray):
dict_container[key] = list(map(decode_str, value))
else:
dict_container[key] = decode_str(value)
return dict_container
def parse_datetime_str(datetime_str):
return Time(datetime_str.split(' ')[0], format='isot', scale='utc').to_datetime()
def extract_root_metadata(dataset):
metadata = dict()
copy_attrs_to_dict(dataset['/'], metadata, include_fields=_ROOT_SELECTED_FIELDS)
metadata['OBSERVATION_START_UTC'] = metadata['OBSERVATION_START_UTC'].split(' ')[0]
metadata['OBSERVATION_END_UTC'] = metadata['OBSERVATION_END_UTC'].split(' ')[0]
return metadata
def extract_coordinates_metadata(dataset):
coordinates = copy_attrs_to_dict(dataset)
for item in dataset:
copy_attrs_to_dict(dataset[item], coordinates)
return coordinates
def extract_dynspec(dataset):
dynspec_root = {}
copy_attrs_to_dict(dataset, dynspec_root)
for item in dataset:
if item == 'COORDINATES':
coordinates = extract_coordinates_metadata(dataset[item])
dynspec_root.update(coordinates)
else:
dynspec_root.update(copy_attrs_to_dict(dataset[item]))
dynspec_root = {key: value for key, value in dynspec_root.items() if key in _BEAM_FIELDS}
return dynspec_root
def compute_start_end_azimuth_elevation(metadata):
x, y, z = metadata['REF_LOCATION_VALUE']
unit, *_ = metadata['REF_LOCATION_UNIT']
ra, dec = metadata['POINT_RA'], metadata['POINT_DEC']
start_time = Time(metadata['OBSERVATION_START_UTC'].split(' ')[0], format='isot', scale='utc')
end_time = Time(metadata['OBSERVATION_END_UTC'].split(' ')[0], format='isot', scale='utc')
location = EarthLocation(x=x, y=y, z=z, unit=unit)
start_altaz, end_altaz = SkyCoord(ra=ra, dec=dec, unit='deg').transform_to(AltAz(obstime=[start_time, end_time],
location=location))
metadata['point_start_azimuth'] = start_altaz.az.deg
metadata['point_start_elevation'] = start_altaz.alt.deg
metadata['point_end_azimuth'] = end_altaz.az.deg
metadata['point_end_elevation'] = end_altaz.alt.deg
return metadata
def extract_metadata(dataset):
root_metadata = extract_root_metadata(dataset)
metadata_per_dynspec = {}
for dynspec in dataset['/']:
if not dynspec.startswith('DYNSPEC'):
continue
metadata_per_dynspec[dynspec] = extract_dynspec(dataset[dynspec])
metadata_per_dynspec[dynspec].update(root_metadata)
compute_start_end_azimuth_elevation(metadata_per_dynspec[dynspec])
return metadata_per_dynspec
def create_fits_from_dataset(sample_info, data_array, output_path):
start_datetime = sample_info['sample_start_datetime']
end_datetime = sample_info['sample_end_datetime']
delta_seconds = (end_datetime - start_datetime).seconds / data_array.shape[0]
start_freq = sample_info['sample_start_frequency']
end_freq = sample_info['sample_end_frequency']
delta_freq = (end_freq - start_freq) / data_array.shape[1]
hdu_lofar = fits.PrimaryHDU()
hdu_lofar.data = data_array[:, :, 0].T
hdu_lofar.header['SIMPLE'] = True
hdu_lofar.header['BITPIX'] = 8
hdu_lofar.header['NAXIS '] = 2
hdu_lofar.header['NAXIS1'] = data_array.shape[0]
hdu_lofar.header['NAXIS2'] = data_array.shape[1]
hdu_lofar.header['EXTEND'] = True
hdu_lofar.header['DATE'] = start_datetime.strftime("%Y-%m-%d")
hdu_lofar.header['CONTENT'] = start_datetime.strftime("%Y/%m/%d") + ' Radio Flux Intensity LOFAR ' + \
sample_info['ANTENNA_SET']
hdu_lofar.header['ORIGIN'] = 'ASTRON Netherlands'
hdu_lofar.header['TELESCOP'] = sample_info['TELESCOPE']
hdu_lofar.header['INSTRUME'] = sample_info['ANTENNA_SET']
hdu_lofar.header['OBJECT'] = sample_info['TARGET'][0]
hdu_lofar.header['DATE-OBS'] = start_datetime.strftime("%Y/%m/%d")
hdu_lofar.header['TIME-OBS'] = start_datetime.strftime("%H:%M:%S.%f")
hdu_lofar.header['DATE-END'] = end_datetime.strftime("%Y/%m/%d")
hdu_lofar.header['TIME-END'] = end_datetime.strftime("%H:%M:%S.%f")
hdu_lofar.header['BZERO'] = 0.
hdu_lofar.header['BSCALE'] = 1.
hdu_lofar.header['BUNIT'] = 'digits '
hdu_lofar.header['DATAMIN'] = numpy.min(data_array)
hdu_lofar.header['DATAMAX'] = numpy.max(data_array)
hdu_lofar.header['CRVAL1'] = start_datetime.timestamp()
hdu_lofar.header['CRPIX1'] = 0
hdu_lofar.header['CTYPE1'] = 'Time [UT]'
hdu_lofar.header['CDELT1'] = delta_seconds
hdu_lofar.header['CRVAL2'] = start_freq
hdu_lofar.header['CRPIX2'] = 0
hdu_lofar.header['CTYPE2'] = 'Frequency [MHz]'
hdu_lofar.header['CDELT2'] = delta_freq
hdu_lofar.header['RA'] = sample_info['POINT_RA']
hdu_lofar.header['DEC'] = sample_info['POINT_DEC']
hdu_lofar.header['HISTORY'] = ' '
full_hdu = fits.HDUList([hdu_lofar])
full_hdu.writeto(output_path, overwrite=True)
def make_plot(data_array, time_axis, frequency_axis, station_name, plot_full_path):
fig = plt.figure(figsize=(6, 4), dpi=120)
ax = plt.gca()
start_time = datetime.fromtimestamp(time_axis[0]).strftime("%Y/%m/%d %H:%M:%S")
datetime_axis = [datetime.fromtimestamp(time_s) for time_s in time_axis]
high_freq = frequency_axis > 30
times = mdates.date2num(datetime_axis)
title = f'Dynspec {station_name} - {start_time}'
data_fits_new = data_array - numpy.nanmean(
numpy.sort(data_array, 0)[
int(data_array.shape[0] * 0.1):int(data_array.shape[0] * 0.3), :], 0)
high_freq_mean = data_fits_new[:, high_freq, 0].mean()
high_freq_std = data_fits_new[:, high_freq, 0].std()
vmin = (high_freq_mean - 2 * high_freq_std)
vmax = (high_freq_mean + 3 * high_freq_std)
ax.imshow(data_fits_new[:, :, 0].T, origin='lower', aspect='auto',
vmin=vmin,
vmax=vmax,
extent=[times[0], times[-1], frequency_axis[0], frequency_axis[-1]],
cmap='inferno')
plt.suptitle(title)
ax.xaxis_date()
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))
ax.set_xlim(
[datetime(year=datetime_axis[0].year, month=datetime_axis[0].month, day=datetime_axis[0].day,
hour=datetime_axis[0].hour),
datetime(year=datetime_axis[-1].year, month=datetime_axis[-1].month, day=(datetime_axis[-1].day + (datetime_axis[-1].hour + 1) // 24),
hour=(datetime_axis[-1].hour + 1) % 24 )]
)
ax.set_xlabel('Time (UT)')
ax.set_ylabel('Frequency (MHz)')
plt.savefig(plot_full_path)
plt.close(fig)
def create_averaged_dataset(sample_info, start_index, data_array):
average_window = sample_info['average_window_samples']
start_datetime, end_datetime = sample_info['sample_start_datetime'], sample_info['sample_end_datetime']
start_freq, end_freq = sample_info['sample_start_frequency'], sample_info['sample_end_frequency']
output_samples = sample_info['sample_time_samples']
tmp_array = numpy.zeros([output_samples, *data_array.shape[1:]], dtype=numpy.float64)
time_axis = numpy.linspace(start_datetime.timestamp(), end_datetime.timestamp(), output_samples)
frequency_axis = numpy.linspace(start_freq, end_freq, data_array.shape[1])
for i in range(output_samples):
index = i * average_window + start_index
tmp_array[i: i + 1, :, :] = numpy.median(data_array[index:index + average_window, :, :], axis=0)
tmp_array[tmp_array > 0] = 10.0 * numpy.log10(tmp_array[tmp_array > 0])
return numpy.array(tmp_array, dtype=numpy.float32), time_axis, frequency_axis
def round_up_datetime(datet, interval):
return datetime.fromtimestamp(numpy.ceil(datet.timestamp() / interval) * interval)
def round_down_datetime(datet, interval):
return datetime.fromtimestamp(numpy.floor(datet.timestamp() / interval) * interval)
def split_samples(dynspec_name,
metadata,
dataset: h5py.File,
sample_window,
averaging_window,
out_directory):
"""
:param dynspec_name: Dynspec tab naem
:param dataset: dynspec dataset
:param sample_window: sample window in seconds
:param averaging_window: averaging window in seconds
:return:
"""
time_delta, *_ = decode_str(dataset[dynspec_name]['COORDINATES']['TIME'].attrs['INCREMENT'])
obs_start_time = parse_datetime_str(decode_str(dataset[dynspec_name].attrs['DYNSPEC_START_UTC']))
obs_end_time = parse_datetime_str(decode_str(dataset[dynspec_name].attrs['DYNSPEC_STOP_UTC']))
frequency = dataset[dynspec_name]['COORDINATES']['SPECTRAL'].attrs['AXIS_VALUE_WORLD']
antenna_set = metadata['ANTENNA_SET']
start_frequency, end_frequency = frequency[0] / 1.e6, frequency[-1] / 1.e6
station_name, *_ = metadata['BEAM_STATIONS_LIST']
station_name = decode_str(station_name)
averaging_window_in_samples = int(numpy.ceil(averaging_window / time_delta))
averaging_window_in_seconds = averaging_window_in_samples * time_delta
data_array = dataset[dynspec_name]['DATA']
total_time_samples = data_array.shape[0]
start_obs_datetime = round_down_datetime(obs_start_time, sample_window)
end_obs_datetime = round_up_datetime(obs_end_time, sample_window)
time_obs = numpy.linspace(obs_start_time.timestamp(), obs_end_time.timestamp(), total_time_samples)
n_samples = int((end_obs_datetime - start_obs_datetime).seconds // sample_window)
for i in range(n_samples):
start_sample_datetime = round_down_datetime(start_obs_datetime + timedelta(seconds=sample_window * i),
averaging_window)
end_sample_datetime = round_up_datetime(start_obs_datetime + timedelta(seconds=sample_window * (i + 1)),
averaging_window)
indexs = numpy.where(numpy.logical_and(time_obs > start_sample_datetime.timestamp(),
time_obs <= end_sample_datetime.timestamp()))[0]
start_index, end_index = indexs[0], indexs[-1]
fname = start_sample_datetime.strftime(
"LOFAR_%Y%m%d_%H%M%S_") + station_name + '_' + antenna_set
full_path = os.path.join(out_directory, fname)
logging.info('--processing sample number %s for station %s--', i, station_name)
output_time_samples = int(numpy.ceil(len(indexs) / averaging_window_in_samples))
sample_info = {
'average_window_samples': averaging_window_in_samples,
'average_window_seconds': averaging_window_in_seconds,
'sample_time_samples': output_time_samples,
'sample_start_datetime': datetime.fromtimestamp(time_obs[start_index]),
'sample_end_datetime': datetime.fromtimestamp(time_obs[end_index]),
'n_time_samples': len(indexs),
'sample_start_frequency': start_frequency,
'sample_end_frequency': end_frequency,
**metadata
}
averaged_data_array, time_axis, frequency_axis = create_averaged_dataset(sample_info, start_index,
data_array)
make_plot(averaged_data_array, time_axis, frequency_axis, station_name, full_path + '.png')
create_fits_from_dataset(sample_info, averaged_data_array, full_path + '.fits')
store_metadata(sample_info, full_path + '.json')
def store_metadata(metadata, path):
with open(path, 'w') as fout:
json.dump(metadata, fout, cls=SmartJsonEncoder, indent=True)