Skip to content
Snippets Groups Projects
Commit d54dfc27 authored by John Swinbank's avatar John Swinbank
Browse files

Smarter handling of tokens

In particular, don't get into an infinite loop when a valid token can't be
found.

We prefer a user-supplied token if one exists and it's valid; otherwise, fall
back to search other locations.
parent 8e2e4701
No related branches found
No related tags found
1 merge request!12Smarter handling of tokens
......@@ -13,6 +13,17 @@ import requests
logger = logging.getLogger(__name__)
def is_valid_token(token: str) -> bool:
"""Check that the given token has not expired"""
try:
selfdata = token.split(".")[1]
padded = data + "=" * divmod(len(data), 4)[1]
payload = json.loads(base64.urlsafe_b64decode(padded))
return payload["exp"] > int(time.time()) + 10
except Exception as e:
logger.warning(f"Couldn't parse token: {e}")
return False
class ShoppingClient:
......@@ -41,7 +52,6 @@ class ShoppingClient:
self.token = token
self.host = host
self.connectors = connectors
self.basket = None
def get_basket(
......@@ -97,24 +107,7 @@ class ShoppingClient:
return self.basket
def _is_valid_token(self, token: Optional[str]) -> bool:
"""Checks expiry of the token"""
if token is None:
return False
try:
data = token.split(".")[1]
padded = data + "=" * divmod(len(data), 4)[1]
payload = json.loads(base64.urlsafe_b64decode(padded))
return payload["exp"] > int(time.time()) + 10
except KeyError:
raise RuntimeError("Invalid JWT format")
def _request_header(self):
while not self._is_valid_token(self.token):
self._get_token()
return dict(Accept="application/json", Authorization=f"Bearer {self.token}")
# filter on items belonging to the provided connectors
......@@ -165,14 +158,18 @@ class ShoppingClient:
)
return self.basket
def _get_token(self):
@property
def token(self):
# If there is a user-specified token and it's valid, return that.
if self._user_token and is_valid_token(self._user_token):
return self._user_token
# Otherwise, search a variety of possible locations and return the
# first valid token we find.
# Generic JH token method using authstate
# Generic JupyterHub token method using authstate
jh_api_uri = getenv("JUPYTERHUB_API_URL")
jh_api_token = getenv("JUPYTERHUB_API_TOKEN")
# Fallback to older rucio file
token_fn = getenv("RUCIO_OIDC_FILE_NAME")
try:
if all((jh_api_token, jh_api_uri)):
res = requests.get(
......@@ -180,18 +177,23 @@ class ShoppingClient:
headers={"Authorization": f"token {jh_api_token}"},
)
self.token = res.json()["auth_state"]["exchanged_tokens"][self.audience]
if is_valid_token(token := res.json()["auth_state"]["exchanged_tokens"][self.audience]):
return token
except KeyError:
logger.warning("JupyterHub without Authstate enabled")
# Try to get token from Rucio OIDC file (when running in CERN DLaaS notebook)
if self.token is None and token_fn is not None:
with open(token_fn) as token_file:
self.token = token_file.readline()
if rucio_token_filename := getenv("RUCIO_OIDC_FILE_NAME"):
with open(rucio_token_filename) as token_file:
if is_valid_token(token := token_file.readline()):
return token
# Finally, fall back to prompting the user.
if is_valid_token(token := getpass.getpass("Enter your ESAP access token:")):
return token
elif self.token is None:
self.token = getpass.getpass("Enter your ESAP access token:")
raise RuntimeError("No valid token available")
if self.token is None:
raise RuntimeError("No token found!")
@token.setter
def token(self, user_token):
self._user_token = user_token
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment