diff --git a/LCS/PyCommon/dbcredentials.py b/LCS/PyCommon/dbcredentials.py index 1d129f9b778a6f4320af726895091afc3bfe891e..e82abd8f804ee4369f33cd824452abddc69b0335 100644 --- a/LCS/PyCommon/dbcredentials.py +++ b/LCS/PyCommon/dbcredentials.py @@ -35,121 +35,121 @@ __all__ = ["Credentials", "DBCredentials", "options_group", "parse_options"] environ = os.environ try: - # Throws a KeyError if user info is not found in /etc/passwd (f.e. in Docker environments) - user_info = pwd.getpwuid(os.getuid()) + # Throws a KeyError if user info is not found in /etc/passwd (f.e. in Docker environments) + user_info = pwd.getpwuid(os.getuid()) - environ.setdefault("HOME", user_info.pw_dir) - environ.setdefault("USER", user_info.pw_name) + environ.setdefault("HOME", user_info.pw_dir) + environ.setdefault("USER", user_info.pw_name) except KeyError: - pass + pass def findfiles(pattern): - """ Returns a list of files matched by `pattern'. + """ Returns a list of files matched by `pattern'. The pattern can include environment variables using the {VAR} notation. - """ - try: - return glob(pattern.format(**environ)) - except KeyError: - return [] + """ + try: + return glob(pattern.format(**environ)) + except KeyError: + return [] class Credentials: - def __init__(self): - # Flavour of database (postgres, mysql, oracle, sqlite) - self.type = "postgres" + def __init__(self): + # Flavour of database (postgres, mysql, oracle, sqlite) + self.type = "postgres" - # Connection information (port 0 = use default) - self.host = "localhost" - self.port = 0 + # Connection information (port 0 = use default) + self.host = "localhost" + self.port = 0 - # Authentication - self.user = environ["USER"] - self.password = "" + # Authentication + self.user = environ["USER"] + self.password = "" - # Database selection - self.database = "" + # Database selection + self.database = "" - # All key-value pairs found in the config - self.config = {} + # All key-value pairs found in the config + self.config = {} - def __str__(self): - return "db={database} addr={host}:{port} auth={user}:{password} type={type}".format(**self.__dict__) + def __str__(self): + return "db={database} addr={host}:{port} auth={user}:{password} type={type}".format(**self.__dict__) - def stringWithHiddenPassword(self): - return "db={database} addr={host}:{port} auth={user}:XXXXXX type={type}".format(**self.__dict__) + def stringWithHiddenPassword(self): + return "db={database} addr={host}:{port} auth={user}:XXXXXX type={type}".format(**self.__dict__) - def pg_connect_options(self): - """ - Returns a dict of options to provide to PyGreSQL's pg.connect function. Use: + def pg_connect_options(self): + """ + Returns a dict of options to provide to PyGreSQL's pg.connect function. Use: - conn = pg.connect(**dbcreds.pg_connect_options()) - """ - return { - "host": self.host, - "port": self.port or -1, + conn = pg.connect(**dbcreds.pg_connect_options()) + """ + return { + "host": self.host, + "port": self.port or -1, - "user": self.user, - "passwd": self.password, + "user": self.user, + "passwd": self.password, - "dbname": self.database, - } + "dbname": self.database, + } - def psycopg2_connect_options(self): - """ - Returns a dict of options to provide to PsycoPG2's psycopg2.connect function. Use: + def psycopg2_connect_options(self): + """ + Returns a dict of options to provide to PsycoPG2's psycopg2.connect function. Use: - conn = psycopg2.connect(**dbcreds.psycopg2_connect_options()) - """ - return { - "host": self.host, - "port": self.port or None, + conn = psycopg2.connect(**dbcreds.psycopg2_connect_options()) + """ + return { + "host": self.host, + "port": self.port or None, - "user": self.user, - "password": self.password, + "user": self.user, + "password": self.password, - "database": self.database, - } + "database": self.database, + } + def mysql_connect_options(self): + """ + Returns a dict of options to provide to python's mysql.connector.connect function. Use: - def mysql_connect_options(self): - """ - Returns a dict of options to provide to python's mysql.connector.connect function. Use: + from mysql import connector + conn = connector.connect(**dbcreds.mysql_connect_options()) + """ + options = {"host": self.host, + "user": self.user, + "passwd": self.password, + "database": self.database} - from mysql import connector - conn = connector.connect(**dbcreds.mysql_connect_options()) - """ - options = { "host": self.host, - "user": self.user, - "passwd": self.password, - "database": self.database } + if self.port: + options["port"] = self.port - if self.port: - options["port"] = self.port + return options - return options class DBCredentials: - NoSectionError = NoSectionError + NoSectionError = NoSectionError - def __init__(self, filepatterns=None): - self.filepatterns = filepatterns if filepatterns is not None else [ - "{LOFARROOT}/etc/dbcredentials/*.ini", - "{HOME}/.lofar/dbcredentials/*.ini", - ] - self.read_config_from_files() + def __init__(self, filepatterns=None): + self.filepatterns = filepatterns if filepatterns is not None else [ + "{LOFARROOT}/etc/dbcredentials/*.ini", + "{HOME}/.lofar/dbcredentials/*.ini", + ] + self.read_config_from_files() - def read_config_from_files(self, filepatterns=None): - """ - Read database credentials from all configuration files matched by any of the patterns. + def read_config_from_files(self, filepatterns=None): + """ + Read database credentials from all configuration files matched by any of the patterns. - By default, the following files are read: + By default, the following files are read: $LOFARROOT/etc/dbcredentials/*.ini ~/.lofar/dbcredentials/*.ini - The configuration files allow for any number of database sections: + The configuration files allow for any number of database sections: [database:OTDB] type = postgres # postgres, mysql, oracle, sqlite @@ -159,205 +159,205 @@ class DBCredentials: password = boskabouter database = LOFAR_4 - These database credentials can subsequently be queried under their - symbolic name ("OTDB" in the example). - """ - if filepatterns is not None: - self.filepatterns = filepatterns - - self.files = sum([findfiles(p) for p in self.filepatterns],[]) - - # make sure the files are mode 600 to hide passwords - for file in self.files: - if oct(stat(file).st_mode & 0o777) != '0o600': - logger.info('Changing permissions of %s to 600' % file) - try: - chmod(file, 0o600) - except Exception as e: - logger.error('Error: Could not change permissions on %s: %s' % (file, str(e))) - - #read the files into config - self.config = ConfigParser() - self.config.read(self.files) - - def create_default_file(self, database): - """ - creates a dbcredentials file with defaults in ~/.lofar/dbcredentials/<database>.ini - :param database: name of the database/file - """ - extensions = list(set(os.path.splitext(pat)[1] for pat in self.filepatterns)) - if extensions: - #pick first extension - extension = extensions[0] - new_path = os.path.join(user_info.pw_dir, '.lofar', 'dbcredentials', database+extension) - if not os.path.exists(os.path.dirname(new_path)): - os.makedirs(os.path.dirname(new_path)) - with open(new_path, 'w+') as new_file: - new_file.write("[database:%s]\nhost=localhost\nuser=%s\npassword=unknown\ntype=unknown\nport=0\ndatabase=%s" - % (database,user_info.pw_name,database)) - logger.info("Created default dbcredentials file for database=%s at %s", database, new_path) - logger.warning(" *** Please fill in the proper credentials for database=%s in new empty credentials file: '%s' ***", database, new_path) - - def get(self, database): - """ - Return credentials for a given database. - """ - # create default credentials - creds = Credentials() - - try: - # read configuration (can throw NoSectionError) - d = dict(self.config.items(self._section(database))) - except NoSectionError: - # create defaults file, and reload - self.create_default_file(database) - self.read_config_from_files() - # re-read configuration now that we have created a new file with defaults - d = dict(self.config.items(self._section(database))) - - # save the full config to support custom fields - creds.config = d - - # parse and convert config information - if "host" in d: creds.host = d["host"] - if "port" in d: creds.port = int(d["port"] or 0) - - if "user" in d: creds.user = d["user"] - if "password" in d: creds.password = d["password"] - - if "database" in d: creds.database = d["database"] - - if "type" in d: creds.type = d["type"] - - return creds - - - def set(self, database, credentials): - """ - Add or overwrite credentials for a given database. - """ - section = self._section(database) - - # create section if needed - try: - self.config.add_section(section) - except DuplicateSectionError: - pass - - # set or override credentials - self.config.set(section, "type", credentials.type) - self.config.set(section, "host", credentials.host) - self.config.set(section, "port", str(credentials.port)) - self.config.set(section, "user", credentials.user) - self.config.set(section, "password", credentials.password) - self.config.set(section, "database", credentials.database) - - def list(self): - """ - Return a list of databases for which credentials are available. - """ - sections = self.config.sections() - return [s[9:] for s in sections if s.startswith("database:")] - - - def _section(self, database): - return "database:%s" % (database,) + These database credentials can subsequently be queried under their + symbolic name ("OTDB" in the example). + """ + if filepatterns is not None: + self.filepatterns = filepatterns + + self.files = sum([findfiles(p) for p in self.filepatterns], []) + + # make sure the files are mode 600 to hide passwords + for file in self.files: + if oct(stat(file).st_mode & 0o777) != '0o600': + logger.info('Changing permissions of %s to 600' % file) + try: + chmod(file, 0o600) + except Exception as e: + logger.error('Error: Could not change permissions on %s: %s' % (file, str(e))) + + # read the files into config + self.config = ConfigParser() + self.config.read(self.files) + + def create_default_file(self, database): + """ + creates a dbcredentials file with defaults in ~/.lofar/dbcredentials/<database>.ini + :param database: name of the database/file + """ + extensions = list(set(os.path.splitext(pat)[1] for pat in self.filepatterns)) + if extensions: + # pick first extension + extension = extensions[0] + new_path = os.path.join(user_info.pw_dir, '.lofar', 'dbcredentials', database + extension) + if not os.path.exists(os.path.dirname(new_path)): + os.makedirs(os.path.dirname(new_path)) + with open(new_path, 'w+') as new_file: + new_file.write( + "[database:%s]\nhost=localhost\nuser=%s\npassword=unknown\ntype=unknown\nport=0\ndatabase=%s" + % (database, user_info.pw_name, database)) + logger.info("Created default dbcredentials file for database=%s at %s", database, new_path) + logger.warning( + " *** Please fill in the proper credentials for database=%s in new empty credentials file: '%s' ***", + database, new_path) + + def get(self, database): + """ + Return credentials for a given database. + """ + # create default credentials + creds = Credentials() + + try: + # read configuration (can throw NoSectionError) + d = dict(self.config.items(self._section(database))) + except NoSectionError: + # create defaults file, and reload + self.create_default_file(database) + self.read_config_from_files() + # re-read configuration now that we have created a new file with defaults + d = dict(self.config.items(self._section(database))) + + # save the full config to support custom fields + creds.config = d + + # parse and convert config information + if "host" in d: creds.host = d["host"] + if "port" in d: creds.port = int(d["port"] or 0) + + if "user" in d: creds.user = d["user"] + if "password" in d: creds.password = d["password"] + + if "database" in d: creds.database = d["database"] + + if "type" in d: creds.type = d["type"] + + return creds + + def set(self, database, credentials): + """ + Add or overwrite credentials for a given database. + """ + section = self._section(database) + + # create section if needed + try: + self.config.add_section(section) + except DuplicateSectionError: + pass + + # set or override credentials + self.config.set(section, "type", credentials.type) + self.config.set(section, "host", credentials.host) + self.config.set(section, "port", str(credentials.port)) + self.config.set(section, "user", credentials.user) + self.config.set(section, "password", credentials.password) + self.config.set(section, "database", credentials.database) + + def list(self): + """ + Return a list of databases for which credentials are available. + """ + sections = self.config.sections() + return [s[9:] for s in sections if s.startswith("database:")] + + def _section(self, database): + return "database:%s" % (database,) def options_group(parser, default_credentials=""): - """ + """ Return an optparse.OptionGroup containing command-line parameters for database connections and authentication. - """ - group = OptionGroup(parser, "Database Credentials") - group.add_option("-D", "--database", dest="dbName", type="string", default="", - help="Name of the database") - group.add_option("-H", "--host", dest="dbHost", type="string", default="", - help="Hostname of the database server") - group.add_option("-p", "--port", dest="dbPort", type="string", default="", - help="Port number of the database server") - group.add_option("-U", "--user", dest="dbUser", type="string", default="", - help="User of the database server") - group.add_option("-P", "--password", dest="dbPassword", type="string", default="", - help="Password of the database server") - group.add_option("-C", "--dbcredentials", dest="dbcredentials", type="string", default=default_credentials, - help="Name of database credential set to use [default=%default]") - - return group + """ + group = OptionGroup(parser, "Database Credentials") + group.add_option("-D", "--database", dest="dbName", type="string", default="", + help="Name of the database") + group.add_option("-H", "--host", dest="dbHost", type="string", default="", + help="Hostname of the database server") + group.add_option("-p", "--port", dest="dbPort", type="string", default="", + help="Port number of the database server") + group.add_option("-U", "--user", dest="dbUser", type="string", default="", + help="User of the database server") + group.add_option("-P", "--password", dest="dbPassword", type="string", default="", + help="Password of the database server") + group.add_option("-C", "--dbcredentials", dest="dbcredentials", type="string", default=default_credentials, + help="Name of database credential set to use [default=%default]") + + return group def parse_options(options, filepatterns=None): - """ + """ Parses command-line parameters provided through options_group() and returns a credentials dictionary. `filepatterns' can be used to override the patterns used to find configuration files. - """ + """ - dbc = DBCredentials(filepatterns) + dbc = DBCredentials(filepatterns) - # get default values - try: - creds = dbc.get(options.dbcredentials) - except NoSectionError: - # credentials will have to be supplied on the command line - creds = Credentials() + # get default values + try: + creds = dbc.get(options.dbcredentials) + except NoSectionError: + # credentials will have to be supplied on the command line + creds = Credentials() - # process supplied overrides - if options.dbHost: creds.host = options.dbHost - if options.dbPort: creds.port = options.dbPort - if options.dbUser: creds.user = options.dbUser - if options.dbPassword: creds.password = options.dbPassword - if options.dbName: creds.database = options.dbName + # process supplied overrides + if options.dbHost: creds.host = options.dbHost + if options.dbPort: creds.port = options.dbPort + if options.dbUser: creds.user = options.dbUser + if options.dbPassword: creds.password = options.dbPassword + if options.dbName: creds.database = options.dbName - return creds + return creds if __name__ == "__main__": - import sys - from optparse import OptionParser - - parser = OptionParser("%prog [options]") - parser.add_option("-D", "--database", dest="database", type="string", default="", - help="Print credentials of a specific database") - parser.add_option("-S", "--shell", dest="shell", action="store_true", default=False, - help="Use machine-readable output for use in shell scripts") - parser.add_option("-L", "--list", dest="list", action="store_true", default=False, - help="List known databases") - parser.add_option("-F", "--files", dest="files", action="store_true", default=False, - help="List names of parsed configuration files") - (options, args) = parser.parse_args() - - if not options.database and not options.list and not options.files: - logger.error("Missing database name") - parser.print_help() - sys.exit(1) - - dbc = DBCredentials() - - if options.files: - """ Print list of configuration files that we've read. """ - if dbc.files: - logger.info("\n".join(dbc.files)) - sys.exit(0) - - if options.list: - """ Print list of databases. """ - databases = dbc.list() - if databases: - logger.info("\n".join(databases)) - sys.exit(0) - - """ Print credentials of a specific database. """ - creds = dbc.get(options.database) - - if options.shell: - print("DBUSER=%s" % (creds.user,)) - print("DBPASSWORD=%s" % (creds.password,)) - print("DBDATABASE=%s" % (creds.database,)) - print("DBHOST=%s" % (creds.host,)) - print("DBPORT=%s" % (creds.port,)) - else: - logger.info(str(creds)) - + import sys + from optparse import OptionParser + + parser = OptionParser("%prog [options]") + parser.add_option("-D", "--database", dest="database", type="string", default="", + help="Print credentials of a specific database") + parser.add_option("-S", "--shell", dest="shell", action="store_true", default=False, + help="Use machine-readable output for use in shell scripts") + parser.add_option("-L", "--list", dest="list", action="store_true", default=False, + help="List known databases") + parser.add_option("-F", "--files", dest="files", action="store_true", default=False, + help="List names of parsed configuration files") + (options, args) = parser.parse_args() + + if not options.database and not options.list and not options.files: + logger.error("Missing database name") + parser.print_help() + sys.exit(1) + + dbc = DBCredentials() + + if options.files: + """ Print list of configuration files that we've read. """ + if dbc.files: + logger.info("\n".join(dbc.files)) + sys.exit(0) + + if options.list: + """ Print list of databases. """ + databases = dbc.list() + if databases: + logger.info("\n".join(databases)) + sys.exit(0) + + """ Print credentials of a specific database. """ + creds = dbc.get(options.database) + + if options.shell: + print("DBUSER=%s" % (creds.user,)) + print("DBPASSWORD=%s" % (creds.password,)) + print("DBDATABASE=%s" % (creds.database,)) + print("DBHOST=%s" % (creds.host,)) + print("DBPORT=%s" % (creds.port,)) + else: + logger.info(str(creds))