×

注意!页面内容来自https://github.com/Jiaxin-Wen/CPM-Live/compare/8c74cda...OpenBMB:CPM-Live:8ad7aa6.diff,本站不储存任何内容,为了更好的阅读体验进行在线解析,若有广告出现,请及时反馈。若您觉得侵犯了您的利益,请通知我们进行删除,然后访问 原网页

diff --git a/.gitignore b/.gitignore index 9081237..06fed04 100644 --- a/.gitignore +++ b/.gitignore @@ -134,4 +134,9 @@ dmypy.on *.bin *.idx -*.pt \ No newline at end of file +*.pt + +data +data_raw +results +pretrain_data \ No newline at end of file diff --git a/README-ZH.md b/README-ZH.md new file mode 100644 index 0000000..3ce63b7 --- /dev/null +++ b/README-ZH.md @@ -0,0 +1,60 @@ +
+ +

CPM-Live

+ +**直播训练开源大模型** + +

+ 官方网站计划书讨论区English +
+

+ +
+ +## 动态 +- 2023/05/27 [CPM-Bee](https://github.com/OpenBMB/CPM-Bee) 发布了! +- 2023/04/12 CPM-Ant 可以在[HuggingFace Transformers](https://huggingface.co/openbmb/cpm-ant-10b)中使用了! +- 2022/10/12 中英双语模型 [CPM-Ant+](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant-plus/cpm-live) 已经发布!除了能够生成中文/英文文本,现在模型还可以处理问答、摘要和翻译任务! +- 2022/09/16 [CPM-Ant](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant/cpm-live) 已经发布! +- 2022/05/29 CPM-Live的训练今天启动! 详情请查看[训练动态](https://live.openbmb.org/home)。 +- 2022/05/25 CPM-Live的[训练计划](./plans/CPM-Live训练计划书.md)现已公布。期待训练开始! + + +## 里程碑 +- **CPM-Bee** (2022/10/13-2023/05/27) [[代码](https://github.com/OpenBMB/CPM-Bee)][[模型](https://github.com/OpenBMB/CPM-Bee#%E6%A8%A1%E5%9E%8B)][[计划书](./plans/CPM-Bee训练计划书.md)] +- **CPM-Ant+** (2022/08/05-2022/10/12) [[代码](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant-plus/cpm-live)][[模型](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant-plus/cpm-live#model-checkpoints)] +- **CPM-Ant** (2022/05/29-2022/08/05) [[代码](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant/cpm-live)][[模型](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant/cpm-live#model-checkpoints)][[网站](https://live.openbmb.org/ant)][[博客](https://www.openbmb.org/en/community/blogs/blogpage?id=98afef2ce45f4fe9a4bc15a66d7ccb92)][[计划书](./plans/CPM-Ant训练计划书.md)] + +## 训练计划 +考虑到数据和计算资源的规模,CPM-Live将从10B模型开始训练并持续学习。 + +### 在训练过程中,我们将进行: + +- **实时**:显示模型训练指标 +- **每天**:发布模型训练日志 +- **每周**:处理社区的讨论和反馈 +- **不定期**:在模型训练期间发布允许公开下载的检查点 + + +### 在训练期间你可以: + +- **提出你的模型倡议**:对模型架构、训练方法或数据源有好的想法?你可以在社区里提出你的模型倡议。如果该倡议得到更多的支持并且实际可行,我们将把它添加到我们正在训练的模型中,这样CPM-Live就可以在大家的帮助下不断学习和进步。 + +- **开发你的应用程序**:基于CPM-Live,你可以向社区提交你初期想法、原型、开发代码或完成的应用程序。我们将在网站上展示最受欢迎的应用程序。 + +- **在论坛上聊天**:你可以在我们的论坛上谈论任何与大模型有关的话题,如学术研究、工程实现、工具使用、应用设计等。无论你是否有经验,我们相信每个人都可以从积极和开放的讨论中受益。 + +- **下载资源**:模型训练完成后,你可以在开放使用许可下自由下载模型参数。CPM-Live使用的是包括商业化许可的开放许可。通过模型压缩和推理加速工具,你可以在自己的电脑上体验大模型的威力! + + + +## 社区 + +我们的[社区](https://github.com/OpenBMB/CPM-Live/discussions) 基于GitHub Discussions。 + +阅读[第一篇帖子](https://github.com/OpenBMB/CPM-Live/discussions/1),开始你对CPM-Live的探索吧! + + + + + diff --git a/README.md b/README.md index c809a39..4c0d942 100644 --- a/README.md +++ b/README.md @@ -5,20 +5,29 @@ **Live Training for Open-source Big Models**

- WebsitePlanDiscussion + WebsitePlanDiscussion简体中文 +

- ## What's New +- 2023/05/27 [CPM-Bee](https://github.com/OpenBMB/CPM-Bee) is released! +- 2023/04/12 CPM-Ant has been integrated into [HuggingFace Transformers](https://huggingface.co/openbmb/cpm-ant-10b)! +- 2022/10/12 [CPM-Ant+](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant-plus/cpm-live)a bilingual modelis released! In addition to generating Chinese/English textyou can now use our model for QAsummarization and translation tasks! +- 2022/09/16 [CPM-Ant](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant/cpm-live) is released! - 2022/05/29 The training of CPM-Live has launched today! See [training dynamics](https://live.openbmb.org/home). - 2022/05/25 The [training plan](./plans/CPM-Live训练计划书.md) for CPM-Live is now published. Look forward to the training! +## Milestones + +- **CPM-Bee** (2022/10/13-2023/05/27) [[Code](https://github.com/OpenBMB/CPM-Bee)][[Model](https://github.com/OpenBMB/CPM-Bee#%E6%A8%A1%E5%9E%8B)][[Plan](./plans/CPM-Bee训练计划书.md)] +- **CPM-Ant+** (2022/08/05-2022/10/12) [[Code](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant-plus/cpm-live)][[Model](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant-plus/cpm-live#model-checkpoints)] +- **CPM-Ant** (2022/05/29-2022/08/05) [[Code](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant/cpm-live)][[Model](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant/cpm-live#model-checkpoints)][[Website](https://live.openbmb.org/ant)][[Blog](https://www.openbmb.org/en/community/blogs/blogpage?id=98afef2ce45f4fe9a4bc15a66d7ccb92)][[Plan](./plans/CPM-Ant训练计划书.md)] ## Training Plan -Considering the scale of data and computing resourcesCPM-Live will start with a 10B model trainingwhich we named CPM-Ant. The training of CPM-Ant will start on May 292022and the entire training process is expected to last five months. +Considering the scale of data and computing resourcesCPM-Live will start with a 10B model training. ### During training we will do: @@ -43,8 +52,3 @@ Considering the scale of data and computing resourcesCPM-Live will start with [Our community](https://github.com/OpenBMB/CPM-Live/discussions) is based on GitHub Discussions. Read the [first post](https://github.com/OpenBMB/CPM-Live/discussions/1) and start your exploration on CPM-Live! - - - - - diff --git a/cpm-live/.flake8 b/cpm-live/.flake8 index b357ab3..e4d2c92 100644 --- a/cpm-live/.flake8 +++ b/cpm-live/.flake8 @@ -3,4 +3,5 @@ per-file-ignores = # imported but unused __init__.py: F401 max-line-length = 100 -extend-ignore = E712E203 \ No newline at end of file +extend-ignore = E712E203 +exclude = examples/*.py \ No newline at end of file diff --git a/cpm-live/config/cpm-bee-10b.on b/cpm-live/config/cpm-bee-10b.on new file mode 100644 index 0000000..c34b2e0 --- /dev/null +++ b/cpm-live/config/cpm-bee-10b.on @@ -0,0 +1,14 @@ +{ + "vocab_size": 86583, + "dim_model": 4096, + "dim_ff" : 10240, + "num_layers" : 48, + "num_heads": 32, + "dim_head" : 128, + "dropout_p" : 0.0, + "position_bias_num_buckets" : 256, + "position_bias_num_segment_buckets": 256, + "position_bias_max_distance" : 2048, + "eps" : 1e-6, + "half" : true +} diff --git a/cpm-live/config/cpm-bee-3b.on b/cpm-live/config/cpm-bee-3b.on new file mode 100644 index 0000000..55fd0f2 --- /dev/null +++ b/cpm-live/config/cpm-bee-3b.on @@ -0,0 +1,14 @@ +{ + "vocab_size": 86580, + "dim_model": 2560, + "dim_ff" : 3072, + "num_layers" : 32, + "num_heads": 32, + "dim_head" : 80, + "dropout_p" : 0.0, + "position_bias_num_buckets" : 256, + "position_bias_num_segment_buckets": 256, + "position_bias_max_distance" : 2048, + "eps" : 1e-6, + "half" : true +} diff --git a/cpm-live/cpm_live/__init__.py b/cpm-live/cpm_live/__init__.py index 03e328b..e69de29 100644 --- a/cpm-live/cpm_live/__init__.py +++ b/cpm-live/cpm_live/__init__.py @@ -1,5 +0,0 @@ -from . import models -from . import dataset -from . import utils -from . import tokenizers -from .arguments import get_args diff --git a/cpm-live/cpm_live/arguments.py b/cpm-live/cpm_live/arguments.py index 2f3202c..b1b6d0e 100644 --- a/cpm-live/cpm_live/arguments.py +++ b/cpm-live/cpm_live/arguments.py @@ -29,13 +29,7 @@ def add_training_args(parser: argparse.ArgumentParser): group = parser.add_argument_group("train""training configurations") - group.add_argument( - "--base-path", - type=str, - default=None, - help="Path to the project base directory.", - ) - group.add_argument("--dataset_name"type=strdefault=Nonehelp="Name of the dataset") + group.add_argument("--dataset"type=strdefault="dataset.on"help="Path to dataset") group.add_argument( "--load", type=str, @@ -54,18 +48,14 @@ def add_training_args(parser: argparse.ArgumentParser): default=None, help="Output filename to save checkpoints to.", ) + group.add_argument( - "--save-iters", - type=int, - default=1000, - help="number of iterations between saves", - ) - group.add_argument( - "--log-dir", + "--tensorboard", type=str, - default="logs", - help="tensorboard log directory", + default=None, + help="tensorboard directory", ) + group.add_argument("--inspect-iters"type=intdefault=1000help="number of inspecting") group.add_argument("--batch-size"type=intdefault=32help="Data Loader batch size") group.add_argument("--clip-grad"type=floatdefault=1.0help="gradient clipping") @@ -76,33 +66,12 @@ def add_training_args(parser: argparse.ArgumentParser): help="total number of iterations to train over all training runs", ) group.add_argument("--max-length"type=intdefault=512help="max length of input") - group.add_argument( - "--max-encoder-length", - type=int, - default=512, - help="max length of encoder input", - ) - group.add_argument( - "--max-decoder-length", - type=int, - default=256, - help="max length of decoder input", - ) - group.add_argument( - "--start-step"type=intdefault=0help="step to start or continue training" - ) - group.add_argument("--seed"type=intdefault=1234help="random seed for reproducibility") - group.add_argument( - "--epochs", - type=int, - default=1, - help="total number of epochs to train over all training runs", - ) + group.add_argument("--seed"type=intdefault=1234help="random seed for reproducibility") # Learning rate. group.add_argument("--lr"type=floatdefault=1.0e-4help="initial learning rate") - group.add_argument("--weight-decay"type=floatdefault=1.0e-2help="weight-decay") + group.add_argument("--weight-decay"type=floatdefault=1.0e-2help="weight decay rate") group.add_argument("--loss-scale"type=floatdefault=65536help="loss scale") group.add_argument( @@ -111,13 +80,6 @@ def add_training_args(parser: argparse.ArgumentParser): default=0.01, help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01", ) - group.add_argument( - "--lr-decay-iters", - type=int, - default=None, - help="number of iterations to decay LR over," - " If None defaults to `--train-iters`*`--epochs`", - ) group.add_argument( "--lr-decay-", type=str, @@ -125,20 +87,62 @@ def add_training_args(parser: argparse.ArgumentParser): choices=["constant""linear""cosine""exponential""noam"], help="learning rate decay function", ) + group.add_argument("--lr-decay-iters"type=intdefault=Nonehelp="lr decay steps") group.add_argument( - "--local_rank", + "--start-step"type=intdefault=0help="step to start or continue training" + ) + + return parser + + +def add_pretrain_args(parser: argparse.ArgumentParser): + group = parser.add_argument_group("pretrain""pretrain configurations") + group.add_argument( + "--save-iters", type=int, + default=1000, + help="number of iterations between saves", + ) + group.add_argument( + "--log-dir", + type=str, default=None, - help="local rank passed from distributed launcher", + help="log directory", ) return parser -def get_args(): +def add_finetune_args(parser: argparse.ArgumentParser): + group = parser.add_argument_group("finetune""fintune configurations") + group.add_argument("--epoch"type=intdefault=1help="number of training epochs") + group.add_argument("--task-name"type=strdefault="task"help="name of training task") + group.add_argument( + "--use-delta", + action="store_true", + default=False, + help="use delta tuning or not" + ) + group.add_argument("--eval_dataset"type=strhelp="path to eval dataset") + group.add_argument( + "--drop-last", + action="store_true", + default=False, + help="drop data from each epoch that cannot be formed into a complete batch at the end", + ) + group.add_argument("--eval-interval"type=intdefault=500help="eval interval") + group.add_argument("--early-stop-patience"type=intdefault=5help="early stop steps") + return parser + + +def get_args(pretrain: bool = Falsefinetune: bool = False): parser = argparse.ArgumentParser() parser = add_model_config_args(parser) parser = add_training_args(parser) + if pretrain: + parser = add_pretrain_args(parser) + if finetune: + parser = add_finetune_args(parser) args = parser.parse_args() return args diff --git a/cpm-live/cpm_live/dataset/distributed_dataset.py b/cpm-live/cpm_live/dataset/distributed_dataset.py index d90c612..a9979da 100644 --- a/cpm-live/cpm_live/dataset/distributed_dataset.py +++ b/cpm-live/cpm_live/dataset/distributed_dataset.py @@ -15,29 +15,35 @@ import io import os -import pickle -from typing import List +import struct +from typing import ListOptionalSet import torch import bisect import bmtrain as bmt - +import on +from .serializer import SerializerPickleSerializer import random import string +import time def _random_string(): return "".join(random.choices(string.ascii_uppercase + string.digitsk=8)) +_DEFAULT_BLOCK_SIZE = 16 << 20 + + class FileInfo: def __init__( self, - file_name: str, - block_begin: int, - block_end: int, - nbytes: int, - nlines: int, + file_name: str = "", + block_begin: int = 0, + block_end: int = 0, + nbytes: int = 0, + nlines: int = 0, mask: bool = False, + block_size: int = _DEFAULT_BLOCK_SIZE, ) -> None: self.file_name = file_name self.block_begin = block_begin @@ -45,34 +51,153 @@ def __init__( self.nbytes = nbytes self.nlines = nlines self.mask = mask + self.block_size = block_size - @classmethod - def _load_from_state(clsversiondata): - if version == 1: - file_nameblock_beginblock_endnbytesnlinesmask = data - return cls(file_nameblock_beginblock_endnbytesnlinesmask) - else: - raise RuntimeError("Unsupported version %d" % version) - - def __reduce__(self): - return ( - FileInfo._load_from_state, - ( - 1, - ( - self.file_name, - self.block_begin, - self.block_end, - self.nbytes, - self.nlines, - self.mask, - ), - ), - ) + def state_dict(self): + return { + "file_name": self.file_name, + "block_begin": self.block_begin, + "block_end": self.block_end, + "nbytes": self.nbytes, + "nlines": self.nlines, + "mask": self.mask, + "block_size": self.block_size, + } + + def load_state_dict(selfd): + self.file_name = d["file_name"] + self.block_begin = d["block_begin"] + self.block_end = d["block_end"] + self.nbytes = d["nbytes"] + self.nlines = d["nlines"] + self.mask = d["mask"] + self.block_size = d["block_size"] + + def dumps(self) -> str: + return on.dumps(self.state_dict()) + + def loads(selfdata: str) -> "FileInfo": + self.load_state_dict(on.loads(data)) + return self + + def dump(selffp: io.TextIOWrapper) -> "FileInfo": + fp.write(self.dumps()) + return self + + def load(selffp: io.TextIOWrapper) -> "FileInfo": + self.loads(fp.read()) + return self + + +def _read_info_list(meta_path: str) -> List[FileInfo]: + info: List[FileInfo] = [] + while True: + try: + with open(meta_path"r"encoding="utf-8") as f: + for line in f.readlines(): + line = line.strip() + if len(line) > 0: + info.append(FileInfo().loads(line)) + return info + except Exception as e: + print("Error: reading info list in _read_info_list!,meta_path={path}err={err}". + format(path=meta_patherr=str(e))) + time.sleep(10) + + +def _write_info_list(meta_path: strinfo: List[FileInfo]): + base_path = os.path.dirname(meta_path) + random_fname = os.path.join(base_path".meta.bin.%s" % _random_string()) + while True: + try: + with open(random_fname"w"encoding="utf-8") as f: + for v in info: + f.write(v.dumps() + "\n") + os.rename(random_fnamemeta_path) + return + except Exception: + print("Error: writing info list!") + time.sleep(10) -_MASK_VALUE = 0x7FFFFFFF -_DEFAULT_BLOCK_SIZE = 16 << 20 +def _filtered_range( + begin: intend: intrank: intworld_size: intfilter_set: Optional[Set[int]] = None +): + begin = begin + (rank + (world_size - (begin % world_size))) % world_size + + if filter_set is not None: + return [i for i in range(beginendworld_size) if i in filter_set] + else: + return [i for i in range(beginendworld_size)] + + +# for some bugs that may exist in hdfs +class SafeFile: + + def __init__(selffnamemode): + self.fname = None + self.mode = None + self._fp = None + self.open_file(fnamemode) + + def read(selfsize=-1): + if self._fp is None: + raise RuntimeError("Dataset is closed") + try: + res = self._fp.read(size) + self.offset = self._fp.tell() + return res + except Exception as e: + print("Error {}: reading blocks in read {}!".format(eself.fname)) + self.open_file(self.fnameself.modeself.offset) + return self.read(size) + + def tell(self): + if self._fp is None: + raise RuntimeError("Dataset is closed") + try: + res = self._fp.tell() + self.offset = res + return res + except Exception as e: + print("Error {}: reading blocks in tell {}!".format(eself.fname)) + self.open_file(self.fnameself.modeself.offset) + return self.tell() + + def seek(selfoffsetwhence=0): + if self._fp is None: + raise RuntimeError("Dataset is closed") + try: + res = self._fp.seek(offsetwhence) + self.offset = self._fp.tell() + return res + except Exception as e: + print("Error {}: reading blocks in seek {}!".format(eself.fname)) + self.open_file(self.fnameself.modeself.offset) + return self.seek(offsetwhence) + + def close(self): + if self._fp is not None: + try: + self._fp.close() + except Exception: + pass + self._fp = None + + def open_file(selffnamemodeoffset=None): + if not os.path.exists(fname): + raise RuntimeError("Dataset does not exist") + try: + self.fname = fname + self.mode = mode + self._fp = open(fnamemode) + if offset is not None: + self._fp.seek(offsetio.SEEK_SET) + self.offset = self._fp.tell() + except Exception as e: + print("Error {}: reading blocks in open_file {}!".format(eself.fname)) + time.sleep(10) + self.open_file(fnamemodeoffset) class DistributedDataset: @@ -100,16 +225,24 @@ def __init__( path: str, rank: int = 0, world_size: int = 1, - block_size=_DEFAULT_BLOCK_SIZE, + serializer: Optional[Serializer] = None, + max_repeat_times: Optional[int] = None, + shuffle: bool = True, ) -> None: # config self._path = path self._rank = rank self._world_size = world_size - self._block_size = block_size + self._max_repeat_times = max_repeat_times + self._repeat_times = 0 + self._shuffle = shuffle + + if serializer is None: + serializer = PickleSerializer() + self.serializer = serializer # dataset meta - self._block_states = torch.tensor([]dtype=torch.int) + self._unused_block: List[int] = [] self._file_info: List[FileInfo] = [] self._file_ends: List[int] = [] self._total_blocks = 0 @@ -124,12 +257,21 @@ def __init__( self._last_mod_time = 0 self._curr_fname = None - self._update_states() + self._update_states(fast_skip=False) + self._repeat_times += 1 def _update_states(selffast_skip: bool = True): meta_path = os.path.join(self._path"meta.bin") - mod_time = os.stat(meta_path).st_mtime + while True: + try: + mod_time = os.stat(meta_path).st_mtime + break + except Exception as e: + print("Error: reading info list in DistributedDataset._update_states" + "meta_path={path}err={err}!".format(path=meta_patherr=str(e))) + time.sleep(10) + if self._last_mod_time < mod_time: # file changed pass @@ -139,12 +281,13 @@ def _update_states(selffast_skip: bool = True): info: List[FileInfo] = [] if os.path.exists(meta_path): - with open(meta_path"rb") as f: - info = pickle.load(f) + info = _read_info_list(meta_path) old_len = len(self._file_info) if old_len > len(info): raise RuntimeError("Dataset meta file: changed unexpectly") + + mask_changed = False for i in range(old_len): if self._file_info[i].file_name != info[i].file_name: raise RuntimeError("Dataset meta file: changed unexpectly") @@ -152,6 +295,8 @@ def _update_states(selffast_skip: bool = True): raise RuntimeError("Dataset meta file: changed unexpectly") if self._file_info[i].block_end != info[i].block_end: raise RuntimeError("Dataset meta file: changed unexpectly") + if self._file_info[i].mask != info[i].mask: + mask_changed = True if info[0].block_begin != 0: raise RuntimeError("Dataset meta file: block error (0)") @@ -159,7 +304,7 @@ def _update_states(selffast_skip: bool = True): if info[i].block_end != info[i + 1].block_begin: raise RuntimeError("Dataset meta file: block error (%d)" % (i + 1)) - if old_len == len(info) and fast_skip: + if (old_len == len(info) and not mask_changed) and fast_skip: # fast skip return @@ -176,32 +321,38 @@ def _update_states(selffast_skip: bool = True): self._nlines = 0 if total_blocks > 0: - masks = torch.full( - (total_blocks,), - _MASK_VALUE, - dtype=torch.int, - device="cpu", - requires_grad=False, - ) - masks[self._rank :: self._world_size] = 0 - for v in info: - if v.mask or (not os.path.exists(self._get_file_path(v.file_name))): - masks[v.block_begin : v.block_end] = _MASK_VALUE - new_block_states = torch.zeros( - total_blocksdtype=torch.intdevice="cpu"requires_grad=False - ) - new_block_states[: self._block_states.size(0)] = self._block_states - new_block_states = torch.maximum(new_block_statesmasks) - - self._block_states = new_block_states + unused_block_set = set(self._unused_block) + nw_unused_block: List[int] = [] + for i in range(len(info)): + v = info[i] + if not v.mask: + if i < old_len: + nw_unused_block.extend( + _filtered_range( + v.block_begin, + v.block_end, + self._rank, + self._world_size, + unused_block_set, + ) + ) + else: + nw_unused_block.extend( + _filtered_range( + v.block_beginv.block_endself._rankself._world_size + ) + ) + + # re-shuffle unused blocks + if self._shuffle: + random.shuffle(nw_unused_block) + self._unused_block = nw_unused_block self._file_ends = [] for v in info: self._file_ends.append(v.block_end) else: - self._block_states = torch.tensor( - []dtype=torch.intdevice="cpu"requires_grad=False - ) + self._unused_block = [] self._file_ends = [] self._total_blocks = total_blocks self._file_info = info @@ -209,27 +360,62 @@ def _update_states(selffast_skip: bool = True): assert len(self._file_ends) == len(self._file_info) def _mask_file(selff: FileInfo): - masks = torch.full( - (self._total_blocks,)0dtype=torch.intdevice="cpu"requires_grad=False - ) - masks[f.block_begin : f.block_end] = _MASK_VALUE - self._block_states = torch.maximum(self._block_statesmasks) + self._unused_block = [ + block_id + for block_id in self._unused_block + if block_id < f.block_begin or block_id >= f.block_end + ] def _get_block_file(selfblock_id: int): # find block in which file file_idx = bisect.bisect_right(self._file_endsblock_id) return self._file_info[file_idx] + def _prepare_new_epoch(self): + if self._max_repeat_times is not None: + if self._repeat_times >= self._max_repeat_times: + raise EOFError("End of dataset") + nw_unused_block: List[int] = [] + for v in self._file_info: + if not v.mask: + nw_unused_block.extend( + _filtered_range(v.block_beginv.block_endself._rankself._world_size) + ) + if self._shuffle: + random.shuffle(nw_unused_block) + self._unused_block = nw_unused_block + self._repeat_times += 1 + def _get_next_block(self): self._update_states() - if self._block_states.size(0) == 0: - raise RuntimeError("Empty dataset") - mn_block: int = self._block_states.argmin().item() # type: ignore - if self._block_states[mn_block].item() == _MASK_VALUE: - raise RuntimeError("Empty dataset") - self._block_states[mn_block] += 1 + if len(self._unused_block) == 0: + self._prepare_new_epoch() + if len(self._unused_block) == 0: + raise RuntimeError("Empty dataset {}".format(self._path)) + + mn_block: int = self._unused_block.pop() return mn_block + def _state_dict(self): + self._update_states() + num_unused_block = len(self._unused_block) + if (self._fp is not None) and (self._curr_block is not None): + curr_block = self._curr_block + curr_f = self._get_block_file(curr_block) + inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * curr_f.block_size + else: + curr_block = -1 + inblock_offset = 0 + + return { + "states": torch.tensor(self._unused_blockdtype=torch.longdevice="cpu"), + "block": torch.tensor( + [curr_blockinblock_offsetnum_unused_blockself._repeat_times], + dtype=torch.long, + device="cpu", + ), + } + def state_dict(self): """Returns a state dict representing the read states of the dataset. @@ -238,32 +424,43 @@ def state_dict(self): >>> dataset.load_state_dict(state) """ self._update_states() - states = torch.where( - self._block_states == _MASK_VALUE, - torch.zeros(self._total_blocksdtype=torch.intdevice="cpu"requires_grad=False), - self._block_states, - ) + num_unused_block = len(self._unused_block) if (self._fp is not None) and (self._curr_block is not None): curr_block = self._curr_block curr_f = self._get_block_file(curr_block) - inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * self._block_size + inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * curr_f.block_size else: curr_block = -1 inblock_offset = 0 with torch.no_grad(): if self._world_size > 1: - gpu_states = states.cuda() - gpu_block = torch.tensor([curr_blockinblock_offset]dtype=torch.long).cuda() - global_states = bmt.distributed.all_reduce(gpu_statesop="sum").cpu() - global_block = bmt.distributed.all_gather(gpu_block).cpu() + gpu_num_unused_block = torch.tensor([num_unused_block]dtype=torch.long).cuda() + max_unused_blocks = ( + bmt.distributed.all_reduce(gpu_num_unused_blockop="max").cpu().item() + ) + gpu_states = torch.full((max_unused_blocks,)-1dtype=torch.long).cuda() + gpu_states[:num_unused_block] = torch.tensor( + self._unused_blockdtype=torch.long + ).cuda() + + gpu_block = torch.tensor( + [curr_blockinblock_offsetnum_unused_blockself._repeat_times], + dtype=torch.long, + ).cuda() + global_states = bmt.distributed.all_gather( + gpu_states + ).cpu() # (world_sizemax_unused_blocks) + global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size4) return {"states": global_states"block": global_block} else: return { - "states": states, + "states": torch.tensor([self._unused_block]dtype=torch.longdevice="cpu"), "block": torch.tensor( - [[curr_blockinblock_offset]]dtype=torch.longdevice="cpu" + [[curr_blockinblock_offsetnum_unused_blockself._repeat_times]], + dtype=torch.long, + device="cpu", ), } @@ -278,11 +475,10 @@ def load_state_dict(selfstatestrict: bool = True): >>> state = dataset.state_dict() >>> """ + block_states: torch.LongTensor = state["states"] + block_info: torch.LongTensor = state["block"] - self._block_states = state["states"] - self._update_states(False) - - if state["block"].size(0) != self._world_size: + if block_states.size(0) != self._world_size: if strict: raise ValueError( "world_size changed (%d -> %d)" % (state["block"].size(0)self._world_size) @@ -291,20 +487,48 @@ def load_state_dict(selfstatestrict: bool = True): self._curr_block = None self._fp = None self._curr_fname = None + self._repeat_times = int(block_info[03].item()) + + # re-shuffle unused blocks + nw_unused_block: List[int] = [] + for i in range(block_states.size(0)): + # filter blocks that are not in this rank + num_unused_blocks: int = int(block_info[i2].item()) + nw_unused_block.extend( + [ + block_id + for block_id in block_states[i:num_unused_blocks].tolist() + if block_id % self._world_size == self._rank + ] + ) + if self._shuffle: + random.shuffle(nw_unused_block) + self._unused_block = nw_unused_block else: - curr_block = state["block"][self._rank][0].item() - inblock_offset = state["block"][self._rank][1].item() + curr_blockinblock_offsetnum_unused_blocksself._repeat_times = block_info[ + self._rank + ].tolist() if curr_block == -1: self._curr_block = None else: - self._curr_block = curr_block - f_info = self._get_block_file(self._curr_block) - self._open_file( - f_info.file_name, - (self._curr_block - f_info.block_begin) * self._block_size + inblock_offset, - ) + while True: + try: + self._curr_block = curr_block + f_info = self._get_block_file(self._curr_block) + self._open_file( + f_info.file_name, + (self._curr_block - f_info.block_begin) + * f_info.block_size + + inblock_offset, + ) + self._unused_block = block_states[self._rank:num_unused_blocks].tolist() + break + except Exception: + print("Error: reading block!") + time.sleep(10) # end + self._update_states() def _get_file_path(selffname): return os.path.join(self._pathfname) @@ -314,11 +538,11 @@ def _open_file(selffnameoffset): if self._fp is not None: self._fp.close() self._curr_fname = None - self._fp = open(self._get_file_path(fname)"rb") + # self._fp = open(self._get_file_path(fname)"rb") + self._fp = SafeFile(self._get_file_path(fname)"rb") self._curr_fname = fname else: assert self._fp is not None"Unexpected error" - self._fp.seek(offsetio.SEEK_SET) # move to block def read(self): @@ -332,10 +556,11 @@ def read(self): try: self._open_file( f_info.file_name, - (next_block_id - f_info.block_begin) * self._block_size, + (next_block_id - f_info.block_begin) * f_info.block_size, ) self._curr_block = next_block_id except FileNotFoundError: + print("ERR: reading again!") self._mask_file(f_info) return self.read() # read again @@ -345,7 +570,9 @@ def read(self): MAGIC = self._fp.read(1) if MAGIC == b"\x1F": # correct - return pickle.load(self._fp) + size = struct.unpack("I"self._fp.read(4))[0] + data = self._fp.read(size) + return self.serializer.deserialize(data) elif MAGIC == b"\x00": # end of block self._curr_block = None @@ -359,24 +586,27 @@ def nbytes(self): class SimpleDataset(DistributedDataset): - def __init__(selfpath: strblock_size=_DEFAULT_BLOCK_SIZE) -> None: - super().__init__(path01block_size) - - def _get_next_block(self): - self._update_states() - if self._block_states.size(0) == 0: - raise RuntimeError("Empty dataset") - mn_block: int = self._block_states.argmin().item() # type: ignore - if self._block_states[mn_block].item() >= 1: - raise EOFError("no more data") - self._block_states[mn_block] += 1 - return mn_block + def __init__( + self, + path: str, + serializer: Optional[Serializer] = None, + shuffle: bool = True, + ) -> None: + super().__init__( + path, + 0, + 1, + serializer=serializer, + max_repeat_times=1, + shuffle=shuffle, + ) def __iter__(self): while True: try: data = self.read() except EOFError: + self._repeat_times = 0 break yield data @@ -385,7 +615,7 @@ def __len__(self): class DatasetWriter: - def __init__(selffnameblock_size): + def __init__(selffname: strblock_size: intserializer: Optional[Serializer] = None): self._fname = fname self._block_size = block_size self._fp = open(self._fname"wb") @@ -395,6 +625,10 @@ def __init__(selffnameblock_size): self._nlines = 0 self._nblocks = 1 + if serializer is None: + serializer = PickleSerializer() + self.serializer = serializer + def write(selfdata): """Write a piece of data into dataset. @@ -405,7 +639,8 @@ def write(selfdata): >>> writer.write( "anything you want" ) """ - byte_data = pickle.dumps(data) + byte_data = self.serializer.serialize(data) + byte_data = struct.pack("I"len(byte_data)) + byte_data if self._inblock_offset + 2 + len(byte_data) > self._block_size: self._fp.write( b"\x00" * (self._block_size - self._inblock_offset) @@ -442,10 +677,19 @@ def close(self): class DatasetBuilder: - def __init__(selfpath: strdbname: strblock_size=_DEFAULT_BLOCK_SIZE) -> None: + def __init__( + self, + path: str, + dbname: str, + block_size=_DEFAULT_BLOCK_SIZE, + serializer: Optional[Serializer] = None, + ) -> None: self._block_size = block_size self._path = path self._dbname = dbname + if serializer is None: + serializer = PickleSerializer() + self.serializer = serializer if not os.path.exists(self._path): os.makedirs(self._path) @@ -454,8 +698,7 @@ def __init__(selfpath: strdbname: strblock_size=_DEFAULT_BLOCK_SIZE) -> No info: List[FileInfo] = [] if os.path.exists(meta_path): - with open(meta_path"rb") as f: - info = pickle.load(f) + info = _read_info_list(meta_path) for v in info: if v.file_name == dbname: @@ -466,7 +709,7 @@ def __init__(selfpath: strdbname: strblock_size=_DEFAULT_BLOCK_SIZE) -> No raise ValueError("File exists `%s`" % self._db_path) def __enter__(self): - self._writer = DatasetWriter(self._db_pathself._block_size) + self._writer = DatasetWriter(self._db_pathself._block_sizeself.serializer) return self._writer def __exit__(selfexc_typeexc_valueexc_traceback): @@ -482,8 +725,7 @@ def __exit__(selfexc_typeexc_valueexc_traceback): meta_path = os.path.join(self._path"meta.bin") info: List[FileInfo] = [] if os.path.exists(meta_path): - with open(meta_path"rb") as f: - info = pickle.load(f) + info = _read_info_list(meta_path) last_block = 0 if len(info) > 0: last_block = info[-1].block_end @@ -495,19 +737,22 @@ def __exit__(selfexc_typeexc_valueexc_traceback): self._writer.nbytes, self._writer.nlines, False, + self._block_size, ) ) # atomic write to meta file - random_fname = os.path.join(self._path".meta.bin.%s" % _random_string()) - with open(random_fname"wb") as f: - pickle.dump(infof) - os.rename(random_fnamemeta_path) + _write_info_list(meta_pathinfo) self._writer = None -def build_dataset(path: strdbname: strblock_size: int = _DEFAULT_BLOCK_SIZE): +def build_dataset( + path: str, + dbname: str, + block_size: int = _DEFAULT_BLOCK_SIZE, + serializer: Optional[Serializer] = None, +): """Open the dataset in write mode and returns a writer. Args: @@ -520,4 +765,4 @@ def build_dataset(path: strdbname: strblock_size: int = _DEFAULT_BLOCK_SIZE) >>> for i in range(10): >>> writer.write( { "anything you want" } ) """ # noqa: E501 - return DatasetBuilder(pathdbnameblock_size) + return DatasetBuilder(pathdbnameblock_size=block_sizeserializer=serializer) diff --git a/cpm-live/cpm_live/dataset/serializer.py b/cpm-live/cpm_live/dataset/serializer.py new file mode 100644 index 0000000..ef3865f --- /dev/null +++ b/cpm-live/cpm_live/dataset/serializer.py @@ -0,0 +1,61 @@ +# coding=utf-8 +# Copyright 2020 The OpenBMB team. All rights reserved. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +import on + + +class Serializer: + def __init__(self) -> None: + pass + + def serialize(selfobj) -> bytes: + raise NotImplementedError() + + def deserialize(selfdata: bytes): + raise NotImplementedError() + + +class PickleSerializer(Serializer): + def __init__(self) -> None: + pass + + def serialize(selfobj) -> bytes: + return pickle.dumps(obj) + + def deserialize(selfdata: bytes): + return pickle.loads(data) + + +class JsonSerializer(Serializer): + def __init__(self) -> None: + pass + + def serialize(selfobj) -> bytes: + return on.dumps(objensure_ascii=False).encode("utf-8") + + def deserialize(selfdata: bytes): + return on.loads(data.decode("utf-8")) + + +class RawSerializer(Serializer): + def __init__(self) -> None: + pass + + def serialize(selfobj) -> bytes: + return obj + + def deserialize(selfdata: bytes): + return data diff --git a/cpm-live/cpm_live/dataset/utils.py b/cpm-live/cpm_live/dataset/utils.py index c166a56..2a6ace6 100644 --- a/cpm-live/cpm_live/dataset/utils.py +++ b/cpm-live/cpm_live/dataset/utils.py @@ -14,16 +14,19 @@ # limitations under the License. import os -from typing import List +import struct +from typing import ListOptional from .distributed_dataset import ( SimpleDataset, build_dataset, + _read_info_list, + _write_info_list, _random_string, _DEFAULT_BLOCK_SIZE, FileInfo, ) +from .serializer import RawSerializer import random -import pickle import shutil try: @@ -42,7 +45,8 @@ def shuffle_dataset( path_tgt: str, block_size: int = _DEFAULT_BLOCK_SIZE, bucket_size: int = _DEFAULT_SHUFFLE_BUCKET_SIZE, - progress_bar=False, + progress_bar: bool = False, + output_name: Optional[str] = None, ): """Shuffle one distributed datatasetwrite results to another dataset. @@ -60,7 +64,7 @@ def shuffle_dataset( if progress_bar and not support_tqdm: raise RuntimeError("Requires `tqdm` to enable progress bar.") - ds = SimpleDataset(path_srcblock_size=block_size) + ds = SimpleDataset(path_srcserializer=RawSerializer()) num_buckets = (ds.nbytes + bucket_size - 1) // bucket_size tmp_files = [os.path.join(path_src".tmp.%s" % _random_string()) for _ in range(num_buckets)] @@ -74,7 +78,8 @@ def shuffle_dataset( iterator = tqdm(dsdesc="Shuffle step 1/2") for data in iterator: bucket_id = int(random.random() * num_buckets) - pickle.dump(dataf_tmp[bucket_id]) # write into a random bucket + len_data = len(data) + f_tmp[bucket_id].write(struct.pack("I"len_data) + data) finally: # close all files for fp in f_tmp: @@ -83,7 +88,14 @@ def shuffle_dataset( f_tmp = [] # Step 2: shuffle inside bucket - with build_dataset(path_tgt"%s.shuffle" % _random_string()) as writer: + if output_name is None: + output_name = "%s.shuffle" % _random_string() + with build_dataset( + path_tgt, + output_name, + block_size=block_size, + serializer=RawSerializer(), + ) as writer: iterator = tmp_files if progress_bar: iterator = tqdm(tmp_filesdesc="Shuffle step 2/2") @@ -93,7 +105,12 @@ def shuffle_dataset( data_in_bucket = [] while True: try: - data_in_bucket.append(pickle.load(fp)) + raw_data = fp.read(4) + if len(raw_data) == 0: + # EOF + break + len_data = struct.unpack("I"raw_data)[0] + data_in_bucket.append(fp.read(len_data)) except EOFError: break random.shuffle(data_in_bucket) @@ -125,12 +142,11 @@ def compact_dataset(path: str): info: List[FileInfo] = [] if os.path.exists(meta_path): - with open(meta_path"rb") as f: - info = pickle.load(f) + info = _read_info_list(meta_path) else: raise ValueError("Dataset not exists") - nw_info = [] + nw_info: List[FileInfo] = [] curr_block = 0 for v in info: if not os.path.exists(v.file_name): @@ -146,14 +162,12 @@ def compact_dataset(path: str): v.nbytes, v.nlines, v.mask, + v.block_size, ) ) curr_block += num_file_block - random_fname = os.path.join(path".meta.bin.%s" % _random_string()) - with open(random_fname"wb") as f: - pickle.dump(nw_infof) - os.rename(random_fnamemeta_path) + _write_info_list(meta_pathnw_info) def mask_dataset(path: strdbname: strmask: bool = True): @@ -173,19 +187,14 @@ def mask_dataset(path: strdbname: strmask: bool = True): info: List[FileInfo] = [] if os.path.exists(meta_path): - with open(meta_path"rb") as f: - info = pickle.load(f) + info = _read_info_list(meta_path) else: raise ValueError("Dataset not exists") for v in info: if v.file_name == dbname: v.mask = mask - - random_fname = os.path.join(path".meta.bin.%s" % _random_string()) - with open(random_fname"wb") as f: - pickle.dump(infof) - os.rename(random_fnamemeta_path) + _write_info_list(meta_pathinfo) def merge_dataset(dst: strsrc: str): @@ -195,15 +204,13 @@ def merge_dataset(dst: strsrc: str): info_src: List[FileInfo] = [] if os.path.exists(meta_path_src): - with open(meta_path_src"rb") as f: - info_src = pickle.load(f) + info_src = _read_info_list(meta_path_src) else: raise ValueError("Dataset not exists") info_dst: List[FileInfo] = [] if os.path.exists(meta_path_dst): - with open(meta_path_dst"rb") as f: - info_dst = pickle.load(f) + info_dst = _read_info_list(meta_path_dst) else: raise ValueError("Dataset not exists") @@ -219,6 +226,7 @@ def merge_dataset(dst: strsrc: str): v.nbytes, v.nlines, v.mask, + v.block_size, ) ) curr_block += num_file_block @@ -244,11 +252,9 @@ def merge_dataset(dst: strsrc: str): v.nbytes, v.nlines, v.mask, + v.block_size, ) ) curr_block += num_file_block - random_fname = os.path.join(dst".meta.bin.%s" % _random_string()) - with open(random_fname"wb") as f: - pickle.dump(nw_infof) - os.rename(random_fnamemeta_path_dst) + _write_info_list(meta_path_dstnw_info) diff --git a/cpm-live/cpm_live/generation/__init__.py b/cpm-live/cpm_live/generation/__init__.py new file mode 100644 index 0000000..70de125 --- /dev/null +++ b/cpm-live/cpm_live/generation/__init__.py @@ -0,0 +1 @@ +from .ant import CPMAntBeamSearchCPMAntRandomSamplingCPMAntGeneration diff --git a/cpm-live/cpm_live/generation/ant.py b/cpm-live/cpm_live/generation/ant.py new file mode 100644 index 0000000..a9e5893 --- /dev/null +++ b/cpm-live/cpm_live/generation/ant.py @@ -0,0 +1,385 @@ +import torch +import torch.nn.functional as F +from .generation_utils import BeamHypothesesapply_repetition_penaltytop_k_top_p_filtering +from ..utils import pad + + +class CPMAntGeneration: + def __init__(selfmodeltokenizerprompt_length=32): + model.eval() + self.model = model + self.tokenizer = tokenizer + self.prompt_length = prompt_length + + def _convert_to_tensors(selfinput_texttask_id=2): + model_inputs = {} + input_ids = [self.tokenizer.bos_id] + self.tokenizer.encode(input_text) + input_ids = [j for j in input_ids if j != self.tokenizer.unk_id] + + model_inputs["input"] = [ + x + self.prompt_length * task_id for x in range(self.prompt_length) + ] + input_ids + model_inputs["length"] = len(model_inputs["input"]) + model_inputs["position"] = list(range(len(model_inputs["input"]))) + model_inputs["span"] = [0] * len(model_inputs["input"]) + model_inputs["context"] = [True] * len(model_inputs["input"]) + model_inputs["segment"] = [0] * self.prompt_length + [2] * len(input_ids) + + for key in model_inputs: + model_inputs[key] = torch.tensor(model_inputs[key]).int().unsqueeze(0) + + return model_inputs + + def _process_texts(selftext_list): + input_tensors = list(map(self._convert_to_tensorstext_list)) + keys = set(input_tensors[0].keys()) + padded = {} + for key in keys: + padded[key] = pad(input_tensorskeypadding_side='left').cuda() + return padded + + def generate(selftext_list**kwargs): + model_inputs = self._process_texts(text_list) + with torch.inference_mode(): + result = self._decode(model_inputs**kwargs) + return result + + def _decode(selfmodel_inputs**kwargs): + raise NotImplementedError("_decode is not implemented.") + + +class CPMAntBeamSearch(CPMAntGeneration): + def _decode( + self, + model_inputs, + beam_size=3, + max_length=100, + repetition_penalty=1.0, + repetition_window=None, + **kwargs + ): + """ + Beam search + Args: + model_inputs (dict): input ids. + beam_size (intoptionaldefaults to 3): beam size of beam search. + generate_length (intoptionaldefaults to 100): maximum generation length. + repetition_penalty (floatoptionaldefaults to 1.0): repetition penalty coefficient1.0 means no penalty. + repetition_window (intoptionaldefaults to None): window size of repetition penaltyNone means that all output tokens are penalized. + """ # noqa: E501 + # generate_length + 1 for EOS token + max_length += 1 + + # expand dimmension + batch_size = model_inputs["input"].size(0) + input = ( + model_inputs["input"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + length = ( + model_inputs["length"] + .unsqueeze(1) + .expand(batch_sizebeam_size) + .contiguous() + .view( + batch_size * beam_size, + ) + ) + context = ( + model_inputs["context"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + position = ( + model_inputs["position"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + segment = ( + model_inputs["segment"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + span = ( + model_inputs["span"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + + done = [False for _ in range(batch_size)] + + beam_scores = torch.zeros((batch_sizebeam_size)dtype=torch.floatdevice=input.device) + beam_scores[:1:] = -1e9 + beam_scores = beam_scores.view(-1) + + # generated hypotheses + generated_hyps = [ + BeamHypotheses(beam_sizemax_lengthlength_penalty=1early_stopping=False) + for _ in range(batch_size) + ] + + pred_start_index = input.size(-1) + past_key_values = None + for i in range(max_length + 1): + if i == 0: + logits_past_key_values = self.model.inference( + input=input, + length=length, + context=context, + position=position, + segment=segment, + span=span, + past_key_values=past_key_values, + ) + else: + logits_past_key_values = self.model.inference( + input=input[:-1:], + length=length, + context=context, + position=position, + segment=segment, + span=span, + past_key_values=past_key_values, + ) + + # skip all steps when we are done with each sentence + if all(done): + break + + # (batch * beamseqlenmodel_dim) + logits = logits[:-1:] + + if i == 0: + logits[:self.tokenizer.eos_id] = -float("inf") + logits[:self.tokenizer.newline_id] = -float("inf") + + apply_repetition_penalty( + logits, + batch_size, + beam_size, + input, + repetition_penalty, + pred_start_index, + input.size(-1) - 1, + repetition_window, + ) + scores = F.log_softmax(logitsdim=-1) + + next_scores = scores + beam_scores[:None].expand_as( + scores + ) # (batch_size * beam_sizevocab_size) + + # re-organize to group the beam together (we are keeping top hypothesis accross beams) + next_scores = next_scores.view(batch_size-1) # (batch_sizebeam_size * vocab_size) + next_scoresnext_words = torch.topk( + next_scores2 * beam_sizedim=1largest=Truesorted=True + ) + + assert next_scores.size() == next_words.size() == (batch_size2 * beam_size) + next_batch_beam = [] + + for sent_id in range(batch_size): + # if we are done with this sentence + done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done( + next_scores[sent_id].max().item()i + ) + if done[sent_id]: + next_batch_beam.extend( + [(0self.tokenizer.pad_id0)] * beam_size + ) # pad the batch + continue + + # next sentence beam content + next_sent_beam = [] + + # next words for this sentence + for idxvalue in zip(next_words[sent_id]next_scores[sent_id]): + + # get beam and word IDs + beam_id = torch.div(idxscores.size(-1)rounding_mode="floor") + word_id = idx % scores.size(-1) + + # end of sentenceor next word + if word_id == self.tokenizer.eos_id or i == max_length: + generated_hyps[sent_id].add( + input[sent_id * beam_size + beam_idpred_start_index:] + .clone() + .cpu() + .tolist(), + value.item(), + ) + else: + next_sent_beam.append((valueword_idsent_id * beam_size + beam_id)) + + # the beam for next step is full + if len(next_sent_beam) == beam_size: + break + + # update next beam content + assert len(next_sent_beam) == 0 if i == max_length else beam_size + if len(next_sent_beam) == 0: + next_sent_beam = [(0self.tokenizer.pad_id0)] * beam_size # pad the batch + next_batch_beam.extend(next_sent_beam) + assert len(next_batch_beam) == beam_size * (sent_id + 1) + + # we have reached the last step + if i == max_length: + break + + # sanity check / prepare next batch + assert len(next_batch_beam) == batch_size * beam_size + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_words = input.new([x[1] for x in next_batch_beam]) + beam_idx = length.new([x[2] for x in next_batch_beam]).long() + + # re-order batch and internal states + input = input[beam_idx:] + + past_key_values = [list(each) if each is not None else each for each in past_key_values] # type: ignore # noqa: E501 + for key_value_layer in past_key_values: + if key_value_layer is not None: + key_value_layer[0] = key_value_layer[0][beam_idx] + key_value_layer[1] = key_value_layer[1][beam_idx] + + # update input ids + input = torch.cat([inputbeam_words.unsqueeze(1)]dim=-1) + length += 1 + context = torch.cat( + [contexttorch.ones((context.size(0)1)dtype=torch.intdevice=context.device)], + dim=-1, + ) + position = torch.cat([positionposition[:-1:] + 1]dim=-1) + segment = torch.cat( + [segmentsegment[:-1:]]dim=-1 + ) # segment id always the same as the previous token + span = torch.cat([spanspan[:-1:]]dim=-1) + + # select the best hypotheses + results = [] + for ihypotheses in enumerate(generated_hyps): + best_hyp = max(hypotheses.hypkey=lambda x: x[0])[1] + results.append(best_hyp) + + result_text = list(map(self.tokenizer.decoderesults)) + return result_text + + +class CPMAntRandomSampling(CPMAntGeneration): + def _decode( + self, + model_inputs, + max_length=100, + top_k=0, + top_p=0.9, + temperature=0.9, + repetition_penalty=1.0, + repetition_window=None, + **kwargs + ): + """ + Top-k and top-p sampling. + Args: + model_inputs (dict): input ids + generate_length (intoptionaldefaults to 100): maximum generation length + top_k (intoptionaldefaults to 0): keep only top k tokens with highest probability. 0 means keeping all tokens. + top_p (intoptionaldefaults to 0.9): keep the top tokens with cumulative probability >= top_p. + temperature (intoptionaldefaults to 0.9): the value that can cool down the logits distribution. + repetition_penalty (floatoptionaldefaults to 1.0): repetition penalty coefficient1.0 means no penalty. + repetition_window (intoptionaldefaults to None): window size of repetition penaltyNone means that all output tokens are penalized. + """ # noqa: E501 + # generate_length + 1 for EOS token + max_length += 1 + + input = model_inputs["input"] + length = model_inputs["length"] + context = model_inputs["context"] + position = model_inputs["position"] + segment = model_inputs["segment"] + span = model_inputs["span"] + batch_size = input.size(0) + + pred_start_index = input.size(-1) + past_key_values = None + done = [False for _ in range(batch_size)] + results = [None for _ in range(batch_size)] + for i in range(max_length): + if i == 0: + logits_past_key_values = self.model.inference( + input=input, + length=length, + context=context, + position=position, + segment=segment, + span=span, + past_key_values=past_key_values, + ) + else: + logits_past_key_values = self.model.inference( + input=input[:-1:], + length=length, + context=context, + position=position, + segment=segment, + span=span, + past_key_values=past_key_values, + ) + + logits = logits[:-1:] + + if i == 0: + logits[:self.tokenizer.eos_id] = -float("inf") + logits[:self.tokenizer.newline_id] = -float("inf") + + apply_repetition_penalty( + logits, + batch_size, + 1, + input, + repetition_penalty, + pred_start_index, + input.size(-1) - 1, + repetition_window, + ) + + logits = logits / temperature + logits = top_k_top_p_filtering(logitstop_k=top_ktop_p=top_p) + + probs = F.softmax(logitsdim=-1) + next_token = torch.multinomial(probsnum_samples=1) + + for idx in range(batch_size): + if not done[idx] and ( + next_token[idx].item() == self.tokenizer.eos_id or i == max_length - 1 + ): + done[idx] = True + results[idx] = input[idxpred_start_index:].clone().cpu().tolist() # type: ignore # noqa: E501 + + if sum(done) == batch_size: + break + + # update input ids + input = torch.cat([inputnext_token]dim=-1) + length += 1 + context = torch.cat( + [contexttorch.ones((context.size(0)1)dtype=torch.intdevice=context.device)], + dim=-1, + ) + position = torch.cat([positionposition[:-1:] + 1]dim=-1) + segment = torch.cat( + [segmentsegment[:-1:]]dim=-1 + ) # segment id always the same as the previous token + span = torch.cat([spanspan[:-1:]]dim=-1) + + result_text = list(map(self.tokenizer.decoderesults)) + return result_text diff --git a/cpm-live/cpm_live/generation/bee.py b/cpm-live/cpm_live/generation/bee.py new file mode 100644 index 0000000..5120b1f --- /dev/null +++ b/cpm-live/cpm_live/generation/bee.py @@ -0,0 +1,629 @@ +from typing import AnyDictListTuple +import numpy as np +import torch +import torch.nn.functional as F +from .generation_utils import BeamHypothesesapply_repetition_penalty +from ..tokenizers.bee import CPMBeeTokenizer +from ..models.bee import CPMBee +from ..training_tasks.bee.pretrain import convert_data_to_id +from ..utils import pad + + +class CPMBeeGeneration: + def __init__(selfmodel: CPMBeetokenizer: CPMBeeTokenizer): + model.eval() + self.model = model + self.tokenizer = tokenizer + + def _convert_to_tensors(selfdata: Anyin_context_samples: List[Any] = []): + answer_placeholders = [] + + def _put_placeholder(data: Anypath: List[str] = []): + if isinstance(datadict): + ret = {} + for kv in data.items(): + ret[k] = _put_placeholder(vpath + [k]) + return ret + else: + answer_placeholders.append(path) + return "".format(len(answer_placeholders)) + + data[""] = _put_placeholder(data[""]) + ( + input_ids, + input_id_subs, + context, + segment_ids, + segment_rel, + n_segments, + table_states, + ) = convert_data_to_id(self.tokenizerdatashuffle_answer=Falsemax_depth=8) + + sub_ans_map: Dict[intint] = {} + for fake_idtoken_sub in table_states["token_id_table"][""].items(): + token = table_states["ext_table"][fake_id] + if token.startswith(""): + ans_id = int(token[5:-1]) + sub_ans_map[token_sub] = ans_id + + tmp_input_ids = [] + tmp_input_sub = [] + tmp_input_seg = [] + + predict_segments: List[Tuple[intint]] = [] + for i in range(input_ids.shape[0]): + if context[i] == 0: + if input_ids[i] == self.tokenizer.encoder[""]: + # is ans + # (segment_idans_id) + predict_segments.append((segment_ids[i]sub_ans_map[input_id_subs[i]])) + else: + tmp_input_ids.append(input_ids[i]) + tmp_input_sub.append(input_id_subs[i]) + tmp_input_seg.append(segment_ids[i]) + + if len(predict_segments) == 0: + raise ValueError("No answer to predict") + + input_ids = np.array(tmp_input_idsdtype=np.int32) + input_id_subs = np.array(tmp_input_subdtype=np.int32) + context = np.full_like(tmp_input_ids1dtype=np.int8) + segment_ids = np.array(tmp_input_segdtype=np.int32) + sample_ids = np.zeros(input_ids.shapedtype=np.int32) + segment_rel_offset = np.zeros(input_ids.shapedtype=np.int32) + num_segments = np.full(input_ids.shapen_segmentsdtype=np.int32) + + for isample in enumerate(in_context_samples): + ( + sample_input_ids, + sample_id_subs, + _, + sample_segments, + sample_rel, + n_segments, + table_states, + ) = convert_data_to_id(self.tokenizersampletable_statesmax_depth=8) + input_ids = np.concatenate([input_idssample_input_ids]axis=0) + input_id_subs = np.concatenate([input_id_subssample_id_subs]axis=0) + context = np.concatenate( + [contextnp.ones(sample_input_ids.shapedtype=np.int8)]axis=0 + ) + segment_ids = np.concatenate([segment_idssample_segments]axis=0) + segment_rel_offset = np.concatenate( + [ + segment_rel_offset, + np.full(sample_input_ids.shapesegment_rel.shape[0]dtype=np.int32), + ], + axis=0, + ) + segment_rel = np.concatenate([segment_relsample_rel]axis=0) + sample_ids = np.concatenate( + [sample_idsnp.full(sample_input_ids.shapei + 1dtype=np.int32)]axis=0 + ) + num_segments = np.concatenate( + [num_segmentsnp.full(sample_input_ids.shapen_segmentsdtype=np.int32)]axis=0 + ) + input_pos = np.arange(input_ids.shape[0]dtype=np.int32) + + return ( + input_ids, + input_id_subs, + input_pos, + context, + segment_ids, + segment_rel_offset, + segment_rel, + sample_ids, + num_segments, + predict_segments, + answer_placeholders, + table_states["ext_table"], + table_states["token_id_table"], + ) + + def _process_list(selfdata_list: List[Any]): + pack_tensor = [] + other_info = [] + segment_rel_pack = [] + + batch_ext_table_map: Dict[Tuple[intint]int] = {} + batch_ext_table_ids: List[int] = [] + batch_ext_table_sub: List[int] = [] + + for data in data_list: + ( + input_ids, + input_id_subs, + input_pos, + context, + segment_ids, + segment_rel_offset, + segment_rel, + sample_ids, + num_segments, + predict_segments, + answer_placeholders, + ext_table, + token_id_table, + ) = self._convert_to_tensors(data[]) + rev_ext_table: Dict[intstr] = {} + for tokenmp in token_id_table.items(): + if token == "": + continue + token_id = self.tokenizer.encoder[token] + for fake_idtoken_sub in mp.items(): + if token_sub > 0: + if (token_idtoken_sub) not in batch_ext_table_map: + batch_ext_table_map[(token_idtoken_sub)] = ( + len(batch_ext_table_ids) + self.tokenizer.vocab_size + ) + batch_ext_table_ids.append(token_id) + batch_ext_table_sub.append(token_sub) + rev_ext_table[batch_ext_table_map[(token_idtoken_sub)]] = ext_table[ + fake_id + ] + else: + rev_ext_table[token_id] = ext_table[fake_id] + pack_tensor.append( + { + "input": torch.from_numpy(input_ids).unsqueeze(0), + "input_sub": torch.from_numpy(input_id_subs).unsqueeze(0), + "input_pos": torch.from_numpy(input_pos).unsqueeze(0), + "context": torch.from_numpy(context).unsqueeze(0), + "sample_idx": torch.from_numpy(sample_ids).unsqueeze(0), + "num_segments": torch.from_numpy(num_segments).unsqueeze(0), + "segment": torch.from_numpy(segment_ids).unsqueeze(0), + "segment_rel_offset": torch.from_numpy(segment_rel_offset).unsqueeze(0), + } + ) + segment_rel_pack.append(torch.from_numpy(segment_rel)) + other_info.append( + { + "predict_segments": predict_segments, + "answer_placeholders": answer_placeholders, + "ext_table": rev_ext_table, + } + ) + + keys = set(pack_tensor[0].keys()) + padded = {} + for key in keys: + padded[key] = pad(pack_tensorkey).cuda() + + max_num_rels = 0 + for rel in segment_rel_pack: + max_num_rels = max(max_num_relsrel.size(0)) + padded_rels = torch.zeros(len(segment_rel_pack)max_num_relsdtype=torch.int32) + for irel in enumerate(segment_rel_pack): + padded_rels[i: rel.size(0)] = rel + padded["segment_rel"] = padded_rels.cuda() + padded["batch_ext_table_ids"] = torch.tensor( + batch_ext_table_idsdtype=torch.int32device="cuda" + ) + padded["batch_ext_table_sub"] = torch.tensor( + batch_ext_table_subdtype=torch.int32device="cuda" + ) + return paddedother_info + + def generate(selfdata_list**kwargs): + model_inputsother_info = self._process_list(data_list) + with torch.inference_mode(): + result_ids = self._decode(model_inputsother_info**kwargs) + for sent_idresult in enumerate(result_ids): + ans_result_map: Dict[intList[int]] = {} + for raw_word_idans_id in result: + if ans_id not in ans_result_map: + ans_result_map[ans_id] = [] + ans_result_map[ans_id].append(raw_word_id) + + answer_placeholders = other_info[sent_id]["answer_placeholders"] + ext_table = other_info[sent_id]["ext_table"] + data = data_list[sent_id] + for ans_idtoken_ids in ans_result_map.items(): + if token_ids[-1] == self.tokenizer.eos_id: + token_ids = token_ids[:-1] + text = self.tokenizer.decode(token_idsext_table) + path = answer_placeholders[ans_id - 1] + + if len(path) > 0: + p = data[""] + for part in path[:-1]: + p = p[part] + p[path[-1]] = text + else: + data[""] = text + for ans_id in range(len(answer_placeholders)): + if (ans_id + 1) not in ans_result_map: + path = answer_placeholders[ans_id] + p = data[""] + for part in path[:-1]: + p = p[part] + p[path[-1]] = None + return data_list + + def _decode(selfmodel_inputsother_info**kwargs): + raise NotImplementedError("_decode is not implemented.") + + +class CPMBeeBeamSearch(CPMBeeGeneration): + def _decode( + self, + model_inputs, + other_info, + beam_size=3, + max_length=100, + repetition_penalty=1.0, + repetition_window=None, + ): + """ + Beam search + Args: + model_inputs (dict): input ids. + beam_size (intoptionaldefaults to 3): beam size of beam search. + generate_length (intoptionaldefaults to 100): maximum generation length. + repetition_penalty (floatoptionaldefaults to 1.0): repetition penalty coefficient1.0 means no penalty. + repetition_window (intoptionaldefaults to None): window size of repetition penaltyNone means that all output tokens are penalized. + """ # noqa: E501 + # generate_length + 1 for EOS token + max_length += 1 + + # expand dimmension + batch_size = model_inputs["input"].size(0) + input: torch.Tensor = ( + model_inputs["input"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + input_sub: torch.Tensor = ( + model_inputs["input_sub"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + input_pos: torch.Tensor = ( + model_inputs["input_pos"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + context: torch.Tensor = ( + model_inputs["context"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + sample_ids: torch.Tensor = ( + model_inputs["sample_idx"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + num_segments: torch.Tensor = ( + model_inputs["num_segments"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + segment: torch.Tensor = ( + model_inputs["segment"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + segment_rel_offset: torch.Tensor = ( + model_inputs["segment_rel_offset"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + segment_rel: torch.Tensor = ( + model_inputs["segment_rel"] + .unsqueeze(1) + .expand(batch_sizebeam_size-1) + .contiguous() + .view(batch_size * beam_size-1) + ) + ext_table_ids: torch.Tensor = model_inputs["batch_ext_table_ids"] + ext_table_sub: torch.Tensor = model_inputs["batch_ext_table_sub"] + ext_table_ids_cpu = ext_table_ids.cpu() + ext_table_sub_cpu = ext_table_sub.cpu() + + done = [False for _ in range(batch_size)] + + beam_scores = torch.zeros((batch_sizebeam_size)dtype=torch.floatdevice=input.device) + beam_scores[:1:] = -1e9 + beam_scores = beam_scores.view(-1) + + # generated hypotheses + generated_hyps = [ + BeamHypotheses(beam_sizemax_lengthlength_penalty=1early_stopping=False) + for _ in range(batch_size) + ] + + pred_start_index = input.size(-1) + __past_key_values = self.model.inference( + input=input, + input_sub=input_sub, + position=input_pos, + context=context, + sample_ids=sample_ids, + num_segments=num_segments, + segment=segment, + segment_rel_offset=segment_rel_offset, + segment_rel=segment_rel, + ext_table_ids=ext_table_ids, + ext_table_sub=ext_table_sub, + past_key_values=None, + ) + + beam_states = [] + for sent_id in range(batch_size): + instance_beam_states = [] + + for beam_id in range(beam_size): + instance_beam_states.append( + { + "idx": 0, + "ans": [], + "nx_token_id": self.tokenizer.bos_id, + "nx_token_sub": 0, + "nx_segment_id": other_info[sent_id]["predict_segments"][0][0], + "nx_position": 0, + } + ) + beam_states.append(instance_beam_states) + for i in range(max_length + 1): + tmp_input = [] + tmp_input_sub = [] + tmp_position = [] + tmp_segment = [] + for sent_id in range(batch_size): + for beam_id in range(beam_size): + tmp_input.append(beam_states[sent_id][beam_id]["nx_token_id"]) + tmp_input_sub.append(beam_states[sent_id][beam_id]["nx_token_sub"]) + tmp_position.append(beam_states[sent_id][beam_id]["nx_position"]) + tmp_segment.append(beam_states[sent_id][beam_id]["nx_segment_id"]) + with torch.no_grad(): + input = torch.cat( + [ + input, + torch.tensor(tmp_inputdtype=torch.int32device="cuda").view( + batch_size * beam_size1 + ), + ], + dim=-1, + ) + logits_past_key_values = self.model.inference( + input=input[:-1:], + input_sub=torch.tensor(tmp_input_subdtype=torch.int32device="cuda").view( + batch_size * beam_size1 + ), + position=torch.tensor(tmp_positiondtype=torch.int32device="cuda").view( + batch_size * beam_size1 + ), + context=torch.ones( + batch_size * beam_sizedtype=torch.booldevice="cuda" + ).view(batch_size * beam_size1), + sample_ids=torch.zeros( + batch_size * beam_sizedtype=torch.int32device="cuda" + ).view(batch_size * beam_size1), + num_segments=num_segments[:-1:], + segment=torch.tensor(tmp_segmentdtype=torch.int32device="cuda").view( + batch_size * beam_size1 + ), + segment_rel_offset=segment_rel_offset[:-1:], + segment_rel=segment_rel, + ext_table_ids=ext_table_ids, + ext_table_sub=ext_table_sub, + past_key_values=past_key_values, + ) + logits = logits[:-1:] + + # skip all steps when we are done with each sentence + if all(done): + break + + for sent_id in range(batch_size): + if self.tokenizer.unk_id not in other_info[sent_id]["ext_table"]: + # unk is not allowedmask unk + logits[ + sent_id * beam_size : (sent_id + 1) * beam_sizeself.tokenizer.unk_id + ] = -10000 + ext_ids = set() + for v in other_info[sent_id]["ext_table"].keys(): + ext_ids.add(v) + for ext_id in range( + self.tokenizer.vocab_sizeself.tokenizer.vocab_size + ext_table_ids.size(0) + ): + if ext_id not in ext_ids: + logits[sent_id * beam_size : (sent_id + 1) * beam_sizeext_id] = -10000 + + apply_repetition_penalty( + logits, + batch_size, + beam_size, + input, + repetition_penalty, + pred_start_index, + input.size(-1) - 1, + repetition_window, + ) + scores = F.log_softmax(logitsdim=-1) + next_scores = scores + beam_scores[:None].expand_as( + scores + ) # (batch_size * beam_sizevocab_size) + + # re-organize to group the beam together (we are keeping top hypothesis accross beams) + next_scores = next_scores.view(batch_size-1) # (batch_sizebeam_size * vocab_size) + next_scoresnext_words = torch.topk( + next_scores2 * beam_sizedim=1largest=Truesorted=True + ) + assert next_scores.size() == next_words.size() == (batch_size2 * beam_size) + next_beam_states = [] + + for sent_id in range(batch_size): + # if we are done with this sentence + done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done( + next_scores[sent_id].max().item()i + ) + if done[sent_id]: + next_beam_states.append( + [ + ( + { + "idx": 0, + "ans": [], + "nx_token_id": 0, + "nx_token_sub": 0, + "nx_segment_id": 0, + "nx_position": 0, + }, + 0, + 0, + ) + ] + * beam_size + ) # pad the batch + continue + + # next sentence beam content + next_instance_beam_states = [] + + # next words for this sentence + for idxvalue in zip(next_words[sent_id]next_scores[sent_id]): + + # get beam and word IDs + beam_id = torch.div(idxscores.size(-1)rounding_mode="floor").item() + word_id = (idx % scores.size(-1)).item() + + curr_info = beam_states[sent_id][beam_id] + # end of sentenceor next word + if ( + word_id == self.tokenizer.eos_id + and (curr_info["idx"] + 1 == len(other_info[sent_id]["predict_segments"])) + ) or i == max_length: + generated_hyps[sent_id].add( + beam_states[sent_id][beam_id]["ans"] + + [ + ( + word_id, + other_info[sent_id]["predict_segments"][curr_info["idx"]][1], + ) + ], + value.item(), + ) + elif word_id == self.tokenizer.eos_id: + next_instance_beam_states.append( + ( + { + "idx": curr_info["idx"] + 1, + "ans": curr_info["ans"] + + [ + ( + word_id, + other_info[sent_id]["predict_segments"][ + curr_info["idx"] + ][1], + ) + ], + "nx_token_id": self.tokenizer.bos_id, + "nx_token_sub": 0, + "nx_segment_id": other_info[sent_id]["predict_segments"][ + curr_info["idx"] + 1 + ][0], + "nx_position": 0, + }, + value.item(), + sent_id * beam_size + beam_id, + ) + ) + + else: + raw_word_id = word_id + word_id_sub = 0 + if word_id >= self.tokenizer.vocab_size: + word_id -= self.tokenizer.vocab_size + word_id_sub = int(ext_table_sub_cpu[word_id].item()) + word_id = int(ext_table_ids_cpu[word_id].item()) + + next_instance_beam_states.append( + ( + { + "idx": curr_info["idx"], + "ans": curr_info["ans"] + + [ + ( + raw_word_id, + other_info[sent_id]["predict_segments"][ + curr_info["idx"] + ][1], + ) + ], + "nx_token_id": word_id, + "nx_token_sub": word_id_sub, + "nx_segment_id": curr_info["nx_segment_id"], + "nx_position": curr_info["nx_position"] + 1, + }, + value.item(), + sent_id * beam_size + beam_id, + ) + ) + + # the beam for next step is full + if len(next_instance_beam_states) == beam_size: + break + + # update next beam content + assert len(next_instance_beam_states) == 0 if i == max_length else beam_size + next_beam_states.append(next_instance_beam_states) + + # we have reached the last step + if i == max_length: + break + + # sanity check / prepare next batch + beam_reorder_idx = [] + beam_new_scores = [] + beam_states = [] + for sent_id in range(batch_size): + instance_beam_states = [] + for beam_id in range(beam_size): + statevaluebeam_idx = next_beam_states[sent_id][beam_id] + beam_reorder_idx.append(beam_idx) + beam_new_scores.append(value) + instance_beam_states.append(state) + beam_states.append(instance_beam_states) + + input = input[beam_reorder_idx:] + beam_scores = torch.tensor(beam_new_scoresdtype=torch.floatdevice=input.device) + for kw in past_key_values.keys(): + if kw == "buffer": + buf_list = past_key_values[kw] + nw_buf_list = [] + for buf in buf_list: + if buf is None: + nw_buf_list.append((NoneNone)) + else: + k_bufv_buf = buf + nw_buf_list.append( + (k_buf[beam_reorder_idx:]v_buf[beam_reorder_idx:]) + ) + past_key_values[kw] = nw_buf_list + else: + past_key_values[kw] = past_key_values[kw][beam_reorder_idx:] + + # select the best hypotheses + results = [] + for sent_idhypotheses in enumerate(generated_hyps): + best_hyp = max(hypotheses.hypkey=lambda x: x[0])[1] + results.append(best_hyp) + return results diff --git a/cpm-live/cpm_live/generation/generation_utils.py b/cpm-live/cpm_live/generation/generation_utils.py new file mode 100644 index 0000000..07867d8 --- /dev/null +++ b/cpm-live/cpm_live/generation/generation_utils.py @@ -0,0 +1,112 @@ +import torch +import torch.nn.functional as F + + +def top_k_top_p_filtering(logitstop_k=0top_p=0.0filter_value=-float("inf")): + # This function has been mostly taken from huggingface conversational ai code at + # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logitstop_k)[0][...-1None] + logits[indices_to_remove] = filter_value + + batch_size = logits.size()[0] + if top_p > 0.0: + logits = logits.view(batch_size-1).contiguous() + for index in range(len(logits)): + + sorted_logitssorted_indices = torch.sort(logits[index].view(-1)descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logitsdim=-1)dim=-1) + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[...1:] = sorted_indices_to_remove[...:-1].clone() + sorted_indices_to_remove[...0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[index][indices_to_remove] = filter_value + + logits = logits.view(batch_size-1).contiguous() + + return logits + + +def apply_repetition_penalty( + logits, + batch_size, + num_beams, + prev_output_tokens, + repetition_penalty, + start_idx=None, + end_idx=None, + window_size=None, +): + # only conduct repetition penalty for the output + assert repetition_penalty >= 1"repetition penalty coefficient should >= 1" + # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) + for i in range(batch_size * num_beams): + if start_idx is None or end_idx is None: + output_tokens = prev_output_tokens[i].tolist() + else: + if end_idx >= start_idx: + if window_size: + output_tokens = prev_output_tokens[i][ + max(start_idxend_idx + 1 - window_size) : end_idx + 1 + ].tolist() + else: + output_tokens = prev_output_tokens[i][start_idx : end_idx + 1].tolist() + else: + output_tokens = [] + for previous_token in set(output_tokens): + # if score < 0 then repetition penalty has to + # multiplied to reduce the previous token probability + if logits[iprevious_token] < 0: + logits[iprevious_token] *= repetition_penalty + else: + logits[iprevious_token] /= repetition_penalty + + +class BeamHypotheses: + def __init__(selfn_hypmax_lenlength_penaltyearly_stopping): + """ + Initialize n-best list of hypotheses. + """ + self.max_len = max_len + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.n_hyp = n_hyp + self.hyp = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.hyp) + + def add(selfhypsum_logprobs): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / len(hyp) ** self.length_penalty + + if len(self) < self.n_hyp or score > self.worst_score: + self.hyp.append((scorehyp)) + if len(self) > self.n_hyp: + sorted_scores = sorted([(sidx) for idx(s_) in enumerate(self.hyp)]) + del self.hyp[sorted_scores[0][1]] + self.worst_score = sorted_scores[1][0] + else: + self.worst_score = min(scoreself.worst_score) + + def is_done(selfbest_sum_logprobscur_len): + """ + If there are enough hypotheses and that none of the hypotheses being generated + can become better than the worst one in the heapthen we are done with this sentence. + """ + if len(self) < self.n_hyp: + return False + elif self.early_stopping: + return True + else: + return self.worst_score >= best_sum_logprobs / cur_len**self.length_penalty diff --git a/cpm-live/cpm_live/layers/__init__.py b/cpm-live/cpm_live/layers/__init__.py index 4ff22c1..dbd8dcd 100644 --- a/cpm-live/cpm_live/layers/__init__.py +++ b/cpm-live/cpm_live/layers/__init__.py @@ -1,5 +1,5 @@ -from .embedding import Embedding -from .position_embedding import SegmentPositionEmbedding +from .embedding import EmbeddingEmbeddingExt +from .position_embedding import SegmentPositionEmbeddingBucketPositionBiasRotaryEmbedding from .linear import Linear from .layernorm import LayerNorm from .attention import Attention diff --git a/cpm-live/cpm_live/layers/attention.py b/cpm-live/cpm_live/layers/attention.py index c227037..241b14e 100644 --- a/cpm-live/cpm_live/layers/attention.py +++ b/cpm-live/cpm_live/layers/attention.py @@ -72,8 +72,8 @@ def forward( len_q = hidden_q.size(1) len_k = hidden_kv.size(1) - h_q = self.project_q(hidden_q) - h_k = self.project_k(hidden_kv) + h_q = self.project_q(hidden_q) / math.sqrt(math.sqrt(self.dim_head)) + h_k = self.project_k(hidden_kv) / math.sqrt(math.sqrt(self.dim_head)) h_v = self.project_v(hidden_kv) h_q = h_q.view(batch_sizelen_qself.num_headsself.dim_head).permute(0213) @@ -86,7 +86,9 @@ def forward( len_k = h_k.size(-2) # (bn_hlen_qd_h) @ (bn_hd_hlen_k) -> (bn_hlen_qlen_k) - score = torch.matmul(h_qh_k.transpose(-1-2)) / math.sqrt(self.dim_head) + score = torch.matmul( + h_qh_k.transpose(-1-2) + ) # / math.sqrt(self.dim_head) moved to line 75~76 score = score + position_bias score = torch.masked_fill( score, diff --git a/cpm-live/cpm_live/layers/blocks.py b/cpm-live/cpm_live/layers/blocks.py index e5db92e..a16abf6 100644 --- a/cpm-live/cpm_live/layers/blocks.py +++ b/cpm-live/cpm_live/layers/blocks.py @@ -91,7 +91,7 @@ def forward( if self.dropout is not None: x = self.dropout(x) - hidden_states = hidden_states + x + hidden_states = (hidden_states + x) / 1.05 if use_cache: return hidden_statescurrent_key_value @@ -154,7 +154,7 @@ def forward( x = self.ffn(x) if self.dropout is not None: x = self.dropout(x) - hidden_states = hidden_states + x + hidden_states = (hidden_states + x) / 1.05 return hidden_states diff --git a/cpm-live/cpm_live/layers/embedding.py b/cpm-live/cpm_live/layers/embedding.py index 5b805b0..ffc305f 100644 --- a/cpm-live/cpm_live/layers/embedding.py +++ b/cpm-live/cpm_live/layers/embedding.py @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional import torch import bmtrain as bmt import math import torch.nn.functional as F +from .position_embedding import RotaryEmbedding class Embedding(bmt.DistributedModule): @@ -60,3 +62,56 @@ def projection(selfx: torch.Tensor): """ # noqa: E501 logits = F.linear(x / math.sqrt(self.dim_model)self.weight) return logits + + +class EmbeddingExt(bmt.DistributedModule): + def __init__( + self, + vocab_size: int, + embedding_size: int, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1, + distance_scale: int = 16, + ): + + super().__init__() + + self.dim_model = embedding_size + self.rotary_emb = RotaryEmbedding( + dim=embedding_sizedistance_scale=distance_scaledtype=dtype + ) + + self.weight = bmt.DistributedParameter( + torch.empty(vocab_sizeembedding_sizedtype=dtype), + init_method=bmt.ParameterInitializer( + torch.nn.init.normal_mean=init_meanstd=init_std + ), + ) + + def forward(selfids: torch.Tensorids_sub: torch.Tensor): + """ + Args: + ids (:obj:`torch.Tensor` of shape ``(batch_sizeseq_len)``): Indices of input sequence tokens. + ids (:obj:`torch.Tensor` of shape ``(batch_size)``): Subscript of input sequence tokens. + Return: + :obj:`torch.Tensor` of shape ``(batch_sizeseq_lenembedding_size)``: The embedding output. + """ # noqa: E501 + + embeds = F.embedding(idsself.weight) / math.sqrt(self.dim_model) + return self.rotary_emb(embedsids_sub) + + def projection(selfx: torch.Tensorext_table: Optional[torch.Tensor] = None): + """ + Projection based on embedding's weight. For exampleembedding map vocab_size to embed_sizethan projection map embed_size back to vocab_size. + Args: + x (:obj:`torch.Tensor` of shape ``(batchseq_lendim_model)``): Input of projection + ext_table (:obj:`torch.Tensor` of shape ``(ext_table_sizedim_model)``): Ext vocab table. + Returns: + :obj:`torch.Tensor` of shape ``(batchseq_lenvocab_size + ext_table_size)``: The projection output. + """ # noqa: E501 + logits = F.linear(x / math.sqrt(self.dim_model)self.weight) + if ext_table is not None: + logits_ext = F.linear(xext_table) + logits = torch.cat([logitslogits_ext]dim=-1) + return logits diff --git a/cpm-live/cpm_live/layers/feedforward.py b/cpm-live/cpm_live/layers/feedforward.py index 8dd3c23..c015cf3 100644 --- a/cpm-live/cpm_live/layers/feedforward.py +++ b/cpm-live/cpm_live/layers/feedforward.py @@ -101,7 +101,7 @@ def __init__( dim_in=dim_ff, dim_out=dim_model, dtype=dtype, - scale_before=True, + scale_before=False, ) def forward(selfx: torch.Tensor): diff --git a/cpm-live/cpm_live/layers/position_embedding.py b/cpm-live/cpm_live/layers/position_embedding.py index 790b527..2cd49f2 100644 --- a/cpm-live/cpm_live/layers/position_embedding.py +++ b/cpm-live/cpm_live/layers/position_embedding.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from typing import Union import torch import bmtrain as bmt import torch.nn.functional as F @@ -21,12 +22,12 @@ class SegmentPositionEmbedding(bmt.DistributedModule): def __init__( self, - num_heads, - num_segments=1, - num_buckets=32, - max_distance=128, - bidirectional=False, - dtype=torch.half, + num_heads: int, + num_segments: int = 1, + num_buckets: int = 32, + max_distance: int = 128, + bidirectional: bool = False, + dtype: torch.dtype = torch.half, init_mean: float = 0.0, init_std: float = 1, ): @@ -125,3 +126,129 @@ def _position_bucket( is_smallrelative_position.to(torch.int32)relative_postion_if_large ) return relative_buckets + + +class BucketPositionBias(bmt.DistributedModule): + def __init__( + self, + num_heads: int, + num_buckets: int = 32, + num_segment_bucket: int = 32, + max_distance: int = 128, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1, + ) -> None: + super().__init__() + + self.num_heads = num_heads + self.num_buckets = num_buckets + self.num_segment_bucket = num_segment_bucket + self.max_distance = max_distance + + self.relative_attention_bias = bmt.DistributedParameter( + torch.empty(num_buckets + num_segment_bucketnum_headsdtype=dtype), + init_method=bmt.ParameterInitializer( + torch.nn.init.normal_mean=init_meanstd=init_std + ), + ) + + def forward( + self, + query_pos: torch.Tensor # (batchlen_q) + key_pos: torch.Tensor # (batchlen_k) + rel_buckets: torch.Tensor # (batchlen_qlen_k) + ): + with torch.no_grad(): + + batch = key_pos.size(0) + keylen = key_pos.size(1) + querylen = query_pos.size(1) + + assert key_pos.size(0) == query_pos.size(0) + assert ( + rel_buckets.size(0) == batch + and rel_buckets.size(1) == querylen + and rel_buckets.size(2) == keylen + ) + + relative_position_bucket = rel_buckets - 1 + self.num_buckets # 与相对位置编码区间不重叠 + + # b*q*k + inner_segment_bucket = self._position_bucket( + key_pos[...None:] - query_pos[...:None], + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + relative_position_bucket = torch.where( + rel_buckets == 0, + inner_segment_bucket, + relative_position_bucket, + ) + # (batchlen_qlen_k) + + # (batchlen_qlen_knum_heads) + embeds = F.embedding(relative_position_bucketself.relative_attention_bias) + # (batchnum_headslen_qlen_k) + embeds = embeds.permute(0312).contiguous() + return embeds + + def _position_bucket(selfrelative_positionnum_buckets=32max_distance=128): + relative_buckets = 0 + num_buckets //= 2 + relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets + relative_position = torch.abs(relative_position) + + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.int32) + relative_postion_if_large = torch.min( + relative_postion_if_large, + torch.full_like(relative_postion_if_largenum_buckets - 1), + ) + relative_buckets += torch.where( + is_smallrelative_position.to(torch.int32)relative_postion_if_large + ) + return relative_buckets + + +class RotaryEmbedding(bmt.DistributedModule): + def __init__( + self, + dim, + base=10000, + distance_scale: Union[intfloat] = 1, + dtype: torch.dtype = torch.half, + ): + super().__init__() + inv_freq = 1.0 / ( + base ** (torch.arange(0dim2device="cuda"dtype=torch.float32) / dim) + ) + inv_freq = inv_freq.to(dtype) + self.distance_scale = distance_scale + self.dtype = dtype + self.inv_freq = inv_freq + + def forward(selfx: torch.Tensorx_pos: torch.Tensor): + """ + Args: + x (:obj:`torch.Tensor` of shape ``(...dim)``): Inputs. + x_pos (:obj:`torch.Tensor` of shape ``(...)``): Positions of inputs. + """ + x_pos = x_pos * self.distance_scale + freqs = x_pos[...None].to(self.dtype) * self.inv_freq[None:] # (...dim/2) + + # the same implementation as sat + emb = torch.cat((freqsfreqs)dim=-1) # (...dim) + emb_cos = emb.cos() # (...dim) + emb_sin = emb.sin() # (...dim) + + rotate_x = torch.cat( + [-x[...x.size(-1) // 2 :]x[...: x.size(-1) // 2]]dim=-1 + ) # (...dim) + + return x * emb_cos + rotate_x * emb_sin diff --git a/cpm-live/cpm_live/models/__init__.py b/cpm-live/cpm_live/models/__init__.py index d1e4033..ed9ed68 100644 --- a/cpm-live/cpm_live/models/__init__.py +++ b/cpm-live/cpm_live/models/__init__.py @@ -1 +1,4 @@ from .ant import CPMAntConfigCPMAnt +from .bee import CPMBeeConfigCPMBee +from .ant_torch import CPMAntTorch +from .bee_torch import CPMBeeTorch diff --git a/cpm-live/cpm_live/models/ant.py b/cpm-live/cpm_live/models/ant.py index dc246f6..8d2f11c 100644 --- a/cpm-live/cpm_live/models/ant.py +++ b/cpm-live/cpm_live/models/ant.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import ListOptionalTuple import torch from ..utils import Config @@ -37,6 +38,7 @@ def __init__( prompt_length: int = 32, segment_types: int = 32, mask_modules: Optional[List[Tuple[boolbool]]] = None, + **kwargs, ): super().__init__() @@ -77,12 +79,12 @@ def __init__(selfconfig: CPMAntConfig): mask_modules=config.mask_modules, ) - # self.prompt_embedding = Embedding( - # vocab_size=config.prompt_types * config.prompt_length, - # embedding_size=config.dim_model, - # dtype=config.dtype, - # init_std=0.02, - # ) + self.prompt_embedding = Embedding( + vocab_size=config.prompt_types * config.prompt_length, + embedding_size=config.dim_model, + dtype=config.dtype, + init_std=0.02, + ) self.segment_embedding = Embedding( vocab_size=config.segment_types, @@ -92,7 +94,7 @@ def __init__(selfconfig: CPMAntConfig): ) self.input_embedding = Embedding( - vocab_size=config.vocab_size + config.prompt_length * config.prompt_types, + vocab_size=config.vocab_size, embedding_size=config.dim_model, dtype=config.dtype, init_std=0.02, @@ -117,8 +119,49 @@ def forward( position: torch.Tensor # (batchseqlen) segment: torch.Tensor # (batchseqlen) span: torch.Tensor # (batchseqlen) + ): + + batch = input.size(0) + seqlen = input.size(1) + input_prompt = input[:: self.prompt_length].contiguous() + input_ids = input[:self.prompt_length :].contiguous() + + prompt_states = self.prompt_embedding(input_prompt) + hidden_states = self.input_embedding(input_ids) + segment_states = self.segment_embedding(segment) + hidden_states = torch.cat([prompt_stateshidden_states]1) + segment_states + + with torch.no_grad(): + device = input.device + directional_mask_2d = torch.arange(seqlendevice=device) <= torch.arange( + seqlendevice=device + ).view(-11) + attention_mask = context[:None:] | ( + context[::None].logical_not() & directional_mask_2d.view(1seqlenseqlen) + ) + attention_mask = attention_mask & (span[:None:] == span[::None]) + mask_1d = ( + torch.arange(seqlendevice=device)[None:].repeat(batch1) < length[:None] + ) + attention_mask = ( + mask_1d.view(batchseqlen1) & mask_1d.view(batch1seqlen) & attention_mask + ) + + position_bias = self.position_bias(positionpositionsegmentsegment) + hidden_states = self.encoder(hidden_statesattention_maskposition_bias) + + logits = self.input_embedding.projection(hidden_states) + return logitshidden_states + + def inference( + self, + input: torch.Tensor # (batchseqlen) + length: torch.Tensor # (batch) + context: torch.Tensor # (batchseqlen) + position: torch.Tensor # (batchseqlen) + segment: torch.Tensor # (batchseqlen) + span: torch.Tensor # (batchseqlen) past_key_values=None # num_layers * 2 * (batchnum_headsseqlendim_head) - use_cache=False, ): batch = input.size(0) @@ -126,11 +169,13 @@ def forward( if past_key_values is None: past_length = 0 past_key_values = tuple([None] * self.encoder.num_layers) - input_ids = input.contiguous() + input_prompt = input[:: self.prompt_length].contiguous() + input_ids = input[:self.prompt_length :].contiguous() + prompt_states = self.prompt_embedding(input_prompt) hidden_states = self.input_embedding(input_ids) segment_states = self.segment_embedding(segment) - hidden_states = hidden_states + segment_states + hidden_states = torch.cat([prompt_stateshidden_states]1) + segment_states else: past_length = past_key_values[0][0].size(-2) @@ -148,8 +193,10 @@ def forward( context[::None].logical_not() & directional_mask_2d.view(1seqlenseqlen) ) attention_mask = attention_mask & (span[:None:] == span[::None]) + # mask for left paddding mask_1d = ( - torch.arange(seqlendevice=device)[None:].repeat(batch1) < length[:None] + torch.tensor(list(range(seqlen))[::-1]device=device)[None:].repeat(batch1) + < length[:None] ) attention_mask = ( mask_1d.view(batchseqlen1) & mask_1d.view(batch1seqlen) & attention_mask @@ -157,19 +204,11 @@ def forward( position_bias = self.position_bias(positionpositionsegmentsegment) - if past_length > 0: - attention_mask = attention_mask[:past_length::] - position_bias = position_bias[::past_length::] + attention_mask = attention_mask[:past_length::] + position_bias = position_bias[::past_length::] - if use_cache: - hidden_statespresent_key_values = self.encoder( - hidden_statesattention_maskposition_biasuse_cachepast_key_values - ) - logits = self.input_embedding.projection(hidden_states) - return logitshidden_statespresent_key_values - else: - hidden_states = self.encoder( - hidden_statesattention_maskposition_biasuse_cachepast_key_values - ) - logits = self.input_embedding.projection(hidden_states) - return logitshidden_states + hidden_statespresent_key_values = self.encoder( + hidden_statesattention_maskposition_biasTruepast_key_values + ) + logits = self.input_embedding.projection(hidden_states) + return logitshidden_statespresent_key_values diff --git a/cpm-live/cpm_live/models/ant_torch.py b/cpm-live/cpm_live/models/ant_torch.py new file mode 100644 index 0000000..44a23bf --- /dev/null +++ b/cpm-live/cpm_live/models/ant_torch.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from ..native_layers import EncoderEmbeddingSegmentPositionEmbedding +from .ant import CPMAntConfig + + +class CPMAntTorch(torch.nn.Module): + def __init__(selfconfig: CPMAntConfig): + + super().__init__() + + self.encoder = Encoder( + num_layers=config.num_layers, + dim_model=config.dim_model, + dim_ff=config.dim_ff, + num_heads=config.num_heads, + dim_head=config.dim_head, + dtype=config.dtype, + eps=config.eps, + dropout_p=config.dropout_p, + mask_modules=config.mask_modules, + ) + + self.prompt_embedding = Embedding( + vocab_size=config.prompt_types * config.prompt_length, + embedding_size=config.dim_model, + dtype=config.dtype, + init_std=0.02, + ) + + self.segment_embedding = Embedding( + vocab_size=config.segment_types, + embedding_size=config.dim_model, + dtype=config.dtype, + init_std=0.02, + ) + + self.input_embedding = Embedding( + vocab_size=config.vocab_size, + embedding_size=config.dim_model, + dtype=config.dtype, + init_std=0.02, + ) + + self.position_bias = SegmentPositionEmbedding( + num_heads=config.num_heads, + num_segments=config.segment_types, + num_buckets=config.position_bias_num_buckets, + max_distance=config.position_bias_max_distance, + bidirectional=True, + dtype=config.dtype, + ) + + self.prompt_length = config.prompt_length + + def forward( + self, + input: torch.Tensor # (batchseqlen) + length: torch.Tensor # (batch) + context: torch.Tensor # (batchseqlen) + position: torch.Tensor # (batchseqlen) + segment: torch.Tensor # (batchseqlen) + span: torch.Tensor # (batchseqlen) + ): + + batch = input.size(0) + seqlen = input.size(1) + input_prompt = input[:: self.prompt_length].contiguous() + input_ids = input[:self.prompt_length :].contiguous() + + prompt_states = self.prompt_embedding(input_prompt) + hidden_states = self.input_embedding(input_ids) + segment_states = self.segment_embedding(segment) + hidden_states = torch.cat([prompt_stateshidden_states]1) + segment_states + + with torch.no_grad(): + device = input.device + directional_mask_2d = torch.arange(seqlendevice=device) <= torch.arange( + seqlendevice=device + ).view(-11) + attention_mask = context[:None:] | ( + context[::None].logical_not() & directional_mask_2d.view(1seqlenseqlen) + ) + attention_mask = attention_mask & (span[:None:] == span[::None]) + mask_1d = ( + torch.arange(seqlendevice=device)[None:].repeat(batch1) < length[:None] + ) + attention_mask = ( + mask_1d.view(batchseqlen1) & mask_1d.view(batch1seqlen) & attention_mask + ) + + position_bias = self.position_bias(positionpositionsegmentsegment) + hidden_states = self.encoder(hidden_statesattention_maskposition_bias) + + logits = self.input_embedding.projection(hidden_states) + return logitshidden_states + + def inference( + self, + input: torch.Tensor # (batchseqlen) + length: torch.Tensor # (batch) + context: torch.Tensor # (batchseqlen) + position: torch.Tensor # (batchseqlen) + segment: torch.Tensor # (batchseqlen) + span: torch.Tensor # (batchseqlen) + past_key_values=None # num_layers * 2 * (batchnum_headsseqlendim_head) + ): + + batch = input.size(0) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * self.encoder.num_layers) + input_prompt = input[:: self.prompt_length].contiguous() + input_ids = input[:self.prompt_length :].contiguous() + + prompt_states = self.prompt_embedding(input_prompt) + hidden_states = self.input_embedding(input_ids) + segment_states = self.segment_embedding(segment) + hidden_states = torch.cat([prompt_stateshidden_states]1) + segment_states + + else: + past_length = past_key_values[0][0].size(-2) + segment_states = self.segment_embedding(segment) + hidden_states = self.input_embedding(input) + segment_states[:-1::] + + seqlen = past_length + input.size(1) + + with torch.no_grad(): + device = input.device + directional_mask_2d = torch.arange(seqlendevice=device) <= torch.arange( + seqlendevice=device + ).view(-11) + attention_mask = context[:None:] | ( + context[::None].logical_not() & directional_mask_2d.view(1seqlenseqlen) + ) + attention_mask = attention_mask & (span[:None:] == span[::None]) + # mask for left paddding + mask_1d = ( + torch.tensor(list(range(seqlen))[::-1]device=device)[None:].repeat(batch1) + < length[:None] + ) + attention_mask = ( + mask_1d.view(batchseqlen1) & mask_1d.view(batch1seqlen) & attention_mask + ) + + position_bias = self.position_bias(positionpositionsegmentsegment) + + attention_mask = attention_mask[:past_length::] + position_bias = position_bias[::past_length::] + + hidden_statespresent_key_values = self.encoder( + hidden_statesattention_maskposition_biasTruepast_key_values + ) + logits = self.input_embedding.projection(hidden_states) + return logitshidden_statespresent_key_values diff --git a/cpm-live/cpm_live/models/bee.py b/cpm-live/cpm_live/models/bee.py new file mode 100644 index 0000000..098ecc6 --- /dev/null +++ b/cpm-live/cpm_live/models/bee.py @@ -0,0 +1,291 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ListOptionalTuple +from typing_extensions import TypedDict +import torch + +from ..utils import Config +from ..layers import EncoderEmbeddingExtBucketPositionBias +import bmtrain as bmt +from ..utils.gradient_shrink import gradient_shrink + + +class CPMBeeInferenceState(TypedDict): + buffer_position: torch.Tensor + buffer_context: torch.Tensor + buffer_sample_ids: torch.Tensor + buffer_num_segments: torch.Tensor + buffer_segments: torch.Tensor + buffer: List[Tuple[torch.Tensortorch.Tensor]] + + +class CPMBeeConfig(Config): + def __init__( + self, + vocab_size=30720, + dim_model=4096, + num_heads=64, + dim_head=64, + dim_ff=10240, + num_layers=32, + dropout_p=0.0, + position_bias_num_buckets=256, + position_bias_num_segment_buckets=256, + position_bias_max_distance=2048, + eps=1e-6, + half: bool = True, + mask_modules: Optional[List[Tuple[boolbool]]] = None, + ): + + super().__init__() + self.dim_model = dim_model + self.num_heads = num_heads + self.dim_head = dim_head + self.dim_ff = dim_ff + self.num_layers = num_layers + self.position_bias_num_buckets = position_bias_num_buckets + self.position_bias_num_segment_buckets = position_bias_num_segment_buckets + self.position_bias_max_distance = position_bias_max_distance + self.dropout_p = dropout_p + self.eps = eps + if half: + self.dtype = torch.half + else: + self.dtype = torch.float + self.vocab_size = vocab_size + self.mask_modules = mask_modules + + +class CPMBee(bmt.DistributedModule): + def __init__(selfconfig: CPMBeeConfig): + + super().__init__() + + self.encoder = Encoder( + num_layers=config.num_layers, + dim_model=config.dim_model, + dim_ff=config.dim_ff, + num_heads=config.num_heads, + dim_head=config.dim_head, + dtype=config.dtype, + eps=config.eps, + dropout_p=config.dropout_p, + mask_modules=config.mask_modules, + ) + + self.input_embedding = EmbeddingExt( + vocab_size=config.vocab_size, + embedding_size=config.dim_model, + dtype=config.dtype, + init_std=0.02, + ) + + self.position_bias = BucketPositionBias( + num_heads=config.num_heads, + num_buckets=config.position_bias_num_buckets, + num_segment_bucket=config.position_bias_num_segment_buckets, + max_distance=config.position_bias_max_distance, + dtype=config.dtype, + ) + + def forward( + self, + input: torch.Tensor # (batchseqlen) int32 + input_sub: torch.Tensor # (batchseqlen) int32 + length: torch.Tensor # (batch) int32 + context: torch.Tensor # (batchseqlen) bool + sample_ids: torch.Tensor # (batchseq_len) int32 + num_segments: torch.Tensor # (batchseq_len) int32 + segment: torch.Tensor # (batchseqlen) int32 + segment_rel_offset: torch.Tensor # (batchseq_len) int32 + segment_rel: torch.Tensor # (batchnum_segment_bucket) int32 + span: torch.Tensor # (batchseqlen) int32 + ext_table_ids: torch.Tensor # (ext_table_size) int32 + ext_table_sub: torch.Tensor # (ext_table_size) int32 + ): + batch = input.size(0) + seqlen = input.size(1) + # processing masks and position bias bucket + with torch.no_grad(): + device = input.device + + # calc segment bucket + segment_rel_2d = torch.masked_fill( + segment[::None] * num_segments[::None] + + segment[:None:] + + segment_rel_offset[::None], + ~( + (sample_ids[::None] == sample_ids[:None:]) + & (span[:None:] == span[::None]) + ) # not in the same span or sample + 0 # avoid torch.gather overflow + ).view(batchseqlen * seqlen) + + segment_bucket = torch.gather( + input=segment_rel, + dim=1, + index=segment_rel_2d.long(), + ).view(batchseqlenseqlen) + + segment_bucket.masked_fill_( + ~( + (sample_ids[::None] == sample_ids[:None:]) + & (span[:None:] == span[::None]) + ) # not in the same span or sample + 1 # bucket is used for in-context samples + ) + + # directional mask + directional_mask_2d = torch.arange(seqlendevice=device) <= torch.arange( + seqlendevice=device + ).view(-11) + # sample mask + sample_mask_2d = (sample_ids[::None] == 0) | ( + sample_ids[::None] == sample_ids[:None:] + ) + # context mask + attention_mask = context[:None:] | ( + context[::None].logical_not() & directional_mask_2d.view(1seqlenseqlen) + ) + # span mask + attention_mask = ( + attention_mask & sample_mask_2d & (span[:None:] == span[::None]) + ) + # length mask + mask_1d = ( + torch.arange(seqlendevice=device)[None:].repeat(batch1) < length[:None] + ) + attention_mask = ( + mask_1d.view(batchseqlen1) & mask_1d.view(batch1seqlen) & attention_mask + ) + position = torch.arange(seqlendevice=device).expand(batchseqlen) + + hidden_states = self.input_embedding(inputinput_sub) + position_bias = self.position_bias(positionpositionsegment_bucket) + + hidden_states = self.encoder(hidden_statesattention_maskposition_bias) + + ext_table = self.input_embedding(ext_table_idsext_table_sub) + + logits = self.input_embedding.projection(hidden_statesext_table) + return logitshidden_states + + def inference( + self, + input: torch.Tensor # (batchlen_q) int32 + input_sub: torch.Tensor # (batchlen_q) int32 + position: torch.Tensor # (batchlen_q) int32 + context: torch.Tensor # (batchlen_q) bool + sample_ids: torch.Tensor # (batchlen_q) int32 + num_segments: torch.Tensor # (batchlen_q) int32 + segment: torch.Tensor # (batchlen_q) int32 + segment_rel_offset: torch.Tensor # (batchlen_q) int32 + segment_rel: torch.Tensor # (batchnum_segment_bucket) int32 + ext_table_ids: torch.Tensor # (ext_table_size) int32 + ext_table_sub: torch.Tensor # (ext_table_size) int32 + past_key_values: Optional[CPMBeeInferenceState] = None, + ) -> Tuple[torch.Tensortorch.TensorCPMBeeInferenceState]: + with torch.no_grad(): + if past_key_values is None: + present_position = position + present_context = context + present_sample_ids = sample_ids + present_num_segments = num_segments + present_segments = segment + present_buffer = None + else: + present_position = torch.cat([past_key_values["buffer_position"]position]dim=-1) + present_context = torch.cat([past_key_values["buffer_context"]context]dim=-1) + present_sample_ids = torch.cat( + [past_key_values["buffer_sample_ids"]sample_ids]dim=-1 + ) + present_num_segments = torch.cat( + [past_key_values["buffer_num_segments"]num_segments]dim=-1 + ) + present_segments = torch.cat([past_key_values["buffer_segments"]segment]dim=-1) + present_buffer = past_key_values["buffer"] + + batch = input.size(0) + len_q = input.size(1) + len_buffer = present_position.size(1) + + segment_rel_2d = torch.masked_fill( + segment[::None] * num_segments[::None] + + present_segments[:None:] + + segment_rel_offset[::None], + ~( + (sample_ids[::None] == present_sample_ids[:None:]) + ) # not in the same sample + 0 # avoid torch.gather overflow + ).view(batchlen_q * len_buffer) + + segment_bucket = torch.gather( + input=segment_rel, + dim=1, + index=segment_rel_2d.long(), + ).view(batchlen_qlen_buffer) + + segment_bucket.masked_fill_( + ~( + (sample_ids[::None] == present_sample_ids[:None:]) + ) # not in the same span or sample + 1 # bucket is used for in-context samples + ) + + # directional mask + directional_mask_2d = present_position[:None:] <= position[::None] + # sample mask + sample_mask_2d = (sample_ids[::None] == 0) | ( + sample_ids[::None] == present_sample_ids[:None:] + ) + # context mask + attention_mask = present_context[:None:] | ( + context[::None].logical_not() + & directional_mask_2d.view(batchlen_qlen_buffer) + ) + # span mask + attention_mask = attention_mask & sample_mask_2d + # length mask + mask_1d = present_num_segments != 0 + attention_mask = mask_1d.view(batch1len_buffer) & attention_mask + + hidden_states = gradient_shrink(self.input_embedding(inputinput_sub)) + + position_bias = gradient_shrink( + self.position_bias(positionpresent_positionsegment_bucket) + ) + hidden_statespresent_key_values = self.encoder( + hidden_states, + attention_mask, + position_bias, + True, + present_buffer, + ) + ext_table = gradient_shrink(self.input_embedding(ext_table_idsext_table_sub)) + logits = self.input_embedding.projection(hidden_statesext_table) + + return ( + logits, + hidden_states, + { + "buffer_position": present_position, + "buffer_context": present_context, + "buffer_sample_ids": present_sample_ids, + "buffer_num_segments": present_num_segments, + "buffer_segments": present_segments, + "buffer": present_key_values, + }, + ) diff --git a/cpm-live/cpm_live/models/bee_torch.py b/cpm-live/cpm_live/models/bee_torch.py new file mode 100644 index 0000000..a2e1489 --- /dev/null +++ b/cpm-live/cpm_live/models/bee_torch.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import OptionalTuple +import torch + +from ..native_layers import EncoderEmbeddingExtBucketPositionBias +from .bee import CPMBeeConfigCPMBeeInferenceState + + +class CPMBeeTorch(torch.nn.Module): + def __init__(selfconfig: CPMBeeConfig): + + super().__init__() + + self.encoder = Encoder( + num_layers=config.num_layers, + dim_model=config.dim_model, + dim_ff=config.dim_ff, + num_heads=config.num_heads, + dim_head=config.dim_head, + dtype=config.dtype, + eps=config.eps, + dropout_p=config.dropout_p, + mask_modules=config.mask_modules, + ) + + self.input_embedding = EmbeddingExt( + vocab_size=config.vocab_size, + embedding_size=config.dim_model, + dtype=config.dtype, + init_std=0.02, + ) + + self.position_bias = BucketPositionBias( + num_heads=config.num_heads, + num_buckets=config.position_bias_num_buckets, + num_segment_bucket=config.position_bias_num_segment_buckets, + max_distance=config.position_bias_max_distance, + dtype=config.dtype, + ) + + def forward( + self, + input: torch.Tensor # (batchseqlen) int32 + input_sub: torch.Tensor # (batchseqlen) int32 + length: torch.Tensor # (batch) int32 + context: torch.Tensor # (batchseqlen) bool + sample_ids: torch.Tensor # (batchseq_len) int32 + num_segments: torch.Tensor # (batchseq_len) int32 + segment: torch.Tensor # (batchseqlen) int32 + segment_rel_offset: torch.Tensor # (batchseq_len) int32 + segment_rel: torch.Tensor # (batchnum_segment_bucket) int32 + span: torch.Tensor # (batchseqlen) int32 + ext_table_ids: torch.Tensor # (ext_table_size) int32 + ext_table_sub: torch.Tensor # (ext_table_size) int32 + ): + batch = input.size(0) + seqlen = input.size(1) + # processing masks and position bias bucket + with torch.no_grad(): + device = input.device + + # calc segment bucket + segment_rel_2d = torch.masked_fill( + segment[::None] * num_segments[::None] + + segment[:None:] + + segment_rel_offset[::None], + ~( + (sample_ids[::None] == sample_ids[:None:]) + & (span[:None:] == span[::None]) + ) # not in the same span or sample + 0 # avoid torch.gather overflow + ).view(batchseqlen * seqlen) + + segment_bucket = torch.gather( + input=segment_rel, + dim=1, + index=segment_rel_2d.long(), + ).view(batchseqlenseqlen) + + segment_bucket.masked_fill_( + ~( + (sample_ids[::None] == sample_ids[:None:]) + & (span[:None:] == span[::None]) + ) # not in the same span or sample + 1 # bucket is used for in-context samples + ) + + # directional mask + directional_mask_2d = torch.arange(seqlendevice=device) <= torch.arange( + seqlendevice=device + ).view(-11) + # sample mask + sample_mask_2d = (sample_ids[::None] == 0) | ( + sample_ids[::None] == sample_ids[:None:] + ) + # context mask + attention_mask = context[:None:] | ( + context[::None].logical_not() & directional_mask_2d.view(1seqlenseqlen) + ) + # span mask + attention_mask = ( + attention_mask & sample_mask_2d & (span[:None:] == span[::None]) + ) + # length mask + mask_1d = ( + torch.arange(seqlendevice=device)[None:].repeat(batch1) < length[:None] + ) + attention_mask = ( + mask_1d.view(batchseqlen1) & mask_1d.view(batch1seqlen) & attention_mask + ) + position = torch.arange(seqlendevice=device).expand(batchseqlen) + + hidden_states = self.input_embedding(inputinput_sub) + position_bias = self.position_bias(positionpositionsegment_bucket) + + hidden_states = self.encoder(hidden_statesattention_maskposition_bias) + + ext_table = self.input_embedding(ext_table_idsext_table_sub) + + logits = self.input_embedding.projection(hidden_statesext_table) + return logitshidden_states + + def inference( + self, + input: torch.Tensor # (batchlen_q) int32 + input_sub: torch.Tensor # (batchlen_q) int32 + position: torch.Tensor # (batchlen_q) int32 + context: torch.Tensor # (batchlen_q) bool + sample_ids: torch.Tensor # (batchlen_q) int32 + num_segments: torch.Tensor # (batchlen_q) int32 + segment: torch.Tensor # (batchlen_q) int32 + segment_rel_offset: torch.Tensor # (batchlen_q) int32 + segment_rel: torch.Tensor # (batchnum_segment_bucket) int32 + ext_table_ids: torch.Tensor # (ext_table_size) int32 + ext_table_sub: torch.Tensor # (ext_table_size) int32 + past_key_values: Optional[CPMBeeInferenceState] = None, + ) -> Tuple[torch.Tensortorch.TensorCPMBeeInferenceState]: + with torch.no_grad(): + if past_key_values is None: + present_position = position + present_context = context + present_sample_ids = sample_ids + present_num_segments = num_segments + present_segments = segment + present_buffer = None + else: + present_position = torch.cat([past_key_values["buffer_position"]position]dim=-1) + present_context = torch.cat([past_key_values["buffer_context"]context]dim=-1) + present_sample_ids = torch.cat( + [past_key_values["buffer_sample_ids"]sample_ids]dim=-1 + ) + present_num_segments = torch.cat( + [past_key_values["buffer_num_segments"]num_segments]dim=-1 + ) + present_segments = torch.cat([past_key_values["buffer_segments"]segment]dim=-1) + present_buffer = past_key_values["buffer"] + + batch = input.size(0) + len_q = input.size(1) + len_buffer = present_position.size(1) + + segment_rel_2d = torch.masked_fill( + segment[::None] * num_segments[::None] + + present_segments[:None:] + + segment_rel_offset[::None], + ~( + (sample_ids[::None] == present_sample_ids[:None:]) + ) # not in the same sample + 0 # avoid torch.gather overflow + ).view(batchlen_q * len_buffer) + + segment_bucket = torch.gather( + input=segment_rel, + dim=1, + index=segment_rel_2d.long(), + ).view(batchlen_qlen_buffer) + + segment_bucket.masked_fill_( + ~( + (sample_ids[::None] == present_sample_ids[:None:]) + ) # not in the same span or sample + 1 # bucket is used for in-context samples + ) + + # directional mask + directional_mask_2d = present_position[:None:] <= position[::None] + # sample mask + sample_mask_2d = (sample_ids[::None] == 0) | ( + sample_ids[::None] == present_sample_ids[:None:] + ) + # context mask + attention_mask = present_context[:None:] | ( + context[::None].logical_not() + & directional_mask_2d.view(batchlen_qlen_buffer) + ) + # span mask + attention_mask = attention_mask & sample_mask_2d + # length mask + mask_1d = present_num_segments != 0 + attention_mask = mask_1d.view(batch1len_buffer) & attention_mask + + hidden_states = self.input_embedding(inputinput_sub) + + position_bias = self.position_bias(positionpresent_positionsegment_bucket) + hidden_statespresent_key_values = self.encoder( + hidden_states, + attention_mask, + position_bias, + True, + present_buffer, + ) + ext_table = self.input_embedding(ext_table_idsext_table_sub) + logits = self.input_embedding.projection(hidden_statesext_table) + + return ( + logits, + hidden_states, + { + "buffer_position": present_position, + "buffer_context": present_context, + "buffer_sample_ids": present_sample_ids, + "buffer_num_segments": present_num_segments, + "buffer_segments": present_segments, + "buffer": present_key_values, + }, + ) diff --git a/cpm-live/cpm_live/native_layers/__init__.py b/cpm-live/cpm_live/native_layers/__init__.py new file mode 100644 index 0000000..dbd8dcd --- /dev/null +++ b/cpm-live/cpm_live/native_layers/__init__.py @@ -0,0 +1,8 @@ +from .embedding import EmbeddingEmbeddingExt +from .position_embedding import SegmentPositionEmbeddingBucketPositionBiasRotaryEmbedding +from .linear import Linear +from .layernorm import LayerNorm +from .attention import Attention +from .feedforward import FeedForward +from .blocks import TransformerBlock +from .transformer import Encoder diff --git a/cpm-live/cpm_live/native_layers/attention.py b/cpm-live/cpm_live/native_layers/attention.py new file mode 100644 index 0000000..5b9f402 --- /dev/null +++ b/cpm-live/cpm_live/native_layers/attention.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import OptionalTuple +import torch +import math +from .linear import Linear + + +class Attention(torch.nn.Module): + def __init__( + self, + dim_model: int, + num_heads: int, + dim_head: int, + dtype: torch.dtype = torch.half, + dropout_p: Optional[float] = None, + ) -> None: + + super().__init__() + + self.dim_model = dim_model + self.num_heads = num_heads + self.dim_head = dim_head + + self.project_q = Linear(self.dim_modelself.num_heads * self.dim_headdtype=dtype) + self.project_k = Linear(self.dim_modelself.num_heads * self.dim_headdtype=dtype) + self.project_v = Linear(self.dim_modelself.num_heads * self.dim_headdtype=dtype) + + self.attention_out = Linear(self.num_heads * self.dim_headself.dim_modeldtype=dtype) + + self.softmax = torch.nn.Softmax(dim=-1) + + if dropout_p is not None: + self.dropout = torch.nn.Dropout(p=dropout_p) + else: + self.dropout = None + + def forward( + self, + hidden_q: torch.Tensor, + hidden_kv: torch.Tensor, + attention_mask: torch.BoolTensor, + position_bias: torch.Tensor, + use_cache: bool = False, + past_kv: Optional[Tuple[torch.Tensortorch.Tensor]] = None, + ): + """ + Args: + hidden_q (:obj:`torch.Tensor` of shape ``(batchlen_qdim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix. + hidden_kv (:obj:`torch.Tensor` of shape ``(batchlen_kdim_model)``): Length of input sequence before padding. + attention_mask (:obj:`torch.Tensor` of shape ``(batchlen_qlen_k)``): Used to avoid performing attention on padding token indices. + position_bias(:obj:`torch.Tensor` of shape ``(num_headslen_qlen_k)`` or ``(1num_headslen_klen_q)``): Provide positional information about tensor `key_value` and `query`. + Return: + out (:obj:`torch.Tensor` of shape ``(batchlen_qdim_model)``): The attention output. + """ # noqa: E501 + + batch_size = hidden_q.size(0) + len_q = hidden_q.size(1) + len_k = hidden_kv.size(1) + + h_q = self.project_q(hidden_q) + h_k = self.project_k(hidden_kv) + h_v = self.project_v(hidden_kv) + + h_q = h_q.view(batch_sizelen_qself.num_headsself.dim_head).permute(0213) + h_k = h_k.view(batch_sizelen_kself.num_headsself.dim_head).permute(0213) + h_v = h_v.view(batch_sizelen_kself.num_headsself.dim_head).permute(0213) + + if past_kv is not None: + h_k = torch.cat([past_kv[0]h_k]dim=-2) + h_v = torch.cat([past_kv[1]h_v]dim=-2) + len_k = h_k.size(-2) + + # (bn_hlen_qd_h) @ (bn_hd_hlen_k) -> (bn_hlen_qlen_k) + score = torch.matmul(h_qh_k.transpose(-1-2)) / math.sqrt(self.dim_head) + score = score + position_bias + score = torch.masked_fill( + score, + attention_mask.view(batch_size1len_qlen_k) == False, + torch.scalar_tensor(float("-inf")device=score.devicedtype=score.dtype), + ) + + score = self.softmax(score) + + score = torch.masked_fill( + score, + attention_mask.view(batch_size1len_qlen_k) == False, + torch.scalar_tensor(0device=score.devicedtype=score.dtype), + ) + + if self.dropout is not None: + score = self.dropout(score) + + # (bn_hlen_qlen_k) @ (bn_hlen_kd_h) -> (bn_hlen_qd_h) + score = torch.matmul(scoreh_v) + + score = score.view(batch_sizeself.num_headslen_qself.dim_head).permute(0213) + score = score.contiguous().view(batch_sizelen_qself.num_heads * self.dim_head) + + score = self.attention_out(score) + if use_cache: + return score(h_kh_v) + else: + return score diff --git a/cpm-live/cpm_live/native_layers/blocks.py b/cpm-live/cpm_live/native_layers/blocks.py new file mode 100644 index 0000000..9c6efcb --- /dev/null +++ b/cpm-live/cpm_live/native_layers/blocks.py @@ -0,0 +1,248 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import OptionalTuple +import torch +from .layernorm import LayerNorm +from .attention import Attention +from .feedforward import FeedForward + + +class SelfAttentionBlock(torch.nn.Module): + """The whole cross-attention block. A sequence of operation. Consists of layernormself-attention and residual connection. + + Args: + dim_model (int): main dimension of modules in transformer blocks. + num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`. + dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`. + dtype (optional): Defaults to torch.half. + eps (floatoptional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. + dropout_p (floatoptional): Defaults to 0. + """ # noqa: E501 + + def __init__( + self, + dim_model: int, + num_heads: int, + dim_head: int, + dtype=torch.half, + eps: float = 1e-6, + dropout_p: Optional[float] = None, + ): + + super().__init__() + + self.layernorm_before_attention = LayerNorm( + dim_model, + dtype=dtype, + eps=eps, + ) + + self.self_attention = Attention( + dim_model=dim_model, + num_heads=num_heads, + dim_head=dim_head, + dtype=dtype, + dropout_p=dropout_p, + ) + + if dropout_p: + self.dropout = torch.nn.Dropout(dropout_p) + else: + self.dropout = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_bias: Optional[torch.Tensor] = None, + use_cache: bool = False, + past_key_value: Optional[Tuple[torch.Tensortorch.Tensor]] = None, + ): + """ + Args: + hidden_states (:obj:`torch.Tensor` of shape ``(batchseq_selfdim_model)``): Input of self-attention block. It can be the embedding of a batch of sequences. + attention_mask (:obj:`torch.Tensor` of shape ``(batchseq_selfseq_self)``): Avoid invalid areas to participate in the calculation. + position_bias (:obj:`torch.Tensor` of shape ``(num_headsseq_selfseq_self)``): Provide positional information to self-attention block. + + Return: + :obj:`torch.Tensor` of shape ``(batchseq_selfdim_model)``: The output of attention block. + + """ # noqa: E501 + x = self.layernorm_before_attention(hidden_states) + x = self.self_attention(xxattention_maskposition_biasuse_cachepast_key_value) + if use_cache: + xcurrent_key_value = x + else: + current_key_value = None + + if self.dropout is not None: + x = self.dropout(x) + hidden_states = (hidden_states + x) / 1.05 + + if use_cache: + return hidden_statescurrent_key_value + else: + return hidden_states + + +class FFNBlock(torch.nn.Module): + """The whole feed-forward block. A sequence of operation. Consists of layernormfeed-forward and residual connection. + + Args: + dim_model (int): main dimension of modules in transformer blocks. + dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`. + dtype (optional): Defaults to torch.half. + eps (floatoptional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. + dropout_p (floatoptional): Defaults to 0. + """ # noqa: E501 + + def __init__( + self, + dim_model: int, + dim_ff: int, + dtype=torch.half, + eps: float = 1e-6, + dropout_p: Optional[float] = 0, + ): + super().__init__() + + self.layernorm_before_ffn = LayerNorm( + dim_model, + dtype=dtype, + eps=eps, + ) + + self.ffn = FeedForward( + dim_model, + dim_ff, + dtype=dtype, + dropout_p=dropout_p, + ) + + if dropout_p: + self.dropout = torch.nn.Dropout(dropout_p) + else: + self.dropout = None + + def forward( + self, + hidden_states: torch.Tensor, + ): + """ + Args: + hidden_states (:obj:`torch.Tensor` of shape ``(batchseq_selfdim_model)``): Hidden states before feed forward layer. + + Return: + :obj:`torch.Tensor` of shape ``(batchseq_selfdim_model)``: The output of feed-forward block + + """ # noqa: E501 + x = self.layernorm_before_ffn(hidden_states) + x = self.ffn(x) + if self.dropout is not None: + x = self.dropout(x) + hidden_states = (hidden_states + x) / 1.05 + return hidden_states + + +class TransformerBlock(torch.nn.Module): + """The whole transformer block. A sequence of operation. Consists of self-attention block[cross-attention block] and feed-forward block. + + Args: + dim_model (int): main dimension of modules in transformer blocks. + dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`. + num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`. + dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`. + dtype (optional): Defaults to torch.half. + eps (floatoptional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. + dropout_p (floatoptional): Defaults to 0. + """ # noqa: E501 + + def __init__( + self, + dim_model: int, + dim_ff: int, + num_heads: int, + dim_head: int, + dtype=torch.half, + eps: float = 1e-6, + dropout_p: Optional[float] = None, + mask_att: bool = False, + mask_ffn: bool = False, + ): + super().__init__() + self.mask_att = mask_att + self.mask_ffn = mask_ffn + + if not self.mask_att: + self.self_att = SelfAttentionBlock( + dim_model=dim_model, + num_heads=num_heads, + dim_head=dim_head, + dtype=dtype, + eps=eps, + dropout_p=dropout_p, + ) + + if not self.mask_ffn: + self.ffn = FFNBlock( + dim_model=dim_model, + dim_ff=dim_ff, + dtype=dtype, + eps=eps, + dropout_p=dropout_p, + ) + + def forward( + self, + self_hidden_states: torch.Tensor, + self_attention_mask: torch.Tensor, + self_position_bias: Optional[torch.Tensor] = None, + use_cache: bool = False, + past_key_value: Optional[Tuple[torch.Tensortorch.Tensor]] = None, + ): + """ + Args: + self_hidden_states (:obj:`torch.Tensor` of shape ``(batchseq_selfdim_model)``): Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences. + self_attention_mask (:obj:`torch.Tensor` of shape ``(batchseq_selfseq_self)``): Avoid invalid areas to participate in the calculation of self-attention. + self_position_bias (:obj:`torch.Tensor` of shape ``(num_headsseq_selfseq_self)``): Provide positional information to self-attention block. + + Return: + :obj:`torch.Tensor` of shape ``(batchseq_selfdim_model)``: The output of transformer block. + + """ # noqa: E501 + # (batchdim_modelseq_self) + current_key_value = None + if not self.mask_att: + hidden_states = self.self_att( + self_hidden_states, + attention_mask=self_attention_mask, + position_bias=self_position_bias, + use_cache=use_cache, + past_key_value=past_key_value, + ) + if use_cache: + hidden_statescurrent_key_value = hidden_states + else: + hidden_states = self_hidden_states + + # (batchdim_modelseq_self) + if not self.mask_ffn: + hidden_states = self.ffn(hidden_states) + + if use_cache: + return hidden_statescurrent_key_value + else: + return hidden_states diff --git a/cpm-live/cpm_live/native_layers/embedding.py b/cpm-live/cpm_live/native_layers/embedding.py new file mode 100644 index 0000000..7f0fb64 --- /dev/null +++ b/cpm-live/cpm_live/native_layers/embedding.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import math +import torch.nn.functional as F +from .position_embedding import RotaryEmbedding +from typing import Optional + + +class Embedding(torch.nn.Module): + def __init__( + self, + vocab_size: int, + embedding_size: int, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1, + ): + + super().__init__() + + self.dim_model = embedding_size + self.weight = torch.nn.parameter.Parameter( + torch.empty(vocab_sizeembedding_sizedtype=dtype) + ) + + def forward(selfids: torch.Tensor): + """ + Args: + ids (:obj:`torch.Tensor` of shape ``(batch_sizeseq_len)``): Indices of input sequence tokens. + Return: + :obj:`torch.Tensor` of shape ``(batch_sizeseq_lenembedding_size)``: The embedding output. + """ # noqa: E501 + + embeds = F.embedding(idsself.weight) / math.sqrt(self.dim_model) + return embeds + + def projection(selfx: torch.Tensor): + """ + Projection based on embedding's weight. For exampleembedding map vocab_size to embed_sizethan projection map embed_size back to vocab_size. + Args: + x (:obj:`torch.Tensor` of shape ``(batchseq_lendim_model)``): Input of projection + Returns: + :obj:`torch.Tensor` of shape ``(batchseq_lenvocab_output_size)``: The projection output. + """ # noqa: E501 + logits = F.linear(x / math.sqrt(self.dim_model)self.weight) + return logits + + +class EmbeddingExt(torch.nn.Module): + def __init__( + self, + vocab_size: int, + embedding_size: int, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1, + distance_scale: int = 16, + ): + + super().__init__() + + self.dim_model = embedding_size + self.rotary_emb = RotaryEmbedding( + dim=embedding_sizedistance_scale=distance_scaledtype=dtype + ) + + self.weight = torch.nn.parameter.Parameter( + torch.empty(vocab_sizeembedding_sizedtype=dtype), + ) + + def forward(selfids: torch.Tensorids_sub: torch.Tensor): + """ + Args: + ids (:obj:`torch.Tensor` of shape ``(batch_sizeseq_len)``): Indices of input sequence tokens. + ids (:obj:`torch.Tensor` of shape ``(batch_size)``): Subscript of input sequence tokens. + Return: + :obj:`torch.Tensor` of shape ``(batch_sizeseq_lenembedding_size)``: The embedding output. + """ # noqa: E501 + + embeds = F.embedding(idsself.weight) / math.sqrt(self.dim_model) + return self.rotary_emb(embedsids_sub) + + def projection(selfx: torch.Tensorext_table: Optional[torch.Tensor] = None): + """ + Projection based on embedding's weight. For exampleembedding map vocab_size to embed_sizethan projection map embed_size back to vocab_size. + Args: + x (:obj:`torch.Tensor` of shape ``(batchseq_lendim_model)``): Input of projection + ext_table (:obj:`torch.Tensor` of shape ``(ext_table_sizedim_model)``): Ext vocab table. + Returns: + :obj:`torch.Tensor` of shape ``(batchseq_lenvocab_size + ext_table_size)``: The projection output. + """ # noqa: E501 + logits = F.linear(x / math.sqrt(self.dim_model)self.weight) + if ext_table is not None: + logits_ext = F.linear(xext_table) + logits = torch.cat([logitslogits_ext]dim=-1) + return logits diff --git a/cpm-live/cpm_live/native_layers/feedforward.py b/cpm-live/cpm_live/native_layers/feedforward.py new file mode 100644 index 0000000..b624592 --- /dev/null +++ b/cpm-live/cpm_live/native_layers/feedforward.py @@ -0,0 +1,120 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +import torch +from .linear import Linear + + +class DenseGatedACT(torch.nn.Module): + def __init__( + self, + dim_in: int, + dim_ff: int, + dtype=torch.half, + ): + super().__init__() + + self.w_0 = Linear( + dim_in=dim_in, + dim_out=dim_ff, + dtype=dtype, + scale_before=False, + ) + + self.w_1 = Linear( + dim_in=dim_in, + dim_out=dim_ff, + dtype=dtype, + scale_before=False, + ) + self.act = torch.nn.GELU() + + def forward(selfx: torch.Tensor): + """Transform an input tensor from one feature space to another via a nonlinear operation + + Args: + x (:obj:`torch.Tensor` of shape ``(batchseq_lendim_in)``): Tensor that will be subject to nonlinear operations. + + Return: + out (:obj:`torch.Tensor` of shape ``(batchseq_lendim_ff)``) + + """ # noqa: E501 + gate_score = self.act(self.w_0(x)) + x = self.w_1(x) + + x = gate_score * x + return x + + +class FeedForward(torch.nn.Module): + r"""FeedForward module + + Args: + dim_in (int): input dimension. + dim_ff (int): middle dimension. + dim_out (intoptional): output dimension. Defaults to Nonewhich means dim_in = dim_out. + dtype (optional): Defaults to torch.half. + init_mean (floatoptional): mean of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}\text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0. + init_std (floatoptional): std of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}\text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.02. + bias (booloptional): whether to use bias term in fully-connected layers used in feed-forward module. Defaults to False. + activate_fn (stroptional): Defaults to `gated_gelu`. + dropout_p (intoptional): Defaults to 0. + """ # noqa: E501 + + def __init__( + self, + dim_model: int, + dim_ff: int, + dtype=torch.half, + dropout_p: Optional[float] = None, + ): + + super().__init__() + + self.w_in = DenseGatedACT( + dim_in=dim_model, + dim_ff=dim_ff, + dtype=dtype, + ) + + if dropout_p is not None: + self.dropout = torch.nn.Dropout(dropout_p) + else: + self.dropout = None + + self.w_out = Linear( + dim_in=dim_ff, + dim_out=dim_model, + dtype=dtype, + scale_before=False, + ) + + def forward(selfx: torch.Tensor): + """ + Args: + x (:obj:`torch.Tensor` of shape ``(batchseq_lendim_in)``): The input of feed-forward module. + + Return: + :obj:`torch.Tensor` of shape ``(batchseq_lendim_out)``: The output of feed-forward module. + """ # noqa: E501 + x = self.w_in(x) + + if self.dropout is not None: + x = self.dropout(x) + + x = self.w_out(x) + + return x diff --git a/cpm-live/cpm_live/native_layers/layernorm.py b/cpm-live/cpm_live/native_layers/layernorm.py new file mode 100644 index 0000000..e8f19e9 --- /dev/null +++ b/cpm-live/cpm_live/native_layers/layernorm.py @@ -0,0 +1,37 @@ +import torch + + [email protected] # type: ignore +def rms_layernorm(hidden: torch.Tensorweight: torch.Tensoreps: float): + old_dtype = hidden.dtype + variance = hidden.to(torch.float32).pow(2).mean(dim=-1keepdim=True) + hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype) + return hidden * weight + + +class LayerNorm(torch.nn.Module): + """RMS LayerNorm""" + + def __init__( + self, + dim_norm: int, + dtype: torch.dtype = torch.half, + eps: float = 1e-6, + init_var: float = 1.0, + ): + + super().__init__() + + self.eps = eps + self.dim_norm = dim_norm + self.weight = torch.nn.parameter.Parameter(torch.full((dim_norm,)init_vardtype=dtype)) + + def forward(selfx: torch.Tensor): + """ + Args: + x (:obj:`torch.Tensor` of shape ``(batch_sizeseq_lendim_norm)``): Input tensor that need to be normalized. + Return: + :obj:`torch.Tensor` of shape ``(batch_sizeseq_lendim_norm)``: The layernorm output. + """ # noqa: E501 + assert x.size(-1) == self.dim_norm + return rms_layernorm(xself.weightself.eps) diff --git a/cpm-live/cpm_live/native_layers/linear.py b/cpm-live/cpm_live/native_layers/linear.py new file mode 100644 index 0000000..fe904ae --- /dev/null +++ b/cpm-live/cpm_live/native_layers/linear.py @@ -0,0 +1,51 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import math +import torch.nn.functional as F + + +class Linear(torch.nn.Module): + def __init__( + self, + dim_in: int, + dim_out: int, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1, + scale_before: bool = False, + ): + super().__init__() + self.dim_in = self.in_features = dim_in + self.dim_out = self.out_features = dim_out + self.scale_before = scale_before + + self.weight = torch.nn.parameter.Parameter(torch.empty((dim_outdim_in)dtype=dtype)) + + def forward(selfx: torch.Tensor): + """ + Args: + x (:obj:`torch.Tensor` of shape ``(batchseq_lendim_in)``): The input of linear layer + Returns: + :obj:`torch.Tensor` of shape ``(batchseq_lendim_out)``: The output of the linear transform y. + """ # noqa: E501 + if self.scale_before: + x = x / math.sqrt(self.dim_in) + x = F.linear(xself.weight) + else: + x = F.linear(xself.weight) + x = x / math.sqrt(self.dim_in) + return x diff --git a/cpm-live/cpm_live/native_layers/position_embedding.py b/cpm-live/cpm_live/native_layers/position_embedding.py new file mode 100644 index 0000000..3778f3e --- /dev/null +++ b/cpm-live/cpm_live/native_layers/position_embedding.py @@ -0,0 +1,247 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Union +import torch +import torch.nn.functional as F + + +class SegmentPositionEmbedding(torch.nn.Module): + def __init__( + self, + num_heads, + num_segments=1, + num_buckets=32, + max_distance=128, + bidirectional=False, + dtype=torch.half, + init_mean: float = 0.0, + init_std: float = 1, + ): + + super().__init__() + + self.num_heads = num_heads + self.num_buckets = num_buckets + self.max_distance = max_distance + self.bidirectional = bidirectional + self.num_segments = num_segments + + self.relative_attention_bias = torch.nn.parameter.Parameter( + torch.empty(num_segments * num_segments + num_bucketsnum_headsdtype=dtype) + ) + + def forward( + self, + key_pos: torch.Tensor, + query_pos: torch.Tensor, + key_segment: torch.Tensor, + query_segment: torch.Tensor, + ): + with torch.no_grad(): + + batch = key_pos.size(0) + keylen = key_pos.size(1) + querylen = query_pos.size(1) + + assert key_pos.size(0) == query_pos.size(0) + assert keylen == key_segment.size(1) and querylen == query_segment.size(1) + + key_pos = key_pos.view(batch-1keylen) + query_pos = query_pos.view(batchquerylen-1) + key_segment = key_segment.view(batch-1keylen) + query_segment = query_segment.view(batchquerylen-1) + + relative_position_bucket = self._segment_relative_position_bucket( + query_segmentkey_segment + ) + relative_position_bucket = relative_position_bucket + self.num_buckets # 与相对位置编码区间不重叠 + + # b*q*k + absolute_position_bucket = self._position_bucket( + torch.arange(keylendtype=torch.int32device=relative_position_bucket.device)[ + None: + ] + - torch.arange(querylendtype=torch.int32device=relative_position_bucket.device)[ + :None + ], + bidirectional=self.bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + relative_position_bucket = torch.where( + (key_segment == query_segment), + absolute_position_bucket[None::], + relative_position_bucket, + ) + # (batchlen_qlen_k) + + # (batchlen_qlen_knum_heads) + embeds = F.embedding(relative_position_bucketself.relative_attention_bias) + # (batchnum_headslen_qlen_k) + embeds = embeds.permute(0312).contiguous() + return embeds + + def _segment_relative_position_bucket(selfquery_segmentkey_segment): + return query_segment * self.num_segments + key_segment + + def _position_bucket( + selfrelative_positionbidirectional=Truenum_buckets=32max_distance=128 + ): + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_positiontorch.zeros_like(relative_position)) + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.int32) + relative_postion_if_large = torch.min( + relative_postion_if_large, + torch.full_like(relative_postion_if_largenum_buckets - 1), + ) + relative_buckets += torch.where( + is_smallrelative_position.to(torch.int32)relative_postion_if_large + ) + return relative_buckets + + +class BucketPositionBias(torch.nn.Module): + def __init__( + self, + num_heads: int, + num_buckets: int = 32, + num_segment_bucket: int = 32, + max_distance: int = 128, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1, + ) -> None: + super().__init__() + + self.num_heads = num_heads + self.num_buckets = num_buckets + self.num_segment_bucket = num_segment_bucket + self.max_distance = max_distance + + self.relative_attention_bias = torch.nn.parameter.Parameter( + torch.empty(num_buckets + num_segment_bucketnum_headsdtype=dtype) + ) + + def forward( + self, + query_pos: torch.Tensor # (batchlen_q) + key_pos: torch.Tensor # (batchlen_k) + rel_buckets: torch.Tensor # (batchlen_qlen_k) + ): + with torch.no_grad(): + + batch = key_pos.size(0) + keylen = key_pos.size(1) + querylen = query_pos.size(1) + + assert key_pos.size(0) == query_pos.size(0) + assert ( + rel_buckets.size(0) == batch + and rel_buckets.size(1) == querylen + and rel_buckets.size(2) == keylen + ) + + relative_position_bucket = rel_buckets - 1 + self.num_buckets # 与相对位置编码区间不重叠 + + # b*q*k + inner_segment_bucket = self._position_bucket( + key_pos[...None:] - query_pos[...:None], + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + relative_position_bucket = torch.where( + rel_buckets == 0, + inner_segment_bucket, + relative_position_bucket, + ) + # (batchlen_qlen_k) + + # (batchlen_qlen_knum_heads) + embeds = F.embedding(relative_position_bucketself.relative_attention_bias) + # (batchnum_headslen_qlen_k) + embeds = embeds.permute(0312).contiguous() + return embeds + + def _position_bucket(selfrelative_positionnum_buckets=32max_distance=128): + relative_buckets = 0 + num_buckets //= 2 + relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets + relative_position = torch.abs(relative_position) + + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.int32) + relative_postion_if_large = torch.min( + relative_postion_if_large, + torch.full_like(relative_postion_if_largenum_buckets - 1), + ) + relative_buckets += torch.where( + is_smallrelative_position.to(torch.int32)relative_postion_if_large + ) + return relative_buckets + + +class RotaryEmbedding(torch.nn.Module): + def __init__( + self, + dim, + base=10000, + distance_scale: Union[intfloat] = 1, + dtype: torch.dtype = torch.half, + ): + super().__init__() + inv_freq = 1.0 / ( + base ** (torch.arange(0dim2device="cuda"dtype=torch.float32) / dim) + ) + inv_freq = inv_freq.to(dtype) + self.distance_scale = distance_scale + self.dtype = dtype + self.inv_freq = inv_freq + + def forward(selfx: torch.Tensorx_pos: torch.Tensor): + """ + Args: + x (:obj:`torch.Tensor` of shape ``(...dim)``): Inputs. + x_pos (:obj:`torch.Tensor` of shape ``(...)``): Positions of inputs. + """ + x_pos = x_pos * self.distance_scale + freqs = x_pos[...None].to(self.dtype) * self.inv_freq[None:] # (...dim/2) + + # the same implementation as sat + emb = torch.cat((freqsfreqs)dim=-1) # (...dim) + emb_cos = emb.cos() # (...dim) + emb_sin = emb.sin() # (...dim) + + rotate_x = torch.cat( + [-x[...x.size(-1) // 2 :]x[...: x.size(-1) // 2]]dim=-1 + ) # (...dim) + + return x * emb_cos + rotate_x * emb_sin diff --git a/cpm-live/cpm_live/native_layers/transformer.py b/cpm-live/cpm_live/native_layers/transformer.py new file mode 100644 index 0000000..b21445c --- /dev/null +++ b/cpm-live/cpm_live/native_layers/transformer.py @@ -0,0 +1,125 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from typing import OptionalListTuple + +from .blocks import TransformerBlock +from .layernorm import LayerNorm + + +class Encoder(torch.nn.Module): + """Layers of encoder transformer blocks plus an final layernorm. + + Args: + num_layers (int): number of layers. + dim_model (int): main dimension of modules in transformer blocks. + dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`. + num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`. + dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`. + dtype (optional): Defaults to torch.half. + eps (floatoptional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-6. + dropout_p (floatoptional): Defaults to 0. + """ # noqa: E501 + + def __init__( + self, + num_layers: int, + dim_model: int, + dim_ff: int, + num_heads: int, + dim_head: int, + dtype: torch.dtype = torch.half, + eps: float = 1e-6, + dropout_p: Optional[float] = None, + mask_modules: Optional[List[Tuple[boolbool]]] = None, + ): + + super().__init__() + + self.num_layers = num_layers + + if mask_modules is not None: + assert ( + len(mask_modules) == num_layers + )"The total number of masks should equal to num_layers" + for mask_module in mask_modules: + assert ( + len(mask_module) == 2 + )"For encodereach mask should be (mask_attmask_ffn)" + else: + mask_modules = [(FalseFalse)] * num_layers + + self.layers = torch.nn.ModuleList( + [ + TransformerBlock( + dim_model=dim_model, + dim_ff=dim_ff, + num_heads=num_heads, + dim_head=dim_head, + dtype=dtype, + eps=eps, + dropout_p=dropout_p, + mask_att=mask_modules[ith][0], + mask_ffn=mask_modules[ith][1], + ) + for ith in range(num_layers) + ] + ) + + self.output_layernorm = LayerNorm(dim_norm=dim_modeldtype=dtypeeps=eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_bias: torch.Tensor, + use_cache: bool = False, + past_key_values: Optional[List[Tuple[torch.Tensortorch.Tensor]]] = None, + ): + """ + Args: + hidden-states (:obj:`torch.Tensor` of shape ``(batchseq_encdim_model)``): Input of encodermight be the embedding of a batch of sequences. + attention_mask (:obj:`torch.Tensor` of shape ``(batchseq_encseq_enc)``): Avoid invalid areas to participate in the calculation + position_bias(:obj:`torch.Tensor` of shape ``(num_headsseq_encseq_enc)``) Provides position information to attention mechanism. + + Return: + :obj:`torch.Tensor` of shape ``(batchseq_encdim_model)``: The encoder output. + + """ # noqa: E501 + if not use_cache: + for layer in self.layers: + hidden_states = layer(hidden_statesattention_maskposition_bias) + hidden_states = self.output_layernorm(hidden_states) + return hidden_states + else: + with torch.no_grad(): + current_key_values = [] + for imodule in enumerate(self.layers): + hidden_states = module( + hidden_states, + attention_mask, + position_bias, + past_key_value=past_key_values[i] if past_key_values else None, + use_cache=use_cache, + ) + if use_cache: + current_key_values.append(hidden_states[1]) + hidden_states = hidden_states[0] + hidden_states = self.output_layernorm(hidden_states) + if use_cache: + return hidden_statescurrent_key_values + else: + return hidden_states diff --git a/cpm-live/cpm_live/tokenizers/__init__.py b/cpm-live/cpm_live/tokenizers/__init__.py index 1c2bb3a..c664def 100644 --- a/cpm-live/cpm_live/tokenizers/__init__.py +++ b/cpm-live/cpm_live/tokenizers/__init__.py @@ -1 +1,2 @@ from .ant import CPMAntTokenizer +from .bee import CPMBeeTokenizer diff --git a/cpm-live/cpm_live/tokenizers/ant.py b/cpm-live/cpm_live/tokenizers/ant.py index bcbf826..d127ec6 100644 --- a/cpm-live/cpm_live/tokenizers/ant.py +++ b/cpm-live/cpm_live/tokenizers/ant.py @@ -1,4 +1,18 @@ # coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import jieba import pkg_resources import io @@ -111,6 +125,10 @@ def pad_id(self): def unk_id(self): return self.encoder[self.unk_token] + @property + def newline_id(self): + return self.encoder["\n"] + def __len__(self): return len(self.encoder) diff --git a/cpm-live/cpm_live/tokenizers/bee.py b/cpm-live/cpm_live/tokenizers/bee.py new file mode 100644 index 0000000..6078f6b --- /dev/null +++ b/cpm-live/cpm_live/tokenizers/bee.py @@ -0,0 +1,223 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pkg_resources +import io +from typing import IODictListOptionalTuple + + +def load_vocab(fp: IO[bytes]) -> Dict[strint]: + """Loads a vocabulary file into a dictionary.""" + vocab: Dict[strint] = {} + + reader = io.TextIOWrapper(fpencoding="utf-8") + for token in reader.readlines(): + if token[-1] == "\n": + token = token[:-1] + if len(token) == 0: + continue + vocab[token] = len(vocab) + return vocab + + +class Token(object): + def __init__(selftoken: strstart: intis_unk: boolis_special: bool): + self.token = token + self.start = start + self.is_unk = is_unk + self.is_special = is_special + + def __str__(self): + return "Token(token={}start={}is_unk={}is_special={})".format( + self.tokenself.startself.is_unkself.is_special + ) + + def __repr__(self): + return self.__str__() + + +class CPMBeeTokenizer(object): + def __init__( + self, + ): + self.unk_token = "" + self.mask_token = "" + self.bos_token = "" + self.eos_token = "" + self.line_token = "\n" + self.space_token = " " + + self.encoder = load_vocab(pkg_resources.resource_stream("cpm_live""vocabs/bee.txt")) + self.encoder[self.line_token] = self.encoder[""] + self.encoder[self.space_token] = self.encoder[""] + del self.encoder[""] + del self.encoder[""] + + self.decoder = {v: k for kv in self.encoder.items()} + self._special_tokens = { + k: v for kv in self.encoder.items() if k.startswith("<") and k.endswith(">") + } + + self._max_word_len = max([len(x) for x in self.encoder.keys()]) + + def get_piece(selftext: str) -> str: + text = text[: self._max_word_len] + len_text = len(text) + for i in range(len(text)): + sub = text[: len_text - i] + if (sub in self.encoder) and (sub not in self._special_tokens): + return sub + return text[0] + + @property + def vocab_size(self): + return len(self.encoder) + + @property + def eos_id(self): + return self.encoder[self.eos_token] + + @property + def bos_id(self): + return self.encoder[self.bos_token] + + @property + def unk_id(self): + return self.encoder[self.unk_token] + + @property + def mask_id(self): + return self.encoder[self.mask_token] + + def __len__(self): + return len(self.encoder) + + def tokenize(selftext: str) -> List[Token]: + output_tokens: List[Token] = [] + + sentence_split = [""] + is_escape = False + is_special_token = False + for ic in enumerate(text): + if is_special_token: + if c == "<": + raise ValueError("Invalid special token at pos {}".format(i)) + elif c == ">": + # end of special token + sentence_split[-1] += c + is_special_token = False + sentence_split.append("") + else: + sentence_split[-1] += c + else: + if c == "<": + if is_escape: + # case: << + sentence_split[-1] += c + is_escape = False + else: + # case: x< + is_escape = True + else: + if is_escape: + # case str: + return text.replace("<""<<") + + @staticmethod + def unescape(text: str) -> str: + return text.replace("<<""<") + + def encode( + selftext: strpast_table: Dict[intstr] = {} + ) -> Tuple[List[int]Dict[intstr]]: + ext_table_rev: Dict[strint] = {} + ext_table: Dict[intstr] = {} + for idxval in past_table.items(): + ext_table[idx] = val + ext_table_rev[val] = idx + ret = [] + for x in self.tokenize(text): + if x.is_unk or (x.is_special and (x.token not in self.encoder)): + if x.token not in ext_table_rev: + ext_table_rev[x.token] = len(ext_table_rev) + self.vocab_size + ext_table[ext_table_rev[x.token]] = x.token + ret.append(ext_table_rev[x.token]) + elif x.token in self.encoder: + ret.append(self.encoder[x.token]) + else: + raise ValueError("Unknown token `{}` at pos {}".format(x.tokenx.start)) + + return retext_table + + def decode(selftokens: List[int]ext_table: Optional[Dict[intstr]] = None): + """Decode ids into a string.""" + if ext_table is None: + ext_table = {} + ret = [] + for token in tokens: + if token in ext_table: + ret.append(ext_table[token]) + else: + if token >= 0: + w = self.decoder[token] + if w in self._special_tokens: + ret.append(w) + else: + ret.append(self.escape(w)) + return "".join(ret) diff --git a/cpm-live/training_tasks/__init__.py b/cpm-live/cpm_live/training_tasks/__init__.py similarity index 50% rename from cpm-live/training_tasks/__init__.py rename to cpm-live/cpm_live/training_tasks/__init__.py index 0a502f1..3cb9fa0 100644 --- a/cpm-live/training_tasks/__init__.py +++ b/cpm-live/cpm_live/training_tasks/__init__.py @@ -1 +1,2 @@ from . import ant +from . import bee diff --git a/cpm-live/training_tasks/ant/__init__.py b/cpm-live/cpm_live/training_tasks/ant/__init__.py similarity index 100% rename from cpm-live/training_tasks/ant/__init__.py rename to cpm-live/cpm_live/training_tasks/ant/__init__.py diff --git a/cpm-live/training_tasks/ant/pretrain.py b/cpm-live/cpm_live/training_tasks/ant/pretrain.py similarity index 96% rename from cpm-live/training_tasks/ant/pretrain.py rename to cpm-live/cpm_live/training_tasks/ant/pretrain.py index 055b691..5604bfc 100644 --- a/cpm-live/training_tasks/ant/pretrain.py +++ b/cpm-live/cpm_live/training_tasks/ant/pretrain.py @@ -96,8 +96,7 @@ def __get_item_data(selfraw_data): tgt = np.concatenate((np.full(self.prompt_length-100dtype=np.int64)tgt)) inp = np.concatenate( ( - np.arange(self.prompt_lengthdtype=np.int64) + - self.prompt_length * global_task + self.tokenizer.vocab_size, + np.arange(self.prompt_lengthdtype=np.int64) + self.prompt_length * global_task, ctx, ) ) diff --git a/cpm-live/cpm_live/training_tasks/bee/__init__.py b/cpm-live/cpm_live/training_tasks/bee/__init__.py new file mode 100644 index 0000000..684f707 --- /dev/null +++ b/cpm-live/cpm_live/training_tasks/bee/__init__.py @@ -0,0 +1,2 @@ +from .pretrain import MixedDataset +from .finetune import FinetuneDataset diff --git a/cpm-live/cpm_live/training_tasks/bee/finetune.py b/cpm-live/cpm_live/training_tasks/bee/finetune.py new file mode 100644 index 0000000..561e3c2 --- /dev/null +++ b/cpm-live/cpm_live/training_tasks/bee/finetune.py @@ -0,0 +1,71 @@ +from ...tokenizers import CPMBeeTokenizer +from .pretrain import _MixedDatasetBatchPacker_MixedDatasetConfigCPMBeeBatch +from ...dataset import SimpleDataset +import bmtrain as bmt + + +class FinetuneDataset: + def __init__( + self, + dataset_path: str, + batch_size: int, + max_length: int, + tokenizer: CPMBeeTokenizer, + max_depth: int = 16, + task_name: str = "task", + drop_last: bool = False, + ) -> None: + self._world_size = bmt.world_size() + self._rank = bmt.rank() + self._batch_size = batch_size + + self._packer = _MixedDatasetBatchPacker( + batch_size * self._world_sizemax_lengthtokenizermax_depth + ) + self._drop_last = drop_last + + ds = SimpleDataset(dataset_pathshuffle=False) + self._ds_cfg: _MixedDatasetConfig = { + "weight": 1.0, + "path": dataset_path, + "transforms": [], + "task_name": task_name, + "dataset_name": "finetune", + "incontext_weight": [1.0], + "lines": len(ds), + "dataset": ds, + } + + def __batch_iter(self): + while True: + try: + batch = self._packer.add_data(self._ds_cfg) + except EOFError: + break + if batch is None: + continue + yield batch + if len(self._packer) > 0: + batch = self._packer.pack_batch(force=True) + if not self._drop_last: + yield batch + self._ds_cfg["dataset"]._repeat_times = 0 + + def __iter__(self): + batch_st = self._batch_size * self._rank + batch_end = self._batch_size * (self._rank + 1) + for batch in self.__batch_iter(): + batch_size = batch["inputs"].shape[0] + if batch_size <= batch_st: + yield None + else: + ret: CPMBeeBatch = { + kw: val[batch_st:batch_end] # type: ignore + for kwval in batch.items() + if kw not in ["task_names""raw_data""ext_ids""ext_sub"] + } # type: ignore + ret["task_names"] = batch["task_names"] + ret["raw_data"] = batch["raw_data"] + ret["ext_ids"] = batch["ext_ids"] + ret["ext_sub"] = batch["ext_sub"] + yield ret diff --git a/cpm-live/cpm_live/training_tasks/bee/pretrain.py b/cpm-live/cpm_live/training_tasks/bee/pretrain.py new file mode 100644 index 0000000..d3caef2 --- /dev/null +++ b/cpm-live/cpm_live/training_tasks/bee/pretrain.py @@ -0,0 +1,1041 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache LicenseVersion 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writingsoftware +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +import multiprocessing +import os +from queue import Empty +from typing import AnyCallableDictListOptionalSetTupleUnion +from typing_extensions import TypedDict +from ...dataset import DistributedDataset +from ...tokenizers import CPMBeeTokenizer +from ...utils.config import load_dataset_config +import numpy as np +import time +from numpy.typing import NDArray +import torch +import bmtrain as bmt +import importlib.machinery +import importlib.util +import types +import random + + +class _MixedDatasetConfig(TypedDict): + weight: float + path: str + transforms: Union[List[Dict[strAny]]str] + task_name: str + dataset_name: str + incontext_weight: List[float] + + lines: int + dataset: DistributedDataset + + +CPMBeeInputType = Union[strDict[str"CPMBeeInputType"]] + + +class _DictTree(TypedDict): + value: str + children: List["_DictTree"] + depth: int + segment_id: int + need_predict: bool + + +class _PrevExtTableStates(TypedDict): + ext_table: Dict[intstr] + token_id_table: Dict[strDict[intint]] + + +class _TransformFuncDict(TypedDict): + loader: importlib.machinery.SourceFileLoader + module: types.ModuleType + last_m: float + + +_TransformFunction = Callable[[CPMBeeInputTypeintrandom.Random]CPMBeeInputType] + + +class CPMBeeBatch(TypedDict): + inputs: NDArray[np.int32] + inputs_sub: NDArray[np.int32] + length: NDArray[np.int32] + context: NDArray[np.bool_] + sample_ids: NDArray[np.int32] + num_segments: NDArray[np.int32] + segment_ids: NDArray[np.int32] + segment_rel_offset: NDArray[np.int32] + segment_rel: NDArray[np.int32] + spans: NDArray[np.int32] + target: NDArray[np.int32] + ext_ids: NDArray[np.int32] + ext_sub: NDArray[np.int32] + task_ids: NDArray[np.int32] + task_names: List[str] + raw_data: List[Any] + + +def rel_to_bucket(n_up: intn_down: intmax_depth: int = 8): + ret = n_up * max_depth + n_down + if ret == 0: + return ret + else: + # bucket 1 is reserved for incontext samples + return ret + 1 + + +def convert_data_to_id( + tokenizer: CPMBeeTokenizer, + data: Any, + prev_ext_states: Optional[_PrevExtTableStates] = None, + shuffle_answer: bool = True, + max_depth: int = 8, +): + root: _DictTree = { + "value": "", + "children": [], + "depth": 0, + "segment_id": 0, + "need_predict": False, + } + + segments = [root] + + def _build_dict_tree(data: CPMBeeInputTypedepth: intneed_predict: bool) -> List[_DictTree]: + if isinstance(datadict): + ret_list: List[_DictTree] = [] + curr_items = list(data.items()) + if need_predict and shuffle_answer: + access_idx = np.arange(len(curr_items)) + np.random.shuffle(access_idx) + curr_items = [curr_items[idx] for idx in access_idx] + for kv in curr_items: + child_info: _DictTree = { + "value": k, + "children": [], + "depth": depth, + "segment_id": len(segments), + "need_predict": False # only leaves are contexts + } + segments.append(child_info) + child_info["children"] = _build_dict_tree( + vdepth + 1need_predict or (depth == 1 and k == "") + ) # elements in . + + ret_list.append(child_info) + return ret_list + else: + assert isinstance(datastr)"Invalid data {}".format(data) + ret: _DictTree = { + "value": data, + "children": [], + "depth": depth, + "segment_id": len(segments), + "need_predict": need_predict, + } + segments.append(ret) + return [ret] + + root["children"] = _build_dict_tree(data1False) + + num_segments = len(segments) + segment_rel = np.zeros((num_segments * num_segments,)dtype=np.int32) + + def _build_segment_rel(node: _DictTree) -> List[Tuple[intint]]: + ret: List[Tuple[intint]] = [(node["segment_id"]node["depth"])] + for child in node["children"]: + sub = _build_segment_rel(child) + for seg_id_1depth_1 in sub: + for seg_id_2depth_2 in ret: + n_up = min(depth_1 - node["depth"]max_depth - 1) + n_down = min(depth_2 - node["depth"]max_depth - 1) + segment_rel[seg_id_1 * num_segments + seg_id_2] = rel_to_bucket( + n_upn_downmax_depth=max_depth + ) + segment_rel[seg_id_2 * num_segments + seg_id_1] = rel_to_bucket( + n_downn_upmax_depth=max_depth + ) + ret.extend(sub) + return ret + + _build_segment_rel(root) + + input_ids: List[int] = [] + input_id_subs: List[int] = [] + segment_bound: List[Tuple[intint]] = [] + + ext_table: Dict[intstr] = {} + token_id_table: Dict[strDict[intint]] = {} + + if prev_ext_states is not None: + ext_table = prev_ext_states["ext_table"] + token_id_table = prev_ext_states["token_id_table"] + + for seg in segments: + tokensext_table = tokenizer.encode(seg["value"]ext_table) + + token_id_subs = [] + reid_token_ids = [] + for idx in tokens: + if idx in ext_table: + # unk or special token + token = ext_table[idx] + if token.startswith("<") and token.endswith(">"): + # special token + if "_" in token: + token_name = token[1:-1].split("_"maxsplit=1)[0] + else: + token_name = token[1:-1] + token_name = "<{}>".format(token_name) + else: + token_name = "" + + if token_name not in token_id_table: + token_id_table[token_name] = {} + if idx not in token_id_table[token_name]: + token_id_table[token_name][idx] = len(token_id_table[token_name]) + if token_name not in tokenizer.encoder: + raise ValueError("Invalid token {}".format(token)) + reid_token_ids.append(tokenizer.encoder[token_name]) + token_id_subs.append(token_id_table[token_name][idx]) + else: + reid_token_ids.append(idx) + token_id_subs.append(0) + tokens = [tokenizer.bos_id] + reid_token_ids + token_id_subs = [0] + token_id_subs + if not seg["need_predict"]: + tokens = tokens + [tokenizer.eos_id] + token_id_subs = token_id_subs + [0] + else: + # no eos + pass + begin = len(input_ids) + input_ids.extend(tokens) + input_id_subs.extend(token_id_subs) + end = len(input_ids) + segment_bound.append((beginend)) + + ids = np.array(input_idsdtype=np.int32) + id_subs = np.array(input_id_subsdtype=np.int32) + segs = np.zeros((ids.shape[0],)dtype=np.int32) + context = np.zeros((ids.shape[0],)dtype=np.int8) + for i(beginend) in enumerate(segment_bound): + if not segments[i]["need_predict"]: + context[begin:end] = 1 + segs[begin:end] = i + + curr_ext_table_states: _PrevExtTableStates = { + "ext_table": ext_table, + "token_id_table": token_id_table, + } + return idsid_subscontextsegssegment_relnum_segmentscurr_ext_table_states + + +def _dataset_identity(c: _MixedDatasetConfig): + return "{}.{}".format(c["task_name"]c["dataset_name"]) + + +class _MixedDatasetBatchPacker: + def __init__( + self, + batch_size: int, + max_length: int, + tokenizer: CPMBeeTokenizer, + max_depth: int = 16, + ) -> None: + self._batch_size = batch_size + self._max_length = max_length + self._max_depth = max_depth + self.tokenizer = tokenizer + self._transform_func_table: Dict[str_TransformFuncDict] = {} + + self._inputs: List[NDArray[np.int32]] = [] + self._inputs_sub: List[NDArray[np.int32]] = [] + self._context: List[NDArray[np.int8]] = [] + self._sample_ids: List[NDArray[np.int32]] = [] + self._segments: List[NDArray[np.int32]] = [] + self._num_segments: List[NDArray[np.int32]] = [] + self._segment_rel_offset: List[NDArray[np.int32]] = [] + self._segment_rel: List[NDArray[np.int32]] = [] + self._spans: List[List[int]] = [] + self._task_ids: List[List[str]] = [] + self._raw_data: List[List[Any]] = [] + + def __len__(self): + return len(self._inputs) + + def apply_transform( + self, + data: CPMBeeInputType, + transform: Union[Dict[strAny]Callable[[CPMBeeInputType]CPMBeeInputType]None], + ) -> CPMBeeInputType: + if transform is None: + return data + if not isinstance(transformdict): + # transform function + return transform(data) + + mapping_list: List[Tuple[strstr]] = [] + + def _walk_transform_dict(data: Union[Dict[strAny]str]prefix: str = ""): + if isinstance(datadict): + for kv in data.items(): + if len(prefix) > 0: + _walk_transform_dict(vprefix + "." + k) + else: + _walk_transform_dict(vk) + else: + assert isinstance(datastr)"Invalid transform {}".format(data) + mapping_list.append((prefixdata)) + + _walk_transform_dict(transform) + + expanded_mapping_list: List[Tuple[strAny]] = [] + + def _expand_mapping( + data: CPMBeeInputTypestars: List[str]path: List[str]target: List[str] + ): + if len(path) == 0: + num_stars = 0 + for it in target: + if it == "*": + num_stars += 1 + if num_stars != len(stars): + raise ValueError("Invalid transform {}".format(".".join(target))) + + nw_tgt = [] + num_stars = 0 + for it in target: + if it == "*": + nw_tgt.append(stars[num_stars]) + num_stars += 1 + else: + nw_tgt.append(it) + expanded_mapping_list.append((".".join(nw_tgt)data)) + else: + if not isinstance(datadict): + raise ValueError("Invalid data {}".format(data)) + if path[0] == "*": + for kv in data.items(): + _expand_mapping(vstars + [k]path[1:]target) + else: + _expand_mapping(data[path[0]]starspath[1:]target) + + # expand mapping list + for tgtsrc in mapping_list: + if src.startswith("$"): + # copy from src + _expand_mapping(data[]src[1:].split(".")tgt.split(".")) + else: + if "*" in tgt: + raise ValueError("Constant value is not allowed to have `*` in prefix") + expanded_mapping_list.append((tgtsrc)) + + ret = {} + for tgtval in expanded_mapping_list: + tgt = tgt.split(".") + cur = ret + while len(tgt) > 1: + cur = cur[tgt[0]] + tgt = tgt[1:] + cur[tgt[0]] = val + return ret + + def data_to_id( + self, + data: Any, + prev_ext_states: Optional[_PrevExtTableStates] = None, + shuffle_answer: bool = True, + ): + return convert_data_to_id( + self.tokenizerdataprev_ext_statesshuffle_answerself._max_depth + ) + + def _ensure_transform_function( + selfmodule_name: strtransform_script_path: str + ) -> _TransformFunction: + module_name = "cpm_live.transforms.{}".format(module_name) + if transform_script_path not in self._transform_func_table: + loader = importlib.machinery.SourceFileLoader(module_nametransform_script_path) + spec = importlib.util.spec_from_loader(loader.nameloader) + if spec is None: + raise RuntimeError("spec is none! {}".format(module_name)) + mod = importlib.util.module_from_spec(spec) + self._transform_func_table[transform_script_path] = { + "loader": loader, + "module": mod, + "last_m": 0, + } + + transform_script_info = self._transform_func_table[transform_script_path] + curr_m_time = float( + transform_script_info["loader"].path_stats(transform_script_path)["mtime"] + ) + if curr_m_time > transform_script_info["last_m"]: + transform_script_info["last_m"] = curr_m_time + transform_script_info["loader"].exec_module(transform_script_info["module"]) + transform_func = getattr(transform_script_info["module"]"transform"None) + if transform_func is None: + + def _empty_transform_func(data: CPMBeeInputTypenum_sample: intr: random.Random): + raise NotImplementedError( + "Transform func for dataset {} not implemented".format(module_name) + ) + + return _empty_transform_func + else: + return transform_func + + def build_instance(selfconfig: _MixedDatasetConfig): + _sample_weight = np.array(config["incontext_weight"]dtype=np.float32) + _sample_weight = _sample_weight / _sample_weight.sum() + num_incontext = np.random.choice(_sample_weight.shape[0]p=_sample_weight) + ds = config["dataset"] + transforms = config["transforms"] + if isinstance(transformsstr): + while True: + try: + if not os.path.exists(transforms): + raise RuntimeError( + "transform script file {} not exists".format(transforms) + ) + # load transform script + transform_func = self._ensure_transform_function( + _dataset_identity(config)transforms + ) + seed = random.random() + break + except Exception as e: + print(e) + time.sleep(10) + + def _transform(data: CPMBeeInputType): + r = random.Random(seed) + return transform_func(datanum_incontextr) + transform = _transform + elif len(transforms) == 0: + transform = None + else: + transform = transforms[np.random.choice(len(transforms))] + + raw_data = {} + while True: + inp = ds.read() + inp = self.apply_transform(inptransform) + + ( + input_ids, + input_id_subs, + context, + segment_ids, + segment_rel, + n_segments, + table_states, + ) = self.data_to_id(inp) + if input_ids.shape[0] > self._max_length: + # too long + continue + input_ids = input_ids[: self._max_length] + context = context[: self._max_length] + segment_ids = segment_ids[: self._max_length] + raw_data["input"] = inp + raw_data["samples"] = [] + break + + sample_ids = np.zeros(input_ids.shapedtype=np.int32) + segment_rel_offset = np.zeros(input_ids.shapedtype=np.int32) + num_segments = np.full(input_ids.shapen_segmentsdtype=np.int32) + + for i in range(num_incontext): + if input_ids.shape[0] >= self._max_length: + # early break + break + + sample = ds.read() + sample = self.apply_transform(sampletransform) + ( + sample_input_ids, + sample_id_subs, + _, + sample_segments, + sample_rel, + n_segments, + table_states, + ) = self.data_to_id(sampletable_states) + + if input_ids.shape[0] + sample_input_ids.shape[0] > self._max_length: + # too longbreak + break + raw_data["samples"].append(sample) + input_ids = np.concatenate([input_idssample_input_ids]axis=0) + input_id_subs = np.concatenate([input_id_subssample_id_subs]axis=0) + context = np.concatenate( + [contextnp.ones(sample_input_ids.shapedtype=np.int8)]axis=0 + ) + segment_ids = np.concatenate([segment_idssample_segments]axis=0) + segment_rel_offset = np.concatenate( + [ + segment_rel_offset, + np.full(sample_input_ids.shapesegment_rel.shape[0]dtype=np.int32), + ], + axis=0, + ) + segment_rel = np.concatenate([segment_relsample_rel]axis=0) + sample_ids = np.concatenate( + [sample_idsnp.full(sample_input_ids.shapei + 1dtype=np.int32)]axis=0 + ) + num_segments = np.concatenate( + [num_segmentsnp.full(sample_input_ids.shapen_segmentsdtype=np.int32)]axis=0 + ) + return ( + input_ids, + input_id_subs, + context, + segment_ids, + segment_rel_offset, + segment_rel, + sample_ids, + num_segments, + raw_data, + ) + + def pack_batch(selfforce: bool = False) -> CPMBeeBatch: + # pack batch + if len(self._inputs) < self._batch_size: + if not force: + raise RuntimeError("Batch insufficient") + batch_size = len(self._inputs) + else: + batch_size = self._batch_size + inputs = np.zeros((batch_sizeself._max_length)dtype=np.int32) + inputs_sub = np.zeros((batch_sizeself._max_length)dtype=np.int32) + context = np.zeros((batch_sizeself._max_length)dtype=np.int8) + sample_ids = np.zeros((batch_sizeself._max_length)dtype=np.int32) + segments = np.zeros((batch_sizeself._max_length)dtype=np.int32) + num_segments = np.zeros((batch_sizeself._max_length)dtype=np.int32) + segment_rel_offset = np.zeros((batch_sizeself._max_length)dtype=np.int32) + tgt = np.full((batch_sizeself._max_length)-100dtype=np.int32) + + max_rel = 0 + for i in range(batch_size): + max_rel = max(max_relself._segment_rel[i].shape[0]) + segment_rel = np.zeros((batch_sizemax_rel)dtype=np.int32) + spans = np.zeros((batch_sizeself._max_length)dtype=np.int32) + length = np.zeros((batch_size,)dtype=np.int32) + task_ids = np.zeros((batch_sizeself._max_length)dtype=np.int32) + + all_task_names: Set[str] = set() + for i in range(batch_size): + for task_name in self._task_ids[i]: + all_task_names.add(task_name) + task_names: List[str] = list(all_task_names) + task_name_to_id = {name: i for iname in enumerate(task_names)} + + batch_ext_table_map: Dict[Tuple[intint]int] = {} + batch_ext_table_ids: List[int] = [] + batch_ext_table_sub: List[int] = [] + raw_data_list: List[Any] = [] + for i in range(batch_size): + instance_length = self._inputs[i].shape[0] + rel_size = self._segment_rel[i].shape[0] + inputs[i:instance_length] = self._inputs[i] + inputs_sub[i:instance_length] = self._inputs_sub[i] + context[i:instance_length] = self._context[i] + sample_ids[i:instance_length] = self._sample_ids[i] + segments[i:instance_length] = self._segments[i] + num_segments[i:instance_length] = self._num_segments[i] + segment_rel_offset[i:instance_length] = self._segment_rel_offset[i] + segment_rel[i:rel_size] = self._segment_rel[i] + + span_begin = 0 + for span_id(span_endtask_name) in enumerate(zip(self._spans[i]self._task_ids[i])): + spans[ispan_begin:span_end] = span_id + task_ids[ispan_begin:span_end] = task_name_to_id[task_name] + span_begin = span_end + length[i] = instance_length + raw_data_list.extend(self._raw_data[i]) + + for j in range(instance_length): + idxidx_sub = self._inputs[i][j]self._inputs_sub[i][j] + tgt_idx = idx + if idx_sub > 0: + # need to be in ext table + if (idxidx_sub) not in batch_ext_table_map: + batch_ext_table_map[(idxidx_sub)] = len(batch_ext_table_map) + batch_ext_table_ids.append(idx) + batch_ext_table_sub.append(idx_sub) + tgt_idx = batch_ext_table_map[(idxidx_sub)] + self.tokenizer.vocab_size + if j > 1 and context[ij - 1] == 0: + if idx != self.tokenizer.bos_id: + tgt[ij - 1] = tgt_idx + else: + tgt[ij - 1] = self.tokenizer.eos_id + if context[iinstance_length - 1] == 0: + tgt[iinstance_length - 1] = self.tokenizer.eos_id + + if len(batch_ext_table_map) == 0: + # placeholder + batch_ext_table_ids.append(0) + batch_ext_table_sub.append(1) + + self._inputs = self._inputs[batch_size:] + self._inputs_sub = self._inputs_sub[batch_size:] + self._context = self._context[batch_size:] + self._sample_ids = self._sample_ids[batch_size:] + self._segments = self._segments[batch_size:] + self._num_segments = self._num_segments[batch_size:] + self._segment_rel_offset = self._segment_rel_offset[batch_size:] + self._segment_rel = self._segment_rel[batch_size:] + self._spans = self._spans[batch_size:] + self._task_ids = self._task_ids[batch_size:] + self._raw_data = self._raw_data[batch_size:] + return { + "inputs": inputs, + "inputs_sub": inputs_sub, + "length": length, + "context": context > 0, + "sample_ids": sample_ids, + "num_segments": num_segments, + "segment_ids": segments, + "segment_rel_offset": segment_rel_offset, + "segment_rel": segment_rel, + "spans": spans, + "target": tgt, + "ext_ids": np.array(batch_ext_table_idsdtype=np.int32), + "ext_sub": np.array(batch_ext_table_subdtype=np.int32), + "task_ids": task_ids, + "task_names": task_names, + "raw_data": raw_data_list, + } + + def add_data(selfconfig: _MixedDatasetConfig) -> Optional[CPMBeeBatch]: + ( + input_ids, + input_id_subs, + context, + segment_ids, + segment_rel_offset, + segment_rel, + sample_ids, + num_segments, + raw_data, + ) = self.build_instance(config) + + # add to batch + best_fit: Union[Noneint] = None + best_fit_space: Union[Noneint] = None + for i in range(len(self._inputs)): + space = self._max_length - self._inputs[i].shape[0] + if input_ids.shape[0] <= space: + if best_fit_space is None: + best_fit = i + best_fit_space = space + elif best_fit_space > space: + best_fit = i + best_fit_space = space + if best_fit is None: + # add a new instance + self._inputs.append(input_ids) + self._inputs_sub.append(input_id_subs) + self._context.append(context) + self._sample_ids.append(sample_ids) + self._segments.append(segment_ids) + self._num_segments.append(num_segments) + self._segment_rel_offset.append(segment_rel_offset) + self._segment_rel.append(segment_rel) + self._spans.append([input_ids.shape[0]]) + self._task_ids.append([config["task_name"]]) + self._raw_data.append([raw_data]) + else: + # add to existing instance + self._inputs[best_fit] = np.concatenate([self._inputs[best_fit]input_ids]axis=0) + self._inputs_sub[best_fit] = np.concatenate( + [self._inputs_sub[best_fit]input_id_subs]axis=0 + ) + self._context[best_fit] = np.concatenate([self._context[best_fit]context]axis=0) + self._sample_ids[best_fit] = np.concatenate( + [self._sample_ids[best_fit]sample_ids]axis=0 + ) + self._segments[best_fit] = np.concatenate( + [self._segments[best_fit]segment_ids]axis=0 + ) + self._num_segments[best_fit] = np.concatenate( + [self._num_segments[best_fit]num_segments]axis=0 + ) + self._segment_rel_offset[best_fit] = np.concatenate( + [ + self._segment_rel_offset[best_fit], + segment_rel_offset + self._segment_rel[best_fit].shape[0], + ], + axis=0, + ) + self._segment_rel[best_fit] = np.concatenate( + [self._segment_rel[best_fit]segment_rel]axis=0 + ) + self._spans[best_fit].append(self._inputs[best_fit].shape[0]) + self._task_ids[best_fit].append(config["task_name"]) + self._raw_data[best_fit].append(raw_data) + + if len(self._inputs) > self._batch_size: + return self.pack_batch() + else: + # not ready + return None + + +class _MixedDatasetConfigMananger: + def __init__(selfconfig_path: str) -> None: + self._config_path: str = config_path + self._config: Union[List[_MixedDatasetConfig]None] = None + self._last_m = 0 + + def changed(self): + while True: + try: + m_time = os.stat(self._config_path).st_mtime + if m_time > self._last_m: + # try to load new config + try: + self._config = load_dataset_config(self._config_path) + except Exception as e: + # failed to load config + print( + "Error: load new config in changed" + "self._config_path={path}err={err}" + .format(path=self._config_patherr=str(e)) + ) + + return False + # new config loaded + self._last_m = m_time + return True + return False + except Exception as e: + print("Error: reading info list in _MixedDatasetConfigMananger.changed!" + "self._config_path={path}err={err}" + .format(path=self._config_patherr=str(e))) + time.sleep(30) + + def get_config(self) -> List[_MixedDatasetConfig]: + if self._config is None: + if not self.changed(): + raise RuntimeError("Failed to load config") + if self._config is None: + raise RuntimeError("Failed to load config") + return self._config + + +def _mixed_dataset_process( + config_path: str, + q_cmd: multiprocessing.Queue, + q_cmd_out: multiprocessing.Queue, + q_data: multiprocessing.Queue, + rank: int, + world_size: int, + packer: _MixedDatasetBatchPacker, +): + # ignore SIGINT + import signal + + signal.signal(signal.SIGINTsignal.SIG_IGN) + config_base_path = os.path.dirname(os.path.abspath(config_path)) + + def _convert_to_abs_path(transform_path: str): + if transform_path.startswith("/"): + return transform_path + else: + return os.path.join(config_base_pathtransform_path) + + def _build_sample_weights(config: List[_MixedDatasetConfig]): + if len(config) == 0: + return np.array([]dtype=np.float32) + weights = [c["weight"] * c["lines"] for c in config] + weights = np.array(weightsdtype=np.float32) + sm_weight = weights.sum() + if sm_weight > 0: + weights = weights / sm_weight + return weights + else: + raise RuntimeError("Empty datasets") + + cfg_mgr = _MixedDatasetConfigMananger(config_path) + config = cfg_mgr.get_config() + + for c in config: + ds = DistributedDataset( + _convert_to_abs_path(c["path"]), + rank, + world_size, + ) + + c["lines"] = ds._nlines + c["dataset"] = ds + if "weight" not in c: + c["weight"] = 1.0 + if "transforms" not in c: + c["transforms"] = [] + elif isinstance(c["transforms"]str): + c["transforms"] = _convert_to_abs_path(c["transforms"]) + if "incontext_weight" not in c: + c["incontext_weight"] = [1.0] + + weights = _build_sample_weights(config) + + should_stop = False + should_start = False + + while not should_stop: + # update config first + if cfg_mgr.changed(): + path_ds_map: Dict[str_MixedDatasetConfig] = {} + nw_path_set: Set[str] = set() + + # load new config + nw_config = cfg_mgr.get_config() + + # build path -> dataset map + for c in config: + path_ds_map[_dataset_identity(c)] = c + + # add new datasets + for c in nw_config: + if _dataset_identity(c) in path_ds_map: + # update values only + if "weight" in c: + path_ds_map[_dataset_identity(c)]["weight"] = c["weight"] + if "transform" in c: + if isinstance(c["transforms"]str): + path_ds_map[_dataset_identity(c)]["transforms"] = _convert_to_abs_path( + c["transforms"] + ) + else: + path_ds_map[_dataset_identity(c)]["transforms"] = c["transforms"] + if "incontext_weight" in c: + path_ds_map[_dataset_identity(c)]["incontext_weight"] = c[ + "incontext_weight" + ] + else: + # new dataset + ds = DistributedDataset( + _convert_to_abs_path(c["path"]), + rank, + world_size, + ) + c["lines"] = ds._nlines + c["dataset"] = ds + if "weight" not in c: + c["weight"] = 1.0 + if "transforms" not in c: + c["transforms"] = [] + elif isinstance(c["transforms"]str): + c["transforms"] = _convert_to_abs_path(c["transforms"]) + if "incontext_weight" not in c: + c["incontext_weight"] = [1.0] + path_ds_map[_dataset_identity(c)] = c + nw_path_set.add(_dataset_identity(c)) + + # remove unused datasets + for c in config: + if _dataset_identity(c) not in nw_path_set: + del path_ds_map[_dataset_identity(c)] + + config: List[_MixedDatasetConfig] = [] + for c in nw_config: + config.append(path_ds_map[_dataset_identity(c)]) + del path_ds_map + del nw_path_set + del nw_config + + weights = _build_sample_weights(config) + + # get cmds + while True: + try: + cmd = q_cmd.get_nowait() + except Empty: + break + if cmd == "stop": + should_stop = True + q_cmd_out.put(True) + break + elif cmd == "state_dict": + ret = OrderedDict() + for c in config: + ds_name = _dataset_identity(c) + ret[ds_name] = c["dataset"]._state_dict() + q_cmd_out.put(ret) + elif cmd == "load_state_dict": + state_dict = q_cmd.get() + missing = [] + for c in config: + ds_name = _dataset_identity(c) + if ds_name in state_dict: + c["dataset"].load_state_dict(state_dict[ds_name]strict=False) + else: + # new dataset + missing.append(ds_name) + q_cmd_out.put(missing) + elif cmd == "start": + should_start = True + q_cmd_out.put(True) + else: + raise RuntimeError("Unknown command: {}".format(cmd)) + + if should_stop: + break + + if not should_start: + # wait for start cmd + time.sleep(1) + continue + + if len(config) == 0: + # no dataset available + time.sleep(1) + continue + + if q_data.full(): + # queue full + time.sleep(1) + continue + + # sample a dataset + ds_id: int = 0 + + while True: + ds_id = np.random.choice(weights.shape[0]p=weights) + if config[ds_id]["dataset"]._nlines != config[ds_id]["lines"]: + # dataset size changed + for c in config: + c["lines"] = c["dataset"]._nlines + weights = _build_sample_weights(config) + continue + else: + break + + batch = packer.add_data(config[ds_id]) + if batch is not None: + # new batch comming + q_data.put(batch) + + # clean queue + while True: + try: + q_data.get_nowait() + except Empty: + break + + +class MixedDataset: + def __init__( + self, + config_path: str, + batch_size: int, + max_length: int, + tokenizer: CPMBeeTokenizer, + max_depth: int = 16, + ) -> None: + self._q_cmd = multiprocessing.Queue() + self._q_cmd_out = multiprocessing.Queue() + self._q_data = multiprocessing.Queue(maxsize=1) + self._packer = _MixedDatasetBatchPacker(batch_sizemax_lengthtokenizermax_depth) + self._p = multiprocessing.Process( + target=_mixed_dataset_process, + args=( + config_path, + self._q_cmd, + self._q_cmd_out, + self._q_data, + bmt.rank(), + bmt.world_size(), + self._packer, + ), + ) + self._p.start() + self._closed = False + + def close(self): + if not self._closed: + self._closed = True + self._q_cmd.put("stop") + assert self._q_cmd_out.get()"Failed to stop process" + self._p.join() + + @property + def closed(self): + return self._closed + + def start(self): + self._q_cmd.put("start") + return self._q_cmd_out.get() + + def state_dict(self): + self._q_cmd.put("state_dict") + states = self._q_cmd_out.get() + if not isinstance(statesOrderedDict): + raise RuntimeError("Invalid state dict {}".format(states)) + if bmt.world_size() == 1: + for val in states.values(): + val["states"].unsqueeze_(0) + val["block"].unsqueeze_(0) + return states + + ret = OrderedDict() + for kv in states.items(): + num_unused_block = v["states"].size(0) + gpu_num_unused_block = torch.tensor([num_unused_block]dtype=torch.long).cuda() + max_unused_blocks = ( + bmt.distributed.all_reduce(gpu_num_unused_blockop="max").cpu().item() + ) + if max_unused_blocks == 0: + max_unused_blocks = 1 + gpu_states = torch.full((max_unused_blocks,)-1dtype=torch.long).cuda() + gpu_states[:num_unused_block] = v["states"].cuda() + + gpu_block = v["block"].cuda() + global_states = bmt.distributed.all_gather( + gpu_states + ).cpu() # (world_sizemax_unused_blocks) + global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size4) + ret[k] = {"states": global_states"block": global_block} + return ret + + def load_state_dict(selfdata: OrderedDictstrict: bool = False): + self._q_cmd.put("load_state_dict") + self._q_cmd.put(data) + missing = self._q_cmd_out.get() + if strict: + if len(missing) > 0: + raise RuntimeError("Missing dataset state: {}".format(missing)) + return missing + + def get(self) -> CPMBeeBatch: + ret: CPMBeeBatch = self._q_data.get() # type: ignore + if not isinstance(retdict): + raise RuntimeError("Invalid data {}".format(ret)) + return ret + + def __iter__(self): + while True: + yield self.get() + + def __del__(self): + if not self.closed: + try: + self.close() + except Exception: + pass diff --git a/cpm-live/cpm_live/utils/__init__.py b/cpm-live/cpm_live/utils/__init__.py index cca5d9b..3d9e31d 100644 --- a/cpm-live/cpm_live/utils/__init__.py +++ b/cpm-live/cpm_live/utils/__init__.py @@ -1 +1,5 @@ from .config import Config +from .data_utils import pad +from .object import allgather_objects +from .log import LogManagerlogger +from .config import load_dataset_config diff --git a/cpm-live/cpm_live/utils/config.py b/cpm-live/cpm_live/utils/config.py index 7e93ee4..28d2b26 100644 --- a/cpm-live/cpm_live/utils/config.py +++ b/cpm-live/cpm_live/utils/config.py @@ -17,10 +17,35 @@ import os import copy from typing import AnyDictUnion +from .log import logger + + +def load_dataset_config(dataset_path: str): + cfg = on.load(open(dataset_path"r"encoding="utf-8")) + + platform_config_path = os.getenv("PLATFORM_CONFIG_PATH") + if platform_config_path is None: + logger.info( + "no platform_config_path. Directly load dataset_path({dataset_path})" + .format(dataset_path=dataset_path) + ) + return cfg + + path_dict = on.load(open(platform_config_path"r"encoding="utf-8"))["dataset_map"] + logger.info( + "load dataset_path({dataset_path}) with platform_config_path({platform_config_path})" + .format(dataset_path=dataset_pathplatform_config_path=platform_config_path) + ) + for dataset in cfg: + dataset["path"] = os.path.join(path_dict[dataset["dataset_name"]]dataset["path"]) + dataset["transforms"] = os.path.join( + path_dict[dataset["dataset_name"]]dataset["transforms"] + ) + return cfg class Config(object): - """enc_dec model configuration""" + """model configuration""" def __init__(self): super().__init__() diff --git a/cpm-live/cpm_live/utils/data_utils.py b/cpm-live/cpm_live/utils/data_utils.py new file mode 100644 index 0000000..913e062 --- /dev/null +++ b/cpm-live/cpm_live/utils/data_utils.py @@ -0,0 +1,44 @@ +import torch + + +def pad(orig_itemskeypadding_value=0padding_side="left"): + items = [] + if isinstance(orig_items[0][key]list): + assert isinstance(orig_items[0][key][0]torch.Tensor) + for it in orig_items: + for tr in it[key]: + items.append({key: tr}) + else: + assert isinstance(orig_items[0][key]torch.Tensor) + items = orig_items + + batch_size = len(items) + shape = items[0][key].shape + dim = len(shape) + assert dim <= 3 + max_length = max(item[key].shape[-1] for item in items) + min_length = min(item[key].shape[-1] for item in items) + dtype = items[0][key].dtype + + if dim == 1: + return torch.cat([item[key] for item in items]dim=0) + elif dim == 2: + if max_length == min_length: + return torch.cat([item[key] for item in items]dim=0) + tensor = torch.zeros((batch_sizemax_length)dtype=dtype) + padding_value + else: + tensor = torch.zeros((batch_sizemax_lengthshape[-1])dtype=dtype) + padding_value + + for iitem in enumerate(items): + if dim == 2: + if padding_side == "left": + tensor[i-len(item[key][0]) :] = item[key][0].clone() + else: + tensor[i: len(item[key][0])] = item[key][0].clone() + elif dim == 3: + if padding_side == "left": + tensor[i-len(item[key][0]) ::] = item[key][0].clone() + else: + tensor[i: len(item[key][0]):] = item[key][0].clone() + + return tensor diff --git a/cpm-live/cpm_live/utils/export.py b/cpm-live/cpm_live/utils/export.py new file mode 100644 index 0000000..6e4db68 --- /dev/null +++ b/cpm-live/cpm_live/utils/export.py @@ -0,0 +1,56 @@ +import os +import time +import functools +import torch +import bmtrain as bmt +import on +from cpm_live.models import CPMBee +from .log import logger +from typing import ListOptional + + +def rename_if_exists(file_path): + if not os.path.exists(file_path): + return + timestamp = time.strftime('%Y%m%d%H%M%S') + file_dirfile_name = os.path.split(file_path) + file_rootfile_ext = os.path.splitext(file_name) + new_file_name = f"{file_root}_bak_{timestamp}{file_ext}" + new_file_path = os.path.join(file_dirnew_file_name) + try: + os.rename(file_pathnew_file_path) + logger.info(f"File '{file_name}' already exists. Renamed to '{new_file_name}'") + except Exception as e: + logger.warn( + "rename file failed,file_path={file_path}new_file_path={new_file_path},err={err}" + .format(file_path=file_pathnew_file_path=new_file_patherr=str(e))) + + +def rename_if_exists_decorator(func): + @functools.wraps(func) + def wrapper(file_path*args**kwargs): + rename_if_exists(file_path) + return func(file_path*args**kwargs) + return wrapper + + +@rename_if_exists_decorator +def bmt_save(file_path: strmodel: CPMBeeexport_files: Optional[List[str]] = None): + bmt.save(modelfile_path) + if export_files is not None: + export_files.append(file_path) + + +@rename_if_exists_decorator +def torch_save(file_path: strobj: objectexport_files: Optional[List[str]] = None): + torch.save(objfile_path) + if export_files is not None: + export_files.append(file_path) + + +@rename_if_exists_decorator +def on_save(file_path: strobj: objectexport_files: Optional[List[str]] = None): + with open(file_path"w") as data_f: + on.dump(objdata_f) + if export_files is not None: + export_files.append(file_path) diff --git a/cpm-live/cpm_live/utils/gradient_shrink.py b/cpm-live/cpm_live/utils/gradient_shrink.py new file mode 100644 index 0000000..7354d73 --- /dev/null +++ b/cpm-live/cpm_live/utils/gradient_shrink.py @@ -0,0 +1,16 @@ +import torch + + +class OpGradientShrink(torch.autograd.Function): + @staticmethod + def forward(ctxx: torch.Tensoralpha: float): + ctx.alpha = alpha + return x + + @staticmethod + def backward(ctxgrad_output): + return grad_output * ctx.alphaNone + + +def gradient_shrink(x: torch.Tensoralpha: float = 0.1): + return OpGradientShrink.apply(xalpha) diff --git a/cpm-live/cpm_live/utils/log.py b/cpm-live/cpm_live/utils/log.py new file mode 100644 index 0000000..55873ad --- /dev/null +++ b/cpm-live/cpm_live/utils/log.py @@ -0,0 +1,98 @@ +import os +import sys +from typing import AnyDictOptionalTupleUnion +import datetime +import on +import logging +import bmtrain as bmt + + +# Set up the common logger +def _get_logger(): + log = logging.getLogger('__name__') + log.setLevel(logging.INFO) + console_handle = logging.StreamHandler(sys.stdout) + node_name = os.getenv("NODE_NAME"str(bmt.rank())) + console_handle.setFormatter( + logging.Formatter( + '[%(levelname)s][%(asctime)s][{}][%(filename)s:%(lineno)d:%(process)d] - %(message)s' + .format(node_name), + datefmt='%Y-%m-%d %H:%M:%S' + ) + ) + log.addHandler(console_handle) + return log + + +logger = _get_logger() + + +class LogManager: + def __init__(selfpath: str): + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + now = self.get_log_time() + latest_log: Union[Dict[strAny]None] = None + for _ in range(15): + log_name = self.get_log_name(now) + if os.path.exists(log_name): + with open(log_name"r") as flog: + latest_log = on.loads(flog.readlines()[-1]) # get last log + break + now -= datetime.timedelta(days=1) + + if latest_log is None: + self.global_token_pass = 0 + else: + self.global_token_pass = latest_log["token pass"] + + def get_log_time(self) -> datetime.datetime: + return datetime.datetime.utcnow() + datetime.timedelta(hours=16) + + def get_log_name(selfnow: Optional[datetime.datetime] = None): + if now is None: + now = self.get_log_time() + return os.path.join(self.path"log.%s.txt" % now.strftime("%Y%m%d")) + + def write( + self, + time: float, + iteration: int, + loss: float, + lr: float, + lr_scale: float, + time_usage: Dict[strfloat], + mem_usage: Dict[strTuple[floatfloat]], + avg_time: float, + token_max: float, + token_pass: float, + throughout: float, + grad_norm: float, + mask_max: float, + num_gpus: int, + task_loss: Dict[strfloat], + model_inspect: Optional[Any] = None, + ): + with open(self.get_log_name()"a") as fp: + ret = { + "time": time, + "iter": iteration, + "loss": loss, + "lr": lr, + "lr scale": int(lr_scale), + "time usage": time_usage, + "mem usage": mem_usage, + "avg time (s)": avg_time, + "token/max": token_max, + "token pass": token_pass + self.global_token_pass, + "throughout (token/s)": throughout, + "grad_norm": grad_norm, + "mask/max": mask_max, + "num_gpus": num_gpus, + "task_loss": task_loss, + } + if model_inspect is not None: + ret["model_inspect"] = model_inspect + fp.write(on.dumps(retensure_ascii=False) + "\n") diff --git a/cpm-live/cpm_live/utils/object.py b/cpm-live/cpm_live/utils/object.py new file mode 100644 index 0000000..0b32e68 --- /dev/null +++ b/cpm-live/cpm_live/utils/object.py @@ -0,0 +1,28 @@ +import bmtrain as bmt +import pickle +import torch + + +def allgather_objects(obj): + if bmt.world_size() == 1: + return [obj] + + with torch.no_grad(): + data_bytes: bytes = pickle.dumps(obj) + data_length: int = len(data_bytes) + + gpu_data_length = torch.tensor([data_length]device="cuda"dtype=torch.long) + gathered_length = bmt.distributed.all_gather(gpu_data_length).view(-1).cpu() + max_data_length = gathered_length.max().item() + + gpu_data_bytes = torch.zeros(max_data_lengthdtype=torch.uint8device="cuda") + byte_storage = torch.ByteStorage.from_buffer(data_bytes) + gpu_data_bytes[:data_length] = torch.ByteTensor(byte_storage) + + gathered_data = bmt.distributed.all_gather(gpu_data_bytes).cpu() + + ret = [] + for i in range(gathered_data.size(0)): + data_bytes = gathered_data[i: gathered_length[i].item()].numpy().tobytes() + ret.append(pickle.loads(data_bytes)) + return ret diff --git a/cpm-live/cpm_live/vocabs/bee.txt b/cpm-live/cpm_live/vocabs/bee.txt new file mode 100644 index 0000000..b977bda --- /dev/null +++ b/cpm-live/cpm_live/vocabs/bee.txt @@ -0,0 +1,86583 @@ + + + + + + + + + + +