1
use std::collections::BTreeMap;
2

            
3
use bonsaidb::core::async_trait::async_trait;
4
use mongodb::bson::oid::ObjectId;
5
use mongodb::bson::{doc, Document};
6
use mongodb::options::{
7
    Acknowledgment, ClientOptions, CreateCollectionOptions, IndexOptions, WriteConcern,
8
};
9
use mongodb::{Client, IndexModel};
10
use serde::{Deserialize, Serialize};
11

            
12
use crate::execute::{Backend, BackendOperator, Measurements, Metric, Operator};
13
use crate::model::{Cart, Order, Product, ProductReview};
14
use crate::plan::{
15
    AddProductToCart, Checkout, CreateCart, FindProduct, Load, LookupProduct, OperationResult,
16
    ReviewProduct,
17
};
18

            
19
pub struct MongoBackend {
20
    client: mongodb::Client,
21
}
22

            
23
pub struct MongoOperator {
24
    client: mongodb::Client,
25
}
26

            
27
#[async_trait]
28
impl Backend for MongoBackend {
29
    type Config = String;
30
    type Operator = MongoOperator;
31

            
32
    fn label(&self) -> &'static str {
33
        "mongodb"
34
    }
35

            
36
    async fn new(url: Self::Config) -> Self {
37
        async fn create_collection(db: &mongodb::Database, name: &str) {
38
            db.create_collection(
39
                name,
40
                CreateCollectionOptions::builder()
41
                    .write_concern(
42
                        WriteConcern::builder()
43
                            .w(Acknowledgment::Majority)
44
                            .journal(true)
45
                            .build(),
46
                    )
47
                    .build(),
48
            )
49
            .await
50
            .unwrap();
51
        }
52
        let mut client_options = ClientOptions::parse(&url).await.unwrap();
53
        client_options.retry_writes = Some(false);
54
        client_options.app_name = Some("Commerce Bench".to_string());
55
        let client = Client::with_options(client_options).unwrap();
56
        // Clean up any old data
57
        client.database("commerce").drop(None).await.unwrap();
58

            
59
        let db = client.database("commerce");
60
        create_collection(&db, "categories").await;
61
        create_collection(&db, "products").await;
62
        create_collection(&db, "customers").await;
63
        create_collection(&db, "orders").await;
64
        create_collection(&db, "reviews").await;
65
        create_collection(&db, "carts").await;
66

            
67
        // Unique index of product by name
68
        db.collection::<MongoDoc<Product>>("products")
69
            .create_index(
70
                IndexModel::builder()
71
                    .keys(doc! { "name": 1 })
72
                    .options(IndexOptions::builder().unique(true).build())
73
                    .build(),
74
                None,
75
            )
76
            .await
77
            .unwrap();
78

            
79
        // unique index of product reviews, first by product id and then by
80
        // customer id. This index is used both for grouping in aggregation as
81
        // well as upserting reviews.
82
        db.collection::<ProductReview>("reviews")
83
            .create_index(
84
                IndexModel::builder()
85
                    .keys(doc! { "product_id": 1, "customer_id": 1 })
86
                    .options(IndexOptions::builder().unique(true).build())
87
                    .build(),
88
                None,
89
            )
90
            .await
91
            .unwrap();
92

            
93
        MongoBackend { client }
94
    }
95

            
96
    async fn new_operator_async(&self) -> Self::Operator {
97
        MongoOperator {
98
            client: self.client.clone(),
99
        }
100
    }
101
}
102

            
103
impl BackendOperator for MongoOperator {
104
    type Id = ObjectId;
105
}
106

            
107
#[derive(Serialize, Deserialize)]
108
pub struct MongoDoc<T> {
109
    #[serde(rename = "_id")]
110
    pub id: u32,
111
    #[serde(flatten)]
112
    pub contents: T,
113
}
114

            
115
#[async_trait]
116
impl Operator<Load, ObjectId> for MongoOperator {
117
    async fn operate(
118
        &mut self,
119
        operation: &Load,
120
        _results: &[OperationResult<ObjectId>],
121
        measurements: &Measurements,
122
    ) -> OperationResult<ObjectId> {
123
        async fn insert_mongo_docs<T: Serialize + Clone>(
124
            db: &mongodb::Database,
125
            collection_name: &str,
126
            initial_data: &BTreeMap<u32, T>,
127
        ) {
128
            db.collection::<MongoDoc<T>>(collection_name)
129
                .insert_many(
130
                    initial_data.iter().map(|(id, contents)| MongoDoc {
131
                        id: *id,
132
                        contents: contents.clone(),
133
                    }),
134
                    None,
135
                )
136
                .await
137
                .unwrap();
138
        }
139
        let measurement = measurements.begin("mongodb", Metric::Load);
140
        let db = self.client.database("commerce");
141

            
142
        insert_mongo_docs(&db, "categories", &operation.initial_data.categories).await;
143
        insert_mongo_docs(&db, "products", &operation.initial_data.products).await;
144
        insert_mongo_docs(&db, "customers", &operation.initial_data.customers).await;
145
        insert_mongo_docs(&db, "orders", &operation.initial_data.orders).await;
146

            
147
        db.collection::<ProductReview>("reviews")
148
            .insert_many(operation.initial_data.reviews.iter().cloned(), None)
149
            .await
150
            .unwrap();
151

            
152
        measurement.finish();
153
        OperationResult::Ok
154
    }
155
}
156

            
157
#[async_trait]
158
impl Operator<FindProduct, ObjectId> for MongoOperator {
159
    async fn operate(
160
        &mut self,
161
        operation: &FindProduct,
162
        _results: &[OperationResult<ObjectId>],
163
        measurements: &Measurements,
164
    ) -> OperationResult<ObjectId> {
165
        let measurement = measurements.begin("mongodb", Metric::FindProduct);
166
        let database = self.client.database("commerce");
167
        let product = database
168
            .collection::<MongoDoc<Product>>("products")
169
            .find_one(doc! { "name": &operation.name }, None)
170
            .await
171
            .unwrap()
172
            .unwrap();
173
        let rating = database
174
            .collection::<Document>("ratings-by-product")
175
            .find_one(doc!("_id": product.id), None)
176
            .await
177
            .unwrap()
178
            .and_then(|doc| doc.get_f64("rating").ok().map(|f| f as f32));
179

            
180
        measurement.finish();
181
        OperationResult::Product {
182
            id: product.id,
183
            product: product.contents,
184
            rating,
185
        }
186
    }
187
}
188

            
189
#[async_trait]
190
impl Operator<LookupProduct, ObjectId> for MongoOperator {
191
    async fn operate(
192
        &mut self,
193
        operation: &LookupProduct,
194
        _results: &[OperationResult<ObjectId>],
195
        measurements: &Measurements,
196
    ) -> OperationResult<ObjectId> {
197
        let measurement = measurements.begin("mongodb", Metric::LookupProduct);
198
        let database = self.client.database("commerce");
199
        let product = database
200
            .collection::<MongoDoc<Product>>("products")
201
            .find_one(doc! { "_id": operation.id }, None)
202
            .await
203
            .unwrap()
204
            .unwrap();
205
        let rating = database
206
            .collection::<Document>("ratings-by-product")
207
            .find_one(doc!("_id": product.id), None)
208
            .await
209
            .unwrap()
210
            .and_then(|doc| doc.get_f64("rating").ok().map(|f| f as f32));
211

            
212
        measurement.finish();
213
        OperationResult::Product {
214
            id: product.id,
215
            product: product.contents,
216
            rating,
217
        }
218
    }
219
}
220

            
221
#[async_trait]
222
impl Operator<CreateCart, ObjectId> for MongoOperator {
223
    async fn operate(
224
        &mut self,
225
        _operation: &CreateCart,
226
        _results: &[OperationResult<ObjectId>],
227
        measurements: &Measurements,
228
    ) -> OperationResult<ObjectId> {
229
        let measurement = measurements.begin("mongodb", Metric::CreateCart);
230
        let carts = self.client.database("commerce").collection::<Cart>("carts");
231
        let result = carts.insert_one(Cart::default(), None).await.unwrap();
232

            
233
        measurement.finish();
234
        OperationResult::Cart {
235
            id: result.inserted_id.as_object_id().unwrap(),
236
        }
237
    }
238
}
239

            
240
#[async_trait]
241
impl Operator<AddProductToCart, ObjectId> for MongoOperator {
242
    async fn operate(
243
        &mut self,
244
        operation: &AddProductToCart,
245
        results: &[OperationResult<ObjectId>],
246
        measurements: &Measurements,
247
    ) -> OperationResult<ObjectId> {
248
        let cart_id = match &results[operation.cart.0] {
249
            OperationResult::Cart { id } => *id,
250
            _ => unreachable!("Invalid operation result"),
251
        };
252
        let product = match &results[operation.product.0] {
253
            OperationResult::Product { id, .. } => *id,
254
            _ => unreachable!("Invalid operation result"),
255
        };
256

            
257
        let measurement = measurements.begin("mongodb", Metric::AddProductToCart);
258
        let carts = self.client.database("commerce").collection::<Cart>("carts");
259

            
260
        carts
261
            .update_one(
262
                doc! {"_id": cart_id },
263
                doc! { "$push": { "product_ids": product } },
264
                None,
265
            )
266
            .await
267
            .unwrap();
268
        measurement.finish();
269

            
270
        OperationResult::CartProduct { id: product }
271
    }
272
}
273

            
274
#[async_trait]
275
impl Operator<Checkout, ObjectId> for MongoOperator {
276
    async fn operate(
277
        &mut self,
278
        operation: &Checkout,
279
        results: &[OperationResult<ObjectId>],
280
        measurements: &Measurements,
281
    ) -> OperationResult<ObjectId> {
282
        let cart_id = match &results[operation.cart.0] {
283
            OperationResult::Cart { id } => *id,
284
            _ => unreachable!("Invalid operation result"),
285
        };
286

            
287
        let measurement = measurements.begin("mongodb", Metric::Checkout);
288
        let carts = self.client.database("commerce").collection::<Cart>("carts");
289
        let cart = carts
290
            .find_one_and_delete(doc! { "_id": cart_id }, None)
291
            .await
292
            .unwrap()
293
            .unwrap();
294

            
295
        let orders = self
296
            .client
297
            .database("commerce")
298
            .collection::<Order>("orders");
299

            
300
        orders
301
            .insert_one(
302
                Order {
303
                    customer_id: operation.customer_id,
304
                    product_ids: cart.product_ids,
305
                },
306
                None,
307
            )
308
            .await
309
            .unwrap();
310

            
311
        measurement.finish();
312

            
313
        OperationResult::Ok
314
    }
315
}
316

            
317
#[async_trait]
318
impl Operator<ReviewProduct, ObjectId> for MongoOperator {
319
    async fn operate(
320
        &mut self,
321
        operation: &ReviewProduct,
322
        results: &[OperationResult<ObjectId>],
323
        measurements: &Measurements,
324
    ) -> OperationResult<ObjectId> {
325
        let product_id = match &results[operation.product_id.0] {
326
            OperationResult::Product { id, .. } => *id,
327
            OperationResult::CartProduct { id, .. } => *id,
328
            other => unreachable!("Invalid operation result {:?}", other),
329
        };
330

            
331
        let measurement = measurements.begin("mongodb", Metric::RateProduct);
332
        let reviews = self
333
            .client
334
            .database("commerce")
335
            .collection::<ProductReview>("reviews");
336

            
337
        // Upsert the review for the customer.
338
        reviews
339
            .update_one(
340
                doc! { "product_id": product_id, "customer_id": operation.customer_id },
341
                doc! { "$set": {
342
                    "review": operation.review.clone(),
343
                    "rating": u32::from(operation.rating)
344
                }},
345
                None,
346
            )
347
            .await
348
            .unwrap();
349

            
350
        // Re-aggregate the ratings for the product that was reviewed. This
351
        // updates the materialized view, "ratings-by-product", but only on
352
        // the $match.
353
        reviews
354
            .aggregate(
355
                [
356
                    doc! {
357
                        "$match": { "product_id": product_id },
358
                    },
359
                    doc! {
360
                        "$group": { "_id": "$product_id", "rating": { "$avg": "$rating" } },
361
                    },
362
                    doc! {
363
                        "$merge": {"into": "ratings-by-product", "whenMatched": "replace" }
364
                    },
365
                ],
366
                None,
367
            )
368
            .await
369
            .unwrap();
370
        measurement.finish();
371

            
372
        OperationResult::Ok
373
    }
374
}