diff --git a/LCS/PyCommon/dbcredentials.py b/LCS/PyCommon/dbcredentials.py index 7daff938fed59acf62aa3477acc5e4b78dca1a04..8da2664132b8efcb5e33254575719a508ac5ee2c 100644 --- a/LCS/PyCommon/dbcredentials.py +++ b/LCS/PyCommon/dbcredentials.py @@ -131,7 +131,7 @@ class Credentials: class DBCredentials: def __init__(self, filepatterns=None): - self.filepatterns = filepatterns if filepatterns else [ + self.filepatterns = filepatterns if filepatterns is not None else [ "{LOFARROOT}/etc/dbcredentials/*.ini", "{HOME}/.lofar/dbcredentials/*.ini", ] @@ -182,13 +182,17 @@ class DBCredentials: creates a dbcredentials file with defaults in ~/.lofar/dbcredentials/<database>.ini :param database: name of the database/file """ - new_path = os.path.join(user_info.pw_dir, '.lofar', 'dbcredentials', database+'.ini') - if not os.path.exists(os.path.dirname(new_path)): - os.makedirs(os.path.dirname(new_path)) - with open(new_path, 'w') as new_file: - new_file.write("[database:%s]\nhost=localhost\nuser=%s\npassword=unknown\ntype=unknown\nport=0\ndatabase=%s" - % (database,user_info.pw_name,database)) - logger.info("created default dbcredentials file for database=%s at %s", database, new_path) + extensions = list(set(os.path.splitext(pat)[1] for pat in self.filepatterns)) + if extensions: + #pick first extension + extension = extensions[0] + new_path = os.path.join(user_info.pw_dir, '.lofar', 'dbcredentials', database+extension) + if not os.path.exists(os.path.dirname(new_path)): + os.makedirs(os.path.dirname(new_path)) + with open(new_path, 'w') as new_file: + new_file.write("[database:%s]\nhost=localhost\nuser=%s\npassword=unknown\ntype=unknown\nport=0\ndatabase=%s" + % (database,user_info.pw_name,database)) + logger.info("created default dbcredentials file for database=%s at %s", database, new_path) def get(self, database): """ diff --git a/LCS/PyCommon/test/t_dbcredentials.py b/LCS/PyCommon/test/t_dbcredentials.py index 17b303f9b5b53af7d7b80cd98c3188047db8511b..72bd51c23f3148f7f9a5dc3cc06d20ae71a177f6 100644 --- a/LCS/PyCommon/test/t_dbcredentials.py +++ b/LCS/PyCommon/test/t_dbcredentials.py @@ -1,8 +1,16 @@ #!/usr/bin/env python +import os import unittest import tempfile +from uuid import uuid4 from lofar.common.dbcredentials import * +try: + #python2 + from ConfigParser import NoSectionError +except ImportError: + #python3 + from configparser import NoSectionError def setUpModule(): pass @@ -53,8 +61,21 @@ class TestDBCredentials(unittest.TestCase): def test_get_non_existing(self): dbc = DBCredentials(filepatterns=[]) - with self.assertRaises(DBCredentials.NoSectionError): - dbc.get("UNKNOWN") + non_existing_db_name = "UNKNOWN-%s" % (uuid4(),) + with self.assertRaises(NoSectionError): + dbc.get(non_existing_db_name) + + def test_creation_for_non_existing(self): + dbc = DBCredentials(filepatterns=["{HOME}/.lofar/dbcredentials/*.test_extension"]) + + non_existing_db_name = "UNKNOWN-%s" % (uuid4(),) + creds = dbc.get(non_existing_db_name) + self.assertTrue(creds) + self.assertEqual(non_existing_db_name, creds.database) + + expected_path = os.path.expanduser("~/.lofar/dbcredentials/%s.test_extension" % (non_existing_db_name,)) + self.assertTrue(os.path.exists(expected_path)) + os.remove(expected_path) def test_list(self): dbc = DBCredentials(filepatterns=[])