stats_2 = az.summary(idata_2, kind= "stats" )["mean" ]
x = np.linspace(beauty["beauty" ].min (), beauty["beauty" ].max (), 10 )
y_pred_male = stats_2["Intercept" ] + stats_2["beauty" ] * x
y_pred_female = stats_2["Intercept" ] + stats_2["female" ] + stats_2["beauty" ] * x
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2 , 2 , figsize= (12 , 9 ))
# Men
ax1.scatter("beauty" , "eval" , data= beauty.query("female == 0" ), color= "grey" , s= 10 )
ax1.set_title("Men" )
ax1.set_xlabel("Beauty" )
ax1.set_ylabel("Average teaching evaluation" )
ax1.plot(x, y_pred_male, color= "grey" )
# Women
ax2.scatter("beauty" , "eval" , data= beauty.query("female == 1" ), color= "black" , s= 10 )
ax2.set_title("Women" )
ax2.set_xlabel("Beauty" )
ax2.set_ylabel("Average teaching evaluation" )
ax2.plot(x, y_pred_female, color= "black" )
# Both sexes
ax3.scatter("beauty" , "eval" , data= beauty.query("female == 1" ), color= "black" , s= 10 )
ax3.scatter("beauty" , "eval" , data= beauty.query("female == 0" ), color= "grey" , s= 10 )
ax3.set_xlabel("Beauty" )
ax3.set_ylabel("Average teaching evaluation" )
ax3.plot(x, y_pred_female, color= "black" )
ax3.plot(x, y_pred_male, color= "grey" )
ax4.remove()
plt.tight_layout();