From d70f8a329271cf1a07ffaa04a6a874f3a24fc473 Mon Sep 17 00:00:00 2001
From: Nico Vermaas <vermaas@astron.nl>
Date: Tue, 9 Aug 2022 11:52:52 +0200
Subject: [PATCH] add optional private key and private key password

---
 README.md                                     | 31 +++++++++++++++
 .../connection/retrieve_db_connection.py      | 39 +++++++++++++------
 2 files changed, 58 insertions(+), 12 deletions(-)

diff --git a/README.md b/README.md
index 0728c44..0c189f6 100644
--- a/README.md
+++ b/README.md
@@ -96,6 +96,35 @@ pip install -e "git+https://git.astron.nl/ldv/ldv_utils.git#egg=ldvspec-migratio
 
 ```
 
+### Configuration
+The database and tunnel configuration are in a local file on the host that can be given as a `--configuration` parameter.
+The parameter file can contain a link to a private key file, and password. 
+When those keys are not given, the script will try to read the local SSH_CONFIG file `~/.ssh/config`. (Note that this does not work on Windows)
+
+See for more documentation about the sshtunnel mechanism:
+https://pypi.org/project/sshtunnel/ 
+
+The following example shows a local configuration using private key.
+```
+[postgresql-local]
+host=localhost
+port=5433
+database=ldv-spec-db
+user=postgres
+password=xxxxx
+
+[postgresql-ldv]
+tunnelhost=dop821.astron.nl
+tunnelusername=sdco
+host=sdc-db.astron.nl
+port=5432
+database=ldvadmin
+user=ldvrbow
+password=xxxxx
+ssh_pkey = "C:\\Program Files Nico\\putty\\astron_private_key.ppk"
+ssh_private_key_password = "xxxxx"
+```
+
 ### Running
 To test if it works
 ```bash
@@ -132,4 +161,6 @@ Some examples:
      ldv_migrate --limit 50000 --max_nbr_dps_to_insert_per_request 10000
 - Import only 1000 records at production:
      ldv_migrate --limit 1000 --host prod
+
+ldv_migrate --limit 1000 --verbose --configuration ~/shared/ldv_migrate.cfg
 ```
diff --git a/ldv_migrate/ldv_migrate/connection/retrieve_db_connection.py b/ldv_migrate/ldv_migrate/connection/retrieve_db_connection.py
index 52f1184..20d13f7 100644
--- a/ldv_migrate/ldv_migrate/connection/retrieve_db_connection.py
+++ b/ldv_migrate/ldv_migrate/connection/retrieve_db_connection.py
@@ -48,21 +48,36 @@ def open_tunnel(configuration_params):
     host = configuration_params.get('host', "no host given")
     port = int(configuration_params.get('port', "no port given"))
 
-    try:
-        ssh_config_file = os.path.expanduser("~/.ssh/config")
-    except FileNotFoundError as exc:
-        raise FileNotFoundError(
-            "Ssh config file not found on standard path '~/.ssh/config'. This is mandatory for opening the ssh tunnel"
-        ) from exc
+    # check if a private key and password was given
+    ssh_pkey = configuration_params.get('ssh_pkey',None)
+    ssh_private_key_password = configuration_params.get('ssh_private_key_password',None)
 
     logging.info("Creating ssh tunnel for %s and port %s with tunnel host %s and username %s", repr(host), port,
                  repr(tunnel_host), repr(tunnel_username))
-    ssh_tunnel = SSHTunnelForwarder(
-        ssh_address_or_host=tunnel_host,
-        ssh_username=tunnel_username,
-        ssh_config_file=ssh_config_file,
-        remote_bind_address=(host, port)
-    )
+
+    if ssh_pkey:
+        ssh_tunnel = SSHTunnelForwarder(
+            ssh_address_or_host=tunnel_host,
+            ssh_username=tunnel_username,
+            remote_bind_address=(host, port),
+            ssh_pkey = ssh_pkey,
+            ssh_private_key_password = ssh_private_key_password
+        )
+    else:
+        try:
+            ssh_config_file = os.path.expanduser("~/.ssh/config")
+        except FileNotFoundError as exc:
+            raise FileNotFoundError(
+                "Ssh config file not found on standard path '~/.ssh/config'. This is mandatory for opening the ssh tunnel"
+            ) from exc
+
+        ssh_tunnel = SSHTunnelForwarder(
+            ssh_address_or_host=tunnel_host,
+            ssh_username=tunnel_username,
+            ssh_config_file=ssh_config_file,
+            remote_bind_address=(host, port),
+        )
+
     ssh_tunnel.start()
     return ssh_tunnel
 
-- 
GitLab