Skip to content
Snippets Groups Projects
Unverified Commit 03a598dd authored by SKAJohanVenter's avatar SKAJohanVenter
Browse files

SAR-150 Updated the use of thread_local

parent ee5b7175
No related branches found
No related tags found
No related merge requests found
...@@ -21,15 +21,13 @@ class TransactionIDTagsFilter(logging.Filter): ...@@ -21,15 +21,13 @@ class TransactionIDTagsFilter(logging.Filter):
""" """
def get_transaction_id(self): def get_transaction_id(self):
if hasattr(thread_local_data, "transaction_ids"): if hasattr(thread_local_data, "transaction_id"):
thread_id = threading.get_ident() return thread_local_data.transaction_id
return thread_local_data.transaction_ids.get(thread_id, None)
return None return None
def get_frame(self): def get_caller_frame(self):
if hasattr(thread_local_data, "frames"): if hasattr(thread_local_data, "caller_frame"):
thread_id = threading.get_ident() return thread_local_data.caller_frame
return thread_local_data.frames.get(thread_id, None)
return None return None
def filter(self, record): def filter(self, record):
...@@ -47,15 +45,15 @@ class TransactionIDTagsFilter(logging.Filter): ...@@ -47,15 +45,15 @@ class TransactionIDTagsFilter(logging.Filter):
# difficult to debug where the `Enter` and `Exit` of a transaction log message # difficult to debug where the `Enter` and `Exit` of a transaction log message
# originated. # originated.
# From Python 3.8 we should rather use `stacklevel` # From Python 3.8 we should rather use `stacklevel`
frame = self.get_frame() caller_frame = self.get_caller_frame()
if frame: if caller_frame:
if record.pathname == __file__ and record.funcName in [ if record.pathname == __file__ and record.funcName in [
"__enter__", "__enter__",
"__exit__", "__exit__",
]: ]:
record.filename = os.path.basename(frame.filename) record.filename = os.path.basename(caller_frame.filename)
record.lineno = frame.lineno record.lineno = caller_frame.lineno
record.funcName = frame.function record.funcName = caller_frame.function
return True return True
...@@ -163,7 +161,7 @@ class Transaction: ...@@ -163,7 +161,7 @@ class Transaction:
self._transaction_id_key = transaction_id_key self._transaction_id_key = transaction_id_key
self._transaction_id = self._get_id_from_params_or_generate_new_id(transaction_id) self._transaction_id = self._get_id_from_params_or_generate_new_id(transaction_id)
self._frame = inspect.stack()[1] self._caller_frame = inspect.stack()[1]
if transaction_id and params.get(self._transaction_id_key): if transaction_id and params.get(self._transaction_id_key):
self.logger.info( self.logger.info(
...@@ -177,22 +175,12 @@ class Transaction: ...@@ -177,22 +175,12 @@ class Transaction:
self._transaction_filter = TransactionIDTagsFilter() self._transaction_filter = TransactionIDTagsFilter()
def store_thread_data(self): def store_thread_data(self):
thread_id = threading.get_ident() thread_local_data.transaction_id = self._transaction_id
if not hasattr(thread_local_data, "transaction_ids"): thread_local_data.caller_frame = self._caller_frame
thread_local_data.transaction_ids = {}
if not hasattr(thread_local_data, "frames"):
thread_local_data.frames = {}
thread_local_data.transaction_ids[thread_id] = self._transaction_id
thread_local_data.frames[thread_id] = self._frame
def clear_thread_data(self): def clear_thread_data(self):
thread_id = threading.get_ident() thread_local_data.transaction_id = None
if hasattr(thread_local_data, "transaction_ids"): thread_local_data.caller_frame = None
if thread_id in thread_local_data.transaction_ids:
del thread_local_data.transaction_ids[thread_id]
if hasattr(thread_local_data, "frames"):
if thread_id in thread_local_data.frames:
del thread_local_data.frames[thread_id]
def __enter__(self): def __enter__(self):
self.store_thread_data() self.store_thread_data()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment