[docs]defsummary(model:torch.nn.Module)->tuple[str,int]:"""Counts the number of parameters in each model layer. :param model: model to summarize :return: repr: a multiline string representation of the network nparam: number of parameters """defrepr(model):# We treat the extra repr like the sub-module, one item per lineextra_lines=[]extra_repr=model.extra_repr()# empty string will be split into list ['']ifextra_repr:extra_lines=extra_repr.split("\n")child_lines=[]total_params=0forkey,moduleinmodel._modules.items():mod_str,num_params=repr(module)mod_str=_addindent(mod_str,2)child_lines.append("("+key+"): "+mod_str)total_params+=num_paramslines=extra_lines+child_linesforname,pinmodel._parameters.items():ifhasattr(p,"dtype"):total_params+=reduce(lambdax,y:x*y,p.shape)main_str=model._get_name()+"("iflines:# simple one-liner info, which most builtin Modules will useiflen(extra_lines)==1andnotchild_lines:main_str+=extra_lines[0]else:main_str+="\n "+"\n ".join(lines)+"\n"main_str+=")"main_str+=f", {total_params:,} params"returnmain_str,total_paramsreturnrepr(model)