From 182106d411572f4b3a6ea9279b8f02e0d79bea52 Mon Sep 17 00:00:00 2001
From: mancini <mancini@astron.nl>
Date: Wed, 18 Sep 2019 12:15:33 +0200
Subject: [PATCH] implemented scheduling workflow

---
 lib/slurm_executor/slurm.py | 60 +++++++++++++++++++++++++++----------
 1 file changed, 44 insertions(+), 16 deletions(-)

diff --git a/lib/slurm_executor/slurm.py b/lib/slurm_executor/slurm.py
index 8bf1b29..8f981a9 100644
--- a/lib/slurm_executor/slurm.py
+++ b/lib/slurm_executor/slurm.py
@@ -1,35 +1,63 @@
 from airflow.plugins_manager import AirflowPlugin
 from airflow.executors.base_executor import BaseExecutor
 from airflow.utils.state import State
-from slurm_cli.slurm_control import get_jobs_status
-import subprocess
+from slurm_cli.slurm_control import get_jobs_status, run_job
+import logging
+import uuid
+
+logger = logging.getLogger(__name__)
+
+
+def reindex_job_status_by_job_name(job_list):
+    return {job_status.job_name: job_status for job_status in job_list.values()}
+
 
 # Will show up under airflow.executors.slurm.SlurmExecutor
 class SlurmExecutor(BaseExecutor):
     def __init__(self):
         super().__init__()
-        self.commands_to_run = []
+        self.commands_to_check = {}
 
     def execute_async(self, key, command, queue=None, executor_config=None):
         print("execute async called")
-        self.commands_to_run.append((key, command,))
-    
-    def trigger_tasks(self, open_slots):
-        print('trigger tasks called', open_slots)
-        super().trigger_tasks(open_slots)
+        unique_id = str(key[0]) + str(uuid.uuid1())
+        queue = queue if queue != 'default' else None
+        logging.debug('submitting job %s on queue %s', key, queue)
+        run_job(cmd=command, queue=queue, task_name=unique_id)
+        self.commands_to_check[unique_id] = key
 
-    def sync(self):
-        for key, command in self.commands_to_run:
-            self.log.info("Executing command with key %s: %s", key, command)
+    def check_state(self):
+        ids = list(self.commands_to_check.keys())
+        statuses = reindex_job_status_by_job_name(get_jobs_status(job_name=ids))
+        logger.debug('statuses found are %s', statuses)
+        logger.debug('commands to check are %s', self.commands_to_check)
 
-            try:
-                subprocess.check_call(command, close_fds=True)
+        completed_jobs = []
+        for unique_id, key in self.commands_to_check.items():
+            status = statuses[unique_id]
+            if status.status_code == 'CD':
                 self.change_state(key, State.SUCCESS)
-            except subprocess.CalledProcessError as e:
+                completed_jobs.append(unique_id)
+            elif status.status_code == 'F':
                 self.change_state(key, State.FAILED)
-                self.log.error("Failed to execute task %s.", str(e))
+                completed_jobs.append(unique_id)
+            elif status.status_code in ('CG', 'R'):
+                self.change_state(key, State.RUNNING)
+            elif status.status_code == 'PD':
+                self.change_state(key, State.SCHEDULED)
+
+        for unique_id in completed_jobs:
+            if unique_id in self.commands_to_check:
+                self.commands_to_check.pop(unique_id)
+            else:
+                logger.error('id %s missing in %s', unique_id, self.commands_to_check)
+
+    def trigger_tasks(self, open_slots):
+        self.check_state()
+        super().trigger_tasks(open_slots)
 
-        self.commands_to_run = []
+    def sync(self):
+        pass
 
     def end(self):
         self.heartbeat()
-- 
GitLab