# Density of data points

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn.preprocessing
import random

import keras
from keras.models import Sequential
from keras.layers import Dense
#from keras.utils import np_utils


POINT_N = 1000
DIM_N = 2

data0 = pd.read_csv("./class_splash_3_w100_PCA.csv", sep=',')
print("Info of the data:", data0.info())

plt.figure(figsize=(8, 8))
plt.scatter(data0["Par1"], data0["Par2"], c='blue', marker='o')
plt.show()

data0 = data0.to_numpy(copy=True)
scaler = sklearn.preprocessing.MinMaxScaler((-1.0, 1.0))
data0 = scaler.fit_transform(data0)

# ---------------------------------------------------------------------------------------------
model = Sequential()
model.add(Dense(30, input_dim=DIM_N, activation='relu'))
model.add(Dense(2, activation='softmax'))

opt = keras.optimizers.RMSprop(learning_rate=.001, momentum=0.9)
model.compile(loss='mse', optimizer=opt, metrics=['accuracy'])


DD = int(2.0 * POINT_N)

data = data0
np.random.shuffle(data)
data = data[0:POINT_N,:]
data = np.concatenate((data, np.zeros((data.shape[0], 1))), axis=1)

data_noise = np.random.rand(DD, DIM_N) * 2.0 - 1.0
data_noise = np.concatenate((data_noise, np.ones((DD, 1))), axis=1)

data = np.concatenate([data, data_noise], axis=0)
np.random.shuffle(data)

x = data[:, 0:DIM_N]
y = data[:, DIM_N].astype(int)
y = keras.utils.to_categorical(y, num_classes=2)

callback = keras.callbacks.EarlyStopping(monitor="loss", mode="min", verbose=1, patience=100, min_delta=0.001)
history = model.fit(x, y, batch_size=100, epochs=1000, verbose=1, callbacks=[callback])

# -----------------------------------------------------------------------------------------------

np.random.shuffle(data0)
data = data0[0:3000, :]

pred = model.predict(data)
res = np.argmax(pred, axis=1)
plt.figure(figsize=(8, 8))
for n in range(len(data)):
    plt.scatter(data[n][0], data[n][1], c=('blue' if res[n] == 0 else 'red'), marker='o')
plt.show()
a = 1.0 - np.sum(res)/len(data)
print("\n", "Accuracy of the dataset", a)
