SemiSupervisedTrackerMixin

class lightning_pose.models.base.SemiSupervisedTrackerMixin[source]

Bases: object

Mixin class providing training step function for semi-supervised models.

Always mixed with BaseSupervisedTracker (which provides LightningModule methods). The conditional inheritance from BaseSupervisedTracker at TYPE_CHECKING time gives pyright visibility into log(), device, evaluate_labeled(), loss_factory, etc.

Methods Summary

evaluate_unlabeled(batch_dict[,Β stage,Β ...])

Compute and log the losses on a batch of unlabeled data (frames only).

get_loss_inputs_unlabeled(batch_dict)

Return predicted heatmaps and their softmaxes (estimated keypoints).

training_step(batch_dict,Β batch_idx)

Training step computes and logs both supervised and unsupervised losses.

Methods Documentation

evaluate_unlabeled(batch_dict: UnlabeledBatchDict | MultiviewUnlabeledBatchDict, stage: Literal['train', 'val', 'test'] | None = None, anneal_weight: float | Tensor = 1.0) Float[Tensor, ''][source]

Compute and log the losses on a batch of unlabeled data (frames only).

get_loss_inputs_unlabeled(batch_dict: UnlabeledBatchDict | MultiviewUnlabeledBatchDict) dict[source]

Return predicted heatmaps and their softmaxes (estimated keypoints).

training_step(batch_dict: SemiSupervisedBatchDict | SemiSupervisedHeatmapBatchDict, batch_idx: int) dict[str, Float[Tensor, '']][source]

Training step computes and logs both supervised and unsupervised losses.

__init__()
__new__(**kwargs)