DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:

MXNet:由Apache软件基金会支持的开源深度学习框架。
PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。

TensorFlow:由Google开发的另一个流行的开源机器学习框架。
DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。
maven
ai.djl api 0.28.0 ai.djl.pytorch pytorch-engine 0.28.0 ai.djl.pytorch pytorch-model-zoo 0.28.0 ai.djl basicdataset 0.28.0 ai.djl model-zoo 0.28.0
Java DJL 架构图
┌──────────────────────────────┐ │ ModelZoo │ ├──────────────────────────────┤ │ Model │ └───────────────┬──────────────┘ │ ┌─────────▼─────────┐ │ Engine │ └───────┬─┬─────────┘ │ │ ┌───────▼─▼─────────┐ │ NDManager │ └───────┬─┬─────────┘ │ │ ┌─────────▼─▼───────────┐ │ Dataset └─────────┬─────────────┘ │ ┌─────────▼─────────────┐ │ Trainer / Predictor │ └───────────────────────┘
主要组件详细描述
1. ModelZoo 和 Model
-
ModelZoo:提供多种预训练模型
ModelZoo 的功能
- 模型发现与下载:
示例
import ai.djl.Application import ai.djl.Model import ai.djl.ModelException import ai.djl.modality.Classifications import ai.djl.modality.cv.Image import ai.djl.repository.zoo.Criteria import ai.djl.repository.zoo.ModelZoo import ai.djl.translate.TranslateException object ModelZooExample { @Throws(ModelException::class, TranslateException::class) @JvmStatic fun main(args: Array) { // 定义模型的标准 val criteria: Criteria = Criteria.builder() .optApplication(Application.CV.IMAGE_CLASSIFICATION) // 应用场景:图像分类 .setTypes(Image::class.java, Classifications::class.java) // 输入输出类型 .optFilter("backbone", "resnet50") // 模型过滤条件 .build() // 从 ModelZoo 加载模型 val model: Model = ModelZoo.loadModel(criteria) // 使用模型进行推理 // ... } }
ModelZoo 的类与接口
- ModelZoo:核心类,提供模型的下载和加载功能。
- Criteria:定义模型加载的标准和过滤条件,用于指定所需模型的应用场景、输入输出类型等。
- ModelLoader:用于实际执行模型的下载和加载操作。
-
Model:表示一个深度学习模型的接口,包含模型的加载、保存和运行等操作。
-
ai.djl.ModelZoo
Key Methods:
- Model loadModel(Criteria criteria): Loads a model based on the provided criteria.
- ModelInfo getModel(ModelId modelId): Retrieves information about a specific model using its ModelId.
- Set listModels(ZooModel model): Lists all models in the zoo that match the given model.
ai.djl.ModelInfo Interface
ModelInfo provides metadata about a model, including its name, description, and input/output information.
Key Methods:
- String getName(): Returns the name of the model.
- String getDescription(): Provides a description of the model.
- Shape getInputShape(): Returns the shape of the input tensor.
- Shape getOutputShape(): Returns the shape of the output tensor.
ai.djl.ModelId Class
ModelId uniquely identifies a model in the model zoo. It includes information about the model’s group, name, and version.
Key Fields:
- String getGroup(): Gets the group name of the model.
- String getName(): Gets the name of the model.
- String getVersion(): Gets the version of the model.
ai.djl.Application Enum
Application enumerates different types of applications supported by the model zoo, such as IMAGE_CLASSIFICATION, OBJECT_DETECTION, etc.
Key Values:
- CV.IMAGE_CLASSIFICATION
- CV.OBJECT_DETECTION
- NLP.TEXT_CLASSIFICATION
ai.djl.Criteria Class
Criteria is a builder for creating criteria objects used to filter and load models.
Key Methods:
- static Builder builder(): Creates a new builder instance.
- Criteria optApplication(Application application): Sets the application type.
- Criteria optEngine(String engine): Specifies the engine to use (e.g., MXNet, PyTorch)
example
import ai.djl.Model import ai.djl.ModelException import ai.djl.modality.Classifications import ai.djl.modality.cv.Image import ai.djl.modality.cv.ImageFactory import ai.djl.ndarray.NDList import ai.djl.translate.TranslateException import ai.djl.translate.Translator import ai.djl.translate.TranslatorContext import java.io.IOException import java.nio.file.Paths object DjlExample { @JvmStatic fun main(args: Array) { // 模型路径 val modelDir = Paths.get("models") val modelName = "resnet18" try { Model.newInstance(modelName).use { model -> // 加载模型 model.load(modelDir) // 加载输入图像 val img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg")) // 获取预测器 val predictor = model.newPredictor(MyTranslator()) // 执行推理 val result = predictor.predict(img) println(result) } } catch (e: IOException) { e.printStackTrace() } catch (e: ModelException) { e.printStackTrace() } catch (e: TranslateException) { e.printStackTrace() } } // 自定义 Translator private class MyTranslator : Translator { override fun processInput(ctx: TranslatorContext?, input: Image?): NDList { return NDList(input!!.toNDArray(ctx!!.ndManager)) } override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications { val probabilitiesNDArray = list.singletonOrThrow().softmax(1) val labels: List = List(100) { "name$it" } return Classifications(labels, probabilitiesNDArray) } } }
2. Dataset
-
常见的数据集类型:
- RandomAccessDataset:
- RandomAccessDataset 是一种基本的数据集接口,适用于数据可以随机访问的情况,如数组或列表。
- 它支持批处理(batching)、数据切片(slicing)等操作,适合大多数监督学习任务。
- IterableDataset:
- IterableDataset 适用于数据不能随机访问的情况,如流数据或实时生成的数据。
- 它通过迭代器(iterator)提供数据,适用于需要动态生成或处理的数据源。
- RecordDataset:
- RecordDataset 是基于记录文件(record file)的数据集格式,常用于大规模数据处理。
- 它可以高效地加载和处理数据记录,适用于分布式训练和大数据集的处理。
DJL 的数据集组件提供的功能包括:
- RandomAccessDataset:
-
Dataset:定义数据集的抽象类,用户可以继承该类来实现自定义的数据集。
-
import ai.djl.Model; import ai.djl.ModelException; import ai.djl.inference.Predictor; import ai.djl.modality.Classifications; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelZoo; import ai.djl.translate.TranslateException; import java.io.IOException; import java.nio.file.Paths; public class DjlExample { public static void main(String[] args) throws IOException, ModelException, TranslateException { // 加载模型 Criteria criteria = Criteria.builder() .optEngine("TensorFlow") // 选择引擎 .setTypes(Image.class, Classifications.class) .optModelPath(Paths.get("path/to/model")) .build(); try (Model model = ModelZoo.loadModel(criteria); Predictor predictor = model.newPredictor()) { // 加载图像 Image img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg")); // 进行推理 Classifications result = predictor.predict(img); System.out.println(result); } } }
-
import ai.djl.Application; import ai.djl.Model; import ai.djl.basicdataset.cv.classification.FashionMnist; import ai.djl.engine.Engine; import ai.djl.metric.Metrics; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.EasyTrain; import ai.djl.training.Trainer; import ai.djl.training.dataset.Batch; import ai.djl.training.dataset.Dataset; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Optimizer; import ai.djl.training.tracker.Tracker; import ai.djl.translate.TranslateException; import ai.djl.util.Pair; import java.io.IOException; public class DJLDatasetExample { public static void main(String[] args) throws IOException, TranslateException { NDManager manager = NDManager.newBaseManager(); FashionMnist fashionMnist = FashionMnist.builder() .optUsage(Dataset.Usage.TRAIN) .setSampling(32, true) // 32 is the batch size .optLimit(Long.MAX_VALUE) // Use this to limit the number of samples .build(); fashionMnist.prepare(); Model model = Model.newInstance("fashion-mnist-model"); TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .optOptimizer(Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build()) .addTrainingListeners(TrainingListener.Defaults.logging()); try (Trainer trainer = model.newTrainer(config)) { trainer.initialize(new long[]{1, 28, 28}); // Example shape for image data Metrics metrics = new Metrics(); trainer.setMetrics(metrics); for (Batch batch : trainer.iterateDataset(fashionMnist)) { EasyTrain.trainBatch(trainer, batch); trainer.step(); batch.close(); } trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer)); } } }
3. Engine 和 NDManager
-
Engine:DJL支持多个深度学习引擎,如MXNet、PyTorch、ONNX、TensorFlow,Engine接口提供统一的抽象,方便切换底层引擎。
-
NDManager:管理NDArray,用于处理多维数组,封装了底层的数组操作。
Using DJL Engine
import ai.djl.Model import ai.djl.ModelException import ai.djl.ndarray.NDArray import ai.djl.ndarray.NDList import ai.djl.ndarray.types.Shape import ai.djl.translate.Batchifier import ai.djl.translate.TranslateException import ai.djl.translate.Translator import ai.djl.translate.TranslatorContext import java.io.IOException import java.nio.file.Paths object DJLEngineExample { @Throws(ModelException::class, TranslateException::class, IOException::class) @JvmStatic fun main(args: Array) { // Initialize the model val model = Model.newInstance("model-name", "ai.djl.pytorch") // Assuming "model-name" is valid and using PyTorch engine // Load a pre-trained model model.load(Paths.get("path/to/your/model")) // Ensure the path is correct // Define a translator for data preprocessing and postprocessing val translator: Translator = object : Translator { override fun processInput(ctx: TranslatorContext, input: Array): NDList { val manager = ctx.ndManager val array: NDArray = manager.create(input.toFloatArray()).reshape(Shape(1, input.size.toLong())) // Reshape might be necessary return NDList(array) } override fun processOutput(ctx: TranslatorContext, list: NDList): Float { // Assuming the output is a single scalar value return list[0].getFloat() // Use getFloat() to get the scalar value } override fun getBatchifier(): Batchifier? { return null // Or implement batching if needed } } model.newPredictor(translator).use { predictor -> val input = arrayOf(1.0f, 2.0f, 3.0f) // Input should match the model's expected input shape val output = predictor.predict(input) println("Prediction: $output") } } }
Overview of NDManager
Key Features of NDManager:
- Memory Management: Automates the process of memory allocation and deallocation for NDArrays.
- Resource Scope: NDArrays created by an NDManager are tied to the lifecycle of that manager. When the manager is closed, all associated NDArrays are also released.
- Hierarchical Structure: NDManagers can create child managers, which can further manage their own NDArrays. This is useful for managing resources in complex workflows.
Using NDManager
import ai.djl.ndarray.NDManager object NDManagerExample { @JvmStatic fun main(args: Array) { NDManager.newBaseManager().use { manager -> val array = manager.create(floatArrayOf(1.0f, 2.0f, 3.0f)) println("Array: $array") // Perform operations val result = array.add(2.0f) println("Result: $result") } // No need to explicitly free the memory, it's handled by the NDManager } }
4. Trainer 和 Predictor
-
Trainer 类
提供训练模型的接口,包含优化器、损失函数和训练循环等功能。用于训练深度学习模型。它封装了训练过程中的一些常见操作,如前向传播、反向传播和参数更新。
主要功能包括:
- 模型的训练和验证
- 管理优化器和损失函数
- 提供易于使用的训练循环
代码演示
以下是使用 DJL 的 Trainer 类训练一个简单神经网络的示例代码:
import ai.djl.Model import ai.djl.basicdataset.cv.classification.FashionMnist import ai.djl.basicmodelzoo.basic.Mlp import ai.djl.ndarray.types.Shape import ai.djl.training.DefaultTrainingConfig import ai.djl.training.TrainingConfig import ai.djl.training.dataset.Dataset import ai.djl.training.dataset.RandomAccessDataset import ai.djl.training.listener.LoggingTrainingListener import ai.djl.training.listener.TrainingListener import ai.djl.training.loss.Loss import ai.djl.training.optimizer.Optimizer import ai.djl.training.tracker.FixedPerVarTracker import ai.djl.training.util.ProgressBar import ai.djl.translate.TranslateException import java.io.IOException import java.nio.file.Paths object DjlTrainerDemo { @Throws(IOException::class, TranslateException::class) @JvmStatic fun main(args: Array) { // Load dataset val trainDataset: RandomAccessDataset = FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true).build() trainDataset.prepare(ProgressBar()) // Define model val model = Model.newInstance("mlp") model.block = Mlp(28 * 28, 10, intArrayOf(128, 64)) // Define training configuration val config: TrainingConfig = DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .optOptimizer( Optimizer.sgd() .setLearningRateTracker( FixedPerVarTracker.builder() .setDefaultValue(0.01f) .build() ).build() ) .addTrainingListeners(LoggingTrainingListener()) model.newTrainer(config).use { trainer -> trainer.initialize(Shape(1, (28 * 28).toLong())) for (epoch in 0..9) { for (batch in trainer.iterateDataset(trainDataset)) { trainer.step() batch.close() } trainer.notifyListeners { listener: TrainingListener -> listener.onEpoch(trainer) } } model.save(Paths.get("model"), "mlp") } } }
Predictor 类
用于模型推理,接收输入数据并返回预测结果。用于对训练好的模型进行推理。它提供了一个简单的接口,用于将输入数据传递给模型并获取预测结果。
主要功能包括:
- 加载模型进行推理
- 处理输入和输出数据的转换
代码演示
import ai.djl.Model import ai.djl.modality.Classifications import ai.djl.ndarray.NDArray import ai.djl.ndarray.NDList import ai.djl.ndarray.NDManager import ai.djl.ndarray.types.Shape import ai.djl.translate.Batchifier import ai.djl.translate.TranslateException import ai.djl.translate.Translator import ai.djl.translate.TranslatorContext import java.io.IOException import java.nio.file.Paths object DjlPredictorDemo { @Throws(IOException::class, TranslateException::class) @JvmStatic fun main(args: Array) { // Load model val model = Model.newInstance("mlp") model.load(Paths.get("model"), "mlp") // Define Translator val translator: Translator = object : Translator { override fun processInput(ctx: TranslatorContext, input: NDArray): NDList { return NDList(input.reshape(Shape(1, (28 * 28).toLong()))) } override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications { // Assuming the output NDArray is the first element in NDList val probabilities = list.singletonOrThrow() return Classifications(listOf("Label1", "Label2"), probabilities) // Example labels } override fun getBatchifier(): Batchifier { return Batchifier.STACK } } model.newPredictor(translator).use { predictor -> val manager = NDManager.newBaseManager() val array = manager.ones(Shape(1, (28 * 28).toLong())) val classifications = predictor.predict(array) println(classifications) } } }
-
-
-
-