multi-task: gender binary + age 4-class image cls (no bbox) PA-100K + MSP60K mix (160K crops) ⏳ 8hr 研究進行中 訓練日期:2026-05-07 | 5090-2 dual-GPU agent | 已完成 16 ablations,SWA + cross-dataset 還沒做
import torch, torch.nn as nn, timm
from PIL import Image
import torchvision.transforms as T
# Model 結構(從 train script 抽出)
class MultiHead(nn.Module):
def __init__(self, backbone_name, drop_rate=0.3, num_age=4):
super().__init__()
self.backbone = timm.create_model(
backbone_name, pretrained=False, num_classes=0,
global_pool="avg", drop_rate=drop_rate)
# probe feat_dim (mnv3/effb0 跟 num_features 不一致)
with torch.no_grad():
feat_dim = self.backbone(torch.zeros(1, 3, 64, 64)).shape[-1]
self.feat_dim = feat_dim
self.gender_head = nn.Linear(feat_dim, 1)
self.age_head = nn.Linear(feat_dim, num_age)
def forward(self, x):
f = self.backbone(x)
return self.gender_head(f).squeeze(-1), self.age_head(f)
# Load
ckpt = torch.load("age_gender_v20260507E_convnext_tiny.pt", weights_only=False)
# ckpt['args']['backbone'] = "convnext_tiny.fb_in22k_ft_in1k"
# ckpt['args']['img_h']=384, img_w=192
model = MultiHead(ckpt["args"]["backbone"]).eval()
model.load_state_dict(ckpt["model_state"])
# Inference (input: person crop)
mean = [0.485, 0.456, 0.406]; std = [0.229, 0.224, 0.225]
tf = T.Compose([T.Resize((384, 192)), T.ToTensor(), T.Normalize(mean, std)])
img = Image.open("person_crop.jpg").convert("RGB")
x = tf(img).unsqueeze(0)
with torch.no_grad():
g_logit, a_logit = model(x)
gender_prob = torch.sigmoid(g_logit).item() # > 0.5 = female
gender = "female" if gender_prob > 0.5 else "male"
age_idx = a_logit.argmax(dim=-1).item()
age_group = ["child", "young", "adult", "elder"][age_idx]
print(f"gender: {gender} ({gender_prob:.2f})")
print(f"age: {age_group}")
# 注意:young class 在 PA-100K/MSP60K 都沒有 supervised data,
# 所以 prediction 不會出 young;只會出 child/adult/elder
| 變體 | backbone | aug/loss | img | ep | g_acc | a_acc | a_f1 | child | adult | elder | 備註 |
|---|---|---|---|---|---|---|---|---|---|---|---|
| A baseline mnv3l | mobilenetv3_l | camaug | 384×192 | 12 | 0.838 | 0.886 | 0.621 | 0.54 | 0.95 | 0.38 | |
| B strongaug mnv3l | mobilenetv3_l | strong | 384×192 | 12 | 0.835 | 0.882 | 0.614 | 0.60 | 0.94 | 0.36 | |
| C balsamp mnv3l | mobilenetv3_l | camaug+wsamp | 384×192 | 12 | 0.830 | 0.913 | 0.617 | 0.54 | 0.98 | 0.17 | weighted sampler |
| D focal mnv3l | mobilenetv3_l | camaug+focal | 384×192 | 11 | 0.853 | 0.852 | 0.608 | 0.78 | 0.87 | 0.51 | focal γ=2 提升 elder |
| E convnext_tiny ⭐ | convnext_tiny | camaug | 384×192 | 12 | 0.857 | 0.924 | 0.683 | 0.73 | 0.96 | 0.34 | 整體冠軍 (lucky) |
| E2 convnext_tiny seed2026 | convnext_tiny | camaug | 384×192 | 11 | 0.860 | 0.923 | 0.679 | 0.74 | 0.96 | 0.32 | seed 對照 |
| E3 convnext_tiny seed7 | convnext_tiny | camaug | 384×192 | 12 | 0.857 | 0.923 | 0.689 | 0.73 | 0.96 | 0.38 | best a_f1 |
| E4 convnext_tiny seed2024 | convnext_tiny | camaug | 384×192 | 12 | 0.855 | 0.919 | 0.679 | 0.71 | 0.96 | 0.36 | |
| F efficientnet_b0 | efficientnet_b0 | camaug | 384×192 | 12 | 0.850 | 0.905 | 0.659 | 0.67 | 0.95 | 0.37 | |
| G convnext_small | convnext_small | camaug | 384×192 | 12 | 0.860 | 0.923 | 0.680 | 0.74 | 0.96 | 0.32 | 比 tiny 沒贏 |
| H img224 mnv3l | mobilenetv3_l | camaug | 224×224 | 12 | 0.823 | 0.874 | 0.593 | 0.55 | 0.94 | 0.32 | 解析度降反而差 |
| I ema mnv3l | mobilenetv3_l | camaug+ema | 384×192 | 12 | 0.835 | 0.885 | 0.617 | 0.57 | 0.94 | 0.36 | |
| J pa100k only | mobilenetv3_l | camaug | 384×192 | 7 | 0.841 | 0.924 | 0.608 | 0.65 | 0.98 | 0.11 | PA-100K 單 dataset |
| K msp60k only | mobilenetv3_l | camaug | 384×192 | 12 | 0.798 | 0.888 | 0.657 | 0.82 | 0.91 | 0.36 | MSP60K 單 dataset,elder 高但 g_acc 低 |
| M vanilla mnv3l | mobilenetv3_l | no-aug | 384×192 | 3 | 0.833 | 0.918 | 0.634 | 0.58 | 0.99 | 0.16 | no aug 對照 |
| N sqrtinv mnv3l | mobilenetv3_l | camaug+sqrtinv | 384×192 | 6 | 0.837 | 0.906 | 0.647 | 0.60 | 0.96 | 0.33 | sqrt-inv class weight |
| 訓練 source | g_acc | a_acc | child | adult | elder |
|---|---|---|---|---|---|
| PA-100K only (J) | 0.841 | 0.924 | 0.65 | 0.98 | 0.11 |
| MSP60K only (K) | 0.798 | 0.888 | 0.82 | 0.91 | 0.36 |
| PA+MSP mix (A baseline) | 0.838 | 0.886 | 0.54 | 0.95 | 0.38 |
| PA+MSP mix (E convnext) | 0.857 | 0.924 | 0.73 | 0.96 | 0.34 |
Generated 2026-05-07 | 訓練中(agent 跑 8hr 自主研究)| rai-vision-training | kaggle-reports.pages.dev