import numpy
import pandas

import sklearn.metrics
from sklearn.svm import SVC

from PIL import Image, ImageFilter

import matplotlib.pyplot as plt


def get_number_of_edge_pixels(img):
    """
    Returns the number of pixels which are detected as an edge of the given image.

    :param img: image which is processed
    :return number of pixels which are detected as an edge
    """

    # apply edge filter and load pixel map of the image
    pixel_map = img.filter(ImageFilter.FIND_EDGES).load()

    # count pixels which are white in the filtered image
    white_count = 0
    for x in range(img.size[0]):
        for y in range(img.size[1]):
            if pixel_map[x, y][0] >= 100:
                # classify pixels with a brightness above 100 as white
                white_count = white_count + 1

    return white_count


def read_data_set_from_csv(csv_path, max_row_count=None):
    """
    Read the data set from the csv file which is accessible at the given path and return two arrays.
    The first array should contain the image data of the images of the data set.
    The second array should contain the classifications of the images in string format.

    The max_row_count parameter is used to limit the number of returned entries from the data set. This might be
    useful for testing, since the training of classifiers takes much longer for a high number of entries.

    :param csv_path:      Path to the csv file describing the data set which is loaded
    :param max_row_count: Maximum number of data set entries which are returned
    :return: Tuple containing the image data array and the classification string array
    """

    df = pandas.read_csv(csv_path, delimiter=',', names=['image_path', 'classification', 'image_data'])
    df = df.apply(lambda row: [row[0], row[1], Image.open(row[0])], axis=1)

    # cut down the number of rows if needed
    df = df[:max_row_count] if max_row_count is not None else df

    return df['image_data'].values, df['classification'].values


def get_classifier_input_for_image(img):
    """
    Convert to given image to an appropriate format which can be passed to a classifier and return it.
    The meaningfulness of the returned data can be increased, by applying filters to the image.

    :param img: Image which is converted to an appropriate format
    :return: data for the given image which can be passed to a classifier
    """

    img_array = numpy.array(img).flatten()
    smooth_contour_array = numpy.array(img.filter(ImageFilter.SMOOTH).filter(ImageFilter.CONTOUR)).flatten()
    blur_array = numpy.array(img.filter(ImageFilter.BLUR)).flatten()
    detail_array = numpy.array(img.filter(ImageFilter.DETAIL)).flatten()
    white_array = numpy.array([get_number_of_edge_pixels(img)])

    # combine all image filters into one array to pass to the classifier
    return numpy.concatenate([img_array, detail_array, blur_array, smooth_contour_array, white_array])


def plot_predictions(images, predicted_classifications, correct_classifications):
    """
    Plot a figure containing at 40 images together with their predicted classifications.
    The classification is shown in green if it is correct, and red if it isn't.

    Additionally, display a bar chart which visualizes the accuracy of the predicted classifications grouped by
    category.

    :param images:                    This is the array of images, that must be showed
    :param predicted_classifications: There is the list of all the predictions of the images
    :param correct_classifications:   There are the correct values of prediction of the data
    """

    # create a new figure with a given size
    plt.figure(figsize=(8, 6))

    # plot each of the first 40 images
    for index, img in enumerate(images[:40]):
        plt.subplot(6, 10, 1 + index)
        plt.axis('off')
        plt.imshow(img)
        
        # show the labels and color them depending on whether the prediction was correct
        color = 'green' if predicted_classifications[index] == correct_classifications[index] else 'red'
        plt.title(predicted_classifications[index], color=color)

    # plot the performance of the classifier for each category

    labels = numpy.unique(predicted_classifications)
    # get the f1 score (a measure of performance) of each of the categories
    f1_scores = sklearn.metrics.f1_score(predicted_classifications, correct_classifications, average=None, labels=labels)
    # create a new section on the grid that spans 1 rows and 10 columns on the bottom of the figure
    plt.subplot2grid((6, 10), (5, 0), rowspan=1, colspan=10)
    # plot the scores
    plt.barh(numpy.arange(len(labels)), 100 * f1_scores)
    plt.xlabel("% Performance")
    plt.xlim(xmax=100)
    plt.yticks(numpy.arange(len(labels)), labels)
    plt.show()


def main():
    """
    Main entry point of the program. this method is called, when the program is executed.
    """

    # load the training data set and the test data set
    training_images, training_classifications = read_data_set_from_csv('training.csv', 1000)
    test_images, test_classifications = read_data_set_from_csv('test.csv')

    # convert the images of the data sets to an appropriate format for the classifier
    training_classifier_image_input = numpy.frompyfunc(get_classifier_input_for_image, 1, 1)(training_images)
    test_classifier_image_input = numpy.frompyfunc(get_classifier_input_for_image, 1, 1)(test_images)

    # instantiate classifiers which are used to predict
    classifiers = [
        SVC(kernel="poly", C=0.025)
        # it only uses one right now, but you can add more to compare performance of classifiers
    ]

    for c in classifiers:
        # train classifier and predict classifications for test images
        c.fit(training_classifier_image_input.tolist(), training_classifications.tolist())
        predicted_classifications = c.predict(test_classifier_image_input.tolist())

        # print classification report
        print('-' * 60)
        print(str(c))
        print(sklearn.metrics.classification_report(test_classifications, predicted_classifications))

        # plot images with their predicted classifications and an accuracy bar chart
        plot_predictions(test_images, predicted_classifications, test_classifications)


if __name__ == "__main__":
    main()
