이 예제에서는 단일 클래스 FCDD(fully convolutional data description) 이상 감지 신경망을 사용하여 알약 영상에서
결함을 검출하는 방법을 보여줍니다.
이상 감지의 핵심적 목표는 훈련된 신경망이 영상을 이상 영상으로 분류한 이유를 관측자가 이해할 수 있도록 하는
것입니다. FCDD는 설명 가능한 분류를 수행하여 클래스 예측값에 대해 신경망이 이 분류 결정을 내리게 된 근거를
설명하는 정보를 제공합니다. FCDD 신경망은 각 픽셀이 비정상일 확률과 함께 히트맵을 반환합니다.
분류기는 이상 점수 히트맵의 평균값에 따라 영상에 정상 또는 비정상으로 레이블을 지정합니다.
1. 데이터셋에 이용할 알약 영상 다운로드
dataDir = fullfile(tempdir, 'PillDefects');
downloadPillQCData(dataDir)
|
만약 위의 오류가 발생하는 경우, 링크에서 직접 다운로드함.
다음 이미지는 각 클래스에 해당하는 영상의 예입니다. 왼쪽은 결함이 없는 정상 알약,
중간은 이물질로 오염된 알약, 오른쪽은 흠집이 있는 알약입니다.
2. 데이터를 불러온 후, 전처리하기
imageDir = fullfile(dataDir,"pillQC-main","images");
imds = imageDatastore(imageDir,IncludeSubfolders=true,LabelSource="foldernames");
|
데이터를 훈련세트, 보정세트, 테스트세트로 분할하기
(훈련 데이터 세트에 정상 영상을 50% 할당하고 각 이상 클래스를 5%의 작은 비율로 할당합니다.
보정 세트에는 정상 영상을 10% 할당하고 각 이상 클래스를 20%를 할당합니다.
나머지 영상을 테스트 세트에 할당합니다. )
normalTrainRatio = 0.5;
anomalyTrainRatio = 0.05;
normalCalRatio = 0.10;
anomalyCalRatio = 0.20;
normalTestRatio = 1 - (normalTrainRatio + normalCalRatio);
anomalyTestRatio = 1 - (anomalyTrainRatio + anomalyCalRatio);
anomalyClasses = ["chip","dirt"];
[imdsTrain,imdsCal,imdsTest] = splitAnomalyData(imds,anomalyClasses, ...
NormalLabelsRatio=[normalTrainRatio normalCalRatio normalTestRatio], ...
AnomalyLabelsRatio=[anomalyTrainRatio anomalyCalRatio anomalyTestRatio]);
|
추가로, 훈련데이터를 하나는 정상데이터만, 다른 하나는 이상데이터만 포함하는 두 개의 데이터저장소로 분할함.
[imdsNormalTrain, imdsAnomalyTrain] = splitAnomalyData(imdsTrain, anomalyClasses, ...
NormalLabelsRatio=[1 0 0], AnomalyLabelsRatio=[0 1 0], Verbose=false);
|
훈련데이터 증강하여 시각화
imdsNormalTrain = transform(imdsNormalTrain, @augmentDataForPillAnomalyDetector);
imdsAnomalyTrain = transform(imdsAnomalyTrain, @augmentDataForPillAnomalyDetector);
dsCal = transform(imdsCal, @addLabelData, IncludeInfo=true);
dsTest = transform(imdsTest, @addLabelData, IncludeInfo=true);
exampleData = readall(subset(imdsNormalTrain, 1:9));
montage(exampleData(:,1));
|
3. FCDD 모델 만들기
(FCDD의 기본 개념은 입력 영상의 각 영역에 이상 부분이 포함될 확률을 설명하는 이상 점수 맵을 생성하도록
신경망을 훈련시키는 것)
backbone = pretrainedEncoderNetwork('inceptionv3', 3);
net = fcddAnomalyDetector(backbone);
|
신경망을 훈련시키거나, 사전훈련된 신경망 다운로드하기
doTraining = true;
numEpochs = 30;
if doTraining
options = trainingOptions("adam", ...
Shuffle="every-epoch",...
MaxEpochs=numEpochs,InitialLearnRate=1e-4, ...
MiniBatchSize=32,...
BatchNormalizationStatistics="moving");
detector = trainFCDDAnomalyDetector(imdsNormalTrain,imdsAnomalyTrain,net,options);
modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss"));
save(fullfile(dataDir,"trainedPillAnomalyDetector-"+modelDateTime+".mat"),"detector");
else
trainedPillAnomalyDetectorNet_url = "https://ssd.mathworks.com/supportfiles/"+ ...
"vision/data/trainedFCDDPillAnomalyDetectorSpkg.zip";
downloadTrainedNetwork(trainedPillAnomalyDetectorNet_url,dataDir);
load(fullfile(dataDir,"folderForSupportFilesInceptionModel", ...
"trainedPillFCDDNet.mat"));
end
|
이상 임계값 설정하기
scores = predict(detector, dsCal);
labels = imdsCal.Labels ~= 'normal';
numBins = 20;
[~, edges] = histcounts(scores, numBins);
figure
hold on
hNormal = histogram(scores(labels==0), edges);
hAnomaly = histogram(scores(labels==1), edges);
hold off
legend([hNormal, hAnomaly], 'Normal', 'Anomaly')
xlabel('Mean Anomaly Score')
ylabel('Counts')
[thresh, roc] = anomalyThreshold(labels, scores, true);
detector.Threshold = thresh;
plot(roc)
title('ROC AUC : '+ roc.AUC)
|
4. 분류모델 평가하기
% 테스트 세트의 각 영상을 정상 또는 이상 영상으로 분류 testSetOutputLabels = classify(detector, dsTest);
% 각 테스트 영상의 실측 레이블을 가져옴. testSetTargetLabels = dsTest.UnderlyingDatastores{1}.Labels;
% evaluateAnomalyDetection 함수로 성능 메트릭을 계산하여 이상 감지기를 평가. metrics = evaluateAnomalyDetection(testSetOutputLabels, testSetTargetLabels, anomalyClasses);
% 혼동행렬을 추출하고 혼동 플롯을 표시. M = metrics.ConfusionMatrix{:,:};
confusionchart(M,["Normal","Anomaly"])
acc = sum(diag(M)) / sum(M, 'all');
title('Accuracy : '+ acc)
% 전체 데이터 세트와 각 이상 클래스에 대해 메트릭을 계산. metrics.ClassMetrics
metrics.ClassMetrics(2, 'AccuracyPerSubClass').AccuracyPerSubClass{1}
|
5. 분류결정 설명하기
이상 히트맵의 표시 범위 계산하기
minMapVal = inf;
maxMapVal = -inf;
reset(dsCal)
while hasdata(dsCal) img = read(dsCal);
map = anomalyMap(detector, img{1});
minMapVal = min(min(map, [], 'all'), minMapVal);
maxMapVal = max(max(map, [], 'all'), maxMapVal);
end
displayRange = [minMapVal, maxMapVal];
|
testSetAnomalyLabels = testSetTargetLabels ~= 'normal';
idxTruePositive = find(testSetAnomalyLabels' & testSetOutputLabels, 1, 'last');
dsExample = subset(dsTest, idxTruePositive);
img = read(dsExample);
img = img{1};
map = anomalyMap(detector, img);
imshow(anomalyMapOverlay(img, map, MapRang=displayRange, Blend='equal'))
|
정상 영상의 히트맵 보기
idxTrueNegative = find(~(testSetAnomalyLabels' | testSetOutputLabels));
dsExample = subset(dsTest, idxTrueNegative);
img = read(dsExample);
img = img{1};
map = anomalyMap(detector, img);
imshow(anomalyMapOverlay(img, map, MapRange=displayRange, Blend='equal'))
|
거짓음성 영상의 히트맵 보기
falseNegativeIdx = find(testSetAnomalyLabels' & ~testSetOutputLabels);
if ~isempty(falseNegativeIdx)
fnExamples = subset(dsTest,falseNegativeIdx);
fnExamplesWithHeatmapOverlays = transform(fnExamples,@(x) {...
anomalyMapOverlay(x{1},anomalyMap(detector,x{1}), ...
MapRange=displayRange,Blend="equal")});
fnExamples = readall(fnExamples);
fnExamples = fnExamples(:,1);
fnExamplesWithHeatmapOverlays = readall(fnExamplesWithHeatmapOverlays);
montage(fnExamples)
montage(fnExamplesWithHeatmapOverlays)
else
disp("No false negatives detected.")
end
|
falsePositiveIdx = find(~testSetAnomalyLabels' & testSetOutputLabels);
if ~isempty(falsePositiveIdx) fnExamples = subset(dsTest, falsePositiveIdx);
fnExamplesWithHeatmapOverlays = transform(fpExamples, @(x) { ...
anomalyMapOverlay(x{1}, anomalyMap(detector, x{1}), ...
MapRange=displayRange, Blend="equal")});
fpExamples = readall(fpExamples);
fpExamples = fpExamples(:,1);
fnExamplesWithHeatmapOverlays = readall(fnExamplesWithHeatmapOverlays)
else
disp('No false positive detected')
end
|
지원함수
function [data,info] = addLabelData(data,info)
if info.Label == categorical("normal")
onehotencoding = 0;
else
onehotencoding = 1;
end
data = {data,onehotencoding};
end
|