KalelPark's LAB

[ Pytorch ] register_buffer, register_parameter란? 본문

Python/Pytorch

[ Pytorch ] register_buffer, register_parameter란?

kalelpark 2022. 12. 29. 15:38

torch.nn.Module.register_buffer

         - parameter가 아니라 buffer를 수행하기 위한 목적으로 활용됩니다.

         - buffer는 state_dict에 저장되지만, 최적화에 사용되지 않습니다. 즉, 단순한 module이라고 볼 수 있습니다.

def register_module(self, name : str, module : Optional["Module"]) -> None:
    self.add_module(name, module)

torch.nn.Module.register_parameter

         - module에 name을 기반으로 parameter를 추가합니다.

         - register_buffer와 다르게, 최적화에 사용될 수 있습니다.

def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
        r"""Adds a parameter to the module.

        The parameter can be accessed as an attribute using given name.

        Args:
            name (str): name of the parameter. The parameter can be accessed
                from this module using the given name
            param (Parameter or None): parameter to be added to the module. If
                ``None``, then operations that run on parameters, such as :attr:`cuda`,
                are ignored. If ``None``, the parameter is **not** included in the
                module's :attr:`state_dict`.
        """
        if '_parameters' not in self.__dict__:
            raise AttributeError(
                "cannot assign parameter before Module.__init__() call")

        elif not isinstance(name, torch._six.string_classes):
            raise TypeError("parameter name should be a string. "
                            "Got {}".format(torch.typename(name)))
        elif '.' in name:
            raise KeyError("parameter name can't contain \".\"")
        elif name == '':
            raise KeyError("parameter name can't be empty string \"\"")
        elif hasattr(self, name) and name not in self._parameters:
            raise KeyError("attribute '{}' already exists".format(name))

        if param is None:
            self._parameters[name] = None
        elif not isinstance(param, Parameter):
            raise TypeError("cannot assign '{}' object to parameter '{}' "
                            "(torch.nn.Parameter or None required)"
                            .format(torch.typename(param), name))
        elif param.grad_fn:
            raise ValueError(
                "Cannot assign non-leaf Tensor to parameter '{0}'. Model "
                "parameters must be created explicitly. To express '{0}' "
                "as a function of another Tensor, compute the value in "
                "the forward() method.".format(name))
        else:
            self._parameters[name] = param

 

Reference

https://pytorch.org/docs/stable/generated/torch.nn.Module.html

 

Module — PyTorch 1.13 documentation

Shortcuts

pytorch.org

https://aigong.tistory.com/429

 

[Pytorch] nn.Module.register_buffer와 nn.Module.register_parameter, nn.Parameters 차이

[Pytorch] nn.Module.register_buffer와 nn.Module.register_parameter, nn.Parameters 차이 목차 torch.nn.Module.register_buffer(name, tensor, persistent) https://pytorch.org/docs/stable/generated/torch.nn.Module.html Module — PyTorch 1.11.0 documentatio

aigong.tistory.com

 

Comments