如何解决如何在 Python 中使用通用映射
我想构建一个(抽象的)数据集,它是通用的并且只提供加载文件的框架。
然后子类用于特定类型(这里打扰示例,注释属于 np.ndarray
类型
当我实例化一个 ImageDataset
类型的对象时,我得到
File "/home/maximilian/darts/tests/test_dataset.py",line 12,in test_simple_loading
dataset = ImageDataset(dataset_root)
File "/home/maximilian/darts/darts/datasets.py",line 72,in __init__
super(ImageDataset,self).__init__(path)
File "/home/maximilian/darts/darts/datasets.py",line 19,in __init__
self.__load_data(path)
File "/home/maximilian/darts/darts/datasets.py",line 48,in __load_data
self.annotations.update(annotations)
AttributeError: 'ImageDataset' object has no attribute 'annotations'
谁能告诉我我在这里做错了什么?
from collections import defaultdict
from abc import abstractmethod
from itertools import tee
from pathlib import Path
from typing import Iterator,TypeVar,Tuple,Dict,Mapping
import numpy as np
from cv2 import haveImageReader,imread
Key = str
Annotation = TypeVar('Annotation')
Sample = TypeVar('Sample')
AnnotatedSample = Tuple[Sample,Annotation]
class Dataset(Mapping[Key,AnnotatedSample]):
def __init__(self,path: Path):
self.__path = path
self.__load_data(path)
self.annotations: Dict[str,Annotation] = defaultdict(lambda: None)
self.samples: Dict[str,Sample] = defaultdict(lambda: None)
@abstractmethod
def _is_sample_file(self,file : Path) -> bool:
raise NotImplementedError()
@abstractmethod
def _is_annotation_file(self,file : Path) -> bool:
raise NotImplementedError()
@abstractmethod
def _load_annotation(self,file: Path) -> Annotation:
raise NotImplementedError()
@abstractmethod
def _load_sample(self,file: Path) -> Sample:
raise NotImplementedError()
def __load_data(self,path: Path):
files = filter(lambda file: not file.is_dir(),path.glob('*'))
it1,it2 = tee(files)
annotations_files = filter(lambda file: self._is_annotation_file(file),it1)
sample_files = filter(lambda file: self._is_sample_file(file),it2)
annotations = map(lambda file: (file.stem,self._load_annotation(file)),annotations_files)
samples = map(lambda file: (file.stem,self._load_sample(file)),sample_files)
self.annotations.update(annotations)
self.samples.update(samples)
annotation_keys = set(self.annotations)
samples_keys = set(self.samples)
annotations_without_sample = annotation_keys.difference(samples_keys)
if annotations_without_sample:
raise ValueError(
f"For each annotation a sample file must be given. Annotation without sample {annotations_without_sample} ")
def __getitem__(self,k: Key) -> AnnotatedSample:
return self.samples[k],self.annotations
def __len__(self) -> int:
return len(self.samples)
def __iter__(self) -> Iterator[Key]:
return self.samples.keys()
class ImageDataset(Dataset[np.ndarray,np.ndarray]):
ANNOTATION_EXTENSIONS = ['.npy']
def __init__(self,path : Path,transformations = []):
super(ImageDataset,self).__init__(path)
self.__transformations = transformations
def _is_annotation_file(self,file: Path) -> bool:
return haveImageReader(str(file))
def _is_sample_file(self,file: Path) -> bool:
return file.stem in ImageDataset.ANNOTATION_EXTENSIONS
def _load_annotation(self,file: Path) -> Annotation:
return np.load(str(file))
def _load_sample(self,file: Path) -> Sample:
image = imread(str(file))
for transform in self.__transformations:
image = transform(image)
return image
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。