Skip to content

Commit

Permalink
add plot
Browse files Browse the repository at this point in the history
  • Loading branch information
ShomyLiu committed Jul 12, 2018
1 parent 6d39ad5 commit 79478eb
Showing 1 changed file with 111 additions and 0 deletions.
111 changes: 111 additions & 0 deletions plot.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import os\n",
"import matplotlib as mpl\n",
"# mpl.rcParams['figure.figsize'] = (15,15)\n",
"\n",
"mpl.rcParams['figure.figsize'] = (9,6)\n",
"plt.ioff()\n",
"\n",
"color = [ 'grey', 'r', 'b', 'black','teal','cornflowerblue', 'g', 'gray', 'c', 'r','m', 'y', 'k']\n",
"label_font_size = 18\n",
"marker = ['>', 'v', '^', 'o', 's']\n",
"\n",
"d = 'out/'\n",
"def getXY(file_name, n = 3000):\n",
" x, y = [], []\n",
" with open(file_name) as f:\n",
" for i, line in enumerate(f):\n",
" entries = line.split()\n",
" y.append(float(entries[0]))\n",
" x.append(float(entries[1]))\n",
" if i >= n:\n",
" break\n",
" return x, y\n",
"\n",
"def plot_mul(files):\n",
" '''\n",
" 对比不同模型的PR曲线\n",
" 传入列表,元素为文件完整名\n",
" '''\n",
" print d\n",
" for i, f in enumerate(files):\n",
" path = './{}/{}'.format(d,f)\n",
" if not os.path.exists(path):\n",
" print('{} is not exists'.format(f))\n",
" continue\n",
"\n",
" x, y = getXY(path)\n",
" plt.plot(x,y, marker = marker[i%len(marker)], markevery = 100, markersize = 5, color = color[i%len(color)])\n",
" legend = ['_'.join(i.split('_')[:-2]) for i in files]\n",
" plt.legend(legend, prop={'size':12})\n",
" plt.ylim([0.3, 1])\n",
" plt.xlim([0.0, 0.5])\n",
" plt.xlabel('Recall', fontsize=label_font_size)\n",
" plt.ylabel('Precision', fontsize=label_font_size)\n",
" plt.gca().tick_params(labelsize=16)\n",
" plt.grid(linestyle='dashdot')\n",
" plt.show() \n",
" \n",
"\n",
"def plot_one(prefix, flag=True):\n",
" '''\n",
" 绘制同一个模型不同epoch的PR曲线\n",
" 传入模型前缀即可(如: PCNN_ATT_DEF)\n",
" '''\n",
" fid = []\n",
" print d\n",
" for i in range(1, 18):\n",
" if flag:\n",
" path = './{}/{}_{}_PR.txt'.format(d, prefix, i)\n",
" else:\n",
" path = './{}/{}_{}.txt'.format(d, prefix, i)\n",
" if not os.path.exists(path):\n",
" #print path\n",
" continue\n",
" fid.append(i)\n",
" x, y = getXY(path)\n",
" plt.plot(x,y, marker = '>', markevery = 100, markersize = 5, color = color[(i-1)%len(color)])\n",
" \n",
" plt.legend([prefix + str(i) for i in fid],prop={'size':10})\n",
" plt.ylim([0.2, 1])\n",
" plt.xlim([0.0, 0.5])\n",
" plt.xlabel('Recall', fontsize=label_font_size,)\n",
" plt.ylabel('Precision', fontsize=label_font_size)\n",
" plt.gca().tick_params(labelsize=16)\n",
" plt.grid(linestyle='dashdot')\n",
" plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 79478eb

Please sign in to comment.