Skip to content
Snippets Groups Projects
Commit 6aa152a1 authored by Mattia Mancini's avatar Mattia Mancini
Browse files

code refactor

parent 78180d5f
Branches
No related tags found
No related merge requests found
from typing import List, Tuple, Union from typing import List, Tuple, Union, Dict
from .jobs import SlurmJobStatus from .jobs import SlurmJobStatus
import subprocess from subprocess import run as run_process
__SQUEUE_PATH='squeue'
__SRUN_PATH='srun'
class EmptyListException(Exception): class EmptyListException(Exception):
...@@ -21,8 +24,8 @@ def __list_contains_valid_ids(ids_list): ...@@ -21,8 +24,8 @@ def __list_contains_valid_ids(ids_list):
raise EmptyListException() raise EmptyListException()
def slurm_get_processes_status(job_ids: Union[List, Tuple] = ()): def __compose_get_processes_status_cmd(job_ids: Union[List, Tuple] = ()):
cmd = ['squeue', '--states=all', '-h'] cmd = ['--states=all', '-h']
fmt = '%i;%j;%t;%T;%r' fmt = '%i;%j;%t;%T;%r'
cmd += ['--format=%s' % fmt] cmd += ['--format=%s' % fmt]
if job_ids: if job_ids:
...@@ -30,13 +33,30 @@ def slurm_get_processes_status(job_ids: Union[List, Tuple] = ()): ...@@ -30,13 +33,30 @@ def slurm_get_processes_status(job_ids: Union[List, Tuple] = ()):
else: else:
cmd += ['-a'] cmd += ['-a']
process_status = subprocess.run(cmd) return cmd
def __execute_squeue(args):
process_status = run_process([__SQUEUE_PATH] + args)
if process_status.returncode > 0: if process_status.returncode > 0:
raise SlurmCallError() raise SlurmCallError()
output = process_status.stdout output = process_status.stdout
return output
def __parse_squeue_output(squeue_output) -> List[SlurmJobStatus]:
'''
Parses the output of squeue
e.g.
123;test_job;CD;COMPLETED;None
:param squeue_output:
:return:
'''
jobs_found = [] jobs_found = []
if process_status.stdout: if squeue_output:
for line in output.split('\n'): for line in squeue_output.split('\n'):
if not line:
continue
job_id, job_name, status_code, status, reason = line.split(';') job_id, job_name, status_code, status, reason = line.split(';')
jobs_found.append(SlurmJobStatus(job_id=job_id, jobs_found.append(SlurmJobStatus(job_id=job_id,
job_name=job_name, job_name=job_name,
...@@ -45,3 +65,14 @@ def slurm_get_processes_status(job_ids: Union[List, Tuple] = ()): ...@@ -45,3 +65,14 @@ def slurm_get_processes_status(job_ids: Union[List, Tuple] = ()):
reason=reason)) reason=reason))
return jobs_found return jobs_found
def __map_job_status_per_jobid(job_status_list: List[SlurmJobStatus]) -> Dict[str, SlurmJobStatus]:
return {job_status.job_id: job_status for job_status in job_status_list}
def get_jobs_status(job_ids: Union[List, Tuple] = ()):
args = __compose_get_processes_status_cmd(job_ids)
output = __execute_squeue(args)
parsed_output = __parse_squeue_output(output)
return __map_job_status_per_jobid(parsed_output)
\ No newline at end of file
from airflow.plugins_manager import AirflowPlugin from airflow.plugins_manager import AirflowPlugin
from airflow.executors.base_executor import BaseExecutor from airflow.executors.base_executor import BaseExecutor
from airflow.utils.state import State from airflow.utils.state import State
from slurm_cli.slurm_control import get_jobs_status
import subprocess import subprocess
# Will show up under airflow.executors.slurm.SlurmExecutor # Will show up under airflow.executors.slurm.SlurmExecutor
...@@ -40,5 +41,5 @@ class SlurmExecutorPlugin(AirflowPlugin): ...@@ -40,5 +41,5 @@ class SlurmExecutorPlugin(AirflowPlugin):
import sys import sys
if __name__=='__main__': if __name__=='__main__':
print('output', _slurm_get_processes_status()) print('output', get_jobs_status())
print('output', _slurm_get_processes_status(sys.argv[1:])) print('output', get_jobs_status(sys.argv[1:]))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment