gwkokab.utils.train¶
Functions¶
|
Binary cross-entropy with logits (numerically stable). |
|
Load model and names from HDF5 (backward-compatible). |
|
Build an MLP with ReLU activations. |
|
Mean squared error loss. |
|
Predict outputs for inputs x. |
|
Read dataset (HDF5) into a DataFrame with columns = keys. |
|
Persist model weights and metadata to HDF5 (backward-compatible format). |
|
Train an MLP regressor with stable optimization and smooth loss curves. |
Module Contents¶
- gwkokab.utils.train.bce_logits_loss_fn(model: jaxtyping.PyTree, x: jaxtyping.Array, y: jaxtyping.Array, batch_size: int | None = 256, eps: float = 1e-06) jaxtyping.Array[source]¶
Binary cross-entropy with logits (numerically stable).
Expects targets in [0,1]; clips to [eps, 1-eps].
- gwkokab.utils.train.load_model(filename: str) Tuple[List[str], equinox.nn.MLP][source]¶
Load model and names from HDF5 (backward-compatible).
- gwkokab.utils.train.make_model(*, key: jaxtyping.PRNGKeyArray, input_layer: int, output_layer: int, width_size: int, depth: int) equinox.nn.MLP[source]¶
Build an MLP with ReLU activations.
- gwkokab.utils.train.mse_loss_fn(model: jaxtyping.PyTree, x: jaxtyping.Array, y: jaxtyping.Array, batch_size: int | None = 256) jaxtyping.Array[source]¶
Mean squared error loss.
- gwkokab.utils.train.predict(model: jaxtyping.PyTree, x: jaxtyping.Array, batch_size: int | None = 256) jaxtyping.Array[source]¶
Predict outputs for inputs x.
- gwkokab.utils.train.read_data(data_path: str, keys: collections.abc.Sequence[str]) pandas.DataFrame[source]¶
Read dataset (HDF5) into a DataFrame with columns = keys.
- gwkokab.utils.train.save_model(*, filepath: str, datafilepath: str, model: equinox.nn.MLP, names: collections.abc.Sequence[str] | None = None, is_log: bool = False) None[source]¶
Persist model weights and metadata to HDF5 (backward-compatible format).
- gwkokab.utils.train.train_regressor(*, input_keys: list[str], output_keys: list[str], width_size: int, depth: int, batch_size: int, data_path: str, checkpoint_path: str | None = None, epochs: int = 50, validation_split: float = 0.2, learning_rate: float = 0.001, train_in_log: bool = False, loss_type: str = 'mse', grad_clip_norm: float = 1.0, weight_decay: float = 0.0001, use_cosine_decay: bool = True, min_lr: float = 1e-06, warmup_epochs: int = 3, seed: int | None = 42) None[source]¶
Train an MLP regressor with stable optimization and smooth loss curves.
Notes
For detection probabilities in [0,1], prefer loss_type=”bce_logits” and do NOT set train_in_log=True (BCE expects probability targets, not log-values).
seed fixes the validation split for a less jittery val-loss.