

import pandas as pd
import matplotlib.pyplot as plt

DIM_N = 4
POINT_N = 150
CLUST_N = 3

class POINT:
    def __init__(self, x):
        self.clust = CLUST_N
        self.X = x
    def ToClust(self, CC):
        n = CLUST_N
        d = 1.0e+99
        for c in CC:
            if c.Dist(self) < d:
                n = c.clust
                d = c.Dist(self)
        self.clust = n

class CLUSTER(POINT):
    def __init__(self, cl, x):
        POINT.__init__(self, x)
        self.clust = cl
        self.N = 0
    def Dist(self, p):
        dd = 0.0
        for i in range(0, DIM_N, 1):
            dd += (self.X[i] - p.X[i]) * (self.X[i] - p.X[i])
        return dd
    def Eval(self, P):
        self.N = 0
        for i in range(0, DIM_N):
            self.X[i] = 0.0
        for p in P:
            if p.clust == self.clust:
                self.N += 1
                for i in range(0, DIM_N):
                    self.X[i] += p.X[i]
        for i in range(0, DIM_N):
            self.X[i] /= self.N

Cl = []
Cl.append(CLUSTER(0, [3.5, 5.0, 0.0, -1.2]))
Cl.append(CLUSTER(1, [5.5, 1.0, 4.0, 1.2]))
Cl.append(CLUSTER(2, [8.0, 3.5, 7.5, 3.5]))

iris = pd.read_csv('../../DATASET/iris.csv')
print(iris.info(),"\n")

V = []
for i in range(0, POINT_N):
	if iris['variety'][i] == "Setosa":
		V.append(0)
	elif iris['variety'][i] == "Versicolor":
		V.append(1)
	elif iris['variety'][i] == "Virginica":
		V.append(2)

PP = [POINT([iris['sepal.length'][i], iris['sepal.width'][i], iris['petal.length'][i], iris['petal.width'][i]]) for i in range(0, len(iris))]

# Show data

colors = ['blue', 'green', 'magenta', 'black']
fig, axes = plt.subplots(2, 3, figsize=(10, 6))
n = 0
for i in range(0, DIM_N):
    for j in range(i+1, DIM_N):
        ix = int(n/3)
        iy = int(n%3)
        for k in range(0, POINT_N):
            axes[ix][iy].scatter(PP[k].X[i], PP[k].X[j], c=colors[V[k]], s=50)
#        for c in Cl:
#            axes[ix][iy].scatter(c.X[i], c.X[j], c='red', marker='*', s=200)
        axes[ix][iy].set_xlabel('$Axis: ('+str(i)+', '+str(j)+')$', fontsize=10)
        axes[ix][iy].set_xticks([])
        axes[ix][iy].set_yticks([])
        n += 1
fig.tight_layout()
plt.show()

# K-means
Prec0 = 0.0
while True:

    for p in PP:
        p.ToClust(Cl)

    for cl in Cl:
        cl.Eval(PP)


    n = 0
    for i in range(0, len(iris)):
        if PP[i].clust == V[i]:
            n += 1
    Prec = float(n)/len(iris)
    print('=====> ', Prec)
    if abs(Prec - Prec0) < 0.001:
        break

    Prec0 = Prec

fig, axes = plt.subplots(2, 3, figsize=(10, 6))
n = 0
for i in range(0, DIM_N):
    for j in range(i+1, DIM_N):
        ix = int(n/3)
        iy = int(n%3)
        for p in PP:
            axes[ix][iy].scatter(p.X[i], p.X[j], c=colors[p.clust], s=50)
        for c in Cl:
            axes[ix][iy].scatter(c.X[i], c.X[j], c='red', marker='*', s=200)
        axes[ix][iy].set_xlabel('$Axis: ('+str(i)+', '+str(j)+')$', fontsize=10)
        axes[ix][iy].set_xticks([])
        axes[ix][iy].set_yticks([])
        n += 1
fig.tight_layout()
plt.show()

   
