gwkokab.utils.train

Functions

bce_logits_loss_fn(→ jaxtyping.Array)

Binary cross-entropy with logits (numerically stable).

load_model(→ Tuple[List[str], equinox.nn.MLP])

Load model and names from HDF5 (backward-compatible).

make_model(→ equinox.nn.MLP)

Build an MLP with ReLU activations.

mse_loss_fn(→ jaxtyping.Array)

Mean squared error loss.

predict(→ jaxtyping.Array)

Predict outputs for inputs x.

read_data(→ pandas.DataFrame)

Read dataset (HDF5) into a DataFrame with columns = keys.

save_model(→ None)

Persist model weights and metadata to HDF5 (backward-compatible format).

train_regressor(→ None)

Train an MLP regressor with stable optimization and smooth loss curves.

Module Contents

gwkokab.utils.train.bce_logits_loss_fn(model: jaxtyping.PyTree, x: jaxtyping.Array, y: jaxtyping.Array, batch_size: int | None = 256, eps: float = 1e-06) jaxtyping.Array[source]

Binary cross-entropy with logits (numerically stable).

Expects targets in [0,1]; clips to [eps, 1-eps].

gwkokab.utils.train.load_model(filename: str) Tuple[List[str], equinox.nn.MLP][source]

Load model and names from HDF5 (backward-compatible).

gwkokab.utils.train.make_model(*, key: jaxtyping.PRNGKeyArray, input_layer: int, output_layer: int, width_size: int, depth: int) equinox.nn.MLP[source]

Build an MLP with ReLU activations.

gwkokab.utils.train.mse_loss_fn(model: jaxtyping.PyTree, x: jaxtyping.Array, y: jaxtyping.Array, batch_size: int | None = 256) jaxtyping.Array[source]

Mean squared error loss.

gwkokab.utils.train.predict(model: jaxtyping.PyTree, x: jaxtyping.Array, batch_size: int | None = 256) jaxtyping.Array[source]

Predict outputs for inputs x.

gwkokab.utils.train.read_data(data_path: str, keys: collections.abc.Sequence[str]) pandas.DataFrame[source]

Read dataset (HDF5) into a DataFrame with columns = keys.

gwkokab.utils.train.save_model(*, filepath: str, datafilepath: str, model: equinox.nn.MLP, names: collections.abc.Sequence[str] | None = None, is_log: bool = False) None[source]

Persist model weights and metadata to HDF5 (backward-compatible format).

gwkokab.utils.train.train_regressor(*, input_keys: list[str], output_keys: list[str], width_size: int, depth: int, batch_size: int, data_path: str, checkpoint_path: str | None = None, epochs: int = 50, validation_split: float = 0.2, learning_rate: float = 0.001, train_in_log: bool = False, loss_type: str = 'mse', grad_clip_norm: float = 1.0, weight_decay: float = 0.0001, use_cosine_decay: bool = True, min_lr: float = 1e-06, warmup_epochs: int = 3, seed: int | None = 42) None[source]

Train an MLP regressor with stable optimization and smooth loss curves.

Notes

  • For detection probabilities in [0,1], prefer loss_type=”bce_logits” and do NOT set train_in_log=True (BCE expects probability targets, not log-values).

  • seed fixes the validation split for a less jittery val-loss.