SemiSupervisedTrackerMixinο
- class lightning_pose.models.base.SemiSupervisedTrackerMixin[source]ο
Bases:
objectMixin class providing training step function for semi-supervised models.
Always mixed with BaseSupervisedTracker (which provides LightningModule methods). The conditional inheritance from BaseSupervisedTracker at TYPE_CHECKING time gives pyright visibility into log(), device, evaluate_labeled(), loss_factory, etc.
Methods Summary
evaluate_unlabeled(batch_dict[,Β stage,Β ...])Compute and log the losses on a batch of unlabeled data (frames only).
get_loss_inputs_unlabeled(batch_dict)Return predicted heatmaps and their softmaxes (estimated keypoints).
training_step(batch_dict,Β batch_idx)Training step computes and logs both supervised and unsupervised losses.
Methods Documentation
- evaluate_unlabeled(batch_dict: UnlabeledBatchDict | MultiviewUnlabeledBatchDict, stage: Literal['train', 'val', 'test'] | None = None, anneal_weight: float | Tensor = 1.0) Float[Tensor, ''][source]ο
Compute and log the losses on a batch of unlabeled data (frames only).
- get_loss_inputs_unlabeled(batch_dict: UnlabeledBatchDict | MultiviewUnlabeledBatchDict) dict[source]ο
Return predicted heatmaps and their softmaxes (estimated keypoints).
- training_step(batch_dict: SemiSupervisedBatchDict | SemiSupervisedHeatmapBatchDict, batch_idx: int) dict[str, Float[Tensor, '']][source]ο
Training step computes and logs both supervised and unsupervised losses.
- __init__()ο
- __new__(**kwargs)ο