Python: Refactoring jupyter notebooks

12 March 2021

Data scientists love jupyter notebooks, and for good reason. Often once a concept or approach is found to be viable, though, code will need to be refactored as scripts in order for it to be ‘productionised’ and run behind an API1, from a scheduler2, or CD automation tool3.

Initial steps

Typically I follow the steps below as an aid to development:

  • Extract the python code from the Jupyter notebook: this is made easy with jupyter nbconvert --to script notebook.ipynb
  • Remove lines containing only whitespace4: sed -i '/^[[:space:]]*$/d' script.py5
  • Remove lines which contain the comments which related to the cell numbers: sed -i '/# In/d'
  • See what the linted output will look like: black --diff --color
  • If happy with that, lint so the code is nicely formatted: black

It’s also handy to have the notebook to hand for reference if course; github and gitlab will render them in the web interface, or a standalone tool like nbviewer-app can be used. Note that if you prefer to lint within the jupyter notebook, may be of interest as it allows that to be done conveniently.

Because it’s usually bit harder to interrogate the dataframes when using a script, rather than a notebook, there are a couple of handy additions that can be made.

Basic data validation of dataframe contents

Firstly, basic data validation on dataframe contents as the script runs.

By way of example: In scripts that process data e.g. before feeding into ML models for training or inference it is common to drop rows where the values are not as we expect; with a malfunction on joins etc. one can then end up with a column with entirely missing values, which are then dropped, resulting in a dataframe with 0 rows, which is something we never want. Similarly, we also never want a dataframe with duplicate column names. So it might make sense to check for these conditions, and trigger the check with a commonly-used function like print.

I tend to use a slightly enhanced version of print in scripts, below, which tells me which script and which line is being executed when the print statement is run:

import inspect
import time

devmode = True
ts = None
prev_ts = time.time()

def dprint(*args, **kwargs):
    global ts, prev_ts

    if devmode:
        cf = inspect.currentframe()
        ts = time.time()
        elapsed = (ts - prev_ts)
        filename = inspect.stack()[1][1].split('/')[-1]
        lineno = cf.f_back.f_lineno
            sep=" : ",
            end=" : ",
        if elapsed > 1:
            elapsed = f"{(ts - prev_ts):.0f} seconds"
            print(elapsed, end=" : ", sep="")
        prev_ts = ts
        print(*args, sep="\n", **kwargs)

And here is our function to check the dataframes:

def check_dataframes():
    issue = False
    for k, v in globals().copy().items():
        if isinstance(v, pd.DataFrame):
            dups = duplicate_column_names(v)
            if dups != None:
                issue = True
                print("\tDuplicated col names are", dups)
            if v.shape[0] == 0:
                issue = True
                print("\t", k, v.shape)
    if issue:
        bail("See above for dataframe issue")

Leaving a function to check the column names:

def duplicate_column_names(df: pd.DataFrame):
    from collections import Counter
    count = Counter()
    for col in df.columns:
        count[col] += 1
    ps_dups = pd.Series(count)

    if (ps_dups > 1).sum() > 0:
        return ps_dups[ps_dups > 1].index.tolist()
        return None

And bail is something along the lines of print(s); sys.exit(1).

Checking proportion isnull() of dataframe columns

Secondly, it would be handy to have a function which will tell you the shape and the proportion of e.g. NaN or otherwise missing values in the columns:

def df_inspect(df, name):
        "\tTop 20 columns based on proportion of values missing",
        (df.isnull().sum() / df.shape[0]).sort_values(ascending=False).head(),
        f"\tThere are {len(df.columns)} columns and the mean isnull proportion is {(df.isnull().sum() / df.shape[0]).mean():.2f}",
        f"\tShape of {name} is {df.shape}"

  1. Typically a REST API.
  2. Like Airflow, or cron.
  3. Like Jenkins.
  4. i.e. not containing any code, but potentially containing spaces, newlines (\n), tabs (\t) etc.
  5. GNU sed syntax; on MacOS you can brew install gnu-sed and use gsed instead of sed and it will work.