SemiSupervisedTrackerMixin
- class lightning_pose.models.base.SemiSupervisedTrackerMixin[source]
Bases:
objectMixin class providing training step function for semi-supervised models.
Methods Summary
evaluate_unlabeled(batch_dict[, stage, ...])Compute and log the losses on a batch of unlabeled data (frames only).
get_loss_inputs_unlabeled(batch_dict)Return predicted heatmaps and their softmaxes (estimated keypoints).
training_step(batch_dict, batch_idx)Training step computes and logs both supervised and unsupervised losses.
Methods Documentation
- evaluate_unlabeled(batch_dict: UnlabeledBatchDict | MultiviewUnlabeledBatchDict, stage: Literal['train', 'val', 'test'] | None = None, anneal_weight: float | Tensor = 1.0) float32), 'cls_name': 'TensorType'}][source]
Compute and log the losses on a batch of unlabeled data (frames only).
- get_loss_inputs_unlabeled(batch_dict: UnlabeledBatchDict) dict[source]
Return predicted heatmaps and their softmaxes (estimated keypoints).
- training_step(batch_dict: SemiSupervisedBatchDict | SemiSupervisedHeatmapBatchDict, batch_idx: int) dict[str, ~torch.Annotated[~torch.Tensor, {'__torchtyping__': True, 'details': ((), torch.float32), 'cls_name': 'TensorType'}]][source]
Training step computes and logs both supervised and unsupervised losses.
- __init__()