diff --git a/.gitignore b/.gitignore
index 9081237..06fed04 100644
--- a/.gitignore
+++ b/.gitignore
@@ -134,4 +134,9 @@ dmypy.on
*.bin
*.idx
-*.pt
\ No newline at end of file
+*.pt
+
+data
+data_raw
+results
+pretrain_data
\ No newline at end of file
diff --git a/README-ZH.md b/README-ZH.md
new file mode 100644
index 0000000..3ce63b7
--- /dev/null
+++ b/README-ZH.md
@@ -0,0 +1,60 @@
+
+
+## 动态
+- 2023/05/27 [CPM-Bee](https://github.com/OpenBMB/CPM-Bee) 发布了!
+- 2023/04/12 CPM-Ant 可以在[HuggingFace Transformers](https://huggingface.co/openbmb/cpm-ant-10b)中使用了!
+- 2022/10/12 中英双语模型 [CPM-Ant+](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant-plus/cpm-live) 已经发布!除了能够生成中文/英文文本,现在模型还可以处理问答、摘要和翻译任务!
+- 2022/09/16 [CPM-Ant](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant/cpm-live) 已经发布!
+- 2022/05/29 CPM-Live的训练今天启动! 详情请查看[训练动态](https://live.openbmb.org/home)。
+- 2022/05/25 CPM-Live的[训练计划](./plans/CPM-Live训练计划书.md)现已公布。期待训练开始!
+
+
+## 里程碑
+- **CPM-Bee** (2022/10/13-2023/05/27) [[代码](https://github.com/OpenBMB/CPM-Bee)][[模型](https://github.com/OpenBMB/CPM-Bee#%E6%A8%A1%E5%9E%8B)][[计划书](./plans/CPM-Bee训练计划书.md)]
+- **CPM-Ant+** (2022/08/05-2022/10/12) [[代码](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant-plus/cpm-live)][[模型](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant-plus/cpm-live#model-checkpoints)]
+- **CPM-Ant** (2022/05/29-2022/08/05) [[代码](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant/cpm-live)][[模型](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant/cpm-live#model-checkpoints)][[网站](https://live.openbmb.org/ant)][[博客](https://www.openbmb.org/en/community/blogs/blogpage?id=98afef2ce45f4fe9a4bc15a66d7ccb92)][[计划书](./plans/CPM-Ant训练计划书.md)]
+
+## 训练计划
+考虑到数据和计算资源的规模,CPM-Live将从10B模型开始训练并持续学习。
+
+### 在训练过程中,我们将进行:
+
+- **实时**:显示模型训练指标
+- **每天**:发布模型训练日志
+- **每周**:处理社区的讨论和反馈
+- **不定期**:在模型训练期间发布允许公开下载的检查点
+
+
+### 在训练期间你可以:
+
+- **提出你的模型倡议**:对模型架构、训练方法或数据源有好的想法?你可以在社区里提出你的模型倡议。如果该倡议得到更多的支持并且实际可行,我们将把它添加到我们正在训练的模型中,这样CPM-Live就可以在大家的帮助下不断学习和进步。
+
+- **开发你的应用程序**:基于CPM-Live,你可以向社区提交你初期想法、原型、开发代码或完成的应用程序。我们将在网站上展示最受欢迎的应用程序。
+
+- **在论坛上聊天**:你可以在我们的论坛上谈论任何与大模型有关的话题,如学术研究、工程实现、工具使用、应用设计等。无论你是否有经验,我们相信每个人都可以从积极和开放的讨论中受益。
+
+- **下载资源**:模型训练完成后,你可以在开放使用许可下自由下载模型参数。CPM-Live使用的是包括商业化许可的开放许可。通过模型压缩和推理加速工具,你可以在自己的电脑上体验大模型的威力!
+
+
+
+## 社区
+
+我们的[社区](https://github.com/OpenBMB/CPM-Live/discussions) 基于GitHub Discussions。
+
+阅读[第一篇帖子](https://github.com/OpenBMB/CPM-Live/discussions/1),开始你对CPM-Live的探索吧!
+
+
+
+
+
diff --git a/README.md b/README.md
index c809a39..4c0d942 100644
--- a/README.md
+++ b/README.md
@@ -5,20 +5,29 @@
**Live Training for Open-source Big Models**
- Website • Plan • Discussion
+ Website • Plan • Discussion • 简体中文
+
-
## What's New
+- 2023/05/27 [CPM-Bee](https://github.com/OpenBMB/CPM-Bee) is released!
+- 2023/04/12 CPM-Ant has been integrated into [HuggingFace Transformers](https://huggingface.co/openbmb/cpm-ant-10b)!
+- 2022/10/12 [CPM-Ant+](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant-plus/cpm-live)a bilingual modelis released! In addition to generating Chinese/English textyou can now use our model for QAsummarization and translation tasks!
+- 2022/09/16 [CPM-Ant](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant/cpm-live) is released!
- 2022/05/29 The training of CPM-Live has launched today! See [training dynamics](https://live.openbmb.org/home).
- 2022/05/25 The [training plan](./plans/CPM-Live训练计划书.md) for CPM-Live is now published. Look forward to the training!
+## Milestones
+
+- **CPM-Bee** (2022/10/13-2023/05/27) [[Code](https://github.com/OpenBMB/CPM-Bee)][[Model](https://github.com/OpenBMB/CPM-Bee#%E6%A8%A1%E5%9E%8B)][[Plan](./plans/CPM-Bee训练计划书.md)]
+- **CPM-Ant+** (2022/08/05-2022/10/12) [[Code](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant-plus/cpm-live)][[Model](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant-plus/cpm-live#model-checkpoints)]
+- **CPM-Ant** (2022/05/29-2022/08/05) [[Code](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant/cpm-live)][[Model](https://github.com/OpenBMB/CPM-Live/tree/cpm-ant/cpm-live#model-checkpoints)][[Website](https://live.openbmb.org/ant)][[Blog](https://www.openbmb.org/en/community/blogs/blogpage?id=98afef2ce45f4fe9a4bc15a66d7ccb92)][[Plan](./plans/CPM-Ant训练计划书.md)]
## Training Plan
-Considering the scale of data and computing resourcesCPM-Live will start with a 10B model trainingwhich we named CPM-Ant. The training of CPM-Ant will start on May 292022and the entire training process is expected to last five months.
+Considering the scale of data and computing resourcesCPM-Live will start with a 10B model training.
### During training we will do:
@@ -43,8 +52,3 @@ Considering the scale of data and computing resourcesCPM-Live will start with
[Our community](https://github.com/OpenBMB/CPM-Live/discussions) is based on GitHub Discussions.
Read the [first post](https://github.com/OpenBMB/CPM-Live/discussions/1) and start your exploration on CPM-Live!
-
-
-
-
-
diff --git a/cpm-live/.flake8 b/cpm-live/.flake8
index b357ab3..e4d2c92 100644
--- a/cpm-live/.flake8
+++ b/cpm-live/.flake8
@@ -3,4 +3,5 @@ per-file-ignores =
# imported but unused
__init__.py: F401
max-line-length = 100
-extend-ignore = E712E203
\ No newline at end of file
+extend-ignore = E712E203
+exclude = examples/*.py
\ No newline at end of file
diff --git a/cpm-live/config/cpm-bee-10b.on b/cpm-live/config/cpm-bee-10b.on
new file mode 100644
index 0000000..c34b2e0
--- /dev/null
+++ b/cpm-live/config/cpm-bee-10b.on
@@ -0,0 +1,14 @@
+{
+ "vocab_size": 86583,
+ "dim_model": 4096,
+ "dim_ff" : 10240,
+ "num_layers" : 48,
+ "num_heads": 32,
+ "dim_head" : 128,
+ "dropout_p" : 0.0,
+ "position_bias_num_buckets" : 256,
+ "position_bias_num_segment_buckets": 256,
+ "position_bias_max_distance" : 2048,
+ "eps" : 1e-6,
+ "half" : true
+}
diff --git a/cpm-live/config/cpm-bee-3b.on b/cpm-live/config/cpm-bee-3b.on
new file mode 100644
index 0000000..55fd0f2
--- /dev/null
+++ b/cpm-live/config/cpm-bee-3b.on
@@ -0,0 +1,14 @@
+{
+ "vocab_size": 86580,
+ "dim_model": 2560,
+ "dim_ff" : 3072,
+ "num_layers" : 32,
+ "num_heads": 32,
+ "dim_head" : 80,
+ "dropout_p" : 0.0,
+ "position_bias_num_buckets" : 256,
+ "position_bias_num_segment_buckets": 256,
+ "position_bias_max_distance" : 2048,
+ "eps" : 1e-6,
+ "half" : true
+}
diff --git a/cpm-live/cpm_live/__init__.py b/cpm-live/cpm_live/__init__.py
index 03e328b..e69de29 100644
--- a/cpm-live/cpm_live/__init__.py
+++ b/cpm-live/cpm_live/__init__.py
@@ -1,5 +0,0 @@
-from . import models
-from . import dataset
-from . import utils
-from . import tokenizers
-from .arguments import get_args
diff --git a/cpm-live/cpm_live/arguments.py b/cpm-live/cpm_live/arguments.py
index 2f3202c..b1b6d0e 100644
--- a/cpm-live/cpm_live/arguments.py
+++ b/cpm-live/cpm_live/arguments.py
@@ -29,13 +29,7 @@ def add_training_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("train""training configurations")
- group.add_argument(
- "--base-path",
- type=str,
- default=None,
- help="Path to the project base directory.",
- )
- group.add_argument("--dataset_name"type=strdefault=Nonehelp="Name of the dataset")
+ group.add_argument("--dataset"type=strdefault="dataset.on"help="Path to dataset")
group.add_argument(
"--load",
type=str,
@@ -54,18 +48,14 @@ def add_training_args(parser: argparse.ArgumentParser):
default=None,
help="Output filename to save checkpoints to.",
)
+
group.add_argument(
- "--save-iters",
- type=int,
- default=1000,
- help="number of iterations between saves",
- )
- group.add_argument(
- "--log-dir",
+ "--tensorboard",
type=str,
- default="logs",
- help="tensorboard log directory",
+ default=None,
+ help="tensorboard directory",
)
+
group.add_argument("--inspect-iters"type=intdefault=1000help="number of inspecting")
group.add_argument("--batch-size"type=intdefault=32help="Data Loader batch size")
group.add_argument("--clip-grad"type=floatdefault=1.0help="gradient clipping")
@@ -76,33 +66,12 @@ def add_training_args(parser: argparse.ArgumentParser):
help="total number of iterations to train over all training runs",
)
group.add_argument("--max-length"type=intdefault=512help="max length of input")
- group.add_argument(
- "--max-encoder-length",
- type=int,
- default=512,
- help="max length of encoder input",
- )
- group.add_argument(
- "--max-decoder-length",
- type=int,
- default=256,
- help="max length of decoder input",
- )
- group.add_argument(
- "--start-step"type=intdefault=0help="step to start or continue training"
- )
- group.add_argument("--seed"type=intdefault=1234help="random seed for reproducibility")
- group.add_argument(
- "--epochs",
- type=int,
- default=1,
- help="total number of epochs to train over all training runs",
- )
+ group.add_argument("--seed"type=intdefault=1234help="random seed for reproducibility")
# Learning rate.
group.add_argument("--lr"type=floatdefault=1.0e-4help="initial learning rate")
- group.add_argument("--weight-decay"type=floatdefault=1.0e-2help="weight-decay")
+ group.add_argument("--weight-decay"type=floatdefault=1.0e-2help="weight decay rate")
group.add_argument("--loss-scale"type=floatdefault=65536help="loss scale")
group.add_argument(
@@ -111,13 +80,6 @@ def add_training_args(parser: argparse.ArgumentParser):
default=0.01,
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
)
- group.add_argument(
- "--lr-decay-iters",
- type=int,
- default=None,
- help="number of iterations to decay LR over,"
- " If None defaults to `--train-iters`*`--epochs`",
- )
group.add_argument(
"--lr-decay-",
type=str,
@@ -125,20 +87,62 @@ def add_training_args(parser: argparse.ArgumentParser):
choices=["constant""linear""cosine""exponential""noam"],
help="learning rate decay function",
)
+ group.add_argument("--lr-decay-iters"type=intdefault=Nonehelp="lr decay steps")
group.add_argument(
- "--local_rank",
+ "--start-step"type=intdefault=0help="step to start or continue training"
+ )
+
+ return parser
+
+
+def add_pretrain_args(parser: argparse.ArgumentParser):
+ group = parser.add_argument_group("pretrain""pretrain configurations")
+ group.add_argument(
+ "--save-iters",
type=int,
+ default=1000,
+ help="number of iterations between saves",
+ )
+ group.add_argument(
+ "--log-dir",
+ type=str,
default=None,
- help="local rank passed from distributed launcher",
+ help="log directory",
)
return parser
-def get_args():
+def add_finetune_args(parser: argparse.ArgumentParser):
+ group = parser.add_argument_group("finetune""fintune configurations")
+ group.add_argument("--epoch"type=intdefault=1help="number of training epochs")
+ group.add_argument("--task-name"type=strdefault="task"help="name of training task")
+ group.add_argument(
+ "--use-delta",
+ action="store_true",
+ default=False,
+ help="use delta tuning or not"
+ )
+ group.add_argument("--eval_dataset"type=strhelp="path to eval dataset")
+ group.add_argument(
+ "--drop-last",
+ action="store_true",
+ default=False,
+ help="drop data from each epoch that cannot be formed into a complete batch at the end",
+ )
+ group.add_argument("--eval-interval"type=intdefault=500help="eval interval")
+ group.add_argument("--early-stop-patience"type=intdefault=5help="early stop steps")
+ return parser
+
+
+def get_args(pretrain: bool = Falsefinetune: bool = False):
parser = argparse.ArgumentParser()
parser = add_model_config_args(parser)
parser = add_training_args(parser)
+ if pretrain:
+ parser = add_pretrain_args(parser)
+ if finetune:
+ parser = add_finetune_args(parser)
args = parser.parse_args()
return args
diff --git a/cpm-live/cpm_live/dataset/distributed_dataset.py b/cpm-live/cpm_live/dataset/distributed_dataset.py
index d90c612..a9979da 100644
--- a/cpm-live/cpm_live/dataset/distributed_dataset.py
+++ b/cpm-live/cpm_live/dataset/distributed_dataset.py
@@ -15,29 +15,35 @@
import io
import os
-import pickle
-from typing import List
+import struct
+from typing import ListOptionalSet
import torch
import bisect
import bmtrain as bmt
-
+import on
+from .serializer import SerializerPickleSerializer
import random
import string
+import time
def _random_string():
return "".join(random.choices(string.ascii_uppercase + string.digitsk=8))
+_DEFAULT_BLOCK_SIZE = 16 << 20
+
+
class FileInfo:
def __init__(
self,
- file_name: str,
- block_begin: int,
- block_end: int,
- nbytes: int,
- nlines: int,
+ file_name: str = "",
+ block_begin: int = 0,
+ block_end: int = 0,
+ nbytes: int = 0,
+ nlines: int = 0,
mask: bool = False,
+ block_size: int = _DEFAULT_BLOCK_SIZE,
) -> None:
self.file_name = file_name
self.block_begin = block_begin
@@ -45,34 +51,153 @@ def __init__(
self.nbytes = nbytes
self.nlines = nlines
self.mask = mask
+ self.block_size = block_size
- @classmethod
- def _load_from_state(clsversiondata):
- if version == 1:
- file_nameblock_beginblock_endnbytesnlinesmask = data
- return cls(file_nameblock_beginblock_endnbytesnlinesmask)
- else:
- raise RuntimeError("Unsupported version %d" % version)
-
- def __reduce__(self):
- return (
- FileInfo._load_from_state,
- (
- 1,
- (
- self.file_name,
- self.block_begin,
- self.block_end,
- self.nbytes,
- self.nlines,
- self.mask,
- ),
- ),
- )
+ def state_dict(self):
+ return {
+ "file_name": self.file_name,
+ "block_begin": self.block_begin,
+ "block_end": self.block_end,
+ "nbytes": self.nbytes,
+ "nlines": self.nlines,
+ "mask": self.mask,
+ "block_size": self.block_size,
+ }
+
+ def load_state_dict(selfd):
+ self.file_name = d["file_name"]
+ self.block_begin = d["block_begin"]
+ self.block_end = d["block_end"]
+ self.nbytes = d["nbytes"]
+ self.nlines = d["nlines"]
+ self.mask = d["mask"]
+ self.block_size = d["block_size"]
+
+ def dumps(self) -> str:
+ return on.dumps(self.state_dict())
+
+ def loads(selfdata: str) -> "FileInfo":
+ self.load_state_dict(on.loads(data))
+ return self
+
+ def dump(selffp: io.TextIOWrapper) -> "FileInfo":
+ fp.write(self.dumps())
+ return self
+
+ def load(selffp: io.TextIOWrapper) -> "FileInfo":
+ self.loads(fp.read())
+ return self
+
+
+def _read_info_list(meta_path: str) -> List[FileInfo]:
+ info: List[FileInfo] = []
+ while True:
+ try:
+ with open(meta_path"r"encoding="utf-8") as f:
+ for line in f.readlines():
+ line = line.strip()
+ if len(line) > 0:
+ info.append(FileInfo().loads(line))
+ return info
+ except Exception as e:
+ print("Error: reading info list in _read_info_list!,meta_path={path}err={err}".
+ format(path=meta_patherr=str(e)))
+ time.sleep(10)
+
+
+def _write_info_list(meta_path: strinfo: List[FileInfo]):
+ base_path = os.path.dirname(meta_path)
+ random_fname = os.path.join(base_path".meta.bin.%s" % _random_string())
+ while True:
+ try:
+ with open(random_fname"w"encoding="utf-8") as f:
+ for v in info:
+ f.write(v.dumps() + "\n")
+ os.rename(random_fnamemeta_path)
+ return
+ except Exception:
+ print("Error: writing info list!")
+ time.sleep(10)
-_MASK_VALUE = 0x7FFFFFFF
-_DEFAULT_BLOCK_SIZE = 16 << 20
+def _filtered_range(
+ begin: intend: intrank: intworld_size: intfilter_set: Optional[Set[int]] = None
+):
+ begin = begin + (rank + (world_size - (begin % world_size))) % world_size
+
+ if filter_set is not None:
+ return [i for i in range(beginendworld_size) if i in filter_set]
+ else:
+ return [i for i in range(beginendworld_size)]
+
+
+# for some bugs that may exist in hdfs
+class SafeFile:
+
+ def __init__(selffnamemode):
+ self.fname = None
+ self.mode = None
+ self._fp = None
+ self.open_file(fnamemode)
+
+ def read(selfsize=-1):
+ if self._fp is None:
+ raise RuntimeError("Dataset is closed")
+ try:
+ res = self._fp.read(size)
+ self.offset = self._fp.tell()
+ return res
+ except Exception as e:
+ print("Error {}: reading blocks in read {}!".format(eself.fname))
+ self.open_file(self.fnameself.modeself.offset)
+ return self.read(size)
+
+ def tell(self):
+ if self._fp is None:
+ raise RuntimeError("Dataset is closed")
+ try:
+ res = self._fp.tell()
+ self.offset = res
+ return res
+ except Exception as e:
+ print("Error {}: reading blocks in tell {}!".format(eself.fname))
+ self.open_file(self.fnameself.modeself.offset)
+ return self.tell()
+
+ def seek(selfoffsetwhence=0):
+ if self._fp is None:
+ raise RuntimeError("Dataset is closed")
+ try:
+ res = self._fp.seek(offsetwhence)
+ self.offset = self._fp.tell()
+ return res
+ except Exception as e:
+ print("Error {}: reading blocks in seek {}!".format(eself.fname))
+ self.open_file(self.fnameself.modeself.offset)
+ return self.seek(offsetwhence)
+
+ def close(self):
+ if self._fp is not None:
+ try:
+ self._fp.close()
+ except Exception:
+ pass
+ self._fp = None
+
+ def open_file(selffnamemodeoffset=None):
+ if not os.path.exists(fname):
+ raise RuntimeError("Dataset does not exist")
+ try:
+ self.fname = fname
+ self.mode = mode
+ self._fp = open(fnamemode)
+ if offset is not None:
+ self._fp.seek(offsetio.SEEK_SET)
+ self.offset = self._fp.tell()
+ except Exception as e:
+ print("Error {}: reading blocks in open_file {}!".format(eself.fname))
+ time.sleep(10)
+ self.open_file(fnamemodeoffset)
class DistributedDataset:
@@ -100,16 +225,24 @@ def __init__(
path: str,
rank: int = 0,
world_size: int = 1,
- block_size=_DEFAULT_BLOCK_SIZE,
+ serializer: Optional[Serializer] = None,
+ max_repeat_times: Optional[int] = None,
+ shuffle: bool = True,
) -> None:
# config
self._path = path
self._rank = rank
self._world_size = world_size
- self._block_size = block_size
+ self._max_repeat_times = max_repeat_times
+ self._repeat_times = 0
+ self._shuffle = shuffle
+
+ if serializer is None:
+ serializer = PickleSerializer()
+ self.serializer = serializer
# dataset meta
- self._block_states = torch.tensor([]dtype=torch.int)
+ self._unused_block: List[int] = []
self._file_info: List[FileInfo] = []
self._file_ends: List[int] = []
self._total_blocks = 0
@@ -124,12 +257,21 @@ def __init__(
self._last_mod_time = 0
self._curr_fname = None
- self._update_states()
+ self._update_states(fast_skip=False)
+ self._repeat_times += 1
def _update_states(selffast_skip: bool = True):
meta_path = os.path.join(self._path"meta.bin")
- mod_time = os.stat(meta_path).st_mtime
+ while True:
+ try:
+ mod_time = os.stat(meta_path).st_mtime
+ break
+ except Exception as e:
+ print("Error: reading info list in DistributedDataset._update_states"
+ "meta_path={path}err={err}!".format(path=meta_patherr=str(e)))
+ time.sleep(10)
+
if self._last_mod_time < mod_time:
# file changed
pass
@@ -139,12 +281,13 @@ def _update_states(selffast_skip: bool = True):
info: List[FileInfo] = []
if os.path.exists(meta_path):
- with open(meta_path"rb") as f:
- info = pickle.load(f)
+ info = _read_info_list(meta_path)
old_len = len(self._file_info)
if old_len > len(info):
raise RuntimeError("Dataset meta file: changed unexpectly")
+
+ mask_changed = False
for i in range(old_len):
if self._file_info[i].file_name != info[i].file_name:
raise RuntimeError("Dataset meta file: changed unexpectly")
@@ -152,6 +295,8 @@ def _update_states(selffast_skip: bool = True):
raise RuntimeError("Dataset meta file: changed unexpectly")
if self._file_info[i].block_end != info[i].block_end:
raise RuntimeError("Dataset meta file: changed unexpectly")
+ if self._file_info[i].mask != info[i].mask:
+ mask_changed = True
if info[0].block_begin != 0:
raise RuntimeError("Dataset meta file: block error (0)")
@@ -159,7 +304,7 @@ def _update_states(selffast_skip: bool = True):
if info[i].block_end != info[i + 1].block_begin:
raise RuntimeError("Dataset meta file: block error (%d)" % (i + 1))
- if old_len == len(info) and fast_skip:
+ if (old_len == len(info) and not mask_changed) and fast_skip:
# fast skip
return
@@ -176,32 +321,38 @@ def _update_states(selffast_skip: bool = True):
self._nlines = 0
if total_blocks > 0:
- masks = torch.full(
- (total_blocks,),
- _MASK_VALUE,
- dtype=torch.int,
- device="cpu",
- requires_grad=False,
- )
- masks[self._rank :: self._world_size] = 0
- for v in info:
- if v.mask or (not os.path.exists(self._get_file_path(v.file_name))):
- masks[v.block_begin : v.block_end] = _MASK_VALUE
- new_block_states = torch.zeros(
- total_blocksdtype=torch.intdevice="cpu"requires_grad=False
- )
- new_block_states[: self._block_states.size(0)] = self._block_states
- new_block_states = torch.maximum(new_block_statesmasks)
-
- self._block_states = new_block_states
+ unused_block_set = set(self._unused_block)
+ nw_unused_block: List[int] = []
+ for i in range(len(info)):
+ v = info[i]
+ if not v.mask:
+ if i < old_len:
+ nw_unused_block.extend(
+ _filtered_range(
+ v.block_begin,
+ v.block_end,
+ self._rank,
+ self._world_size,
+ unused_block_set,
+ )
+ )
+ else:
+ nw_unused_block.extend(
+ _filtered_range(
+ v.block_beginv.block_endself._rankself._world_size
+ )
+ )
+
+ # re-shuffle unused blocks
+ if self._shuffle:
+ random.shuffle(nw_unused_block)
+ self._unused_block = nw_unused_block
self._file_ends = []
for v in info:
self._file_ends.append(v.block_end)
else:
- self._block_states = torch.tensor(
- []dtype=torch.intdevice="cpu"requires_grad=False
- )
+ self._unused_block = []
self._file_ends = []
self._total_blocks = total_blocks
self._file_info = info
@@ -209,27 +360,62 @@ def _update_states(selffast_skip: bool = True):
assert len(self._file_ends) == len(self._file_info)
def _mask_file(selff: FileInfo):
- masks = torch.full(
- (self._total_blocks,)0dtype=torch.intdevice="cpu"requires_grad=False
- )
- masks[f.block_begin : f.block_end] = _MASK_VALUE
- self._block_states = torch.maximum(self._block_statesmasks)
+ self._unused_block = [
+ block_id
+ for block_id in self._unused_block
+ if block_id < f.block_begin or block_id >= f.block_end
+ ]
def _get_block_file(selfblock_id: int):
# find block in which file
file_idx = bisect.bisect_right(self._file_endsblock_id)
return self._file_info[file_idx]
+ def _prepare_new_epoch(self):
+ if self._max_repeat_times is not None:
+ if self._repeat_times >= self._max_repeat_times:
+ raise EOFError("End of dataset")
+ nw_unused_block: List[int] = []
+ for v in self._file_info:
+ if not v.mask:
+ nw_unused_block.extend(
+ _filtered_range(v.block_beginv.block_endself._rankself._world_size)
+ )
+ if self._shuffle:
+ random.shuffle(nw_unused_block)
+ self._unused_block = nw_unused_block
+ self._repeat_times += 1
+
def _get_next_block(self):
self._update_states()
- if self._block_states.size(0) == 0:
- raise RuntimeError("Empty dataset")
- mn_block: int = self._block_states.argmin().item() # type: ignore
- if self._block_states[mn_block].item() == _MASK_VALUE:
- raise RuntimeError("Empty dataset")
- self._block_states[mn_block] += 1
+ if len(self._unused_block) == 0:
+ self._prepare_new_epoch()
+ if len(self._unused_block) == 0:
+ raise RuntimeError("Empty dataset {}".format(self._path))
+
+ mn_block: int = self._unused_block.pop()
return mn_block
+ def _state_dict(self):
+ self._update_states()
+ num_unused_block = len(self._unused_block)
+ if (self._fp is not None) and (self._curr_block is not None):
+ curr_block = self._curr_block
+ curr_f = self._get_block_file(curr_block)
+ inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * curr_f.block_size
+ else:
+ curr_block = -1
+ inblock_offset = 0
+
+ return {
+ "states": torch.tensor(self._unused_blockdtype=torch.longdevice="cpu"),
+ "block": torch.tensor(
+ [curr_blockinblock_offsetnum_unused_blockself._repeat_times],
+ dtype=torch.long,
+ device="cpu",
+ ),
+ }
+
def state_dict(self):
"""Returns a state dict representing the read states of the dataset.
@@ -238,32 +424,43 @@ def state_dict(self):
>>> dataset.load_state_dict(state)
"""
self._update_states()
- states = torch.where(
- self._block_states == _MASK_VALUE,
- torch.zeros(self._total_blocksdtype=torch.intdevice="cpu"requires_grad=False),
- self._block_states,
- )
+ num_unused_block = len(self._unused_block)
if (self._fp is not None) and (self._curr_block is not None):
curr_block = self._curr_block
curr_f = self._get_block_file(curr_block)
- inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * self._block_size
+ inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * curr_f.block_size
else:
curr_block = -1
inblock_offset = 0
with torch.no_grad():
if self._world_size > 1:
- gpu_states = states.cuda()
- gpu_block = torch.tensor([curr_blockinblock_offset]dtype=torch.long).cuda()
- global_states = bmt.distributed.all_reduce(gpu_statesop="sum").cpu()
- global_block = bmt.distributed.all_gather(gpu_block).cpu()
+ gpu_num_unused_block = torch.tensor([num_unused_block]dtype=torch.long).cuda()
+ max_unused_blocks = (
+ bmt.distributed.all_reduce(gpu_num_unused_blockop="max").cpu().item()
+ )
+ gpu_states = torch.full((max_unused_blocks,)-1dtype=torch.long).cuda()
+ gpu_states[:num_unused_block] = torch.tensor(
+ self._unused_blockdtype=torch.long
+ ).cuda()
+
+ gpu_block = torch.tensor(
+ [curr_blockinblock_offsetnum_unused_blockself._repeat_times],
+ dtype=torch.long,
+ ).cuda()
+ global_states = bmt.distributed.all_gather(
+ gpu_states
+ ).cpu() # (world_sizemax_unused_blocks)
+ global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size4)
return {"states": global_states"block": global_block}
else:
return {
- "states": states,
+ "states": torch.tensor([self._unused_block]dtype=torch.longdevice="cpu"),
"block": torch.tensor(
- [[curr_blockinblock_offset]]dtype=torch.longdevice="cpu"
+ [[curr_blockinblock_offsetnum_unused_blockself._repeat_times]],
+ dtype=torch.long,
+ device="cpu",
),
}
@@ -278,11 +475,10 @@ def load_state_dict(selfstatestrict: bool = True):
>>> state = dataset.state_dict()
>>>
"""
+ block_states: torch.LongTensor = state["states"]
+ block_info: torch.LongTensor = state["block"]
- self._block_states = state["states"]
- self._update_states(False)
-
- if state["block"].size(0) != self._world_size:
+ if block_states.size(0) != self._world_size:
if strict:
raise ValueError(
"world_size changed (%d -> %d)" % (state["block"].size(0)self._world_size)
@@ -291,20 +487,48 @@ def load_state_dict(selfstatestrict: bool = True):
self._curr_block = None
self._fp = None
self._curr_fname = None
+ self._repeat_times = int(block_info[03].item())
+
+ # re-shuffle unused blocks
+ nw_unused_block: List[int] = []
+ for i in range(block_states.size(0)):
+ # filter blocks that are not in this rank
+ num_unused_blocks: int = int(block_info[i2].item())
+ nw_unused_block.extend(
+ [
+ block_id
+ for block_id in block_states[i:num_unused_blocks].tolist()
+ if block_id % self._world_size == self._rank
+ ]
+ )
+ if self._shuffle:
+ random.shuffle(nw_unused_block)
+ self._unused_block = nw_unused_block
else:
- curr_block = state["block"][self._rank][0].item()
- inblock_offset = state["block"][self._rank][1].item()
+ curr_blockinblock_offsetnum_unused_blocksself._repeat_times = block_info[
+ self._rank
+ ].tolist()
if curr_block == -1:
self._curr_block = None
else:
- self._curr_block = curr_block
- f_info = self._get_block_file(self._curr_block)
- self._open_file(
- f_info.file_name,
- (self._curr_block - f_info.block_begin) * self._block_size + inblock_offset,
- )
+ while True:
+ try:
+ self._curr_block = curr_block
+ f_info = self._get_block_file(self._curr_block)
+ self._open_file(
+ f_info.file_name,
+ (self._curr_block - f_info.block_begin)
+ * f_info.block_size
+ + inblock_offset,
+ )
+ self._unused_block = block_states[self._rank:num_unused_blocks].tolist()
+ break
+ except Exception:
+ print("Error: reading block!")
+ time.sleep(10)
# end
+ self._update_states()
def _get_file_path(selffname):
return os.path.join(self._pathfname)
@@ -314,11 +538,11 @@ def _open_file(selffnameoffset):
if self._fp is not None:
self._fp.close()
self._curr_fname = None
- self._fp = open(self._get_file_path(fname)"rb")
+ # self._fp = open(self._get_file_path(fname)"rb")
+ self._fp = SafeFile(self._get_file_path(fname)"rb")
self._curr_fname = fname
else:
assert self._fp is not None"Unexpected error"
-
self._fp.seek(offsetio.SEEK_SET) # move to block
def read(self):
@@ -332,10 +556,11 @@ def read(self):
try:
self._open_file(
f_info.file_name,
- (next_block_id - f_info.block_begin) * self._block_size,
+ (next_block_id - f_info.block_begin) * f_info.block_size,
)
self._curr_block = next_block_id
except FileNotFoundError:
+ print("ERR: reading again!")
self._mask_file(f_info)
return self.read() # read again
@@ -345,7 +570,9 @@ def read(self):
MAGIC = self._fp.read(1)
if MAGIC == b"\x1F":
# correct
- return pickle.load(self._fp)
+ size = struct.unpack("I"self._fp.read(4))[0]
+ data = self._fp.read(size)
+ return self.serializer.deserialize(data)
elif MAGIC == b"\x00":
# end of block
self._curr_block = None
@@ -359,24 +586,27 @@ def nbytes(self):
class SimpleDataset(DistributedDataset):
- def __init__(selfpath: strblock_size=_DEFAULT_BLOCK_SIZE) -> None:
- super().__init__(path01block_size)
-
- def _get_next_block(self):
- self._update_states()
- if self._block_states.size(0) == 0:
- raise RuntimeError("Empty dataset")
- mn_block: int = self._block_states.argmin().item() # type: ignore
- if self._block_states[mn_block].item() >= 1:
- raise EOFError("no more data")
- self._block_states[mn_block] += 1
- return mn_block
+ def __init__(
+ self,
+ path: str,
+ serializer: Optional[Serializer] = None,
+ shuffle: bool = True,
+ ) -> None:
+ super().__init__(
+ path,
+ 0,
+ 1,
+ serializer=serializer,
+ max_repeat_times=1,
+ shuffle=shuffle,
+ )
def __iter__(self):
while True:
try:
data = self.read()
except EOFError:
+ self._repeat_times = 0
break
yield data
@@ -385,7 +615,7 @@ def __len__(self):
class DatasetWriter:
- def __init__(selffnameblock_size):
+ def __init__(selffname: strblock_size: intserializer: Optional[Serializer] = None):
self._fname = fname
self._block_size = block_size
self._fp = open(self._fname"wb")
@@ -395,6 +625,10 @@ def __init__(selffnameblock_size):
self._nlines = 0
self._nblocks = 1
+ if serializer is None:
+ serializer = PickleSerializer()
+ self.serializer = serializer
+
def write(selfdata):
"""Write a piece of data into dataset.
@@ -405,7 +639,8 @@ def write(selfdata):
>>> writer.write( "anything you want" )
"""
- byte_data = pickle.dumps(data)
+ byte_data = self.serializer.serialize(data)
+ byte_data = struct.pack("I"len(byte_data)) + byte_data
if self._inblock_offset + 2 + len(byte_data) > self._block_size:
self._fp.write(
b"\x00" * (self._block_size - self._inblock_offset)
@@ -442,10 +677,19 @@ def close(self):
class DatasetBuilder:
- def __init__(selfpath: strdbname: strblock_size=_DEFAULT_BLOCK_SIZE) -> None:
+ def __init__(
+ self,
+ path: str,
+ dbname: str,
+ block_size=_DEFAULT_BLOCK_SIZE,
+ serializer: Optional[Serializer] = None,
+ ) -> None:
self._block_size = block_size
self._path = path
self._dbname = dbname
+ if serializer is None:
+ serializer = PickleSerializer()
+ self.serializer = serializer
if not os.path.exists(self._path):
os.makedirs(self._path)
@@ -454,8 +698,7 @@ def __init__(selfpath: strdbname: strblock_size=_DEFAULT_BLOCK_SIZE) -> No
info: List[FileInfo] = []
if os.path.exists(meta_path):
- with open(meta_path"rb") as f:
- info = pickle.load(f)
+ info = _read_info_list(meta_path)
for v in info:
if v.file_name == dbname:
@@ -466,7 +709,7 @@ def __init__(selfpath: strdbname: strblock_size=_DEFAULT_BLOCK_SIZE) -> No
raise ValueError("File exists `%s`" % self._db_path)
def __enter__(self):
- self._writer = DatasetWriter(self._db_pathself._block_size)
+ self._writer = DatasetWriter(self._db_pathself._block_sizeself.serializer)
return self._writer
def __exit__(selfexc_typeexc_valueexc_traceback):
@@ -482,8 +725,7 @@ def __exit__(selfexc_typeexc_valueexc_traceback):
meta_path = os.path.join(self._path"meta.bin")
info: List[FileInfo] = []
if os.path.exists(meta_path):
- with open(meta_path"rb") as f:
- info = pickle.load(f)
+ info = _read_info_list(meta_path)
last_block = 0
if len(info) > 0:
last_block = info[-1].block_end
@@ -495,19 +737,22 @@ def __exit__(selfexc_typeexc_valueexc_traceback):
self._writer.nbytes,
self._writer.nlines,
False,
+ self._block_size,
)
)
# atomic write to meta file
- random_fname = os.path.join(self._path".meta.bin.%s" % _random_string())
- with open(random_fname"wb") as f:
- pickle.dump(infof)
- os.rename(random_fnamemeta_path)
+ _write_info_list(meta_pathinfo)
self._writer = None
-def build_dataset(path: strdbname: strblock_size: int = _DEFAULT_BLOCK_SIZE):
+def build_dataset(
+ path: str,
+ dbname: str,
+ block_size: int = _DEFAULT_BLOCK_SIZE,
+ serializer: Optional[Serializer] = None,
+):
"""Open the dataset in write mode and returns a writer.
Args:
@@ -520,4 +765,4 @@ def build_dataset(path: strdbname: strblock_size: int = _DEFAULT_BLOCK_SIZE)
>>> for i in range(10):
>>> writer.write( { "anything you want" } )
""" # noqa: E501
- return DatasetBuilder(pathdbnameblock_size)
+ return DatasetBuilder(pathdbnameblock_size=block_sizeserializer=serializer)
diff --git a/cpm-live/cpm_live/dataset/serializer.py b/cpm-live/cpm_live/dataset/serializer.py
new file mode 100644
index 0000000..ef3865f
--- /dev/null
+++ b/cpm-live/cpm_live/dataset/serializer.py
@@ -0,0 +1,61 @@
+# coding=utf-8
+# Copyright 2020 The OpenBMB team. All rights reserved.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pickle
+import on
+
+
+class Serializer:
+ def __init__(self) -> None:
+ pass
+
+ def serialize(selfobj) -> bytes:
+ raise NotImplementedError()
+
+ def deserialize(selfdata: bytes):
+ raise NotImplementedError()
+
+
+class PickleSerializer(Serializer):
+ def __init__(self) -> None:
+ pass
+
+ def serialize(selfobj) -> bytes:
+ return pickle.dumps(obj)
+
+ def deserialize(selfdata: bytes):
+ return pickle.loads(data)
+
+
+class JsonSerializer(Serializer):
+ def __init__(self) -> None:
+ pass
+
+ def serialize(selfobj) -> bytes:
+ return on.dumps(objensure_ascii=False).encode("utf-8")
+
+ def deserialize(selfdata: bytes):
+ return on.loads(data.decode("utf-8"))
+
+
+class RawSerializer(Serializer):
+ def __init__(self) -> None:
+ pass
+
+ def serialize(selfobj) -> bytes:
+ return obj
+
+ def deserialize(selfdata: bytes):
+ return data
diff --git a/cpm-live/cpm_live/dataset/utils.py b/cpm-live/cpm_live/dataset/utils.py
index c166a56..2a6ace6 100644
--- a/cpm-live/cpm_live/dataset/utils.py
+++ b/cpm-live/cpm_live/dataset/utils.py
@@ -14,16 +14,19 @@
# limitations under the License.
import os
-from typing import List
+import struct
+from typing import ListOptional
from .distributed_dataset import (
SimpleDataset,
build_dataset,
+ _read_info_list,
+ _write_info_list,
_random_string,
_DEFAULT_BLOCK_SIZE,
FileInfo,
)
+from .serializer import RawSerializer
import random
-import pickle
import shutil
try:
@@ -42,7 +45,8 @@ def shuffle_dataset(
path_tgt: str,
block_size: int = _DEFAULT_BLOCK_SIZE,
bucket_size: int = _DEFAULT_SHUFFLE_BUCKET_SIZE,
- progress_bar=False,
+ progress_bar: bool = False,
+ output_name: Optional[str] = None,
):
"""Shuffle one distributed datatasetwrite results to another dataset.
@@ -60,7 +64,7 @@ def shuffle_dataset(
if progress_bar and not support_tqdm:
raise RuntimeError("Requires `tqdm` to enable progress bar.")
- ds = SimpleDataset(path_srcblock_size=block_size)
+ ds = SimpleDataset(path_srcserializer=RawSerializer())
num_buckets = (ds.nbytes + bucket_size - 1) // bucket_size
tmp_files = [os.path.join(path_src".tmp.%s" % _random_string()) for _ in range(num_buckets)]
@@ -74,7 +78,8 @@ def shuffle_dataset(
iterator = tqdm(dsdesc="Shuffle step 1/2")
for data in iterator:
bucket_id = int(random.random() * num_buckets)
- pickle.dump(dataf_tmp[bucket_id]) # write into a random bucket
+ len_data = len(data)
+ f_tmp[bucket_id].write(struct.pack("I"len_data) + data)
finally:
# close all files
for fp in f_tmp:
@@ -83,7 +88,14 @@ def shuffle_dataset(
f_tmp = []
# Step 2: shuffle inside bucket
- with build_dataset(path_tgt"%s.shuffle" % _random_string()) as writer:
+ if output_name is None:
+ output_name = "%s.shuffle" % _random_string()
+ with build_dataset(
+ path_tgt,
+ output_name,
+ block_size=block_size,
+ serializer=RawSerializer(),
+ ) as writer:
iterator = tmp_files
if progress_bar:
iterator = tqdm(tmp_filesdesc="Shuffle step 2/2")
@@ -93,7 +105,12 @@ def shuffle_dataset(
data_in_bucket = []
while True:
try:
- data_in_bucket.append(pickle.load(fp))
+ raw_data = fp.read(4)
+ if len(raw_data) == 0:
+ # EOF
+ break
+ len_data = struct.unpack("I"raw_data)[0]
+ data_in_bucket.append(fp.read(len_data))
except EOFError:
break
random.shuffle(data_in_bucket)
@@ -125,12 +142,11 @@ def compact_dataset(path: str):
info: List[FileInfo] = []
if os.path.exists(meta_path):
- with open(meta_path"rb") as f:
- info = pickle.load(f)
+ info = _read_info_list(meta_path)
else:
raise ValueError("Dataset not exists")
- nw_info = []
+ nw_info: List[FileInfo] = []
curr_block = 0
for v in info:
if not os.path.exists(v.file_name):
@@ -146,14 +162,12 @@ def compact_dataset(path: str):
v.nbytes,
v.nlines,
v.mask,
+ v.block_size,
)
)
curr_block += num_file_block
- random_fname = os.path.join(path".meta.bin.%s" % _random_string())
- with open(random_fname"wb") as f:
- pickle.dump(nw_infof)
- os.rename(random_fnamemeta_path)
+ _write_info_list(meta_pathnw_info)
def mask_dataset(path: strdbname: strmask: bool = True):
@@ -173,19 +187,14 @@ def mask_dataset(path: strdbname: strmask: bool = True):
info: List[FileInfo] = []
if os.path.exists(meta_path):
- with open(meta_path"rb") as f:
- info = pickle.load(f)
+ info = _read_info_list(meta_path)
else:
raise ValueError("Dataset not exists")
for v in info:
if v.file_name == dbname:
v.mask = mask
-
- random_fname = os.path.join(path".meta.bin.%s" % _random_string())
- with open(random_fname"wb") as f:
- pickle.dump(infof)
- os.rename(random_fnamemeta_path)
+ _write_info_list(meta_pathinfo)
def merge_dataset(dst: strsrc: str):
@@ -195,15 +204,13 @@ def merge_dataset(dst: strsrc: str):
info_src: List[FileInfo] = []
if os.path.exists(meta_path_src):
- with open(meta_path_src"rb") as f:
- info_src = pickle.load(f)
+ info_src = _read_info_list(meta_path_src)
else:
raise ValueError("Dataset not exists")
info_dst: List[FileInfo] = []
if os.path.exists(meta_path_dst):
- with open(meta_path_dst"rb") as f:
- info_dst = pickle.load(f)
+ info_dst = _read_info_list(meta_path_dst)
else:
raise ValueError("Dataset not exists")
@@ -219,6 +226,7 @@ def merge_dataset(dst: strsrc: str):
v.nbytes,
v.nlines,
v.mask,
+ v.block_size,
)
)
curr_block += num_file_block
@@ -244,11 +252,9 @@ def merge_dataset(dst: strsrc: str):
v.nbytes,
v.nlines,
v.mask,
+ v.block_size,
)
)
curr_block += num_file_block
- random_fname = os.path.join(dst".meta.bin.%s" % _random_string())
- with open(random_fname"wb") as f:
- pickle.dump(nw_infof)
- os.rename(random_fnamemeta_path_dst)
+ _write_info_list(meta_path_dstnw_info)
diff --git a/cpm-live/cpm_live/generation/__init__.py b/cpm-live/cpm_live/generation/__init__.py
new file mode 100644
index 0000000..70de125
--- /dev/null
+++ b/cpm-live/cpm_live/generation/__init__.py
@@ -0,0 +1 @@
+from .ant import CPMAntBeamSearchCPMAntRandomSamplingCPMAntGeneration
diff --git a/cpm-live/cpm_live/generation/ant.py b/cpm-live/cpm_live/generation/ant.py
new file mode 100644
index 0000000..a9e5893
--- /dev/null
+++ b/cpm-live/cpm_live/generation/ant.py
@@ -0,0 +1,385 @@
+import torch
+import torch.nn.functional as F
+from .generation_utils import BeamHypothesesapply_repetition_penaltytop_k_top_p_filtering
+from ..utils import pad
+
+
+class CPMAntGeneration:
+ def __init__(selfmodeltokenizerprompt_length=32):
+ model.eval()
+ self.model = model
+ self.tokenizer = tokenizer
+ self.prompt_length = prompt_length
+
+ def _convert_to_tensors(selfinput_texttask_id=2):
+ model_inputs = {}
+ input_ids = [self.tokenizer.bos_id] + self.tokenizer.encode(input_text)
+ input_ids = [j for j in input_ids if j != self.tokenizer.unk_id]
+
+ model_inputs["input"] = [
+ x + self.prompt_length * task_id for x in range(self.prompt_length)
+ ] + input_ids
+ model_inputs["length"] = len(model_inputs["input"])
+ model_inputs["position"] = list(range(len(model_inputs["input"])))
+ model_inputs["span"] = [0] * len(model_inputs["input"])
+ model_inputs["context"] = [True] * len(model_inputs["input"])
+ model_inputs["segment"] = [0] * self.prompt_length + [2] * len(input_ids)
+
+ for key in model_inputs:
+ model_inputs[key] = torch.tensor(model_inputs[key]).int().unsqueeze(0)
+
+ return model_inputs
+
+ def _process_texts(selftext_list):
+ input_tensors = list(map(self._convert_to_tensorstext_list))
+ keys = set(input_tensors[0].keys())
+ padded = {}
+ for key in keys:
+ padded[key] = pad(input_tensorskeypadding_side='left').cuda()
+ return padded
+
+ def generate(selftext_list**kwargs):
+ model_inputs = self._process_texts(text_list)
+ with torch.inference_mode():
+ result = self._decode(model_inputs**kwargs)
+ return result
+
+ def _decode(selfmodel_inputs**kwargs):
+ raise NotImplementedError("_decode is not implemented.")
+
+
+class CPMAntBeamSearch(CPMAntGeneration):
+ def _decode(
+ self,
+ model_inputs,
+ beam_size=3,
+ max_length=100,
+ repetition_penalty=1.0,
+ repetition_window=None,
+ **kwargs
+ ):
+ """
+ Beam search
+ Args:
+ model_inputs (dict): input ids.
+ beam_size (intoptionaldefaults to 3): beam size of beam search.
+ generate_length (intoptionaldefaults to 100): maximum generation length.
+ repetition_penalty (floatoptionaldefaults to 1.0): repetition penalty coefficient1.0 means no penalty.
+ repetition_window (intoptionaldefaults to None): window size of repetition penaltyNone means that all output tokens are penalized.
+ """ # noqa: E501
+ # generate_length + 1 for EOS token
+ max_length += 1
+
+ # expand dimmension
+ batch_size = model_inputs["input"].size(0)
+ input = (
+ model_inputs["input"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+ length = (
+ model_inputs["length"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size)
+ .contiguous()
+ .view(
+ batch_size * beam_size,
+ )
+ )
+ context = (
+ model_inputs["context"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+ position = (
+ model_inputs["position"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+ segment = (
+ model_inputs["segment"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+ span = (
+ model_inputs["span"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+
+ done = [False for _ in range(batch_size)]
+
+ beam_scores = torch.zeros((batch_sizebeam_size)dtype=torch.floatdevice=input.device)
+ beam_scores[:1:] = -1e9
+ beam_scores = beam_scores.view(-1)
+
+ # generated hypotheses
+ generated_hyps = [
+ BeamHypotheses(beam_sizemax_lengthlength_penalty=1early_stopping=False)
+ for _ in range(batch_size)
+ ]
+
+ pred_start_index = input.size(-1)
+ past_key_values = None
+ for i in range(max_length + 1):
+ if i == 0:
+ logits_past_key_values = self.model.inference(
+ input=input,
+ length=length,
+ context=context,
+ position=position,
+ segment=segment,
+ span=span,
+ past_key_values=past_key_values,
+ )
+ else:
+ logits_past_key_values = self.model.inference(
+ input=input[:-1:],
+ length=length,
+ context=context,
+ position=position,
+ segment=segment,
+ span=span,
+ past_key_values=past_key_values,
+ )
+
+ # skip all steps when we are done with each sentence
+ if all(done):
+ break
+
+ # (batch * beamseqlenmodel_dim)
+ logits = logits[:-1:]
+
+ if i == 0:
+ logits[:self.tokenizer.eos_id] = -float("inf")
+ logits[:self.tokenizer.newline_id] = -float("inf")
+
+ apply_repetition_penalty(
+ logits,
+ batch_size,
+ beam_size,
+ input,
+ repetition_penalty,
+ pred_start_index,
+ input.size(-1) - 1,
+ repetition_window,
+ )
+ scores = F.log_softmax(logitsdim=-1)
+
+ next_scores = scores + beam_scores[:None].expand_as(
+ scores
+ ) # (batch_size * beam_sizevocab_size)
+
+ # re-organize to group the beam together (we are keeping top hypothesis accross beams)
+ next_scores = next_scores.view(batch_size-1) # (batch_sizebeam_size * vocab_size)
+ next_scoresnext_words = torch.topk(
+ next_scores2 * beam_sizedim=1largest=Truesorted=True
+ )
+
+ assert next_scores.size() == next_words.size() == (batch_size2 * beam_size)
+ next_batch_beam = []
+
+ for sent_id in range(batch_size):
+ # if we are done with this sentence
+ done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(
+ next_scores[sent_id].max().item()i
+ )
+ if done[sent_id]:
+ next_batch_beam.extend(
+ [(0self.tokenizer.pad_id0)] * beam_size
+ ) # pad the batch
+ continue
+
+ # next sentence beam content
+ next_sent_beam = []
+
+ # next words for this sentence
+ for idxvalue in zip(next_words[sent_id]next_scores[sent_id]):
+
+ # get beam and word IDs
+ beam_id = torch.div(idxscores.size(-1)rounding_mode="floor")
+ word_id = idx % scores.size(-1)
+
+ # end of sentenceor next word
+ if word_id == self.tokenizer.eos_id or i == max_length:
+ generated_hyps[sent_id].add(
+ input[sent_id * beam_size + beam_idpred_start_index:]
+ .clone()
+ .cpu()
+ .tolist(),
+ value.item(),
+ )
+ else:
+ next_sent_beam.append((valueword_idsent_id * beam_size + beam_id))
+
+ # the beam for next step is full
+ if len(next_sent_beam) == beam_size:
+ break
+
+ # update next beam content
+ assert len(next_sent_beam) == 0 if i == max_length else beam_size
+ if len(next_sent_beam) == 0:
+ next_sent_beam = [(0self.tokenizer.pad_id0)] * beam_size # pad the batch
+ next_batch_beam.extend(next_sent_beam)
+ assert len(next_batch_beam) == beam_size * (sent_id + 1)
+
+ # we have reached the last step
+ if i == max_length:
+ break
+
+ # sanity check / prepare next batch
+ assert len(next_batch_beam) == batch_size * beam_size
+ beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
+ beam_words = input.new([x[1] for x in next_batch_beam])
+ beam_idx = length.new([x[2] for x in next_batch_beam]).long()
+
+ # re-order batch and internal states
+ input = input[beam_idx:]
+
+ past_key_values = [list(each) if each is not None else each for each in past_key_values] # type: ignore # noqa: E501
+ for key_value_layer in past_key_values:
+ if key_value_layer is not None:
+ key_value_layer[0] = key_value_layer[0][beam_idx]
+ key_value_layer[1] = key_value_layer[1][beam_idx]
+
+ # update input ids
+ input = torch.cat([inputbeam_words.unsqueeze(1)]dim=-1)
+ length += 1
+ context = torch.cat(
+ [contexttorch.ones((context.size(0)1)dtype=torch.intdevice=context.device)],
+ dim=-1,
+ )
+ position = torch.cat([positionposition[:-1:] + 1]dim=-1)
+ segment = torch.cat(
+ [segmentsegment[:-1:]]dim=-1
+ ) # segment id always the same as the previous token
+ span = torch.cat([spanspan[:-1:]]dim=-1)
+
+ # select the best hypotheses
+ results = []
+ for ihypotheses in enumerate(generated_hyps):
+ best_hyp = max(hypotheses.hypkey=lambda x: x[0])[1]
+ results.append(best_hyp)
+
+ result_text = list(map(self.tokenizer.decoderesults))
+ return result_text
+
+
+class CPMAntRandomSampling(CPMAntGeneration):
+ def _decode(
+ self,
+ model_inputs,
+ max_length=100,
+ top_k=0,
+ top_p=0.9,
+ temperature=0.9,
+ repetition_penalty=1.0,
+ repetition_window=None,
+ **kwargs
+ ):
+ """
+ Top-k and top-p sampling.
+ Args:
+ model_inputs (dict): input ids
+ generate_length (intoptionaldefaults to 100): maximum generation length
+ top_k (intoptionaldefaults to 0): keep only top k tokens with highest probability. 0 means keeping all tokens.
+ top_p (intoptionaldefaults to 0.9): keep the top tokens with cumulative probability >= top_p.
+ temperature (intoptionaldefaults to 0.9): the value that can cool down the logits distribution.
+ repetition_penalty (floatoptionaldefaults to 1.0): repetition penalty coefficient1.0 means no penalty.
+ repetition_window (intoptionaldefaults to None): window size of repetition penaltyNone means that all output tokens are penalized.
+ """ # noqa: E501
+ # generate_length + 1 for EOS token
+ max_length += 1
+
+ input = model_inputs["input"]
+ length = model_inputs["length"]
+ context = model_inputs["context"]
+ position = model_inputs["position"]
+ segment = model_inputs["segment"]
+ span = model_inputs["span"]
+ batch_size = input.size(0)
+
+ pred_start_index = input.size(-1)
+ past_key_values = None
+ done = [False for _ in range(batch_size)]
+ results = [None for _ in range(batch_size)]
+ for i in range(max_length):
+ if i == 0:
+ logits_past_key_values = self.model.inference(
+ input=input,
+ length=length,
+ context=context,
+ position=position,
+ segment=segment,
+ span=span,
+ past_key_values=past_key_values,
+ )
+ else:
+ logits_past_key_values = self.model.inference(
+ input=input[:-1:],
+ length=length,
+ context=context,
+ position=position,
+ segment=segment,
+ span=span,
+ past_key_values=past_key_values,
+ )
+
+ logits = logits[:-1:]
+
+ if i == 0:
+ logits[:self.tokenizer.eos_id] = -float("inf")
+ logits[:self.tokenizer.newline_id] = -float("inf")
+
+ apply_repetition_penalty(
+ logits,
+ batch_size,
+ 1,
+ input,
+ repetition_penalty,
+ pred_start_index,
+ input.size(-1) - 1,
+ repetition_window,
+ )
+
+ logits = logits / temperature
+ logits = top_k_top_p_filtering(logitstop_k=top_ktop_p=top_p)
+
+ probs = F.softmax(logitsdim=-1)
+ next_token = torch.multinomial(probsnum_samples=1)
+
+ for idx in range(batch_size):
+ if not done[idx] and (
+ next_token[idx].item() == self.tokenizer.eos_id or i == max_length - 1
+ ):
+ done[idx] = True
+ results[idx] = input[idxpred_start_index:].clone().cpu().tolist() # type: ignore # noqa: E501
+
+ if sum(done) == batch_size:
+ break
+
+ # update input ids
+ input = torch.cat([inputnext_token]dim=-1)
+ length += 1
+ context = torch.cat(
+ [contexttorch.ones((context.size(0)1)dtype=torch.intdevice=context.device)],
+ dim=-1,
+ )
+ position = torch.cat([positionposition[:-1:] + 1]dim=-1)
+ segment = torch.cat(
+ [segmentsegment[:-1:]]dim=-1
+ ) # segment id always the same as the previous token
+ span = torch.cat([spanspan[:-1:]]dim=-1)
+
+ result_text = list(map(self.tokenizer.decoderesults))
+ return result_text
diff --git a/cpm-live/cpm_live/generation/bee.py b/cpm-live/cpm_live/generation/bee.py
new file mode 100644
index 0000000..5120b1f
--- /dev/null
+++ b/cpm-live/cpm_live/generation/bee.py
@@ -0,0 +1,629 @@
+from typing import AnyDictListTuple
+import numpy as np
+import torch
+import torch.nn.functional as F
+from .generation_utils import BeamHypothesesapply_repetition_penalty
+from ..tokenizers.bee import CPMBeeTokenizer
+from ..models.bee import CPMBee
+from ..training_tasks.bee.pretrain import convert_data_to_id
+from ..utils import pad
+
+
+class CPMBeeGeneration:
+ def __init__(selfmodel: CPMBeetokenizer: CPMBeeTokenizer):
+ model.eval()
+ self.model = model
+ self.tokenizer = tokenizer
+
+ def _convert_to_tensors(selfdata: Anyin_context_samples: List[Any] = []):
+ answer_placeholders = []
+
+ def _put_placeholder(data: Anypath: List[str] = []):
+ if isinstance(datadict):
+ ret = {}
+ for kv in data.items():
+ ret[k] = _put_placeholder(vpath + [k])
+ return ret
+ else:
+ answer_placeholders.append(path)
+ return "
".format(len(answer_placeholders))
+
+ data[""] = _put_placeholder(data[""])
+ (
+ input_ids,
+ input_id_subs,
+ context,
+ segment_ids,
+ segment_rel,
+ n_segments,
+ table_states,
+ ) = convert_data_to_id(self.tokenizerdatashuffle_answer=Falsemax_depth=8)
+
+ sub_ans_map: Dict[intint] = {}
+ for fake_idtoken_sub in table_states["token_id_table"][""].items():
+ token = table_states["ext_table"][fake_id]
+ if token.startswith(""):
+ ans_id = int(token[5:-1])
+ sub_ans_map[token_sub] = ans_id
+
+ tmp_input_ids = []
+ tmp_input_sub = []
+ tmp_input_seg = []
+
+ predict_segments: List[Tuple[intint]] = []
+ for i in range(input_ids.shape[0]):
+ if context[i] == 0:
+ if input_ids[i] == self.tokenizer.encoder[""]:
+ # is ans
+ # (segment_idans_id)
+ predict_segments.append((segment_ids[i]sub_ans_map[input_id_subs[i]]))
+ else:
+ tmp_input_ids.append(input_ids[i])
+ tmp_input_sub.append(input_id_subs[i])
+ tmp_input_seg.append(segment_ids[i])
+
+ if len(predict_segments) == 0:
+ raise ValueError("No answer to predict")
+
+ input_ids = np.array(tmp_input_idsdtype=np.int32)
+ input_id_subs = np.array(tmp_input_subdtype=np.int32)
+ context = np.full_like(tmp_input_ids1dtype=np.int8)
+ segment_ids = np.array(tmp_input_segdtype=np.int32)
+ sample_ids = np.zeros(input_ids.shapedtype=np.int32)
+ segment_rel_offset = np.zeros(input_ids.shapedtype=np.int32)
+ num_segments = np.full(input_ids.shapen_segmentsdtype=np.int32)
+
+ for isample in enumerate(in_context_samples):
+ (
+ sample_input_ids,
+ sample_id_subs,
+ _,
+ sample_segments,
+ sample_rel,
+ n_segments,
+ table_states,
+ ) = convert_data_to_id(self.tokenizersampletable_statesmax_depth=8)
+ input_ids = np.concatenate([input_idssample_input_ids]axis=0)
+ input_id_subs = np.concatenate([input_id_subssample_id_subs]axis=0)
+ context = np.concatenate(
+ [contextnp.ones(sample_input_ids.shapedtype=np.int8)]axis=0
+ )
+ segment_ids = np.concatenate([segment_idssample_segments]axis=0)
+ segment_rel_offset = np.concatenate(
+ [
+ segment_rel_offset,
+ np.full(sample_input_ids.shapesegment_rel.shape[0]dtype=np.int32),
+ ],
+ axis=0,
+ )
+ segment_rel = np.concatenate([segment_relsample_rel]axis=0)
+ sample_ids = np.concatenate(
+ [sample_idsnp.full(sample_input_ids.shapei + 1dtype=np.int32)]axis=0
+ )
+ num_segments = np.concatenate(
+ [num_segmentsnp.full(sample_input_ids.shapen_segmentsdtype=np.int32)]axis=0
+ )
+ input_pos = np.arange(input_ids.shape[0]dtype=np.int32)
+
+ return (
+ input_ids,
+ input_id_subs,
+ input_pos,
+ context,
+ segment_ids,
+ segment_rel_offset,
+ segment_rel,
+ sample_ids,
+ num_segments,
+ predict_segments,
+ answer_placeholders,
+ table_states["ext_table"],
+ table_states["token_id_table"],
+ )
+
+ def _process_list(selfdata_list: List[Any]):
+ pack_tensor = []
+ other_info = []
+ segment_rel_pack = []
+
+ batch_ext_table_map: Dict[Tuple[intint]int] = {}
+ batch_ext_table_ids: List[int] = []
+ batch_ext_table_sub: List[int] = []
+
+ for data in data_list:
+ (
+ input_ids,
+ input_id_subs,
+ input_pos,
+ context,
+ segment_ids,
+ segment_rel_offset,
+ segment_rel,
+ sample_ids,
+ num_segments,
+ predict_segments,
+ answer_placeholders,
+ ext_table,
+ token_id_table,
+ ) = self._convert_to_tensors(data[])
+ rev_ext_table: Dict[intstr] = {}
+ for tokenmp in token_id_table.items():
+ if token == "":
+ continue
+ token_id = self.tokenizer.encoder[token]
+ for fake_idtoken_sub in mp.items():
+ if token_sub > 0:
+ if (token_idtoken_sub) not in batch_ext_table_map:
+ batch_ext_table_map[(token_idtoken_sub)] = (
+ len(batch_ext_table_ids) + self.tokenizer.vocab_size
+ )
+ batch_ext_table_ids.append(token_id)
+ batch_ext_table_sub.append(token_sub)
+ rev_ext_table[batch_ext_table_map[(token_idtoken_sub)]] = ext_table[
+ fake_id
+ ]
+ else:
+ rev_ext_table[token_id] = ext_table[fake_id]
+ pack_tensor.append(
+ {
+ "input": torch.from_numpy(input_ids).unsqueeze(0),
+ "input_sub": torch.from_numpy(input_id_subs).unsqueeze(0),
+ "input_pos": torch.from_numpy(input_pos).unsqueeze(0),
+ "context": torch.from_numpy(context).unsqueeze(0),
+ "sample_idx": torch.from_numpy(sample_ids).unsqueeze(0),
+ "num_segments": torch.from_numpy(num_segments).unsqueeze(0),
+ "segment": torch.from_numpy(segment_ids).unsqueeze(0),
+ "segment_rel_offset": torch.from_numpy(segment_rel_offset).unsqueeze(0),
+ }
+ )
+ segment_rel_pack.append(torch.from_numpy(segment_rel))
+ other_info.append(
+ {
+ "predict_segments": predict_segments,
+ "answer_placeholders": answer_placeholders,
+ "ext_table": rev_ext_table,
+ }
+ )
+
+ keys = set(pack_tensor[0].keys())
+ padded = {}
+ for key in keys:
+ padded[key] = pad(pack_tensorkey).cuda()
+
+ max_num_rels = 0
+ for rel in segment_rel_pack:
+ max_num_rels = max(max_num_relsrel.size(0))
+ padded_rels = torch.zeros(len(segment_rel_pack)max_num_relsdtype=torch.int32)
+ for irel in enumerate(segment_rel_pack):
+ padded_rels[i: rel.size(0)] = rel
+ padded["segment_rel"] = padded_rels.cuda()
+ padded["batch_ext_table_ids"] = torch.tensor(
+ batch_ext_table_idsdtype=torch.int32device="cuda"
+ )
+ padded["batch_ext_table_sub"] = torch.tensor(
+ batch_ext_table_subdtype=torch.int32device="cuda"
+ )
+ return paddedother_info
+
+ def generate(selfdata_list**kwargs):
+ model_inputsother_info = self._process_list(data_list)
+ with torch.inference_mode():
+ result_ids = self._decode(model_inputsother_info**kwargs)
+ for sent_idresult in enumerate(result_ids):
+ ans_result_map: Dict[intList[int]] = {}
+ for raw_word_idans_id in result:
+ if ans_id not in ans_result_map:
+ ans_result_map[ans_id] = []
+ ans_result_map[ans_id].append(raw_word_id)
+
+ answer_placeholders = other_info[sent_id]["answer_placeholders"]
+ ext_table = other_info[sent_id]["ext_table"]
+ data = data_list[sent_id]
+ for ans_idtoken_ids in ans_result_map.items():
+ if token_ids[-1] == self.tokenizer.eos_id:
+ token_ids = token_ids[:-1]
+ text = self.tokenizer.decode(token_idsext_table)
+ path = answer_placeholders[ans_id - 1]
+
+ if len(path) > 0:
+ p = data[""]
+ for part in path[:-1]:
+ p = p[part]
+ p[path[-1]] = text
+ else:
+ data[""] = text
+ for ans_id in range(len(answer_placeholders)):
+ if (ans_id + 1) not in ans_result_map:
+ path = answer_placeholders[ans_id]
+ p = data[""]
+ for part in path[:-1]:
+ p = p[part]
+ p[path[-1]] = None
+ return data_list
+
+ def _decode(selfmodel_inputsother_info**kwargs):
+ raise NotImplementedError("_decode is not implemented.")
+
+
+class CPMBeeBeamSearch(CPMBeeGeneration):
+ def _decode(
+ self,
+ model_inputs,
+ other_info,
+ beam_size=3,
+ max_length=100,
+ repetition_penalty=1.0,
+ repetition_window=None,
+ ):
+ """
+ Beam search
+ Args:
+ model_inputs (dict): input ids.
+ beam_size (intoptionaldefaults to 3): beam size of beam search.
+ generate_length (intoptionaldefaults to 100): maximum generation length.
+ repetition_penalty (floatoptionaldefaults to 1.0): repetition penalty coefficient1.0 means no penalty.
+ repetition_window (intoptionaldefaults to None): window size of repetition penaltyNone means that all output tokens are penalized.
+ """ # noqa: E501
+ # generate_length + 1 for EOS token
+ max_length += 1
+
+ # expand dimmension
+ batch_size = model_inputs["input"].size(0)
+ input: torch.Tensor = (
+ model_inputs["input"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+ input_sub: torch.Tensor = (
+ model_inputs["input_sub"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+ input_pos: torch.Tensor = (
+ model_inputs["input_pos"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+ context: torch.Tensor = (
+ model_inputs["context"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+ sample_ids: torch.Tensor = (
+ model_inputs["sample_idx"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+ num_segments: torch.Tensor = (
+ model_inputs["num_segments"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+ segment: torch.Tensor = (
+ model_inputs["segment"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+ segment_rel_offset: torch.Tensor = (
+ model_inputs["segment_rel_offset"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+ segment_rel: torch.Tensor = (
+ model_inputs["segment_rel"]
+ .unsqueeze(1)
+ .expand(batch_sizebeam_size-1)
+ .contiguous()
+ .view(batch_size * beam_size-1)
+ )
+ ext_table_ids: torch.Tensor = model_inputs["batch_ext_table_ids"]
+ ext_table_sub: torch.Tensor = model_inputs["batch_ext_table_sub"]
+ ext_table_ids_cpu = ext_table_ids.cpu()
+ ext_table_sub_cpu = ext_table_sub.cpu()
+
+ done = [False for _ in range(batch_size)]
+
+ beam_scores = torch.zeros((batch_sizebeam_size)dtype=torch.floatdevice=input.device)
+ beam_scores[:1:] = -1e9
+ beam_scores = beam_scores.view(-1)
+
+ # generated hypotheses
+ generated_hyps = [
+ BeamHypotheses(beam_sizemax_lengthlength_penalty=1early_stopping=False)
+ for _ in range(batch_size)
+ ]
+
+ pred_start_index = input.size(-1)
+ __past_key_values = self.model.inference(
+ input=input,
+ input_sub=input_sub,
+ position=input_pos,
+ context=context,
+ sample_ids=sample_ids,
+ num_segments=num_segments,
+ segment=segment,
+ segment_rel_offset=segment_rel_offset,
+ segment_rel=segment_rel,
+ ext_table_ids=ext_table_ids,
+ ext_table_sub=ext_table_sub,
+ past_key_values=None,
+ )
+
+ beam_states = []
+ for sent_id in range(batch_size):
+ instance_beam_states = []
+
+ for beam_id in range(beam_size):
+ instance_beam_states.append(
+ {
+ "idx": 0,
+ "ans": [],
+ "nx_token_id": self.tokenizer.bos_id,
+ "nx_token_sub": 0,
+ "nx_segment_id": other_info[sent_id]["predict_segments"][0][0],
+ "nx_position": 0,
+ }
+ )
+ beam_states.append(instance_beam_states)
+ for i in range(max_length + 1):
+ tmp_input = []
+ tmp_input_sub = []
+ tmp_position = []
+ tmp_segment = []
+ for sent_id in range(batch_size):
+ for beam_id in range(beam_size):
+ tmp_input.append(beam_states[sent_id][beam_id]["nx_token_id"])
+ tmp_input_sub.append(beam_states[sent_id][beam_id]["nx_token_sub"])
+ tmp_position.append(beam_states[sent_id][beam_id]["nx_position"])
+ tmp_segment.append(beam_states[sent_id][beam_id]["nx_segment_id"])
+ with torch.no_grad():
+ input = torch.cat(
+ [
+ input,
+ torch.tensor(tmp_inputdtype=torch.int32device="cuda").view(
+ batch_size * beam_size1
+ ),
+ ],
+ dim=-1,
+ )
+ logits_past_key_values = self.model.inference(
+ input=input[:-1:],
+ input_sub=torch.tensor(tmp_input_subdtype=torch.int32device="cuda").view(
+ batch_size * beam_size1
+ ),
+ position=torch.tensor(tmp_positiondtype=torch.int32device="cuda").view(
+ batch_size * beam_size1
+ ),
+ context=torch.ones(
+ batch_size * beam_sizedtype=torch.booldevice="cuda"
+ ).view(batch_size * beam_size1),
+ sample_ids=torch.zeros(
+ batch_size * beam_sizedtype=torch.int32device="cuda"
+ ).view(batch_size * beam_size1),
+ num_segments=num_segments[:-1:],
+ segment=torch.tensor(tmp_segmentdtype=torch.int32device="cuda").view(
+ batch_size * beam_size1
+ ),
+ segment_rel_offset=segment_rel_offset[:-1:],
+ segment_rel=segment_rel,
+ ext_table_ids=ext_table_ids,
+ ext_table_sub=ext_table_sub,
+ past_key_values=past_key_values,
+ )
+ logits = logits[:-1:]
+
+ # skip all steps when we are done with each sentence
+ if all(done):
+ break
+
+ for sent_id in range(batch_size):
+ if self.tokenizer.unk_id not in other_info[sent_id]["ext_table"]:
+ # unk is not allowedmask unk
+ logits[
+ sent_id * beam_size : (sent_id + 1) * beam_sizeself.tokenizer.unk_id
+ ] = -10000
+ ext_ids = set()
+ for v in other_info[sent_id]["ext_table"].keys():
+ ext_ids.add(v)
+ for ext_id in range(
+ self.tokenizer.vocab_sizeself.tokenizer.vocab_size + ext_table_ids.size(0)
+ ):
+ if ext_id not in ext_ids:
+ logits[sent_id * beam_size : (sent_id + 1) * beam_sizeext_id] = -10000
+
+ apply_repetition_penalty(
+ logits,
+ batch_size,
+ beam_size,
+ input,
+ repetition_penalty,
+ pred_start_index,
+ input.size(-1) - 1,
+ repetition_window,
+ )
+ scores = F.log_softmax(logitsdim=-1)
+ next_scores = scores + beam_scores[:None].expand_as(
+ scores
+ ) # (batch_size * beam_sizevocab_size)
+
+ # re-organize to group the beam together (we are keeping top hypothesis accross beams)
+ next_scores = next_scores.view(batch_size-1) # (batch_sizebeam_size * vocab_size)
+ next_scoresnext_words = torch.topk(
+ next_scores2 * beam_sizedim=1largest=Truesorted=True
+ )
+ assert next_scores.size() == next_words.size() == (batch_size2 * beam_size)
+ next_beam_states = []
+
+ for sent_id in range(batch_size):
+ # if we are done with this sentence
+ done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(
+ next_scores[sent_id].max().item()i
+ )
+ if done[sent_id]:
+ next_beam_states.append(
+ [
+ (
+ {
+ "idx": 0,
+ "ans": [],
+ "nx_token_id": 0,
+ "nx_token_sub": 0,
+ "nx_segment_id": 0,
+ "nx_position": 0,
+ },
+ 0,
+ 0,
+ )
+ ]
+ * beam_size
+ ) # pad the batch
+ continue
+
+ # next sentence beam content
+ next_instance_beam_states = []
+
+ # next words for this sentence
+ for idxvalue in zip(next_words[sent_id]next_scores[sent_id]):
+
+ # get beam and word IDs
+ beam_id = torch.div(idxscores.size(-1)rounding_mode="floor").item()
+ word_id = (idx % scores.size(-1)).item()
+
+ curr_info = beam_states[sent_id][beam_id]
+ # end of sentenceor next word
+ if (
+ word_id == self.tokenizer.eos_id
+ and (curr_info["idx"] + 1 == len(other_info[sent_id]["predict_segments"]))
+ ) or i == max_length:
+ generated_hyps[sent_id].add(
+ beam_states[sent_id][beam_id]["ans"]
+ + [
+ (
+ word_id,
+ other_info[sent_id]["predict_segments"][curr_info["idx"]][1],
+ )
+ ],
+ value.item(),
+ )
+ elif word_id == self.tokenizer.eos_id:
+ next_instance_beam_states.append(
+ (
+ {
+ "idx": curr_info["idx"] + 1,
+ "ans": curr_info["ans"]
+ + [
+ (
+ word_id,
+ other_info[sent_id]["predict_segments"][
+ curr_info["idx"]
+ ][1],
+ )
+ ],
+ "nx_token_id": self.tokenizer.bos_id,
+ "nx_token_sub": 0,
+ "nx_segment_id": other_info[sent_id]["predict_segments"][
+ curr_info["idx"] + 1
+ ][0],
+ "nx_position": 0,
+ },
+ value.item(),
+ sent_id * beam_size + beam_id,
+ )
+ )
+
+ else:
+ raw_word_id = word_id
+ word_id_sub = 0
+ if word_id >= self.tokenizer.vocab_size:
+ word_id -= self.tokenizer.vocab_size
+ word_id_sub = int(ext_table_sub_cpu[word_id].item())
+ word_id = int(ext_table_ids_cpu[word_id].item())
+
+ next_instance_beam_states.append(
+ (
+ {
+ "idx": curr_info["idx"],
+ "ans": curr_info["ans"]
+ + [
+ (
+ raw_word_id,
+ other_info[sent_id]["predict_segments"][
+ curr_info["idx"]
+ ][1],
+ )
+ ],
+ "nx_token_id": word_id,
+ "nx_token_sub": word_id_sub,
+ "nx_segment_id": curr_info["nx_segment_id"],
+ "nx_position": curr_info["nx_position"] + 1,
+ },
+ value.item(),
+ sent_id * beam_size + beam_id,
+ )
+ )
+
+ # the beam for next step is full
+ if len(next_instance_beam_states) == beam_size:
+ break
+
+ # update next beam content
+ assert len(next_instance_beam_states) == 0 if i == max_length else beam_size
+ next_beam_states.append(next_instance_beam_states)
+
+ # we have reached the last step
+ if i == max_length:
+ break
+
+ # sanity check / prepare next batch
+ beam_reorder_idx = []
+ beam_new_scores = []
+ beam_states = []
+ for sent_id in range(batch_size):
+ instance_beam_states = []
+ for beam_id in range(beam_size):
+ statevaluebeam_idx = next_beam_states[sent_id][beam_id]
+ beam_reorder_idx.append(beam_idx)
+ beam_new_scores.append(value)
+ instance_beam_states.append(state)
+ beam_states.append(instance_beam_states)
+
+ input = input[beam_reorder_idx:]
+ beam_scores = torch.tensor(beam_new_scoresdtype=torch.floatdevice=input.device)
+ for kw in past_key_values.keys():
+ if kw == "buffer":
+ buf_list = past_key_values[kw]
+ nw_buf_list = []
+ for buf in buf_list:
+ if buf is None:
+ nw_buf_list.append((NoneNone))
+ else:
+ k_bufv_buf = buf
+ nw_buf_list.append(
+ (k_buf[beam_reorder_idx:]v_buf[beam_reorder_idx:])
+ )
+ past_key_values[kw] = nw_buf_list
+ else:
+ past_key_values[kw] = past_key_values[kw][beam_reorder_idx:]
+
+ # select the best hypotheses
+ results = []
+ for sent_idhypotheses in enumerate(generated_hyps):
+ best_hyp = max(hypotheses.hypkey=lambda x: x[0])[1]
+ results.append(best_hyp)
+ return results
diff --git a/cpm-live/cpm_live/generation/generation_utils.py b/cpm-live/cpm_live/generation/generation_utils.py
new file mode 100644
index 0000000..07867d8
--- /dev/null
+++ b/cpm-live/cpm_live/generation/generation_utils.py
@@ -0,0 +1,112 @@
+import torch
+import torch.nn.functional as F
+
+
+def top_k_top_p_filtering(logitstop_k=0top_p=0.0filter_value=-float("inf")):
+ # This function has been mostly taken from huggingface conversational ai code at
+ # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
+
+ if top_k > 0:
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = logits < torch.topk(logitstop_k)[0][...-1None]
+ logits[indices_to_remove] = filter_value
+
+ batch_size = logits.size()[0]
+ if top_p > 0.0:
+ logits = logits.view(batch_size-1).contiguous()
+ for index in range(len(logits)):
+
+ sorted_logitssorted_indices = torch.sort(logits[index].view(-1)descending=True)
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logitsdim=-1)dim=-1)
+ # Remove tokens with cumulative probability above the threshold
+ sorted_indices_to_remove = cumulative_probs > top_p
+ # Shift the indices to the right to keep also the first token above the threshold
+ sorted_indices_to_remove[...1:] = sorted_indices_to_remove[...:-1].clone()
+ sorted_indices_to_remove[...0] = 0
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
+ logits[index][indices_to_remove] = filter_value
+
+ logits = logits.view(batch_size-1).contiguous()
+
+ return logits
+
+
+def apply_repetition_penalty(
+ logits,
+ batch_size,
+ num_beams,
+ prev_output_tokens,
+ repetition_penalty,
+ start_idx=None,
+ end_idx=None,
+ window_size=None,
+):
+ # only conduct repetition penalty for the output
+ assert repetition_penalty >= 1"repetition penalty coefficient should >= 1"
+ # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
+ for i in range(batch_size * num_beams):
+ if start_idx is None or end_idx is None:
+ output_tokens = prev_output_tokens[i].tolist()
+ else:
+ if end_idx >= start_idx:
+ if window_size:
+ output_tokens = prev_output_tokens[i][
+ max(start_idxend_idx + 1 - window_size) : end_idx + 1
+ ].tolist()
+ else:
+ output_tokens = prev_output_tokens[i][start_idx : end_idx + 1].tolist()
+ else:
+ output_tokens = []
+ for previous_token in set(output_tokens):
+ # if score < 0 then repetition penalty has to
+ # multiplied to reduce the previous token probability
+ if logits[iprevious_token] < 0:
+ logits[iprevious_token] *= repetition_penalty
+ else:
+ logits[iprevious_token] /= repetition_penalty
+
+
+class BeamHypotheses:
+ def __init__(selfn_hypmax_lenlength_penaltyearly_stopping):
+ """
+ Initialize n-best list of hypotheses.
+ """
+ self.max_len = max_len
+ self.length_penalty = length_penalty
+ self.early_stopping = early_stopping
+ self.n_hyp = n_hyp
+ self.hyp = []
+ self.worst_score = 1e9
+
+ def __len__(self):
+ """
+ Number of hypotheses in the list.
+ """
+ return len(self.hyp)
+
+ def add(selfhypsum_logprobs):
+ """
+ Add a new hypothesis to the list.
+ """
+ score = sum_logprobs / len(hyp) ** self.length_penalty
+
+ if len(self) < self.n_hyp or score > self.worst_score:
+ self.hyp.append((scorehyp))
+ if len(self) > self.n_hyp:
+ sorted_scores = sorted([(sidx) for idx(s_) in enumerate(self.hyp)])
+ del self.hyp[sorted_scores[0][1]]
+ self.worst_score = sorted_scores[1][0]
+ else:
+ self.worst_score = min(scoreself.worst_score)
+
+ def is_done(selfbest_sum_logprobscur_len):
+ """
+ If there are enough hypotheses and that none of the hypotheses being generated
+ can become better than the worst one in the heapthen we are done with this sentence.
+ """
+ if len(self) < self.n_hyp:
+ return False
+ elif self.early_stopping:
+ return True
+ else:
+ return self.worst_score >= best_sum_logprobs / cur_len**self.length_penalty
diff --git a/cpm-live/cpm_live/layers/__init__.py b/cpm-live/cpm_live/layers/__init__.py
index 4ff22c1..dbd8dcd 100644
--- a/cpm-live/cpm_live/layers/__init__.py
+++ b/cpm-live/cpm_live/layers/__init__.py
@@ -1,5 +1,5 @@
-from .embedding import Embedding
-from .position_embedding import SegmentPositionEmbedding
+from .embedding import EmbeddingEmbeddingExt
+from .position_embedding import SegmentPositionEmbeddingBucketPositionBiasRotaryEmbedding
from .linear import Linear
from .layernorm import LayerNorm
from .attention import Attention
diff --git a/cpm-live/cpm_live/layers/attention.py b/cpm-live/cpm_live/layers/attention.py
index c227037..241b14e 100644
--- a/cpm-live/cpm_live/layers/attention.py
+++ b/cpm-live/cpm_live/layers/attention.py
@@ -72,8 +72,8 @@ def forward(
len_q = hidden_q.size(1)
len_k = hidden_kv.size(1)
- h_q = self.project_q(hidden_q)
- h_k = self.project_k(hidden_kv)
+ h_q = self.project_q(hidden_q) / math.sqrt(math.sqrt(self.dim_head))
+ h_k = self.project_k(hidden_kv) / math.sqrt(math.sqrt(self.dim_head))
h_v = self.project_v(hidden_kv)
h_q = h_q.view(batch_sizelen_qself.num_headsself.dim_head).permute(0213)
@@ -86,7 +86,9 @@ def forward(
len_k = h_k.size(-2)
# (bn_hlen_qd_h) @ (bn_hd_hlen_k) -> (bn_hlen_qlen_k)
- score = torch.matmul(h_qh_k.transpose(-1-2)) / math.sqrt(self.dim_head)
+ score = torch.matmul(
+ h_qh_k.transpose(-1-2)
+ ) # / math.sqrt(self.dim_head) moved to line 75~76
score = score + position_bias
score = torch.masked_fill(
score,
diff --git a/cpm-live/cpm_live/layers/blocks.py b/cpm-live/cpm_live/layers/blocks.py
index e5db92e..a16abf6 100644
--- a/cpm-live/cpm_live/layers/blocks.py
+++ b/cpm-live/cpm_live/layers/blocks.py
@@ -91,7 +91,7 @@ def forward(
if self.dropout is not None:
x = self.dropout(x)
- hidden_states = hidden_states + x
+ hidden_states = (hidden_states + x) / 1.05
if use_cache:
return hidden_statescurrent_key_value
@@ -154,7 +154,7 @@ def forward(
x = self.ffn(x)
if self.dropout is not None:
x = self.dropout(x)
- hidden_states = hidden_states + x
+ hidden_states = (hidden_states + x) / 1.05
return hidden_states
diff --git a/cpm-live/cpm_live/layers/embedding.py b/cpm-live/cpm_live/layers/embedding.py
index 5b805b0..ffc305f 100644
--- a/cpm-live/cpm_live/layers/embedding.py
+++ b/cpm-live/cpm_live/layers/embedding.py
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
import torch
import bmtrain as bmt
import math
import torch.nn.functional as F
+from .position_embedding import RotaryEmbedding
class Embedding(bmt.DistributedModule):
@@ -60,3 +62,56 @@ def projection(selfx: torch.Tensor):
""" # noqa: E501
logits = F.linear(x / math.sqrt(self.dim_model)self.weight)
return logits
+
+
+class EmbeddingExt(bmt.DistributedModule):
+ def __init__(
+ self,
+ vocab_size: int,
+ embedding_size: int,
+ dtype: torch.dtype = torch.half,
+ init_mean: float = 0.0,
+ init_std: float = 1,
+ distance_scale: int = 16,
+ ):
+
+ super().__init__()
+
+ self.dim_model = embedding_size
+ self.rotary_emb = RotaryEmbedding(
+ dim=embedding_sizedistance_scale=distance_scaledtype=dtype
+ )
+
+ self.weight = bmt.DistributedParameter(
+ torch.empty(vocab_sizeembedding_sizedtype=dtype),
+ init_method=bmt.ParameterInitializer(
+ torch.nn.init.normal_mean=init_meanstd=init_std
+ ),
+ )
+
+ def forward(selfids: torch.Tensorids_sub: torch.Tensor):
+ """
+ Args:
+ ids (:obj:`torch.Tensor` of shape ``(batch_sizeseq_len)``): Indices of input sequence tokens.
+ ids (:obj:`torch.Tensor` of shape ``(batch_size)``): Subscript of input sequence tokens.
+ Return:
+ :obj:`torch.Tensor` of shape ``(batch_sizeseq_lenembedding_size)``: The embedding output.
+ """ # noqa: E501
+
+ embeds = F.embedding(idsself.weight) / math.sqrt(self.dim_model)
+ return self.rotary_emb(embedsids_sub)
+
+ def projection(selfx: torch.Tensorext_table: Optional[torch.Tensor] = None):
+ """
+ Projection based on embedding's weight. For exampleembedding map vocab_size to embed_sizethan projection map embed_size back to vocab_size.
+ Args:
+ x (:obj:`torch.Tensor` of shape ``(batchseq_lendim_model)``): Input of projection
+ ext_table (:obj:`torch.Tensor` of shape ``(ext_table_sizedim_model)``): Ext vocab table.
+ Returns:
+ :obj:`torch.Tensor` of shape ``(batchseq_lenvocab_size + ext_table_size)``: The projection output.
+ """ # noqa: E501
+ logits = F.linear(x / math.sqrt(self.dim_model)self.weight)
+ if ext_table is not None:
+ logits_ext = F.linear(xext_table)
+ logits = torch.cat([logitslogits_ext]dim=-1)
+ return logits
diff --git a/cpm-live/cpm_live/layers/feedforward.py b/cpm-live/cpm_live/layers/feedforward.py
index 8dd3c23..c015cf3 100644
--- a/cpm-live/cpm_live/layers/feedforward.py
+++ b/cpm-live/cpm_live/layers/feedforward.py
@@ -101,7 +101,7 @@ def __init__(
dim_in=dim_ff,
dim_out=dim_model,
dtype=dtype,
- scale_before=True,
+ scale_before=False,
)
def forward(selfx: torch.Tensor):
diff --git a/cpm-live/cpm_live/layers/position_embedding.py b/cpm-live/cpm_live/layers/position_embedding.py
index 790b527..2cd49f2 100644
--- a/cpm-live/cpm_live/layers/position_embedding.py
+++ b/cpm-live/cpm_live/layers/position_embedding.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
+from typing import Union
import torch
import bmtrain as bmt
import torch.nn.functional as F
@@ -21,12 +22,12 @@
class SegmentPositionEmbedding(bmt.DistributedModule):
def __init__(
self,
- num_heads,
- num_segments=1,
- num_buckets=32,
- max_distance=128,
- bidirectional=False,
- dtype=torch.half,
+ num_heads: int,
+ num_segments: int = 1,
+ num_buckets: int = 32,
+ max_distance: int = 128,
+ bidirectional: bool = False,
+ dtype: torch.dtype = torch.half,
init_mean: float = 0.0,
init_std: float = 1,
):
@@ -125,3 +126,129 @@ def _position_bucket(
is_smallrelative_position.to(torch.int32)relative_postion_if_large
)
return relative_buckets
+
+
+class BucketPositionBias(bmt.DistributedModule):
+ def __init__(
+ self,
+ num_heads: int,
+ num_buckets: int = 32,
+ num_segment_bucket: int = 32,
+ max_distance: int = 128,
+ dtype: torch.dtype = torch.half,
+ init_mean: float = 0.0,
+ init_std: float = 1,
+ ) -> None:
+ super().__init__()
+
+ self.num_heads = num_heads
+ self.num_buckets = num_buckets
+ self.num_segment_bucket = num_segment_bucket
+ self.max_distance = max_distance
+
+ self.relative_attention_bias = bmt.DistributedParameter(
+ torch.empty(num_buckets + num_segment_bucketnum_headsdtype=dtype),
+ init_method=bmt.ParameterInitializer(
+ torch.nn.init.normal_mean=init_meanstd=init_std
+ ),
+ )
+
+ def forward(
+ self,
+ query_pos: torch.Tensor # (batchlen_q)
+ key_pos: torch.Tensor # (batchlen_k)
+ rel_buckets: torch.Tensor # (batchlen_qlen_k)
+ ):
+ with torch.no_grad():
+
+ batch = key_pos.size(0)
+ keylen = key_pos.size(1)
+ querylen = query_pos.size(1)
+
+ assert key_pos.size(0) == query_pos.size(0)
+ assert (
+ rel_buckets.size(0) == batch
+ and rel_buckets.size(1) == querylen
+ and rel_buckets.size(2) == keylen
+ )
+
+ relative_position_bucket = rel_buckets - 1 + self.num_buckets # 与相对位置编码区间不重叠
+
+ # b*q*k
+ inner_segment_bucket = self._position_bucket(
+ key_pos[...None:] - query_pos[...:None],
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance,
+ )
+ relative_position_bucket = torch.where(
+ rel_buckets == 0,
+ inner_segment_bucket,
+ relative_position_bucket,
+ )
+ # (batchlen_qlen_k)
+
+ # (batchlen_qlen_knum_heads)
+ embeds = F.embedding(relative_position_bucketself.relative_attention_bias)
+ # (batchnum_headslen_qlen_k)
+ embeds = embeds.permute(0312).contiguous()
+ return embeds
+
+ def _position_bucket(selfrelative_positionnum_buckets=32max_distance=128):
+ relative_buckets = 0
+ num_buckets //= 2
+ relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
+ relative_position = torch.abs(relative_position)
+
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+ relative_postion_if_large = max_exact + (
+ torch.log(relative_position.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.int32)
+ relative_postion_if_large = torch.min(
+ relative_postion_if_large,
+ torch.full_like(relative_postion_if_largenum_buckets - 1),
+ )
+ relative_buckets += torch.where(
+ is_smallrelative_position.to(torch.int32)relative_postion_if_large
+ )
+ return relative_buckets
+
+
+class RotaryEmbedding(bmt.DistributedModule):
+ def __init__(
+ self,
+ dim,
+ base=10000,
+ distance_scale: Union[intfloat] = 1,
+ dtype: torch.dtype = torch.half,
+ ):
+ super().__init__()
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0dim2device="cuda"dtype=torch.float32) / dim)
+ )
+ inv_freq = inv_freq.to(dtype)
+ self.distance_scale = distance_scale
+ self.dtype = dtype
+ self.inv_freq = inv_freq
+
+ def forward(selfx: torch.Tensorx_pos: torch.Tensor):
+ """
+ Args:
+ x (:obj:`torch.Tensor` of shape ``(...dim)``): Inputs.
+ x_pos (:obj:`torch.Tensor` of shape ``(...)``): Positions of inputs.
+ """
+ x_pos = x_pos * self.distance_scale
+ freqs = x_pos[...None].to(self.dtype) * self.inv_freq[None:] # (...dim/2)
+
+ # the same implementation as sat
+ emb = torch.cat((freqsfreqs)dim=-1) # (...dim)
+ emb_cos = emb.cos() # (...dim)
+ emb_sin = emb.sin() # (...dim)
+
+ rotate_x = torch.cat(
+ [-x[...x.size(-1) // 2 :]x[...: x.size(-1) // 2]]dim=-1
+ ) # (...dim)
+
+ return x * emb_cos + rotate_x * emb_sin
diff --git a/cpm-live/cpm_live/models/__init__.py b/cpm-live/cpm_live/models/__init__.py
index d1e4033..ed9ed68 100644
--- a/cpm-live/cpm_live/models/__init__.py
+++ b/cpm-live/cpm_live/models/__init__.py
@@ -1 +1,4 @@
from .ant import CPMAntConfigCPMAnt
+from .bee import CPMBeeConfigCPMBee
+from .ant_torch import CPMAntTorch
+from .bee_torch import CPMBeeTorch
diff --git a/cpm-live/cpm_live/models/ant.py b/cpm-live/cpm_live/models/ant.py
index dc246f6..8d2f11c 100644
--- a/cpm-live/cpm_live/models/ant.py
+++ b/cpm-live/cpm_live/models/ant.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import ListOptionalTuple
import torch
from ..utils import Config
@@ -37,6 +38,7 @@ def __init__(
prompt_length: int = 32,
segment_types: int = 32,
mask_modules: Optional[List[Tuple[boolbool]]] = None,
+ **kwargs,
):
super().__init__()
@@ -77,12 +79,12 @@ def __init__(selfconfig: CPMAntConfig):
mask_modules=config.mask_modules,
)
- # self.prompt_embedding = Embedding(
- # vocab_size=config.prompt_types * config.prompt_length,
- # embedding_size=config.dim_model,
- # dtype=config.dtype,
- # init_std=0.02,
- # )
+ self.prompt_embedding = Embedding(
+ vocab_size=config.prompt_types * config.prompt_length,
+ embedding_size=config.dim_model,
+ dtype=config.dtype,
+ init_std=0.02,
+ )
self.segment_embedding = Embedding(
vocab_size=config.segment_types,
@@ -92,7 +94,7 @@ def __init__(selfconfig: CPMAntConfig):
)
self.input_embedding = Embedding(
- vocab_size=config.vocab_size + config.prompt_length * config.prompt_types,
+ vocab_size=config.vocab_size,
embedding_size=config.dim_model,
dtype=config.dtype,
init_std=0.02,
@@ -117,8 +119,49 @@ def forward(
position: torch.Tensor # (batchseqlen)
segment: torch.Tensor # (batchseqlen)
span: torch.Tensor # (batchseqlen)
+ ):
+
+ batch = input.size(0)
+ seqlen = input.size(1)
+ input_prompt = input[:: self.prompt_length].contiguous()
+ input_ids = input[:self.prompt_length :].contiguous()
+
+ prompt_states = self.prompt_embedding(input_prompt)
+ hidden_states = self.input_embedding(input_ids)
+ segment_states = self.segment_embedding(segment)
+ hidden_states = torch.cat([prompt_stateshidden_states]1) + segment_states
+
+ with torch.no_grad():
+ device = input.device
+ directional_mask_2d = torch.arange(seqlendevice=device) <= torch.arange(
+ seqlendevice=device
+ ).view(-11)
+ attention_mask = context[:None:] | (
+ context[::None].logical_not() & directional_mask_2d.view(1seqlenseqlen)
+ )
+ attention_mask = attention_mask & (span[:None:] == span[::None])
+ mask_1d = (
+ torch.arange(seqlendevice=device)[None:].repeat(batch1) < length[:None]
+ )
+ attention_mask = (
+ mask_1d.view(batchseqlen1) & mask_1d.view(batch1seqlen) & attention_mask
+ )
+
+ position_bias = self.position_bias(positionpositionsegmentsegment)
+ hidden_states = self.encoder(hidden_statesattention_maskposition_bias)
+
+ logits = self.input_embedding.projection(hidden_states)
+ return logitshidden_states
+
+ def inference(
+ self,
+ input: torch.Tensor # (batchseqlen)
+ length: torch.Tensor # (batch)
+ context: torch.Tensor # (batchseqlen)
+ position: torch.Tensor # (batchseqlen)
+ segment: torch.Tensor # (batchseqlen)
+ span: torch.Tensor # (batchseqlen)
past_key_values=None # num_layers * 2 * (batchnum_headsseqlendim_head)
- use_cache=False,
):
batch = input.size(0)
@@ -126,11 +169,13 @@ def forward(
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * self.encoder.num_layers)
- input_ids = input.contiguous()
+ input_prompt = input[:: self.prompt_length].contiguous()
+ input_ids = input[:self.prompt_length :].contiguous()
+ prompt_states = self.prompt_embedding(input_prompt)
hidden_states = self.input_embedding(input_ids)
segment_states = self.segment_embedding(segment)
- hidden_states = hidden_states + segment_states
+ hidden_states = torch.cat([prompt_stateshidden_states]1) + segment_states
else:
past_length = past_key_values[0][0].size(-2)
@@ -148,8 +193,10 @@ def forward(
context[::None].logical_not() & directional_mask_2d.view(1seqlenseqlen)
)
attention_mask = attention_mask & (span[:None:] == span[::None])
+ # mask for left paddding
mask_1d = (
- torch.arange(seqlendevice=device)[None:].repeat(batch1) < length[:None]
+ torch.tensor(list(range(seqlen))[::-1]device=device)[None:].repeat(batch1)
+ < length[:None]
)
attention_mask = (
mask_1d.view(batchseqlen1) & mask_1d.view(batch1seqlen) & attention_mask
@@ -157,19 +204,11 @@ def forward(
position_bias = self.position_bias(positionpositionsegmentsegment)
- if past_length > 0:
- attention_mask = attention_mask[:past_length::]
- position_bias = position_bias[::past_length::]
+ attention_mask = attention_mask[:past_length::]
+ position_bias = position_bias[::past_length::]
- if use_cache:
- hidden_statespresent_key_values = self.encoder(
- hidden_statesattention_maskposition_biasuse_cachepast_key_values
- )
- logits = self.input_embedding.projection(hidden_states)
- return logitshidden_statespresent_key_values
- else:
- hidden_states = self.encoder(
- hidden_statesattention_maskposition_biasuse_cachepast_key_values
- )
- logits = self.input_embedding.projection(hidden_states)
- return logitshidden_states
+ hidden_statespresent_key_values = self.encoder(
+ hidden_statesattention_maskposition_biasTruepast_key_values
+ )
+ logits = self.input_embedding.projection(hidden_states)
+ return logitshidden_statespresent_key_values
diff --git a/cpm-live/cpm_live/models/ant_torch.py b/cpm-live/cpm_live/models/ant_torch.py
new file mode 100644
index 0000000..44a23bf
--- /dev/null
+++ b/cpm-live/cpm_live/models/ant_torch.py
@@ -0,0 +1,169 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB team.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from ..native_layers import EncoderEmbeddingSegmentPositionEmbedding
+from .ant import CPMAntConfig
+
+
+class CPMAntTorch(torch.nn.Module):
+ def __init__(selfconfig: CPMAntConfig):
+
+ super().__init__()
+
+ self.encoder = Encoder(
+ num_layers=config.num_layers,
+ dim_model=config.dim_model,
+ dim_ff=config.dim_ff,
+ num_heads=config.num_heads,
+ dim_head=config.dim_head,
+ dtype=config.dtype,
+ eps=config.eps,
+ dropout_p=config.dropout_p,
+ mask_modules=config.mask_modules,
+ )
+
+ self.prompt_embedding = Embedding(
+ vocab_size=config.prompt_types * config.prompt_length,
+ embedding_size=config.dim_model,
+ dtype=config.dtype,
+ init_std=0.02,
+ )
+
+ self.segment_embedding = Embedding(
+ vocab_size=config.segment_types,
+ embedding_size=config.dim_model,
+ dtype=config.dtype,
+ init_std=0.02,
+ )
+
+ self.input_embedding = Embedding(
+ vocab_size=config.vocab_size,
+ embedding_size=config.dim_model,
+ dtype=config.dtype,
+ init_std=0.02,
+ )
+
+ self.position_bias = SegmentPositionEmbedding(
+ num_heads=config.num_heads,
+ num_segments=config.segment_types,
+ num_buckets=config.position_bias_num_buckets,
+ max_distance=config.position_bias_max_distance,
+ bidirectional=True,
+ dtype=config.dtype,
+ )
+
+ self.prompt_length = config.prompt_length
+
+ def forward(
+ self,
+ input: torch.Tensor # (batchseqlen)
+ length: torch.Tensor # (batch)
+ context: torch.Tensor # (batchseqlen)
+ position: torch.Tensor # (batchseqlen)
+ segment: torch.Tensor # (batchseqlen)
+ span: torch.Tensor # (batchseqlen)
+ ):
+
+ batch = input.size(0)
+ seqlen = input.size(1)
+ input_prompt = input[:: self.prompt_length].contiguous()
+ input_ids = input[:self.prompt_length :].contiguous()
+
+ prompt_states = self.prompt_embedding(input_prompt)
+ hidden_states = self.input_embedding(input_ids)
+ segment_states = self.segment_embedding(segment)
+ hidden_states = torch.cat([prompt_stateshidden_states]1) + segment_states
+
+ with torch.no_grad():
+ device = input.device
+ directional_mask_2d = torch.arange(seqlendevice=device) <= torch.arange(
+ seqlendevice=device
+ ).view(-11)
+ attention_mask = context[:None:] | (
+ context[::None].logical_not() & directional_mask_2d.view(1seqlenseqlen)
+ )
+ attention_mask = attention_mask & (span[:None:] == span[::None])
+ mask_1d = (
+ torch.arange(seqlendevice=device)[None:].repeat(batch1) < length[:None]
+ )
+ attention_mask = (
+ mask_1d.view(batchseqlen1) & mask_1d.view(batch1seqlen) & attention_mask
+ )
+
+ position_bias = self.position_bias(positionpositionsegmentsegment)
+ hidden_states = self.encoder(hidden_statesattention_maskposition_bias)
+
+ logits = self.input_embedding.projection(hidden_states)
+ return logitshidden_states
+
+ def inference(
+ self,
+ input: torch.Tensor # (batchseqlen)
+ length: torch.Tensor # (batch)
+ context: torch.Tensor # (batchseqlen)
+ position: torch.Tensor # (batchseqlen)
+ segment: torch.Tensor # (batchseqlen)
+ span: torch.Tensor # (batchseqlen)
+ past_key_values=None # num_layers * 2 * (batchnum_headsseqlendim_head)
+ ):
+
+ batch = input.size(0)
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = tuple([None] * self.encoder.num_layers)
+ input_prompt = input[:: self.prompt_length].contiguous()
+ input_ids = input[:self.prompt_length :].contiguous()
+
+ prompt_states = self.prompt_embedding(input_prompt)
+ hidden_states = self.input_embedding(input_ids)
+ segment_states = self.segment_embedding(segment)
+ hidden_states = torch.cat([prompt_stateshidden_states]1) + segment_states
+
+ else:
+ past_length = past_key_values[0][0].size(-2)
+ segment_states = self.segment_embedding(segment)
+ hidden_states = self.input_embedding(input) + segment_states[:-1::]
+
+ seqlen = past_length + input.size(1)
+
+ with torch.no_grad():
+ device = input.device
+ directional_mask_2d = torch.arange(seqlendevice=device) <= torch.arange(
+ seqlendevice=device
+ ).view(-11)
+ attention_mask = context[:None:] | (
+ context[::None].logical_not() & directional_mask_2d.view(1seqlenseqlen)
+ )
+ attention_mask = attention_mask & (span[:None:] == span[::None])
+ # mask for left paddding
+ mask_1d = (
+ torch.tensor(list(range(seqlen))[::-1]device=device)[None:].repeat(batch1)
+ < length[:None]
+ )
+ attention_mask = (
+ mask_1d.view(batchseqlen1) & mask_1d.view(batch1seqlen) & attention_mask
+ )
+
+ position_bias = self.position_bias(positionpositionsegmentsegment)
+
+ attention_mask = attention_mask[:past_length::]
+ position_bias = position_bias[::past_length::]
+
+ hidden_statespresent_key_values = self.encoder(
+ hidden_statesattention_maskposition_biasTruepast_key_values
+ )
+ logits = self.input_embedding.projection(hidden_states)
+ return logitshidden_statespresent_key_values
diff --git a/cpm-live/cpm_live/models/bee.py b/cpm-live/cpm_live/models/bee.py
new file mode 100644
index 0000000..098ecc6
--- /dev/null
+++ b/cpm-live/cpm_live/models/bee.py
@@ -0,0 +1,291 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB team.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import ListOptionalTuple
+from typing_extensions import TypedDict
+import torch
+
+from ..utils import Config
+from ..layers import EncoderEmbeddingExtBucketPositionBias
+import bmtrain as bmt
+from ..utils.gradient_shrink import gradient_shrink
+
+
+class CPMBeeInferenceState(TypedDict):
+ buffer_position: torch.Tensor
+ buffer_context: torch.Tensor
+ buffer_sample_ids: torch.Tensor
+ buffer_num_segments: torch.Tensor
+ buffer_segments: torch.Tensor
+ buffer: List[Tuple[torch.Tensortorch.Tensor]]
+
+
+class CPMBeeConfig(Config):
+ def __init__(
+ self,
+ vocab_size=30720,
+ dim_model=4096,
+ num_heads=64,
+ dim_head=64,
+ dim_ff=10240,
+ num_layers=32,
+ dropout_p=0.0,
+ position_bias_num_buckets=256,
+ position_bias_num_segment_buckets=256,
+ position_bias_max_distance=2048,
+ eps=1e-6,
+ half: bool = True,
+ mask_modules: Optional[List[Tuple[boolbool]]] = None,
+ ):
+
+ super().__init__()
+ self.dim_model = dim_model
+ self.num_heads = num_heads
+ self.dim_head = dim_head
+ self.dim_ff = dim_ff
+ self.num_layers = num_layers
+ self.position_bias_num_buckets = position_bias_num_buckets
+ self.position_bias_num_segment_buckets = position_bias_num_segment_buckets
+ self.position_bias_max_distance = position_bias_max_distance
+ self.dropout_p = dropout_p
+ self.eps = eps
+ if half:
+ self.dtype = torch.half
+ else:
+ self.dtype = torch.float
+ self.vocab_size = vocab_size
+ self.mask_modules = mask_modules
+
+
+class CPMBee(bmt.DistributedModule):
+ def __init__(selfconfig: CPMBeeConfig):
+
+ super().__init__()
+
+ self.encoder = Encoder(
+ num_layers=config.num_layers,
+ dim_model=config.dim_model,
+ dim_ff=config.dim_ff,
+ num_heads=config.num_heads,
+ dim_head=config.dim_head,
+ dtype=config.dtype,
+ eps=config.eps,
+ dropout_p=config.dropout_p,
+ mask_modules=config.mask_modules,
+ )
+
+ self.input_embedding = EmbeddingExt(
+ vocab_size=config.vocab_size,
+ embedding_size=config.dim_model,
+ dtype=config.dtype,
+ init_std=0.02,
+ )
+
+ self.position_bias = BucketPositionBias(
+ num_heads=config.num_heads,
+ num_buckets=config.position_bias_num_buckets,
+ num_segment_bucket=config.position_bias_num_segment_buckets,
+ max_distance=config.position_bias_max_distance,
+ dtype=config.dtype,
+ )
+
+ def forward(
+ self,
+ input: torch.Tensor # (batchseqlen) int32
+ input_sub: torch.Tensor # (batchseqlen) int32
+ length: torch.Tensor # (batch) int32
+ context: torch.Tensor # (batchseqlen) bool
+ sample_ids: torch.Tensor # (batchseq_len) int32
+ num_segments: torch.Tensor # (batchseq_len) int32
+ segment: torch.Tensor # (batchseqlen) int32
+ segment_rel_offset: torch.Tensor # (batchseq_len) int32
+ segment_rel: torch.Tensor # (batchnum_segment_bucket) int32
+ span: torch.Tensor # (batchseqlen) int32
+ ext_table_ids: torch.Tensor # (ext_table_size) int32
+ ext_table_sub: torch.Tensor # (ext_table_size) int32
+ ):
+ batch = input.size(0)
+ seqlen = input.size(1)
+ # processing masks and position bias bucket
+ with torch.no_grad():
+ device = input.device
+
+ # calc segment bucket
+ segment_rel_2d = torch.masked_fill(
+ segment[::None] * num_segments[::None]
+ + segment[:None:]
+ + segment_rel_offset[::None],
+ ~(
+ (sample_ids[::None] == sample_ids[:None:])
+ & (span[:None:] == span[::None])
+ ) # not in the same span or sample
+ 0 # avoid torch.gather overflow
+ ).view(batchseqlen * seqlen)
+
+ segment_bucket = torch.gather(
+ input=segment_rel,
+ dim=1,
+ index=segment_rel_2d.long(),
+ ).view(batchseqlenseqlen)
+
+ segment_bucket.masked_fill_(
+ ~(
+ (sample_ids[::None] == sample_ids[:None:])
+ & (span[:None:] == span[::None])
+ ) # not in the same span or sample
+ 1 # bucket is used for in-context samples
+ )
+
+ # directional mask
+ directional_mask_2d = torch.arange(seqlendevice=device) <= torch.arange(
+ seqlendevice=device
+ ).view(-11)
+ # sample mask
+ sample_mask_2d = (sample_ids[::None] == 0) | (
+ sample_ids[::None] == sample_ids[:None:]
+ )
+ # context mask
+ attention_mask = context[:None:] | (
+ context[::None].logical_not() & directional_mask_2d.view(1seqlenseqlen)
+ )
+ # span mask
+ attention_mask = (
+ attention_mask & sample_mask_2d & (span[:None:] == span[::None])
+ )
+ # length mask
+ mask_1d = (
+ torch.arange(seqlendevice=device)[None:].repeat(batch1) < length[:None]
+ )
+ attention_mask = (
+ mask_1d.view(batchseqlen1) & mask_1d.view(batch1seqlen) & attention_mask
+ )
+ position = torch.arange(seqlendevice=device).expand(batchseqlen)
+
+ hidden_states = self.input_embedding(inputinput_sub)
+ position_bias = self.position_bias(positionpositionsegment_bucket)
+
+ hidden_states = self.encoder(hidden_statesattention_maskposition_bias)
+
+ ext_table = self.input_embedding(ext_table_idsext_table_sub)
+
+ logits = self.input_embedding.projection(hidden_statesext_table)
+ return logitshidden_states
+
+ def inference(
+ self,
+ input: torch.Tensor # (batchlen_q) int32
+ input_sub: torch.Tensor # (batchlen_q) int32
+ position: torch.Tensor # (batchlen_q) int32
+ context: torch.Tensor # (batchlen_q) bool
+ sample_ids: torch.Tensor # (batchlen_q) int32
+ num_segments: torch.Tensor # (batchlen_q) int32
+ segment: torch.Tensor # (batchlen_q) int32
+ segment_rel_offset: torch.Tensor # (batchlen_q) int32
+ segment_rel: torch.Tensor # (batchnum_segment_bucket) int32
+ ext_table_ids: torch.Tensor # (ext_table_size) int32
+ ext_table_sub: torch.Tensor # (ext_table_size) int32
+ past_key_values: Optional[CPMBeeInferenceState] = None,
+ ) -> Tuple[torch.Tensortorch.TensorCPMBeeInferenceState]:
+ with torch.no_grad():
+ if past_key_values is None:
+ present_position = position
+ present_context = context
+ present_sample_ids = sample_ids
+ present_num_segments = num_segments
+ present_segments = segment
+ present_buffer = None
+ else:
+ present_position = torch.cat([past_key_values["buffer_position"]position]dim=-1)
+ present_context = torch.cat([past_key_values["buffer_context"]context]dim=-1)
+ present_sample_ids = torch.cat(
+ [past_key_values["buffer_sample_ids"]sample_ids]dim=-1
+ )
+ present_num_segments = torch.cat(
+ [past_key_values["buffer_num_segments"]num_segments]dim=-1
+ )
+ present_segments = torch.cat([past_key_values["buffer_segments"]segment]dim=-1)
+ present_buffer = past_key_values["buffer"]
+
+ batch = input.size(0)
+ len_q = input.size(1)
+ len_buffer = present_position.size(1)
+
+ segment_rel_2d = torch.masked_fill(
+ segment[::None] * num_segments[::None]
+ + present_segments[:None:]
+ + segment_rel_offset[::None],
+ ~(
+ (sample_ids[::None] == present_sample_ids[:None:])
+ ) # not in the same sample
+ 0 # avoid torch.gather overflow
+ ).view(batchlen_q * len_buffer)
+
+ segment_bucket = torch.gather(
+ input=segment_rel,
+ dim=1,
+ index=segment_rel_2d.long(),
+ ).view(batchlen_qlen_buffer)
+
+ segment_bucket.masked_fill_(
+ ~(
+ (sample_ids[::None] == present_sample_ids[:None:])
+ ) # not in the same span or sample
+ 1 # bucket is used for in-context samples
+ )
+
+ # directional mask
+ directional_mask_2d = present_position[:None:] <= position[::None]
+ # sample mask
+ sample_mask_2d = (sample_ids[::None] == 0) | (
+ sample_ids[::None] == present_sample_ids[:None:]
+ )
+ # context mask
+ attention_mask = present_context[:None:] | (
+ context[::None].logical_not()
+ & directional_mask_2d.view(batchlen_qlen_buffer)
+ )
+ # span mask
+ attention_mask = attention_mask & sample_mask_2d
+ # length mask
+ mask_1d = present_num_segments != 0
+ attention_mask = mask_1d.view(batch1len_buffer) & attention_mask
+
+ hidden_states = gradient_shrink(self.input_embedding(inputinput_sub))
+
+ position_bias = gradient_shrink(
+ self.position_bias(positionpresent_positionsegment_bucket)
+ )
+ hidden_statespresent_key_values = self.encoder(
+ hidden_states,
+ attention_mask,
+ position_bias,
+ True,
+ present_buffer,
+ )
+ ext_table = gradient_shrink(self.input_embedding(ext_table_idsext_table_sub))
+ logits = self.input_embedding.projection(hidden_statesext_table)
+
+ return (
+ logits,
+ hidden_states,
+ {
+ "buffer_position": present_position,
+ "buffer_context": present_context,
+ "buffer_sample_ids": present_sample_ids,
+ "buffer_num_segments": present_num_segments,
+ "buffer_segments": present_segments,
+ "buffer": present_key_values,
+ },
+ )
diff --git a/cpm-live/cpm_live/models/bee_torch.py b/cpm-live/cpm_live/models/bee_torch.py
new file mode 100644
index 0000000..a2e1489
--- /dev/null
+++ b/cpm-live/cpm_live/models/bee_torch.py
@@ -0,0 +1,240 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB team.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import OptionalTuple
+import torch
+
+from ..native_layers import EncoderEmbeddingExtBucketPositionBias
+from .bee import CPMBeeConfigCPMBeeInferenceState
+
+
+class CPMBeeTorch(torch.nn.Module):
+ def __init__(selfconfig: CPMBeeConfig):
+
+ super().__init__()
+
+ self.encoder = Encoder(
+ num_layers=config.num_layers,
+ dim_model=config.dim_model,
+ dim_ff=config.dim_ff,
+ num_heads=config.num_heads,
+ dim_head=config.dim_head,
+ dtype=config.dtype,
+ eps=config.eps,
+ dropout_p=config.dropout_p,
+ mask_modules=config.mask_modules,
+ )
+
+ self.input_embedding = EmbeddingExt(
+ vocab_size=config.vocab_size,
+ embedding_size=config.dim_model,
+ dtype=config.dtype,
+ init_std=0.02,
+ )
+
+ self.position_bias = BucketPositionBias(
+ num_heads=config.num_heads,
+ num_buckets=config.position_bias_num_buckets,
+ num_segment_bucket=config.position_bias_num_segment_buckets,
+ max_distance=config.position_bias_max_distance,
+ dtype=config.dtype,
+ )
+
+ def forward(
+ self,
+ input: torch.Tensor # (batchseqlen) int32
+ input_sub: torch.Tensor # (batchseqlen) int32
+ length: torch.Tensor # (batch) int32
+ context: torch.Tensor # (batchseqlen) bool
+ sample_ids: torch.Tensor # (batchseq_len) int32
+ num_segments: torch.Tensor # (batchseq_len) int32
+ segment: torch.Tensor # (batchseqlen) int32
+ segment_rel_offset: torch.Tensor # (batchseq_len) int32
+ segment_rel: torch.Tensor # (batchnum_segment_bucket) int32
+ span: torch.Tensor # (batchseqlen) int32
+ ext_table_ids: torch.Tensor # (ext_table_size) int32
+ ext_table_sub: torch.Tensor # (ext_table_size) int32
+ ):
+ batch = input.size(0)
+ seqlen = input.size(1)
+ # processing masks and position bias bucket
+ with torch.no_grad():
+ device = input.device
+
+ # calc segment bucket
+ segment_rel_2d = torch.masked_fill(
+ segment[::None] * num_segments[::None]
+ + segment[:None:]
+ + segment_rel_offset[::None],
+ ~(
+ (sample_ids[::None] == sample_ids[:None:])
+ & (span[:None:] == span[::None])
+ ) # not in the same span or sample
+ 0 # avoid torch.gather overflow
+ ).view(batchseqlen * seqlen)
+
+ segment_bucket = torch.gather(
+ input=segment_rel,
+ dim=1,
+ index=segment_rel_2d.long(),
+ ).view(batchseqlenseqlen)
+
+ segment_bucket.masked_fill_(
+ ~(
+ (sample_ids[::None] == sample_ids[:None:])
+ & (span[:None:] == span[::None])
+ ) # not in the same span or sample
+ 1 # bucket is used for in-context samples
+ )
+
+ # directional mask
+ directional_mask_2d = torch.arange(seqlendevice=device) <= torch.arange(
+ seqlendevice=device
+ ).view(-11)
+ # sample mask
+ sample_mask_2d = (sample_ids[::None] == 0) | (
+ sample_ids[::None] == sample_ids[:None:]
+ )
+ # context mask
+ attention_mask = context[:None:] | (
+ context[::None].logical_not() & directional_mask_2d.view(1seqlenseqlen)
+ )
+ # span mask
+ attention_mask = (
+ attention_mask & sample_mask_2d & (span[:None:] == span[::None])
+ )
+ # length mask
+ mask_1d = (
+ torch.arange(seqlendevice=device)[None:].repeat(batch1) < length[:None]
+ )
+ attention_mask = (
+ mask_1d.view(batchseqlen1) & mask_1d.view(batch1seqlen) & attention_mask
+ )
+ position = torch.arange(seqlendevice=device).expand(batchseqlen)
+
+ hidden_states = self.input_embedding(inputinput_sub)
+ position_bias = self.position_bias(positionpositionsegment_bucket)
+
+ hidden_states = self.encoder(hidden_statesattention_maskposition_bias)
+
+ ext_table = self.input_embedding(ext_table_idsext_table_sub)
+
+ logits = self.input_embedding.projection(hidden_statesext_table)
+ return logitshidden_states
+
+ def inference(
+ self,
+ input: torch.Tensor # (batchlen_q) int32
+ input_sub: torch.Tensor # (batchlen_q) int32
+ position: torch.Tensor # (batchlen_q) int32
+ context: torch.Tensor # (batchlen_q) bool
+ sample_ids: torch.Tensor # (batchlen_q) int32
+ num_segments: torch.Tensor # (batchlen_q) int32
+ segment: torch.Tensor # (batchlen_q) int32
+ segment_rel_offset: torch.Tensor # (batchlen_q) int32
+ segment_rel: torch.Tensor # (batchnum_segment_bucket) int32
+ ext_table_ids: torch.Tensor # (ext_table_size) int32
+ ext_table_sub: torch.Tensor # (ext_table_size) int32
+ past_key_values: Optional[CPMBeeInferenceState] = None,
+ ) -> Tuple[torch.Tensortorch.TensorCPMBeeInferenceState]:
+ with torch.no_grad():
+ if past_key_values is None:
+ present_position = position
+ present_context = context
+ present_sample_ids = sample_ids
+ present_num_segments = num_segments
+ present_segments = segment
+ present_buffer = None
+ else:
+ present_position = torch.cat([past_key_values["buffer_position"]position]dim=-1)
+ present_context = torch.cat([past_key_values["buffer_context"]context]dim=-1)
+ present_sample_ids = torch.cat(
+ [past_key_values["buffer_sample_ids"]sample_ids]dim=-1
+ )
+ present_num_segments = torch.cat(
+ [past_key_values["buffer_num_segments"]num_segments]dim=-1
+ )
+ present_segments = torch.cat([past_key_values["buffer_segments"]segment]dim=-1)
+ present_buffer = past_key_values["buffer"]
+
+ batch = input.size(0)
+ len_q = input.size(1)
+ len_buffer = present_position.size(1)
+
+ segment_rel_2d = torch.masked_fill(
+ segment[::None] * num_segments[::None]
+ + present_segments[:None:]
+ + segment_rel_offset[::None],
+ ~(
+ (sample_ids[::None] == present_sample_ids[:None:])
+ ) # not in the same sample
+ 0 # avoid torch.gather overflow
+ ).view(batchlen_q * len_buffer)
+
+ segment_bucket = torch.gather(
+ input=segment_rel,
+ dim=1,
+ index=segment_rel_2d.long(),
+ ).view(batchlen_qlen_buffer)
+
+ segment_bucket.masked_fill_(
+ ~(
+ (sample_ids[::None] == present_sample_ids[:None:])
+ ) # not in the same span or sample
+ 1 # bucket is used for in-context samples
+ )
+
+ # directional mask
+ directional_mask_2d = present_position[:None:] <= position[::None]
+ # sample mask
+ sample_mask_2d = (sample_ids[::None] == 0) | (
+ sample_ids[::None] == present_sample_ids[:None:]
+ )
+ # context mask
+ attention_mask = present_context[:None:] | (
+ context[::None].logical_not()
+ & directional_mask_2d.view(batchlen_qlen_buffer)
+ )
+ # span mask
+ attention_mask = attention_mask & sample_mask_2d
+ # length mask
+ mask_1d = present_num_segments != 0
+ attention_mask = mask_1d.view(batch1len_buffer) & attention_mask
+
+ hidden_states = self.input_embedding(inputinput_sub)
+
+ position_bias = self.position_bias(positionpresent_positionsegment_bucket)
+ hidden_statespresent_key_values = self.encoder(
+ hidden_states,
+ attention_mask,
+ position_bias,
+ True,
+ present_buffer,
+ )
+ ext_table = self.input_embedding(ext_table_idsext_table_sub)
+ logits = self.input_embedding.projection(hidden_statesext_table)
+
+ return (
+ logits,
+ hidden_states,
+ {
+ "buffer_position": present_position,
+ "buffer_context": present_context,
+ "buffer_sample_ids": present_sample_ids,
+ "buffer_num_segments": present_num_segments,
+ "buffer_segments": present_segments,
+ "buffer": present_key_values,
+ },
+ )
diff --git a/cpm-live/cpm_live/native_layers/__init__.py b/cpm-live/cpm_live/native_layers/__init__.py
new file mode 100644
index 0000000..dbd8dcd
--- /dev/null
+++ b/cpm-live/cpm_live/native_layers/__init__.py
@@ -0,0 +1,8 @@
+from .embedding import EmbeddingEmbeddingExt
+from .position_embedding import SegmentPositionEmbeddingBucketPositionBiasRotaryEmbedding
+from .linear import Linear
+from .layernorm import LayerNorm
+from .attention import Attention
+from .feedforward import FeedForward
+from .blocks import TransformerBlock
+from .transformer import Encoder
diff --git a/cpm-live/cpm_live/native_layers/attention.py b/cpm-live/cpm_live/native_layers/attention.py
new file mode 100644
index 0000000..5b9f402
--- /dev/null
+++ b/cpm-live/cpm_live/native_layers/attention.py
@@ -0,0 +1,117 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB team.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import OptionalTuple
+import torch
+import math
+from .linear import Linear
+
+
+class Attention(torch.nn.Module):
+ def __init__(
+ self,
+ dim_model: int,
+ num_heads: int,
+ dim_head: int,
+ dtype: torch.dtype = torch.half,
+ dropout_p: Optional[float] = None,
+ ) -> None:
+
+ super().__init__()
+
+ self.dim_model = dim_model
+ self.num_heads = num_heads
+ self.dim_head = dim_head
+
+ self.project_q = Linear(self.dim_modelself.num_heads * self.dim_headdtype=dtype)
+ self.project_k = Linear(self.dim_modelself.num_heads * self.dim_headdtype=dtype)
+ self.project_v = Linear(self.dim_modelself.num_heads * self.dim_headdtype=dtype)
+
+ self.attention_out = Linear(self.num_heads * self.dim_headself.dim_modeldtype=dtype)
+
+ self.softmax = torch.nn.Softmax(dim=-1)
+
+ if dropout_p is not None:
+ self.dropout = torch.nn.Dropout(p=dropout_p)
+ else:
+ self.dropout = None
+
+ def forward(
+ self,
+ hidden_q: torch.Tensor,
+ hidden_kv: torch.Tensor,
+ attention_mask: torch.BoolTensor,
+ position_bias: torch.Tensor,
+ use_cache: bool = False,
+ past_kv: Optional[Tuple[torch.Tensortorch.Tensor]] = None,
+ ):
+ """
+ Args:
+ hidden_q (:obj:`torch.Tensor` of shape ``(batchlen_qdim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix.
+ hidden_kv (:obj:`torch.Tensor` of shape ``(batchlen_kdim_model)``): Length of input sequence before padding.
+ attention_mask (:obj:`torch.Tensor` of shape ``(batchlen_qlen_k)``): Used to avoid performing attention on padding token indices.
+ position_bias(:obj:`torch.Tensor` of shape ``(num_headslen_qlen_k)`` or ``(1num_headslen_klen_q)``): Provide positional information about tensor `key_value` and `query`.
+ Return:
+ out (:obj:`torch.Tensor` of shape ``(batchlen_qdim_model)``): The attention output.
+ """ # noqa: E501
+
+ batch_size = hidden_q.size(0)
+ len_q = hidden_q.size(1)
+ len_k = hidden_kv.size(1)
+
+ h_q = self.project_q(hidden_q)
+ h_k = self.project_k(hidden_kv)
+ h_v = self.project_v(hidden_kv)
+
+ h_q = h_q.view(batch_sizelen_qself.num_headsself.dim_head).permute(0213)
+ h_k = h_k.view(batch_sizelen_kself.num_headsself.dim_head).permute(0213)
+ h_v = h_v.view(batch_sizelen_kself.num_headsself.dim_head).permute(0213)
+
+ if past_kv is not None:
+ h_k = torch.cat([past_kv[0]h_k]dim=-2)
+ h_v = torch.cat([past_kv[1]h_v]dim=-2)
+ len_k = h_k.size(-2)
+
+ # (bn_hlen_qd_h) @ (bn_hd_hlen_k) -> (bn_hlen_qlen_k)
+ score = torch.matmul(h_qh_k.transpose(-1-2)) / math.sqrt(self.dim_head)
+ score = score + position_bias
+ score = torch.masked_fill(
+ score,
+ attention_mask.view(batch_size1len_qlen_k) == False,
+ torch.scalar_tensor(float("-inf")device=score.devicedtype=score.dtype),
+ )
+
+ score = self.softmax(score)
+
+ score = torch.masked_fill(
+ score,
+ attention_mask.view(batch_size1len_qlen_k) == False,
+ torch.scalar_tensor(0device=score.devicedtype=score.dtype),
+ )
+
+ if self.dropout is not None:
+ score = self.dropout(score)
+
+ # (bn_hlen_qlen_k) @ (bn_hlen_kd_h) -> (bn_hlen_qd_h)
+ score = torch.matmul(scoreh_v)
+
+ score = score.view(batch_sizeself.num_headslen_qself.dim_head).permute(0213)
+ score = score.contiguous().view(batch_sizelen_qself.num_heads * self.dim_head)
+
+ score = self.attention_out(score)
+ if use_cache:
+ return score(h_kh_v)
+ else:
+ return score
diff --git a/cpm-live/cpm_live/native_layers/blocks.py b/cpm-live/cpm_live/native_layers/blocks.py
new file mode 100644
index 0000000..9c6efcb
--- /dev/null
+++ b/cpm-live/cpm_live/native_layers/blocks.py
@@ -0,0 +1,248 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB team.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import OptionalTuple
+import torch
+from .layernorm import LayerNorm
+from .attention import Attention
+from .feedforward import FeedForward
+
+
+class SelfAttentionBlock(torch.nn.Module):
+ """The whole cross-attention block. A sequence of operation. Consists of layernormself-attention and residual connection.
+
+ Args:
+ dim_model (int): main dimension of modules in transformer blocks.
+ num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
+ dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
+ dtype (optional): Defaults to torch.half.
+ eps (floatoptional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
+ dropout_p (floatoptional): Defaults to 0.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ dim_model: int,
+ num_heads: int,
+ dim_head: int,
+ dtype=torch.half,
+ eps: float = 1e-6,
+ dropout_p: Optional[float] = None,
+ ):
+
+ super().__init__()
+
+ self.layernorm_before_attention = LayerNorm(
+ dim_model,
+ dtype=dtype,
+ eps=eps,
+ )
+
+ self.self_attention = Attention(
+ dim_model=dim_model,
+ num_heads=num_heads,
+ dim_head=dim_head,
+ dtype=dtype,
+ dropout_p=dropout_p,
+ )
+
+ if dropout_p:
+ self.dropout = torch.nn.Dropout(dropout_p)
+ else:
+ self.dropout = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_bias: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ past_key_value: Optional[Tuple[torch.Tensortorch.Tensor]] = None,
+ ):
+ """
+ Args:
+ hidden_states (:obj:`torch.Tensor` of shape ``(batchseq_selfdim_model)``): Input of self-attention block. It can be the embedding of a batch of sequences.
+ attention_mask (:obj:`torch.Tensor` of shape ``(batchseq_selfseq_self)``): Avoid invalid areas to participate in the calculation.
+ position_bias (:obj:`torch.Tensor` of shape ``(num_headsseq_selfseq_self)``): Provide positional information to self-attention block.
+
+ Return:
+ :obj:`torch.Tensor` of shape ``(batchseq_selfdim_model)``: The output of attention block.
+
+ """ # noqa: E501
+ x = self.layernorm_before_attention(hidden_states)
+ x = self.self_attention(xxattention_maskposition_biasuse_cachepast_key_value)
+ if use_cache:
+ xcurrent_key_value = x
+ else:
+ current_key_value = None
+
+ if self.dropout is not None:
+ x = self.dropout(x)
+ hidden_states = (hidden_states + x) / 1.05
+
+ if use_cache:
+ return hidden_statescurrent_key_value
+ else:
+ return hidden_states
+
+
+class FFNBlock(torch.nn.Module):
+ """The whole feed-forward block. A sequence of operation. Consists of layernormfeed-forward and residual connection.
+
+ Args:
+ dim_model (int): main dimension of modules in transformer blocks.
+ dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
+ dtype (optional): Defaults to torch.half.
+ eps (floatoptional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
+ dropout_p (floatoptional): Defaults to 0.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ dim_model: int,
+ dim_ff: int,
+ dtype=torch.half,
+ eps: float = 1e-6,
+ dropout_p: Optional[float] = 0,
+ ):
+ super().__init__()
+
+ self.layernorm_before_ffn = LayerNorm(
+ dim_model,
+ dtype=dtype,
+ eps=eps,
+ )
+
+ self.ffn = FeedForward(
+ dim_model,
+ dim_ff,
+ dtype=dtype,
+ dropout_p=dropout_p,
+ )
+
+ if dropout_p:
+ self.dropout = torch.nn.Dropout(dropout_p)
+ else:
+ self.dropout = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ):
+ """
+ Args:
+ hidden_states (:obj:`torch.Tensor` of shape ``(batchseq_selfdim_model)``): Hidden states before feed forward layer.
+
+ Return:
+ :obj:`torch.Tensor` of shape ``(batchseq_selfdim_model)``: The output of feed-forward block
+
+ """ # noqa: E501
+ x = self.layernorm_before_ffn(hidden_states)
+ x = self.ffn(x)
+ if self.dropout is not None:
+ x = self.dropout(x)
+ hidden_states = (hidden_states + x) / 1.05
+ return hidden_states
+
+
+class TransformerBlock(torch.nn.Module):
+ """The whole transformer block. A sequence of operation. Consists of self-attention block[cross-attention block] and feed-forward block.
+
+ Args:
+ dim_model (int): main dimension of modules in transformer blocks.
+ dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
+ num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
+ dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
+ dtype (optional): Defaults to torch.half.
+ eps (floatoptional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
+ dropout_p (floatoptional): Defaults to 0.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ dim_model: int,
+ dim_ff: int,
+ num_heads: int,
+ dim_head: int,
+ dtype=torch.half,
+ eps: float = 1e-6,
+ dropout_p: Optional[float] = None,
+ mask_att: bool = False,
+ mask_ffn: bool = False,
+ ):
+ super().__init__()
+ self.mask_att = mask_att
+ self.mask_ffn = mask_ffn
+
+ if not self.mask_att:
+ self.self_att = SelfAttentionBlock(
+ dim_model=dim_model,
+ num_heads=num_heads,
+ dim_head=dim_head,
+ dtype=dtype,
+ eps=eps,
+ dropout_p=dropout_p,
+ )
+
+ if not self.mask_ffn:
+ self.ffn = FFNBlock(
+ dim_model=dim_model,
+ dim_ff=dim_ff,
+ dtype=dtype,
+ eps=eps,
+ dropout_p=dropout_p,
+ )
+
+ def forward(
+ self,
+ self_hidden_states: torch.Tensor,
+ self_attention_mask: torch.Tensor,
+ self_position_bias: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ past_key_value: Optional[Tuple[torch.Tensortorch.Tensor]] = None,
+ ):
+ """
+ Args:
+ self_hidden_states (:obj:`torch.Tensor` of shape ``(batchseq_selfdim_model)``): Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
+ self_attention_mask (:obj:`torch.Tensor` of shape ``(batchseq_selfseq_self)``): Avoid invalid areas to participate in the calculation of self-attention.
+ self_position_bias (:obj:`torch.Tensor` of shape ``(num_headsseq_selfseq_self)``): Provide positional information to self-attention block.
+
+ Return:
+ :obj:`torch.Tensor` of shape ``(batchseq_selfdim_model)``: The output of transformer block.
+
+ """ # noqa: E501
+ # (batchdim_modelseq_self)
+ current_key_value = None
+ if not self.mask_att:
+ hidden_states = self.self_att(
+ self_hidden_states,
+ attention_mask=self_attention_mask,
+ position_bias=self_position_bias,
+ use_cache=use_cache,
+ past_key_value=past_key_value,
+ )
+ if use_cache:
+ hidden_statescurrent_key_value = hidden_states
+ else:
+ hidden_states = self_hidden_states
+
+ # (batchdim_modelseq_self)
+ if not self.mask_ffn:
+ hidden_states = self.ffn(hidden_states)
+
+ if use_cache:
+ return hidden_statescurrent_key_value
+ else:
+ return hidden_states
diff --git a/cpm-live/cpm_live/native_layers/embedding.py b/cpm-live/cpm_live/native_layers/embedding.py
new file mode 100644
index 0000000..7f0fb64
--- /dev/null
+++ b/cpm-live/cpm_live/native_layers/embedding.py
@@ -0,0 +1,110 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB team.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import math
+import torch.nn.functional as F
+from .position_embedding import RotaryEmbedding
+from typing import Optional
+
+
+class Embedding(torch.nn.Module):
+ def __init__(
+ self,
+ vocab_size: int,
+ embedding_size: int,
+ dtype: torch.dtype = torch.half,
+ init_mean: float = 0.0,
+ init_std: float = 1,
+ ):
+
+ super().__init__()
+
+ self.dim_model = embedding_size
+ self.weight = torch.nn.parameter.Parameter(
+ torch.empty(vocab_sizeembedding_sizedtype=dtype)
+ )
+
+ def forward(selfids: torch.Tensor):
+ """
+ Args:
+ ids (:obj:`torch.Tensor` of shape ``(batch_sizeseq_len)``): Indices of input sequence tokens.
+ Return:
+ :obj:`torch.Tensor` of shape ``(batch_sizeseq_lenembedding_size)``: The embedding output.
+ """ # noqa: E501
+
+ embeds = F.embedding(idsself.weight) / math.sqrt(self.dim_model)
+ return embeds
+
+ def projection(selfx: torch.Tensor):
+ """
+ Projection based on embedding's weight. For exampleembedding map vocab_size to embed_sizethan projection map embed_size back to vocab_size.
+ Args:
+ x (:obj:`torch.Tensor` of shape ``(batchseq_lendim_model)``): Input of projection
+ Returns:
+ :obj:`torch.Tensor` of shape ``(batchseq_lenvocab_output_size)``: The projection output.
+ """ # noqa: E501
+ logits = F.linear(x / math.sqrt(self.dim_model)self.weight)
+ return logits
+
+
+class EmbeddingExt(torch.nn.Module):
+ def __init__(
+ self,
+ vocab_size: int,
+ embedding_size: int,
+ dtype: torch.dtype = torch.half,
+ init_mean: float = 0.0,
+ init_std: float = 1,
+ distance_scale: int = 16,
+ ):
+
+ super().__init__()
+
+ self.dim_model = embedding_size
+ self.rotary_emb = RotaryEmbedding(
+ dim=embedding_sizedistance_scale=distance_scaledtype=dtype
+ )
+
+ self.weight = torch.nn.parameter.Parameter(
+ torch.empty(vocab_sizeembedding_sizedtype=dtype),
+ )
+
+ def forward(selfids: torch.Tensorids_sub: torch.Tensor):
+ """
+ Args:
+ ids (:obj:`torch.Tensor` of shape ``(batch_sizeseq_len)``): Indices of input sequence tokens.
+ ids (:obj:`torch.Tensor` of shape ``(batch_size)``): Subscript of input sequence tokens.
+ Return:
+ :obj:`torch.Tensor` of shape ``(batch_sizeseq_lenembedding_size)``: The embedding output.
+ """ # noqa: E501
+
+ embeds = F.embedding(idsself.weight) / math.sqrt(self.dim_model)
+ return self.rotary_emb(embedsids_sub)
+
+ def projection(selfx: torch.Tensorext_table: Optional[torch.Tensor] = None):
+ """
+ Projection based on embedding's weight. For exampleembedding map vocab_size to embed_sizethan projection map embed_size back to vocab_size.
+ Args:
+ x (:obj:`torch.Tensor` of shape ``(batchseq_lendim_model)``): Input of projection
+ ext_table (:obj:`torch.Tensor` of shape ``(ext_table_sizedim_model)``): Ext vocab table.
+ Returns:
+ :obj:`torch.Tensor` of shape ``(batchseq_lenvocab_size + ext_table_size)``: The projection output.
+ """ # noqa: E501
+ logits = F.linear(x / math.sqrt(self.dim_model)self.weight)
+ if ext_table is not None:
+ logits_ext = F.linear(xext_table)
+ logits = torch.cat([logitslogits_ext]dim=-1)
+ return logits
diff --git a/cpm-live/cpm_live/native_layers/feedforward.py b/cpm-live/cpm_live/native_layers/feedforward.py
new file mode 100644
index 0000000..b624592
--- /dev/null
+++ b/cpm-live/cpm_live/native_layers/feedforward.py
@@ -0,0 +1,120 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB team.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+import torch
+from .linear import Linear
+
+
+class DenseGatedACT(torch.nn.Module):
+ def __init__(
+ self,
+ dim_in: int,
+ dim_ff: int,
+ dtype=torch.half,
+ ):
+ super().__init__()
+
+ self.w_0 = Linear(
+ dim_in=dim_in,
+ dim_out=dim_ff,
+ dtype=dtype,
+ scale_before=False,
+ )
+
+ self.w_1 = Linear(
+ dim_in=dim_in,
+ dim_out=dim_ff,
+ dtype=dtype,
+ scale_before=False,
+ )
+ self.act = torch.nn.GELU()
+
+ def forward(selfx: torch.Tensor):
+ """Transform an input tensor from one feature space to another via a nonlinear operation
+
+ Args:
+ x (:obj:`torch.Tensor` of shape ``(batchseq_lendim_in)``): Tensor that will be subject to nonlinear operations.
+
+ Return:
+ out (:obj:`torch.Tensor` of shape ``(batchseq_lendim_ff)``)
+
+ """ # noqa: E501
+ gate_score = self.act(self.w_0(x))
+ x = self.w_1(x)
+
+ x = gate_score * x
+ return x
+
+
+class FeedForward(torch.nn.Module):
+ r"""FeedForward module
+
+ Args:
+ dim_in (int): input dimension.
+ dim_ff (int): middle dimension.
+ dim_out (intoptional): output dimension. Defaults to Nonewhich means dim_in = dim_out.
+ dtype (optional): Defaults to torch.half.
+ init_mean (floatoptional): mean of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}\text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.
+ init_std (floatoptional): std of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}\text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.02.
+ bias (booloptional): whether to use bias term in fully-connected layers used in feed-forward module. Defaults to False.
+ activate_fn (stroptional): Defaults to `gated_gelu`.
+ dropout_p (intoptional): Defaults to 0.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ dim_model: int,
+ dim_ff: int,
+ dtype=torch.half,
+ dropout_p: Optional[float] = None,
+ ):
+
+ super().__init__()
+
+ self.w_in = DenseGatedACT(
+ dim_in=dim_model,
+ dim_ff=dim_ff,
+ dtype=dtype,
+ )
+
+ if dropout_p is not None:
+ self.dropout = torch.nn.Dropout(dropout_p)
+ else:
+ self.dropout = None
+
+ self.w_out = Linear(
+ dim_in=dim_ff,
+ dim_out=dim_model,
+ dtype=dtype,
+ scale_before=False,
+ )
+
+ def forward(selfx: torch.Tensor):
+ """
+ Args:
+ x (:obj:`torch.Tensor` of shape ``(batchseq_lendim_in)``): The input of feed-forward module.
+
+ Return:
+ :obj:`torch.Tensor` of shape ``(batchseq_lendim_out)``: The output of feed-forward module.
+ """ # noqa: E501
+ x = self.w_in(x)
+
+ if self.dropout is not None:
+ x = self.dropout(x)
+
+ x = self.w_out(x)
+
+ return x
diff --git a/cpm-live/cpm_live/native_layers/layernorm.py b/cpm-live/cpm_live/native_layers/layernorm.py
new file mode 100644
index 0000000..e8f19e9
--- /dev/null
+++ b/cpm-live/cpm_live/native_layers/layernorm.py
@@ -0,0 +1,37 @@
+import torch
+
+
[email protected] # type: ignore
+def rms_layernorm(hidden: torch.Tensorweight: torch.Tensoreps: float):
+ old_dtype = hidden.dtype
+ variance = hidden.to(torch.float32).pow(2).mean(dim=-1keepdim=True)
+ hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
+ return hidden * weight
+
+
+class LayerNorm(torch.nn.Module):
+ """RMS LayerNorm"""
+
+ def __init__(
+ self,
+ dim_norm: int,
+ dtype: torch.dtype = torch.half,
+ eps: float = 1e-6,
+ init_var: float = 1.0,
+ ):
+
+ super().__init__()
+
+ self.eps = eps
+ self.dim_norm = dim_norm
+ self.weight = torch.nn.parameter.Parameter(torch.full((dim_norm,)init_vardtype=dtype))
+
+ def forward(selfx: torch.Tensor):
+ """
+ Args:
+ x (:obj:`torch.Tensor` of shape ``(batch_sizeseq_lendim_norm)``): Input tensor that need to be normalized.
+ Return:
+ :obj:`torch.Tensor` of shape ``(batch_sizeseq_lendim_norm)``: The layernorm output.
+ """ # noqa: E501
+ assert x.size(-1) == self.dim_norm
+ return rms_layernorm(xself.weightself.eps)
diff --git a/cpm-live/cpm_live/native_layers/linear.py b/cpm-live/cpm_live/native_layers/linear.py
new file mode 100644
index 0000000..fe904ae
--- /dev/null
+++ b/cpm-live/cpm_live/native_layers/linear.py
@@ -0,0 +1,51 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB team.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import math
+import torch.nn.functional as F
+
+
+class Linear(torch.nn.Module):
+ def __init__(
+ self,
+ dim_in: int,
+ dim_out: int,
+ dtype: torch.dtype = torch.half,
+ init_mean: float = 0.0,
+ init_std: float = 1,
+ scale_before: bool = False,
+ ):
+ super().__init__()
+ self.dim_in = self.in_features = dim_in
+ self.dim_out = self.out_features = dim_out
+ self.scale_before = scale_before
+
+ self.weight = torch.nn.parameter.Parameter(torch.empty((dim_outdim_in)dtype=dtype))
+
+ def forward(selfx: torch.Tensor):
+ """
+ Args:
+ x (:obj:`torch.Tensor` of shape ``(batchseq_lendim_in)``): The input of linear layer
+ Returns:
+ :obj:`torch.Tensor` of shape ``(batchseq_lendim_out)``: The output of the linear transform y.
+ """ # noqa: E501
+ if self.scale_before:
+ x = x / math.sqrt(self.dim_in)
+ x = F.linear(xself.weight)
+ else:
+ x = F.linear(xself.weight)
+ x = x / math.sqrt(self.dim_in)
+ return x
diff --git a/cpm-live/cpm_live/native_layers/position_embedding.py b/cpm-live/cpm_live/native_layers/position_embedding.py
new file mode 100644
index 0000000..3778f3e
--- /dev/null
+++ b/cpm-live/cpm_live/native_layers/position_embedding.py
@@ -0,0 +1,247 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB team.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Union
+import torch
+import torch.nn.functional as F
+
+
+class SegmentPositionEmbedding(torch.nn.Module):
+ def __init__(
+ self,
+ num_heads,
+ num_segments=1,
+ num_buckets=32,
+ max_distance=128,
+ bidirectional=False,
+ dtype=torch.half,
+ init_mean: float = 0.0,
+ init_std: float = 1,
+ ):
+
+ super().__init__()
+
+ self.num_heads = num_heads
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.bidirectional = bidirectional
+ self.num_segments = num_segments
+
+ self.relative_attention_bias = torch.nn.parameter.Parameter(
+ torch.empty(num_segments * num_segments + num_bucketsnum_headsdtype=dtype)
+ )
+
+ def forward(
+ self,
+ key_pos: torch.Tensor,
+ query_pos: torch.Tensor,
+ key_segment: torch.Tensor,
+ query_segment: torch.Tensor,
+ ):
+ with torch.no_grad():
+
+ batch = key_pos.size(0)
+ keylen = key_pos.size(1)
+ querylen = query_pos.size(1)
+
+ assert key_pos.size(0) == query_pos.size(0)
+ assert keylen == key_segment.size(1) and querylen == query_segment.size(1)
+
+ key_pos = key_pos.view(batch-1keylen)
+ query_pos = query_pos.view(batchquerylen-1)
+ key_segment = key_segment.view(batch-1keylen)
+ query_segment = query_segment.view(batchquerylen-1)
+
+ relative_position_bucket = self._segment_relative_position_bucket(
+ query_segmentkey_segment
+ )
+ relative_position_bucket = relative_position_bucket + self.num_buckets # 与相对位置编码区间不重叠
+
+ # b*q*k
+ absolute_position_bucket = self._position_bucket(
+ torch.arange(keylendtype=torch.int32device=relative_position_bucket.device)[
+ None:
+ ]
+ - torch.arange(querylendtype=torch.int32device=relative_position_bucket.device)[
+ :None
+ ],
+ bidirectional=self.bidirectional,
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance,
+ )
+ relative_position_bucket = torch.where(
+ (key_segment == query_segment),
+ absolute_position_bucket[None::],
+ relative_position_bucket,
+ )
+ # (batchlen_qlen_k)
+
+ # (batchlen_qlen_knum_heads)
+ embeds = F.embedding(relative_position_bucketself.relative_attention_bias)
+ # (batchnum_headslen_qlen_k)
+ embeds = embeds.permute(0312).contiguous()
+ return embeds
+
+ def _segment_relative_position_bucket(selfquery_segmentkey_segment):
+ return query_segment * self.num_segments + key_segment
+
+ def _position_bucket(
+ selfrelative_positionbidirectional=Truenum_buckets=32max_distance=128
+ ):
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
+ relative_position = torch.abs(relative_position)
+ else:
+ relative_position = -torch.min(relative_positiontorch.zeros_like(relative_position))
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+ relative_postion_if_large = max_exact + (
+ torch.log(relative_position.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.int32)
+ relative_postion_if_large = torch.min(
+ relative_postion_if_large,
+ torch.full_like(relative_postion_if_largenum_buckets - 1),
+ )
+ relative_buckets += torch.where(
+ is_smallrelative_position.to(torch.int32)relative_postion_if_large
+ )
+ return relative_buckets
+
+
+class BucketPositionBias(torch.nn.Module):
+ def __init__(
+ self,
+ num_heads: int,
+ num_buckets: int = 32,
+ num_segment_bucket: int = 32,
+ max_distance: int = 128,
+ dtype: torch.dtype = torch.half,
+ init_mean: float = 0.0,
+ init_std: float = 1,
+ ) -> None:
+ super().__init__()
+
+ self.num_heads = num_heads
+ self.num_buckets = num_buckets
+ self.num_segment_bucket = num_segment_bucket
+ self.max_distance = max_distance
+
+ self.relative_attention_bias = torch.nn.parameter.Parameter(
+ torch.empty(num_buckets + num_segment_bucketnum_headsdtype=dtype)
+ )
+
+ def forward(
+ self,
+ query_pos: torch.Tensor # (batchlen_q)
+ key_pos: torch.Tensor # (batchlen_k)
+ rel_buckets: torch.Tensor # (batchlen_qlen_k)
+ ):
+ with torch.no_grad():
+
+ batch = key_pos.size(0)
+ keylen = key_pos.size(1)
+ querylen = query_pos.size(1)
+
+ assert key_pos.size(0) == query_pos.size(0)
+ assert (
+ rel_buckets.size(0) == batch
+ and rel_buckets.size(1) == querylen
+ and rel_buckets.size(2) == keylen
+ )
+
+ relative_position_bucket = rel_buckets - 1 + self.num_buckets # 与相对位置编码区间不重叠
+
+ # b*q*k
+ inner_segment_bucket = self._position_bucket(
+ key_pos[...None:] - query_pos[...:None],
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance,
+ )
+ relative_position_bucket = torch.where(
+ rel_buckets == 0,
+ inner_segment_bucket,
+ relative_position_bucket,
+ )
+ # (batchlen_qlen_k)
+
+ # (batchlen_qlen_knum_heads)
+ embeds = F.embedding(relative_position_bucketself.relative_attention_bias)
+ # (batchnum_headslen_qlen_k)
+ embeds = embeds.permute(0312).contiguous()
+ return embeds
+
+ def _position_bucket(selfrelative_positionnum_buckets=32max_distance=128):
+ relative_buckets = 0
+ num_buckets //= 2
+ relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
+ relative_position = torch.abs(relative_position)
+
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+ relative_postion_if_large = max_exact + (
+ torch.log(relative_position.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.int32)
+ relative_postion_if_large = torch.min(
+ relative_postion_if_large,
+ torch.full_like(relative_postion_if_largenum_buckets - 1),
+ )
+ relative_buckets += torch.where(
+ is_smallrelative_position.to(torch.int32)relative_postion_if_large
+ )
+ return relative_buckets
+
+
+class RotaryEmbedding(torch.nn.Module):
+ def __init__(
+ self,
+ dim,
+ base=10000,
+ distance_scale: Union[intfloat] = 1,
+ dtype: torch.dtype = torch.half,
+ ):
+ super().__init__()
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0dim2device="cuda"dtype=torch.float32) / dim)
+ )
+ inv_freq = inv_freq.to(dtype)
+ self.distance_scale = distance_scale
+ self.dtype = dtype
+ self.inv_freq = inv_freq
+
+ def forward(selfx: torch.Tensorx_pos: torch.Tensor):
+ """
+ Args:
+ x (:obj:`torch.Tensor` of shape ``(...dim)``): Inputs.
+ x_pos (:obj:`torch.Tensor` of shape ``(...)``): Positions of inputs.
+ """
+ x_pos = x_pos * self.distance_scale
+ freqs = x_pos[...None].to(self.dtype) * self.inv_freq[None:] # (...dim/2)
+
+ # the same implementation as sat
+ emb = torch.cat((freqsfreqs)dim=-1) # (...dim)
+ emb_cos = emb.cos() # (...dim)
+ emb_sin = emb.sin() # (...dim)
+
+ rotate_x = torch.cat(
+ [-x[...x.size(-1) // 2 :]x[...: x.size(-1) // 2]]dim=-1
+ ) # (...dim)
+
+ return x * emb_cos + rotate_x * emb_sin
diff --git a/cpm-live/cpm_live/native_layers/transformer.py b/cpm-live/cpm_live/native_layers/transformer.py
new file mode 100644
index 0000000..b21445c
--- /dev/null
+++ b/cpm-live/cpm_live/native_layers/transformer.py
@@ -0,0 +1,125 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB team.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from typing import OptionalListTuple
+
+from .blocks import TransformerBlock
+from .layernorm import LayerNorm
+
+
+class Encoder(torch.nn.Module):
+ """Layers of encoder transformer blocks plus an final layernorm.
+
+ Args:
+ num_layers (int): number of layers.
+ dim_model (int): main dimension of modules in transformer blocks.
+ dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
+ num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
+ dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
+ dtype (optional): Defaults to torch.half.
+ eps (floatoptional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-6.
+ dropout_p (floatoptional): Defaults to 0.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ num_layers: int,
+ dim_model: int,
+ dim_ff: int,
+ num_heads: int,
+ dim_head: int,
+ dtype: torch.dtype = torch.half,
+ eps: float = 1e-6,
+ dropout_p: Optional[float] = None,
+ mask_modules: Optional[List[Tuple[boolbool]]] = None,
+ ):
+
+ super().__init__()
+
+ self.num_layers = num_layers
+
+ if mask_modules is not None:
+ assert (
+ len(mask_modules) == num_layers
+ )"The total number of masks should equal to num_layers"
+ for mask_module in mask_modules:
+ assert (
+ len(mask_module) == 2
+ )"For encodereach mask should be (mask_attmask_ffn)"
+ else:
+ mask_modules = [(FalseFalse)] * num_layers
+
+ self.layers = torch.nn.ModuleList(
+ [
+ TransformerBlock(
+ dim_model=dim_model,
+ dim_ff=dim_ff,
+ num_heads=num_heads,
+ dim_head=dim_head,
+ dtype=dtype,
+ eps=eps,
+ dropout_p=dropout_p,
+ mask_att=mask_modules[ith][0],
+ mask_ffn=mask_modules[ith][1],
+ )
+ for ith in range(num_layers)
+ ]
+ )
+
+ self.output_layernorm = LayerNorm(dim_norm=dim_modeldtype=dtypeeps=eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_bias: torch.Tensor,
+ use_cache: bool = False,
+ past_key_values: Optional[List[Tuple[torch.Tensortorch.Tensor]]] = None,
+ ):
+ """
+ Args:
+ hidden-states (:obj:`torch.Tensor` of shape ``(batchseq_encdim_model)``): Input of encodermight be the embedding of a batch of sequences.
+ attention_mask (:obj:`torch.Tensor` of shape ``(batchseq_encseq_enc)``): Avoid invalid areas to participate in the calculation
+ position_bias(:obj:`torch.Tensor` of shape ``(num_headsseq_encseq_enc)``) Provides position information to attention mechanism.
+
+ Return:
+ :obj:`torch.Tensor` of shape ``(batchseq_encdim_model)``: The encoder output.
+
+ """ # noqa: E501
+ if not use_cache:
+ for layer in self.layers:
+ hidden_states = layer(hidden_statesattention_maskposition_bias)
+ hidden_states = self.output_layernorm(hidden_states)
+ return hidden_states
+ else:
+ with torch.no_grad():
+ current_key_values = []
+ for imodule in enumerate(self.layers):
+ hidden_states = module(
+ hidden_states,
+ attention_mask,
+ position_bias,
+ past_key_value=past_key_values[i] if past_key_values else None,
+ use_cache=use_cache,
+ )
+ if use_cache:
+ current_key_values.append(hidden_states[1])
+ hidden_states = hidden_states[0]
+ hidden_states = self.output_layernorm(hidden_states)
+ if use_cache:
+ return hidden_statescurrent_key_values
+ else:
+ return hidden_states
diff --git a/cpm-live/cpm_live/tokenizers/__init__.py b/cpm-live/cpm_live/tokenizers/__init__.py
index 1c2bb3a..c664def 100644
--- a/cpm-live/cpm_live/tokenizers/__init__.py
+++ b/cpm-live/cpm_live/tokenizers/__init__.py
@@ -1 +1,2 @@
from .ant import CPMAntTokenizer
+from .bee import CPMBeeTokenizer
diff --git a/cpm-live/cpm_live/tokenizers/ant.py b/cpm-live/cpm_live/tokenizers/ant.py
index bcbf826..d127ec6 100644
--- a/cpm-live/cpm_live/tokenizers/ant.py
+++ b/cpm-live/cpm_live/tokenizers/ant.py
@@ -1,4 +1,18 @@
# coding=utf-8
+# Copyright 2022 The OpenBMB team.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import jieba
import pkg_resources
import io
@@ -111,6 +125,10 @@ def pad_id(self):
def unk_id(self):
return self.encoder[self.unk_token]
+ @property
+ def newline_id(self):
+ return self.encoder["\n"]
+
def __len__(self):
return len(self.encoder)
diff --git a/cpm-live/cpm_live/tokenizers/bee.py b/cpm-live/cpm_live/tokenizers/bee.py
new file mode 100644
index 0000000..6078f6b
--- /dev/null
+++ b/cpm-live/cpm_live/tokenizers/bee.py
@@ -0,0 +1,223 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB team.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pkg_resources
+import io
+from typing import IODictListOptionalTuple
+
+
+def load_vocab(fp: IO[bytes]) -> Dict[strint]:
+ """Loads a vocabulary file into a dictionary."""
+ vocab: Dict[strint] = {}
+
+ reader = io.TextIOWrapper(fpencoding="utf-8")
+ for token in reader.readlines():
+ if token[-1] == "\n":
+ token = token[:-1]
+ if len(token) == 0:
+ continue
+ vocab[token] = len(vocab)
+ return vocab
+
+
+class Token(object):
+ def __init__(selftoken: strstart: intis_unk: boolis_special: bool):
+ self.token = token
+ self.start = start
+ self.is_unk = is_unk
+ self.is_special = is_special
+
+ def __str__(self):
+ return "Token(token={}start={}is_unk={}is_special={})".format(
+ self.tokenself.startself.is_unkself.is_special
+ )
+
+ def __repr__(self):
+ return self.__str__()
+
+
+class CPMBeeTokenizer(object):
+ def __init__(
+ self,
+ ):
+ self.unk_token = ""
+ self.mask_token = ""
+ self.bos_token = ""
+ self.eos_token = ""
+ self.line_token = "\n"
+ self.space_token = " "
+
+ self.encoder = load_vocab(pkg_resources.resource_stream("cpm_live""vocabs/bee.txt"))
+ self.encoder[self.line_token] = self.encoder[""]
+ self.encoder[self.space_token] = self.encoder[""]
+ del self.encoder[""]
+ del self.encoder[""]
+
+ self.decoder = {v: k for kv in self.encoder.items()}
+ self._special_tokens = {
+ k: v for kv in self.encoder.items() if k.startswith("<") and k.endswith(">")
+ }
+
+ self._max_word_len = max([len(x) for x in self.encoder.keys()])
+
+ def get_piece(selftext: str) -> str:
+ text = text[: self._max_word_len]
+ len_text = len(text)
+ for i in range(len(text)):
+ sub = text[: len_text - i]
+ if (sub in self.encoder) and (sub not in self._special_tokens):
+ return sub
+ return text[0]
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ @property
+ def eos_id(self):
+ return self.encoder[self.eos_token]
+
+ @property
+ def bos_id(self):
+ return self.encoder[self.bos_token]
+
+ @property
+ def unk_id(self):
+ return self.encoder[self.unk_token]
+
+ @property
+ def mask_id(self):
+ return self.encoder[self.mask_token]
+
+ def __len__(self):
+ return len(self.encoder)
+
+ def tokenize(selftext: str) -> List[Token]:
+ output_tokens: List[Token] = []
+
+ sentence_split = [""]
+ is_escape = False
+ is_special_token = False
+ for ic in enumerate(text):
+ if is_special_token:
+ if c == "<":
+ raise ValueError("Invalid special token at pos {}".format(i))
+ elif c == ">":
+ # end of special token
+ sentence_split[-1] += c
+ is_special_token = False
+ sentence_split.append("")
+ else:
+ sentence_split[-1] += c
+ else:
+ if c == "<":
+ if is_escape:
+ # case: <<
+ sentence_split[-1] += c
+ is_escape = False
+ else:
+ # case: x<
+ is_escape = True
+ else:
+ if is_escape:
+ # case str:
+ return text.replace("<""<<")
+
+ @staticmethod
+ def unescape(text: str) -> str:
+ return text.replace("<<""<")
+
+ def encode(
+ selftext: strpast_table: Dict[intstr] = {}
+ ) -> Tuple[List[int]Dict[intstr]]:
+ ext_table_rev: Dict[strint] = {}
+ ext_table: Dict[intstr] = {}
+ for idxval in past_table.items():
+ ext_table[idx] = val
+ ext_table_rev[val] = idx
+ ret = []
+ for x in self.tokenize(text):
+ if x.is_unk or (x.is_special and (x.token not in self.encoder)):
+ if x.token not in ext_table_rev:
+ ext_table_rev[x.token] = len(ext_table_rev) + self.vocab_size
+ ext_table[ext_table_rev[x.token]] = x.token
+ ret.append(ext_table_rev[x.token])
+ elif x.token in self.encoder:
+ ret.append(self.encoder[x.token])
+ else:
+ raise ValueError("Unknown token `{}` at pos {}".format(x.tokenx.start))
+
+ return retext_table
+
+ def decode(selftokens: List[int]ext_table: Optional[Dict[intstr]] = None):
+ """Decode ids into a string."""
+ if ext_table is None:
+ ext_table = {}
+ ret = []
+ for token in tokens:
+ if token in ext_table:
+ ret.append(ext_table[token])
+ else:
+ if token >= 0:
+ w = self.decoder[token]
+ if w in self._special_tokens:
+ ret.append(w)
+ else:
+ ret.append(self.escape(w))
+ return "".join(ret)
diff --git a/cpm-live/training_tasks/__init__.py b/cpm-live/cpm_live/training_tasks/__init__.py
similarity index 50%
rename from cpm-live/training_tasks/__init__.py
rename to cpm-live/cpm_live/training_tasks/__init__.py
index 0a502f1..3cb9fa0 100644
--- a/cpm-live/training_tasks/__init__.py
+++ b/cpm-live/cpm_live/training_tasks/__init__.py
@@ -1 +1,2 @@
from . import ant
+from . import bee
diff --git a/cpm-live/training_tasks/ant/__init__.py b/cpm-live/cpm_live/training_tasks/ant/__init__.py
similarity index 100%
rename from cpm-live/training_tasks/ant/__init__.py
rename to cpm-live/cpm_live/training_tasks/ant/__init__.py
diff --git a/cpm-live/training_tasks/ant/pretrain.py b/cpm-live/cpm_live/training_tasks/ant/pretrain.py
similarity index 96%
rename from cpm-live/training_tasks/ant/pretrain.py
rename to cpm-live/cpm_live/training_tasks/ant/pretrain.py
index 055b691..5604bfc 100644
--- a/cpm-live/training_tasks/ant/pretrain.py
+++ b/cpm-live/cpm_live/training_tasks/ant/pretrain.py
@@ -96,8 +96,7 @@ def __get_item_data(selfraw_data):
tgt = np.concatenate((np.full(self.prompt_length-100dtype=np.int64)tgt))
inp = np.concatenate(
(
- np.arange(self.prompt_lengthdtype=np.int64) +
- self.prompt_length * global_task + self.tokenizer.vocab_size,
+ np.arange(self.prompt_lengthdtype=np.int64) + self.prompt_length * global_task,
ctx,
)
)
diff --git a/cpm-live/cpm_live/training_tasks/bee/__init__.py b/cpm-live/cpm_live/training_tasks/bee/__init__.py
new file mode 100644
index 0000000..684f707
--- /dev/null
+++ b/cpm-live/cpm_live/training_tasks/bee/__init__.py
@@ -0,0 +1,2 @@
+from .pretrain import MixedDataset
+from .finetune import FinetuneDataset
diff --git a/cpm-live/cpm_live/training_tasks/bee/finetune.py b/cpm-live/cpm_live/training_tasks/bee/finetune.py
new file mode 100644
index 0000000..561e3c2
--- /dev/null
+++ b/cpm-live/cpm_live/training_tasks/bee/finetune.py
@@ -0,0 +1,71 @@
+from ...tokenizers import CPMBeeTokenizer
+from .pretrain import _MixedDatasetBatchPacker_MixedDatasetConfigCPMBeeBatch
+from ...dataset import SimpleDataset
+import bmtrain as bmt
+
+
+class FinetuneDataset:
+ def __init__(
+ self,
+ dataset_path: str,
+ batch_size: int,
+ max_length: int,
+ tokenizer: CPMBeeTokenizer,
+ max_depth: int = 16,
+ task_name: str = "task",
+ drop_last: bool = False,
+ ) -> None:
+ self._world_size = bmt.world_size()
+ self._rank = bmt.rank()
+ self._batch_size = batch_size
+
+ self._packer = _MixedDatasetBatchPacker(
+ batch_size * self._world_sizemax_lengthtokenizermax_depth
+ )
+ self._drop_last = drop_last
+
+ ds = SimpleDataset(dataset_pathshuffle=False)
+ self._ds_cfg: _MixedDatasetConfig = {
+ "weight": 1.0,
+ "path": dataset_path,
+ "transforms": [],
+ "task_name": task_name,
+ "dataset_name": "finetune",
+ "incontext_weight": [1.0],
+ "lines": len(ds),
+ "dataset": ds,
+ }
+
+ def __batch_iter(self):
+ while True:
+ try:
+ batch = self._packer.add_data(self._ds_cfg)
+ except EOFError:
+ break
+ if batch is None:
+ continue
+ yield batch
+ if len(self._packer) > 0:
+ batch = self._packer.pack_batch(force=True)
+ if not self._drop_last:
+ yield batch
+ self._ds_cfg["dataset"]._repeat_times = 0
+
+ def __iter__(self):
+ batch_st = self._batch_size * self._rank
+ batch_end = self._batch_size * (self._rank + 1)
+ for batch in self.__batch_iter():
+ batch_size = batch["inputs"].shape[0]
+ if batch_size <= batch_st:
+ yield None
+ else:
+ ret: CPMBeeBatch = {
+ kw: val[batch_st:batch_end] # type: ignore
+ for kwval in batch.items()
+ if kw not in ["task_names""raw_data""ext_ids""ext_sub"]
+ } # type: ignore
+ ret["task_names"] = batch["task_names"]
+ ret["raw_data"] = batch["raw_data"]
+ ret["ext_ids"] = batch["ext_ids"]
+ ret["ext_sub"] = batch["ext_sub"]
+ yield ret
diff --git a/cpm-live/cpm_live/training_tasks/bee/pretrain.py b/cpm-live/cpm_live/training_tasks/bee/pretrain.py
new file mode 100644
index 0000000..d3caef2
--- /dev/null
+++ b/cpm-live/cpm_live/training_tasks/bee/pretrain.py
@@ -0,0 +1,1041 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB team.
+#
+# Licensed under the Apache LicenseVersion 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writingsoftware
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KINDeither express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import OrderedDict
+import multiprocessing
+import os
+from queue import Empty
+from typing import AnyCallableDictListOptionalSetTupleUnion
+from typing_extensions import TypedDict
+from ...dataset import DistributedDataset
+from ...tokenizers import CPMBeeTokenizer
+from ...utils.config import load_dataset_config
+import numpy as np
+import time
+from numpy.typing import NDArray
+import torch
+import bmtrain as bmt
+import importlib.machinery
+import importlib.util
+import types
+import random
+
+
+class _MixedDatasetConfig(TypedDict):
+ weight: float
+ path: str
+ transforms: Union[List[Dict[strAny]]str]
+ task_name: str
+ dataset_name: str
+ incontext_weight: List[float]
+
+ lines: int
+ dataset: DistributedDataset
+
+
+CPMBeeInputType = Union[strDict[str"CPMBeeInputType"]]
+
+
+class _DictTree(TypedDict):
+ value: str
+ children: List["_DictTree"]
+ depth: int
+ segment_id: int
+ need_predict: bool
+
+
+class _PrevExtTableStates(TypedDict):
+ ext_table: Dict[intstr]
+ token_id_table: Dict[strDict[intint]]
+
+
+class _TransformFuncDict(TypedDict):
+ loader: importlib.machinery.SourceFileLoader
+ module: types.ModuleType
+ last_m: float
+
+
+_TransformFunction = Callable[[CPMBeeInputTypeintrandom.Random]CPMBeeInputType]
+
+
+class CPMBeeBatch(TypedDict):
+ inputs: NDArray[np.int32]
+ inputs_sub: NDArray[np.int32]
+ length: NDArray[np.int32]
+ context: NDArray[np.bool_]
+ sample_ids: NDArray[np.int32]
+ num_segments: NDArray[np.int32]
+ segment_ids: NDArray[np.int32]
+ segment_rel_offset: NDArray[np.int32]
+ segment_rel: NDArray[np.int32]
+ spans: NDArray[np.int32]
+ target: NDArray[np.int32]
+ ext_ids: NDArray[np.int32]
+ ext_sub: NDArray[np.int32]
+ task_ids: NDArray[np.int32]
+ task_names: List[str]
+ raw_data: List[Any]
+
+
+def rel_to_bucket(n_up: intn_down: intmax_depth: int = 8):
+ ret = n_up * max_depth + n_down
+ if ret == 0:
+ return ret
+ else:
+ # bucket 1 is reserved for incontext samples
+ return ret + 1
+
+
+def convert_data_to_id(
+ tokenizer: CPMBeeTokenizer,
+ data: Any,
+ prev_ext_states: Optional[_PrevExtTableStates] = None,
+ shuffle_answer: bool = True,
+ max_depth: int = 8,
+):
+ root: _DictTree = {
+ "value": "",
+ "children": [],
+ "depth": 0,
+ "segment_id": 0,
+ "need_predict": False,
+ }
+
+ segments = [root]
+
+ def _build_dict_tree(data: CPMBeeInputTypedepth: intneed_predict: bool) -> List[_DictTree]:
+ if isinstance(datadict):
+ ret_list: List[_DictTree] = []
+ curr_items = list(data.items())
+ if need_predict and shuffle_answer:
+ access_idx = np.arange(len(curr_items))
+ np.random.shuffle(access_idx)
+ curr_items = [curr_items[idx] for idx in access_idx]
+ for kv in curr_items:
+ child_info: _DictTree = {
+ "value": k,
+ "children": [],
+ "depth": depth,
+ "segment_id": len(segments),
+ "need_predict": False # only leaves are contexts
+ }
+ segments.append(child_info)
+ child_info["children"] = _build_dict_tree(
+ vdepth + 1need_predict or (depth == 1 and k == "")
+ ) # elements in .
+
+ ret_list.append(child_info)
+ return ret_list
+ else:
+ assert isinstance(datastr)"Invalid data {}".format(data)
+ ret: _DictTree = {
+ "value": data,
+ "children": [],
+ "depth": depth,
+ "segment_id": len(segments),
+ "need_predict": need_predict,
+ }
+ segments.append(ret)
+ return [ret]
+
+ root["children"] = _build_dict_tree(data1False)
+
+ num_segments = len(segments)
+ segment_rel = np.zeros((num_segments * num_segments,)dtype=np.int32)
+
+ def _build_segment_rel(node: _DictTree) -> List[Tuple[intint]]:
+ ret: List[Tuple[intint]] = [(node["segment_id"]node["depth"])]
+ for child in node["children"]:
+ sub = _build_segment_rel(child)
+ for seg_id_1depth_1 in sub:
+ for seg_id_2depth_2 in ret:
+ n_up = min(depth_1 - node["depth"]max_depth - 1)
+ n_down = min(depth_2 - node["depth"]max_depth - 1)
+ segment_rel[seg_id_1 * num_segments + seg_id_2] = rel_to_bucket(
+ n_upn_downmax_depth=max_depth
+ )
+ segment_rel[seg_id_2 * num_segments + seg_id_1] = rel_to_bucket(
+ n_downn_upmax_depth=max_depth
+ )
+ ret.extend(sub)
+ return ret
+
+ _build_segment_rel(root)
+
+ input_ids: List[int] = []
+ input_id_subs: List[int] = []
+ segment_bound: List[Tuple[intint]] = []
+
+ ext_table: Dict[intstr] = {}
+ token_id_table: Dict[strDict[intint]] = {}
+
+ if prev_ext_states is not None:
+ ext_table = prev_ext_states["ext_table"]
+ token_id_table = prev_ext_states["token_id_table"]
+
+ for seg in segments:
+ tokensext_table = tokenizer.encode(seg["value"]ext_table)
+
+ token_id_subs = []
+ reid_token_ids = []
+ for idx in tokens:
+ if idx in ext_table:
+ # unk or special token
+ token = ext_table[idx]
+ if token.startswith("<") and token.endswith(">"):
+ # special token
+ if "_" in token:
+ token_name = token[1:-1].split("_"maxsplit=1)[0]
+ else:
+ token_name = token[1:-1]
+ token_name = "<{}>".format(token_name)
+ else:
+ token_name = ""
+
+ if token_name not in token_id_table:
+ token_id_table[token_name] = {}
+ if idx not in token_id_table[token_name]:
+ token_id_table[token_name][idx] = len(token_id_table[token_name])
+ if token_name not in tokenizer.encoder:
+ raise ValueError("Invalid token {}".format(token))
+ reid_token_ids.append(tokenizer.encoder[token_name])
+ token_id_subs.append(token_id_table[token_name][idx])
+ else:
+ reid_token_ids.append(idx)
+ token_id_subs.append(0)
+ tokens = [tokenizer.bos_id] + reid_token_ids
+ token_id_subs = [0] + token_id_subs
+ if not seg["need_predict"]:
+ tokens = tokens + [tokenizer.eos_id]
+ token_id_subs = token_id_subs + [0]
+ else:
+ # no eos
+ pass
+ begin = len(input_ids)
+ input_ids.extend(tokens)
+ input_id_subs.extend(token_id_subs)
+ end = len(input_ids)
+ segment_bound.append((beginend))
+
+ ids = np.array(input_idsdtype=np.int32)
+ id_subs = np.array(input_id_subsdtype=np.int32)
+ segs = np.zeros((ids.shape[0],)dtype=np.int32)
+ context = np.zeros((ids.shape[0],)dtype=np.int8)
+ for i(beginend) in enumerate(segment_bound):
+ if not segments[i]["need_predict"]:
+ context[begin:end] = 1
+ segs[begin:end] = i
+
+ curr_ext_table_states: _PrevExtTableStates = {
+ "ext_table": ext_table,
+ "token_id_table": token_id_table,
+ }
+ return idsid_subscontextsegssegment_relnum_segmentscurr_ext_table_states
+
+
+def _dataset_identity(c: _MixedDatasetConfig):
+ return "{}.{}".format(c["task_name"]c["dataset_name"])
+
+
+class _MixedDatasetBatchPacker:
+ def __init__(
+ self,
+ batch_size: int,
+ max_length: int,
+ tokenizer: CPMBeeTokenizer,
+ max_depth: int = 16,
+ ) -> None:
+ self._batch_size = batch_size
+ self._max_length = max_length
+ self._max_depth = max_depth
+ self.tokenizer = tokenizer
+ self._transform_func_table: Dict[str_TransformFuncDict] = {}
+
+ self._inputs: List[NDArray[np.int32]] = []
+ self._inputs_sub: List[NDArray[np.int32]] = []
+ self._context: List[NDArray[np.int8]] = []
+ self._sample_ids: List[NDArray[np.int32]] = []
+ self._segments: List[NDArray[np.int32]] = []
+ self._num_segments: List[NDArray[np.int32]] = []
+ self._segment_rel_offset: List[NDArray[np.int32]] = []
+ self._segment_rel: List[NDArray[np.int32]] = []
+ self._spans: List[List[int]] = []
+ self._task_ids: List[List[str]] = []
+ self._raw_data: List[List[Any]] = []
+
+ def __len__(self):
+ return len(self._inputs)
+
+ def apply_transform(
+ self,
+ data: CPMBeeInputType,
+ transform: Union[Dict[strAny]Callable[[CPMBeeInputType]CPMBeeInputType]None],
+ ) -> CPMBeeInputType:
+ if transform is None:
+ return data
+ if not isinstance(transformdict):
+ # transform function
+ return transform(data)
+
+ mapping_list: List[Tuple[strstr]] = []
+
+ def _walk_transform_dict(data: Union[Dict[strAny]str]prefix: str = ""):
+ if isinstance(datadict):
+ for kv in data.items():
+ if len(prefix) > 0:
+ _walk_transform_dict(vprefix + "." + k)
+ else:
+ _walk_transform_dict(vk)
+ else:
+ assert isinstance(datastr)"Invalid transform {}".format(data)
+ mapping_list.append((prefixdata))
+
+ _walk_transform_dict(transform)
+
+ expanded_mapping_list: List[Tuple[strAny]] = []
+
+ def _expand_mapping(
+ data: CPMBeeInputTypestars: List[str]path: List[str]target: List[str]
+ ):
+ if len(path) == 0:
+ num_stars = 0
+ for it in target:
+ if it == "*":
+ num_stars += 1
+ if num_stars != len(stars):
+ raise ValueError("Invalid transform {}".format(".".join(target)))
+
+ nw_tgt = []
+ num_stars = 0
+ for it in target:
+ if it == "*":
+ nw_tgt.append(stars[num_stars])
+ num_stars += 1
+ else:
+ nw_tgt.append(it)
+ expanded_mapping_list.append((".".join(nw_tgt)data))
+ else:
+ if not isinstance(datadict):
+ raise ValueError("Invalid data {}".format(data))
+ if path[0] == "*":
+ for kv in data.items():
+ _expand_mapping(vstars + [k]path[1:]target)
+ else:
+ _expand_mapping(data[path[0]]starspath[1:]target)
+
+ # expand mapping list
+ for tgtsrc in mapping_list:
+ if src.startswith("$"):
+ # copy from src
+ _expand_mapping(data[]src[1:].split(".")tgt.split("."))
+ else:
+ if "*" in tgt:
+ raise ValueError("Constant value is not allowed to have `*` in prefix")
+ expanded_mapping_list.append((tgtsrc))
+
+ ret = {}
+ for tgtval in expanded_mapping_list:
+ tgt = tgt.split(".")
+ cur = ret
+ while len(tgt) > 1:
+ cur = cur[tgt[0]]
+ tgt = tgt[1:]
+ cur[tgt[0]] = val
+ return ret
+
+ def data_to_id(
+ self,
+ data: Any,
+ prev_ext_states: Optional[_PrevExtTableStates] = None,
+ shuffle_answer: bool = True,
+ ):
+ return convert_data_to_id(
+ self.tokenizerdataprev_ext_statesshuffle_answerself._max_depth
+ )
+
+ def _ensure_transform_function(
+ selfmodule_name: strtransform_script_path: str
+ ) -> _TransformFunction:
+ module_name = "cpm_live.transforms.{}".format(module_name)
+ if transform_script_path not in self._transform_func_table:
+ loader = importlib.machinery.SourceFileLoader(module_nametransform_script_path)
+ spec = importlib.util.spec_from_loader(loader.nameloader)
+ if spec is None:
+ raise RuntimeError("spec is none! {}".format(module_name))
+ mod = importlib.util.module_from_spec(spec)
+ self._transform_func_table[transform_script_path] = {
+ "loader": loader,
+ "module": mod,
+ "last_m": 0,
+ }
+
+ transform_script_info = self._transform_func_table[transform_script_path]
+ curr_m_time = float(
+ transform_script_info["loader"].path_stats(transform_script_path)["mtime"]
+ )
+ if curr_m_time > transform_script_info["last_m"]:
+ transform_script_info["last_m"] = curr_m_time
+ transform_script_info["loader"].exec_module(transform_script_info["module"])
+ transform_func = getattr(transform_script_info["module"]"transform"None)
+ if transform_func is None:
+
+ def _empty_transform_func(data: CPMBeeInputTypenum_sample: intr: random.Random):
+ raise NotImplementedError(
+ "Transform func for dataset {} not implemented".format(module_name)
+ )
+
+ return _empty_transform_func
+ else:
+ return transform_func
+
+ def build_instance(selfconfig: _MixedDatasetConfig):
+ _sample_weight = np.array(config["incontext_weight"]dtype=np.float32)
+ _sample_weight = _sample_weight / _sample_weight.sum()
+ num_incontext = np.random.choice(_sample_weight.shape[0]p=_sample_weight)
+ ds = config["dataset"]
+ transforms = config["transforms"]
+ if isinstance(transformsstr):
+ while True:
+ try:
+ if not os.path.exists(transforms):
+ raise RuntimeError(
+ "transform script file {} not exists".format(transforms)
+ )
+ # load transform script
+ transform_func = self._ensure_transform_function(
+ _dataset_identity(config)transforms
+ )
+ seed = random.random()
+ break
+ except Exception as e:
+ print(e)
+ time.sleep(10)
+
+ def _transform(data: CPMBeeInputType):
+ r = random.Random(seed)
+ return transform_func(datanum_incontextr)
+ transform = _transform
+ elif len(transforms) == 0:
+ transform = None
+ else:
+ transform = transforms[np.random.choice(len(transforms))]
+
+ raw_data = {}
+ while True:
+ inp = ds.read()
+ inp = self.apply_transform(inptransform)
+
+ (
+ input_ids,
+ input_id_subs,
+ context,
+ segment_ids,
+ segment_rel,
+ n_segments,
+ table_states,
+ ) = self.data_to_id(inp)
+ if input_ids.shape[0] > self._max_length:
+ # too long
+ continue
+ input_ids = input_ids[: self._max_length]
+ context = context[: self._max_length]
+ segment_ids = segment_ids[: self._max_length]
+ raw_data["input"] = inp
+ raw_data["samples"] = []
+ break
+
+ sample_ids = np.zeros(input_ids.shapedtype=np.int32)
+ segment_rel_offset = np.zeros(input_ids.shapedtype=np.int32)
+ num_segments = np.full(input_ids.shapen_segmentsdtype=np.int32)
+
+ for i in range(num_incontext):
+ if input_ids.shape[0] >= self._max_length:
+ # early break
+ break
+
+ sample = ds.read()
+ sample = self.apply_transform(sampletransform)
+ (
+ sample_input_ids,
+ sample_id_subs,
+ _,
+ sample_segments,
+ sample_rel,
+ n_segments,
+ table_states,
+ ) = self.data_to_id(sampletable_states)
+
+ if input_ids.shape[0] + sample_input_ids.shape[0] > self._max_length:
+ # too longbreak
+ break
+ raw_data["samples"].append(sample)
+ input_ids = np.concatenate([input_idssample_input_ids]axis=0)
+ input_id_subs = np.concatenate([input_id_subssample_id_subs]axis=0)
+ context = np.concatenate(
+ [contextnp.ones(sample_input_ids.shapedtype=np.int8)]axis=0
+ )
+ segment_ids = np.concatenate([segment_idssample_segments]axis=0)
+ segment_rel_offset = np.concatenate(
+ [
+ segment_rel_offset,
+ np.full(sample_input_ids.shapesegment_rel.shape[0]dtype=np.int32),
+ ],
+ axis=0,
+ )
+ segment_rel = np.concatenate([segment_relsample_rel]axis=0)
+ sample_ids = np.concatenate(
+ [sample_idsnp.full(sample_input_ids.shapei + 1dtype=np.int32)]axis=0
+ )
+ num_segments = np.concatenate(
+ [num_segmentsnp.full(sample_input_ids.shapen_segmentsdtype=np.int32)]axis=0
+ )
+ return (
+ input_ids,
+ input_id_subs,
+ context,
+ segment_ids,
+ segment_rel_offset,
+ segment_rel,
+ sample_ids,
+ num_segments,
+ raw_data,
+ )
+
+ def pack_batch(selfforce: bool = False) -> CPMBeeBatch:
+ # pack batch
+ if len(self._inputs) < self._batch_size:
+ if not force:
+ raise RuntimeError("Batch insufficient")
+ batch_size = len(self._inputs)
+ else:
+ batch_size = self._batch_size
+ inputs = np.zeros((batch_sizeself._max_length)dtype=np.int32)
+ inputs_sub = np.zeros((batch_sizeself._max_length)dtype=np.int32)
+ context = np.zeros((batch_sizeself._max_length)dtype=np.int8)
+ sample_ids = np.zeros((batch_sizeself._max_length)dtype=np.int32)
+ segments = np.zeros((batch_sizeself._max_length)dtype=np.int32)
+ num_segments = np.zeros((batch_sizeself._max_length)dtype=np.int32)
+ segment_rel_offset = np.zeros((batch_sizeself._max_length)dtype=np.int32)
+ tgt = np.full((batch_sizeself._max_length)-100dtype=np.int32)
+
+ max_rel = 0
+ for i in range(batch_size):
+ max_rel = max(max_relself._segment_rel[i].shape[0])
+ segment_rel = np.zeros((batch_sizemax_rel)dtype=np.int32)
+ spans = np.zeros((batch_sizeself._max_length)dtype=np.int32)
+ length = np.zeros((batch_size,)dtype=np.int32)
+ task_ids = np.zeros((batch_sizeself._max_length)dtype=np.int32)
+
+ all_task_names: Set[str] = set()
+ for i in range(batch_size):
+ for task_name in self._task_ids[i]:
+ all_task_names.add(task_name)
+ task_names: List[str] = list(all_task_names)
+ task_name_to_id = {name: i for iname in enumerate(task_names)}
+
+ batch_ext_table_map: Dict[Tuple[intint]int] = {}
+ batch_ext_table_ids: List[int] = []
+ batch_ext_table_sub: List[int] = []
+ raw_data_list: List[Any] = []
+ for i in range(batch_size):
+ instance_length = self._inputs[i].shape[0]
+ rel_size = self._segment_rel[i].shape[0]
+ inputs[i:instance_length] = self._inputs[i]
+ inputs_sub[i:instance_length] = self._inputs_sub[i]
+ context[i:instance_length] = self._context[i]
+ sample_ids[i:instance_length] = self._sample_ids[i]
+ segments[i:instance_length] = self._segments[i]
+ num_segments[i:instance_length] = self._num_segments[i]
+ segment_rel_offset[i:instance_length] = self._segment_rel_offset[i]
+ segment_rel[i:rel_size] = self._segment_rel[i]
+
+ span_begin = 0
+ for span_id(span_endtask_name) in enumerate(zip(self._spans[i]self._task_ids[i])):
+ spans[ispan_begin:span_end] = span_id
+ task_ids[ispan_begin:span_end] = task_name_to_id[task_name]
+ span_begin = span_end
+ length[i] = instance_length
+ raw_data_list.extend(self._raw_data[i])
+
+ for j in range(instance_length):
+ idxidx_sub = self._inputs[i][j]self._inputs_sub[i][j]
+ tgt_idx = idx
+ if idx_sub > 0:
+ # need to be in ext table
+ if (idxidx_sub) not in batch_ext_table_map:
+ batch_ext_table_map[(idxidx_sub)] = len(batch_ext_table_map)
+ batch_ext_table_ids.append(idx)
+ batch_ext_table_sub.append(idx_sub)
+ tgt_idx = batch_ext_table_map[(idxidx_sub)] + self.tokenizer.vocab_size
+ if j > 1 and context[ij - 1] == 0:
+ if idx != self.tokenizer.bos_id:
+ tgt[ij - 1] = tgt_idx
+ else:
+ tgt[ij - 1] = self.tokenizer.eos_id
+ if context[iinstance_length - 1] == 0:
+ tgt[iinstance_length - 1] = self.tokenizer.eos_id
+
+ if len(batch_ext_table_map) == 0:
+ # placeholder
+ batch_ext_table_ids.append(0)
+ batch_ext_table_sub.append(1)
+
+ self._inputs = self._inputs[batch_size:]
+ self._inputs_sub = self._inputs_sub[batch_size:]
+ self._context = self._context[batch_size:]
+ self._sample_ids = self._sample_ids[batch_size:]
+ self._segments = self._segments[batch_size:]
+ self._num_segments = self._num_segments[batch_size:]
+ self._segment_rel_offset = self._segment_rel_offset[batch_size:]
+ self._segment_rel = self._segment_rel[batch_size:]
+ self._spans = self._spans[batch_size:]
+ self._task_ids = self._task_ids[batch_size:]
+ self._raw_data = self._raw_data[batch_size:]
+ return {
+ "inputs": inputs,
+ "inputs_sub": inputs_sub,
+ "length": length,
+ "context": context > 0,
+ "sample_ids": sample_ids,
+ "num_segments": num_segments,
+ "segment_ids": segments,
+ "segment_rel_offset": segment_rel_offset,
+ "segment_rel": segment_rel,
+ "spans": spans,
+ "target": tgt,
+ "ext_ids": np.array(batch_ext_table_idsdtype=np.int32),
+ "ext_sub": np.array(batch_ext_table_subdtype=np.int32),
+ "task_ids": task_ids,
+ "task_names": task_names,
+ "raw_data": raw_data_list,
+ }
+
+ def add_data(selfconfig: _MixedDatasetConfig) -> Optional[CPMBeeBatch]:
+ (
+ input_ids,
+ input_id_subs,
+ context,
+ segment_ids,
+ segment_rel_offset,
+ segment_rel,
+ sample_ids,
+ num_segments,
+ raw_data,
+ ) = self.build_instance(config)
+
+ # add to batch
+ best_fit: Union[Noneint] = None
+ best_fit_space: Union[Noneint] = None
+ for i in range(len(self._inputs)):
+ space = self._max_length - self._inputs[i].shape[0]
+ if input_ids.shape[0] <= space:
+ if best_fit_space is None:
+ best_fit = i
+ best_fit_space = space
+ elif best_fit_space > space:
+ best_fit = i
+ best_fit_space = space
+ if best_fit is None:
+ # add a new instance
+ self._inputs.append(input_ids)
+ self._inputs_sub.append(input_id_subs)
+ self._context.append(context)
+ self._sample_ids.append(sample_ids)
+ self._segments.append(segment_ids)
+ self._num_segments.append(num_segments)
+ self._segment_rel_offset.append(segment_rel_offset)
+ self._segment_rel.append(segment_rel)
+ self._spans.append([input_ids.shape[0]])
+ self._task_ids.append([config["task_name"]])
+ self._raw_data.append([raw_data])
+ else:
+ # add to existing instance
+ self._inputs[best_fit] = np.concatenate([self._inputs[best_fit]input_ids]axis=0)
+ self._inputs_sub[best_fit] = np.concatenate(
+ [self._inputs_sub[best_fit]input_id_subs]axis=0
+ )
+ self._context[best_fit] = np.concatenate([self._context[best_fit]context]axis=0)
+ self._sample_ids[best_fit] = np.concatenate(
+ [self._sample_ids[best_fit]sample_ids]axis=0
+ )
+ self._segments[best_fit] = np.concatenate(
+ [self._segments[best_fit]segment_ids]axis=0
+ )
+ self._num_segments[best_fit] = np.concatenate(
+ [self._num_segments[best_fit]num_segments]axis=0
+ )
+ self._segment_rel_offset[best_fit] = np.concatenate(
+ [
+ self._segment_rel_offset[best_fit],
+ segment_rel_offset + self._segment_rel[best_fit].shape[0],
+ ],
+ axis=0,
+ )
+ self._segment_rel[best_fit] = np.concatenate(
+ [self._segment_rel[best_fit]segment_rel]axis=0
+ )
+ self._spans[best_fit].append(self._inputs[best_fit].shape[0])
+ self._task_ids[best_fit].append(config["task_name"])
+ self._raw_data[best_fit].append(raw_data)
+
+ if len(self._inputs) > self._batch_size:
+ return self.pack_batch()
+ else:
+ # not ready
+ return None
+
+
+class _MixedDatasetConfigMananger:
+ def __init__(selfconfig_path: str) -> None:
+ self._config_path: str = config_path
+ self._config: Union[List[_MixedDatasetConfig]None] = None
+ self._last_m = 0
+
+ def changed(self):
+ while True:
+ try:
+ m_time = os.stat(self._config_path).st_mtime
+ if m_time > self._last_m:
+ # try to load new config
+ try:
+ self._config = load_dataset_config(self._config_path)
+ except Exception as e:
+ # failed to load config
+ print(
+ "Error: load new config in changed"
+ "self._config_path={path}err={err}"
+ .format(path=self._config_patherr=str(e))
+ )
+
+ return False
+ # new config loaded
+ self._last_m = m_time
+ return True
+ return False
+ except Exception as e:
+ print("Error: reading info list in _MixedDatasetConfigMananger.changed!"
+ "self._config_path={path}err={err}"
+ .format(path=self._config_patherr=str(e)))
+ time.sleep(30)
+
+ def get_config(self) -> List[_MixedDatasetConfig]:
+ if self._config is None:
+ if not self.changed():
+ raise RuntimeError("Failed to load config")
+ if self._config is None:
+ raise RuntimeError("Failed to load config")
+ return self._config
+
+
+def _mixed_dataset_process(
+ config_path: str,
+ q_cmd: multiprocessing.Queue,
+ q_cmd_out: multiprocessing.Queue,
+ q_data: multiprocessing.Queue,
+ rank: int,
+ world_size: int,
+ packer: _MixedDatasetBatchPacker,
+):
+ # ignore SIGINT
+ import signal
+
+ signal.signal(signal.SIGINTsignal.SIG_IGN)
+ config_base_path = os.path.dirname(os.path.abspath(config_path))
+
+ def _convert_to_abs_path(transform_path: str):
+ if transform_path.startswith("/"):
+ return transform_path
+ else:
+ return os.path.join(config_base_pathtransform_path)
+
+ def _build_sample_weights(config: List[_MixedDatasetConfig]):
+ if len(config) == 0:
+ return np.array([]dtype=np.float32)
+ weights = [c["weight"] * c["lines"] for c in config]
+ weights = np.array(weightsdtype=np.float32)
+ sm_weight = weights.sum()
+ if sm_weight > 0:
+ weights = weights / sm_weight
+ return weights
+ else:
+ raise RuntimeError("Empty datasets")
+
+ cfg_mgr = _MixedDatasetConfigMananger(config_path)
+ config = cfg_mgr.get_config()
+
+ for c in config:
+ ds = DistributedDataset(
+ _convert_to_abs_path(c["path"]),
+ rank,
+ world_size,
+ )
+
+ c["lines"] = ds._nlines
+ c["dataset"] = ds
+ if "weight" not in c:
+ c["weight"] = 1.0
+ if "transforms" not in c:
+ c["transforms"] = []
+ elif isinstance(c["transforms"]str):
+ c["transforms"] = _convert_to_abs_path(c["transforms"])
+ if "incontext_weight" not in c:
+ c["incontext_weight"] = [1.0]
+
+ weights = _build_sample_weights(config)
+
+ should_stop = False
+ should_start = False
+
+ while not should_stop:
+ # update config first
+ if cfg_mgr.changed():
+ path_ds_map: Dict[str_MixedDatasetConfig] = {}
+ nw_path_set: Set[str] = set()
+
+ # load new config
+ nw_config = cfg_mgr.get_config()
+
+ # build path -> dataset map
+ for c in config:
+ path_ds_map[_dataset_identity(c)] = c
+
+ # add new datasets
+ for c in nw_config:
+ if _dataset_identity(c) in path_ds_map:
+ # update values only
+ if "weight" in c:
+ path_ds_map[_dataset_identity(c)]["weight"] = c["weight"]
+ if "transform" in c:
+ if isinstance(c["transforms"]str):
+ path_ds_map[_dataset_identity(c)]["transforms"] = _convert_to_abs_path(
+ c["transforms"]
+ )
+ else:
+ path_ds_map[_dataset_identity(c)]["transforms"] = c["transforms"]
+ if "incontext_weight" in c:
+ path_ds_map[_dataset_identity(c)]["incontext_weight"] = c[
+ "incontext_weight"
+ ]
+ else:
+ # new dataset
+ ds = DistributedDataset(
+ _convert_to_abs_path(c["path"]),
+ rank,
+ world_size,
+ )
+ c["lines"] = ds._nlines
+ c["dataset"] = ds
+ if "weight" not in c:
+ c["weight"] = 1.0
+ if "transforms" not in c:
+ c["transforms"] = []
+ elif isinstance(c["transforms"]str):
+ c["transforms"] = _convert_to_abs_path(c["transforms"])
+ if "incontext_weight" not in c:
+ c["incontext_weight"] = [1.0]
+ path_ds_map[_dataset_identity(c)] = c
+ nw_path_set.add(_dataset_identity(c))
+
+ # remove unused datasets
+ for c in config:
+ if _dataset_identity(c) not in nw_path_set:
+ del path_ds_map[_dataset_identity(c)]
+
+ config: List[_MixedDatasetConfig] = []
+ for c in nw_config:
+ config.append(path_ds_map[_dataset_identity(c)])
+ del path_ds_map
+ del nw_path_set
+ del nw_config
+
+ weights = _build_sample_weights(config)
+
+ # get cmds
+ while True:
+ try:
+ cmd = q_cmd.get_nowait()
+ except Empty:
+ break
+ if cmd == "stop":
+ should_stop = True
+ q_cmd_out.put(True)
+ break
+ elif cmd == "state_dict":
+ ret = OrderedDict()
+ for c in config:
+ ds_name = _dataset_identity(c)
+ ret[ds_name] = c["dataset"]._state_dict()
+ q_cmd_out.put(ret)
+ elif cmd == "load_state_dict":
+ state_dict = q_cmd.get()
+ missing = []
+ for c in config:
+ ds_name = _dataset_identity(c)
+ if ds_name in state_dict:
+ c["dataset"].load_state_dict(state_dict[ds_name]strict=False)
+ else:
+ # new dataset
+ missing.append(ds_name)
+ q_cmd_out.put(missing)
+ elif cmd == "start":
+ should_start = True
+ q_cmd_out.put(True)
+ else:
+ raise RuntimeError("Unknown command: {}".format(cmd))
+
+ if should_stop:
+ break
+
+ if not should_start:
+ # wait for start cmd
+ time.sleep(1)
+ continue
+
+ if len(config) == 0:
+ # no dataset available
+ time.sleep(1)
+ continue
+
+ if q_data.full():
+ # queue full
+ time.sleep(1)
+ continue
+
+ # sample a dataset
+ ds_id: int = 0
+
+ while True:
+ ds_id = np.random.choice(weights.shape[0]p=weights)
+ if config[ds_id]["dataset"]._nlines != config[ds_id]["lines"]:
+ # dataset size changed
+ for c in config:
+ c["lines"] = c["dataset"]._nlines
+ weights = _build_sample_weights(config)
+ continue
+ else:
+ break
+
+ batch = packer.add_data(config[ds_id])
+ if batch is not None:
+ # new batch comming
+ q_data.put(batch)
+
+ # clean queue
+ while True:
+ try:
+ q_data.get_nowait()
+ except Empty:
+ break
+
+
+class MixedDataset:
+ def __init__(
+ self,
+ config_path: str,
+ batch_size: int,
+ max_length: int,
+ tokenizer: CPMBeeTokenizer,
+ max_depth: int = 16,
+ ) -> None:
+ self._q_cmd = multiprocessing.Queue()
+ self._q_cmd_out = multiprocessing.Queue()
+ self._q_data = multiprocessing.Queue(maxsize=1)
+ self._packer = _MixedDatasetBatchPacker(batch_sizemax_lengthtokenizermax_depth)
+ self._p = multiprocessing.Process(
+ target=_mixed_dataset_process,
+ args=(
+ config_path,
+ self._q_cmd,
+ self._q_cmd_out,
+ self._q_data,
+ bmt.rank(),
+ bmt.world_size(),
+ self._packer,
+ ),
+ )
+ self._p.start()
+ self._closed = False
+
+ def close(self):
+ if not self._closed:
+ self._closed = True
+ self._q_cmd.put("stop")
+ assert self._q_cmd_out.get()"Failed to stop process"
+ self._p.join()
+
+ @property
+ def closed(self):
+ return self._closed
+
+ def start(self):
+ self._q_cmd.put("start")
+ return self._q_cmd_out.get()
+
+ def state_dict(self):
+ self._q_cmd.put("state_dict")
+ states = self._q_cmd_out.get()
+ if not isinstance(statesOrderedDict):
+ raise RuntimeError("Invalid state dict {}".format(states))
+ if bmt.world_size() == 1:
+ for val in states.values():
+ val["states"].unsqueeze_(0)
+ val["block"].unsqueeze_(0)
+ return states
+
+ ret = OrderedDict()
+ for kv in states.items():
+ num_unused_block = v["states"].size(0)
+ gpu_num_unused_block = torch.tensor([num_unused_block]dtype=torch.long).cuda()
+ max_unused_blocks = (
+ bmt.distributed.all_reduce(gpu_num_unused_blockop="max").cpu().item()
+ )
+ if max_unused_blocks == 0:
+ max_unused_blocks = 1
+ gpu_states = torch.full((max_unused_blocks,)-1dtype=torch.long).cuda()
+ gpu_states[:num_unused_block] = v["states"].cuda()
+
+ gpu_block = v["block"].cuda()
+ global_states = bmt.distributed.all_gather(
+ gpu_states
+ ).cpu() # (world_sizemax_unused_blocks)
+ global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size4)
+ ret[k] = {"states": global_states"block": global_block}
+ return ret
+
+ def load_state_dict(selfdata: OrderedDictstrict: bool = False):
+ self._q_cmd.put("load_state_dict")
+ self._q_cmd.put(data)
+ missing = self._q_cmd_out.get()
+ if strict:
+ if len(missing) > 0:
+ raise RuntimeError("Missing dataset state: {}".format(missing))
+ return missing
+
+ def get(self) -> CPMBeeBatch:
+ ret: CPMBeeBatch = self._q_data.get() # type: ignore
+ if not isinstance(retdict):
+ raise RuntimeError("Invalid data {}".format(ret))
+ return ret
+
+ def __iter__(self):
+ while True:
+ yield self.get()
+
+ def __del__(self):
+ if not self.closed:
+ try:
+ self.close()
+ except Exception:
+ pass
diff --git a/cpm-live/cpm_live/utils/__init__.py b/cpm-live/cpm_live/utils/__init__.py
index cca5d9b..3d9e31d 100644
--- a/cpm-live/cpm_live/utils/__init__.py
+++ b/cpm-live/cpm_live/utils/__init__.py
@@ -1 +1,5 @@
from .config import Config
+from .data_utils import pad
+from .object import allgather_objects
+from .log import LogManagerlogger
+from .config import load_dataset_config
diff --git a/cpm-live/cpm_live/utils/config.py b/cpm-live/cpm_live/utils/config.py
index 7e93ee4..28d2b26 100644
--- a/cpm-live/cpm_live/utils/config.py
+++ b/cpm-live/cpm_live/utils/config.py
@@ -17,10 +17,35 @@
import os
import copy
from typing import AnyDictUnion
+from .log import logger
+
+
+def load_dataset_config(dataset_path: str):
+ cfg = on.load(open(dataset_path"r"encoding="utf-8"))
+
+ platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
+ if platform_config_path is None:
+ logger.info(
+ "no platform_config_path. Directly load dataset_path({dataset_path})"
+ .format(dataset_path=dataset_path)
+ )
+ return cfg
+
+ path_dict = on.load(open(platform_config_path"r"encoding="utf-8"))["dataset_map"]
+ logger.info(
+ "load dataset_path({dataset_path}) with platform_config_path({platform_config_path})"
+ .format(dataset_path=dataset_pathplatform_config_path=platform_config_path)
+ )
+ for dataset in cfg:
+ dataset["path"] = os.path.join(path_dict[dataset["dataset_name"]]dataset["path"])
+ dataset["transforms"] = os.path.join(
+ path_dict[dataset["dataset_name"]]dataset["transforms"]
+ )
+ return cfg
class Config(object):
- """enc_dec model configuration"""
+ """model configuration"""
def __init__(self):
super().__init__()
diff --git a/cpm-live/cpm_live/utils/data_utils.py b/cpm-live/cpm_live/utils/data_utils.py
new file mode 100644
index 0000000..913e062
--- /dev/null
+++ b/cpm-live/cpm_live/utils/data_utils.py
@@ -0,0 +1,44 @@
+import torch
+
+
+def pad(orig_itemskeypadding_value=0padding_side="left"):
+ items = []
+ if isinstance(orig_items[0][key]list):
+ assert isinstance(orig_items[0][key][0]torch.Tensor)
+ for it in orig_items:
+ for tr in it[key]:
+ items.append({key: tr})
+ else:
+ assert isinstance(orig_items[0][key]torch.Tensor)
+ items = orig_items
+
+ batch_size = len(items)
+ shape = items[0][key].shape
+ dim = len(shape)
+ assert dim <= 3
+ max_length = max(item[key].shape[-1] for item in items)
+ min_length = min(item[key].shape[-1] for item in items)
+ dtype = items[0][key].dtype
+
+ if dim == 1:
+ return torch.cat([item[key] for item in items]dim=0)
+ elif dim == 2:
+ if max_length == min_length:
+ return torch.cat([item[key] for item in items]dim=0)
+ tensor = torch.zeros((batch_sizemax_length)dtype=dtype) + padding_value
+ else:
+ tensor = torch.zeros((batch_sizemax_lengthshape[-1])dtype=dtype) + padding_value
+
+ for iitem in enumerate(items):
+ if dim == 2:
+ if padding_side == "left":
+ tensor[i-len(item[key][0]) :] = item[key][0].clone()
+ else:
+ tensor[i: len(item[key][0])] = item[key][0].clone()
+ elif dim == 3:
+ if padding_side == "left":
+ tensor[i-len(item[key][0]) ::] = item[key][0].clone()
+ else:
+ tensor[i: len(item[key][0]):] = item[key][0].clone()
+
+ return tensor
diff --git a/cpm-live/cpm_live/utils/export.py b/cpm-live/cpm_live/utils/export.py
new file mode 100644
index 0000000..6e4db68
--- /dev/null
+++ b/cpm-live/cpm_live/utils/export.py
@@ -0,0 +1,56 @@
+import os
+import time
+import functools
+import torch
+import bmtrain as bmt
+import on
+from cpm_live.models import CPMBee
+from .log import logger
+from typing import ListOptional
+
+
+def rename_if_exists(file_path):
+ if not os.path.exists(file_path):
+ return
+ timestamp = time.strftime('%Y%m%d%H%M%S')
+ file_dirfile_name = os.path.split(file_path)
+ file_rootfile_ext = os.path.splitext(file_name)
+ new_file_name = f"{file_root}_bak_{timestamp}{file_ext}"
+ new_file_path = os.path.join(file_dirnew_file_name)
+ try:
+ os.rename(file_pathnew_file_path)
+ logger.info(f"File '{file_name}' already exists. Renamed to '{new_file_name}'")
+ except Exception as e:
+ logger.warn(
+ "rename file failed,file_path={file_path}new_file_path={new_file_path},err={err}"
+ .format(file_path=file_pathnew_file_path=new_file_patherr=str(e)))
+
+
+def rename_if_exists_decorator(func):
+ @functools.wraps(func)
+ def wrapper(file_path*args**kwargs):
+ rename_if_exists(file_path)
+ return func(file_path*args**kwargs)
+ return wrapper
+
+
+@rename_if_exists_decorator
+def bmt_save(file_path: strmodel: CPMBeeexport_files: Optional[List[str]] = None):
+ bmt.save(modelfile_path)
+ if export_files is not None:
+ export_files.append(file_path)
+
+
+@rename_if_exists_decorator
+def torch_save(file_path: strobj: objectexport_files: Optional[List[str]] = None):
+ torch.save(objfile_path)
+ if export_files is not None:
+ export_files.append(file_path)
+
+
+@rename_if_exists_decorator
+def on_save(file_path: strobj: objectexport_files: Optional[List[str]] = None):
+ with open(file_path"w") as data_f:
+ on.dump(objdata_f)
+ if export_files is not None:
+ export_files.append(file_path)
diff --git a/cpm-live/cpm_live/utils/gradient_shrink.py b/cpm-live/cpm_live/utils/gradient_shrink.py
new file mode 100644
index 0000000..7354d73
--- /dev/null
+++ b/cpm-live/cpm_live/utils/gradient_shrink.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class OpGradientShrink(torch.autograd.Function):
+ @staticmethod
+ def forward(ctxx: torch.Tensoralpha: float):
+ ctx.alpha = alpha
+ return x
+
+ @staticmethod
+ def backward(ctxgrad_output):
+ return grad_output * ctx.alphaNone
+
+
+def gradient_shrink(x: torch.Tensoralpha: float = 0.1):
+ return OpGradientShrink.apply(xalpha)
diff --git a/cpm-live/cpm_live/utils/log.py b/cpm-live/cpm_live/utils/log.py
new file mode 100644
index 0000000..55873ad
--- /dev/null
+++ b/cpm-live/cpm_live/utils/log.py
@@ -0,0 +1,98 @@
+import os
+import sys
+from typing import AnyDictOptionalTupleUnion
+import datetime
+import on
+import logging
+import bmtrain as bmt
+
+
+# Set up the common logger
+def _get_logger():
+ log = logging.getLogger('__name__')
+ log.setLevel(logging.INFO)
+ console_handle = logging.StreamHandler(sys.stdout)
+ node_name = os.getenv("NODE_NAME"str(bmt.rank()))
+ console_handle.setFormatter(
+ logging.Formatter(
+ '[%(levelname)s][%(asctime)s][{}][%(filename)s:%(lineno)d:%(process)d] - %(message)s'
+ .format(node_name),
+ datefmt='%Y-%m-%d %H:%M:%S'
+ )
+ )
+ log.addHandler(console_handle)
+ return log
+
+
+logger = _get_logger()
+
+
+class LogManager:
+ def __init__(selfpath: str):
+ if not os.path.exists(path):
+ os.makedirs(path)
+ self.path = path
+
+ now = self.get_log_time()
+ latest_log: Union[Dict[strAny]None] = None
+ for _ in range(15):
+ log_name = self.get_log_name(now)
+ if os.path.exists(log_name):
+ with open(log_name"r") as flog:
+ latest_log = on.loads(flog.readlines()[-1]) # get last log
+ break
+ now -= datetime.timedelta(days=1)
+
+ if latest_log is None:
+ self.global_token_pass = 0
+ else:
+ self.global_token_pass = latest_log["token pass"]
+
+ def get_log_time(self) -> datetime.datetime:
+ return datetime.datetime.utcnow() + datetime.timedelta(hours=16)
+
+ def get_log_name(selfnow: Optional[datetime.datetime] = None):
+ if now is None:
+ now = self.get_log_time()
+ return os.path.join(self.path"log.%s.txt" % now.strftime("%Y%m%d"))
+
+ def write(
+ self,
+ time: float,
+ iteration: int,
+ loss: float,
+ lr: float,
+ lr_scale: float,
+ time_usage: Dict[strfloat],
+ mem_usage: Dict[strTuple[floatfloat]],
+ avg_time: float,
+ token_max: float,
+ token_pass: float,
+ throughout: float,
+ grad_norm: float,
+ mask_max: float,
+ num_gpus: int,
+ task_loss: Dict[strfloat],
+ model_inspect: Optional[Any] = None,
+ ):
+ with open(self.get_log_name()"a") as fp:
+ ret = {
+ "time": time,
+ "iter": iteration,
+ "loss": loss,
+ "lr": lr,
+ "lr scale": int(lr_scale),
+ "time usage": time_usage,
+ "mem usage": mem_usage,
+ "avg time (s)": avg_time,
+ "token/max": token_max,
+ "token pass": token_pass + self.global_token_pass,
+ "throughout (token/s)": throughout,
+ "grad_norm": grad_norm,
+ "mask/max": mask_max,
+ "num_gpus": num_gpus,
+ "task_loss": task_loss,
+ }
+ if model_inspect is not None:
+ ret["model_inspect"] = model_inspect
+ fp.write(on.dumps(retensure_ascii=False) + "\n")
diff --git a/cpm-live/cpm_live/utils/object.py b/cpm-live/cpm_live/utils/object.py
new file mode 100644
index 0000000..0b32e68
--- /dev/null
+++ b/cpm-live/cpm_live/utils/object.py
@@ -0,0 +1,28 @@
+import bmtrain as bmt
+import pickle
+import torch
+
+
+def allgather_objects(obj):
+ if bmt.world_size() == 1:
+ return [obj]
+
+ with torch.no_grad():
+ data_bytes: bytes = pickle.dumps(obj)
+ data_length: int = len(data_bytes)
+
+ gpu_data_length = torch.tensor([data_length]device="cuda"dtype=torch.long)
+ gathered_length = bmt.distributed.all_gather(gpu_data_length).view(-1).cpu()
+ max_data_length = gathered_length.max().item()
+
+ gpu_data_bytes = torch.zeros(max_data_lengthdtype=torch.uint8device="cuda")
+ byte_storage = torch.ByteStorage.from_buffer(data_bytes)
+ gpu_data_bytes[:data_length] = torch.ByteTensor(byte_storage)
+
+ gathered_data = bmt.distributed.all_gather(gpu_data_bytes).cpu()
+
+ ret = []
+ for i in range(gathered_data.size(0)):
+ data_bytes = gathered_data[i: gathered_length[i].item()].numpy().tobytes()
+ ret.append(pickle.loads(data_bytes))
+ return ret
diff --git a/cpm-live/cpm_live/vocabs/bee.txt b/cpm-live/cpm_live/vocabs/bee.txt
new file mode 100644
index 0000000..b977bda
--- /dev/null
+++ b/cpm-live/cpm_live/vocabs/bee.txt
@@ -0,0 +1,86583 @@
+
+
+
+
+
+
+
+
+
+
+