plotting with matplotlib
You can watch and practice these videos
You must import these:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
Matplotlib has two popular APIs but the most popular one is the Matplotlib Global API.
Matplotlib's default pyplot API has a global, MATLAB-style interface, as we've already seen:
Matlab Style API
Codes:
x = np.arange(-10, 11)
plt.figure(figsize=(12, 6))
plt.title('My Nice Plot')
plt.plot(x, x ** 2)
plt.plot(x, -1 * (x ** 2))
Setting the names for the lines: setting legends.
plt.figure(figsize=(12, 6))
plt.title('My Nice Plot')
plt.plot(x, x ** 2)
plt.plot(x, -1 * (x ** 2))
plt.legend(["plusbar", "minusbar"])
if you wan to add another line, try ths code now:
plt.figure(figsize=(12, 6))
plt.title('My Nice Plot')
plt.plot(x, x ** 2)
plt.plot(x, -1 * (x ** 2))
plt.plot(-x, x**2)
plt.legend(["plusbar", "minusbar", "-xbar"])
#you will nitice legends are in the order of the plots.
You wan tto add one stright line to the existing design?
try this code:
plt.figure(figsize=(12, 6))
plt.title('My Nice Plot')
plt.plot(x, x ** 2)
plt.plot(x, -1 * (x ** 2))
plt.plot(-x, x**2)
plt.plot([0,5,10,15],[0,0,0,0])
plt.legend(["plusbar", "minusbar", "-xbar", "strightbar"])
Now if you want to label the x and y axises, you can do that using the following code:
plt.figure(figsize=(12, 6))
plt.title('My Nice Plot')
plt.plot(x, x ** 2)
plt.plot(x, -1 * (x ** 2))
plt.plot(-x, x**2)
plt.plot([0,5,10,15],[0,0,0,0])
plt.legend(["plusbar", "minusbar", "-xbar", "strightbar"])
#lets lable the axis now.
plt.xlabel("x's value")
plt.ylabel("y's value")
How about making subplots?
Subplot is the idea of drawing two plots in one matplotlib figure:
First we plot the first one of the two subplots
plt.figure(figsize=(12, 6))
plt.title('My Nice Plot')
#lets make the first subplot
plt.subplot(1,2,1) #1 row, 2 columns and the 1st subplot graph
plt.plot(x, x ** 2)
plt.plot(x, -1 * (x ** 2))
plt.plot(-x, x**2)
plt.plot([0,5,10,15],[0,0,0,0])
plt.legend(["plusbar", "minusbar", "-xbar", "strightbar"])
#lets label the axis now.
plt.xlabel("x's value")
plt.ylabel("y's value")
#we used row value 1 because this graph has only one row in the drawing and two columns as the subplots. (You will see a picture soon about this)
But, now lets plot the second subplot:
plt.figure(figsize=(12, 6))
plt.title('My Nice Plot')
#lets make the first subplot
plt.subplot(1,2,1) #1 row, 2 columns and the 1st subplot graph
plt.plot(x, x ** 2)
plt.plot(x, -1 * (x ** 2))
plt.plot(-x, x**2)
plt.plot([0,5,10,15],[0,0,0,0])
plt.legend(["plusbar", "minusbar", "-xbar", "strightbar"])
#lets label the axis now.
plt.xlabel("x's value")
plt.ylabel("y's value")
#lets make the second subplot
plt.subplot(1,2,2) #1 row, 2 columns and the 2nd subplot graph
plt.plot(-x, -x ** 2)
plt.plot(x, -1 * (x ** 2))
plt.plot(-x, x**2)
plt.plot([-0,-5,-10,-15],[0,0,0,0])
plt.legend(["plusbar", "minusbar", "-xbar", "strightbar"])
#lets label the axis now.
plt.xlabel("x's value")
plt.ylabel("y's value")
Try to understand looking at the second picture. The second picture has two rows, two columns and four plots. There is a better way to do all these and that is OOP:
OOP API:
fig is the entire figure or the plot area.
Axes object is the region of the image with the data space. A given figure can contain many Axes, but a given Axes object can only be in one Figure. The Axes contains two (or three in the case of 3D) Axis objects.
Lets start with a figure and only one axes in the figure:
fig, axes = plt.subplots(figsize=(16,6))
do you notice the change in it? if compared to the global api, plt.subplots has an extra s here. We are assigning the figure size to the axes. fig will take the size of the axes. dont forget, we are creating only one axes here.
Now lets draw four lines using the oop api:
fig, axes = plt.subplots(figsize=(16,6))
axes.plot(x, x + 0, linestyle='solid')
axes.plot(x, x + 1, linestyle='dashed')
axes.plot(x, x + 2, linestyle='dashdot')
axes.plot(x, x + 3, linestyle='dotted')
axes.set_title("My Nice Plot")
What if I want to label to these lines and show them as legend()?
fig, axes = plt.subplots(figsize=(16,8))
axes.plot(x, x + 0, linestyle='solid', label = 'firni')
axes.plot(x, x + 1, linestyle='dashed', label = 'jorda')
axes.plot(x, x + 2, linestyle='dashdot', label = 'alu')
axes.plot(x, x + 3, linestyle='dotted', label = 'gajor')
axes.legend() #without this, the fig wont call the legend labels.
axes.set_title("My Nice Plot")
Okay, I will show you a way of determining color, but this time, we wont write the linestyle=' ' in the code.
fig, axes = plt.subplots(figsize=(12, 6))
axes.plot(x, x + 0, '-og', label="solid green") #- and o means solid and g means the color is green.
axes.plot(x, x + 1, '--c', label="dashed cyan")
axes.plot(x, x + 2, '-.b', label="dashdot blue")
axes.plot(x, x + 3, ':r', label="dotted red")
axes.set_title("My Nice Plot")
axes.legend()
You can also do it like this: setting the color with this code- color=" "
axes.plot(x, x + 3, color="black", label="dotted red")
A better way of doing all the previous stuffs:
plot_objects = plt.subplots(figsize=(14, 8))
fig, axes = plot_objects
axes.plot(x, x + 0, '-og', label="solid green")
axes.plot(x, x + 1, '--c', label="dashed cyan")
axes.plot(x, x + 2, '-.b', label="dashdot blue")
axes.plot(x, x + 3, color="black", label="dotted red")
axes.set_title("My Nice Plot")
axes.legend()
plot_objects
But now, lets make subplots with this OOP api:
plot_objects = plt.subplots(nrows=3, ncols=2, figsize=(14, 8))
fig, ((ax1,ax2), (ax3,ax4), (ax5,ax6)) = plot_objects
# ((ax1,ax2), (ax3,ax4), (ax5,ax6)) these are the rows of the two columns.
ax1.plot(x, x + 0, '-og', label="solid green")
ax2.plot(x, x + 1, '--c', label="dashed cyan")
ax3.plot(x, x + 2, '-.b', label="dashdot blue")
ax4.plot(x, x + 3, color="black", label="dotted red")
ax5.plot(x, x + 4, color="black", label="dotted red")
ax6.plot(x, x + 5, color="black", label="dotted red")
ax1.set_title("ax1")
ax2.set_title("ax2")
ax3.set_title("ax3")
ax4.set_title("ax4")
ax5.set_title("ax5")
ax6.set_title("ax6")
ax1.legend()
plot_objects
Subplot2grid: Plotting subplots in the grid
plt.figure(figsize=(14, 6))
#calling plt.subplot2grid((nrows, ncols), positioning, colspan=?)
ax1 = plt.subplot2grid((3,3), (0,0), colspan=3)
ax2 = plt.subplot2grid((3,3), (1,0), colspan=2)
ax3 = plt.subplot2grid((3,3), (1,2), rowspan=2)
ax4 = plt.subplot2grid((3,3), (2,0))
ax5 = plt.subplot2grid((3,3), (2,1))
#calling axes.plot(x value, y value, color=?, labal=?)
ax1.plot(x, x + 0, '-og', label="ax1 0,0")
ax2.plot(x, x + 1, '--c', label="ax2 1,0")
ax3.plot(x, x + 2, '-.b', label="ax3 1,2 ")
ax4.plot(x, x + 3, color="black", label="ax4 2,0")
ax5.plot(x, x + 4, color="black", label="ax5 2,1")
#setting title for each of the plots.
ax1.set_title("ax1")
ax2.set_title("ax2")
ax3.set_title("ax3")
ax4.set_title("ax4")
ax5.set_title("ax5")
#calling legend labels for each of the plots.
ax1.legend()
ax2.legend()
ax3.legend()
ax4.legend()
ax5.legend()
Lets see how you can make a zig zag line:
fig = plt.figure()
ax = plt.axes()
ax.plot(x, np.sin(x));
Another way of doing it:
n2=np.sin(x)
plt.plot(n2)
Scatter Plot
Steps to make scatter plot one by one:
Lets make a simple scatter plot-
x = np.arange(1,20)
y = np.arange(1,20)
colors = np.arange(1,20)
#make a scatter plot.
plt.figure(figsize=(14, 6))
plt.scatter(x, y)
plt.colorbar()
plt.show()
you can change the color fo the map, you can change the size of the square and you can have subplots as well.
This is how you do it.
#first design and get the data, set the colors, size of the scatter dots.
N = 50
x = np.random.rand(N)
y = np.random.rand(N)
colors = np.random.rand(N)
area = np.pi * (20 * np.random.rand(N))**2 # 0 to 15 point radii
#now start with the figure and the plots.
fig = plt.figure(figsize=(14, 6))
ax1 = fig.add_subplot(1,2,1)
plt.scatter(x, y, s=area, c=colors, alpha=0.5, cmap='Pastel1')
plt.colorbar()
ax2 = fig.add_subplot(1,2,2)
plt.scatter(x, y, s=area, c=colors, alpha=0.5, cmap='Pastel2')
plt.colorbar()
plt.show()
Histograms
First of we need to know what is histogram.
A video link: https://www.youtube.com/watch?v=YLPDPglvePY
import matplotlib.pyplot as plt x = [1,1,2,3,3,5,7,8,9,10, 10,11,11,13,13,15,16,17,18,18, 18,19,20,21,21,23,24,24,25,25, 25,25,26,26,26,27,27,27,27,27, 29,30,30,31,33,34,34,34,35,36, 36,37,37,38,38,39,40,41,41,42, 43,44,45,45,46,47,48,48,49,50, 51,52,53,54,55,55,56,57,58,60, 61,63,64,65,66,68,70,71,72,74, 75,77,81,83,84,87,89,90,90,91 ] plt.style.use('ggplot') plt.hist(x, bins=10) plt.show()
Using our formulas:
- n = number of observations = 100
- Range = maximum value – minimum value = 91 – 1 = 90
- # of intervals = √n = √100 = 10
- Width of intervals = Range / (# of intervals) = 90/10 = 9
You can aslo specify the bins in the form of a list.
plt.style.use('ggplot')
plt.hist(x, bins=[0,10,20,30,40,50,60,70,80,90,99])
plt.show()
Barplots
Here is how to create a bar plot:
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('ggplot')
x = ['Nuclear', 'Hydro', 'Gas', 'Oil', 'Coal', 'Biofuel']
energy = [5, 6, 15, 22, 24, 8]
x_pos = [i for i, _ in enumerate(x)]
plt.bar(x_pos, energy, color='green')
plt.xlabel("Energy Source")
plt.ylabel("Energy Output (GJ)")
plt.title("Energy output from various fuel sources")
plt.xticks(x_pos, x) #because we are plotting x on the x axis.
plt.show()
#lets say you want to make a vertical bar plot:
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('ggplot')
x = ['Nuclear', 'Hydro', 'Gas', 'Oil', 'Coal', 'Biofuel']
energy = [5, 6, 15, 22, 24, 8]
x_pos = [i for i, _ in enumerate(x)]
plt.barh(x_pos, energy, color='green')#We can show the exact same chart horizontally using plt.barh()
plt.xlabel("Energy Source")
plt.ylabel("Energy Output (GJ)")
plt.title("Energy output from various fuel sources")
plt.yticks(x_pos, x) #we want the x_pos on the y axis.
plt.show()
Stacked Bar Plotting:
countries = ['USA', 'GB', 'China', 'Russia', 'Germany'] bronzes = np.array([38, 17, 26, 19, 15]) silvers = np.array([37, 23, 18, 18, 10]) golds = np.array([46, 27, 26, 19, 17]) ind = [x for x, _ in enumerate(countries)] plt.bar(ind, golds, width=0.8, label='golds', color='gold', bottom=silvers+bronzes) plt.bar(ind, silvers, width=0.8, label='silvers', color='silver', bottom=bronzes) plt.bar(ind, bronzes, width=0.8, label='bronzes', color='#CD853F') plt.xticks(ind, countries) plt.ylabel("Medals") plt.xlabel("Countries") plt.legend(loc="upper right") plt.title("2012 Olympics Top Scorers") plt.show()
KDE (kernel density estimation)
Example 1:
from scipy import stats
values = np.random.randn(1000)
density = stats.kde.gaussian_kde(values)
density
plt.subplots(figsize=(12, 6))
values2 = np.linspace(min(values)-10, max(values)+10, 100)
plt.plot(values2, density(values2), color='#FF7F00')
plt.fill_between(values2, 0, density(values2), alpha=0.5, color='#FF7F00')
plt.xlim(xmin=-5, xmax=5)
plt.show()
Example 2:
import matplotlib.pyplot as plt
import numpy
from scipy import stats
data = [1.5]*7 + [2.5]*2 + [3.5]*8 + [4.5]*3 + [5.5]*1 + [6.5]*8
density = stats.kde.gaussian_kde(data) #drop the data you want to find the kde of.
x = numpy.arange(0., 8, .1)#x axis will have this upto 8.0 and the gap is gonna be 0.1
plt.plot(x, density(x)) #density will always take the x axis values in consideration.
plt.show()
this is how it looks:
Sounds too difficult? let's make it simple.
Let me take a bit easier to begin-
let's take a dataset in a series and plot it.
if you want to read more https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.gaussian_kde.html#scipy.stats.gaussian_kde
If you want to get these done with seaborn: watch this
or follow this link: visualising with seaborn
Reading from DataFrame and exporting charts to Pdf:
from pandas import DataFrame
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages #calling the pdf functions package
Data1 = {'Unemployment_Rate': [6.1,5.8,5.7,5.7,5.8,5.6,5.5,5.3,5.2,5.2],
'Stock_Index_Price': [1500,1520,1525,1523,1515,1540,1545,1560,1555,1565]
}
df1 = DataFrame(Data1,columns=['Unemployment_Rate','Stock_Index_Price'])
with PdfPages(r'Chartstry.pdf') as export_pdf: #assigning the folder and file name
plt.scatter(df1['Unemployment_Rate'], df1['Stock_Index_Price'], color='green')
plt.title('Unemployment Rate Vs Stock Index Price', fontsize=10)
plt.xlabel('Unemployment Rate', fontsize=8)
plt.ylabel('Stock Index Price', fontsize=8)
plt.grid(True)
export_pdf.savefig() #calling the functions to start the magic
plt.close()
Effective ways of using Matplotlib:
Effectively Using Matplotlib
Posted by Chris Moffitt in articles
Introduction
The python visualization world can be a frustrating place for a new user. There are many different options and choosing the right one is a challenge. For example, even after 2 years, this article is one of the top posts that lead people to this site. In that article, I threw some shade at matplotlib and dismissed it during the analysis. However, after using tools such as pandas, scikit-learn, seaborn and the rest of the data science stack in python - I think I was a little premature in dismissing matplotlib. To be honest, I did not quite understand it and how to use it effectively in my workflow.
Now that I have taken the time to learn some of these tools and how to use them with matplotlib, I have started to see matplotlib as an indispensable tool. This post will show how I use matplotlib and provide some recommendations for users getting started or users who have not taken the time to learn matplotlib. I do firmly believe matplotlib is an essential part of the python data science stack and hope this article will help people understand how to use it for their own visualizations.
Why all the negativity towards matplotlib?
In my opinion, there are a couple of reasons why matplotlib is challenging for the new user to learn.
First, matplotlib has two interfaces. The first is based on MATLAB and uses a state-based interface. The second option is an an object-oriented interface. The why’s of this dual approach are outside the scope of this post but knowing that there are two approaches is vitally important when plotting with matplotlib.
The reason two interfaces cause confusion is that in the world of stack overflow and tons of information available via google searches, new users will stumble across multiple solutions to problems that look somewhat similar but are not the same. I can speak from experience. Looking back on some of my old code, I can tell that there is a mishmash of matplotlib code - which is confusing to me (even if I wrote it).
Another historic challenge with matplotlib is that some of the default style choices were rather unattractive. In a world where R could generate some really cool plots with ggplot, the matplotlib options tended to look a bit ugly in comparison. The good news is that matplotlib 2.0 has much nicer styling capabilities and ability to theme your visualizations with minimal effort.
The third challenge I see with matplotlib is that there is confusion as to when you should use pure matplotlib to plot something vs. a tool like pandas or seaborn that is built on top of matplotlib. Anytime there can be more than one way to do something, it is challenging for the new or infrequent user to follow the right path. Couple this confusion with the two different API’s and it is a recipe for frustration.
Why stick with matplotlib?
Despite some of these issues, I have come to appreciate matplotlib because it is extremely powerful. The library allows you to create almost any visualization you could imagine. Additionally, there is a rich ecosystem of python tools built around it and many of the more advanced visualization tools use matplotlib as the base library. If you do any work in the python data science stack, you will need to develop some basic familiarity with how to use matplotlib. That is the focus of the rest of this post - developing a basic approach for effectively using matplotlib.
Basic Premises
If you take nothing else away from this post, I recommend the following steps for learning how to use matplotlib:
- Learn the basic matplotlib terminology, specifically what is a
Figure
and anAxes
. - Always use the object-oriented interface. Get in the habit of using it from the start of your analysis.
- Start your visualizations with basic pandas plotting.
- Use seaborn for the more complex statistical visualizations.
- Use matplotlib to customize the pandas or seaborn visualization.
This graphic from the matplotlib faq is gold. Keep it handy to understand the different terminology of a plot.
Most of the terms are straightforward but the main thing to remember is that the Figure
is the final image that may contain 1 or more axes. The Axes
represent an individual plot. Once you understand what these are and how to access them through the object oriented API, the rest of the process starts to fall into place.
The other benefit of this knowledge is that you have a starting point when you see things on the web. If you take the time to understand this point, the rest of the matplotlib API will start to make sense. Also, many of the advanced python packages like seaborn and ggplot rely on matplotlib so understanding the basics will make those more powerful frameworks much easier to learn.
Finally, I am not saying that you should avoid the other good options like ggplot (aka ggpy), bokeh, plotly or altair. I just think you’ll need a basic understanding of matplotlib + pandas + seaborn to start. Once you understand the basic visualization stack, you can explore the other options and make informed choices based on your needs.
Getting Started
The rest of this post will be a primer on how to do the basic visualization creation in pandas and customize the most common items using matplotlib. Once you understand the basic process, further customizations are relatively straightforward.
I have focused on the most common plotting tasks I encounter such as labeling axes, adjusting limits, updating plot titles, saving figures and adjusting legends. If you would like to follow along, the notebook includes additional detail that should be helpful.
To get started, I am going to setup my imports and read in some data:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
df = pd.read_excel("https://github.com/chris1610/pbpython/blob/master/data/sample-salesv3.xlsx?raw=true")
df.head()
account number | name | sku | quantity | unit price | ext price | date | |
---|---|---|---|---|---|---|---|
0 | 740150 | Barton LLC | B1-20000 | 39 | 86.69 | 3380.91 | 2014-01-01 07:21:51 |
1 | 714466 | Trantow-Barrows | S2-77896 | -1 | 63.16 | -63.16 | 2014-01-01 10:00:47 |
2 | 218895 | Kulas Inc | B1-69924 | 23 | 90.70 | 2086.10 | 2014-01-01 13:24:58 |
3 | 307599 | Kassulke, Ondricka and Metz | S1-65481 | 41 | 21.05 | 863.05 | 2014-01-01 15:05:22 |
4 | 412290 | Jerde-Hilpert | S2-34077 | 6 | 83.21 | 499.26 | 2014-01-01 23:26:55 |
The data consists of sales transactions for 2014. In order to make this post a little shorter, I’m going to summarize the data so we can see the total number of purchases and total sales for the top 10 customers. I am also going to rename columns for clarity during plots.
top_10 = (df.groupby('name')['ext price', 'quantity'].agg({'ext price': 'sum', 'quantity': 'count'})
.sort_values(by='ext price', ascending=False))[:10].reset_index()
top_10.rename(columns={'name': 'Name', 'ext price': 'Sales', 'quantity': 'Purchases'}, inplace=True)
Here is what the data looks like.
Name | Purchases | Sales | |
---|---|---|---|
0 | Kulas Inc | 94 | 137351.96 |
1 | White-Trantow | 86 | 135841.99 |
2 | Trantow-Barrows | 94 | 123381.38 |
3 | Jerde-Hilpert | 89 | 112591.43 |
4 | Fritsch, Russel and Anderson | 81 | 112214.71 |
5 | Barton LLC | 82 | 109438.50 |
6 | Will LLC | 74 | 104437.60 |
7 | Koepp Ltd | 82 | 103660.54 |
8 | Frami, Hills and Schmidt | 72 | 103569.59 |
9 | Keeling LLC | 74 | 100934.30 |
Now that the data is formatted in a simple table, let’s talk about plotting these results as a bar chart.
As I mentioned earlier, matplotlib has many different styles available for rendering plots. You can see which ones are available on your system using plt.style.available
.
plt.style.available
['seaborn-dark', 'seaborn-dark-palette', 'fivethirtyeight', 'seaborn-whitegrid', 'seaborn-darkgrid', 'seaborn', 'bmh', 'classic', 'seaborn-colorblind', 'seaborn-muted', 'seaborn-white', 'seaborn-talk', 'grayscale', 'dark_background', 'seaborn-deep', 'seaborn-bright', 'ggplot', 'seaborn-paper', 'seaborn-notebook', 'seaborn-poster', 'seaborn-ticks', 'seaborn-pastel']
Using a style is as simple as:
plt.style.use('ggplot')
I encourage you to play around with different styles and see which ones you like.
Now that we have a nicer style in place, the first step is to plot the data using the standard pandas plotting function:
top_10.plot(kind='barh', y="Sales", x="Name")
The reason I recommend using pandas plotting first is that it is a quick and easy way to prototype your visualization. Since most people are probably already doing some level of data manipulation/analysis in pandas as a first step, go ahead and use the basic plots to get started.
Customizing the Plot
Assuming you are comfortable with the gist of this plot, the next step is to customize it. Some of the customizations (like adding titles and labels) are very simple to use with the pandas plot
function. However, you will probably find yourself needing to move outside of that functionality at some point. That’s why I recommend getting in the habit of doing this:
fig, ax = plt.subplots()
top_10.plot(kind='barh', y="Sales", x="Name", ax=ax)
The resulting plot looks exactly the same as the original but we added an additional call to plt.subplots()
and passed the ax
to the plotting function. Why should you do this? Remember when I said it is critical to get access to the axes and figures in matplotlib? That’s what we have accomplished here. Any future customization will be done via the ax
or fig
objects.
We have the benefit of a quick plot from pandas but access to all the power from matplotlib now. An example should show what we can do now. Also, by using this naming convention, it is fairly straightforward to adapt others’ solutions to your unique needs.
Suppose we want to tweak the x limits and change some axis labels? Now that we have the axes in the ax
variable, we have a lot of control:
fig, ax = plt.subplots()
top_10.plot(kind='barh', y="Sales", x="Name", ax=ax)
ax.set_xlim([-10000, 140000])
ax.set_xlabel('Total Revenue')
ax.set_ylabel('Customer');
Here’s another shortcut we can use to change the title and both labels:
fig, ax = plt.subplots()
top_10.plot(kind='barh', y="Sales", x="Name", ax=ax)
ax.set_xlim([-10000, 140000])
ax.set(title='2014 Revenue', xlabel='Total Revenue', ylabel='Customer')
To further demonstrate this approach, we can also adjust the size of this image. By using the plt.subplots()
function, we can define the figsize
in inches. We can also remove the legend using ax.legend().set_visible(False)
fig, ax = plt.subplots(figsize=(5, 6))
top_10.plot(kind='barh', y="Sales", x="Name", ax=ax)
ax.set_xlim([-10000, 140000])
ax.set(title='2014 Revenue', xlabel='Total Revenue')
ax.legend().set_visible(False)
There are plenty of things you probably want to do to clean up this plot. One of the biggest eye sores is the formatting of the Total Revenue numbers. Matplotlib can help us with this through the use of the FuncFormatter
. This versatile function can apply a user defined function to a value and return a nicely formatted string to place on the axis.
Here is a currency formatting function to gracefully handle US dollars in the several hundred thousand dollar range:
def currency(x, pos):
'The two args are the value and tick position'
if x >= 1000000:
return '${:1.1f}M'.format(x*1e-6)
return '${:1.0f}K'.format(x*1e-3)
Now that we have a formatter function, we need to define it and apply it to the x axis. Here is the full code:
fig, ax = plt.subplots()
top_10.plot(kind='barh', y="Sales", x="Name", ax=ax)
ax.set_xlim([-10000, 140000])
ax.set(title='2014 Revenue', xlabel='Total Revenue', ylabel='Customer')
formatter = FuncFormatter(currency)
ax.xaxis.set_major_formatter(formatter)
ax.legend().set_visible(False)
That’s much nicer and shows a good example of the flexibility to define your own solution to the problem.
The final customization feature I will go through is the ability to add annotations to the plot. In order to draw a vertical line, you can use ax.axvline()
and to add custom text, you can use ax.text()
.
For this example, we’ll draw a line showing an average and include labels showing three new customers. Here is the full code with comments to pull it all together.
# Create the figure and the axes
fig, ax = plt.subplots()
# Plot the data and get the averaged
top_10.plot(kind='barh', y="Sales", x="Name", ax=ax)
avg = top_10['Sales'].mean()
# Set limits and labels
ax.set_xlim([-10000, 140000])
ax.set(title='2014 Revenue', xlabel='Total Revenue', ylabel='Customer')
# Add a line for the average
ax.axvline(x=avg, color='b', label='Average', linestyle='--', linewidth=1)
# Annotate the new customers
for cust in [3, 5, 8]:
ax.text(115000, cust, "New Customer")
# Format the currency
formatter = FuncFormatter(currency)
ax.xaxis.set_major_formatter(formatter)
# Hide the legend
ax.legend().set_visible(False)
While this may not be the most exciting plot it does show how much power you have when following this approach.
Figures and Plots
Up until now, all the changes we have made have been with the individual plot. Fortunately, we also have the ability to add multiple plots on a figure as well as save the entire figure using various options.
If we decided that we wanted to put two plots on the same figure, we should have a basic understanding of how to do it. First, create the figure, then the axes, then plot it all together. We can accomplish this using plt.subplots()
:
fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(7, 4))
In this example, I’m using nrows
and ncols
to specify the size because this is very clear to the new user. In sample code you will frequently just see variables like 1,2. I think using the named parameters is a little easier to interpret later on when you’re looking at your code.
I am also using sharey=True
so that the yaxis will share the same labels.
This example is also kind of nifty because the various axes get unpacked to ax0
and ax1
. Now that we have these axes, you can plot them like the examples above but put one plot on ax0
and the other on ax1
.
# Get the figure and the axes
fig, (ax0, ax1) = plt.subplots(nrows=1,ncols=2, sharey=True, figsize=(7, 4))
top_10.plot(kind='barh', y="Sales", x="Name", ax=ax0)
ax0.set_xlim([-10000, 140000])
ax0.set(title='Revenue', xlabel='Total Revenue', ylabel='Customers')
# Plot the average as a vertical line
avg = top_10['Sales'].mean()
ax0.axvline(x=avg, color='b', label='Average', linestyle='--', linewidth=1)
# Repeat for the unit plot
top_10.plot(kind='barh', y="Purchases", x="Name", ax=ax1)
avg = top_10['Purchases'].mean()
ax1.set(title='Units', xlabel='Total Units', ylabel='')
ax1.axvline(x=avg, color='b', label='Average', linestyle='--', linewidth=1)
# Title the figure
fig.suptitle('2014 Sales Analysis', fontsize=14, fontweight='bold');
# Hide the legends
ax1.legend().set_visible(False)
ax0.legend().set_visible(False)
Up until now, I have been relying on the jupyter notebook to display the figures by virtue of the %matplotlib inline
directive. However, there are going to be plenty of times where you have the need to save a figure in a specific format and integrate it with some other presentation.
Matplotlib supports many different formats for saving files. You can use fig.canvas.get_supported_filetypes()
to see what your system supports:
fig.canvas.get_supported_filetypes()
{'eps': 'Encapsulated Postscript', 'jpeg': 'Joint Photographic Experts Group', 'jpg': 'Joint Photographic Experts Group', 'pdf': 'Portable Document Format', 'pgf': 'PGF code for LaTeX', 'png': 'Portable Network Graphics', 'ps': 'Postscript', 'raw': 'Raw RGBA bitmap', 'rgba': 'Raw RGBA bitmap', 'svg': 'Scalable Vector Graphics', 'svgz': 'Scalable Vector Graphics', 'tif': 'Tagged Image File Format', 'tiff': 'Tagged Image File Format'}
Since we have the fig
object, we can save the figure using multiple options:
fig.savefig('sales.png', transparent=False, dpi=80, bbox_inches="tight")
This version saves the plot as a png with opaque background. I have also specified the dpi and bbox_inches="tight"
in order to minimize excess white space.
Conclusion
Hopefully, this process has helped you understand how to more effectively use matplotlib in your daily data analysis. If you get in the habit of using this approach when doing your analysis, you should be able to quickly find out how to do whatever you need to do to customize your plot.
As a final bonus, I am including a quick guide to unify all the concepts. I hope this helps bring this post together and proves a handy reference for future use.
No comments:
Post a Comment