SemiSupervisedTrackerMixin

class lightning_pose.models.base.SemiSupervisedTrackerMixin[source]

Bases: object

Mixin class providing training step function for semi-supervised models.

Methods Summary

evaluate_unlabeled(batch_dict[, stage, ...])

Compute and log the losses on a batch of unlabeled data (frames only).

get_loss_inputs_unlabeled(batch_dict)

Return predicted heatmaps and their softmaxes (estimated keypoints).

training_step(batch_dict, batch_idx)

Training step computes and logs both supervised and unsupervised losses.

Methods Documentation

evaluate_unlabeled(batch_dict: UnlabeledBatchDict | MultiviewUnlabeledBatchDict, stage: Literal['train', 'val', 'test'] | None = None, anneal_weight: float | Tensor = 1.0) Tensor[Tensor][source]

Compute and log the losses on a batch of unlabeled data (frames only).

get_loss_inputs_unlabeled(batch_dict: UnlabeledBatchDict) dict[source]

Return predicted heatmaps and their softmaxes (estimated keypoints).

training_step(batch_dict: SemiSupervisedBatchDict | SemiSupervisedHeatmapBatchDict, batch_idx: int) Dict[str, Tensor[Tensor]][source]

Training step computes and logs both supervised and unsupervised losses.