From 831162d224f7bbcb02446070ed89abd799a23ec6 Mon Sep 17 00:00:00 2001
From: Jan David Mol <mol@astron.nl>
Date: Sun, 14 Aug 2016 12:24:46 +0000
Subject: [PATCH] Task #9678: Fix use of socket.recv in jobserver and lofarnode
 in case of closed connections and interrupts

---
 .../framework/lofarpipe/support/jobserver.py  | 15 ++++++-----
 .../framework/lofarpipe/support/lofarnode.py  | 20 +++++++++-----
 .../framework/lofarpipe/support/utilities.py  | 26 +++++++++++++++++++
 3 files changed, 49 insertions(+), 12 deletions(-)

diff --git a/CEP/Pipeline/framework/lofarpipe/support/jobserver.py b/CEP/Pipeline/framework/lofarpipe/support/jobserver.py
index 10182fe2890..a88a795a422 100644
--- a/CEP/Pipeline/framework/lofarpipe/support/jobserver.py
+++ b/CEP/Pipeline/framework/lofarpipe/support/jobserver.py
@@ -21,7 +21,7 @@ import cPickle as pickle
 
 from lofarpipe.support.lofarexceptions import PipelineQuit
 from lofarpipe.support.pipelinelogging import log_process_output
-from lofarpipe.support.utilities import spawn_process
+from lofarpipe.support.utilities import spawn_process, socket_recv
 
 class JobStreamHandler(SocketServer.StreamRequestHandler):
     """
@@ -41,17 +41,20 @@ class JobStreamHandler(SocketServer.StreamRequestHandler):
         Each request is expected to be a 4-bute length followed by either a
         GET/PUT request or a pickled LogRecord.
         """
+
         while True:
-            chunk = self.request.recv(4)
+            # Read message length
             try:
+                chunk = socket_recv(self.request, 4)
                 slen = struct.unpack(">L", chunk)[0]
             except:
                 break
-            chunk = self.connection.recv(slen)
-            while len(chunk) < slen:
-                chunk = chunk + self.connection.recv(slen - len(chunk))
-            input_msg = chunk.split(" ", 2)
+
+            # Read message
+            chunk = socket_recv(self.request, slen)
             try:
+                input_msg = chunk.split(" ", 2)
+
                 # Can we handle this message type?
                 if input_msg[0] == "GET":
                     self.send_arguments(int(input_msg[1]))
diff --git a/CEP/Pipeline/framework/lofarpipe/support/lofarnode.py b/CEP/Pipeline/framework/lofarpipe/support/lofarnode.py
index ab028f856f4..c8127e4db51 100644
--- a/CEP/Pipeline/framework/lofarpipe/support/lofarnode.py
+++ b/CEP/Pipeline/framework/lofarpipe/support/lofarnode.py
@@ -16,6 +16,7 @@ import logging.handlers
 import cPickle as pickle
 
 from lofarpipe.support.usagestats import UsageStats
+from lofarpipe.support.utilities  import socket_recv
 
 def run_node(*args):
     """
@@ -128,18 +129,25 @@ class LOFARnodeTCP(LOFARnode):
         while True:
             tries -= 1
             try:
+                # connect
                 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                 self.__try_connect(s)
+
+                # send request
                 message = "GET %d" % self.job_id
                 s.sendall(struct.pack(">L", len(message)) + message)
-                chunk = s.recv(4)
+
+                # receive response length
+                chunk = socket_recv(s, 4)
                 slen = struct.unpack(">L", chunk)[0]
-                chunk = s.recv(slen)
-                while len(chunk) < slen:
-                    chunk += s.recv(slen - len(chunk))
+
+                # receive response
+                chunk = socket_recv(s, slen)
+
+                # parse response
                 self.arguments = pickle.loads(chunk)
-            except socket.error, e:
-                print "Failed to get recipe arguments from server"
+            except (IOError, socket.error) as e:
+                print "Failed to get recipe arguments from server: %s" % (e,)
                 if tries > 0:
                     timeout = random.uniform(min_timeout, max_timeout)
                     print("Retrying in %f seconds (%d more %s)." %
diff --git a/CEP/Pipeline/framework/lofarpipe/support/utilities.py b/CEP/Pipeline/framework/lofarpipe/support/utilities.py
index c445578c03c..31bf7c21a6b 100644
--- a/CEP/Pipeline/framework/lofarpipe/support/utilities.py
+++ b/CEP/Pipeline/framework/lofarpipe/support/utilities.py
@@ -299,3 +299,29 @@ def catch_segfaults(cmd, cwd, env, logger, max = 1, cleanup = lambda: None,
         logger.error("Too many segfaults from %s; aborted" % (cmd[0]))
         raise subprocess.CalledProcessError(process.returncode, cmd[0])
     return process
+
+def socket_recv(socket, numbytes):
+    """
+    Read numbytes from the given socket.
+    
+    Raises IOError if connection has closed before all data could be read.
+    """
+
+    data = ""
+    while numbytes > 0:
+        try:
+            chunk = socket.recv(numbytes)
+        except IOError, e:
+            if e.errno == errno.EINTR:
+                continue
+            else:
+                raise
+
+        if not chunk:
+             raise IOError("Connection closed. Received '%s', need %d more bytes" % (data,numbytes))
+
+        data += chunk
+        numbytes -= len(chunk)
+
+    return data
+
-- 
GitLab