小ネタ: streamlit でデフォルト値を設定する

前提

こういう感じで、複数の input に大してデフォルトの値を設定しておいて、個々の値 (foo_i_j) を微調整したいというロジックが作りたいときのレシピ。

作ったロジックでは、

  • ユーザ入力があればそれを優先する
  • ユーザ入力がない場合はデフォルト値を使う
  • ユーザ入力があっても、デフォルト値が変わったら新しいデフォルト値を使う

となっている(もちろん3つ目はユースケースによると思うので適宜調整してね)。

コード

コードはこう

import streamlit as st

NUM_COLS = 3
NUM_ROWS = 5


def on_submit(default_value):
    if st.session_state.get("default_value") != default_value:
        st.session_state["default_changed"] = True
        return
    st.session_state["default_changed"] = False


def get_default_value(key, default_value):
    if not key in st.session_state:
        # no user input yet
        return default_value

    if st.session_state.get("default_changed"):
        return default_value
    return st.session_state[key]


with st.form("form"):
    default_value = st.number_input("default value", step=1)
    st.form_submit_button("submit", on_click=lambda x: on_submit, args=(default_value,))

# ... other code
for row_idx in range(NUM_ROWS):
    cols = st.columns(NUM_COLS)
    for col_idx in range(NUM_COLS):
        col = cols[col_idx]
        key = f"foo_{row_idx}_{col_idx}"
        col.number_input(
            key, key=key, value=get_default_value(key, default_value), step=1
        )

        if col_idx == NUM_COLS - 1:
            # last column
            cols = st.columns(1)

解説

まず form の中でデフォルト値を取得する。

with st.form("form"):
    default_value = st.number_input("default value", step=1)

そして、このデフォルト値が変わったかどうかは session state に保存する。

def on_submit(default_value):
    if st.session_state.get("default_value") != default_value:
        st.session_state["default_changed"] = True
        return
    st.session_state["default_changed"] = False

# ...

st.form_submit_button("submit", on_click=lambda x: on_submit, args=(default_value,))

個々の foo_i_j の input では、デフォルト値の取得を上記の session_state を参照したロジックに切り出す。

def get_default_value(key, default_value):
    if not key in st.session_state:
        # no user input yet
        return default_value

    if st.session_state.get("default_changed"):
        return default_value
    return st.session_state[key]

# ...

    col.number_input(
            key, key=key, value=get_default_value(key, default_value), step=1
        )

おまけ

行列のグリッド形式で UI を表示させたいけど、たとえば画像とかを表示してて行ごとに高さが違うんだよな... というときは、一番最後の列で st.columns(NUM_COLUMNS) を呼んであげると、行の頭の y を揃えられる。

for row_idx in range(NUM_ROWS):
    cols = st.columns(NUM_COLS)
    for col_idx in range(NUM_COLS):

        # ... other code
    
        if col_idx == NUM_COLS - 1:
            # last column
            cols = st.columns(1)