diff --git a/tangostationcontrol/tangostationcontrol/toolkit/retriever.py b/tangostationcontrol/tangostationcontrol/toolkit/retriever.py index 190c014da690d3e3542f26f59fc84a4596573111..2905abee59ff8a080715c203d3f9a1667d53e6ab 100644 --- a/tangostationcontrol/tangostationcontrol/toolkit/retriever.py +++ b/tangostationcontrol/tangostationcontrol/toolkit/retriever.py @@ -1,8 +1,9 @@ #! /usr/bin/env python3 from tango import DeviceProxy, AttributeProxy -from tangostationcontrol.toolkit.archiver import * +from tangostationcontrol.toolkit.archiver import split_tango_name +from abc import ABC, abstractmethod from datetime import datetime, timedelta from sqlalchemy import create_engine, and_ from sqlalchemy.orm import sessionmaker @@ -10,15 +11,10 @@ from sqlalchemy.orm.exc import NoResultFound import importlib import numpy -class Retriever(): +class Retriever(ABC): """ The Retriever class implements retrieve operations on a given DBMS """ - def __init__(self, cm_name: str = 'archiving/hdbppts/confmanager01'): - self.cm_name = cm_name - self.session, self.dbms = self.connect_to_archiving_db() - self.ab = self.set_archiver_base() - def get_db_credentials(self): """ @@ -34,33 +30,14 @@ class Retriever(): pw = str([s for s in config_list if "password" in s][0].split('=')[1]) return host,dbname,port,user,pw - def connect_to_archiving_db(self): + def create_session(self,libname:str,user:str,pw:str,host:str,port:str,dbname:str): """ - Returns a session to a MySQL DBMS using default credentials. + Returns a session to a DBMS using default credentials. """ - host,dbname,port,user,pw = self.get_db_credentials() - # Set sqlalchemy library connection - if host=='archiver-maria-db': - libname = 'mysql+pymysql' - dbms = 'mysql' - elif host=='archiver-timescale': - libname = 'postgresql+psycopg2' - dbms = 'postgres' - else: - raise ValueError(f"Invalid hostname: {host}") - engine = create_engine(libname+'://'+user+':'+pw+'@'+host+':'+port+'/'+dbname) + connection_string = f"{libname}://{user}:{pw}@{host}:{port}/{dbname}" + engine = create_engine(connection_string) Session = sessionmaker(bind=engine) - return Session(),dbms - - def set_archiver_base(self): - """ - Sets the right mapper class following the DBMS connection - """ - if self.dbms == 'postgres': - ab = importlib.import_module('.archiver_base_ts', package=__package__) - elif self.dbms == 'mysql': - ab = importlib.import_module('.archiver_base_mysql', package=__package__) - return ab + return Session def get_all_archived_attributes(self): """ @@ -222,3 +199,56 @@ class Retriever(): #masked_values = np.multiply(temp_array_values,mask_array_values) masked_values = numpy.ma.masked_array(temp_array_values,mask=numpy.invert(mask_array_values.astype(bool))) return masked_values, mask_values, temp_values + +class Retriever_MySQL(Retriever): + + def __init__(self, cm_name: str = 'archiving/hdbpp/confmanager01'): + self.cm_name = cm_name + self.session, self.dbms = self.connect_to_archiving_db() + self.ab = self.set_archiver_base() + + def connect_to_archiving_db(self): + """ + Returns a session to a MySQL DBMS using default credentials. + """ + host,dbname,port,user,pw = super().get_db_credentials() + # Set sqlalchemy library connection + if host=='archiver-maria-db': + libname = 'mysql+pymysql' + else: + raise ValueError(f"Invalid hostname: {host}") + Session = super().create_session(libname,user,pw,host,port,dbname) + return Session() + + def set_archiver_base(self): + """ + Sets the right mapper class following the DBMS connection + """ + return importlib.import_module('.archiver_base_mysql', package=__package__) + +class Retriever_Timescale(Retriever): + + def __init__(self, cm_name: str = 'archiving/hdbppts/confmanager01'): + self.cm_name = cm_name + self.session, self.dbms = self.connect_to_archiving_db() + self.ab = self.set_archiver_base() + + def connect_to_archiving_db(self): + """ + Returns a session to a MySQL DBMS using default credentials. + """ + host,dbname,port,user,pw = super().get_db_credentials() + # Set sqlalchemy library connection + if host=='archiver-timescale': + libname = 'postgresql+psycopg2' + else: + raise ValueError(f"Invalid hostname: {host}") + Session = super().create_session(libname,user,pw,host,port,dbname) + return Session() + + def set_archiver_base(self): + """ + Sets the right mapper class following the DBMS connection + """ + return importlib.import_module('.archiver_base_ts', package=__package__) + \ No newline at end of file