Commit 182106d4 authored by Mattia Mancini's avatar Mattia Mancini

implemented scheduling workflow

parent e82bd3ed
from airflow.plugins_manager import AirflowPlugin
from airflow.executors.base_executor import BaseExecutor
from airflow.utils.state import State
from slurm_cli.slurm_control import get_jobs_status
import subprocess
from slurm_cli.slurm_control import get_jobs_status, run_job
import logging
import uuid
logger = logging.getLogger(__name__)
def reindex_job_status_by_job_name(job_list):
return {job_status.job_name: job_status for job_status in job_list.values()}
# Will show up under airflow.executors.slurm.SlurmExecutor
class SlurmExecutor(BaseExecutor):
def __init__(self):
self.commands_to_run = []
self.commands_to_check = {}
def execute_async(self, key, command, queue=None, executor_config=None):
print("execute async called")
self.commands_to_run.append((key, command,))
def trigger_tasks(self, open_slots):
print('trigger tasks called', open_slots)
unique_id = str(key[0]) + str(uuid.uuid1())
queue = queue if queue != 'default' else None
logging.debug('submitting job %s on queue %s', key, queue)
run_job(cmd=command, queue=queue, task_name=unique_id)
self.commands_to_check[unique_id] = key
def sync(self):
for key, command in self.commands_to_run:"Executing command with key %s: %s", key, command)
def check_state(self):
ids = list(self.commands_to_check.keys())
statuses = reindex_job_status_by_job_name(get_jobs_status(job_name=ids))
logger.debug('statuses found are %s', statuses)
logger.debug('commands to check are %s', self.commands_to_check)
subprocess.check_call(command, close_fds=True)
completed_jobs = []
for unique_id, key in self.commands_to_check.items():
status = statuses[unique_id]
if status.status_code == 'CD':
self.change_state(key, State.SUCCESS)
except subprocess.CalledProcessError as e:
elif status.status_code == 'F':
self.change_state(key, State.FAILED)
self.log.error("Failed to execute task %s.", str(e))
elif status.status_code in ('CG', 'R'):
self.change_state(key, State.RUNNING)
elif status.status_code == 'PD':
self.change_state(key, State.SCHEDULED)
for unique_id in completed_jobs:
if unique_id in self.commands_to_check:
logger.error('id %s missing in %s', unique_id, self.commands_to_check)
def trigger_tasks(self, open_slots):
self.commands_to_run = []
def sync(self):
def end(self):
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment