可視化(part 2)#

近年、SeabornPlotlyなどmatplotlibの機能を補完するような形態のライブラリ拡張するために広く利用されています。これらのライブラリはデータの可視化をより直感的かつ簡単に行うためのツールを提供しており、matplotlib単体では実現しにくい高度なグラフィックやインタラクティブなプロットを作成することが可能です。

Seaborn#

Seabornmatplotlibの上に構築されているデータ可視化ライブラリで、複雑なグラフを簡単に作成するための高レベルのインターフェースを提供します。

  • 散布図、棒グラフ、ヒートマップ、箱ひげ図、バイオリンプロット、ペアプロットなど、多様なプロットタイプをサポートしています。

  • データフレームとの親和性が高く、データセットの操作やフィルタリング、グループ化などが容易に行えます。

  • 統計的なデータ可視化を簡単に行えるように設計されて、統計的なデータ集約や要約を自動的に行う機能を持っています。

#!pip install seaborn

基本的な使い方#

Seabornmatplotlibを補完するものなので、よくこの2つをセットで使います。

#import scienceplots
import matplotlib.pyplot as plt
from matplotlib import font_manager
#plt.style.use(['science','no-latex'])
# Path to your TTF file
ttf_path = './Noto_Sans_JP/NotoSansJP-VariableFont_wght.ttf'
# Register the font
font_manager.fontManager.addfont(ttf_path)
custom_font = font_manager.FontProperties(fname=ttf_path)
# Set the custom font as default
plt.rcParams['font.family'] = custom_font.get_name()
plt.rcParams['font.family'] = 'Hiragino Sans'
import pandas as pd
import seaborn as sns

df=pd.read_csv("https://raw.githubusercontent.com/lvzeyu/css_tohoku/master/css_tohoku/draft/Data/titanic.csv")
plt.figure(figsize=(6, 4))

ax=sns.histplot(
    data=df,
    x="age",
    kde=True,# カーネル密度推定
    hue="sex",
    multiple="dodge",  # “layer”, “dodge”, “stack”, “fill”
    palette={"male": "blue", "female": "red"},
)

ax.set_xlabel("年齢(歳)",fontsize=14)
ax.set_ylabel("人数",fontsize=14)
ax.legend(title="性別", title_fontsize='13', loc='upper right',labels=['男性', '女性'])

plt.show()
../_images/dac8af1ac6d5e9fbbf5cc3929db0a670c37f25136089eda74686d2aabc17b047.png

カテゴリ別のプロット#

plt.figure(figsize=(6, 4))

sns.boxplot(data=df, x="embarked", y="age", hue="survived",width=.5)

plt.show()
../_images/094b2a10d575a8e99878cb3d1e07d0ef636966ee8b48246fd57ba5771db0edfb.png
plt.figure(figsize=(6, 4))

ax=sns.boxplot(
    data=df, x="age", y="embarked",
    notch=True, showcaps=False,
    flierprops={"marker": "x"},
    boxprops={"facecolor": (.3, .5, .7, .5)},
    medianprops={"color": "black", "linewidth": 2},
)

plt.show()
../_images/83b7da6e6957a0e88b868525eccf6a3abbeff0dbee8a46339cabe35ceed9e003.png
tips = sns.load_dataset("tips")
sns.scatterplot(data=tips, x="total_bill", y="tip", hue="size", size="size",style="time")
<Axes: xlabel='total_bill', ylabel='tip'>
../_images/c535c0ff39cd71219a850148f9226ed6a04229ef86e94da7289d3a85450de13e.png

回帰直線#

tips = sns.load_dataset("tips")
sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips,
           markers=["o", "x"], palette="Set1")
<seaborn.axisgrid.FacetGrid at 0x15f945d60>
../_images/70b43db8576fdfa9eecec2c9a9ec34a6389416a3f77bd5db4c76343b8d33dd9f.png
plt.figure(figsize=(6, 6))

sns.lmplot(x="total_bill", y="tip", hue="smoker",
           col="time", row="sex", data=tips)

plt.show()
<Figure size 600x600 with 0 Axes>
../_images/1ce0e9a6e32af5ba2b7ae4730023c662d7ba71288c22ae306b2dc0142b486974.png
plt.figure(figsize=(6, 4))

ax=sns.jointplot(x="total_bill", y="tip", data=tips, kind="reg")

plt.show()
<Figure size 600x400 with 0 Axes>
../_images/2edd6e438dfeb72c3db558a42f0a60b01f6884fed84573eb65fcd1c9df63431d.png

ヒートマップ#

# 航空機のデータを読み込み

flights_long = sns.load_dataset("flights")

# ピボットを生成

flights = (
    flights_long
    .pivot(index="month", columns="year", values="passengers")
)

# Draw a heatmap with the numeric values in each cell
f, ax = plt.subplots(figsize=(9, 6))
sns.heatmap(flights, annot=True, fmt="d", linewidths=.5, ax=ax)
<Axes: xlabel='year', ylabel='month'>
../_images/6fd86507ca88edf69770ea442f374060caa305375051ff946ff7a42bd45393b4.png

相関関係の可視化#

penguins = sns.load_dataset("penguins")
sns.pairplot(penguins, hue="species")
<seaborn.axisgrid.PairGrid at 0x15fceeea0>
../_images/807884497db4506656aac091a72856b20624ce5662d6e56e2e783b70c2a8c6a7.png

Plotly#

Plotlyは、インタラクティブなグラフを作成するための強力なオープンソースのライブラリです。動的でインタラクティブなグラフを生成できるため、データの視覚化をより深く、豊かに表現することができます。

#!pip install plotly
import plotly.express as px
df = px.data.iris()
fig = px.scatter(df, x="sepal_width", y="sepal_length", color="species",
                 size='petal_length', hover_data=['petal_width'])
fig.show()

Plotlyのモジュール#

Plotlyにはplotly.graph_objectsplotly.expressという2つの主要なモジュールがあります。

  • plotly.graph_objects: より細かい制御やカスタマイズが可能です。グラフの構成要素を個別に設定したり、複雑なグラフを作成する際に使います。

  • plotly.express: 高レベルな関数でグラフを作成できます。より簡単に、少ないコードでシンプルなグラフを作成するのに適しています。

import pandas as pd

df = pd.DataFrame({
  "Fruit": ["Apples", "Oranges", "Bananas", "Apples", "Oranges", "Bananas"],
  "Contestant": ["Alex", "Alex", "Alex", "Jordan", "Jordan", "Jordan"],
  "Number Eaten": [2, 1, 3, 1, 3, 2],
})
import plotly.express as px

fig = px.bar(df, x="Fruit", y="Number Eaten", color="Contestant", barmode="group")
fig.show()
import plotly.graph_objects as go

fig = go.Figure()
for contestant, group in df.groupby("Contestant"):
    fig.add_trace(go.Bar(x=group["Fruit"], y=group["Number Eaten"], name=contestant,
      hovertemplate="Contestant=%s<br>Fruit=%%{x}<br>Number Eaten=%%{y}<extra></extra>"% contestant))
fig.update_layout(legend_title_text = "Contestant")
fig.update_xaxes(title_text="Fruit")
fig.update_yaxes(title_text="Number Eaten")
fig.show()

基本的な使い方#

  • plotly.expressで基本的な図を描画します

  • fig.update_でレイアウトなどを細かく設定します

iris = sns.load_dataset('iris')

fig = px.histogram(iris, x='sepal_length', color='species', 
                           nbins=19, range_x=[4,8], width=600, height=350,
                           opacity=0.4, marginal='box')
# histogram描画時にrange_yを指定すると、marginalのboxplotの描画位置が崩れる
fig.update_layout(barmode='overlay')
fig.update_yaxes(range=[0,20],row=1, col=1)
# htmlで保存、以後は省略
# fig.write_html('histogram_with_boxplot.html', auto_open=False)

補足: SeabornとMatplotlibの違いと併用#

  • Matplotlib

    • 高い柔軟性を持つため、カスタマイズの幅が非常に広いですが、コードがやや複雑になることがあります

  • Seaborn

    • Matplotlibを基盤として構築された高レベルのデータ可視化ライブラリであり、複雑なプロットも簡単に作成できます (基本的なカスタマイズは可能だが、柔軟性はやや限定的)

SeabornとMatplotlibの違い#

データフレームの扱い#

import pandas as pd

df = pd.DataFrame({
    'group': ['A', 'A', 'B', 'B', 'C', 'C'],
    'category': ['X', 'Y', 'X', 'Y', 'X', 'Y'],
    'value': [10, 15, 7, 12, 5, 9]
})
import seaborn as sns

sns.barplot(data=df, x="category", y="value", hue="group")
plt.title("Seaborn: Grouped Barplot by 'group'")
plt.show()
../_images/14d32f3b2ece525508ba52c50bc6be7dee7576f2e3a47992816865fbfab7c923.png
import numpy as np

# uniqueなカテゴリとグループを取得
categories = df['category'].unique()
groups = df['group'].unique()

# 棒の位置調整
x = np.arange(len(categories))
width = 0.25

fig, ax = plt.subplots()

# 各groupごとに描画
for i, group in enumerate(groups):
    values = df[df['group'] == group].sort_values('category')['value']
    ax.bar(x + i * width, values, width=width, label=group)

ax.set_xticks(x + width)
ax.set_xticklabels(categories)
ax.set_title("Matplotlib: Grouped Barplot by 'group'")
ax.legend(title="Group")
plt.show()
../_images/c22066d5982dfc27785b9be60855a038172ce72836938cb1fba5010f8fdb11b3.png

統計的プロット#

tips = sns.load_dataset("tips")
sns.regplot(data=tips, x="total_bill", y="tip")
plt.title("Seaborn")
plt.show()
../_images/56cf427298c0f121f8b196174b319d9fe5737d0c39648969f1dae2db2ee9ea47.png
# x, y の抽出
x = tips["total_bill"].values
y = tips["tip"].values

# 線形回帰の係数(NumPyで最小二乗法)
a, b = np.polyfit(x, y, deg=1)  # y = ax + b

# 散布図と回帰直線の描画
plt.scatter(x, y, label="Data")
plt.plot(x, a * x + b, color="red", label=f"y = {a:.2f}x + {b:.2f}")
plt.title("Matplotlib")
plt.xlabel("Total Bill")
plt.ylabel("Tip")
plt.legend()
plt.show()
../_images/59fe9462bd52f7a3c6e5abc86ebd835cd52e7404ceb4b6fac0f8e8cde56461dd.png

SeabornとMatplotlibの併用#

# データの読み込み
tips = sns.load_dataset("tips")

# Seabornで描画しつつAxesオブジェクトを取得
ax = sns.scatterplot(data=tips, x="total_bill", y="tip")

# Matplotlibで細かく調整
ax.set_title("Tip vs Total Bill", fontsize=16)
ax.axhline(y=5, color='red', linestyle='--', label="Y=5 Line")
ax.annotate("Here", xy=(30, 5), xytext=(25, 7),
            arrowprops=dict(arrowstyle="->", color='gray'))

ax.legend()
plt.show()
../_images/4e2df7122779df68cc276456f05570ce21fd6924eef92d1527cacd4c4992e68d.png
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# 左:箱ひげ図
sns.boxplot(data=tips, x="day", y="tip", ax=axes[0])
axes[0].set_title("Boxplot")

# 右:回帰線付き散布図
sns.regplot(data=tips, x="total_bill", y="tip", ax=axes[1])
axes[1].set_title("Regression Plot")

plt.tight_layout()
plt.show()
../_images/78eacfdc8643f5641268795b7abfb292dfa046ca9f780c4343f995421aadeba9.png