Select Git revision
dish_wg_sweep_data_path.py
-
Pieter Donker authoredPieter Donker authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
tmss_http_rest_client.py 7.26 KiB
import logging
logger = logging.getLogger(__file__)
import requests
import os
import json
from datetime import datetime
from lofar.common.datetimeutils import formatDatetime
# usage example:
#
# with TMSSsession('paulus', 'pauluspass', 'localhost', 8000) as tmsssession:
# response = tmsssession.session.get(url='http://localhost/api/task_draft/')
# print(response)
class TMSSsession(object):
OPENID = "openid"
BASICAUTH = "basicauth"
def __init__(self, username, password, host, port: int=8000, authentication_method=OPENID):
self.session = requests.session()
self.username = username
self.password = password
self.base_url = "http://%s:%d/api" % (host, port)
self.authentication_method = authentication_method
@staticmethod
def create_from_dbcreds_for_ldap(dbcreds_name: str=None):
'''Factory method to create a TMSSSession object which uses the credentials in the ~/.lofar/dbcredentials/<dbcreds_name>.ini file
(mis)use the DBCredentials to get a url with user/pass for tmss
the contents below are used to contruct a url like this: http://localhost:8000/api
[database:TMSS]
host=localhost
user=<username>
password=<password>
type=http
port=8000
'''
if dbcreds_name is None:
dbcreds_name = os.environ.get("TMSS_CLIENT_DBCREDENTIALS", "TMSSClient")
from lofar.common.dbcredentials import DBCredentials
dbcreds = DBCredentials().get(dbcreds_name)
return TMSSsession(username=dbcreds.user, password=dbcreds.password,
host=dbcreds.host,
port=dbcreds.port,
authentication_method=TMSSsession.BASICAUTH)
def __enter__(self):
self.open()
# return the request session for use within the context
return self
def __exit__(self, type, value, traceback):
self.close()
def open(self):
'''open the request session and login'''
self.session.__enter__()
self.session.verify = False
if self.authentication_method == self.OPENID:
# get authentication page of OIDC through TMSS redirect
response = self.session.get(self.base_url.replace('/api', '/oidc/authenticate/'), allow_redirects=True)
csrftoken = self.session.cookies['csrftoken']
# post user credentials to login page, also pass csrf token
data = {'username': self.username, 'password': self.password, 'csrfmiddlewaretoken': csrftoken}
response = self.session.post(url=response.url, data=data, allow_redirects=True)
# raise when sth went wrong
if "The username and/or password you specified are not correct" in response.content.decode('utf8'):
raise ValueError("The username and/or password you specified are not correct")
if response.status_code != 200:
raise ConnectionError(response.content.decode('utf8'))
if self.authentication_method == self.BASICAUTH:
self.session.auth = (self.username, self.password)
def close(self):
'''close the request session and logout'''
try:
# logout user
self.session.get(self.base_url + '/logout/', allow_redirects=True)
self.session.close()
except:
pass
def set_subtask_status(self, subtask_id: int, status: str) -> requests.Response:
'''set the status for the given subtask'''
result = self.session.patch(url='%s/subtask/%s/' % (self.base_url, subtask_id),
json={'state': "%s/subtask_state/%s/" % (self.base_url, status)})
return result
def get_subtask_parset(self, subtask_id) -> str:
'''get the lofar parameterset (as text) for the given subtask'''
result = self.session.get(url='%s/subtask/%s/parset' % (self.base_url, subtask_id))
if result.status_code >= 200 and result.status_code < 300:
return result.content.decode('utf-8')
raise Exception("Could not get parameterset for subtask %s.\nResponse: %s" % (subtask_id, result))
def get_subtask(self, subtask_id: int) -> dict:
'''get the subtask as dict for the given subtask'''
path = 'subtask/%s' % (subtask_id,)
return self.get_path_as_json_object(path)
def get_subtasks(self, state: str=None,
start_time_less_then: datetime=None, start_time_greater_then: datetime=None,
stop_time_less_then: datetime = None, stop_time_greater_then: datetime = None) -> list:
'''get subtasks (as list of dicts) filtered by the given parameters'''
clauses = {}
if state is not None:
clauses["state__value"] = state
if start_time_less_then is not None:
clauses["start_time__lt="] = formatDatetime(start_time_less_then)
if start_time_greater_then is not None:
clauses["start_time__gt"] = formatDatetime(start_time_greater_then)
if stop_time_less_then is not None:
clauses["stop_time__lt"] = formatDatetime(stop_time_less_then)
if stop_time_greater_then is not None:
clauses["stop_time__gt"] = formatDatetime(stop_time_greater_then)
return self.get_path_as_json_object("subtask", clauses)
def get_path_as_json_object(self, path: str, params={}) -> dict:
'''get resource at the given path, interpret it as json, and return it as as native object'''
full_url = '%s/%s/' % (self.base_url, path)
return self.get_url_as_json_object(full_url, params=params)
def get_url_as_json_object(self, full_url: str, params={}) -> dict:
'''get resource at the given full url (including http://<base_url>, interpret it as json, and return it as as native object'''
if "format=json" not in full_url or params.get("format") != "json":
params['format'] ='json'
result = self.session.get(url=full_url, params=params)
if result.status_code >= 200 and result.status_code < 300:
return json.loads(result.content.decode('utf-8'))
raise Exception("Could not get %s.\nResponse: %s" % (full_url, result))
def get_subtask_template(self, name: str, version: str=None) -> dict:
'''get the subtask_template as dict for the given name (and version)'''
clauses = {}
if name is not None:
clauses["name"] = name
if version is not None:
clauses["version"] = version
result = self.get_path_as_json_object('subtask_template', clauses)
if result['count'] > 1:
raise ValueError("Found more then one SubtaskTemplate for clauses: %s" % (clauses,))
elif result['count'] == 1:
return result['results'][0]
return None
def specify_observation_task(self, task_id: int) -> requests.Response:
"""specify observation for the given draft task by just doing a REST API call """
result = self.session.get(url='%s/api/task/%s/specify_observation' % (self.base_url, task_id))
if result.status_code >= 200 and result.status_code < 300:
return result.content.decode('utf-8')
raise Exception("Could not specify observation for task %s.\nResponse: %s" % (task_id, result))