Refactoring Jupyter notebooks

12 March 2021

Data scientists love jupyter notebooks, and for good reason. Often once a concept or approach is found to be viable, though, code will need to be refactored as scripts in order for it to be ‘productionised’ and run behind an API1, from a scheduler2, or CD automation tool3.

Initial steps

Typically I follow the steps below as an aid to development:

  • Extract the python code from the Jupyter notebook: this is made easy with jupyter nbconvert --to script notebook.ipynb
  • Remove lines containing only whitespace4: sed -i '/^[[:space:]]*$/d' script.py5
  • Remove lines which contain the comments which related to the cell numbers: sed -i '/# In/d' script.py
  • See what the linted output will look like: black --diff --color script.py
  • If happy with that, lint so the code is nicely formatted: black script.py

It’s also handy to have the notebook to hand for reference if course; github and gitlab will render them in the web interface, or a standalone tool like nbviewer-app can be used. Note that if you prefer to lint within the jupyter notebook, https://github.com/drillan/jupyter-black may be of interest as it allows that to be done conveniently.

Because it’s usually bit harder to interrogate the dataframes when using a script, rather than a notebook, there are a couple of handy additions that can be made.

Basic data validation of dataframe contents

Firstly, basic data validation on dataframe contents as the script runs.

By way of example: In scripts that process data e.g. before feeding into ML models for training or inference it is common to drop rows where the values are not as we expect; with a malfunction on joins etc. one can then end up with a column with entirely missing values, which are then dropped, resulting in a dataframe with 0 rows, which is something we never want. Similarly, we also never want a dataframe with duplicate column names. So it might make sense to check for these conditions, and trigger the check with a commonly-used function like print.

I tend to use a slightly enhanced version of print in scripts, below, which tells me which script and which line is being executed when the print statement is run:

import inspect
import time

devmode = True
ts = None
prev_ts = time.time()

def dprint(*args, **kwargs):
    global ts, prev_ts

    if devmode:
        cf = inspect.currentframe()
        ts = time.time()
        elapsed = (ts - prev_ts)
        filename = inspect.stack()[1][1].split('/')[-1]
        lineno = cf.f_back.f_lineno
        print(
            f"{filename}",
            f"{lineno}",
            sep=" : ",
            end=" : ",
        )
        if elapsed > 1:
            elapsed = f"{(ts - prev_ts):.0f} seconds"
            print(elapsed, end=" : ", sep="")
        prev_ts = ts
        print(*args, sep="\n", **kwargs)
        check_dataframes()

And here is our function to check the dataframes:

def check_dataframes():
    issue = False
    for k, v in globals().copy().items():
        if isinstance(v, pd.DataFrame):
            dups = duplicate_column_names(v)
            if dups != None:
                issue = True
                print("\tDuplicated col names are", dups)
            if v.shape[0] == 0:
                issue = True
                print("\t", k, v.shape)
    if issue:
        bail("See above for dataframe issue")

Leaving a function to check the column names:

def duplicate_column_names(df: pd.DataFrame):
    from collections import Counter
    count = Counter()
    for col in df.columns:
        count[col] += 1
    ps_dups = pd.Series(count)

    if (ps_dups > 1).sum() > 0:
        return ps_dups[ps_dups > 1].index.tolist()
    else:
        return None

And bail is something along the lines of print(s); sys.exit(1).

Checking proportion isnull() of dataframe columns

Secondly, it would be handy to have a function which will tell you the shape and the proportion of e.g. NaN or otherwise missing values in the columns:

def df_inspect(df, name):
    dprint(
        f"\n\n\t{name}",
        "\tTop 20 columns based on proportion of values missing",
        (df.isnull().sum() / df.shape[0]).sort_values(ascending=False).head(),
        f"\tThere are {len(df.columns)} columns and the mean isnull proportion is {(df.isnull().sum() / df.shape[0]).mean():.2f}",
        f"\tShape of {name} is {df.shape}"
    )

  1. Typically a REST API.
  2. Like Airflow, or cron.
  3. Like Jenkins.
  4. i.e. not containing any code, but potentially containing spaces, newlines (\n), tabs (\t) etc.
  5. GNU sed syntax; on MacOS you can brew install gnu-sed and use gsed instead of sed and it will work.