Methods of Clustering For Product Managers

Welcome! My name is Anant Agrawal and I am an Information Systems and Statistics double major at Carnegie Mellon University. I'm really excited about working as a full-time Product Manager after graduating and just wanted to write this small tutorial on the practical uses of clustering algorithms for Product Management. I hope you find it useful.

What can Clustering Algorithms Do For You?

As I will demonstrate below with the Yelp dataset, these clustering algorithms can be really useful in defining the different clusters of user preferences for the content on your platform or your different products. This means you can better understand

  • Which user preferences are most dominant on your platform
  • How homogeneous your users’ behaviors are
  • How this breakdown occurs for different user demographics and product attributes
  • How to make personalized recommendations to users

In fact, this analysis can largely be replicated using the Amazon dataset , and for any platform delivering content that has categories (such as books and movies) and data about user preferences.

So What Exactly are Clustering Algorithms?

Clustering algorithms do exactly what you would guess they do. Given information about a set of objects (users, products, etc.), clustering algorithms differentiate between these objects based on their similarity in the given attribute(s). On a high level, they work by defining centers and putting all the points closest to that center in a certain cluster. They repeat this process until they have minimized the error, which is defined by the sum of all the distances between the center and each point in that cluster. This is not the way every clustering algorithm works, but it offers a good high level understanding of what these algorithms are trying to do. In this tutorial I am going to highlight two clustering algorithms that can be really powerful for product management.

  • K-means: This algorithm works as described above in that it minimizes the total distance between group centers and every point. The advantages of this algorithm is that it is really fast as it is just computing distances. However, this algorithm can only create circular clusters as it works by creating a radius to define a cluster.

  • Gaussian Mixture Models: This algorithm is a little more complicated in that it creates a multidimensional probability distribution for every point. It does this by assuming the points are Gaussian distributed. All this means is that it calculates the probability of a point belonging in each cluster for every point using both the mean and standard deviation parameters of the data. Using this model, an object can have partial membership in more than 1 cluster and clusters are not just restricted to being circular.

So, without further ado, let's dive into examples that help show us when one algorithm is better than another.

Example 1: Yelp Open Dataset

In my opinion, this is one of the coolest datasets I could find to show the true power of clustering algorithms. This example will analyze the Open Yelp Dataset which has almost 7 million reviews by real people on restaurants and other businesses from all over the country. Specifically, we are going to start by using clustering to see if people in have a stronger preference for the restaurants in Pittsburgh that serve burgers over the restuarants that serve Chinese Food or if it is the other way around.

Step 1 - Lets Look At the Data

In [1]:
#As always, lets start by importing the libraries we need
import json    
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
import helper

#And the tools from Sklearn to do our clustering
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_samples, silhouette_score
from sklearn.metrics import mean_squared_error

from scipy.sparse import csr_matrix

#To make things cleaner, lets also not display all the warnings
import warnings
warnings.filterwarnings('ignore')

The Open Yelp dataset is comprised of giant JSON files with information about the users, businesses, and ratings on Yelp. Lets start by loading in the data and seeing what information we have about a business and review:

In [2]:
businesses = []
with open('business.json') as f:
    for line in f:
        businesses.append(json.loads(line))
        
reviews = []
with open('review.json') as f:
    for line in f:
        reviews.append(json.loads(line))
        
testBusiness = businesses[1]
testReview = reviews[1]

print(testBusiness)
print('')
print(testReview)
{'business_id': 'QXAEGFB4oINsVuTFxEYKFQ', 'name': 'Emerald Chinese Restaurant', 'address': '30 Eglinton Avenue W', 'city': 'Mississauga', 'state': 'ON', 'postal_code': 'L5R 3E7', 'latitude': 43.6054989743, 'longitude': -79.652288909, 'stars': 2.5, 'review_count': 128, 'is_open': 1, 'attributes': {'RestaurantsReservations': 'True', 'GoodForMeal': "{'dessert': False, 'latenight': False, 'lunch': True, 'dinner': True, 'brunch': False, 'breakfast': False}", 'BusinessParking': "{'garage': False, 'street': False, 'validated': False, 'lot': True, 'valet': False}", 'Caters': 'True', 'NoiseLevel': "u'loud'", 'RestaurantsTableService': 'True', 'RestaurantsTakeOut': 'True', 'RestaurantsPriceRange2': '2', 'OutdoorSeating': 'False', 'BikeParking': 'False', 'Ambience': "{'romantic': False, 'intimate': False, 'classy': False, 'hipster': False, 'divey': False, 'touristy': False, 'trendy': False, 'upscale': False, 'casual': True}", 'HasTV': 'False', 'WiFi': "u'no'", 'GoodForKids': 'True', 'Alcohol': "u'full_bar'", 'RestaurantsAttire': "u'casual'", 'RestaurantsGoodForGroups': 'True', 'RestaurantsDelivery': 'False'}, 'categories': 'Specialty Food, Restaurants, Dim Sum, Imported Food, Food, Chinese, Ethnic Food, Seafood', 'hours': {'Monday': '9:0-0:0', 'Tuesday': '9:0-0:0', 'Wednesday': '9:0-0:0', 'Thursday': '9:0-0:0', 'Friday': '9:0-1:0', 'Saturday': '9:0-1:0', 'Sunday': '9:0-0:0'}}

{'review_id': 'GJXCdrto3ASJOqKeVWPi6Q', 'user_id': 'yXQM5uF2jS6es16SJzNHfg', 'business_id': 'NZnhc2sEQy3RmzKTZnqtwQ', 'stars': 5.0, 'useful': 0, 'funny': 0, 'cool': 0, 'text': "I *adore* Travis at the Hard Rock's new Kelly Cardenas Salon!  I'm always a fan of a great blowout and no stranger to the chains that offer this service; however, Travis has taken the flawless blowout to a whole new level!  \n\nTravis's greets you with his perfectly green swoosh in his otherwise perfectly styled black hair and a Vegas-worthy rockstar outfit.  Next comes the most relaxing and incredible shampoo -- where you get a full head message that could cure even the very worst migraine in minutes --- and the scented shampoo room.  Travis has freakishly strong fingers (in a good way) and use the perfect amount of pressure.  That was superb!  Then starts the glorious blowout... where not one, not two, but THREE people were involved in doing the best round-brush action my hair has ever seen.  The team of stylists clearly gets along extremely well, as it's evident from the way they talk to and help one another that it's really genuine and not some corporate requirement.  It was so much fun to be there! \n\nNext Travis started with the flat iron.  The way he flipped his wrist to get volume all around without over-doing it and making me look like a Texas pagent girl was admirable.  It's also worth noting that he didn't fry my hair -- something that I've had happen before with less skilled stylists.  At the end of the blowout & style my hair was perfectly bouncey and looked terrific.  The only thing better?  That this awesome blowout lasted for days! \n\nTravis, I will see you every single time I'm out in Vegas.  You make me feel beauuuutiful!", 'date': '2017-01-14 21:30:33'}

Let's Focus In on Pittsburgh

Since we want to know about the preferences users have for businesses in Pittsburgh we should probably filter the businesses and reviews so we only get the information regarding Pittsburgh businesses. It would be really interesting to see how this data compares to results from other cities such as 'Boston' or 'San Francisco'.

In [3]:
#Get the businesses from Pittsburgh
pitt_business_ids = []
pitt_business_names = []
pitt_business_categories = []

for jsonObj in businesses:
    if jsonObj['city'] == 'Pittsburgh':
        pitt_business_ids.append(jsonObj['business_id'])
        pitt_business_names.append(jsonObj['name'])
        
        categories = None
        if jsonObj['categories'] != None:
            categories = ""
            for category in jsonObj['categories']:
                categories += category
            
        pitt_business_categories.append(categories)

business = pd.DataFrame({'business_id': pitt_business_ids,
                             'name': pitt_business_names,
                             'category': pitt_business_categories})
print(len(business))
business.head()
7017
Out[3]:
business_id name category
0 1RHY4K3BD22FK7Cfftn8Mg Marathon Diner Sandwiches, Salad, Restaurants, Burgers, Comfo...
1 dQj5DLZjeDK3KFysh1SYOQ Apteka Nightlife, Bars, Polish, Modern European, Rest...
2 v-scZMU6jhnmV955RSzGJw No. 1 Sushi Sushi Japanese, Sushi Bars, Restaurants
3 t-6tdxRaz7s9a0sf94Tguw Impressionz Restaurants, Caribbean
4 5WMIvoMx3l1vn1uJ3HZB6Q Subway Fast Food, Sandwiches, Restaurants
In [8]:
#Get all the review information
user_ids = []
bus_ids = []
ratings = []

for jsonObj in reviews:
    user_ids.append(jsonObj['user_id'])
    bus_ids.append(jsonObj['business_id'])
    ratings.append(jsonObj['stars'])

ratings = pd.DataFrame({'user_id': user_ids,
                             'business_id': bus_ids,
                             'rating': ratings})

print(len(ratings))
ratings.head()
6685900
Out[8]:
user_id business_id rating
0 hG7b0MtEbXx5QzbzE6C_VA ujmEBvifdJM6h6RLv4wQIg 1.0
1 yXQM5uF2jS6es16SJzNHfg NZnhc2sEQy3RmzKTZnqtwQ 5.0
2 n6-Gk65cPZL6Uz8qRm3NYw WTqjgwHlXbSFevF32_DJVw 5.0
3 dacAIZ6fTM6mqwW5uxkskg ikCg8xy5JIg_NGPx-MSIDA 5.0
4 ssoyf2_x0EQMed6fgHeMyQ b1b1eb3uo-w561D0ZfCEiQ 1.0
In [5]:
# #Filter so we only get the information for reviews on Pittsburgh businesses
# ratings = ratings[ratings['business_id'].isin(pitt_business_ids)]
# print(len(ratings))
# ratings.head()
225496
Out[5]:
user_id business_id rating
5 w31MKYsNFMrjhWxxAb5wIw eU_713ec6fTGNO4BegRaww 4.0
12 GYNnVehQeXjty0xH7-6Fhw FxLfqxdYPA6Z85PFKaqLrg 4.0
28 q3GeSW9dWN9r_ocqFkhrvg 9nTF596jDvBBia2EXXiOOg 1.0
36 _o740mSNRhMNYuPjSJoPLg sMzNLdhJZGzYirIWt-fMAg 5.0
55 JjDR060LJQcNNVWKuU64fA hcFSc0OHgZJybnjQBrL_8Q 4.0
In [11]:
print(len(ratings))
6685900

Time to Get Ready For Clustering

Now we have the data that we want. But we still have to restructure it to be able to cluster properly. The way we want to structure the data is that every column represents a category ('Chinese', 'Indian', 'Burger') and every row a user. In each row we want the average rating that the user gave for each category. Any category for which they have no reviews is just an NA.

In [12]:
# Calculate the average rating of restaurants that serve burgers and restaurants that serve chinese food. 

def create_category_table(categories, column_names):
    category_ratings = pd.DataFrame()
    
    for category in categories:
        cat_businesses = business[business['category'].str.contains(pat=category, na=False)]
        ratings_from_cat = ratings[ratings['business_id'].isin(cat_businesses['business_id'])]
        avg_votes_per_user = ratings_from_cat.loc[:, ['user_id', 'rating']].groupby(['user_id'])['rating'].mean().round(2)
        category_ratings = pd.concat([category_ratings, avg_votes_per_user], axis=1)

    category_ratings.columns = column_names

    return category_ratings

categories = ["Sandwiches",
"Pizza",
"Chinese",
"Food Stands",
"Steakhouses",
"Mexican",
"Fast Food",
"Seafood",
"Indian",
"Gluten-Free",
"Breakfast & Brunch",
"Delis",
"Burgers",
"Salad",
"Vegan",
"Comfort Food",
"Mediterranean",
"Latin American",
"German",
"Cafes",
"Vegetarian",
"Italian",
"Middle Eastern",
"Diners",
"Hot Dogs",
"Caribbean",
"French",
"Buffets",
"Thai"]

BurgerChineseRatings = create_category_table(categories, categories)
BurgerChineseRatings
Out[12]:
Sandwiches Pizza Chinese Food Stands Steakhouses Mexican Fast Food Seafood Indian Gluten-Free ... Cafes Vegetarian Italian Middle Eastern Diners Hot Dogs Caribbean French Buffets Thai
--CH8yRGXhO2MmbF-4BWXg 5.0 NaN NaN NaN NaN NaN NaN NaN NaN 5.0 ... 5.0 5.0 NaN NaN NaN NaN NaN NaN NaN NaN
--EMqnd727rtC0G5Oc-Mrg NaN NaN NaN NaN NaN 5.0 NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
--OECAoqfSTBaZ3biOyzwA 5.0 5.0 5.0 NaN NaN NaN NaN 5.0 NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
--RBfYfIpx44V5Kux2fPFA NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN 4.0 NaN NaN NaN NaN NaN NaN NaN NaN
--TvGNywm2I1iwNWZmerBA NaN 1.0 NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
zzGUiwY-emOE0dkTRV0ztw NaN NaN NaN NaN NaN NaN NaN 5.0 NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
zzSujOEhYzdCduvpLisCvw 5.0 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
zzTHRzFR-a_F0YZ7c56vcA 3.0 NaN NaN NaN NaN NaN NaN 2.0 NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
zzYMFAyY5-ZpsPaknmt8tw NaN NaN NaN NaN NaN NaN 5.0 5.0 NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
zzx7k7JqCQNhhqL4VJxL0A 4.0 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN 3.0 NaN NaN NaN NaN NaN NaN NaN

44986 rows × 29 columns

In [13]:
BurgerChineseRatings.to_csv(path_or_buf="yelpcf.csv", index=False)

And now we are going to create the scatterplot comparing the restaurant ratings for restaurants with burgers versus Chinese restaurants.

In [7]:
def draw_scatterplot(x_data, x_label, y_data, y_label):
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111)

    plt.xlim(0, 5)
    plt.ylim(0, 5)
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.scatter(x_data, y_data, s=30)
    
draw_scatterplot(BurgerChineseRatings['avg_burgers_ratings'],'avg_burgers_ratings', BurgerChineseRatings['avg_chinese_ratings'], 'avg_chinese_ratings')

An interesting thing to note here is that when drawing the scatterplot NAs are automatically dropped. This means the data is showing only users who ate at both restaurants that served burgers and ones that served chinese food. Also, user preferences are hard to see. To make them easier to see we will need to get rid of the users that have high ratings for both categories or low ratings for both categories. This way we can see the preferences of users who actually have a strong preferences for one category over another.

In [9]:
def highPreferenceUsers(category_ratings, category1Data, category2Data, score_limit_1, score_limit_2):
    biased_dataset = category_ratings[((category1Data < score_limit_1) & (category2Data > score_limit_2)) | ((category2Data < score_limit_1) & (category1Data > score_limit_2))]
    biased_dataset = pd.concat([biased_dataset[:300], category_ratings[:2]])
    biased_dataset = pd.DataFrame(biased_dataset.to_records())

    return biased_dataset

biased_dataset = highPreferenceUsers(BurgerChineseRatings, BurgerChineseRatings['avg_burgers_ratings'], BurgerChineseRatings['avg_chinese_ratings'], 3, 3)
In [10]:
draw_scatterplot(biased_dataset['avg_burgers_ratings'],'avg_burgers_ratings', biased_dataset['avg_chinese_ratings'], 'avg_chinese_ratings')

Alright, so its still hard to see the different clusters here. Particularly it is hard to tell if there are more users who have a strong preference for burgers over chinese food or whether there are more users who have a strong preference for chinese food over burgers. This is where clustering comes in.

In [11]:
def draw_clusters(biased_dataset, xlabel, ylabel, predictions, cmap='viridis'):
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111)
    plt.xlim(0, 5)
    plt.ylim(0, 5)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    clustered = pd.concat([biased_dataset.reset_index(), pd.DataFrame({'group':predictions})], axis=1)
    plt.scatter(clustered[xlabel], clustered[ylabel], c=clustered['group'], s=20, cmap=cmap)
In [12]:
#Create the dataset with the users who have strong preferences for a category
X = biased_dataset[['avg_burgers_ratings','avg_chinese_ratings']].values

#Drop any NA values (Can't cluster with NAs)
X = pd.DataFrame(X).dropna()
X[0]
Out[12]:
0      1.00
1      1.00
2      4.00
3      1.00
4      2.50
       ... 
295    1.00
296    2.50
297    4.00
298    2.33
299    4.00
Name: 0, Length: 300, dtype: float64

But before we start clustering, we should stop to think about which algorithm to use. Remember that it is stated that the Gaussian Mixture Model assumes the data follows the Gaussian assumption. A good way to get some idea about this is to see if the data for each set of ratings looks normally distributed. Below is the distribution for each variable:

In [13]:
plt.hist(X[0], bins='auto')
plt.show()
In [14]:
plt.hist(X[1], bins='auto')
plt.show()

Looking at these distributions, it seems pretty clear that these distributions are not normal. So should we abandon the Gaussian Mixture Model? Lets see how the clusters compare between K-means and the Gaussian Mixture Model:

In [15]:
#GMM Model
gmm = GaussianMixture(n_components=2).fit(X)
predictionsGMM = gmm.predict(X)

# Plot
draw_clusters(biased_dataset, 'avg_burgers_ratings', 'avg_chinese_ratings', predictionsGMM)
In [16]:
# Use K means to define 2 cluster groups
kmeans_1 = KMeans(n_clusters=2)
predictions = kmeans_1.fit_predict(X)

# Plot
draw_clusters(biased_dataset, 'avg_burgers_ratings', 'avg_chinese_ratings', predictions)

Wow, they look pretty similar. What if we tried to find 3 clusters instead?

In [17]:
#GMM Model for 3 clusters
gmm = GaussianMixture(n_components=3).fit(X)
predictionsGMM2 = gmm.predict(X)

# Plot
draw_clusters(biased_dataset, 'avg_burgers_ratings', 'avg_chinese_ratings', predictionsGMM2)

This is one example of the scatterplot the GMM model produced with this data:

In [18]:
from IPython.display import Image
Image("GMM.png")
Out[18]:
In [19]:
# Use K means to define 3 clusters
kmeans_3 = KMeans(n_clusters=3)
predictions2 = kmeans_3.fit_predict(X)

# Plot
draw_clusters(biased_dataset, 'avg_burgers_ratings', 'avg_chinese_ratings', predictions2)

Now we can see a much bigger difference. The Gaussian Mixture Model seems way off as it created a small cluster far above all the other clusters. Now, sometimes the Gaussian Mixture Model comes up with decent looking clusters even when the Gaussian assumption does not hold, but as can be seen in this case, it is better to stick with K means as K means does not hold this assumption to be true.

In [20]:
# TODO: Create an instance of KMeans to find two clusters
kmeans_2 = KMeans(n_clusters=3)

# TODO: use fit_predict to cluster the dataset
prediction3 = kmeans_2.fit_predict(X)

#Add the predicted cluster to the dataset
def userTypeTable(X, x_label, y_label, prediction):
    user_preferences = pd.DataFrame(X) 
    user_preferences.rename(columns={0:x_label,
                                     1:y_label}, 
                                     inplace=True)
    
    user_preferences['cluster'] = prediction
    return user_preferences

chinese_burger_user_preferences = userTypeTable(X, 'avg_burgers_ratings', 'avg_chinese_ratings', prediction3)

#Plot the new scatterplot with the appropriate labels

def scatterplot3Cluster(user_preferences, x_label, y_label, legend_labels):
    cluster0 = user_preferences[user_preferences['cluster'] == 0]
    cluster1 = user_preferences[user_preferences['cluster'] == 1]
    cluster2 = user_preferences[user_preferences['cluster'] == 2]
    
    c0 = plt.scatter(cluster0[x_label], cluster0[y_label], marker='o')
    c1 = plt.scatter(cluster1[x_label], cluster1[y_label], marker='s')
    c2 = plt.scatter(cluster2[x_label], cluster2[y_label], marker='x')
    
    plt.legend((c0, c1, c2),
               legend_labels,
               scatterpoints=1,
               loc='lower left',
               ncol=1,
               fontsize=14)

    plt.xlabel(x_label)
    plt.ylabel(y_label)
    
    return plt.show()

scatterplot3Cluster(chinese_burger_user_preferences, 'avg_burgers_ratings', 'avg_chinese_ratings', ('Loves Burgers', 'Loves Chinese', 'Eh About Both'))

The clusters seem to move around in terms of which cluster number is which cluster so if the legend makes no sense I apologize. The bottom right should be 'Love Burgers' and top right should be 'Loves Chinese'. You may need to rerun the cell above till these are aligned.

So, this clustering is a lot more interesting. It essentially shows us which users like chinese restaurants, but not burger restaurants, users who are ok with both chinese restuarants and restaurants that serve burgers, and users who like burger restaurants, but not chinese restaurants. It seems that there are many more users who prefer Chinese Food over Burgers than vice versa and also many users that don't really have a strong preference either way. But what are the exact numbers of users with each preference?

In [21]:
def countByUserType(user_type_table, colname, clusterNames):
    prefTable = user_type_table.groupby('cluster').count()[[colname]]
    prefTable['User Type'] = clusterNames
    prefTable.set_index('User Type', drop=True, inplace=True)
    prefTable.columns = ['Count']
    return prefTable

byusertable = countByUserType(chinese_burger_user_preferences, 'avg_burgers_ratings', ['Loves Burgers', 'Loves Chinese', 'Eh about Both'])
byusertable
Out[21]:
Count
User Type
Loves Burgers 87
Loves Chinese 159
Eh about Both 54

Wow, now that we have the numbers we can see that there are 3 times as many users who prefer Chinese over Burgers than vice-versa. In fact, it seems there are more users who feel lukewarm about both categories than those that prefer burgers over chinese food.

So, which restaurant category is more popular?

As we noted before, the scatterplot automatically removes any user who gave no ratings for restaurants in either category. While this makes sense for plotting purposes, knowing how many users have only rated one category versus another is itself pretty telling.

In [22]:
print(BurgerChineseRatings['avg_burgers_ratings'].isna().sum())
print(BurgerChineseRatings['avg_chinese_ratings'].isna().sum())
4060
6938

Here, for example, we can see that there are more users who have never rated a Chinese restaurant than users who have never rated a restaurant serving burgers. Since rating a restaurant is likely highly correlated with going to a restaurant, the number of NAs could indicate that, in Pittsburgh, restaurants with burgers are more popular than Chinese restaurants.

Thus, so far, we are able to say that more users go to restaurants with burgers than chinese restaurants, but, of the users that go to both, there are more who prefer Chinese restaurants than those that prefer burgers.

Indian vs Chinese Restaurants

Here is another analysis of Indian Restaurants versus Chinese Restaurants:

In [23]:
IndianChineseRatings = create_category_table(['Indian', 'Chinese'], ['avg_indian_ratings', 'avg_chinese_ratings'])
highPrefUsers = highPreferenceUsers(IndianChineseRatings, IndianChineseRatings['avg_indian_ratings'], IndianChineseRatings['avg_chinese_ratings'], 3, 3)
In [24]:
print(IndianChineseRatings['avg_indian_ratings'].isna().sum())
print(IndianChineseRatings['avg_chinese_ratings'].isna().sum())
4889
1655
In [25]:
draw_scatterplot(highPrefUsers['avg_indian_ratings'],'avg_indian_ratings', highPrefUsers['avg_chinese_ratings'], 'avg_chinese_ratings')
In [26]:
X1 = highPrefUsers[['avg_indian_ratings', 'avg_chinese_ratings']].values

# To-Do: Drop any NA values (Can't cluster with NAs)
X1 = pd.DataFrame(X1).dropna()

# To-Do: Create an instance of KMeans to find two clusters
kmeans_3 = KMeans(n_clusters=3)

# To-Do: use fit_predict to cluster the dataset
prediction4 = kmeans_2.fit_predict(X1)

indian_burger_user_preferences = userTypeTable(X1, 'avg_indian_ratings', 'avg_chinese_ratings', prediction4)
scatterplot3Cluster(indian_burger_user_preferences, 'avg_chinese_ratings', 'avg_indian_ratings',  ('Loves Indian', 'Loves Chinese', 'Eh about Both'))
    
In [27]:
usertypetable = countByUserType(indian_burger_user_preferences, 'avg_indian_ratings', ['Loves Indian', 'Loves Chinese', 'Eh about Both'])
usertypetable
Out[27]:
Count
User Type
Loves Indian 33
Loves Chinese 37
Eh about Both 62

It seems that many more people go to Chinese restuarants since their are far less NAs, but that, of the users that have strong preferences towards a category of restaurant, there are about the same users with either preference.

Choosing the Number of Clusters

Alright, the more clusters we add, the better each cluster represents the group of users in the cluster. This is because the more clusters, the smaller the clusters, and the more similar the users in the cluster. But if we have too many clusters, then the entire clustering is useless as the clusters get a lot less distinguished from each other. That is why it is pivotal to select a number of clusters that represents very similar users, but does not break groups of similar users into too many different user cluster groups.

One method designed to show the best number of clusters is the silhouette analysis. A silhouette analysis essentially plots the average silhouette score for each clusters for different numbers of clusters. A silhouette score is a value from -1 to 1, where a value of 1 means that objects are very well matched to their cluster in comparison to the other clusters. As can be seen below, this can be used to find the ideal number of clusters.

In [27]:
def clustering_errors(k, data):
    kmeans = KMeans(n_clusters=k).fit(data)
    predictions = kmeans.predict(data)
    silhouette_avg = silhouette_score(data, predictions)
    return silhouette_avg
In [28]:
df = chinese_burger_user_preferences[['avg_burgers_ratings','avg_chinese_ratings']]
X = biased_dataset[['avg_burgers_ratings','avg_chinese_ratings']].values
X = pd.DataFrame(X).dropna()

# Choose the range of k values to test.
# We added a stride of 5 to improve performance. We don't need to calculate the error for every k value
possible_k_values = range(2, len(X)+1, 5)

# Calculate error values for all k values we're interested in
errors_per_k = [clustering_errors(k, X) for k in possible_k_values]

# Plot the each value of K vs. the silhouette score at that value
fig, ax = plt.subplots(figsize=(16, 6))
plt.plot(possible_k_values, errors_per_k)

# Ticks and grid
xticks = np.arange(min(possible_k_values), max(possible_k_values)+1, 5.0)
ax.set_xticks(xticks, minor=False)
ax.set_xticks(xticks, minor=True)
ax.xaxis.grid(True, which='both')
yticks = np.arange(round(min(errors_per_k), 2), max(errors_per_k), .05)
ax.set_yticks(yticks, minor=False)
ax.set_yticks(yticks, minor=True)
ax.yaxis.grid(True, which='both')
In [42]:
df = indian_burger_user_preferences[['avg_indian_ratings','avg_chinese_ratings']]
X = highPrefUsers[['avg_indian_ratings', 'avg_chinese_ratings']].values
X = pd.DataFrame(X).dropna()
X
Out[42]:
0 1
0 4.0 1.00
1 5.0 1.00
2 4.0 2.00
3 2.0 3.78
4 4.0 2.75
... ... ...
127 4.0 1.00
128 2.0 5.00
129 4.4 2.00
130 2.5 3.50
131 4.0 2.33

132 rows × 2 columns

In [29]:
# Choose the range of k values to test.
# We added a stride of 5 to improve performance. We don't need to calculate the error for every k value
possible_k_values = range(2, len(X)+1, 5)

# Calculate error values for all k values we're interested in
errors_per_k = [clustering_errors(k, X) for k in possible_k_values]

# Plot the each value of K vs. the silhouette score at that value
fig, ax = plt.subplots(figsize=(16, 6))
plt.plot(possible_k_values, errors_per_k)

# Ticks and grid
xticks = np.arange(min(possible_k_values), max(possible_k_values)+1, 5.0)
ax.set_xticks(xticks, minor=False)
ax.set_xticks(xticks, minor=True)
ax.xaxis.grid(True, which='both')
yticks = np.arange(round(min(errors_per_k), 2), max(errors_per_k), .05)
ax.set_yticks(yticks, minor=False)
ax.set_yticks(yticks, minor=True)
ax.yaxis.grid(True, which='both')

As we can see from this analysis, the optimal number of clusters for the Chinese vs Burger analysis is 62 and the optimal number for the Indian versus Chinese analysis is 27. What these clusters mean past 3 clusters can be really hard to tell in this context, but it can be really important to choose the right number of clusters depending on what you are trying to do. When showing a heatmap of users ratings for each movie or business or example, it can make a big difference. An example is shown below:

Business Level Clustering

In [28]:
# Merge the two tables then pivot so we have Users X Movies dataframe
business_ratings = pd.merge(ratings, business[['business_id', 'name']], on='business_id' )
user_business_ratings = pd.pivot_table(business_ratings, index='user_id', columns= 'name', values='rating')

print('dataset dimensions: ', user_business_ratings.shape, '\n\nSubset example:')
user_business_ratings
dataset dimensions:  (68907, 5976) 

Subset example:
Out[28]:
name #1 Cochran Buick GMC of Robinson #1 Cochran Hyundai of South Hills #1 Cochran Kia - Robinson #1 Cochran Nissan - Pittsburgh #1 Cochran Nissan Monroeville #1 Cochran Volkswagen of South Hills 1-800-GOT-JUNK? Pittsburgh City 1-800-Haul-Out 105.9 The X 10th Street Tattoo and Body Piercing ... la Cappella lather A Pet Bath House love, Pittsburgh lululemon athletica notion terraFITNESS the pizza company täkō uBreakiFix west elm
user_id
--6CV8BPNofy7jt1JavD-g NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
--7J8nln4XVSBjPGzrLbyA NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
--CH8yRGXhO2MmbF-4BWXg NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
--EMqnd727rtC0G5Oc-Mrg NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
--Hc1I83HDuQWI6VBeSAEA NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
zzanZzhoA3wufshSrE5Q-g NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
zzbCMW84h24boAefr_y6Aw NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
zzdA0w9_bQ1DA-Nm_nFqkw NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
zzeKc7oOhSDSqvznjt6m0A NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
zzx7k7JqCQNhhqL4VJxL0A NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

68907 rows × 5976 columns

In [34]:
n_businesses = 30
n_users = 18
most_rated_businesses_users_selection = helper.sort_by_rating_density(user_business_ratings, n_businesses, n_users)

print('dataset dimensions: ', most_rated_businesses_users_selection.shape)
most_rated_businesses_users_selection.head()
dataset dimensions:  (18, 30)
Out[34]:
name Primanti Bros Meat & Potatoes Gaucho Parrilla Argentina Burgatory täkō Noodlehead Nicky's Thai Kitchen P&G's Pamela's Diner Church Brew Works Butcher and the Rye ... The Porch at Schenley Proper Brick Oven & Tap Room Steel Cactus Point Brugge Café Deluca's Diner Smoke BBQ Taqueria The Commoner Condado Tacos NOLA PNC Park
6318 4.000000 4.666667 5.0 4.0 5.0 3.0 3.75 4.5 2.0 4.0 ... 5.0 4.0 4.0 5.0 4.0 4.666667 4.0 5.0 5.0 5.0
7810 3.666667 5.000000 NaN 4.0 4.0 NaN 5.00 3.0 4.0 4.0 ... 4.0 4.0 NaN 4.0 NaN 5.000000 4.0 NaN 5.0 4.0
2198 3.000000 5.000000 5.0 4.0 4.0 5.0 5.00 4.0 3.0 NaN ... 3.5 4.0 2.0 NaN 5.0 4.000000 NaN NaN NaN 5.0
42926 NaN 4.000000 5.0 4.0 4.0 NaN 3.00 3.0 3.0 4.0 ... 3.0 2.0 3.0 4.0 4.0 4.000000 NaN 4.0 NaN NaN
26596 3.000000 4.000000 4.0 5.0 NaN NaN 4.00 NaN 3.0 NaN ... 4.0 4.0 NaN 4.0 3.0 NaN 4.0 NaN 3.0 5.0

5 rows × 30 columns

In [35]:
helper.draw_movies_heatmap(most_rated_businesses_users_selection)
In [37]:
most_rated_businesses_1k = helper.get_most_rated_movies(user_business_ratings, 1000)
most_rated_businesses_1k
Out[37]:
name Primanti Bros Meat & Potatoes Gaucho Parrilla Argentina Burgatory täkō Noodlehead Nicky's Thai Kitchen P&G's Pamela's Diner Church Brew Works Butcher and the Rye ... Megabus The Warren Bar and Burrow T-Swirl Crepe Sphinx Cafe Pesaro's Pizza Archie's Bar 11 Taza 21 Belvedere's Emiliano's Mexican Restaurant & Bar
0 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
1 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
3 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
4 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
68902 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
68903 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
68904 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
68905 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
68906 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

68907 rows × 1000 columns

In [49]:
def get_most_rated_movies(user_movie_ratings, max_number_of_movies):
    # 1- Count
    user_movie_ratings = user_movie_ratings.append(user_movie_ratings.count(), ignore_index=True)
    # 2- sort
    user_movie_ratings_sorted = user_movie_ratings.sort_values(len(user_movie_ratings)-1, axis=1, ascending=False)
    user_movie_ratings_sorted = user_movie_ratings_sorted.drop(user_movie_ratings_sorted.tail(1).index)
    # 3- slice
    most_rated_movies = user_movie_ratings_sorted.iloc[:, :max_number_of_movies]
    return most_rated_movies

def get_users_who_rate_the_most(most_rated_movies, max_number_of_movies):
    # Get most voting users
    # 1- Count
    most_rated_movies['counts'] = pd.Series(most_rated_movies.count(axis=1))
    # 2- Sort
    most_rated_movies_users = most_rated_movies.sort_values('counts', ascending=False)
    # 3- Slice
    most_rated_movies_users_selection = most_rated_movies_users.iloc[:max_number_of_movies, :]
    most_rated_movies_users_selection = most_rated_movies_users_selection.drop(['counts'], axis=1)
    
    return most_rated_movies_users_selection

def sort_by_rating_density(user_movie_ratings, n_movies, n_users):
    most_rated_movies = get_most_rated_movies(user_movie_ratings, n_movies)
    most_rated_movies = get_users_who_rate_the_most(most_rated_movies, n_users)
    return most_rated_movies


def draw_movie_clusters(clustered, max_users, max_movies):
    c=1
    for cluster_id in clustered.group.unique():
        # To improve visibility, we're showing at most max_users users and max_movies movies per cluster.
        # You can change these values to see more users & movies per cluster
        d = clustered[clustered.group == cluster_id].drop(['index', 'group'], axis=1)
        n_users_in_cluster = d.shape[0]
        
        d = sort_by_rating_density(d, max_movies, max_users)
        
        d = d.reindex(d.mean().sort_values(ascending=False).index, axis=1)
        d = d.reindex(d.count(axis=1).sort_values(ascending=False).index)
        d = d.iloc[:max_users, :max_movies]
        n_users_in_plot = d.shape[0]
        
        # We're only selecting to show clusters that have more than 9 users, otherwise, they're less interesting
        if len(d) > 9:
            print('cluster # {}'.format(cluster_id))
            print('# of users in cluster: {}.'.format(n_users_in_cluster), '# of users in plot: {}'.format(n_users_in_plot))
            fig = plt.figure(figsize=(15,4))
            ax = plt.gca()

            ax.invert_yaxis()
            ax.xaxis.tick_top()
            labels = d.columns.str[:40]

            ax.set_yticks(np.arange(d.shape[0]) , minor=False)
            ax.set_xticks(np.arange(d.shape[1]) , minor=False)

            ax.set_xticklabels(labels, minor=False)
                        
            ax.get_yaxis().set_visible(False)

            # Heatmap
            heatmap = plt.imshow(d, vmin=0, vmax=5, aspect='auto')

            ax.set_xlabel('movies')
            ax.set_ylabel('User id')

            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)

            # Color bar
            cbar = fig.colorbar(heatmap, ticks=[5, 4, 3, 2, 1, 0], cax=cax)
            cbar.ax.set_yticklabels(['5 stars', '4 stars','3 stars','2 stars','1 stars','0 stars'])

            plt.setp(ax.get_xticklabels(), rotation=90, fontsize=9)
            plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', labelbottom='off', labelleft='off') 
            #print('cluster # {} \n(Showing at most {} users and {} movies)'.format(cluster_id, max_users, max_movies))

            plt.show()


            # Let's only show 5 clusters
            # Remove the next three lines if you want to see all the clusters
            # Contribution welcomed: Pythonic way of achieving this
            # c = c+1
            # if c > 6:
            #    break
In [63]:
sparse_ratings = csr_matrix(pd.SparseDataFrame(most_rated_businesses_1k).to_coo())
sparse_ratings

# 20 clusters
predictions = KMeans(n_clusters=20, algorithm='full').fit_predict(sparse_ratings)

max_users = 70
max_businesses = 50

clustered = pd.concat([most_rated_businesses_1k.reset_index(), pd.DataFrame({'group':predictions})], axis=1)
In [77]:
clustered
Out[77]:
index Primanti Bros Meat & Potatoes Gaucho Parrilla Argentina Burgatory täkō Noodlehead Nicky's Thai Kitchen P&G's Pamela's Diner Church Brew Works ... The Warren Bar and Burrow T-Swirl Crepe Sphinx Cafe Pesaro's Pizza Archie's Bar 11 Taza 21 Belvedere's Emiliano's Mexican Restaurant & Bar group
0 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
1 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
3 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
4 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
68902 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
68903 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
68904 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
68905 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
68906 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

68907 rows × 1002 columns

In [61]:
draw_movie_clusters(clustered, max_users, max_businesses)
cluster # 17
# of users in cluster: 420. # of users in plot: 70
cluster # 0
# of users in cluster: 60639. # of users in plot: 70
cluster # 16
# of users in cluster: 1136. # of users in plot: 70
cluster # 4
# of users in cluster: 1957. # of users in plot: 70
cluster # 14
# of users in cluster: 783. # of users in plot: 70
cluster # 13
# of users in cluster: 797. # of users in plot: 70
cluster # 1
# of users in cluster: 661. # of users in plot: 70
cluster # 6
# of users in cluster: 1111. # of users in plot: 70
cluster # 19
# of users in cluster: 429. # of users in plot: 70
cluster # 3
# of users in cluster: 378. # of users in plot: 70
cluster # 8
# of users in cluster: 343. # of users in plot: 70
cluster # 12
# of users in cluster: 39. # of users in plot: 39
cluster # 9
# of users in cluster: 95. # of users in plot: 70
cluster # 11
# of users in cluster: 99. # of users in plot: 70
cluster # 5
# of users in cluster: 11. # of users in plot: 11
In [51]:
# TODO: Pick a cluster ID from the clusters above
cluster_number = 12

# Let's filter to only see the region of the dataset with the most number of values 
n_users = 75
n_businesses = 300
cluster = clustered[clustered.group == cluster_number].drop(['index', 'group'], axis=1)

cluster = sort_by_rating_density(cluster, n_businesses, n_users)
helper.draw_movies_heatmap(cluster, axis_labels=False)
In [52]:
cluster.fillna('').head()
Out[52]:
Noodlehead Gaucho Parrilla Argentina Millie's Homemade Ice Cream Everyday Noodles Smiling Banana Leaf Nicky's Thai Kitchen Meat & Potatoes Choolaah Indian BBQ täkō Condado Tacos ... I Tea Cafe The Green Mango Top Shabu-Shabu & Lounge Pittsburgh Poke Ikea Murray Avenue Grill McFadden's Saloon Eddie Merlot's - Pittsburgh Vue 412 Pirata
325 3.0 5 3 ... 4
436 4.0 3 3 ...
342 5.0 3 ... 4
5 4.0 5 4 ... 4
231 4.0 2 ...

5 rows × 300 columns

In [55]:
# TODO: Fill in the name of the column/movie. e.g. 'Forrest Gump (1994)'
business_name = "Pirata"

cluster[business_name].mean()
Out[55]:
4.0
In [56]:
# The average rating of 20 movies as rated by the users in the cluster
cluster.mean().head(20)
Out[56]:
Noodlehead                     4.633333
Gaucho Parrilla Argentina      4.769231
Millie's Homemade Ice Cream    4.437500
Everyday Noodles               4.500000
Smiling Banana Leaf            3.676471
Nicky's Thai Kitchen           4.200000
Meat & Potatoes                4.133333
Choolaah Indian BBQ            3.909091
täkō                           4.200000
Condado Tacos                  3.615385
Sienna Mercato                 3.866667
Primanti Bros                  3.000000
The Porch at Schenley          3.545455
DiAnoia's Eatery               4.250000
P&G's Pamela's Diner           3.444444
Bakersfield                    4.428571
Muddy Waters Oyster Bar        4.444444
Smallman Galley                4.000000
Church Brew Works              3.444444
Senyai Thai Kitchen            4.000000
dtype: float64
In [59]:
user_id = 5

# Get all this user's ratings
user_2_ratings  = cluster.loc[user_id, :]

# Which businesses did they not rate? (We don't want to recommend movies they've already rated)
user_2_unrated_movies =  user_2_ratings[user_2_ratings.isnull()]

# What are the ratings of these movies the user did not rate?
avg_ratings = pd.concat([user_2_unrated_movies, cluster.mean()], axis=1, join='inner').loc[:,0]

# Let's sort by rating so the highest rated movies are presented first
avg_ratings.sort_values(ascending=False)[:20]
Out[59]:
Seviche                               5.0
The Allegheny Wine Mixer              5.0
Adda Coffee & Tea House               5.0
Pho Minh                              5.0
Natural Eyebrows Threading            5.0
Pear and the Pickle, Cafe & Market    5.0
Miss T's Beauty Lounge                5.0
Hidden Harbor                         5.0
Constellation Coffee                  5.0
Pasha Cafe Lounge                     5.0
Acacia                                5.0
Duquesne Incline                      5.0
Bar Marco                             5.0
Kaya                                  5.0
B52 Cafe                              5.0
Arsenal Cider House & Wine Cellar     5.0
Showcase BBQ                          5.0
Randyland                             5.0
Central Diner & Grille                5.0
The Green Mango                       5.0
Name: 0, dtype: float64

Example 2 - Iris Dataset

Another really short, but interesting example we can look at is the Iris dataset. This dataset has variables about the sepal length and petal length of 3 different flowers. The petal length and sepal length are normally distributed and so we see that the Gaussian Mixture Model is fitting the data well. We know from the data set that there are 50 data points for each flower type and so we can see the Gaussian Mixture Model fits the model well and that it improves as it is given more data.

In [31]:
from sklearn import datasets
import numpy as np

iris = datasets.load_iris()
iris
Out[31]:
{'data': array([[5.1, 3.5, 1.4, 0.2],
        [4.9, 3. , 1.4, 0.2],
        [4.7, 3.2, 1.3, 0.2],
        [4.6, 3.1, 1.5, 0.2],
        [5. , 3.6, 1.4, 0.2],
        [5.4, 3.9, 1.7, 0.4],
        [4.6, 3.4, 1.4, 0.3],
        [5. , 3.4, 1.5, 0.2],
        [4.4, 2.9, 1.4, 0.2],
        [4.9, 3.1, 1.5, 0.1],
        [5.4, 3.7, 1.5, 0.2],
        [4.8, 3.4, 1.6, 0.2],
        [4.8, 3. , 1.4, 0.1],
        [4.3, 3. , 1.1, 0.1],
        [5.8, 4. , 1.2, 0.2],
        [5.7, 4.4, 1.5, 0.4],
        [5.4, 3.9, 1.3, 0.4],
        [5.1, 3.5, 1.4, 0.3],
        [5.7, 3.8, 1.7, 0.3],
        [5.1, 3.8, 1.5, 0.3],
        [5.4, 3.4, 1.7, 0.2],
        [5.1, 3.7, 1.5, 0.4],
        [4.6, 3.6, 1. , 0.2],
        [5.1, 3.3, 1.7, 0.5],
        [4.8, 3.4, 1.9, 0.2],
        [5. , 3. , 1.6, 0.2],
        [5. , 3.4, 1.6, 0.4],
        [5.2, 3.5, 1.5, 0.2],
        [5.2, 3.4, 1.4, 0.2],
        [4.7, 3.2, 1.6, 0.2],
        [4.8, 3.1, 1.6, 0.2],
        [5.4, 3.4, 1.5, 0.4],
        [5.2, 4.1, 1.5, 0.1],
        [5.5, 4.2, 1.4, 0.2],
        [4.9, 3.1, 1.5, 0.2],
        [5. , 3.2, 1.2, 0.2],
        [5.5, 3.5, 1.3, 0.2],
        [4.9, 3.6, 1.4, 0.1],
        [4.4, 3. , 1.3, 0.2],
        [5.1, 3.4, 1.5, 0.2],
        [5. , 3.5, 1.3, 0.3],
        [4.5, 2.3, 1.3, 0.3],
        [4.4, 3.2, 1.3, 0.2],
        [5. , 3.5, 1.6, 0.6],
        [5.1, 3.8, 1.9, 0.4],
        [4.8, 3. , 1.4, 0.3],
        [5.1, 3.8, 1.6, 0.2],
        [4.6, 3.2, 1.4, 0.2],
        [5.3, 3.7, 1.5, 0.2],
        [5. , 3.3, 1.4, 0.2],
        [7. , 3.2, 4.7, 1.4],
        [6.4, 3.2, 4.5, 1.5],
        [6.9, 3.1, 4.9, 1.5],
        [5.5, 2.3, 4. , 1.3],
        [6.5, 2.8, 4.6, 1.5],
        [5.7, 2.8, 4.5, 1.3],
        [6.3, 3.3, 4.7, 1.6],
        [4.9, 2.4, 3.3, 1. ],
        [6.6, 2.9, 4.6, 1.3],
        [5.2, 2.7, 3.9, 1.4],
        [5. , 2. , 3.5, 1. ],
        [5.9, 3. , 4.2, 1.5],
        [6. , 2.2, 4. , 1. ],
        [6.1, 2.9, 4.7, 1.4],
        [5.6, 2.9, 3.6, 1.3],
        [6.7, 3.1, 4.4, 1.4],
        [5.6, 3. , 4.5, 1.5],
        [5.8, 2.7, 4.1, 1. ],
        [6.2, 2.2, 4.5, 1.5],
        [5.6, 2.5, 3.9, 1.1],
        [5.9, 3.2, 4.8, 1.8],
        [6.1, 2.8, 4. , 1.3],
        [6.3, 2.5, 4.9, 1.5],
        [6.1, 2.8, 4.7, 1.2],
        [6.4, 2.9, 4.3, 1.3],
        [6.6, 3. , 4.4, 1.4],
        [6.8, 2.8, 4.8, 1.4],
        [6.7, 3. , 5. , 1.7],
        [6. , 2.9, 4.5, 1.5],
        [5.7, 2.6, 3.5, 1. ],
        [5.5, 2.4, 3.8, 1.1],
        [5.5, 2.4, 3.7, 1. ],
        [5.8, 2.7, 3.9, 1.2],
        [6. , 2.7, 5.1, 1.6],
        [5.4, 3. , 4.5, 1.5],
        [6. , 3.4, 4.5, 1.6],
        [6.7, 3.1, 4.7, 1.5],
        [6.3, 2.3, 4.4, 1.3],
        [5.6, 3. , 4.1, 1.3],
        [5.5, 2.5, 4. , 1.3],
        [5.5, 2.6, 4.4, 1.2],
        [6.1, 3. , 4.6, 1.4],
        [5.8, 2.6, 4. , 1.2],
        [5. , 2.3, 3.3, 1. ],
        [5.6, 2.7, 4.2, 1.3],
        [5.7, 3. , 4.2, 1.2],
        [5.7, 2.9, 4.2, 1.3],
        [6.2, 2.9, 4.3, 1.3],
        [5.1, 2.5, 3. , 1.1],
        [5.7, 2.8, 4.1, 1.3],
        [6.3, 3.3, 6. , 2.5],
        [5.8, 2.7, 5.1, 1.9],
        [7.1, 3. , 5.9, 2.1],
        [6.3, 2.9, 5.6, 1.8],
        [6.5, 3. , 5.8, 2.2],
        [7.6, 3. , 6.6, 2.1],
        [4.9, 2.5, 4.5, 1.7],
        [7.3, 2.9, 6.3, 1.8],
        [6.7, 2.5, 5.8, 1.8],
        [7.2, 3.6, 6.1, 2.5],
        [6.5, 3.2, 5.1, 2. ],
        [6.4, 2.7, 5.3, 1.9],
        [6.8, 3. , 5.5, 2.1],
        [5.7, 2.5, 5. , 2. ],
        [5.8, 2.8, 5.1, 2.4],
        [6.4, 3.2, 5.3, 2.3],
        [6.5, 3. , 5.5, 1.8],
        [7.7, 3.8, 6.7, 2.2],
        [7.7, 2.6, 6.9, 2.3],
        [6. , 2.2, 5. , 1.5],
        [6.9, 3.2, 5.7, 2.3],
        [5.6, 2.8, 4.9, 2. ],
        [7.7, 2.8, 6.7, 2. ],
        [6.3, 2.7, 4.9, 1.8],
        [6.7, 3.3, 5.7, 2.1],
        [7.2, 3.2, 6. , 1.8],
        [6.2, 2.8, 4.8, 1.8],
        [6.1, 3. , 4.9, 1.8],
        [6.4, 2.8, 5.6, 2.1],
        [7.2, 3. , 5.8, 1.6],
        [7.4, 2.8, 6.1, 1.9],
        [7.9, 3.8, 6.4, 2. ],
        [6.4, 2.8, 5.6, 2.2],
        [6.3, 2.8, 5.1, 1.5],
        [6.1, 2.6, 5.6, 1.4],
        [7.7, 3. , 6.1, 2.3],
        [6.3, 3.4, 5.6, 2.4],
        [6.4, 3.1, 5.5, 1.8],
        [6. , 3. , 4.8, 1.8],
        [6.9, 3.1, 5.4, 2.1],
        [6.7, 3.1, 5.6, 2.4],
        [6.9, 3.1, 5.1, 2.3],
        [5.8, 2.7, 5.1, 1.9],
        [6.8, 3.2, 5.9, 2.3],
        [6.7, 3.3, 5.7, 2.5],
        [6.7, 3. , 5.2, 2.3],
        [6.3, 2.5, 5. , 1.9],
        [6.5, 3. , 5.2, 2. ],
        [6.2, 3.4, 5.4, 2.3],
        [5.9, 3. , 5.1, 1.8]]),
 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'),
 'DESCR': '.. _iris_dataset:\n\nIris plants dataset\n--------------------\n\n**Data Set Characteristics:**\n\n    :Number of Instances: 150 (50 in each of three classes)\n    :Number of Attributes: 4 numeric, predictive attributes and the class\n    :Attribute Information:\n        - sepal length in cm\n        - sepal width in cm\n        - petal length in cm\n        - petal width in cm\n        - class:\n                - Iris-Setosa\n                - Iris-Versicolour\n                - Iris-Virginica\n                \n    :Summary Statistics:\n\n    ============== ==== ==== ======= ===== ====================\n                    Min  Max   Mean    SD   Class Correlation\n    ============== ==== ==== ======= ===== ====================\n    sepal length:   4.3  7.9   5.84   0.83    0.7826\n    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)\n    ============== ==== ==== ======= ===== ====================\n\n    :Missing Attribute Values: None\n    :Class Distribution: 33.3% for each of 3 classes.\n    :Creator: R.A. Fisher\n    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n    :Date: July, 1988\n\nThe famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\nfrom Fisher\'s paper. Note that it\'s the same as in R, but not as in the UCI\nMachine Learning Repository, which has two wrong data points.\n\nThis is perhaps the best known database to be found in the\npattern recognition literature.  Fisher\'s paper is a classic in the field and\nis referenced frequently to this day.  (See Duda & Hart, for example.)  The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant.  One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\n.. topic:: References\n\n   - Fisher, R.A. "The use of multiple measurements in taxonomic problems"\n     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n     Mathematical Statistics" (John Wiley, NY, 1950).\n   - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.\n   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n     Structure and Classification Rule for Recognition in Partially Exposed\n     Environments".  IEEE Transactions on Pattern Analysis and Machine\n     Intelligence, Vol. PAMI-2, No. 1, 67-71.\n   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions\n     on Information Theory, May 1972, 431-433.\n   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II\n     conceptual clustering system finds 3 classes in the data.\n   - Many, many more ...',
 'feature_names': ['sepal length (cm)',
  'sepal width (cm)',
  'petal length (cm)',
  'petal width (cm)'],
 'filename': '/usr/local/lib/python3.7/site-packages/sklearn/datasets/data/iris.csv'}

1) Use GMM to cluster only using sepal length and width

In [32]:
irisData = iris['data']

sepalData = []
for flower in irisData:
    flowerData = []
    flowerData.append(flower[0])
    flowerData.append(flower[1])

    sepalData.append(flowerData)

sepalData = np.asarray(sepalData)

plt.scatter(sepalData[:, 0], sepalData[:, 1], s=40, cmap='viridis');
In [33]:
plt.hist(sepalData[:, 0], bins='auto')
plt.show()
In [34]:
plt.hist(sepalData[:, 0], bins='auto')
plt.show()
In [35]:
from sklearn.mixture import GaussianMixture
gmm = GaussianMixture(n_components=3).fit(sepalData)
labels = gmm.predict(sepalData)
plt.scatter(sepalData[:, 0], sepalData[:, 1], c=labels, s=40, cmap='viridis');
In [36]:
sepalData = pd.DataFrame(sepalData) 
sepalData.rename(columns={0:'sepal_length',
                          1:'sepal_width'}, 
                 inplace=True)

sepalData['cluster'] = labels
sepalData

sepalData.groupby('cluster').count()[['sepal_length']]
Out[36]:
sepal_length
cluster
0 60
1 49
2 41

2) Use GMM to cluster using all data (sepal data and petal data)

In [37]:
from sklearn.mixture import GaussianMixture
gmm = GaussianMixture(n_components=3).fit(irisData)
labels = gmm.predict(irisData)

irisData = pd.DataFrame(irisData) 
irisData.rename(columns={0:'sepal_length',
                          1:'sepal_width',
                        2:'petal_length',
                        3:'petal_width'}, 
                 inplace=True)

irisData['cluster'] = labels
irisData.groupby('cluster').count()[['sepal_length']]
Out[37]:
sepal_length
cluster
0 45
1 50
2 55