Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
views.py 14.02 KiB
import logging
import time
from typing import Tuple
import numpy

from django.contrib.auth.models import User
from django.core.exceptions import ObjectDoesNotExist
from django.http import HttpResponseRedirect
from django.shortcuts import redirect, render
from django.urls import reverse
from django.views.generic import CreateView, DeleteView, DetailView, UpdateView, TemplateView
from django.views.generic.list import ListView
from django_filters import rest_framework as filters
from rest_framework import generics, status, viewsets
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.reverse import reverse_lazy
from rest_framework.schemas.openapi import AutoSchema

from .forms import WorkSpecificationForm
from .models import (
    ATDBProcessingSite,
    DataFilterType,
    DataProduct,
    DataProductFilter,
    WorkSpecification,
)
from .serializers import (
    ATDBProcessingSiteSerializer,
    DataProductFlatSerializer,
    DataProductSerializer,
    WorkSpecificationSerializer,
)
from .tasks import insert_task_into_atdb


def compute_size_of_inputs(inputs: dict) -> Tuple[int, int, int]:
    total_size = 0
    number_of_files = 0

    if isinstance(inputs, dict) and "size" in inputs:
        number_of_files = 1
        total_size = inputs["size"]
    elif (
            isinstance(inputs, dict)
            or isinstance(inputs, list)
            or isinstance(inputs, tuple)
    ):
        values = inputs
        if isinstance(inputs, dict):
            values = inputs.values()
        for value in values:
            item_total, item_count, _ = compute_size_of_inputs(value)
            total_size += item_total
            number_of_files += item_count

    average_file_size = total_size / number_of_files if number_of_files else 0

    return total_size, number_of_files, average_file_size


def compute_inputs_histogram(inputs):
    # create sizes array
    if isinstance(inputs, dict):
        inputs = inputs.values()
    inputs_sizes = []
    for entry in inputs:
        for item in entry:
            inputs_sizes.append(item['size'])
    inputs_sizes = numpy.array(inputs_sizes)

    # define histogram values
    min_size = inputs_sizes.min()
    max_size = inputs_sizes.max()

    n_distinct_sizes = numpy.unique(inputs_sizes).__len__()
    n_bins = min(n_distinct_sizes, 100)
    counts, buckets = numpy.histogram(inputs_sizes, bins=n_bins, range=(min_size, max_size))
    formatted_bins = [format_size(bucket) % bucket for bucket in buckets]

    return min_size, max_size, n_bins, counts.tolist(), counts.max(), formatted_bins


def format_size(num, suffix="B"):
    if num == 0:
        return "-"
    for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
        if abs(num) < 1024.0:
            return f"{num:3.3f}{unit}{suffix}"
        num /= 1024.0
    return f"{num:.1f}Yi{suffix}"


class DynamicFilterSet(filters.FilterSet):
    class Meta:
        filter_class = None

    def __init__(self, *args, **kwargs):
        self._load_filters()
        super().__init__(*args, **kwargs)

    def _load_filters(self):
        if self.Meta.filter_class is None:
            raise Exception("Define filter_class meta attribute")
        for item in self.Meta.filter_class.objects.all():
            field_obj = self.Meta.model._meta.get_field(item.field)
            filter_class, *_ = self.filter_for_lookup(field_obj, item.lookup_type)
            self.base_filters[item.name] = filter_class(item.field)


# --- Filters ---
class DataProductFilterSet(DynamicFilterSet):
    class Meta:
        model = DataProduct
        filter_class = DataProductFilter
        fields = {
            "obs_id": ["exact", "icontains"],
        }


# ---------- GUI Views -----------


def api(request):
    atdb_hosts = ATDBProcessingSite.objects.values("name", "url")
    return render(request, "lofardata/api.html", {"atdb_hosts": atdb_hosts})


def preprocess_filters_specification_view(specification):
    dataproduct_filters = DataProductFilter.objects.all()
    for dataproduct_filter in dataproduct_filters:
        if (
                specification is not None
                and specification.filters
                and dataproduct_filter.field in specification.filters
        ):
            dataproduct_filter.default = specification.filters[dataproduct_filter.field]
        else:
            dataproduct_filter.default = ""

        if dataproduct_filter.filter_type == DataFilterType.DROPDOWN:
            dataproduct_filter.choices = DataProduct.objects.distinct(
                dataproduct_filter.field
            ).values_list(dataproduct_filter.field)
    return dataproduct_filters


def retrieve_general_dataproduct_information(sas_id):
    # Per SAS ID, the retrieved data products should have these unique values
    data_products = DataProduct.objects.filter(obs_id=sas_id).values('dataproduct_source',
                                                                     'dataproduct_type',
                                                                     'project',
                                                                     'location',
                                                                     'activity',
                                                                     'antenna_set',
                                                                     'instrument_filter',
                                                                     'dysco_compression').distinct()
    combined_data_products_on_key = combine_dataproducts_on_key(data_products, {})

    dysco_compressions = combined_data_products_on_key['dysco_compression']
    true_count = len([dysco_compression for dysco_compression in dysco_compressions if dysco_compression])
    dysco_compression_true_percentage = true_count / len(combined_data_products_on_key['dysco_compression'])
    # put in a list for template convenience
    combined_data_products_on_key['dysco_compression'] = [dysco_compression_true_percentage]
    return combined_data_products_on_key


def combine_dataproducts_on_key(data_products, combined_data_products):
    for data_product in data_products:
        combined_data_products = fill_unique_nested_dict(data_product, combined_data_products)
    return combined_data_products


def fill_unique_nested_dict(data_product, combined_data_products_on_key):
    for key, value in data_product.items():
        if combined_data_products_on_key.get(key) and \
                value not in combined_data_products_on_key.get(key) and \
                value is not None and value is not '':
            combined_data_products_on_key[key].append(value)
        else:
            combined_data_products_on_key[key] = [value] if value is not None and value is not '' else []
    return combined_data_products_on_key


class Specifications(ListView):
    serializer_class = WorkSpecificationSerializer
    template_name = "lofardata/index.html"
    model = WorkSpecification
    ordering = ["-created_on"]

    def get_queryset(self):
        queryset = WorkSpecification.objects.all()
        current_user: User = self.request.user
        if current_user.is_staff or current_user.is_superuser:
            return queryset.order_by("-created_on")
        return queryset.filter(created_by=current_user.id).order_by("-created_on")


class WorkSpecificationCreateUpdateView(UpdateView):
    template_name = "lofardata/workspecification/create_update.html"
    model = WorkSpecification
    form_class = WorkSpecificationForm

    def get_object(self, queryset=None):
        if self.kwargs.__len__() == 0 or self.kwargs["pk"] is None:
            specification = WorkSpecification()
        else:
            specification = WorkSpecification.objects.get(pk=self.kwargs["pk"])
        return specification

    def get_context_data(self, **kwargs):
        context = super().get_context_data(**kwargs)
        try:
            specification = WorkSpecification.objects.get(pk=context["object"].pk)
        except ObjectDoesNotExist:
            specification = None
        context["filters"] = preprocess_filters_specification_view(specification)
        context["processing_sites"] = list(ATDBProcessingSite.objects.values("name", "url"))

        return context

    def create_successor(self, specification):
        successor = WorkSpecification()
        successor.predecessor_specification = specification
        successor.processing_site = specification.processing_site
        successor.save()
        return self.get_success_url(pk=successor.pk)

    def form_valid(self, form):
        action_ = form.data["action"]
        specification = form.instance
        if action_ == "Submit":
            specification.async_task_result = None
            specification.is_ready = False
        if action_ == "Send":
            insert_task_into_atdb.delay(specification.pk)
        if action_ == "Successor":
            specification.save()
            successor = WorkSpecification()
            successor.predecessor_specification = specification
            successor.processing_site = specification.processing_site
            successor.selected_workflow = specification.selected_workflow
            return HttpResponseRedirect(self.create_successor(specification))

        return super().form_valid(form)

    def get_success_url(self, **kwargs):
        if kwargs.__len__() == 0 or kwargs["pk"] is None:
            return reverse_lazy("index")
        else:
            return reverse_lazy("specification-update", kwargs={"pk": kwargs["pk"]})


class WorkSpecificationDetailView(DetailView):
    template_name = "lofardata/workspecification/detail.html"
    model = WorkSpecification

    def get_context_data(self, **kwargs):
        context = super().get_context_data(**kwargs)
        specification = WorkSpecification.objects.get(pk=context["object"].pk)
        context["filters"] = preprocess_filters_specification_view(specification)
        total_input_size, number_of_files, average_file_size = compute_size_of_inputs(
            specification.inputs
        )
        context["number_of_files"] = number_of_files
        context["total_input_size"] = format_size(total_input_size)
        context["size_per_task"] = format_size(
            average_file_size * specification.batch_size
            if specification.batch_size > 0
            else total_input_size
        )
        return context


class WorkSpecificationDeleteView(DeleteView):
    template_name = "lofardata/workspecification/delete.html"
    model = WorkSpecification
    success_url = reverse_lazy("index")


class WorkSpecificationInputsView(DetailView):
    template_name = "lofardata/workspecification/inputs.html"
    model = WorkSpecification


class WorkSpecificationATDBTasksView(DetailView):
    template_name = "lofardata/workspecification/tasks.html"
    model = WorkSpecification


class WorkSpecificationDatasetSizeInfoView(DetailView):
    template_name = "lofardata/workspecification/dataset_size_info.html"
    model = WorkSpecification

    def get_context_data(self, **kwargs):
        context = super().get_context_data(**kwargs)
        specification = WorkSpecification.objects.get(pk=context["object"].pk)

        min_size, max_size, n_bins, counts, biggest_bucket, bins = compute_inputs_histogram(specification.inputs)

        context["min_size"] = format_size(min_size)
        context["max_size"] = format_size(max_size)
        context["biggest_bucket"] = biggest_bucket
        context["n_bins"] = n_bins
        context["counts"] = counts
        context["bins"] = bins

        return context


class DataProductViewPerSasID(TemplateView):
    template_name = "lofardata/workspecification/dataproducts.html"

    def get_context_data(self, **kwargs):
        context = super().get_context_data(**kwargs)

        try:
            specification = WorkSpecification.objects.get(pk=kwargs['pk'])
        except ObjectDoesNotExist:
            return context

        sas_id = specification.filters['obs_id']
        context["sas_id"] = sas_id
        context["dataproduct_info"] = retrieve_general_dataproduct_information(sas_id)
        return context


# ---------- REST API views ----------
class DataProductView(generics.ListCreateAPIView):
    model = DataProduct
    serializer_class = DataProductSerializer

    queryset = DataProduct.objects.all().order_by("obs_id")

    # using the Django Filter Backend - https://django-filter.readthedocs.io/en/latest/index.html
    filter_backends = (filters.DjangoFilterBackend,)
    filter_class = DataProductFilterSet


class ATDBProcessingSiteView(viewsets.ReadOnlyModelViewSet):
    model = ATDBProcessingSite
    serializer_class = ATDBProcessingSiteSerializer

    queryset = ATDBProcessingSite.objects.all().order_by("pk")


class DataProductDetailsView(generics.RetrieveUpdateDestroyAPIView):
    model = DataProduct
    serializer_class = DataProductSerializer
    queryset = DataProduct.objects.all()


class InsertWorkSpecificationSchema(AutoSchema):
    def get_operation_id_base(self, path, method, action):
        return "createDataProductMulti"


class InsertMultiDataproductView(generics.CreateAPIView):
    """
    Add single DataProduct
    """

    queryset = DataProduct.objects.all()
    serializer_class = DataProductFlatSerializer
    schema = InsertWorkSpecificationSchema()

    def get_serializer(self, *args, **kwargs):
        """if an array is passed, set serializer to many"""
        if isinstance(kwargs.get("data", {}), list):
            kwargs["many"] = True
        return super().get_serializer(*args, **kwargs)


class WorkSpecificationViewset(viewsets.ModelViewSet):
    queryset = WorkSpecification.objects.all()
    serializer_class = WorkSpecificationSerializer

    def get_queryset(self):
        current_user: User = self.request.user
        if not current_user.is_staff or not current_user.is_superuser:
            return self.queryset.filter(created_by=current_user.id)
        else:
            return self.queryset

    @action(detail=True, methods=["POST"])
    def submit(self, request, pk=None) -> Response:
        # TODO: check that there are some matches in the request?
        insert_task_into_atdb.delay(pk)

        time.sleep(1)  # allow for some time to pass

        return redirect("index")