diff --git a/src/ska/logging/transactions.py b/src/ska/logging/transactions.py index 8e5d1ccc9af4145f582cad765cdfe71b07830ff3..1e409c303e1443531f7cc03b46b71742fa113f24 100644 --- a/src/ska/logging/transactions.py +++ b/src/ska/logging/transactions.py @@ -21,15 +21,13 @@ class TransactionIDTagsFilter(logging.Filter): """ def get_transaction_id(self): - if hasattr(thread_local_data, "transaction_ids"): - thread_id = threading.get_ident() - return thread_local_data.transaction_ids.get(thread_id, None) + if hasattr(thread_local_data, "transaction_id"): + return thread_local_data.transaction_id return None - def get_frame(self): - if hasattr(thread_local_data, "frames"): - thread_id = threading.get_ident() - return thread_local_data.frames.get(thread_id, None) + def get_caller_frame(self): + if hasattr(thread_local_data, "caller_frame"): + return thread_local_data.caller_frame return None def filter(self, record): @@ -47,15 +45,15 @@ class TransactionIDTagsFilter(logging.Filter): # difficult to debug where the `Enter` and `Exit` of a transaction log message # originated. # From Python 3.8 we should rather use `stacklevel` - frame = self.get_frame() - if frame: + caller_frame = self.get_caller_frame() + if caller_frame: if record.pathname == __file__ and record.funcName in [ "__enter__", "__exit__", ]: - record.filename = os.path.basename(frame.filename) - record.lineno = frame.lineno - record.funcName = frame.function + record.filename = os.path.basename(caller_frame.filename) + record.lineno = caller_frame.lineno + record.funcName = caller_frame.function return True @@ -163,7 +161,7 @@ class Transaction: self._transaction_id_key = transaction_id_key self._transaction_id = self._get_id_from_params_or_generate_new_id(transaction_id) - self._frame = inspect.stack()[1] + self._caller_frame = inspect.stack()[1] if transaction_id and params.get(self._transaction_id_key): self.logger.info( @@ -177,22 +175,12 @@ class Transaction: self._transaction_filter = TransactionIDTagsFilter() def store_thread_data(self): - thread_id = threading.get_ident() - if not hasattr(thread_local_data, "transaction_ids"): - thread_local_data.transaction_ids = {} - if not hasattr(thread_local_data, "frames"): - thread_local_data.frames = {} - thread_local_data.transaction_ids[thread_id] = self._transaction_id - thread_local_data.frames[thread_id] = self._frame + thread_local_data.transaction_id = self._transaction_id + thread_local_data.caller_frame = self._caller_frame def clear_thread_data(self): - thread_id = threading.get_ident() - if hasattr(thread_local_data, "transaction_ids"): - if thread_id in thread_local_data.transaction_ids: - del thread_local_data.transaction_ids[thread_id] - if hasattr(thread_local_data, "frames"): - if thread_id in thread_local_data.frames: - del thread_local_data.frames[thread_id] + thread_local_data.transaction_id = None + thread_local_data.caller_frame = None def __enter__(self): self.store_thread_data()