// SimpleID3 - A implementation of ID3 algorithm
// (c) Devendra Laulkar <devendralaulkar@yahoo.com>
// This file is released under GNU GPL v 2.0
// Some parts of the code in this file are taken from the Weka Project
// which is also licensed under the GNU GPL v 2.0

package weka.classifiers.trees;
// Place this file in weka/classifiers/trees
// To run this file, use it like any other classifier

import java.util.*;

import org.omg.PortableInterceptor.SUCCESSFUL;

import weka.classifiers.Classifier;
import weka.core.*;

//This object of this class is itself a node in the tree
//From this node, we create branches
public class SimpleID3 extends Classifier {

	// Sub nodes from this tree
	private SimpleID3[] m_SubNodes;

	// The attribute being tested at this Node
	private Attribute m_Attribute;

	// Parent node
	private SimpleID3 m_Parent;

	// If Leaf node, then set to either Yes / No
	private double m_leafValue;
	
	/* Index of the class */
	private double m_ClassValue;


	public void buildClassifier(Instances data) throws Exception {

		// Create a copy of data for local use
		// localData contains examples, target_attribute and attributes to test

		Instances localData = new Instances(data);
		//double entropy = calcEntropy(localData);
		//double gain = calcGain(data.attribute("Outlook"), data);
		createTree(localData);

	}

	public void createTree(Instances data)
	{

		boolean allPos = true;
		boolean allNeg = true;
		int i = 0;
		Attribute attr, maxGainAttr = null;
		Map gains = new HashMap();

		double maxGain = 0.0, temp;
		
		m_ClassValue = (double)data.numAttributes() - 1;
		for(i = data.numInstances() - 1; i >= 0; i--)
		{
			if(data.instance(i).classValue() == 0.0)
				allPos = false;
			if(data.instance(i).classValue() == 1.0)
				allNeg = false;

			if(allPos == false && allNeg == false)
				break;
		}

		// 	If all examples are positive, return as leaf with label +		
		if(allPos)
		{
			this.m_leafValue = 1.0;
			this.m_SubNodes = null;
			return;
		}

//		If all examples are negative, return as leaf with label -
		if (allNeg)
		{
			this.m_leafValue = 0.0;
			this.m_SubNodes = null;
			return;
		}

		// Find the best attribute

		attr = maxGainAttr(data);
		m_Attribute = attr;
		Instances splitData[] = splitData(data, attr);

		m_SubNodes = new SimpleID3[attr.numValues()];
		
		for(i = 0; i < attr.numValues(); i++)
		{
			m_SubNodes[i] = new SimpleID3();
			m_SubNodes[i].createTree(splitData[i]);
			m_SubNodes[i].m_Parent = this;
			
		}

		// If Gain is 0, then make leaf, else make a recursive call to generate nodes
	}

	Attribute maxGainAttr(Instances data)
	{
		Attribute attr, maxGainAttr = null;
		double temp, maxGain = 0.0;

		// Loop over all current attributes, find out information gain for all
		// Select the one with the maximum Gain
		for(Enumeration a = data.enumerateAttributes(); a.hasMoreElements();)
		{
			attr = (Attribute) a.nextElement();
			temp = calcGain(attr,data);
			if(temp > maxGain)
			{
				maxGain = temp;
				maxGainAttr = attr;
			}

			//gains.put((String)attr.name(),(double) calcGain(attr,data));
		}

		//System.out.println("Best attribute is " + maxGainAttr + " with gain of " + maxGain);
		return maxGainAttr;
	}
	double calcGain(Attribute attr, Instances data)
	{
		// Algo :- Split the dataset by the attribute values
		// Go over the split data set, and calculate gain
		// Gain = Entropy(data) - sum of relative entropies of each subset

		double entropyS = calcEntropy(data);

		Instances []splitData = splitData(data, attr);

		for( int i = 0; i < attr.numValues(); i++)
		{
			entropyS -= (splitData[i].numInstances() / (double)data.numInstances()) * calcEntropy(splitData[i]);
		}

		return entropyS;


	}

	Instances [] splitData(Instances data, Attribute attr)
	{
		Instance instance;

		// Create the split data array
		Instances []splitData = new Instances[attr.numValues()];

		// Initialize with parent dataset information and maximum size

		for(int i = 0; i < attr.numValues(); i++)
		{
			splitData[i] = new Instances(data, data.numInstances());
		}

		// Iterate over all instances and split them according to value of attribute
		for(Enumeration e = data.enumerateInstances(); e.hasMoreElements();)
		{
			instance = (Instance) e.nextElement();
			splitData[(int) instance.value(attr)].add(instance);
		}

		return splitData;
	}
	double calcEntropy(Instances data)
	{
		// For each class, count number of instances in data
		Instance instance;

		double entropy = 0;

		double []classes = new double[data.numClasses()];

		for(Enumeration e = data.enumerateInstances(); e.hasMoreElements();)
		{
			instance = (Instance) e.nextElement();

			classes[(int)instance.classValue()]++;

		}

		// Entropy is summation of Number of instances for the class / total instances * log 2 ( Num instances of class/total instance )
		// for each class
		for(int i = data.numClasses() - 1; i >= 0; i--)
		{
			// Define 0 log 0 as 0
			if( classes[i] == 0)
				continue;

			entropy -= classes[i] * Utils.log2(classes[i] / data.numInstances());
		}

		if (entropy == 0)
			return 0.0;
		entropy /= data.numInstances();

		return entropy;
	}

	public static void main(String [] args)
	{
		runClassifier(new SimpleID3(), args);
	}
	
	// A classfier must implement below method (or distribution for instance)
	// 
	public double classifyInstance(Instance instance)
	{
		// Check if leaf node, if yes - return the class label
		// If not leaf, we have subnodes, select the subnode depending upon the instance value of the 
		// attribute tested at this node and call classifyInstance recursively
		
		if( this.m_SubNodes == null )
			return this.m_leafValue;
		
		return this.m_SubNodes[(int)instance.value(this.m_Attribute)].classifyInstance(instance);
		
		
	}
	
	public String toString()
	{
		return toString(0);
	}
    // Following code is from Weka 
	private String toString(int level) {

		StringBuffer text = new StringBuffer();

		if (m_SubNodes == null)
		{
			text.append(": " + (this.m_leafValue == 0 ? "Yes" : "No") );

		} 
		else
			for (int j = 0; j < m_Attribute.numValues(); j++) 
			{
				text.append("\n");
				for (int i = 0; i < level; i++) 
					text.append("|  ");	

				text.append(m_Attribute.name() + " = " + m_Attribute.value(j));
				text.append(m_SubNodes[j].toString(level + 1));
			}
			
		return text.toString();
	}

}
