Source code for simulai.parallel

# (C) Copyright IBM Corp. 2019, 2020, 2021, 2022.

#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at

#           http://www.apache.org/licenses/LICENSE-2.0

#     Unless required by applicable law or agreed to in writing, software
#     distributed under the License is distributed on an "AS IS" BASIS,
#     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#     See the License for the specific language governing permissions and
#     limitations under the License.

try:
    from mpi4py import MPI
except:
    print('It must be configured.')

# Pipeline for executing independent MPI jobs
[docs]class PipelineMPI: def __init__(self, exec: callable=None, extra_params:dict=None, collect:bool=None, show_log:bool=True) -> None: self.exec = exec self.show_log = show_log if extra_params is not None: self.extra_params = extra_params else: self.extra_params = {} self.collect = collect self.comm = MPI.COMM_WORLD self.n_procs = self.comm.Get_size() self.status = (self.n_procs - 1)*[False] self.status_dict = dict() # Check if the provided datasets def _check_kwargs_consistency(self, kwargs: dict=None) -> int: types = [type(value) for value in kwargs.values()] lengths = [len(value) for value in kwargs.values()] assert all([t==list for t in types]), f"All the elements in kwargs must be list," \ f" but received {types}." assert len(set(lengths)) == 1, f"All the elements in kwargs must be the same length," \ f" but received {lengths}" print("kwargs is alright.") return lengths[0] # The workload can be executed serially in each worker node def _split_kwargs(self, kwargs:dict, rank:int, size:int, total_size:int) -> (dict, int): size -= 1 rank -= 1 batch_size_float = total_size/size # If the number of instances cannot be equally distributed between # the ranks, redistribute the residual if batch_size_float % size != 0: res = total_size % size batch_size = int((total_size - res)/size) if (total_size - res) == (rank+1)*batch_size: append = res else: append = 0 kwargs_batch = {key: value[rank*batch_size:(rank+1)*batch_size + append] for key, value in kwargs.items()} batch_size += append else: batch_size = int(batch_size_float) kwargs_batch = {key: value[rank * batch_size:(rank + 1) * batch_size] for key, value in kwargs.items()} return kwargs_batch, batch_size def _attribute_dict_output(self, dicts:list=None) -> None: root = dict() for e in dicts: root.update(e) for key, value in root.items(): self.status_dict[key] = value
[docs] @staticmethod def inner_type(obj: list=None): types_list = [type(o) for o in obj] assert len(set(types_list)) == 1, "Composed types are not supported." return types_list[0]
def _exec_wrapper(self, kwargs:dict, total_size:int) -> None: comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() size_ = size # Rank 0 is the 'master' node # The worker nodes execute their workload and send a message to # master if rank != 0: print(f"Executing rank {rank}.") kwargs_batch, batch_size = self._split_kwargs(kwargs, rank , size_, total_size) kwargs_batch_list = [{key:value[j] for key, value in kwargs_batch.items()} for j in range(batch_size)] out = list() for i in kwargs_batch_list: print(f"Executing batch {i['key']} in rank {rank}") # Concatenate the rank to the extra parameters i.update(self.extra_params) # Appending the result of the operation self.exec to the partial list out.append(self.exec(**i)) if self.collect is True: msg = out else: msg = 1 if self.show_log: print(f"Sending the output {msg} to rank 0") comm.send(msg, dest=0) print(f"Execution concluded for rank {rank}.") # The master awaits the responses of each worker node elif rank == 0: for r in range(1, size): msg = comm.recv(source=r) self.status[r - 1] = msg if self.inner_type(msg) == dict: self._attribute_dict_output(dicts=msg) if self.show_log: print(f"Rank 0 received {msg} from rank {r}") comm.barrier() @property def success(self): return all(self.status)
[docs] def run(self, kwargs:dict=None) -> None: comm = MPI.COMM_WORLD rank = comm.Get_rank() total_size = 0 # Checking if the datasets dimensions are in accordance with the expected ones if rank == 0: total_size = self._check_kwargs_consistency(kwargs=kwargs) total_size = comm.bcast(total_size, root=0) comm.barrier() # Executing a wrapper containing the parallelized operation self._exec_wrapper(kwargs, total_size) comm.barrier()