2008-05-15

Refactoring a big if block into a simple command processor using attributes

Recently someone had a problem where they had some massive control block full of if statements looking at a string, dispatching one of a variety of functions. The if block was massive. Hundreds of if statments, hundreds of magic strings.

Interestingly all the functions had the same signature... So I gave him this example of how to use attributes on the methods to specify the corresponding token, then we use Reflection to scan the assembly for all the functions with that attribute, then create a function table keyed by thier token, to privde fast lookup. This example shows how to creat an object instance and then invoke the method via reflection, but this could be made much simpler if the methods were all static and the function protoype was part of an interface instead of just a unspoken convention.

Here's the "Before" example from the original question...


string tag;
string cmdLine;
State state;
string outData;

...

if (token == "ABCSearch") {
ABC abc = new ABC();
abc.SearchFor(tag, state, cmdLine, ref outData);
}
else if (token == "JklmDoSomething") {
JKLM jklm = new JKLM();
jklm.Dowork1(tag, state, cmdLine, ref outData);
}


A couple of notes:

  • There is no correlation between the token and the class name (ABC, JKLM, ...) or the method (SearchFor, Dowork1).
  • The methods do have the same signature:
    void func(string tag, State state, string cmdLine, ref string outData)
  • The if ()... block is 500+ lines and growing



And here is my example command processor (as a console app):


using System;
using System.Collections.Generic;
using System.Reflection;

namespace ConsoleApplication2
{
public class Program
{
static void Main(string[] args)
{
while(true)
{
Console.Write("[e(x)ecute, (t)okens, (q)uit] -> ");
string s = Console.ReadKey().KeyChar.ToString().ToLower();
Console.WriteLine();

switch (s)
{
case "q":
Console.WriteLine("Finished.");
return;

case "t":
Console.WriteLine("Known tokens:");
foreach (string tokenName in CommandProcessor.GetTokens())
{
Console.WriteLine(tokenName);
}
break;

case "x":
string token = string.Empty;
string tag = string.Empty;
string cmdLine = string.Empty;
string state = string.Empty;

Console.Write("token: ");
token = Console.ReadLine();
Console.Write("tag: ");
tag = Console.ReadLine();
Console.Write("cmdLine: ");
cmdLine = Console.ReadLine();
Console.Write("state: ");
state = Console.ReadLine();

try
{
string output = CommandProcessor.DoCommand(token, tag, cmdLine, State.GetStateFromString(state));
Console.WriteLine("Output:");
Console.WriteLine(output);
}
catch (TokenNotFoundException ex)
{
Console.WriteLine(ex.Message);
}
catch (Exception ex)
{
Console.WriteLine("Unknown error occured during execution. Exception was: " + ex.Message);
}
break;

default:
Console.WriteLine("Unknown command: {0}", s);
break;
}
}
}
}

public class CommandProcessor
{
// our dictionary of method calls.
internal static Dictionary availableFunctions = new Dictionary();

static CommandProcessor()
{
SetupMethodCallDictionary();
}

private static void SetupMethodCallDictionary()
{
// get the current assembly.
Assembly assembly = Assembly.GetExecutingAssembly();

// cycle through the types in the assembly
foreach (Type type in assembly.GetTypes())
{
// cycle through the methods on each type
foreach (MethodInfo method in type.GetMethods())
{
// look for Token attributes on the methods.
object[] tokens = method.GetCustomAttributes(typeof(TokenAttribute), true);

if (tokens.Length > 0)
{
// cycle through the token attributes (allowing multiple attributes
// leaves room for backwards compatibility if you change your tokens
// or consolidate functionality of the methods. etc.
foreach (TokenAttribute token in tokens)
{
// look for the token in the dictionary, if it's not there add it..
MethodInfo foundMethod = default(MethodInfo);
if (availableFunctions.TryGetValue(token.TokenName, out foundMethod))
{
// if there is more than one function registered for the same
// token, just keep the last one found.
availableFunctions[token.TokenName] = method;
}
else
{
// add to the table.
availableFunctions.Add(token.TokenName, method);
}
}
}
}
}
}

public static string DoCommand(string token, string tag, string cmdLine, State state)
{
// the data returned from the command
string outData = string.Empty;
MethodInfo method = default(MethodInfo);

// see if we have a method for that token
if (availableFunctions.TryGetValue(token, out method))
{
// if so, create an instance of the object, and then execute the method,
// unless it's static.. in which case just execute the method.
object instance = null;
if (!method.IsStatic)
{
// this just invokes the default constructor... if you need to pass
// parameters use one of the other overloads.
instance = Activator.CreateInstance(method.ReflectedType);
}

object[] args = new object[] { tag, state, cmdLine, outData };

method.Invoke(instance, args);
outData = (string)args[3];
}
else
{
throw new TokenNotFoundException(string.Format("Token {0} not found. Cannot execute.", token));
}
return outData;
}

public static IEnumerable GetTokens()
{
foreach (KeyValuePair entry in availableFunctions)
{
yield return entry.Key;
}
}
}

public class State
{
public State(string text)
{
_text = text;
}

private string _text;

public string Text
{
get { return _text; }
set { _text = value; }
}

public static State GetStateFromString(string state)
{
// implement parsing of string to build State object here.
return new State(state);
}
}

[AttributeUsage(AttributeTargets.Method)]
public class TokenAttribute : Attribute
{
public TokenAttribute(string tokenName)
{
_tokenName = tokenName;
}

private string _tokenName;

public string TokenName
{
get { return _tokenName; }
set { _tokenName = value; }
}
}

[global::System.Serializable]
public class TokenNotFoundException : Exception
{
//
// For guidelines regarding the creation of new exception types, see
// http://msdn.microsoft.com/library/default.asp?url=/library/en-us/cpgenref/html/cpconerrorraisinghandlingguidelines.asp
// and
// http://msdn.microsoft.com/library/default.asp?url=/library/en-us/dncscol/html/csharp07192001.asp
//
public TokenNotFoundException() { }
public TokenNotFoundException(string message) : base(message) { }
public TokenNotFoundException(string message, Exception inner) : base(message, inner) { }
protected TokenNotFoundException(
System.Runtime.Serialization.SerializationInfo info,
System.Runtime.Serialization.StreamingContext context)
: base(info, context) { }
}

public class ABC
{
[Token("ABCSearch")]
public void SearchFor(string tag, State state, string cmdLine, ref string outData)
{
// do some stuff.
outData =
string.Format("You called ABC.SearchFor. Parameters were [tag: {0}, state: {1}, cmdLine: {2}]", tag, state.Text, cmdLine);

}
}

public class JKLM
{
[Token("JklmDoSomething")]
public void Dowork1(string tag, State state, string cmdLine, ref string outData)
{
// do some other stuff.
outData =
string.Format("You called JKLM.Dowork1. Parameters were [tag: {0}, state: {1}, cmdLine: {2}]", tag, state.Text, cmdLine);
}
}
}

How to get information about your current culture.

Instead of doing a college survery and asking a bunch of probing questions about the lives of twenty-somethings, there's an easier way to get information about your current culture. Just look at CultureInfo.CurrentCulture.

Here's a quick program that explains how to do that. This can be very useful in debugging and troubleshooting how your program behaves on machines that are setup for other laungages or regions.


using System;
using System.Collections.Generic;
using System.Text;
using System.Globalization;

namespace ConsoleApplication1
{
class Program
{
static void Main(string[] args)
{
CultureInfo currentCulture = CultureInfo.CurrentCulture;

Console.WriteLine("CultureInfo");
Console.WriteLine("-----------");
Console.WriteLine("DisplayName: {0}", currentCulture.DisplayName);
Console.WriteLine("Name: {0}", currentCulture.Name);
Console.WriteLine("LCID: {0}", currentCulture.LCID);
Console.WriteLine();

Console.WriteLine("NumberFormatInfo");
Console.WriteLine("----------------");
Console.WriteLine("Decimal Seperator: {0}", currentCulture.NumberFormat.NumberDecimalSeparator);
Console.Write("Digits: ");

foreach (string s in currentCulture.NumberFormat.NativeDigits)
{
Console.Write(s + " ");
}

Console.WriteLine();
}
}
}




Base output should look like:


CultureInfo
-----------
DisplayName: English (United States)
Name: en-US
LCID: 1033



NumberFormatInfo
----------------
Decimal Seperator: .
Digits: 0 1 2 3 4 5 6 7 8 9

Filtering a network stream using a wrapper

So not that long ago, someone posted a question asking how to deal with a certain situation. The situation is such that there is a network file stream coming from somewhere, that has certain data you want to keep, and certain data you don't want to keep. Control blocks, extra header information, weirdo protocol, too much data coming back form an API, etc..

My suggestion was to create a simple container object (aka wrapper) to the existing network stream, that operates the same as the network stream, but does the necessary filtering.

Here's an example of how you'd use it, and and example base class implementation of for the filters follows it. In the actual problem case example, he was dealing with a NetworkStream that contained Xml data in irregular chunks, with control blocks as fixed headers. Each header indicates how much XmlData follows. The filter will remove the headers as needed, presenting a simple stream of Xml data to the XmlReader to parse.

I've left out the concrete implementation that actually parses the stream, and here you just have the FilteredNetworkStream base class and an idea of how to use it once you implement it. All that's left for the implementer is to override the abstract method FilterBeforeRead, which contains the customized filtering logic for the particular situation.



using (NetworkStream inputStream = GetNetworkStreamFromSomewhere())
using (StreamWriter outputStream = new StreamWriter(@"C:\Path\To\File.xml", false))
{

XmlReader reader = XmlReader.Create(new FilteredNetworkStream(inputStream));
while (reader.Read())
{

// method returns empty string if current data is discardable
string outputData = GetDesiredDataFromReader(reader);

if (!string.IsNullOrEmpty(outputData))
{

// save desired data to local file
outputStream.Write(outputData);
}
}
}


Here's the base class:


public abstract class FilteredNetworkStream : Stream
{
public FilteredNetworkStream(NetworkStream baseStream)
{
_baseStream = baseStream;
}

protected NetworkStream _baseStream;
public abstract void FilterBeforeRead();

#region Stream Implementation

public override bool CanRead
{
get { return _baseStream.CanRead; }
}

public override bool CanSeek
{
get { return _baseStream.CanSeek; }
}

public override bool CanWrite
{
get { return _baseStream.CanWrite; }
}

public override void Flush()
{
_baseStream.Flush();
}

public override long Length
{
get { return _baseStream.Length; }
}

public override long Position
{
get
{
return _baseStream.Position;
}
set
{
_baseStream.Position = value;
}
}

public override int Read(byte[] buffer, int offset, int count)
{
this.FilterBeforeRead();
return _baseStream.Read(buffer, offset, count);
}

public override long Seek(long offset, SeekOrigin origin)
{
return _baseStream.Seek(offset, origin);
}

public override void SetLength(long value)
{
_baseStream.SetLength(value);
}

public override void Write(byte[] buffer, int offset, int count)
{
_baseStream.Write(buffer, offset, count);
}

#endregion
}