multi-task: gender binary + age 4-class image cls (no bbox) PA-100K + MSP60K mix (160K crops) 13 ablations + SWA-4 訓練:2026-05-07 | 5090-2 dual-GPU agent | 5h wall-clock
import torch, torch.nn as nn, timm
from PIL import Image
import torchvision.transforms as T
class MultiHead(nn.Module):
def __init__(self, backbone_name, drop_rate=0.3, num_age=4):
super().__init__()
self.backbone = timm.create_model(backbone_name, pretrained=False,
num_classes=0, global_pool="avg",
drop_rate=drop_rate)
with torch.no_grad():
feat_dim = self.backbone(torch.zeros(1,3,64,64)).shape[-1]
self.gender_head = nn.Linear(feat_dim, 1)
self.age_head = nn.Linear(feat_dim, num_age)
def forward(self, x):
f = self.backbone(x)
return self.gender_head(f).squeeze(-1), self.age_head(f)
ckpt = torch.load("age_gender_v20260507_swa4.pt", weights_only=False)
model = MultiHead(ckpt["args"]["backbone"]).eval()
model.load_state_dict(ckpt["model_state"])
mean = [0.485, 0.456, 0.406]; std = [0.229, 0.224, 0.225]
tf = T.Compose([T.Resize((384, 192)), T.ToTensor(), T.Normalize(mean, std)])
x = tf(Image.open("person_crop.jpg").convert("RGB")).unsqueeze(0)
with torch.no_grad():
g_logit, a_logit = model(x)
gender = "female" if torch.sigmoid(g_logit).item() > 0.5 else "male"
age = ["child","young","adult","elder"][a_logit.argmax(dim=-1).item()]
# 注意:young 在 PA-100K + MSP60K 都沒 supervision,推論不會出 young
已上線 https://ppe-demo.intemotech.com/,dropdown 選「👥 年齡+性別 | v20260507 SWA-4 ⭐ + BoT-SORT」。
另含 BoT-SORT 追蹤 + 跨 frame majority vote:每 person bbox 顯示 T#7 ♂ 0.92 成人 0.97 (95f),95 是該 track 累積的 frame 數。Track 越長預測越穩定。新影片自動 reset tracker state。
Generated 2026-05-07 | rai-vision-training | 8hr autonomous research on 5090-2 | kaggle-reports.pages.dev