From e6818c9fa25817b3996bfd0da48c13329424a057 Mon Sep 17 00:00:00 2001
From: Klaas Kliffen <kliffen@astron.nl>
Date: Thu, 18 Nov 2021 09:06:26 +0000
Subject: [PATCH] Resolve SDC-355

---
 README.md                            |   3 +
 atdb_csv_gen/atdb_csv_gen/csv_gen.py | 151 +++++++++++++++++++--------
 atdb_csv_gen/requirements.txt        |   1 +
 3 files changed, 110 insertions(+), 45 deletions(-)

diff --git a/README.md b/README.md
index 670d61a..d7580b2 100644
--- a/README.md
+++ b/README.md
@@ -18,6 +18,9 @@ cp ./atdb_csv_gen/.env.example ./atdb_csv_gen/.env
 export $(cat .env | xargs)  # or export env vars elsewhere
 atdb_csv_gen -v -o out.csv 650273
 
+# or using a config file following .env.example layout:
+atdb_csv_gen -c path/to/config.env -o out.csv 650273
+
 # More info and flags
 atdb_csv_gen -h
 ```
diff --git a/atdb_csv_gen/atdb_csv_gen/csv_gen.py b/atdb_csv_gen/atdb_csv_gen/csv_gen.py
index 8a1ad36..e14b44e 100755
--- a/atdb_csv_gen/atdb_csv_gen/csv_gen.py
+++ b/atdb_csv_gen/atdb_csv_gen/csv_gen.py
@@ -3,61 +3,64 @@
 
 import logging
 import os
-import sys
 from argparse import ArgumentParser, Namespace
 from sys import version
 from typing import Iterator, List, Optional, Tuple
 
-from sqlalchemy import create_engine
-from sqlalchemy.engine.base import Connection
+from dotenv import load_dotenv
+from sqlalchemy import create_engine as sqla_create_engine
+from sqlalchemy.engine.base import Connection, Engine
 from sqlalchemy.sql import text
 
-USER = os.getenv("DB_USERNAME")
-PASSWORD = os.getenv("DB_PASSWORD")
-DATABASE = os.getenv("DB_DATABASE")
-HOST = os.getenv("DB_HOST", "localhost")
-PORT = os.getenv("DP_PORT", "5432")
-
-
 logger = logging.getLogger(__name__)
 
+
 def parse_args() -> Namespace:
-    """ Parse CLI arguments """
+    """Parse CLI arguments"""
 
-    parser = ArgumentParser(description="Generates CSV for atdb_load_tasks_from_table\n"
+    parser = ArgumentParser(
+        description="Generates CSV for atdb_load_tasks_from_table\n"
         "Requires the following env vars:\n"
         "\tDB_USERNAME, DB_PASSWORD, DB_DATABASE and optionally:\n"
         "\tDB_HOST (localhost by default) and DB_PORT (5432, postgres default)"
     )
     parser.add_argument("obs_id", type=str, help="Observation ID (SAS ID)")
     parser.add_argument("-o", type=str, help="Output file")
-    parser.add_argument("--save-missing", type=str, help="Directory to store csv with missing files")
+    parser.add_argument(
+        "-c", type=str, help="Config file (in .env format)", default=None
+    )
+    parser.add_argument(
+        "--save-missing", type=str, help="Directory to store csv with missing files"
+    )
     parser.add_argument("-v", action="store_true", help="Verbose logging")
     parser.add_argument("-q", action="store_true", help="Quiet logging (only errors)")
     return parser.parse_args()
 
 
 def fetch_registered_files(conn: Connection, obs_id: str) -> List[Tuple[int, str]]:
-    """ Fetch the registered files from the LTA """
+    """Fetch the registered files from the LTA"""
 
     sql = text("SELECT file_size, surl FROM ldv_delete.aw_uris WHERE obsid = :obs_id")
     return [row for row in conn.execute(sql, obs_id=obs_id)]
 
 
 def fetch_crawled_files(conn: Connection, obs_id: str) -> List[Tuple[int, str]]:
-    """ Fetch crawled files from the LTA """
+    """Fetch crawled files from the LTA"""
 
-    sql = text("SELECT file_size, surl FROM ldv_delete.lta_uris WHERE dir_name like CONCAT('/pnfs/grid.sara.nl/data/lofar/ops/projects/%/',:obs_id)")
+    sql = text(
+        "SELECT file_size, surl FROM ldv_delete.lta_uris WHERE dir_name like CONCAT('/pnfs/grid.sara.nl/data/lofar/ops/projects/%/',:obs_id)"
+    )
     return [row for row in conn.execute(sql, obs_id=obs_id)]
 
 
-def tuple_list_to_csv(l: List[Tuple[int, str]], sep: str=";") -> Iterator[str]:
-    """ Convert a list of tuples to a lines in a csv """
+def tuple_list_to_csv(l: List[Tuple[int, str]], sep: str = ";") -> Iterator[str]:
+    """Convert a list of tuples to a line in a csv"""
 
     return map(lambda row: f"{row[0]}{sep}{row[1]}", l)
 
+
 def write_output(csv_data: str, output: Optional[str]):
-    """ Write csv data to output file or stdout """
+    """Write csv data to output file or stdout"""
     if output is not None:
         dirs = os.path.dirname(output)
         if dirs != "":
@@ -70,34 +73,54 @@ def write_output(csv_data: str, output: Optional[str]):
     else:
         print(csv_data)
 
-def gen_csv(obs_id: str, o_file: Optional[str] = None, save_missing:Optional[str] = None):
-    """ Generate CSV file for ATDB
 
-    Parameters
-    ----------
-    obs_id : string
-        observation id
-    o_file : string, optional
-        File to write the output to (by default to stdout)
-    save_missing : string, optional
-        Directory to write missing files output to (by default only shows warnings!)
-    """
+def create_engine(env_file: Optional[str] = None) -> Engine:
+    """Create sqlalchemy engine from env vars or config file"""
+
+    # Optionally load the config file
+    if env_file is not None:
+        load_dotenv(env_file)
+
+    USER = os.getenv("DB_USERNAME")
+    PASSWORD = os.getenv("DB_PASSWORD")
+    DATABASE = os.getenv("DB_DATABASE")
+    HOST = os.getenv("DB_HOST", "localhost")
+    PORT = os.getenv("DP_PORT", "5432")
 
-    if not all([USER, PASSWORD, DATABASE, HOST, PORT]):
-        raise RuntimeError("Missing DB credentials in env vars.\n"
-                           "Did you export the .env file?")
+    if not all((USER, PASSWORD, DATABASE, HOST, PORT)):
+        raise RuntimeError(
+            "Missing DB credentials in env vars.\nDid you export the .env file?"
+        )
 
-    engine = create_engine(f"postgresql://{USER}:{PASSWORD}@{HOST}:{PORT}/{DATABASE}")
+    engine = sqla_create_engine(f"postgresql://{USER}:{PASSWORD}@{HOST}:{PORT}/{DATABASE}")
     logger.debug("Connected to %s:%s", HOST, PORT)
 
+    return engine
+
+
+def query_file_lists(
+    obs_id: str, config_file: Optional[str] = None
+) -> Tuple[List[Tuple[int, str]], List[Tuple[int, str]]]:
+    """Query database for both file lists"""
+
+    engine = create_engine(config_file)
+
+    logger.info("Query file lists...")
     with engine.connect() as conn:
         conn: Connection
-
         registered = fetch_registered_files(conn, obs_id)
         logger.debug("Found %i registered files", len(registered))
         crawled = fetch_crawled_files(conn, obs_id)
         logger.debug("Found %i crawled files", len(crawled))
-        
+    logger.info("Query complete")
+    return registered, crawled
+
+
+def match_file_lists(
+    registered: List[Tuple[int, str]],
+    crawled: List[Tuple[int, str]],
+    save_missing: Optional[str] = None,
+) -> List[Tuple[int, str]]:
     # Create set for easier manipulation
     reg_data = set(tuple_list_to_csv(registered))
     cra_data = set(tuple_list_to_csv(crawled))
@@ -105,27 +128,59 @@ def gen_csv(obs_id: str, o_file: Optional[str] = None, save_missing:Optional[str
     # Registered, no file found
     reg_no_file = reg_data.difference(cra_data)
     if len(reg_no_file) > 0:
-        logger.warning("Observation contains %i registered files which are not present on disk",  len(reg_no_file))
+        logger.warning(
+            "Observation contains %i registered files which are not present on disk",
+            len(reg_no_file),
+        )
         for row in reg_no_file:
             logger.warning(row)
         if save_missing:
-            write_output("\n".join(reg_no_file), os.path.join(save_missing, "reg_no_file.csv"))
+            write_output(
+                "\n".join(reg_no_file), os.path.join(save_missing, "reg_no_file.csv")
+            )
 
     # File found, not registered
     file_no_reg = cra_data.difference(reg_data)
     if len(file_no_reg) > 0:
-        logger.warning("Observation contains %i files on disk which are not registered in the LTA", len(file_no_reg))
+        logger.warning(
+            "Observation contains %i files on disk which are not registered in the LTA",
+            len(file_no_reg),
+        )
         for row in file_no_reg:
             logger.warning(row)
         if save_missing:
-            write_output("\n".join(file_no_reg), os.path.join(save_missing, "file_no_reg.csv"))
+            write_output(
+                "\n".join(file_no_reg), os.path.join(save_missing, "file_no_reg.csv")
+            )
 
+    return reg_data.intersection(cra_data)
+
+
+def gen_csv(
+    obs_id: str,
+    o_file: Optional[str] = None,
+    save_missing: Optional[str] = None,
+    config_file: Optional[str] = None,
+):
+    """Generate CSV file for ATDB
+
+    Parameters
+    ----------
+    obs_id : string
+        observation id
+    o_file : string, optional
+        File to write the output to (by default to stdout)
+    save_missing : string, optional
+        Directory to write missing files output to (by default only shows warnings!)
+    config_file: string, optional
+        File in .env format to load configuration from
+    """
 
-    inter = reg_data.intersection(cra_data)
+    registered, crawled = query_file_lists(obs_id, config_file)
 
-    write_output("\n".join(inter), o_file)
+    intersection = match_file_lists(registered, crawled, save_missing)
 
-    logger.info("Done")
+    write_output("\n".join(intersection), o_file)
 
 
 def main():
@@ -135,10 +190,16 @@ def main():
         logging.basicConfig(level=logging.DEBUG)
     elif args.q:
         logging.basicConfig(level=logging.ERROR)
-    else:    
+    else:
         logging.basicConfig(level=logging.INFO)
 
-    gen_csv(obs_id=args.obs_id,o_file=args.o, save_missing=args.save_missing)
+    gen_csv(
+        obs_id=args.obs_id,
+        o_file=args.o,
+        save_missing=args.save_missing,
+        config_file=args.c,
+    )
+
 
 if __name__ == "__main__":
     main()
diff --git a/atdb_csv_gen/requirements.txt b/atdb_csv_gen/requirements.txt
index 41d60df..10da2bf 100644
--- a/atdb_csv_gen/requirements.txt
+++ b/atdb_csv_gen/requirements.txt
@@ -1,2 +1,3 @@
 psycopg2-binary>=2.9.1
 sqlalchemy>=1.4.26
+python-dotenv>=0.19.2
\ No newline at end of file
-- 
GitLab