From 1f75f24b7ad4591b895020c9bd5ed786f29ffe39 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Mar 2024 14:28:28 +0100 Subject: [PATCH 01/24] Add missing wandb finish for aborted jobs --- napari_cellseg3d/code_models/worker_training.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index d2789885..9144ebef 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -624,6 +624,9 @@ def train( del criterionW torch.cuda.empty_cache() + if WANDB_INSTALLED: + wandb.finish() + self.ncuts_losses.append( epoch_ncuts_loss / len(self.dataloader) ) @@ -736,7 +739,8 @@ def train( if epoch % 5 == 0: torch.save( model.state_dict(), - self.config.results_path_folder + "/wnet_.pth", + self.config.results_path_folder + + "/wnet_checkpoint.pth", ) self.log("Training finished") @@ -1523,6 +1527,9 @@ def get_patch_loader_func(num_samples): if device.type == "cuda": torch.cuda.empty_cache() + if WANDB_INSTALLED: + wandb.finish() + yield TrainingReport( show_plot=False, weights=model.state_dict(), From c73685a938dec2bc1e240badccc140cb69bfe5dd Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Mar 2024 14:44:11 +0100 Subject: [PATCH 02/24] Update utils.py --- napari_cellseg3d/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 45f50582..ab68731a 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -16,8 +16,8 @@ ############### # Global logging level setting # SET TO INFO FOR RELEASE -# LOGGER.setLevel(logging.DEBUG) -LOGGER.setLevel(logging.INFO) +LOGGER.setLevel(logging.DEBUG) +# LOGGER.setLevel(logging.INFO) ############### """ utils.py From ee8f54d89b2a5d7e890234d714444a94c583a018 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Mar 2024 14:45:52 +0100 Subject: [PATCH 03/24] Update plugin_model_training.py --- napari_cellseg3d/code_plugins/plugin_model_training.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 7d43265e..a34bf213 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1404,8 +1404,12 @@ def _make_csv(self): raise ValueError(err) else: ncuts_loss = self.loss_1_values["SoftNCuts"] + + logger.debug(f"Loss 1 values : {ncuts_loss}") + logger.debug(f"Loss 2 values : {self.loss_2_values}") try: dice_metric = self.loss_1_values["Dice metric"] + logger.debug(f"Dice metric : {dice_metric}") self.df = pd.DataFrame( { "Epoch": size_column, From 5329da27440fc8ab800b4e96f6eb77273e1d666b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Mar 2024 15:27:42 +0100 Subject: [PATCH 04/24] Update plugin_model_training.py --- napari_cellseg3d/code_plugins/plugin_model_training.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index a34bf213..24e31dcd 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1405,8 +1405,9 @@ def _make_csv(self): else: ncuts_loss = self.loss_1_values["SoftNCuts"] - logger.debug(f"Loss 1 values : {ncuts_loss}") - logger.debug(f"Loss 2 values : {self.loss_2_values}") + logger.debug(f"Epochs : {len(size_column)}") + logger.debug(f"Loss 1 values : {len(ncuts_loss)}") + logger.debug(f"Loss 2 values : {len(self.loss_2_values)}") try: dice_metric = self.loss_1_values["Dice metric"] logger.debug(f"Dice metric : {dice_metric}") From cc3d0869929905f33b3d5ca2c293abba13ec5f54 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Mar 2024 15:35:19 +0100 Subject: [PATCH 05/24] Fix num_epochs for aborted training in csv --- napari_cellseg3d/code_plugins/plugin_model_training.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 24e31dcd..d6b91a23 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1367,6 +1367,13 @@ def on_yield(self, report: TrainingReport): def _make_csv(self): size_column = range(1, self.worker_config.max_epochs + 1) + # this assumption does not hold when training is stopped + if len(size_column) != len(self.loss_2_values): + logger.info( + "Training was stopped, setting epochs for csv to", + len(self.loss_2_values), + ) + size_column = range(1, len(self.loss_2_values) + 1) if len(self.loss_1_values) == 0 or self.loss_1_values is None: logger.warning("No loss values to add to csv !") From 8f853b93e1f44ff7d26a1010b591c96f9452fcd4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Mar 2024 15:39:48 +0100 Subject: [PATCH 06/24] Fix Wnet weight loading --- napari_cellseg3d/code_models/worker_training.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index 9144ebef..7d52baa1 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -472,13 +472,17 @@ def train( if WANDB_INSTALLED: wandb.watch(model, log_freq=100) - if self.config.weights_info.use_custom: + if ( + self.config.weights_info.use_pretrained + or self.config.weights_info.use_custom + ): if self.config.weights_info.use_pretrained: weights_file = "wnet.pth" self.downloader.download_weights("WNet", weights_file) weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file) self.config.weights_info.path = weights - else: + + if self.config.weights_info.use_custom: weights = str(Path(self.config.weights_info.path)) try: From af4e01669d1a917ae78c28ab362bde988729de73 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Mar 2024 15:44:47 +0100 Subject: [PATCH 07/24] Fix reference to WNet weights for transfer learn --- napari_cellseg3d/code_models/worker_training.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index 7d52baa1..859d9f5c 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -477,7 +477,11 @@ def train( or self.config.weights_info.use_custom ): if self.config.weights_info.use_pretrained: - weights_file = "wnet.pth" + from napari_cellseg3d.code_models.models.model_WNet import ( + WNet_, + ) + + weights_file = WNet_.weights_file self.downloader.download_weights("WNet", weights_file) weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file) self.config.weights_info.path = weights From 0d2e74afb89533cd8b38fc2dd8374583af4dbc65 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Mar 2024 15:47:50 +0100 Subject: [PATCH 08/24] Update plugin_model_training.py --- napari_cellseg3d/code_plugins/plugin_model_training.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index d6b91a23..9944683f 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1370,8 +1370,7 @@ def _make_csv(self): # this assumption does not hold when training is stopped if len(size_column) != len(self.loss_2_values): logger.info( - "Training was stopped, setting epochs for csv to", - len(self.loss_2_values), + f"Training was stopped, setting epochs for csv to {len(self.loss_2_values)}" ) size_column = range(1, len(self.loss_2_values) + 1) From be62e21501125f20e242aba15fd7cbfd3763d127 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Mar 2024 15:54:06 +0100 Subject: [PATCH 09/24] Improved safety for closing previous wandb runs --- napari_cellseg3d/code_models/worker_training.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index 859d9f5c..e81f2ad5 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -423,6 +423,11 @@ def train( if WANDB_INSTALLED: config_dict = self.config.__dict__ logger.debug(f"wandb config : {config_dict}") + if wandb.run is not None: + logger.warning( + "A previous wandb run is still active. It will be stoppe before starting a new one." + ) + wandb.finish() wandb.init( config=config_dict, project="CellSeg3D - WNet", @@ -1132,6 +1137,11 @@ def train( config_dict = self.config.__dict__ logger.debug(f"wandb config : {config_dict}") try: + if wandb.run is not None: + logger.warning( + "A previous wandb run is still active. It will be stoppe before starting a new one." + ) + wandb.finish() wandb.init( config=config_dict, project="CellSeg3D", From 97a4b35106bf7404053db6ef2bce42e6115f5468 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Mar 2024 16:04:16 +0100 Subject: [PATCH 10/24] Fix rogue Path instead of str --- napari_cellseg3d/code_models/worker_training.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index e81f2ad5..3e45eeda 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -488,7 +488,7 @@ def train( weights_file = WNet_.weights_file self.downloader.download_weights("WNet", weights_file) - weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file) + weights = str(PRETRAINED_WEIGHTS_DIR / Path(weights_file)) self.config.weights_info.path = weights if self.config.weights_info.use_custom: @@ -1432,13 +1432,13 @@ def get_patch_loader_func(num_samples): # time = utils.get_date_time() logger.debug("Weights") - if weights_config.use_custom: + if weights_config.use_custom or weights_config.use_pretrained: if weights_config.use_pretrained: weights_file = model_class.weights_file self.downloader.download_weights(model_name, weights_file) - weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file) + weights = str(PRETRAINED_WEIGHTS_DIR / Path(weights_file)) weights_config.path = weights - else: + elif weights_config.use_custom: weights = str(Path(weights_config.path)) try: From 80746ed646a140ad2b00d8bb0504e724c0c46d26 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Mar 2024 16:16:14 +0100 Subject: [PATCH 11/24] Update worker_training.py --- napari_cellseg3d/code_models/worker_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index 3e45eeda..847f5f12 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -425,7 +425,7 @@ def train( logger.debug(f"wandb config : {config_dict}") if wandb.run is not None: logger.warning( - "A previous wandb run is still active. It will be stoppe before starting a new one." + "A previous wandb run is still active. It will be stopped before starting a new one." ) wandb.finish() wandb.init( @@ -1139,7 +1139,7 @@ def train( try: if wandb.run is not None: logger.warning( - "A previous wandb run is still active. It will be stoppe before starting a new one." + "A previous wandb run is still active. It will be stopped before starting a new one." ) wandb.finish() wandb.init( From 7ff2b63ea2d3619f3452b46a32aa475a8039baee Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Mar 2024 16:17:45 +0100 Subject: [PATCH 12/24] Change bounds for WNet train parameters --- napari_cellseg3d/code_plugins/plugin_model_training.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 9944683f..3d2d7e7d 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1054,7 +1054,7 @@ def _set_worker_config( logger.debug("Loading config...") model_config = config.ModelInfo(name=self.model_choice.currentText()) - self.weights_config.path = self.weights_config.path + # self.weights_config.path = self.weights_config.path self.weights_config.use_custom = self.custom_weights_choice.isChecked() self.weights_config.use_pretrained = ( @@ -1641,7 +1641,7 @@ def __init__(self, parent): text_label="Number of classes", ) self.intensity_sigma_choice = ui.DoubleIncrementCounter( - lower=1.0, + lower=0.01, upper=100.0, default=self.default_config.intensity_sigma, parent=parent, @@ -1649,7 +1649,7 @@ def __init__(self, parent): ) self.intensity_sigma_choice.setMaximumWidth(20) self.spatial_sigma_choice = ui.DoubleIncrementCounter( - lower=1.0, + lower=0.01, upper=100.0, default=self.default_config.spatial_sigma, parent=parent, From 0a585c802e418b3dbb7b14ec15409d04f5de0572 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Mar 2024 16:19:32 +0100 Subject: [PATCH 13/24] Shorten file names in train log --- napari_cellseg3d/code_models/worker_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index 847f5f12..15c48e54 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -388,7 +388,7 @@ def log_parameters(self): self.log("-- Data --") self.log("Training data :\n") [ - self.log(f"{v}") + self.log(f"{Path(v).stem}") for d in self.config.train_data_dict for k, v in d.items() ] From 86a8384714247004aaec3e7963cd27dd170a1cda Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 6 Mar 2024 10:17:45 +0100 Subject: [PATCH 14/24] More range foor rec loss weight --- napari_cellseg3d/code_plugins/plugin_model_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 3d2d7e7d..f39cd2a9 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1685,8 +1685,8 @@ def __init__(self, parent): ) self.reconstruction_weight_choice.setMaximumWidth(20) self.reconstruction_weight_divide_factor_choice = ( - ui.IntIncrementCounter( - lower=1, + ui.DoubleIncrementCounter( + lower=0.01, upper=10000, default=100, parent=parent, From 530b7a373eae8137e2c0fd8fc315736b3b514331 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 6 Mar 2024 10:18:41 +0100 Subject: [PATCH 15/24] Update Dice coeff calculation --- napari_cellseg3d/code_models/worker_training.py | 3 +-- napari_cellseg3d/code_plugins/plugin_model_training.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index 15c48e54..e37109dd 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -873,8 +873,7 @@ def eval(self, model, epoch) -> TrainingReport: self.dice_metric( y_pred=val_outputs[ :, - max_dice_channel : (max_dice_channel + 1), - :, + max_dice_channel:, # : (max_dice_channel + 1), :, :, ], diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index f39cd2a9..ef176d7a 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1687,8 +1687,8 @@ def __init__(self, parent): self.reconstruction_weight_divide_factor_choice = ( ui.DoubleIncrementCounter( lower=0.01, - upper=10000, - default=100, + upper=10000.0, + default=100.0, parent=parent, text_label="Reconstruction weight divide factor", ) From 3a418b2d6a0d8573f027541f9571a611683b003a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 11 Mar 2024 14:22:29 +0100 Subject: [PATCH 16/24] Fix matching len check for csv --- napari_cellseg3d/code_plugins/plugin_model_training.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index ef176d7a..7c342d9d 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1368,11 +1368,11 @@ def on_yield(self, report: TrainingReport): def _make_csv(self): size_column = range(1, self.worker_config.max_epochs + 1) # this assumption does not hold when training is stopped - if len(size_column) != len(self.loss_2_values): + if len(size_column) != len(self.loss_1_values): logger.info( - f"Training was stopped, setting epochs for csv to {len(self.loss_2_values)}" + f"Training was stopped, setting epochs for csv to {len(self.loss_1_values)}" ) - size_column = range(1, len(self.loss_2_values) + 1) + size_column = range(1, len(self.loss_1_values) + 1) if len(self.loss_1_values) == 0 or self.loss_1_values is None: logger.warning("No loss values to add to csv !") From 8b576573590c63c42711687c51a63a2cefc21230 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 21 Mar 2024 15:36:22 +0100 Subject: [PATCH 17/24] Update training_wnet.rst --- docs/source/guides/training_wnet.rst | 56 ++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/docs/source/guides/training_wnet.rst b/docs/source/guides/training_wnet.rst index 35aedd2d..3eacab2b 100644 --- a/docs/source/guides/training_wnet.rst +++ b/docs/source/guides/training_wnet.rst @@ -10,15 +10,35 @@ For training your model, you can choose among: * The provided Jupyter notebook (locally) * Our Colab notebook (inspired by ZeroCostDL4Mic) -The WNet does not require a large amount of data to train, but during inference images should be similar to those -the model was trained on; you can retrain from our pretrained model to your image dataset to quickly reach good performance. +Selecting training data +----------------------- + +The WNet **does not require a large amount of data to train**, but **choosing the right data** to train this unsupervised model **is crucial**. +You may find below some guidelines, based on our own data and testing. +The WNet is designed to segment objects based on their brightness, and is particularly well-suited for images with a clear contrast between objects and background. + +The WNet is not suitable for images with artifacts, therefore care should be taken that the images are clean and that the objects are at least somewhat distinguishable from the background. + +For optimal performance, the following should be avoided for training: +- Images with very large, bright regions +- Almost-empty and empty images +- Images with large empty regions or "holes" + +However, the model may be accomodate: +- Uneven brightness distribution +- Varied object shapes and radius +- Noisy images +- Uneven illumination across the image + +For optimal results, during inference, images should be similar to those the model was trained on; however this is not a strict requirement. +You can retrain from our pretrained model to your image dataset to quickly reach good performance, simply check "Use pre-trained weights" in the training module, and lower the learning rate. .. note:: - - The WNet relies on brightness to distinguish objects from the background. For better results, use image regions with minimal artifacts. If you notice many artifacts, consider training on one of the supervised models. - - The model has two losses, the **`SoftNCut loss`**, which clusters pixels according to brightness, and a reconstruction loss, either **`Mean Square Error (MSE)`** or **`Binary Cross Entropy (BCE)`**. Unlike the method described in the original paper, these losses are added in a weighted sum and the backward pass is performed for the whole model at once. The SoftNcuts and BCE are bounded between 0 and 1; the MSE may take large positive values. It is recommended to watch for the weighted sum of losses to be **close to one on the first epoch**, for training stability. - - For good performance, you should wait for the SoftNCut to reach a plateau; the reconstruction loss must also decrease but is generally less critical. + - The WNet relies on brightness to distinguish objects from the background. For better results, use image regions with minimal artifacts. If you notice many artifacts, consider trying one of our supervised models (for lightsheet microscopy). + - The model has two losses, the **`SoftNCut loss`**, which clusters pixels according to brightness, and a reconstruction loss, either **`Mean Square Error (MSE)`** or **`Binary Cross Entropy (BCE)`**. + - For good performance, wait for the SoftNCut to reach a plateau; the reconstruction loss should also be decreasing overall, but this is generally less critical for segmentation performance. -Parameters +Parameterss ---------- .. figure:: ../images/training_tab_4.png @@ -29,7 +49,7 @@ Parameters _`When using the WNet training module`, the **Advanced** tab contains a set of additional options: -- **Number of classes** : Dictates the segmentation classes (default is 2). Increasing the number of classes will result in a more progressive segmentation according to brightness; can be useful if you have "halos" around your objects or artifacts with a significantly different brightness. +- **Number of classes** : Dictates the segmentation classes (default is 2). Increasing the number of classes will result in a more progressive segmentation according to brightness; can be useful if you have "halos" around your objects, or to approximate boundary labels. - **Reconstruction loss** : Choose between MSE or BCE (default is MSE). MSE is more precise but also sensitive to outliers; BCE is more robust against outliers at the cost of precision. - NCuts parameters: @@ -43,22 +63,26 @@ _`When using the WNet training module`, the **Advanced** tab contains a set of a - Weights for the sum of losses : - **NCuts weight** : Sets the weight of the NCuts loss (default is 0.5). - - **Reconstruction weight** : Sets the weight for the reconstruction loss (default is 0.5*1e-2). + - **Reconstruction weight** : Sets the weight for the reconstruction loss (default is 5*1e-3). -.. note:: - The weight of the reconstruction loss should be adjusted to ensure the weighted sum is around one during the first epoch; - ideally the reconstruction loss should be of the same order of magnitude as the NCuts loss after being multiplied by its weight. +.. important:: + The weight of the reconstruction loss should be adjusted to ensure that both losses are balanced; + this balance can be assessed using the live view of training outputs : + if the NCuts loss is "taking over", causing the segmentation to only label very large, brighter versus dimmer regions, the reconstruction loss should be increased. + This will help the model to focus on the details of the objects, rather than just the overall brightness of the volume. Common issues troubleshooting ------------------------------ -If you do not find a satisfactory answer here, please do not hesitate to `open an issue`_ on GitHub. -- **The NCuts loss explodes after a few epochs** : Lower the learning rate, first by a factor of two, then ten. +.. important:: + If you do not find a satisfactory answer here, please do not hesitate to `open an issue`_ on GitHub. + + +- **The NCuts loss "explodes" after a few epochs** : Lower the learning rate, for example start with a factor of two, then ten. -- **The NCuts loss does not converge and is unstable** : - The normalization step might not be adapted to your images. Disable normalization and change intensity_sigma according to the distribution of values in your image. For reference, by default images are remapped to values between 0 and 100, and intensity_sigma=1. +- **Reconstruction (decoder) performance is poor** : First, try increasing the weight of the reconstruction loss. If this is ineffective, switch to BCE loss and set the scaling factor of the reconstruction loss to 0.5, OR adjust the weight of the MSE loss. -- **Reconstruction (decoder) performance is poor** : switch to BCE and set the scaling factor of the reconstruction loss to 0.5, OR adjust the weight of the MSE loss to make it closer to 1 in the weighted sum. +- **Segmentation only separates the brighter versus dimmer regions** : Increase the weight of the reconstruction loss. .. _WNet, A Deep Model for Fully Unsupervised Image Segmentation: https://arxiv.org/abs/1711.08506 From 76218bd3ed56faa0f50fe91d6aae9835bef92f74 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 21 Mar 2024 15:48:42 +0100 Subject: [PATCH 18/24] Update WNet docs + typos --- docs/source/guides/cropping_module_guide.rst | 2 +- docs/source/guides/training_wnet.rst | 42 ++++++++++++-------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/docs/source/guides/cropping_module_guide.rst b/docs/source/guides/cropping_module_guide.rst index 8853e601..f7db8495 100644 --- a/docs/source/guides/cropping_module_guide.rst +++ b/docs/source/guides/cropping_module_guide.rst @@ -59,7 +59,7 @@ Once you have launched the review process, you will gain control over three slid you to **adjust the position** of the cropped volumes and labels in the x,y and z positions. .. note:: - * If your **cropped volume isnt visible**, consider changing the **colormap** of the image and the cropped + * If your **cropped volume isn't visible**, consider changing the **colormap** of the image and the cropped volume to improve their visibility. * You may want to adjust the **opacity** and **contrast thresholds** depending on your image. * If the image appears empty: diff --git a/docs/source/guides/training_wnet.rst b/docs/source/guides/training_wnet.rst index 3eacab2b..7f6401e1 100644 --- a/docs/source/guides/training_wnet.rst +++ b/docs/source/guides/training_wnet.rst @@ -4,42 +4,50 @@ Advanced : WNet training ======================== This plugin provides a reimplemented, custom version of the WNet model from `WNet, A Deep Model for Fully Unsupervised Image Segmentation`_. + For training your model, you can choose among: * Directly within the plugin * The provided Jupyter notebook (locally) -* Our Colab notebook (inspired by ZeroCostDL4Mic) +* Our Colab notebook (inspired by https://github.com/HenriquesLab/ZeroCostDL4Mic) Selecting training data ------------------------ +------------------------- The WNet **does not require a large amount of data to train**, but **choosing the right data** to train this unsupervised model **is crucial**. + You may find below some guidelines, based on our own data and testing. + The WNet is designed to segment objects based on their brightness, and is particularly well-suited for images with a clear contrast between objects and background. The WNet is not suitable for images with artifacts, therefore care should be taken that the images are clean and that the objects are at least somewhat distinguishable from the background. -For optimal performance, the following should be avoided for training: -- Images with very large, bright regions -- Almost-empty and empty images -- Images with large empty regions or "holes" -However, the model may be accomodate: -- Uneven brightness distribution -- Varied object shapes and radius -- Noisy images -- Uneven illumination across the image +.. important:: + For optimal performance, the following should be avoided for training: + + - Images with very large, bright regions + - Almost-empty and empty images + - Images with large empty regions or "holes" + + However, the model may be accomodate: + + - Uneven brightness distribution + - Varied object shapes and radius + - Noisy images + - Uneven illumination across the image For optimal results, during inference, images should be similar to those the model was trained on; however this is not a strict requirement. -You can retrain from our pretrained model to your image dataset to quickly reach good performance, simply check "Use pre-trained weights" in the training module, and lower the learning rate. + +You may also retrain from our pretrained model to your image dataset to help quickly reach good performance if, simply check "Use pre-trained weights" in the training module, and lower the learning rate. .. note:: - The WNet relies on brightness to distinguish objects from the background. For better results, use image regions with minimal artifacts. If you notice many artifacts, consider trying one of our supervised models (for lightsheet microscopy). - The model has two losses, the **`SoftNCut loss`**, which clusters pixels according to brightness, and a reconstruction loss, either **`Mean Square Error (MSE)`** or **`Binary Cross Entropy (BCE)`**. - For good performance, wait for the SoftNCut to reach a plateau; the reconstruction loss should also be decreasing overall, but this is generally less critical for segmentation performance. -Parameterss ----------- +Parameters +------------- .. figure:: ../images/training_tab_4.png :scale: 100 % @@ -66,9 +74,11 @@ _`When using the WNet training module`, the **Advanced** tab contains a set of a - **Reconstruction weight** : Sets the weight for the reconstruction loss (default is 5*1e-3). .. important:: - The weight of the reconstruction loss should be adjusted to ensure that both losses are balanced; - this balance can be assessed using the live view of training outputs : + The weight of the reconstruction loss should be adjusted to ensure that both losses are balanced. + + This balance can be assessed using the live view of training outputs : if the NCuts loss is "taking over", causing the segmentation to only label very large, brighter versus dimmer regions, the reconstruction loss should be increased. + This will help the model to focus on the details of the objects, rather than just the overall brightness of the volume. Common issues troubleshooting From de510296ab8f34600bddb45a517beeb3a465b907 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 21 Mar 2024 15:53:05 +0100 Subject: [PATCH 19/24] Update utils.py --- napari_cellseg3d/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index ab68731a..45f50582 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -16,8 +16,8 @@ ############### # Global logging level setting # SET TO INFO FOR RELEASE -LOGGER.setLevel(logging.DEBUG) -# LOGGER.setLevel(logging.INFO) +# LOGGER.setLevel(logging.DEBUG) +LOGGER.setLevel(logging.INFO) ############### """ utils.py From 1d13fe5ddf51348fb7d19d29c3913e2df1d336eb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 25 Mar 2024 15:09:29 +0100 Subject: [PATCH 20/24] Change rec loss weight default --- napari_cellseg3d/code_plugins/plugin_model_training.py | 2 +- napari_cellseg3d/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 7c342d9d..41c8b80b 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1688,7 +1688,7 @@ def __init__(self, parent): ui.DoubleIncrementCounter( lower=0.01, upper=10000.0, - default=100.0, + default=1.0, parent=parent, text_label="Reconstruction weight divide factor", ) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 1b2af90f..002c6024 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -399,7 +399,7 @@ class WNetTrainingWorkerConfig(TrainingWorkerConfig): reconstruction_loss: str = "MSE" # or "BCE" # summed losses weights n_cuts_weight: float = 0.5 - rec_loss_weight: float = 0.5 / 100 + rec_loss_weight: float = 0.5 / 1.0 # 0.5 / 100 # normalization params # normalizing_function: callable = remap_image # FIXME: call directly in worker, not a param # data params From 3507c32b8cfaa0a5e1620edf09e24379c77c776e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 25 Mar 2024 15:27:16 +0100 Subject: [PATCH 21/24] Change discrete output display for WNet training --- napari_cellseg3d/code_models/worker_training.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index e37109dd..9ad08d24 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -658,9 +658,7 @@ def train( "cmap": "turbo", }, "Encoder output (discrete)": { - "data": AsDiscrete(threshold=0.5)( - enc_out - ).numpy(), + "data": np.where(enc_out > 0.5, enc_out, 0), "cmap": "bop blue", }, "Decoder output": { From a933fd05db984416391ff936a931a5806a77ceae Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 25 Mar 2024 15:32:19 +0100 Subject: [PATCH 22/24] Fix incorrect checks for csv saving --- .../code_plugins/plugin_model_training.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 41c8b80b..b9272bb8 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1368,23 +1368,40 @@ def on_yield(self, report: TrainingReport): def _make_csv(self): size_column = range(1, self.worker_config.max_epochs + 1) # this assumption does not hold when training is stopped - if len(size_column) != len(self.loss_1_values): - logger.info( - f"Training was stopped, setting epochs for csv to {len(self.loss_1_values)}" - ) - size_column = range(1, len(self.loss_1_values) + 1) - - if len(self.loss_1_values) == 0 or self.loss_1_values is None: - logger.warning("No loss values to add to csv !") - return try: self.loss_1_values["Loss"] supervised = True + if len(size_column) != len(self.loss_1_values["Loss"]): + logger.info( + f"Training was stopped, setting epochs for csv to {len(self.loss_1_values['Loss'])}" + ) + size_column = range(1, len(self.loss_1_values["Loss"]) + 1) + + if ( + len(self.loss_1_values["Loss"]) == 0 + or self.loss_1_values["Loss"] is None + ): + logger.warning("No loss values to add to csv !") + return except KeyError: try: self.loss_1_values["SoftNCuts"] supervised = False + if len(size_column) != len(self.loss_1_values["SoftNCuts"]): + logger.info( + f"Training was stopped, setting epochs for csv to {len(self.loss_1_values['SoftNCuts'])}" + ) + size_column = range( + 1, len(self.loss_1_values["SoftNCuts"]) + 1 + ) + + if ( + len(self.loss_1_values["SoftNCuts"]) == 0 + or self.loss_1_values["SoftNCuts"] is None + ): + logger.warning("No loss values to add to csv !") + return except KeyError as e: raise KeyError( "Error when making csv. Check loss dict keys ?" From c22d2a6de7c91301231de41dbfdba86d1cea4060 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 25 Mar 2024 15:42:10 +0100 Subject: [PATCH 23/24] Small improvement to make_csv --- .../code_plugins/plugin_model_training.py | 70 ++++++++----------- 1 file changed, 30 insertions(+), 40 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index b9272bb8..51f6dd72 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1365,54 +1365,44 @@ def on_yield(self, report: TrainingReport): self.on_stop() self._stop_requested = False - def _make_csv(self): - size_column = range(1, self.worker_config.max_epochs + 1) - # this assumption does not hold when training is stopped + def _check_lens(self, size_column, loss_values): + if len(size_column) != len(loss_values): + logger.info( + f"Training was stopped, setting epochs for csv to {len(loss_values)}" + ) + return range(1, len(loss_values) + 1) + return size_column - try: - self.loss_1_values["Loss"] - supervised = True - if len(size_column) != len(self.loss_1_values["Loss"]): - logger.info( - f"Training was stopped, setting epochs for csv to {len(self.loss_1_values['Loss'])}" - ) - size_column = range(1, len(self.loss_1_values["Loss"]) + 1) + def _handle_loss_values(self, size_column, key): + loss_values = self.loss_1_values.get(key) + if loss_values is None: + return None - if ( - len(self.loss_1_values["Loss"]) == 0 - or self.loss_1_values["Loss"] is None - ): - logger.warning("No loss values to add to csv !") - return - except KeyError: - try: - self.loss_1_values["SoftNCuts"] - supervised = False - if len(size_column) != len(self.loss_1_values["SoftNCuts"]): - logger.info( - f"Training was stopped, setting epochs for csv to {len(self.loss_1_values['SoftNCuts'])}" - ) - size_column = range( - 1, len(self.loss_1_values["SoftNCuts"]) + 1 - ) + if len(loss_values) == 0: + logger.warning("No loss values to add to csv !") + return None - if ( - len(self.loss_1_values["SoftNCuts"]) == 0 - or self.loss_1_values["SoftNCuts"] is None - ): - logger.warning("No loss values to add to csv !") - return - except KeyError as e: - raise KeyError( - "Error when making csv. Check loss dict keys ?" - ) from e + return self._check_lens(size_column, loss_values) + + def _make_csv(self): # TDOD(cyril) design could use a good rework + size_column = range(1, self.worker_config.max_epochs + 1) + + supervised = False + size_column = self._handle_loss_values(size_column, "Loss") + if size_column is None: + size_column = self._handle_loss_values(size_column, "SoftNCuts") + if size_column is None: + raise KeyError("Error when making csv. Check loss dict keys ?") + supervised = True if supervised: - val = utils.fill_list_in_between( + val = utils.fill_list_in_between( # fills the validation list based on validation interval self.loss_2_values, self.worker_config.validation_interval - 1, "", - )[: len(size_column)] + )[ + : len(size_column) + ] self.df = pd.DataFrame( { From 74efbcf9e8ed3159f5ae7b84b6200161acdd0245 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 15 Apr 2024 18:08:37 +0200 Subject: [PATCH 24/24] Fix logic issue in make_csv --- napari_cellseg3d/code_plugins/plugin_model_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 51f6dd72..a668e155 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1387,13 +1387,13 @@ def _handle_loss_values(self, size_column, key): def _make_csv(self): # TDOD(cyril) design could use a good rework size_column = range(1, self.worker_config.max_epochs + 1) - supervised = False + supervised = True size_column = self._handle_loss_values(size_column, "Loss") if size_column is None: size_column = self._handle_loss_values(size_column, "SoftNCuts") if size_column is None: raise KeyError("Error when making csv. Check loss dict keys ?") - supervised = True + supervised = False if supervised: val = utils.fill_list_in_between( # fills the validation list based on validation interval