diff --git a/LCS/PyCommon/dbcredentials.py b/LCS/PyCommon/dbcredentials.py index 2de15f70e22af080aae19d0e39cd2438d1cdb2ee..7daff938fed59acf62aa3477acc5e4b78dca1a04 100644 --- a/LCS/PyCommon/dbcredentials.py +++ b/LCS/PyCommon/dbcredentials.py @@ -130,9 +130,14 @@ class Credentials: return options class DBCredentials: - NoSectionError = NoSectionError - def __init__(self, filepatterns=None): + self.filepatterns = filepatterns if filepatterns else [ + "{LOFARROOT}/etc/dbcredentials/*.ini", + "{HOME}/.lofar/dbcredentials/*.ini", + ] + self.read_config_from_files() + + def read_config_from_files(self, filepatterns=None): """ Read database credentials from all configuration files matched by any of the patterns. @@ -154,13 +159,10 @@ class DBCredentials: These database credentials can subsequently be queried under their symbolic name ("OTDB" in the example). """ - if filepatterns is None: - filepatterns = [ - "{LOFARROOT}/etc/dbcredentials/*.ini", - "{HOME}/.lofar/dbcredentials/*.ini", - ] + if filepatterns is not None: + self.filepatterns = filepatterns - self.files = sum([findfiles(p) for p in filepatterns],[]) + self.files = sum([findfiles(p) for p in self.filepatterns],[]) # make sure the files are mode 600 to hide passwords for file in self.files: @@ -175,6 +177,19 @@ class DBCredentials: self.config = SafeConfigParser() self.config.read(self.files) + def create_default_file(self, database): + """ + creates a dbcredentials file with defaults in ~/.lofar/dbcredentials/<database>.ini + :param database: name of the database/file + """ + new_path = os.path.join(user_info.pw_dir, '.lofar', 'dbcredentials', database+'.ini') + if not os.path.exists(os.path.dirname(new_path)): + os.makedirs(os.path.dirname(new_path)) + with open(new_path, 'w') as new_file: + new_file.write("[database:%s]\nhost=localhost\nuser=%s\npassword=unknown\ntype=unknown\nport=0\ndatabase=%s" + % (database,user_info.pw_name,database)) + logger.info("created default dbcredentials file for database=%s at %s", database, new_path) + def get(self, database): """ Return credentials for a given database. @@ -182,8 +197,15 @@ class DBCredentials: # create default credentials creds = Credentials() - # read configuration (can throw NoSectionError) - d = dict(self.config.items(self._section(database))) + try: + # read configuration (can throw NoSectionError) + d = dict(self.config.items(self._section(database))) + except NoSectionError: + # create defaults file, and reload + self.create_default_file(database) + self.read_config_from_files() + # re-read configuration now that we have created a new file with defaults + d = dict(self.config.items(self._section(database))) # save the full config to support custom fields creds.config = d @@ -202,6 +224,7 @@ class DBCredentials: return creds + def set(self, database, credentials): """ Add or overwrite credentials for a given database.