diff --git a/nbfetch/handlers.py b/nbfetch/handlers.py index cf59291..3cd4b4c 100644 --- a/nbfetch/handlers.py +++ b/nbfetch/handlers.py @@ -1,26 +1,24 @@ -from tornado import gen, web, locks -from tornado.escape import url_escape, url_unescape -import traceback -from urllib.parse import urljoin - -import threading import json import os +import pickle +import threading +import traceback from pathlib import Path -from queue import Queue, Empty -import jinja2 +from queue import Empty, Queue +from typing import Optional, Tuple +from urllib.parse import urljoin + +from cryptography.fernet import Fernet from hsclient import HydroShare -from .pull import GitPuller, HSPuller -from .version import __version__ -import pickle from jupyter_server.base.handlers import JupyterHandler from jupyter_server.extension.handler import ( - ExtensionHandlerMixin, - ExtensionHandlerJinjaMixin, + ExtensionHandlerJinjaMixin, ExtensionHandlerMixin, ) +from tornado import gen, locks, web +from tornado.escape import url_escape, url_unescape -# typing imports -from typing import Optional, Tuple +from .pull import GitPuller, HSPuller +from .version import __version__ class ExtensionHandler( @@ -52,12 +50,10 @@ def get(self): @gen.coroutine def post(self): - pwfile = os.path.expanduser("~/.hs_pass") - userfile = os.path.expanduser("~/.hs_user") - with open(userfile, "w") as f: - f.write(self.get_argument("name")) - with open(pwfile, "w") as f: - f.write(self.get_argument("pass")) + username = self.get_argument("name") + password = self.get_argument("pass") + # cache the user account + UserAccountHandler.save_account(username, password) self.redirect(url_unescape(self.get_argument("next", "/"))) @@ -341,9 +337,9 @@ def login(self): if hs is None: # If oauth fails, we can log in using - # user name and password. These are saved in + # username and password. These are saved in # files in the home directory. - username, password = _get_user_authentication() + username, password = UserAccountHandler.get_account() try: hs = self.check_auth(username=username, password=password) if hs is None: @@ -441,32 +437,56 @@ def get(self): self.flush() -def _get_user_authentication() -> Tuple[Optional[str], Optional[str]]: - """retrieve HS authentication from standard locations(see below) as tuple of username and - password. If either cannot be located, both tuple members are be None. +class UserAccountHandler: + """Handler for saving and retrieving user account information - Note: files take precedence over environment variables + Note: files take precedence over environment variables. If the keyfile does not exist, the user + and password will be retrieved from environment variables. standard file locations: + `~/.hs_key` `~/.hs_user` `~/.hs_pass` standard environment variables: `HS_USER` `HS_PASS` - - Returns - ------- - Tuple[Optional[str], Optional[str]] - username, password """ + keyfile = Path("~/.hs_key").expanduser() userfile = Path("~/.hs_user").expanduser() pwfile = Path("~/.hs_pass").expanduser() - user = ( - userfile.read_text().strip() - if userfile.exists() - else os.getenv("HS_USER", None) - ) + @classmethod + def save_account(cls, username: str, password: str) -> None: + key = Fernet.generate_key() + with open(cls.keyfile, "wb") as f: + f.write(key) + cypher = Fernet(key) + encrypted_user_name = cypher.encrypt(username.encode()) + encrypted_password = cypher.encrypt(password.encode()) + with open(cls.userfile, "wb") as f: + f.write(encrypted_user_name) + with open(cls.pwfile, "wb") as f: + f.write(encrypted_password) + + @classmethod + def get_account(cls) -> Tuple[Optional[str], Optional[str]]: + if cls.keyfile.exists(): + key = cls.keyfile.read_bytes() + else: + # if the keyfile does not exist, we cannot decrypt the user and password + # check for user and password in environment variables + return os.getenv("HS_USER", None), os.getenv("HS_PASS", None) + + cipher = Fernet(key) + if cls.userfile.exists(): + user_encrypted = cls.userfile.read_bytes() + user = cipher.decrypt(user_encrypted).decode() + else: + user = os.getenv("HS_USER", None) - pw = pwfile.read_text().strip() if pwfile.exists() else os.getenv("HS_PASS", None) + if cls.pwfile.exists(): + pw_encrypted = cls.pwfile.read_bytes() + pw = cipher.decrypt(pw_encrypted).decode() + else: + pw = os.getenv("HS_PASS", None) - return user, pw + return user, pw diff --git a/setup.py b/setup.py index 23759ee..43aa641 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ packages=find_packages(), include_package_data=True, platforms="any", - install_requires=["notebook==6.4.6", "tornado", "hsclient", "jupyter_server"], + install_requires=["notebook==6.4.6", "tornado", "hsclient", "jupyter_server", "cryptography==3.4.8"], extras_require={"develop": ["pytest", "pytest-jupyter"]}, data_files=[ (