matplotlib之如何解决 python 多处理 matplotlib savefig() 问题

sky-heaven 阅读:34 2024-04-12 10:01:50 评论:0

我想通过多处理模块为许多图形加速 matplotlib.savefig(),并尝试对并行和序列之间的性能进行基准测试。

以下是代码:

# -*- coding: utf-8 -*- 
""" 
Compare the time of matplotlib savefig() in parallel and sequence 
""" 
 
import numpy as np 
import matplotlib.pyplot as plt 
import multiprocessing 
import time 
 
 
def gen_fig_list(n): 
    ''' generate a list to contain n demo scatter figure object ''' 
    plt.ioff() 
    fig_list = [] 
    for i in range(n): 
        plt.figure(); 
        dt = np.random.randn(5, 4); 
        fig = plt.scatter(dt[:,0], dt[:,1], s=abs(dt[:,2]*1000), c=abs(dt[:,3]*100)).get_figure() 
        fig.FM_figname = "img"+str(i) 
        fig_list.append(fig) 
    plt.ion() 
    return fig_list 
 
 
def savefig_worker(fig, img_type, folder): 
    file_name = folder+"\\"+fig.FM_figname+"."+img_type 
    fig.savefig(file_name, format=img_type, dpi=fig.dpi) 
    return file_name 
 
 
def parallel_savefig(fig_list, folder): 
    proclist = [] 
    for fig in fig_list: 
        print fig.FM_figname, 
        p = multiprocessing.Process(target=savefig_worker, args=(fig, 'png', folder)) # cause error 
        proclist.append(p) 
        p.start() 
 
    for i in proclist: 
        i.join() 
 
 
 
if __name__ == '__main__': 
    folder_1, folder_2 = 'Z:\\A1', 'Z:\\A2' 
    fig_list = gen_fig_list(10) 
 
    t1 = time.time() 
    parallel_savefig(fig_list,folder_1) 
    t2 = time.time() 
    print '\nMulprocessing time    : %0.3f'%((t2-t1)) 
 
    t3 = time.time() 
    for fig in fig_list: 
        savefig_worker(fig, 'png', folder_2) 
    t4 = time.time() 
    print 'Non_Mulprocessing time: %0.3f'%((t4-t3)) 

我遇到问题 "This application has requested the Runtime to terminate it in an unusual way. Please contact the application's support team for more information."p = multiprocessing.Process(target=savefig_worker, args=(fig, 'png', folder)) 引起的错误.

为什么 ?以及如何解决?

(Windows XP + Python:2.6.1 + Numpy:1.6.2 + Matplotlib:1.2.0)

编辑:(在 python 2.7.3 上添加错误消息)

在 python 2.7.3 的 IDLE 上运行时,它会给出以下错误消息:
>>>  
img0 
 
Traceback (most recent call last): 
  File "C:\Documents and Settings\Administrator\desktop\mulsavefig_pilot.py", line 61, in <module> 
    proc.start() 
  File "d:\Python27\lib\multiprocessing\process.py", line 130, in start 
 
  File "d:\Python27\lib\pickle.py", line 286, in save 
    f(self, obj) # Call unbound method with explicit self 
  File "d:\Python27\lib\pickle.py", line 748, in save_global 
    (obj, module, name)) 
PicklingError: Can't pickle <function notify_axes_change at 0x029F5030>: it's not found as matplotlib.backends.backend_qt4.notify_axes_change 

编辑:(我的解决方案演示)

灵感来自 Matplotlib: simultaneous plotting in multiple threads
# -*- coding: utf-8 -*- 
""" 
Compare the time of matplotlib savefig() in parallel and sequence 
""" 
 
import numpy as np 
import matplotlib.pyplot as plt 
import multiprocessing 
import time 
 
 
def gen_data(fig_qty, bubble_qty): 
    ''' generate data for fig drawing ''' 
    dt = np.random.randn(fig_qty, bubble_qty, 4) 
    return dt 
 
 
def parallel_savefig(draw_data, folder): 
    ''' prepare data and pass to worker ''' 
 
    pool = multiprocessing.Pool() 
 
    fig_qty = len(draw_data) 
    fig_para = zip(range(fig_qty), draw_data, [folder]*fig_qty) 
 
    pool.map(fig_draw_save_worker, fig_para) 
    return None 
 
 
def fig_draw_save_worker(args): 
    seq, dt, folder = args 
    plt.figure() 
    fig = plt.scatter(dt[:,0], dt[:,1], s=abs(dt[:,2]*1000), c=abs(dt[:,3]*100), alpha=0.7).get_figure() 
    plt.title('Plot of a scatter of %i' % seq) 
    fig.savefig(folder+"\\"+'fig_%02i.png' % seq) 
    plt.close() 
    return None 
 
 
if __name__ == '__main__': 
    folder_1, folder_2 = 'A1', 'A2' 
    fig_qty, bubble_qty =  500, 100 
    draw_data = gen_data(fig_qty, bubble_qty) 
 
    print 'Mulprocessing  ...   ', 
    t1 = time.time() 
    parallel_savefig(draw_data, folder_1) 
    t2 = time.time() 
    print 'Time : %0.3f'%((t2-t1)) 
 
    print 'Non_Mulprocessing .. ',  
    t3 = time.time() 
    for para in zip(range(fig_qty), draw_data, [folder_2]*fig_qty): 
        fig_draw_save_worker(para) 
    t4 = time.time() 
    print 'Time : %0.3f'%((t4-t3)) 
 
    print 'Speed Up: %0.1fx'%(((t4-t3)/(t2-t1))) 

请您参考如下方法:

您可以尝试将所有 matplotlib 代码(包括导入)移动到一个函数中。

  • 确保您的代码顶部没有 import matplotlib 或 import matplotlib.pyplot as plt。
  • 创建一个执行所有 matplotlib 的函数,包括导入。

  • 例子:
    import numpy as np 
    from multiprocessing import pool 
     
    def graphing_function(graph_data): 
        import matplotlib.pyplot as plt 
        plt.figure() 
        plt.hist(graph_data.data) 
        plt.savefig(graph_data.filename) 
        plt.close() 
        return 
     
    pool = Pool(4) 
    pool.map(graphing_function, data_list)  
    


    标签:Python
    声明

    1.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源;2.本站的原创文章,请转载时务必注明文章作者和来源,不尊重原创的行为我们将追究责任;3.作者投稿可能会经我们编辑修改或补充。

    关注我们

    一个IT知识分享的公众号