CVPR 2022에서 발표된 WSSS (Weakly Supervised Semantic Segmentation) 관련 논문(링크)입니다. 'Semantic' 사전적으로 '의미'라는 뜻을 가지고 있는데, 말 그대로 이미지 안에서 '의미' 단위로 '분할'하는 것이 Semantic Segmentation 이 되겠습니다.
특히나 저는 요즘 의료이미지를 다루고 있기 때문에, Medical Image에 적합한 모델을 찾던 도중 이 논문을 발견했는데요. 코드도 공개했다길래 보니까 지금은 닫혀있네요;
Weakly Supervised Learning 이 그 방법상 인류에게 더 도움이 될 거 같다는 생각에 최근에 많이 연구하고 있습니다. Segmentation, Detection을 위해 일일이 사람이 annotation 하는 수고를 할 필요가 없기 때문이죠. 여기서는 Causality(인과 관계?)라는 개념을 이용해 성능을 높였다고 합니다.
1. Introduction
- 최근에 pixel 단위 예측 모델인 Segmentation 연구가 활발하지만, pixel-level labeled data를 얻기위해선 시간과 노력(돈)이 많이 필요하다. 이 문제를 해결하기 위해 Weakly Supervised Semantic Segmentation가 등장한다.
- WSSS는 보통 image-level label, point, scribble(낙서), bounding box 등의 약한 annotation들을 활용하는데, 이 중에서 image-level label이 가장 데이터 마련하기 쉬운 동시에 가장 segmentation하기 힘든 데이터 유형이다.
- WSSS의 가장 큰 문제점은 'Location information'이 없다는 것. CNN 네트워크에서 그런 정보를 얻을 수 있는 Class activation mapping (CAM) 기법이 있으나, 이것만으론 segmentation 성능을 보장할 수 없다. 몇몇 기법들은 CAM을 이용해 seed를 생성하고, 이를 정제해서 object를 덮으려고 노력한다.
- 그러나, 대부분의 CAM-based 기법들은 '의료 이미지'에서 잘 동작하지 않을 것임. Medical Image에 WSSS를 적용하기 위해서는 넘어야 할 장애물이 두 가지가 있다.
- Foreground와 Background의 경계가 모호하다 (뭐가 앞에있고, 뒤에 있는지 애매하다)
- co-occurrence가 자주 발생한다 (natural image의 경우 "사람"이 항상 "말"과 동시에 나오는 것은 아니다). image-level로 레이블링 되어있다면 "콩팥" 부분을 보고 "간"이라고 할 수 있다.
- 위 장애물들을 극복하기 위해 C-CAM (casual-CAM)을 소개한다. C-CAM은 두 가지의 인과 사슬?(causality chain)로 이루어져 있다.
- category causality. X-> Y : X는 이미지 콘텐츠(cause), Y는 카테고리(effect)를 가리킨다. 여기서는 C-CAM이 예측한 분류의 진짜 원인을 파악하기 위해 causal intervention(원인 중재)을 사용한다.
- anatomy causality. Z -> S : Z는 해부학적 구조 (cause), S는 organ segmentation (effect)를 가리킨다. 여기서는 C- CAM이 segmentation의 진짜 원인을 파악하기 위해 해부학적 제약(anatomical constraint)이 합쳐진다. 이게 co-occurrence 문제를 해결하는 실마리다.
Causal Intervention?
본 논문에서 '인과 추론'에 대한 개념이 나오는데, 확률 모형으로 데이터를 분석하는 기법 중 하나라고 합니다. 본 논문에서는 "Judea Pearl, Madelyn Glymour, and Nicholas P Jewell. Causal inference in statistics: A primer. John Wiley & Sons, 2016"를 참고했습니다. 사실 어떤 내용인지 자세히는 모르지만 링크에서 쉽게 설명해주고 있습니다.
본 논문의 category causality는 결국 X를 바꿔가면서 학습해 Y에 대한 추론을 하는 Intervention을 사용한다는 것 같습니다.
따라서 본 논문의 contribution은 다음 세 가지다.
- Medical image 분야에서 더 정교하고 정확한 모양의 mask를 생성하는 C-CAM 기법을 소개한다. (medical image WSSS 분야에서 causality를 도입한 것은 이게 최초라고 함)
- Medical image WSSS의 장애물을 극복하기 위해 두 개의 causailty chain을 결합했다.
- 3개의 공개 데이터 셋에서 실험을 진행했고, SOTA 기법을 뛰어넘는 성능을 보였다. (결과는 4. Experiment에)
2. Related Work
2.1. Weakly Supervised Semantic Segmentation
- 보통의 CAM 기반 파이프라인은 세 단계로 이루어져 있다.
- CAM으로 seed region을 생성한다
- seed region을 정제해 pseudo masks (가짜 마스크)를 생성한다.
- segmentation 모델로 pseudo masks를 학습한다.
- 대부분 연구는 2번에 집중하고 있다. (AffinityNet, BES, DSRG 등). 최근에는 mask를 직접 생성하는 모델들도 연구되었다 (FickleNet, MCIS, SEAM 등). 그러나 이 기법들은 medical image에선 잘 안 먹힌다 (위에서 언급한 두 가지 장애물 때문)
2.2. Anatomical Prior
- prior 지식을 결합하는 것은 natural image나 medical image 둘 다 성능을 높이는 좋은 방법이다.
Prior란?
확률 분포 관점에서 특정 사건에 대한 초기의 믿음 (initial belief)를 말합니다. 사건이 생기기 전에 "대충 이런 결과가 나오겠구먼~"이라는 생각입니다. (출처: https://cs.stackexchange.com/questions/76647/what-is-meant-by-the-term-prior-in-machine-learning)
- CNN 기반의 기법은 pixel 단위 loss함수가 있기 때문에 출력에 제약을 고려하지 않는다. 좋은 prior 설계는 더 좋은 구조 제약을 제공한다. 특히 medical image가 보통 해부학적 정보가 많기 때문에 prior가 더 효과가 있다. 그러나 현존하는 기법들은 전문 지식과 복잡한 모델이 필요하다.
- C-CAM은 모델 스스로 anatomical information을 추출하고 anatomy-causality chain을 통해 통합한다.
2.3. Causality in Computer Vision
- causality는 최근에 컴퓨터 비전 task에서 많이 쓰인다. causality의 등장은 학습도 더 잘되고 설명하기도 좋은 모델을 만들었다. 그러나 medical image WSSS 분야에서 적용된 적은 없다.
3. Method
3.1. Motivation
- medical image WSSS에서 causality는 아래 두 질문에 대답하여 분석할 수 있다.
- Q1. 왜 classification 모델은 정확한데 CAM이 생성한 region은 부정확한가?
- A. classification은 필수적으로 Association 모델이기 때문이다. (pearl's causal hierarchy 참고)
- ex. 일부 비 전립선 영역은 전립선 영역과 상관관계가 높아 biased category information을 제공할 수 있다.
- Q1. 왜 classification 모델은 정확한데 CAM이 생성한 region은 부정확한가?
-
- Q2. 생성된 region의 모양이 ground-truth와 왜 다른가?
- A. WSSS에서는 pixel 단위의 loss를 적용하는 게 불가능하기 때문이다.
- Q2. 생성된 region의 모양이 ground-truth와 왜 다른가?
- 따라서 본 논문에서 제안한 두 개의 causality chain으로 위 문제를 해결할 수 있다.
3.2. Global Sampling Module
- CAM이 생성한 Saliency map (heat map)은 segmentation에 정확하진 않지만, category와 anatomy에 연관된 가치 있는 정보를 전달한다. 따라서 그런 정보(global context)를 추출할 수 있는 global sampling(GS) 모듈을 설계했다.
- 데이터는 곧바로 P-CAM (pure CAM: CNN 백본, classification 헤드, mapping, upsample 연산이 들어있는 기초적인 CAM 모델)에 들어간다. training 단계에서는 training과 classification head만 사용된다.
- Mapping 단계에서는 argmax 연산을 이용해 가장 큰 pixel 값의 좌표를 추출하고, 각 클래스별 saliency map이 생성된 후 upscaling 된다 -> Coarse Mask
- 이후 Coarse Mask를 결합해 global context map (MGC)을 생성한다.
3.3. Causality in medical image WSSS
X: input image
Y: 분류된 카테고리
C: context confounder (모델이 헷갈리게 하는 요인)
Z: anatomical structure (해부학적 구조)
S: segmentation의 형태
P: 가짜 마스크
결론적으로 P는 Y와 S에 의해 결정된다.
3.4. Causality Module
Category-Causality Chain.
- Coarse mask와 MGC는 각각 convolution 레이어를 거쳐 같은 차원으로 변환된다. 이 후 두 행렬을 곱한 값에 softmax를 이용해 각 카테고리별 확률을 매핑한다.
- 이 카테고리별 확률을 내포한 벡터를 Mgc와 곱한 뒤 Downsampling을 통해 Mc를 도출한다.
Anatomy-Causality Chain.
- 복부 스캔같이 장기가 여러 개 있는 이미지에서 Saliency map은 좌-우 신장을 구분하지 못한다. 이걸 막기 위해 anatomy-causality chain을 설계했다. anatomy causality map은 각 카테고리의 가능 위치를 얻기 위해 Mgc가 양수인 부분을 1로 매핑한다.
- 이후 최종 saliency map인 CAMac를 얻기 위해, anatomy causality map과 CAMcc를 곱한 뒤, argmax 및 upscaling을 통해 최종 pseudo mask를 얻는다. 이 최종 mask는 다음 full-supervision 단계의 U-Net 모델의 학습에 사용된다.
4. Experiments
4.1. Dataset
- ProMRI: 전립선 데이터셋
- ACDC: 좌심실 심장 내막 데이터셋
- CHAOS: 복부 기관 데이터셋 (간, 좌 신장, 우 신장, 비장)
4.2. Implementation Details
- python, pytorch, ubuntu 16.04.1, 2 Nvidia GTX 1080Ti
- P-CAM 선 학습 (negative 이미지 포함), C-CAM 후 학습 (negative 이미지 미포함)
- SGD optimizer, Lr = 0.1(P-CAM), 0.001(C-CAM)
- pseudo segmentation mask를 이용해 U-Net 학습, Adam optimizer, Lr=0.0005, 100 epoch
4.3 Ablation Studies for C-CAM
4.4. Comparison with other CAM-like methods
4.5. Parameter Sensitivity
- 적당한 background threshold를 정하는 것이 가장 기본적이면서 크리티컬 한 영향을 미친다. 따라서 background threshold의 영향력을 파악하기 위해 광범위한 실험을 진행했다.
- 그 결과 다른 CAM기반 모델들은 threshold에 따라 DSC가 많이 달라졌다. 그러나 C-CAM은 거기에 덜 민감했다. -> background threshold를 정하기가 쉽다.
4.6 Visualization of saliency maps in C-CAM
4.7. Comparison with other WSSS methods
- (natural image segmentation에 쓰이는) 다른 WSSS 알고리즘들과 비교해도 좋은 성능을 보였다.
- 전립선을 세 부분으로 나누어서 측정해보았는데, Apex에서 SizeLoss와 작은 차이로 성능이 낮았고, 이는 SizeLoss는 weak label을 생성하는데 Ground truth를 사용하기 때문이다.
5. Conclusion and future work
- 본 논문에서 Causality 분석 기법에 기반한 medical image WSSS 기법을 제시했으며, 두 causal chain으로 이루어진 C-CAM이 또렷한 경계뿐만 아니라, 해부학적 지식에 상응하는 결과를 내었다.
- 그럼에도 불구하고 C-CAM은 복잡한 모양의 object를 segmentation 하기 어려워한다.
- 적은 양의 strong annotation과 많은 양의 weak annotation을 결합해 더 정확한 카테고리와 해부학적 정보를 제공하게 만들 수 있을 것이다.
Causality라는 분석 기법을 네트워크에 도입해 WSSS 분야에서 압도적인 스코어를 달성했습니다. 그러나 복잡한 모양의 object segmentation은 아직 한계가 있으므로, 추후에 연구하게 된다면 global feature부터 local feature까지 고려하는 pyramid 형식의 모델을 구상해볼 수도 있겠습니다.
Ref