From d594420b5c65c6b3fb08db886b51cfee6e930b47 Mon Sep 17 00:00:00 2001
From: lukken <lukken@astron.nl>
Date: Tue, 21 Sep 2021 10:10:52 +0000
Subject: [PATCH] L2SS-340: Change initialization behavior to run immediately

---
 devices/clients/tcp_replicator.py             | 118 ++++++++++--------
 .../client/test_tcp_replicator.py             |   8 +-
 devices/test/clients/test_tcp_replicator.py   |  61 +++++----
 3 files changed, 106 insertions(+), 81 deletions(-)

diff --git a/devices/clients/tcp_replicator.py b/devices/clients/tcp_replicator.py
index f6087ef03..39e9e64dd 100644
--- a/devices/clients/tcp_replicator.py
+++ b/devices/clients/tcp_replicator.py
@@ -4,7 +4,6 @@ from threading import Semaphore
 
 import asyncio
 import logging
-import time
 
 from clients.statistics_client_thread import StatisticsClientThread
 
@@ -43,7 +42,7 @@ class TCPReplicator(Thread, StatisticsClientThread):
     _default_options = {
         "tcp_bind": '127.0.0.1',
         "tcp_port": 6666,
-        "tcp_buffer_size": 128000000,   # In bytes
+        "tcp_buffer_size": 128000000,  # In bytes
     }
 
     def __init__(self, options: dict = None):
@@ -57,7 +56,8 @@ class TCPReplicator(Thread, StatisticsClientThread):
         """
         self._loop = None
 
-        # Create and acquire lock to prevent premature termination in join
+        # Create and acquire lock to prevent leaving the constructor without
+        # starting the thread.
         self.initialization_semaphore = Semaphore()
         self.initialization_semaphore.acquire()
 
@@ -67,8 +67,21 @@ class TCPReplicator(Thread, StatisticsClientThread):
         # Connected clients the event loop is managing
         self._connected_clients = []
 
+        # Parse the configured options
         self.options = self._parse_options(options)
 
+        # We start ourselves immediately to reduce amount of possible states.
+        self.start()
+
+        # Wait until we can hold the semaphore, this indicates the thread has
+        # initialized or encountered an exception.
+        with self.initialization_semaphore:
+            if not self.is_alive():
+                self.join()
+                raise RuntimeError("TCPReplicator failed to initialize")
+
+            logging.debug("TCPReplicator initialization completed")
+
     @property
     def _options(self) -> dict:
         return TCPReplicator._default_options
@@ -114,36 +127,50 @@ class TCPReplicator(Thread, StatisticsClientThread):
             pass
 
     def run(self):
-        """Run is launched by calling .start() on TCPReplicator
+        """Run is launched from constructor of TCPReplicator
 
         It manages an asyncio event loop to orchestrate our TCPServerProtocol.
         """
+        try:
+            logger.info("Starting TCPReplicator thread")
 
-        logger.info("Starting TCPReplicator thread")
+            # Create the event loop, must be done in the new thread
+            self._loop = asyncio.new_event_loop()
 
-        # Create the event loop, must be done in the new thread
-        self._loop = asyncio.new_event_loop()
+            # TODO(Corne): REMOVE ME
+            self._loop.set_debug(True)
 
-        # TODO(Corne): REMOVE ME
-        self._loop.set_debug(True)
+            # Schedule the task to create the server
+            self._loop.create_task(TCPReplicator._run_server(
+                self.options, self._connected_clients))
 
-        # Schedule the task to create the server
-        self._loop.create_task(TCPReplicator._run_server(
-            self.options, self._connected_clients))
+            # Everything is initialized, the constructor can safely return
+            self.initialization_semaphore.release()
 
-        # Everything is initialized, join can now safely be called
-        self.initialization_semaphore.release()
-
-        # Keep running event loop until self._loop.stop() is called
-        self._loop.run_forever()
-
-        # Stop must have been called, close the event loop
-        with self.shutdown_condition:
-            logger.debug("Closing TCPReplicator event loop")
-            self._loop.close()
-            self.shutdown_condition.notify()
+            # Keep running event loop until self._loop.stop() is called.
+            # Calling this will lose control flow to the event loop indefinitely,
+            # upon self._loop.stop() control flow is returned here.
+            self._loop.run_forever()
 
-        return
+            # Stop must have been called, close the event loop
+            with self.shutdown_condition:
+                logger.debug("Closing TCPReplicator event loop")
+                self._loop.close()
+                self.shutdown_condition.notify()
+        except Exception as e:
+            # Log the exception as thread exceptions won't be returned to us
+            # on the main thread.
+            logging.fatal("TCPReplicator thread encountered fatal exception: "
+                          "{}".format(e))
+            # We will lose the exception and the original stacktrace of the
+            # thread. Once we use a threadpool it will be much easier to
+            # retrieve this so I propose to not bother implementing it now.
+            # For the pattern to do this see anyway:
+            # https://stackoverflow.com/a/6894023
+        finally:
+            # Always release the lock upon error so the constructor can return
+            if self.initialization_semaphore.acquire(blocking=False) is False:
+                self.initialization_semaphore.release()
 
     def transmit(self, data: bytes):
         """Transmit data to connected clients"""
@@ -151,24 +178,17 @@ class TCPReplicator(Thread, StatisticsClientThread):
         if not isinstance(data, (bytes, bytearray)):
             raise TypeError("Data must be byte-like object")
 
-        with self.initialization_semaphore:
-            if not self._loop.is_running():
-                logger.warning("Attempt to transmit with TCPReplicator before"
-                               "fully started.")
-                return
-
-            self._loop.call_soon_threadsafe(
-                self._loop.create_task, self._transmit(data))
+        self._loop.call_soon_threadsafe(
+            self._loop.create_task, self._transmit(data))
 
     def join(self, timeout=None):
-        with self.initialization_semaphore:
-            logging.info("Received shutdown request on TCPReplicator thread")
+        logging.info("Received shutdown request on TCPReplicator thread")
 
-            self._clean_shutdown()
+        self._clean_shutdown()
 
-            # Only call join at the end otherwise Thread will falsely assume
-            # all child 'processes' have stopped
-            super().join(timeout)
+        # Only call join at the end otherwise Thread will falsely assume
+        # all child 'processes' have stopped
+        super().join(timeout)
 
     def disconnect(self):
         # TODO(Corne): Prevent duplicate code across TCPReplicator, UDPReceiver
@@ -215,10 +235,14 @@ class TCPReplicator(Thread, StatisticsClientThread):
     def _clean_shutdown(self):
         """Disconnect clients, stop the event loop and wait for it to close"""
 
-        # This should never ever happen, semaphore race condition
+        # Event loop did not start, this can happen when run raises an exception
+        # early
         if not self._loop:
-            logging.error(
-                "TCPReplicator event loop unset, early termination?!")
+            return
+
+        # The event loop is not running anymore, we can't send tasks to shut
+        # it down further.
+        if not self._loop.is_running():
             return
 
         with self.shutdown_condition:
@@ -232,15 +256,3 @@ class TCPReplicator(Thread, StatisticsClientThread):
                 self._loop.call_soon_threadsafe(
                     self._loop.create_task, self._conditional_stop())
                 self.shutdown_condition.wait()
-
-        # Should never happen, conditional race condition
-        while self._loop.is_running():
-            logging.error("TCPReplicator event loop still running after"
-                          "returning from condition.wait!")
-            time.sleep(1)
-
-        # Should never happen, conditional race condition
-        while not self._loop.is_closed():
-            logging.error("TCPReplicator event loop not closed after"
-                          "returning from condition.wait!")
-            time.sleep(1)
diff --git a/devices/integration_test/client/test_tcp_replicator.py b/devices/integration_test/client/test_tcp_replicator.py
index e5a8d52ca..2bac9356a 100644
--- a/devices/integration_test/client/test_tcp_replicator.py
+++ b/devices/integration_test/client/test_tcp_replicator.py
@@ -35,7 +35,9 @@ class TestTCPReplicator(base.IntegrationTestCase):
         }
 
         replicator = TCPReplicator(test_options)
-        replicator.start()
+
+    def test_start_except(self):
+
 
     def test_start_transmit_empty_stop(self):
         """Test transmitting without clients"""
@@ -45,7 +47,6 @@ class TestTCPReplicator(base.IntegrationTestCase):
         }
 
         replicator = TCPReplicator(test_options)
-        replicator.start()
 
         replicator.transmit("Hello World!".encode('utf-8'))
 
@@ -55,7 +56,6 @@ class TestTCPReplicator(base.IntegrationTestCase):
         }
 
         replicator = TCPReplicator(test_options)
-        replicator.start()
 
         time.sleep(2)
 
@@ -76,7 +76,6 @@ class TestTCPReplicator(base.IntegrationTestCase):
         m_data = "hello world".encode("utf-8")
 
         replicator = TCPReplicator(test_options)
-        replicator.start()
 
         time.sleep(2)
 
@@ -100,7 +99,6 @@ class TestTCPReplicator(base.IntegrationTestCase):
         m_data = "hello world".encode("utf-8")
 
         replicator = TCPReplicator(test_options)
-        replicator.start()
 
         time.sleep(2)
 
diff --git a/devices/test/clients/test_tcp_replicator.py b/devices/test/clients/test_tcp_replicator.py
index 4e2661bc5..0b8039de7 100644
--- a/devices/test/clients/test_tcp_replicator.py
+++ b/devices/test/clients/test_tcp_replicator.py
@@ -12,6 +12,7 @@ import time
 from unittest import mock
 
 from clients.tcp_replicator import TCPReplicator
+from clients import tcp_replicator
 
 from test import base
 
@@ -22,9 +23,21 @@ logger = logging.getLogger()
 
 class TestTCPReplicator(base.TestCase):
 
+    @staticmethod
+    async def dummy_task():
+        pass
+
     def setUp(self):
         super(TestTCPReplicator, self).setUp()
 
+        # Create reusable test fixture for unit tests
+        self.m_tcp_replicator = TCPReplicator
+        stat_agg_patcher = mock.patch.object(
+            self.m_tcp_replicator, '_run_server',
+            spec=TCPReplicator._run_server, return_value=self.dummy_task())
+        self.mock_replicator = stat_agg_patcher.start()
+        self.addCleanup(stat_agg_patcher.stop)
+
     def test_parse_options(self):
         """Validate option parsing"""
 
@@ -36,7 +49,7 @@ class TestTCPReplicator(base.TestCase):
             "tcp_bind": '0.0.0.0',  # I should get set
         }
 
-        replicator = TCPReplicator(test_options)
+        replicator = self.m_tcp_replicator(test_options)
 
         # Ensure replicator initialization does not modify static variable
         self.assertEqual(t_tcp_bind, TCPReplicator._default_options['tcp_bind'])
@@ -53,7 +66,7 @@ class TestTCPReplicator(base.TestCase):
         m_client = mock.Mock()
 
         # Create both a TCPReplicator and TCPServerProtocol separately
-        replicator = TCPReplicator()
+        replicator = self.m_tcp_replicator()
         protocol = TCPReplicator.TCPServerProtocol(
             replicator._options, replicator._connected_clients)
 
@@ -66,8 +79,7 @@ class TestTCPReplicator(base.TestCase):
     def test_start_stop(self):
         """Verify threading behavior, being able to start and stop the thread"""
 
-        replicator = TCPReplicator()
-        replicator.start()
+        replicator = self.m_tcp_replicator()
 
         # Give the thread 5 seconds to stop
         replicator.join(5)
@@ -75,31 +87,39 @@ class TestTCPReplicator(base.TestCase):
         # Thread should now be dead
         self.assertFalse(replicator.is_alive())
 
+    def test_start_exception(self):
+        """Verify the run() methods kills the thread cleanly on exceptions"""
+        m_loop = mock.Mock()
+        m_loop.create_task.side_effect = RuntimeError("Test Error")
+
+        # Signal to _clean_shutdown that the exception has caused the loop to
+        # stop
+        m_loop.is_running.return_value = False
+
+        m_replicator_import = tcp_replicator
+
+        with mock.patch.object(m_replicator_import, 'asyncio') as run_patcher:
+            run_patcher.new_event_loop.return_value = m_loop
+
+            # Constructor should raise an exception if the thread is killed
+            self.assertRaises(RuntimeError, self.m_tcp_replicator)
+
     @timeout_decorator.timeout(5)
     def test_start_stop_delete(self):
         """Verify that deleting the TCPReplicator object safely halts thread"""
 
-        replicator = TCPReplicator()
-        replicator.start()
+        replicator = self.m_tcp_replicator()
 
         del replicator
 
-    @staticmethod
-    async def dummy_task():
-        pass
-
-    @mock.patch.object(TCPReplicator, "_run_server")
-    def test_transmit(self, m_run_server):
+    def test_transmit(self):
         """Test that clients are getting data written to their transport"""
-        m_run_server.return_value = self.dummy_task()
 
         m_data = "Hello World!".encode('utf-8')
 
         m_client = mock.Mock()
 
-        replicator = TCPReplicator()
-
-        replicator.start()
+        replicator = self.m_tcp_replicator()
 
         replicator._connected_clients.append(m_client)
 
@@ -116,15 +136,10 @@ class TestTCPReplicator(base.TestCase):
 
         m_client.transport.write.assert_called_once_with(m_data)
 
-    @mock.patch.object(TCPReplicator, "_run_server")
-    def test_disconnect(self, m_run_server):
-        m_run_server.return_value = self.dummy_task()
-
+    def test_disconnect(self,):
         m_client = mock.Mock()
 
-        replicator = TCPReplicator()
-
-        replicator.start()
+        replicator = self.m_tcp_replicator()
 
         replicator._connected_clients.append(m_client)
 
-- 
GitLab