From e6818c9fa25817b3996bfd0da48c13329424a057 Mon Sep 17 00:00:00 2001 From: Klaas Kliffen <kliffen@astron.nl> Date: Thu, 18 Nov 2021 09:06:26 +0000 Subject: [PATCH] Resolve SDC-355 --- README.md | 3 + atdb_csv_gen/atdb_csv_gen/csv_gen.py | 151 +++++++++++++++++++-------- atdb_csv_gen/requirements.txt | 1 + 3 files changed, 110 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 670d61a..d7580b2 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,9 @@ cp ./atdb_csv_gen/.env.example ./atdb_csv_gen/.env export $(cat .env | xargs) # or export env vars elsewhere atdb_csv_gen -v -o out.csv 650273 +# or using a config file following .env.example layout: +atdb_csv_gen -c path/to/config.env -o out.csv 650273 + # More info and flags atdb_csv_gen -h ``` diff --git a/atdb_csv_gen/atdb_csv_gen/csv_gen.py b/atdb_csv_gen/atdb_csv_gen/csv_gen.py index 8a1ad36..e14b44e 100755 --- a/atdb_csv_gen/atdb_csv_gen/csv_gen.py +++ b/atdb_csv_gen/atdb_csv_gen/csv_gen.py @@ -3,61 +3,64 @@ import logging import os -import sys from argparse import ArgumentParser, Namespace from sys import version from typing import Iterator, List, Optional, Tuple -from sqlalchemy import create_engine -from sqlalchemy.engine.base import Connection +from dotenv import load_dotenv +from sqlalchemy import create_engine as sqla_create_engine +from sqlalchemy.engine.base import Connection, Engine from sqlalchemy.sql import text -USER = os.getenv("DB_USERNAME") -PASSWORD = os.getenv("DB_PASSWORD") -DATABASE = os.getenv("DB_DATABASE") -HOST = os.getenv("DB_HOST", "localhost") -PORT = os.getenv("DP_PORT", "5432") - - logger = logging.getLogger(__name__) + def parse_args() -> Namespace: - """ Parse CLI arguments """ + """Parse CLI arguments""" - parser = ArgumentParser(description="Generates CSV for atdb_load_tasks_from_table\n" + parser = ArgumentParser( + description="Generates CSV for atdb_load_tasks_from_table\n" "Requires the following env vars:\n" "\tDB_USERNAME, DB_PASSWORD, DB_DATABASE and optionally:\n" "\tDB_HOST (localhost by default) and DB_PORT (5432, postgres default)" ) parser.add_argument("obs_id", type=str, help="Observation ID (SAS ID)") parser.add_argument("-o", type=str, help="Output file") - parser.add_argument("--save-missing", type=str, help="Directory to store csv with missing files") + parser.add_argument( + "-c", type=str, help="Config file (in .env format)", default=None + ) + parser.add_argument( + "--save-missing", type=str, help="Directory to store csv with missing files" + ) parser.add_argument("-v", action="store_true", help="Verbose logging") parser.add_argument("-q", action="store_true", help="Quiet logging (only errors)") return parser.parse_args() def fetch_registered_files(conn: Connection, obs_id: str) -> List[Tuple[int, str]]: - """ Fetch the registered files from the LTA """ + """Fetch the registered files from the LTA""" sql = text("SELECT file_size, surl FROM ldv_delete.aw_uris WHERE obsid = :obs_id") return [row for row in conn.execute(sql, obs_id=obs_id)] def fetch_crawled_files(conn: Connection, obs_id: str) -> List[Tuple[int, str]]: - """ Fetch crawled files from the LTA """ + """Fetch crawled files from the LTA""" - sql = text("SELECT file_size, surl FROM ldv_delete.lta_uris WHERE dir_name like CONCAT('/pnfs/grid.sara.nl/data/lofar/ops/projects/%/',:obs_id)") + sql = text( + "SELECT file_size, surl FROM ldv_delete.lta_uris WHERE dir_name like CONCAT('/pnfs/grid.sara.nl/data/lofar/ops/projects/%/',:obs_id)" + ) return [row for row in conn.execute(sql, obs_id=obs_id)] -def tuple_list_to_csv(l: List[Tuple[int, str]], sep: str=";") -> Iterator[str]: - """ Convert a list of tuples to a lines in a csv """ +def tuple_list_to_csv(l: List[Tuple[int, str]], sep: str = ";") -> Iterator[str]: + """Convert a list of tuples to a line in a csv""" return map(lambda row: f"{row[0]}{sep}{row[1]}", l) + def write_output(csv_data: str, output: Optional[str]): - """ Write csv data to output file or stdout """ + """Write csv data to output file or stdout""" if output is not None: dirs = os.path.dirname(output) if dirs != "": @@ -70,34 +73,54 @@ def write_output(csv_data: str, output: Optional[str]): else: print(csv_data) -def gen_csv(obs_id: str, o_file: Optional[str] = None, save_missing:Optional[str] = None): - """ Generate CSV file for ATDB - Parameters - ---------- - obs_id : string - observation id - o_file : string, optional - File to write the output to (by default to stdout) - save_missing : string, optional - Directory to write missing files output to (by default only shows warnings!) - """ +def create_engine(env_file: Optional[str] = None) -> Engine: + """Create sqlalchemy engine from env vars or config file""" + + # Optionally load the config file + if env_file is not None: + load_dotenv(env_file) + + USER = os.getenv("DB_USERNAME") + PASSWORD = os.getenv("DB_PASSWORD") + DATABASE = os.getenv("DB_DATABASE") + HOST = os.getenv("DB_HOST", "localhost") + PORT = os.getenv("DP_PORT", "5432") - if not all([USER, PASSWORD, DATABASE, HOST, PORT]): - raise RuntimeError("Missing DB credentials in env vars.\n" - "Did you export the .env file?") + if not all((USER, PASSWORD, DATABASE, HOST, PORT)): + raise RuntimeError( + "Missing DB credentials in env vars.\nDid you export the .env file?" + ) - engine = create_engine(f"postgresql://{USER}:{PASSWORD}@{HOST}:{PORT}/{DATABASE}") + engine = sqla_create_engine(f"postgresql://{USER}:{PASSWORD}@{HOST}:{PORT}/{DATABASE}") logger.debug("Connected to %s:%s", HOST, PORT) + return engine + + +def query_file_lists( + obs_id: str, config_file: Optional[str] = None +) -> Tuple[List[Tuple[int, str]], List[Tuple[int, str]]]: + """Query database for both file lists""" + + engine = create_engine(config_file) + + logger.info("Query file lists...") with engine.connect() as conn: conn: Connection - registered = fetch_registered_files(conn, obs_id) logger.debug("Found %i registered files", len(registered)) crawled = fetch_crawled_files(conn, obs_id) logger.debug("Found %i crawled files", len(crawled)) - + logger.info("Query complete") + return registered, crawled + + +def match_file_lists( + registered: List[Tuple[int, str]], + crawled: List[Tuple[int, str]], + save_missing: Optional[str] = None, +) -> List[Tuple[int, str]]: # Create set for easier manipulation reg_data = set(tuple_list_to_csv(registered)) cra_data = set(tuple_list_to_csv(crawled)) @@ -105,27 +128,59 @@ def gen_csv(obs_id: str, o_file: Optional[str] = None, save_missing:Optional[str # Registered, no file found reg_no_file = reg_data.difference(cra_data) if len(reg_no_file) > 0: - logger.warning("Observation contains %i registered files which are not present on disk", len(reg_no_file)) + logger.warning( + "Observation contains %i registered files which are not present on disk", + len(reg_no_file), + ) for row in reg_no_file: logger.warning(row) if save_missing: - write_output("\n".join(reg_no_file), os.path.join(save_missing, "reg_no_file.csv")) + write_output( + "\n".join(reg_no_file), os.path.join(save_missing, "reg_no_file.csv") + ) # File found, not registered file_no_reg = cra_data.difference(reg_data) if len(file_no_reg) > 0: - logger.warning("Observation contains %i files on disk which are not registered in the LTA", len(file_no_reg)) + logger.warning( + "Observation contains %i files on disk which are not registered in the LTA", + len(file_no_reg), + ) for row in file_no_reg: logger.warning(row) if save_missing: - write_output("\n".join(file_no_reg), os.path.join(save_missing, "file_no_reg.csv")) + write_output( + "\n".join(file_no_reg), os.path.join(save_missing, "file_no_reg.csv") + ) + return reg_data.intersection(cra_data) + + +def gen_csv( + obs_id: str, + o_file: Optional[str] = None, + save_missing: Optional[str] = None, + config_file: Optional[str] = None, +): + """Generate CSV file for ATDB + + Parameters + ---------- + obs_id : string + observation id + o_file : string, optional + File to write the output to (by default to stdout) + save_missing : string, optional + Directory to write missing files output to (by default only shows warnings!) + config_file: string, optional + File in .env format to load configuration from + """ - inter = reg_data.intersection(cra_data) + registered, crawled = query_file_lists(obs_id, config_file) - write_output("\n".join(inter), o_file) + intersection = match_file_lists(registered, crawled, save_missing) - logger.info("Done") + write_output("\n".join(intersection), o_file) def main(): @@ -135,10 +190,16 @@ def main(): logging.basicConfig(level=logging.DEBUG) elif args.q: logging.basicConfig(level=logging.ERROR) - else: + else: logging.basicConfig(level=logging.INFO) - gen_csv(obs_id=args.obs_id,o_file=args.o, save_missing=args.save_missing) + gen_csv( + obs_id=args.obs_id, + o_file=args.o, + save_missing=args.save_missing, + config_file=args.c, + ) + if __name__ == "__main__": main() diff --git a/atdb_csv_gen/requirements.txt b/atdb_csv_gen/requirements.txt index 41d60df..10da2bf 100644 --- a/atdb_csv_gen/requirements.txt +++ b/atdb_csv_gen/requirements.txt @@ -1,2 +1,3 @@ psycopg2-binary>=2.9.1 sqlalchemy>=1.4.26 +python-dotenv>=0.19.2 \ No newline at end of file -- GitLab