Portfolio / Blog / Medical AI Case Study

Medical AI · Computer Vision · Explainability

Convolutional Neural Network for Pneumonia Detection - Part 1 (Baseline)

This project is part of a series on training a Convolutional Neural Network (CNN) for chest X-ray classification. What began as a standard classification challenge evolved into a broader investigation: I built a model, inspected what it learned, and corrected course when explainability revealed problems. The core argument is a case study of: high classification scores don't guarantee the model is learning the right features.
TLTR: accuracy without explainability can be misleading.

Do high classification scores mean the model learned the right features?

Medical image classification is many times presented as hitting the perfect score. Although scoring high is very important for the use of model in a medical context, it is not always the panacea. In my opinion, a neural model is truly useful in real-world settings only if its predictions are both accurate and explainable. By accurate, I mean the model performs its task without errors. By explainable, I mean it should be clear which features the model relies on to make its decisions.
The article below is mostly a reflection and a walk-through of my work on a project. Even though, at the very beginning of this project, my aim was to make a model that will score high, very quickly I realised the importance also lies on explaining what makes the model good and even more importantly, on what the model focusing on when making the prediction.

Problem to solve

The task was to classify chest X-rays into three categories: Covid, Normal, and Pneumonia.

The most important lesson of this project was not how to build a classifier. It was how quickly a convincing metric can stop being convincing once you inspect what the model is actually using.

Building the Baseline

Show Beside This Section

Figure: class distribution bar chart.

I always start by examining the dataset’s basic properties (see Figure 1). For neural network training, checking class imbalance is especially important. The class counts confirmed the dataset description: this is a multi-class classification task with grayscale chest X-rays.

Bar chart showing the number of X-ray images across the Covid, Normal, and Pneumonia classes.
Figure 1. Dataset distribution across the three classes used in the baseline experiment. There is a small imbalance but I do not believe that will affect that much the final result.

The first pipeline was built around a small set of deliberate preprocessing decisions. Each one was chosen for a specific reason, not as a default.

  • Reading images directly from the folder structure using Keras image generators. I do that to keep the pipeline reproducible and at the same time tightly coupled to the dataset's original layout.
  • Resizing to 256 × 256 pixels. It is very important to keep a consistent spatial resolution before batching. At the same time, I need to preserce enough anatomical details while keeping memory and computation at a manageable level for a modest training setup.
  • Batching with a small batch size of 8. Think of it like studying for an exam by reviewing a few flashcards at a time rather than the entire stack at once. Smaller groups keep the learning process a little unpredictable, which stops the model from taking too many shortcuts and forces it to generalise rather than memorise.
  • Applying only rescaling in the data generators, then performing augmentation inside the model graph. This keeps the validation and test streams clean while still exposing the training path to meaningful stochastic variation.

Data Pipeline

from tensorflow.keras.preprocessing.image import ImageDataGenerator

# single generator for train/validation split: rescale only
train_val_generator = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,
)

test_data_generator = ImageDataGenerator(rescale=1./255)

train_data_iter = train_val_generator.flow_from_directory(
    directory="Covid19-dataset/train/",
    class_mode='categorical',
    color_mode='grayscale',
    target_size=(256, 256),
    batch_size=BATCH_SIZE,
    seed=SEED,
    shuffle=True,
    subset='training'
)

validation_data_iter = train_val_generator.flow_from_directory(
    directory="Covid19-dataset/train/",
    class_mode='categorical',
    color_mode='grayscale',
    target_size=(256, 256),
    batch_size=BATCH_SIZE,
    seed=SEED,
    shuffle=False,
    subset='validation'
)

test_data_iter = test_data_generator.flow_from_directory(
    directory="Covid19-dataset/test/",
    class_mode='categorical',
    color_mode='grayscale',
    target_size=(256, 256),
    batch_size=BATCH_SIZE,
    seed=SEED,
    shuffle=False,
)

In this setup, all iterators perform only rescaling — no augmentation in the generators. The training iterator uses shuffle=True for stochastic batching, while the validation and test iterators use shuffle=False so that prediction order aligns with the ground-truth labels. Augmentation is handled inside the model with Keras random layers, which are active only during model.fit() and automatically disabled during evaluation and prediction.

Architecture and training of the baseline model

The baseline architecture was intentionally compact: a small custom CNN with three convolutional stages, each using 3×3 convolutions2 and max-pooling3 for progressive spatial compression. The design is inspired by the small convnet patterns described in Chollet's Deep Learning with Python.8

Instead of heavy preprocessing in the iterator, I placed augmentation layers directly in the model graph. This way, random flips, rotations, translations, and zoom are applied only while fitting, and stay disabled during validation and testing.

I also used a single-channel input. Chest X-rays are greyscale, so a three-channel RGB model would mean either duplicating the channel artificially or wasting two input channels on zeros. Using (256, 256, 1) keeps the input honest.

Finally, the classification head is intentionally small (Flatten → Dropout → Dense(64) → Softmax) so the model remains a practical reference point, not peak performance.

Layer-by-layer visualization of the baseline CNN architecture from 1x256x256 input to 3-class dense output.
Figure 2. Visual summary of the baseline CNN architecture used in this experiment.

The architecture diagram in Figure 2 provides a compact view of the feature-map transitions from input to the final three-class output.

Selected Code Snippet: Baseline Model

import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping

model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(256, 256, 1)),

    # in-model augmentation (active only during training)
    tf.keras.layers.RandomFlip("horizontal", seed=SEED),
    tf.keras.layers.RandomRotation(15 / 360, seed=SEED),
    tf.keras.layers.RandomTranslation(0.1, 0.1, seed=SEED),
    tf.keras.layers.RandomZoom(0.1, 0.1, seed=SEED),

    # convolutional backbone
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='valid'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='valid'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='valid'),

    # classification head
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(3, activation='softmax'),
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=[tf.keras.metrics.CategoricalAccuracy(), tf.keras.metrics.AUC(name='auc'), tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]
)

print(model.summary())

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=20,
    restore_best_weights=True
)

During training, I tracked more than categorical accuracy. I also monitored AUC, precision, and recall. I trained with early stopping and ReduceLROnPlateau, then performed a one-shot evaluation on the held-out test set after training completed. That matters in medical classification because a model can look acceptable on one aggregate metric while collapsing on one class that matters clinically.

Three line plots showing baseline train, validation, and test loss, categorical accuracy, and AUC across training epochs.
Figure 3. Training, validation, and test curves for the baseline.

By inspecting Figure 3, the model shows the hallmarks of a healthy fit. Training and validation loss both decrease together from around 1.1 down to roughly 0.2 by epoch 50, with no diverging gap between them. There is some expected bounciness in validation loss during the first 10–15 epochs. In my opinion, unavoidable with a small dataset and mini-batch noise, but the two curves track each other closely from epoch 20 onward.

Categorical accuracy tells the same story. Training accuracy climbs from about 0.50 to 0.93, while validation accuracy rises in parallel and settles in the 0.88–0.93 range. Validation occasionally edges slightly above training, which is the expected fingerprint of in-graph augmentation: the model sees distorted images at training time but clean ones at validation time. The two curves meeting and staying together indicates the model is learning generalisable patterns rather than memorising the training set.

AUC climbs fastest of all three metrics, passing 0.95 within the first 10 epochs and saturating near 0.99 for both training and validation by epoch 20. The held-out test set confirm the whole behaviour: test accuracy of 0.85, test AUC of 0.95, and test loss of 0.43. The small drop from validation to test is normal sampling variability on a 66-image test set, not a generalisation failure. For a custom CNN trained from scratch on a few hundred images per class and on a 3-way task that includes the genuinely subtle Normal-versus-Pneumonia distinction, the performance of the baseline is actually very good to start with.

How Did the Baseline Model Perform?

After training, the baseline reached approximately 0.74 test accuracy, 0.56 test loss, and around 0.90 test AUC. This is clearly above random chance for a three-class task, so the model is learning useful signal. At the same time, the gap between accuracy and AUC suggests that class separation is better than final decision calibration. In practical terms: the baseline is informative, but still not reliable enough for high-confidence medical use.

Selected Code Snippet: Baseline Evaluation

test_data_iter.reset()
predictions = model.predict(test_data_iter)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = test_data_iter.classes
class_labels = list(test_data_iter.class_indices.keys())

assert len(predicted_classes) == len(true_classes), (
    "Prediction/label length mismatch - check that shuffle=False on test_data_iter."
)

report = classification_report(true_classes, predicted_classes, target_names=class_labels)
print(report)

cm = confusion_matrix(true_classes, predicted_classes)
sns.heatmap(
    cm / cm.sum(axis=1, keepdims=True),
    annot=True,
    fmt=".2f",
    cmap="viridis",
    xticklabels=class_labels,
    yticklabels=class_labels,
)
Confusion matrix for the baseline chest X-ray classifier showing widespread misclassification across Covid, Normal, and Pneumonia classes.
Figure 4. The baseline confusion matrix of the test set.

Looking at Figure 4 allows us to break down the performance class by class. From top to bottom:

Covid (row 1): The model correctly identifies 88% of Covid cases — the strongest result across all three classes. The remaining 12% are misclassified as Normal, and nothing leaks into Pneumonia. What does that zero mean? That the model is not confusing a Covid X-ray with a bacterial infection. It is worth noting that the 12% Covid-to-Normal confusion is not ideal in a medical context. However, for now I am only evaluating the baseline model and trying to gather as many insights as possible, which I will then use in the second article of this series to improve performance.

Normal (row 2): 85% of Normal cases are correctly identified. The concerning figure here is the 15% that get pushed into Pneumonia. A healthy patient being flagged as potentially ill is a false positive. In a medical context, that is not necessarily a critical failure, since it is generally better to follow up with additional tests rather than miss a condition. However, for the patients themselves, it can be stressful and costly. It is still a pattern worth tracking as the model is improved.

Pneumonia (row 3): This is where the model struggles most. Only 80% of true Pneumonia cases are caught, and 20% are misclassified as Normal. That is a false negative rate of one in five — meaning the model declares as "healty" almost 20% of the patients. In a real screening pipeline, that is the error type the model can least afford to make. Unlike the false positives discussed above, a missed Pneumonia diagnosis can be life-threatening.

Taken together, the picture is consistent with what the aggregate metrics already hinted at: the model has learned something useful, but it is not yet reliable enough across all classes. Covid is handled well; Pneumonia recall is the weak link. That asymmetry will guide what to improve next.

Conclusions

In this article, I built and evaluated a baseline CNN for chest X-ray classification across three classes: Covid, Normal, and Pneumonia. The aggregate metrics — 0.74 test accuracy and 0.90 test AUC — place the model clearly above random chance and confirm that it is picking up genuine signal from the X-rays. For a custom architecture trained from scratch on a small dataset, that is an honest and useful starting point.

The confusion matrix, however, tells a more sobering story. Covid is handled well at 88% recall, but Pneumonia tops out at only 80%. That means that one in five true Pneumonia cases is quietly sent into the Normal class. In a clinical screening context, that is precisely the error the model cannot afford to make. However, a missed Covid case is serious; a missed Pneumonia case in a fragile patient can be life-threatening. The aggregate accuracy hides this, which is exactly why reading the confusion matrix row by row is important in this case.

Before trying to fix those numbers, though, I think there is another important question to ask: "what is the model actually looking at?". In other words, what are the parts of the lungs that have the highest discriminative power between classes. This is important to figure it out because a model can score well for the wrong reasons. For instance, basing its prediction on irrelevant artefacts in the image rather than genuine anatomical features. If we add this one layer of explainability, we will already achieve something important: a good performance and an explanation of the performance as well as space to improve.

Comments

Have a question, disagree with an assumption, or want to suggest an improvement? Leave a comment below and I will reply.

  1. VGG16 is a 16-layer deep convolutional neural network developed by the Visual Geometry Group at the University of Oxford. It follows a simple and uniform architecture: repeated blocks of 3×3 convolutional filters stacked in increasing depth, followed by max-pooling layers and a fully connected classification head. Despite its age, it remains a common baseline and backbone for image classification tasks. Keras VGG16 documentation ↗
  2. 3×3 convolutions are the core operation in most modern CNNs. A small 3×3 filter slides across the image and computes a weighted sum of each 3×3 neighbourhood of pixels, producing a new feature map that captures local patterns such as edges or textures. Using many small filters stacked in depth is more expressive and parameter-efficient than using a single large filter.
  3. Max-pooling is a downsampling operation that divides the feature map into small regions (typically 2×2) and keeps only the maximum value from each region. Think of it like summarising a paragraph by keeping only the most important word from each sentence. It reduces the spatial size of the representation, which lowers computation and introduces a degree of translation invariance — small shifts in the input produce the same output.
  4. Skip connections (used in architectures like ResNet) are shortcut paths that feed the output of an earlier layer directly into a later layer, bypassing one or more intermediate layers. They help gradients flow during training and allow the network to learn residual corrections rather than full transformations. VGG16 has none — each layer must process the signal in sequence without shortcuts.
  5. Complex branching refers to architectures where the network splits into multiple parallel paths that process the input differently and then merge — as seen in Inception or EfficientNet. Each branch can capture features at different scales or with different operations simultaneously. VGG16 is a single linear stack with no such parallel paths, which is why it is easy to reason about and interpret.
  6. Precision answers: "Of all samples predicted as a class, how many were actually that class?" A high precision means fewer false positives. Formula: Precision = TP / (TP + FP), where TP is true positives and FP is false positives.
  7. Recall answers: "Of all real samples of a class, how many did the model correctly find?" A high recall means fewer false negatives. Formula: Recall = TP / (TP + FN), where TP is true positives and FN is false negatives.
  8. Deep Learning with Python — Chollet, F. Chapter 8: Image classification. deeplearningwithpython.io ↗