슬기로운 에러 생활

[ BERT ] NotImplementedError: Layers with arguments in __init__ must override get_config

vhrehfdl 2020. 5. 26. 22:24

* 에러 메세지 : NotImplementedError: Layers with arguments in `__init__` must override `get_config`

 

* 에러 원인 : callback 함수를 사용해 모델을 저장하려 하는데 위에 에러가 발생하였다.

callback 함수를 사용하지 않고 model.fit을 하는 경우에는 위에 에러가 발생하지 않았다.

 

* 해결 방법 : BERT Layer Class 안에 get_config 함수를 추가해주니 문제가 해결되었다.


class BertLayer(tf.keras.layers.Layer):
    def __init__(self, n_fine_tune_layers=10, **kwargs):
        self.n_fine_tune_layers = n_fine_tune_layers
        self.trainable = True
        self.output_size = 768
        super(BertLayer, self).__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        return config

    def build(self, input_shape):
        self.bert = hub.Module(
            "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1",
            trainable=self.trainable,
            name="{}_module".format(self.name)
        )

해당 문제가 transformer등에서 자주 일어나는 것 같은데...
이 때 아래의 URL에 나온 것처럼 self.model 등의 파라미터를 업데이트 해주면 해결이 된다.

 

* 참고 URL : stackoverflow.com/questions/58678836/notimplementederror-layers-with-arguments-in-init-must-override-get-conf