Wednesday 4 November 2015

Using Chain Of Responsibility instead of if/else statement

Currently I am involved in building a web application. One of the features of the app is for the user to be able to upload their photo. But the system currently only allows three image formats to be uploaded JPG, BMP and PNG, with the possibility that in the future other formats to be supported.

Each image file has a header that, among other things, it specifies what format it is. In our case the headers are as follows:

File format Header
JPG [0xff, 0xd8]
BMP [0x42, 0x4D]
PNG [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]

Our image decoding service is defined as

public interface IImageDecodingService
{
    ImageFormat DecodeImage(byte[] imageBuffer);
}

And the ImageFormat is defined as an enumeration:

public enum ImageFormat
{
    Unknown,
    Bmp,
    Png,
    Jpeg
}

and one possible implementation can be:

public class ImageDecodingService : IImageDecodingService
{
    private readonly byte[] _jpgHeader = { 0xff, 0xd8 };
    private readonly byte[] _pngHeader = { 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A };
    private readonly byte[] _bmpHeader = { 0x42, 0x4D };

    public ImageFormat DecodeImage(byte[] imageBuffer)
    {
        if (ContainsHeader(imageBuffer, _jpgHeader))
            return ImageFormat.Jpeg;

        if (ContainsHeader(imageBuffer, _pngHeader))
            return ImageFormat.Png;

        if (ContainsHeader(imageBuffer, _bmpHeader))
            return ImageFormat.Bmp;

        return ImageFormat.Unknown;
    }

    protected static bool ContainsHeader(byte[] buffer, byte[] header)
    {
        for (int i = 0; i < header.Length; i += 1)
        {
            if (header[i] != buffer[i])
            {
                return false;
            }
        }

        return true;
    }
}

The problem with this approach is that any time we need to support a new image format, we need to go and change the class. This breaks the "Open/Closed Principle".

A better approach is to implement each decoder as a class and then chain them together (using "Chain of Responsibility Pattern").

To achieve this, first we need to implement the decoders. The interface for the decoders looks like:

public interface IImageDecoder
{
    ImageFormat DecodeImage(byte[] buffer);
}

Since the decoders are very similar, we can extract a base class that implements common methods as follows:

public abstract class BaseDecoder : IImageDecoder
{
    private ImageFormat _decodingFormat;

    protected BaseDecoder(ImageFormat decodingFormat)
    {
        _decodingFormat = decodingFormat;
    }

    protected abstract byte[] Header { get; }

    public ImageFormat DecodeImage(byte[] buffer)
    {
        if(ContainsHeader(buffer, Header))
        {
            return _decodingFormat;
        }

        return ImageFormat.Unknown;
    }

    private static bool ContainsHeader(byte[] buffer, byte[] header)
    {
        for (int i = 0; i < header.Length; i += 1)
        {
            if (header[i] != buffer[i])
            {
                return false;
            }
        }

        return true;
    }
}

Now our decoders look like:

public sealed class JpegDecoder : BaseDecoder, IImageDecoder
{
    public JpegDecoder() : base(ImageFormat.Jpeg)
    { }

    protected override byte[] Header
    {
        get { return new byte[] { 0xff, 0xd8 }; }
    }
}

public sealed class BmpDecoder : BaseDecoder, IImageDecoder
{
    public BmpDecoder() : base(ImageFormat.Bmp)
    { }

    protected override byte[] Header
    {
        get { return new byte[] { 0xff, 0xd8 }; }
    }
}

public sealed class PngDecoder : BaseDecoder, IImageDecoder
{
    public PngDecoder() : base(ImageFormat.Png)
    { }

    protected override byte[] Header
    {
        get { return new byte[] { 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A }; }
    }
}

And the last decoder is the decoder that just returns ImageFormat.Unknown. The implementation looks like:

public class UnknownImageDecoder : IImageDecoder
{
    public ImageFormat DecodeImage(byte[] buffer)
    {
        return ImageFormat.Unknown
    }
}

The next step is to refactor our base class so it allows chaining. The refactored class looks like:

public abstract class BaseDecoder : IImageDecoder
{
    private readonly ImageFormat _decodingFormat;
    private IImageDecoder _nextChain;

    protected BaseDecoder(ImageFormat decodingFormat)
    {
        _decodingFormat = decodingFormat;
    }

    protected BaseDecoder(IImageDecoder nextChain, ImageFormat decodingFormat) : this(decodingFormat)
    {
        if (nextChain == null)
        {
            throw new ArgumentNullException("nextChain");
        }

        _nextChain = nextChain;
    }

    protected abstract byte[] Header { get; }

    public ImageFormat DecodeImage(byte[] buffer)
    {
        if (ContainsHeader(buffer, Header))
        {
            return _decodingFormat;
        }

        if (_nextChain != null)
        {
            return _nextChain.DecodeImage(buffer);
        }

        return ImageFormat.Unknown;
    }

    private static bool ContainsHeader(byte[] buffer, byte[] header)
    {
        for (int i = 0; i < header.Length; i += 1)
        {
            if (header[i] != buffer[i])
            {
                return false;
            }
        }

        return true;
    }
}

As you can see now we have two constructors, one that takes the ImageFormat and the other that takes IImageDecoder as a next chain and ImageFormat. The reason for two constructors is that the first constructor (with only one parameter) allows the decoder to be used on its own, whereas the second constrcutor (the one with two parameters) enables to build the chain.

Pay attention to the DecodeImage(...) method. Now if this method does not know how to decode the image, and the next chain is specified, it passes the responsibility to the next chain.

We also need to add the second constructor to our decoders:

public sealed class BmpDecoder : BaseDecoder
{
    public BmpDecoder() 
        : base(ImageFormat.Bmp)
    { }

    public BmpDecoder(IImageDecoder nextChain) 
        : base(nextChain, ImageFormat.Bmp)
    { } 

    protected override byte[] Header
    {
        get { return new byte[] { 0xff, 0xd8 }; }
    }
}

public sealed class PngDecoder : BaseDecoder
{
    public PngDecoder() 
        : base(ImageFormat.Png)
    { }

    public PngDecoder(IImageDecoder nextChain) 
        : base(nextChain, ImageFormat.Png)
    { }            

    protected override byte[] Header
    {
        get { return new byte[] { 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A }; }
    }
}

public sealed class JpegDecoder : BaseDecoder
{
    public JpegDecoder()
        : base(ImageFormat.Jpeg)
    { }

    public JpegDecoder(IImageDecoder nextChain)
        : base(nextChain, ImageFormat.Jpeg)
    { } 

    protected override byte[] Header
    {
        get { return new byte[] { 0xff, 0xd8 }; }
    }
}

To construct the chain we need a factory that constructs it and returns the first one. The interface for the factory looks like:

public interface IImageDecoderFactory
{
    IImageDecoder Create();
}

And the implementation looks like:

public class ImageDecoderFactory : IImageDecoderFactory
{
    public IImageDecoder Create()
    {
        return new BmpDecoder(new JpegDecoder(new PngDecoder(new UnknownImageDecoder())));
    }
}

Now our ImageDecodingService looks like:

public class ImageDecodingService : IImageDecodingService
{
    private readonly IImageDecoderFactory _imageDecoderFactory;

    public ImageDecodingService(IImageDecoderFactory imageDecoderFactory)
    {
        _imageDecoderFactory = imageDecoderFactory;
    }

    public ImageFormat DecodeImage(byte[] imageBuffer)
    {
            var decoder = _imageDecoderFactory.Create();
        return decoder.DecodeImage(imageBuffer);
    }
}

So, if we need to support another format, we would implement ther decoder for it and then add it to the factory. In a real-world application you would register the decoders with a DI Container and then the DI Container would pass the decoders to the factory and the factory would chain them together. In this way you do not need to change any existing code to support another format.