Details | Last modification | View Log | RSS feed
Rev | Author | Line No. | Line |
---|---|---|---|
14 | pmbaty | 1 | import sys |
2 | import multiprocessing |
||
3 | |||
4 | |||
5 | _current = None |
||
6 | _total = None |
||
7 | |||
8 | |||
9 | def _init(current, total): |
||
10 | global _current |
||
11 | global _total |
||
12 | _current = current |
||
13 | _total = total |
||
14 | |||
15 | |||
16 | def _wrapped_func(func_and_args): |
||
17 | func, argument, should_print_progress, filter_ = func_and_args |
||
18 | |||
19 | if should_print_progress: |
||
20 | with _current.get_lock(): |
||
21 | _current.value += 1 |
||
22 | sys.stdout.write('\r\t{} of {}'.format(_current.value, _total.value)) |
||
23 | sys.stdout.flush() |
||
24 | |||
25 | return func(argument, filter_) |
||
26 | |||
27 | |||
28 | def pmap(func, iterable, processes, should_print_progress, filter_=None, *args, **kwargs): |
||
29 | """ |
||
30 | A parallel map function that reports on its progress. |
||
31 | |||
32 | Applies `func` to every item of `iterable` and return a list of the |
||
33 | results. If `processes` is greater than one, a process pool is used to run |
||
34 | the functions in parallel. `should_print_progress` is a boolean value that |
||
35 | indicates whether a string 'N of M' should be printed to indicate how many |
||
36 | of the functions have finished being run. |
||
37 | """ |
||
38 | global _current |
||
39 | global _total |
||
40 | _current = multiprocessing.Value('i', 0) |
||
41 | _total = multiprocessing.Value('i', len(iterable)) |
||
42 | |||
43 | func_and_args = [(func, arg, should_print_progress, filter_) for arg in iterable] |
||
44 | if processes == 1: |
||
45 | result = list(map(_wrapped_func, func_and_args, *args, **kwargs)) |
||
46 | else: |
||
47 | pool = multiprocessing.Pool(initializer=_init, |
||
48 | initargs=(_current, _total,), |
||
49 | processes=processes) |
||
50 | result = pool.map(_wrapped_func, func_and_args, *args, **kwargs) |
||
51 | pool.close() |
||
52 | pool.join() |
||
53 | |||
54 | if should_print_progress: |
||
55 | sys.stdout.write('\r') |
||
56 | return result |