前提
こういう感じで、複数の input に大してデフォルトの値を設定しておいて、個々の値 (foo_i_j) を微調整したいというロジックが作りたいときのレシピ。
作ったロジックでは、
- ユーザ入力があればそれを優先する
- ユーザ入力がない場合はデフォルト値を使う
- ユーザ入力があっても、デフォルト値が変わったら新しいデフォルト値を使う
となっている(もちろん3つ目はユースケースによると思うので適宜調整してね)。
コード
コードはこう
import streamlit as st NUM_COLS = 3 NUM_ROWS = 5 def on_submit(default_value): if st.session_state.get("default_value") != default_value: st.session_state["default_changed"] = True return st.session_state["default_changed"] = False def get_default_value(key, default_value): if not key in st.session_state: # no user input yet return default_value if st.session_state.get("default_changed"): return default_value return st.session_state[key] with st.form("form"): default_value = st.number_input("default value", step=1) st.form_submit_button("submit", on_click=lambda x: on_submit, args=(default_value,)) # ... other code for row_idx in range(NUM_ROWS): cols = st.columns(NUM_COLS) for col_idx in range(NUM_COLS): col = cols[col_idx] key = f"foo_{row_idx}_{col_idx}" col.number_input( key, key=key, value=get_default_value(key, default_value), step=1 ) if col_idx == NUM_COLS - 1: # last column cols = st.columns(1)
解説
まず form の中でデフォルト値を取得する。
with st.form("form"): default_value = st.number_input("default value", step=1)
そして、このデフォルト値が変わったかどうかは session state に保存する。
def on_submit(default_value): if st.session_state.get("default_value") != default_value: st.session_state["default_changed"] = True return st.session_state["default_changed"] = False # ... st.form_submit_button("submit", on_click=lambda x: on_submit, args=(default_value,))
個々の foo_i_j
の input では、デフォルト値の取得を上記の session_state を参照したロジックに切り出す。
def get_default_value(key, default_value): if not key in st.session_state: # no user input yet return default_value if st.session_state.get("default_changed"): return default_value return st.session_state[key] # ... col.number_input( key, key=key, value=get_default_value(key, default_value), step=1 )
おまけ
行列のグリッド形式で UI を表示させたいけど、たとえば画像とかを表示してて行ごとに高さが違うんだよな... というときは、一番最後の列で st.columns(NUM_COLUMNS)
を呼んであげると、行の頭の y を揃えられる。
for row_idx in range(NUM_ROWS): cols = st.columns(NUM_COLS) for col_idx in range(NUM_COLS): # ... other code if col_idx == NUM_COLS - 1: # last column cols = st.columns(1)