diff --git a/atdb_csv_gen/atdb_csv_gen/csv_gen.py b/atdb_csv_gen/atdb_csv_gen/csv_gen.py index 3f4f4e850b7a517e855426d57af9232423a7b626..0a87bd337639713a8cb7375804b8c98339f05f1c 100755 --- a/atdb_csv_gen/atdb_csv_gen/csv_gen.py +++ b/atdb_csv_gen/atdb_csv_gen/csv_gen.py @@ -13,6 +13,10 @@ from sqlalchemy.engine.base import Connection, Engine from sqlalchemy.sql import text logger = logging.getLogger(__name__) +_EXPECTED_STR_PER_DATATYPE = { + 'MS': '.MS_', + 'BF': '_bf.' +} def parse_args() -> Namespace: @@ -20,9 +24,9 @@ def parse_args() -> Namespace: parser = ArgumentParser( description="Generates CSV for atdb_load_tasks_from_table\n" - "Requires the following env vars:\n" - "\tDB_USERNAME, DB_PASSWORD, DB_DATABASE and optionally:\n" - "\tDB_HOST (localhost by default) and DB_PORT (5432, postgres default)" + "Requires the following env vars:\n" + "\tDB_USERNAME, DB_PASSWORD, DB_DATABASE and optionally:\n" + "\tDB_HOST (localhost by default) and DB_PORT (5432, postgres default)" ) parser.add_argument("obs_id", type=str, help="Observation ID (SAS ID)") parser.add_argument("-o", type=str, help="Output file") @@ -32,6 +36,7 @@ def parse_args() -> Namespace: parser.add_argument( "--save-missing", type=str, help="Directory to store csv with missing files" ) + parser.add_argument("--datatype", help="Select a specific datatype", choices=_EXPECTED_STR_PER_DATATYPE.keys()) parser.add_argument("-v", action="store_true", help="Verbose logging") parser.add_argument("-q", action="store_true", help="Quiet logging (only errors)") return parser.parse_args() @@ -99,7 +104,7 @@ def create_engine(env_file: Optional[str] = None) -> Engine: def query_file_lists( - obs_id: str, config_file: Optional[str] = None + obs_id: str, config_file: Optional[str] = None ) -> Tuple[List[Tuple[int, str]], List[Tuple[int, str]]]: """Query database for both file lists""" @@ -117,9 +122,9 @@ def query_file_lists( def match_file_lists( - registered: List[Tuple[int, str]], - crawled: List[Tuple[int, str]], - save_missing: Optional[str] = None, + registered: List[Tuple[int, str]], + crawled: List[Tuple[int, str]], + save_missing: Optional[str] = None, ) -> List[Tuple[int, str]]: # Create set for easier manipulation reg_data = set(tuple_list_to_csv(registered)) @@ -156,11 +161,20 @@ def match_file_lists( return reg_data.intersection(cra_data) +def filter_datatype(size_path_list, datatype): + if datatype is None: + return size_path_list + else: + expected_string = _EXPECTED_STR_PER_DATATYPE[datatype] + return [item for item in size_path_list if expected_string in item] + + def gen_csv( - obs_id: str, - o_file: Optional[str] = None, - save_missing: Optional[str] = None, - config_file: Optional[str] = None, + obs_id: str, + o_file: Optional[str] = None, + save_missing: Optional[str] = None, + config_file: Optional[str] = None, + datatype: Optional[str] = None, ): """Generate CSV file for ATDB @@ -180,7 +194,9 @@ def gen_csv( intersection = match_file_lists(registered, crawled, save_missing) - write_output("\n".join(intersection), o_file) + filtered_intersection = filter_datatype(intersection, datatype) + + write_output("\n".join(filtered_intersection), o_file) def main(): @@ -198,6 +214,7 @@ def main(): o_file=args.o, save_missing=args.save_missing, config_file=args.c, + datatype=args.datatype )