diff --git a/shopping_client/shopping_client.py b/shopping_client/shopping_client.py index d27013e356719542ca5e29b41a25f1d38bc520b9..cdd5a12385c50ed833bdc20b343235d7d8df958b 100644 --- a/shopping_client/shopping_client.py +++ b/shopping_client/shopping_client.py @@ -1,6 +1,7 @@ import base64 import getpass import json +import logging import time import urllib.parse from os import getenv @@ -10,10 +11,13 @@ from warnings import warn import pandas as pd import requests +logger = logging.getLogger(__name__) + class shopping_client: endpoint = "esap-api/accounts/user-profiles/" + audience = "rucio" # Audience used by ESAP, might be configurable later def __init__( self, @@ -41,7 +45,10 @@ class shopping_client: self.basket = None def get_basket( - self, convert_to_pandas: bool = False, reload: bool = False, filter_archives: bool = False + self, + convert_to_pandas: bool = False, + reload: bool = False, + filter_archives: bool = False, ) -> Union[list, pd.DataFrame, None]: """Retrieve the shopping basket for a user. Prompts for access token if one was not supplied to constructor. @@ -91,7 +98,7 @@ class shopping_client: return self.basket def _is_valid_token(self, token: Optional[str]) -> bool: - """ Checks expiry of the token """ + """Checks expiry of the token""" if token is None: return False @@ -117,12 +124,14 @@ class shopping_client: item_data = json.loads(item["item_data"]) for connector in self.connectors: - if "archive" in item_data and item_data["archive"] == connector.archive: + if ( + "archive" in item_data + and item_data["archive"] == connector.archive + ): filtered_items.append(item) return filtered_items - def _basket_to_pandas(self): if len(self.connectors): converted_basket = { @@ -146,12 +155,31 @@ class shopping_client: return self.basket def _get_token(self): - # Try to get token from Rucio OIDC file (when running in CERN DLaaS notebook) + + # Generic JH token method using authstate + jh_api_uri = getenv("JUPYTERHUB_API_URL") + jh_api_token = getenv("JUPYTERHUB_API_TOKEN") + # Fallback to older rucio file token_fn = getenv("RUCIO_OIDC_FILE_NAME") - if token_fn is not None: + + try: + if all((jh_api_token, jh_api_uri)): + res = requests.get( + f"{jh_api_uri}/user", + headers={"Authorization": f"token {jh_api_token}"}, + ) + + self.token = res.json()["auth_state"]["exchanged_tokens"][self.audience] + + except KeyError: + logger.warning("JupyterHub without Authstate enabled") + + # Try to get token from Rucio OIDC file (when running in CERN DLaaS notebook) + if self.token is None and token_fn is not None: with open(token_fn) as token_file: self.token = token_file.readline() - else: + + elif self.token is None: self.token = getpass.getpass("Enter your ESAP access token:") if self.token is None: