gwkokab.utils.train =================== .. py:module:: gwkokab.utils.train Functions --------- .. autoapisummary:: gwkokab.utils.train.bce_logits_loss_fn gwkokab.utils.train.load_model gwkokab.utils.train.make_model gwkokab.utils.train.mse_loss_fn gwkokab.utils.train.predict gwkokab.utils.train.read_data gwkokab.utils.train.save_model gwkokab.utils.train.train_regressor Module Contents --------------- .. py:function:: bce_logits_loss_fn(model: jaxtyping.PyTree, x: jaxtyping.Array, y: jaxtyping.Array, batch_size: Optional[int] = 256, eps: float = 1e-06) -> jaxtyping.Array Binary cross-entropy with logits (numerically stable). Expects targets in [0,1]; clips to [eps, 1-eps]. .. py:function:: load_model(filename: str) -> Tuple[List[str], equinox.nn.MLP] Load model and names from HDF5 (backward-compatible). .. py:function:: make_model(*, key: jaxtyping.PRNGKeyArray, input_layer: int, output_layer: int, width_size: int, depth: int) -> equinox.nn.MLP Build an MLP with ReLU activations. .. py:function:: mse_loss_fn(model: jaxtyping.PyTree, x: jaxtyping.Array, y: jaxtyping.Array, batch_size: Optional[int] = 256) -> jaxtyping.Array Mean squared error loss. .. py:function:: predict(model: jaxtyping.PyTree, x: jaxtyping.Array, batch_size: Optional[int] = 256) -> jaxtyping.Array Predict outputs for inputs x. .. py:function:: read_data(data_path: str, keys: collections.abc.Sequence[str]) -> pandas.DataFrame Read dataset (HDF5) into a DataFrame with columns = keys. .. py:function:: save_model(*, filepath: str, datafilepath: str, model: equinox.nn.MLP, names: Optional[collections.abc.Sequence[str]] = None, is_log: bool = False) -> None Persist model weights and metadata to HDF5 (backward-compatible format). .. py:function:: train_regressor(*, input_keys: list[str], output_keys: list[str], width_size: int, depth: int, batch_size: int, data_path: str, checkpoint_path: Optional[str] = None, epochs: int = 50, validation_split: float = 0.2, learning_rate: float = 0.001, train_in_log: bool = False, loss_type: str = 'mse', grad_clip_norm: float = 1.0, weight_decay: float = 0.0001, use_cosine_decay: bool = True, min_lr: float = 1e-06, warmup_epochs: int = 3, seed: Optional[int] = 42) -> None Train an MLP regressor with stable optimization and smooth loss curves. .. rubric:: Notes - For detection probabilities in [0,1], prefer `loss_type="bce_logits"` and do NOT set `train_in_log=True` (BCE expects probability targets, not log-values). - `seed` fixes the validation split for a less jittery val-loss.