Commit 182106d4 authored by mancini's avatar mancini

implemented scheduling workflow

parent e82bd3ed
from airflow.plugins_manager import AirflowPlugin from airflow.plugins_manager import AirflowPlugin
from airflow.executors.base_executor import BaseExecutor from airflow.executors.base_executor import BaseExecutor
from airflow.utils.state import State from airflow.utils.state import State
from slurm_cli.slurm_control import get_jobs_status from slurm_cli.slurm_control import get_jobs_status, run_job
import subprocess import logging
import uuid
logger = logging.getLogger(__name__)
def reindex_job_status_by_job_name(job_list):
return {job_status.job_name: job_status for job_status in job_list.values()}
# Will show up under airflow.executors.slurm.SlurmExecutor # Will show up under airflow.executors.slurm.SlurmExecutor
class SlurmExecutor(BaseExecutor): class SlurmExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.commands_to_run = [] self.commands_to_check = {}
def execute_async(self, key, command, queue=None, executor_config=None): def execute_async(self, key, command, queue=None, executor_config=None):
print("execute async called") print("execute async called")
self.commands_to_run.append((key, command,)) unique_id = str(key[0]) + str(uuid.uuid1())
queue = queue if queue != 'default' else None
def trigger_tasks(self, open_slots): logging.debug('submitting job %s on queue %s', key, queue)
print('trigger tasks called', open_slots) run_job(cmd=command, queue=queue, task_name=unique_id)
super().trigger_tasks(open_slots) self.commands_to_check[unique_id] = key
def sync(self): def check_state(self):
for key, command in self.commands_to_run: ids = list(self.commands_to_check.keys())
self.log.info("Executing command with key %s: %s", key, command) statuses = reindex_job_status_by_job_name(get_jobs_status(job_name=ids))
logger.debug('statuses found are %s', statuses)
logger.debug('commands to check are %s', self.commands_to_check)
try: completed_jobs = []
subprocess.check_call(command, close_fds=True) for unique_id, key in self.commands_to_check.items():
status = statuses[unique_id]
if status.status_code == 'CD':
self.change_state(key, State.SUCCESS) self.change_state(key, State.SUCCESS)
except subprocess.CalledProcessError as e: completed_jobs.append(unique_id)
elif status.status_code == 'F':
self.change_state(key, State.FAILED) self.change_state(key, State.FAILED)
self.log.error("Failed to execute task %s.", str(e)) completed_jobs.append(unique_id)
elif status.status_code in ('CG', 'R'):
self.change_state(key, State.RUNNING)
elif status.status_code == 'PD':
self.change_state(key, State.SCHEDULED)
for unique_id in completed_jobs:
if unique_id in self.commands_to_check:
self.commands_to_check.pop(unique_id)
else:
logger.error('id %s missing in %s', unique_id, self.commands_to_check)
def trigger_tasks(self, open_slots):
self.check_state()
super().trigger_tasks(open_slots)
self.commands_to_run = [] def sync(self):
pass
def end(self): def end(self):
self.heartbeat() self.heartbeat()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment